1
0
mirror of https://github.com/golang/go synced 2024-11-26 04:07:59 -07:00

net/http: add Transport.OnProxyConnectResponse

Fixes #54299

Change-Id: I3a29527bde7ac71f3824e771982db4257234e9ef
Reviewed-on: https://go-review.googlesource.com/c/go/+/447216
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: xie cui <523516579@qq.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
cuiweixie 2022-11-02 21:07:46 +08:00 committed by Brad Fitzpatrick
parent cd8d1bca2c
commit 2f4d5c3b79
3 changed files with 111 additions and 1 deletions

1
api/next/54299.txt Normal file
View File

@ -0,0 +1 @@
pkg net/http, type Transport struct, OnProxyConnectResponse func(context.Context, *url.URL, *Request, *Response) error #54299

View File

@ -120,6 +120,11 @@ type Transport struct {
// If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*Request) (*url.URL, error)
// OnProxyConnectResponse is called when the Transport gets an HTTP response from
// a proxy for a CONNECT request. It's called before the check for a 200 OK response.
// If it returns an error, the request fails with that error.
OnProxyConnectResponse func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error
// DialContext specifies the dial function for creating unencrypted TCP connections.
// If DialContext is nil (and the deprecated Dial below is also nil),
// then the transport dials using package net.
@ -309,6 +314,7 @@ func (t *Transport) Clone() *Transport {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
t2 := &Transport{
Proxy: t.Proxy,
OnProxyConnectResponse: t.OnProxyConnectResponse,
DialContext: t.DialContext,
Dial: t.Dial,
DialTLS: t.DialTLS,
@ -1716,6 +1722,14 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
conn.Close()
return nil, err
}
if t.OnProxyConnectResponse != nil {
err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp)
if err != nil {
return nil, err
}
}
if resp.StatusCode != 200 {
_, text, ok := strings.Cut(resp.Status, " ")
conn.Close()

View File

@ -1465,6 +1465,98 @@ func TestTransportProxy(t *testing.T) {
}
}
func TestOnProxyConnectResponse(t *testing.T) {
var tcases = []struct {
proxyStatusCode int
err error
}{
{
StatusOK,
nil,
},
{
StatusForbidden,
errors.New("403"),
},
}
for _, tcase := range tcases {
h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
})
h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
// Implement an entire CONNECT proxy
if r.Method == "CONNECT" {
if tcase.proxyStatusCode != StatusOK {
w.WriteHeader(tcase.proxyStatusCode)
return
}
hijacker, ok := w.(Hijacker)
if !ok {
t.Errorf("hijack not allowed")
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
t.Errorf("hijacking failed")
return
}
res := &Response{
StatusCode: StatusOK,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(Header),
}
targetConn, err := net.Dial("tcp", r.URL.Host)
if err != nil {
t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err)
return
}
if err := res.Write(clientConn); err != nil {
t.Errorf("Writing 200 OK failed: %v", err)
return
}
go io.Copy(targetConn, clientConn)
go func() {
io.Copy(clientConn, targetConn)
targetConn.Close()
}()
}
})
ts := newClientServerTest(t, https1Mode, h1).ts
proxy := newClientServerTest(t, https1Mode, h2).ts
pu, err := url.Parse(proxy.URL)
if err != nil {
t.Fatal(err)
}
c := proxy.Client()
c.Transport.(*Transport).Proxy = ProxyURL(pu)
c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
if proxyURL.String() != pu.String() {
t.Errorf("proxy url got %s, want %s", proxyURL, pu)
}
if "https://"+connectReq.URL.String() != ts.URL {
t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL)
}
return tcase.err
}
if _, err := c.Head(ts.URL); err != nil {
if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) {
t.Errorf("got %v, want %v", err, tcase.err)
}
}
}
}
// Issue 28012: verify that the Transport closes its TCP connection to http proxies
// when they're slow to reply to HTTPS CONNECT responses.
func TestTransportProxyHTTPSConnectLeak(t *testing.T) {
@ -5906,7 +5998,10 @@ func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) {
func TestTransportClone(t *testing.T) {
tr := &Transport{
Proxy: func(*Request) (*url.URL, error) { panic("") },
Proxy: func(*Request) (*url.URL, error) { panic("") },
OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error {
return nil
},
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
Dial: func(network, addr string) (net.Conn, error) { panic("") },
DialTLS: func(network, addr string) (net.Conn, error) { panic("") },