From 2f4d5c3b791b9b78c32ad587a70adfc1b46f29e0 Mon Sep 17 00:00:00 2001 From: cuiweixie Date: Wed, 2 Nov 2022 21:07:46 +0800 Subject: [PATCH] net/http: add Transport.OnProxyConnectResponse Fixes #54299 Change-Id: I3a29527bde7ac71f3824e771982db4257234e9ef Reviewed-on: https://go-review.googlesource.com/c/go/+/447216 Reviewed-by: Brad Fitzpatrick Run-TryBot: xie cui <523516579@qq.com> TryBot-Result: Gopher Robot Reviewed-by: Bryan Mills Reviewed-by: Damien Neil --- api/next/54299.txt | 1 + src/net/http/transport.go | 14 +++++ src/net/http/transport_test.go | 97 +++++++++++++++++++++++++++++++++- 3 files changed, 111 insertions(+), 1 deletion(-) create mode 100644 api/next/54299.txt diff --git a/api/next/54299.txt b/api/next/54299.txt new file mode 100644 index 00000000000..19bac0cf176 --- /dev/null +++ b/api/next/54299.txt @@ -0,0 +1 @@ +pkg net/http, type Transport struct, OnProxyConnectResponse func(context.Context, *url.URL, *Request, *Response) error #54299 \ No newline at end of file diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 671d9959ea6..2a508ec41b1 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -120,6 +120,11 @@ type Transport struct { // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*Request) (*url.URL, error) + // OnProxyConnectResponse is called when the Transport gets an HTTP response from + // a proxy for a CONNECT request. It's called before the check for a 200 OK response. + // If it returns an error, the request fails with that error. + OnProxyConnectResponse func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error + // DialContext specifies the dial function for creating unencrypted TCP connections. // If DialContext is nil (and the deprecated Dial below is also nil), // then the transport dials using package net. @@ -309,6 +314,7 @@ func (t *Transport) Clone() *Transport { t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) t2 := &Transport{ Proxy: t.Proxy, + OnProxyConnectResponse: t.OnProxyConnectResponse, DialContext: t.DialContext, Dial: t.Dial, DialTLS: t.DialTLS, @@ -1716,6 +1722,14 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers conn.Close() return nil, err } + + if t.OnProxyConnectResponse != nil { + err = t.OnProxyConnectResponse(ctx, cm.proxyURL, connectReq, resp) + if err != nil { + return nil, err + } + } + if resp.StatusCode != 200 { _, text, ok := strings.Cut(resp.Status, " ") conn.Close() diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index a5818455160..b637e40cb42 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -1465,6 +1465,98 @@ func TestTransportProxy(t *testing.T) { } } +func TestOnProxyConnectResponse(t *testing.T) { + + var tcases = []struct { + proxyStatusCode int + err error + }{ + { + StatusOK, + nil, + }, + { + StatusForbidden, + errors.New("403"), + }, + } + for _, tcase := range tcases { + h1 := HandlerFunc(func(w ResponseWriter, r *Request) { + + }) + + h2 := HandlerFunc(func(w ResponseWriter, r *Request) { + // Implement an entire CONNECT proxy + if r.Method == "CONNECT" { + if tcase.proxyStatusCode != StatusOK { + w.WriteHeader(tcase.proxyStatusCode) + return + } + hijacker, ok := w.(Hijacker) + if !ok { + t.Errorf("hijack not allowed") + return + } + clientConn, _, err := hijacker.Hijack() + if err != nil { + t.Errorf("hijacking failed") + return + } + res := &Response{ + StatusCode: StatusOK, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + } + + targetConn, err := net.Dial("tcp", r.URL.Host) + if err != nil { + t.Errorf("net.Dial(%q) failed: %v", r.URL.Host, err) + return + } + + if err := res.Write(clientConn); err != nil { + t.Errorf("Writing 200 OK failed: %v", err) + return + } + + go io.Copy(targetConn, clientConn) + go func() { + io.Copy(clientConn, targetConn) + targetConn.Close() + }() + } + }) + ts := newClientServerTest(t, https1Mode, h1).ts + proxy := newClientServerTest(t, https1Mode, h2).ts + + pu, err := url.Parse(proxy.URL) + if err != nil { + t.Fatal(err) + } + + c := proxy.Client() + + c.Transport.(*Transport).Proxy = ProxyURL(pu) + c.Transport.(*Transport).OnProxyConnectResponse = func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error { + if proxyURL.String() != pu.String() { + t.Errorf("proxy url got %s, want %s", proxyURL, pu) + } + + if "https://"+connectReq.URL.String() != ts.URL { + t.Errorf("connect url got %s, want %s", connectReq.URL, ts.URL) + } + return tcase.err + } + if _, err := c.Head(ts.URL); err != nil { + if tcase.err != nil && !strings.Contains(err.Error(), tcase.err.Error()) { + t.Errorf("got %v, want %v", err, tcase.err) + } + } + } +} + // Issue 28012: verify that the Transport closes its TCP connection to http proxies // when they're slow to reply to HTTPS CONNECT responses. func TestTransportProxyHTTPSConnectLeak(t *testing.T) { @@ -5906,7 +5998,10 @@ func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) { func TestTransportClone(t *testing.T) { tr := &Transport{ - Proxy: func(*Request) (*url.URL, error) { panic("") }, + Proxy: func(*Request) (*url.URL, error) { panic("") }, + OnProxyConnectResponse: func(ctx context.Context, proxyURL *url.URL, connectReq *Request, connectRes *Response) error { + return nil + }, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, Dial: func(network, addr string) (net.Conn, error) { panic("") }, DialTLS: func(network, addr string) (net.Conn, error) { panic("") },