mirror of
https://github.com/golang/go
synced 2024-11-23 17:40:03 -07:00
net/http: add DialTLSContext hook to Transport
Fixes #21526 Change-Id: I2f8215cd671641cddfa8499f8a8c0130db93dbc6 Reviewed-on: https://go-review.googlesource.com/c/go/+/61291 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
c9a4b01f42
commit
eb55a0c864
@ -142,15 +142,24 @@ type Transport struct {
|
||||
// If both are set, DialContext takes priority.
|
||||
Dial func(network, addr string) (net.Conn, error)
|
||||
|
||||
// DialTLS specifies an optional dial function for creating
|
||||
// DialTLSContext specifies an optional dial function for creating
|
||||
// TLS connections for non-proxied HTTPS requests.
|
||||
//
|
||||
// If DialTLS is nil, Dial and TLSClientConfig are used.
|
||||
// If DialTLSContext is nil (and the deprecated DialTLS below is also nil),
|
||||
// DialContext and TLSClientConfig are used.
|
||||
//
|
||||
// If DialTLS is set, the Dial hook is not used for HTTPS
|
||||
// If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS
|
||||
// requests and the TLSClientConfig and TLSHandshakeTimeout
|
||||
// are ignored. The returned net.Conn is assumed to already be
|
||||
// past the TLS handshake.
|
||||
DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
|
||||
|
||||
// DialTLS specifies an optional dial function for creating
|
||||
// TLS connections for non-proxied HTTPS requests.
|
||||
//
|
||||
// Deprecated: Use DialTLSContext instead, which allows the transport
|
||||
// to cancel dials as soon as they are no longer needed.
|
||||
// If both are set, DialTLSContext takes priority.
|
||||
DialTLS func(network, addr string) (net.Conn, error)
|
||||
|
||||
// TLSClientConfig specifies the TLS configuration to use with
|
||||
@ -286,6 +295,7 @@ func (t *Transport) Clone() *Transport {
|
||||
DialContext: t.DialContext,
|
||||
Dial: t.Dial,
|
||||
DialTLS: t.DialTLS,
|
||||
DialTLSContext: t.DialTLSContext,
|
||||
TLSHandshakeTimeout: t.TLSHandshakeTimeout,
|
||||
DisableKeepAlives: t.DisableKeepAlives,
|
||||
DisableCompression: t.DisableCompression,
|
||||
@ -324,6 +334,10 @@ type h2Transport interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
func (t *Transport) hasCustomTLSDialer() bool {
|
||||
return t.DialTLS != nil || t.DialTLSContext != nil
|
||||
}
|
||||
|
||||
// onceSetNextProtoDefaults initializes TLSNextProto.
|
||||
// It must be called via t.nextProtoOnce.Do.
|
||||
func (t *Transport) onceSetNextProtoDefaults() {
|
||||
@ -352,7 +366,7 @@ func (t *Transport) onceSetNextProtoDefaults() {
|
||||
// Transport.
|
||||
return
|
||||
}
|
||||
if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialTLS != nil || t.DialContext != nil) {
|
||||
if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()) {
|
||||
// Be conservative and don't automatically enable
|
||||
// http2 if they've specified a custom TLS config or
|
||||
// custom dialers. Let them opt-in themselves via
|
||||
@ -1185,6 +1199,18 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) {
|
||||
if t.DialTLSContext != nil {
|
||||
conn, err = t.DialTLSContext(ctx, network, addr)
|
||||
} else {
|
||||
conn, err = t.DialTLS(network, addr)
|
||||
}
|
||||
if conn == nil && err == nil {
|
||||
err = errors.New("net/http: Transport.DialTLS or DialTLSContext returned (nil, nil)")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// getConn dials and creates a new persistConn to the target as
|
||||
// specified in the connectMethod. This includes doing a proxy CONNECT
|
||||
// and/or setting up TLS. If this doesn't return an error, the persistConn
|
||||
@ -1435,15 +1461,12 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
|
||||
}
|
||||
return err
|
||||
}
|
||||
if cm.scheme() == "https" && t.DialTLS != nil {
|
||||
if cm.scheme() == "https" && t.hasCustomTLSDialer() {
|
||||
var err error
|
||||
pconn.conn, err = t.DialTLS("tcp", cm.addr())
|
||||
pconn.conn, err = t.customDialTLS(ctx, "tcp", cm.addr())
|
||||
if err != nil {
|
||||
return nil, wrapErr(err)
|
||||
}
|
||||
if pconn.conn == nil {
|
||||
return nil, wrapErr(errors.New("net/http: Transport.DialTLS returned (nil, nil)"))
|
||||
}
|
||||
if tc, ok := pconn.conn.(*tls.Conn); ok {
|
||||
// Handshake here, in case DialTLS didn't. TLSNextProto below
|
||||
// depends on it for knowing the connection state.
|
||||
|
@ -3506,6 +3506,90 @@ func TestTransportDialTLS(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportDialContext(t *testing.T) {
|
||||
setParallel(t)
|
||||
defer afterTest(t)
|
||||
var mu sync.Mutex // guards following
|
||||
var gotReq bool
|
||||
var receivedContext context.Context
|
||||
|
||||
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
mu.Lock()
|
||||
gotReq = true
|
||||
mu.Unlock()
|
||||
}))
|
||||
defer ts.Close()
|
||||
c := ts.Client()
|
||||
c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||||
mu.Lock()
|
||||
receivedContext = ctx
|
||||
mu.Unlock()
|
||||
return net.Dial(netw, addr)
|
||||
}
|
||||
|
||||
req, err := NewRequest("GET", ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), "some-key", "some-value")
|
||||
res, err := c.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
mu.Lock()
|
||||
if !gotReq {
|
||||
t.Error("didn't get request")
|
||||
}
|
||||
if receivedContext != ctx {
|
||||
t.Error("didn't receive correct context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportDialTLSContext(t *testing.T) {
|
||||
setParallel(t)
|
||||
defer afterTest(t)
|
||||
var mu sync.Mutex // guards following
|
||||
var gotReq bool
|
||||
var receivedContext context.Context
|
||||
|
||||
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
mu.Lock()
|
||||
gotReq = true
|
||||
mu.Unlock()
|
||||
}))
|
||||
defer ts.Close()
|
||||
c := ts.Client()
|
||||
c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
|
||||
mu.Lock()
|
||||
receivedContext = ctx
|
||||
mu.Unlock()
|
||||
c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, c.Handshake()
|
||||
}
|
||||
|
||||
req, err := NewRequest("GET", ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ctx := context.WithValue(context.Background(), "some-key", "some-value")
|
||||
res, err := c.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
mu.Lock()
|
||||
if !gotReq {
|
||||
t.Error("didn't get request")
|
||||
}
|
||||
if receivedContext != ctx {
|
||||
t.Error("didn't receive correct context")
|
||||
}
|
||||
}
|
||||
|
||||
// Test for issue 8755
|
||||
// Ensure that if a proxy returns an error, it is exposed by RoundTrip
|
||||
func TestRoundTripReturnsProxyError(t *testing.T) {
|
||||
@ -5577,6 +5661,7 @@ func TestTransportClone(t *testing.T) {
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
|
||||
Dial: func(network, addr string) (net.Conn, error) { panic("") },
|
||||
DialTLS: func(network, addr string) (net.Conn, error) { panic("") },
|
||||
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { panic("") },
|
||||
TLSClientConfig: new(tls.Config),
|
||||
TLSHandshakeTimeout: time.Second,
|
||||
DisableKeepAlives: true,
|
||||
|
Loading…
Reference in New Issue
Block a user