1
0
mirror of https://github.com/golang/go synced 2024-11-23 00:40:08 -07:00

net/http: close connection if OnProxyConnectResponse returns an error

Fixes #64804

Change-Id: Ibe56ab8d114b8826e477b0718470d0b9fbfef9b0
Reviewed-on: https://go-review.googlesource.com/c/go/+/560856
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
This commit is contained in:
Damien Neil 2024-02-02 16:19:37 -08:00
parent e17e5308fd
commit 62cebb2e91
2 changed files with 32 additions and 0 deletions

View File

@ -1761,6 +1761,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
if t.OnProxyConnectResponse != nil {
err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp)
if err != nil {
conn.Close()
return nil, err
}
}

View File

@ -1523,6 +1523,24 @@ func TestOnProxyConnectResponse(t *testing.T) {
c := proxy.Client()
var (
dials atomic.Int32
closes atomic.Int32
)
c.Transport.(*Transport).DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
conn, err := net.Dial(network, addr)
if err != nil {
return nil, err
}
dials.Add(1)
return noteCloseConn{
Conn: conn,
closeFunc: func() {
closes.Add(1)
},
}, nil
}
c.Transport.(*Transport).Proxy = ProxyURL(pu)
c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
if proxyURL.String() != pu.String() {
@ -1534,10 +1552,23 @@ func TestOnProxyConnectResponse(t *testing.T) {
}
return tcase.err
}
wantCloses := int32(0)
if _, err := c.Head(ts.URL); err != nil {
wantCloses = 1
if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
t.Errorf("got %v, want %v", err, tcase.err)
}
} else {
if tcase.err != nil {
t.Errorf("got %v, want nil", err)
}
}
if got, want := dials.Load(), int32(1); got != want {
t.Errorf("got %v dials, want %v", got, want)
}
// #64804: If OnProxyConnectResponse returns an error, we should close the conn.
if got, want := closes.Load(), wantCloses; got != want {
t.Errorf("got %v closes, want %v", got, want)
}
}
}