diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 88761909fd..ca97489eea 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -261,13 +261,14 @@ type Transport struct { // ReadBufferSize specifies the size of the read buffer used // when reading from the transport. - //If zero, a default (currently 4KB) is used. + // If zero, a default (currently 4KB) is used. ReadBufferSize int // nextProtoOnce guards initialization of TLSNextProto and // h2transport (via onceSetNextProtoDefaults) - nextProtoOnce sync.Once - h2transport h2Transport // non-nil if http2 wired up + nextProtoOnce sync.Once + h2transport h2Transport // non-nil if http2 wired up + tlsNextProtoWasNil bool // whether TLSNextProto was nil when the Once fired // ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero // TLSClientConfig or Dial, DialTLS or DialContext func is provided. By default, use of any those fields conservatively @@ -290,6 +291,40 @@ func (t *Transport) readBufferSize() int { return 4 << 10 } +// Clone returns a deep copy of t's exported fields. +func (t *Transport) Clone() *Transport { + t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + t2 := &Transport{ + Proxy: t.Proxy, + DialContext: t.DialContext, + Dial: t.Dial, + DialTLS: t.DialTLS, + TLSClientConfig: t.TLSClientConfig.Clone(), + TLSHandshakeTimeout: t.TLSHandshakeTimeout, + DisableKeepAlives: t.DisableKeepAlives, + DisableCompression: t.DisableCompression, + MaxIdleConns: t.MaxIdleConns, + MaxIdleConnsPerHost: t.MaxIdleConnsPerHost, + MaxConnsPerHost: t.MaxConnsPerHost, + IdleConnTimeout: t.IdleConnTimeout, + ResponseHeaderTimeout: t.ResponseHeaderTimeout, + ExpectContinueTimeout: t.ExpectContinueTimeout, + ProxyConnectHeader: t.ProxyConnectHeader.Clone(), + MaxResponseHeaderBytes: t.MaxResponseHeaderBytes, + ForceAttemptHTTP2: t.ForceAttemptHTTP2, + WriteBufferSize: t.WriteBufferSize, + ReadBufferSize: t.ReadBufferSize, + } + if !t.tlsNextProtoWasNil { + npm := map[string]func(authority string, c *tls.Conn) RoundTripper{} + for k, v := range t.TLSNextProto { + npm[k] = v + } + t2.TLSNextProto = npm + } + return t2 +} + // h2Transport is the interface we expect to be able to call from // net/http against an *http2.Transport that's either bundled into // h2_bundle.go or supplied by the user via x/net/http2. @@ -303,6 +338,7 @@ type h2Transport interface { // onceSetNextProtoDefaults initializes TLSNextProto. // It must be called via t.nextProtoOnce.Do. func (t *Transport) onceSetNextProtoDefaults() { + t.tlsNextProtoWasNil = (t.TLSNextProto == nil) if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") { return } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index dbfbd5792d..cf2bbe1189 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -20,6 +20,7 @@ import ( "encoding/binary" "errors" "fmt" + "go/token" "internal/nettrace" "internal/testenv" "io" @@ -5320,3 +5321,53 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { }) } } + +func TestTransportClone(t *testing.T) { + tr := &Transport{ + Proxy: func(*Request) (*url.URL, error) { panic("") }, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") }, + Dial: func(network, addr string) (net.Conn, error) { panic("") }, + DialTLS: func(network, addr string) (net.Conn, error) { panic("") }, + TLSClientConfig: new(tls.Config), + TLSHandshakeTimeout: time.Second, + DisableKeepAlives: true, + DisableCompression: true, + MaxIdleConns: 1, + MaxIdleConnsPerHost: 1, + MaxConnsPerHost: 1, + IdleConnTimeout: time.Second, + ResponseHeaderTimeout: time.Second, + ExpectContinueTimeout: time.Second, + ProxyConnectHeader: Header{}, + MaxResponseHeaderBytes: 1, + ForceAttemptHTTP2: true, + TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{ + "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") }, + }, + ReadBufferSize: 1, + WriteBufferSize: 1, + } + tr2 := tr.Clone() + rv := reflect.ValueOf(tr2).Elem() + rt := rv.Type() + for i := 0; i < rt.NumField(); i++ { + sf := rt.Field(i) + if !token.IsExported(sf.Name) { + continue + } + if rv.Field(i).IsZero() { + t.Errorf("cloned field t2.%s is zero", sf.Name) + } + } + + if _, ok := tr2.TLSNextProto["foo"]; !ok { + t.Errorf("cloned Transport lacked TLSNextProto 'foo' key") + } + + // But test that a nil TLSNextProto is kept nil: + tr = new(Transport) + tr2 = tr.Clone() + if tr2.TLSNextProto != nil { + t.Errorf("Transport.TLSNextProto unexpected non-nil") + } +}