diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 57c70e72f95..2a549a9576b 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -1761,6 +1761,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers if t.OnProxyConnectResponse != nil { err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp) if err != nil { + conn.Close() return nil, err } } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 3057024b764..698a43530ad 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -1523,6 +1523,24 @@ func TestOnProxyConnectResponse(t *testing.T) { c := proxy.Client() + var ( + dials atomic.Int32 + closes atomic.Int32 + ) + c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + conn, err := net.Dial(network, addr) + if err != nil { + return nil, err + } + dials.Add(1) + return noteCloseConn{ + Conn: conn, + closeFunc: func() { + closes.Add(1) + }, + }, nil + } + c.Transport.(*Transport).Proxy = ProxyURL(pu) c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error { if proxyURL.String() != pu.String() { @@ -1534,10 +1552,23 @@ func TestOnProxyConnectResponse(t *testing.T) { } return tcase.err } + wantCloses := int32(0) if _, err := c.Head(ts.URL); err != nil { + wantCloses = 1 if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) { t.Errorf("got %v, want %v", err, tcase.err) } + } else { + if tcase.err != nil { + t.Errorf("got %v, want nil", err) + } + } + if got, want := dials.Load(), int32(1); got != want { + t.Errorf("got %v dials, want %v", got, want) + } + // #64804: If OnProxyConnectResponse returns an error, we should close the conn. + if got, want := closes.Load(), wantCloses; got != want { + t.Errorf("got %v closes, want %v", got, want) } } }