mirror of
https://github.com/golang/go
synced 2024-11-18 15:54:42 -07:00
net/http: add Transport.DialTLS hook
Per discussions out of https://golang.org/cl/128930043/ and golang-nuts threads and with agl. Fixes #8522 LGTM=agl, adg R=agl, c, adg CC=c, golang-codereviews https://golang.org/cl/137940043
This commit is contained in:
parent
13d0b82bc8
commit
ae47e044a8
@ -43,8 +43,8 @@ var DefaultTransport RoundTripper = &Transport{
|
||||
// MaxIdleConnsPerHost.
|
||||
const DefaultMaxIdleConnsPerHost = 2
|
||||
|
||||
// Transport is an implementation of RoundTripper that supports http,
|
||||
// https, and http proxies (for either http or https with CONNECT).
|
||||
// Transport is an implementation of RoundTripper that supports HTTP,
|
||||
// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
|
||||
// Transport can also cache connections for future re-use.
|
||||
type Transport struct {
|
||||
idleMu sync.Mutex
|
||||
@ -61,11 +61,22 @@ type Transport struct {
|
||||
// If Proxy is nil or returns a nil *URL, no proxy is used.
|
||||
Proxy func(*Request) (*url.URL, error)
|
||||
|
||||
// Dial specifies the dial function for creating TCP
|
||||
// connections.
|
||||
// Dial specifies the dial function for creating unencrypted
|
||||
// TCP connections.
|
||||
// If Dial is nil, net.Dial is used.
|
||||
Dial func(network, addr string) (net.Conn, error)
|
||||
|
||||
// DialTLS specifies an optional dial function for creating
|
||||
// TLS connections for non-proxied HTTPS requests.
|
||||
//
|
||||
// If DialTLS is nil, Dial and TLSClientConfig are used.
|
||||
//
|
||||
// If DialTLS is set, the Dial hook is not used for HTTPS
|
||||
// requests and the TLSClientConfig and TLSHandshakeTimeout
|
||||
// are ignored. The returned net.Conn is assumed to already be
|
||||
// past the TLS handshake.
|
||||
DialTLS func(network, addr string) (net.Conn, error)
|
||||
|
||||
// TLSClientConfig specifies the TLS configuration to use with
|
||||
// tls.Client. If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
@ -504,44 +515,56 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
|
||||
}
|
||||
|
||||
func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
|
||||
conn, err := t.dial("tcp", cm.addr())
|
||||
if err != nil {
|
||||
if cm.proxyURL != nil {
|
||||
err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pa := cm.proxyAuth()
|
||||
|
||||
pconn := &persistConn{
|
||||
t: t,
|
||||
cacheKey: cm.key(),
|
||||
conn: conn,
|
||||
reqch: make(chan requestAndChan, 1),
|
||||
writech: make(chan writeRequest, 1),
|
||||
closech: make(chan struct{}),
|
||||
writeErrCh: make(chan error, 1),
|
||||
}
|
||||
tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
|
||||
if tlsDial {
|
||||
var err error
|
||||
pconn.conn, err = t.DialTLS("tcp", cm.addr())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tc, ok := pconn.conn.(*tls.Conn); ok {
|
||||
cs := tc.ConnectionState()
|
||||
pconn.tlsState = &cs
|
||||
}
|
||||
} else {
|
||||
conn, err := t.dial("tcp", cm.addr())
|
||||
if err != nil {
|
||||
if cm.proxyURL != nil {
|
||||
err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
pconn.conn = conn
|
||||
}
|
||||
|
||||
// Proxy setup.
|
||||
switch {
|
||||
case cm.proxyURL == nil:
|
||||
// Do nothing.
|
||||
// Do nothing. Not using a proxy.
|
||||
case cm.targetScheme == "http":
|
||||
pconn.isProxy = true
|
||||
if pa != "" {
|
||||
if pa := cm.proxyAuth(); pa != "" {
|
||||
pconn.mutateHeaderFunc = func(h Header) {
|
||||
h.Set("Proxy-Authorization", pa)
|
||||
}
|
||||
}
|
||||
case cm.targetScheme == "https":
|
||||
conn := pconn.conn
|
||||
connectReq := &Request{
|
||||
Method: "CONNECT",
|
||||
URL: &url.URL{Opaque: cm.targetAddr},
|
||||
Host: cm.targetAddr,
|
||||
Header: make(Header),
|
||||
}
|
||||
if pa != "" {
|
||||
if pa := cm.proxyAuth(); pa != "" {
|
||||
connectReq.Header.Set("Proxy-Authorization", pa)
|
||||
}
|
||||
connectReq.Write(conn)
|
||||
@ -562,7 +585,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if cm.targetScheme == "https" {
|
||||
if cm.targetScheme == "https" && !tlsDial {
|
||||
// Initiate TLS and check remote host name against certificate.
|
||||
cfg := t.TLSClientConfig
|
||||
if cfg == nil || cfg.ServerName == "" {
|
||||
@ -575,7 +598,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
|
||||
cfg = &clone
|
||||
}
|
||||
}
|
||||
plainConn := conn
|
||||
plainConn := pconn.conn
|
||||
tlsConn := tls.Client(plainConn, cfg)
|
||||
errc := make(chan error, 2)
|
||||
var timer *time.Timer // for canceling TLS handshake
|
||||
|
@ -2096,6 +2096,46 @@ func TestTransportClosesBodyOnError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportDialTLS(t *testing.T) {
|
||||
var mu sync.Mutex // guards following
|
||||
var gotReq, didDial bool
|
||||
|
||||
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
mu.Lock()
|
||||
gotReq = true
|
||||
mu.Unlock()
|
||||
}))
|
||||
defer ts.Close()
|
||||
tr := &Transport{
|
||||
DialTLS: func(netw, addr string) (net.Conn, error) {
|
||||
mu.Lock()
|
||||
didDial = true
|
||||
mu.Unlock()
|
||||
c, err := tls.Dial(netw, addr, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, c.Handshake()
|
||||
},
|
||||
}
|
||||
defer tr.CloseIdleConnections()
|
||||
client := &Client{Transport: tr}
|
||||
res, err := client.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
res.Body.Close()
|
||||
mu.Lock()
|
||||
if !gotReq {
|
||||
t.Error("didn't get request")
|
||||
}
|
||||
if !didDial {
|
||||
t.Error("didn't use dial hook")
|
||||
}
|
||||
}
|
||||
|
||||
func wantBody(res *http.Response, err error, want string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
|
Loading…
Reference in New Issue
Block a user