1
0
mirror of https://github.com/golang/go synced 2024-11-18 15:54:42 -07:00

net/http: add Transport.DialTLS hook

Per discussions out of https://golang.org/cl/128930043/
and golang-nuts threads and with agl.

Fixes #8522

LGTM=agl, adg
R=agl, c, adg
CC=c, golang-codereviews
https://golang.org/cl/137940043
This commit is contained in:
Brad Fitzpatrick 2014-09-07 20:48:40 -07:00
parent 13d0b82bc8
commit ae47e044a8
2 changed files with 83 additions and 20 deletions

View File

@ -43,8 +43,8 @@ var DefaultTransport RoundTripper = &Transport{
// MaxIdleConnsPerHost.
const DefaultMaxIdleConnsPerHost = 2
// Transport is an implementation of RoundTripper that supports http,
// https, and http proxies (for either http or https with CONNECT).
// Transport is an implementation of RoundTripper that supports HTTP,
// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
// Transport can also cache connections for future re-use.
type Transport struct {
idleMu sync.Mutex
@ -61,11 +61,22 @@ type Transport struct {
// If Proxy is nil or returns a nil *URL, no proxy is used.
Proxy func(*Request) (*url.URL, error)
// Dial specifies the dial function for creating TCP
// connections.
// Dial specifies the dial function for creating unencrypted
// TCP connections.
// If Dial is nil, net.Dial is used.
Dial func(network, addr string) (net.Conn, error)
// DialTLS specifies an optional dial function for creating
// TLS connections for non-proxied HTTPS requests.
//
// If DialTLS is nil, Dial and TLSClientConfig are used.
//
// If DialTLS is set, the Dial hook is not used for HTTPS
// requests and the TLSClientConfig and TLSHandshakeTimeout
// are ignored. The returned net.Conn is assumed to already be
// past the TLS handshake.
DialTLS func(network, addr string) (net.Conn, error)
// TLSClientConfig specifies the TLS configuration to use with
// tls.Client. If nil, the default configuration is used.
TLSClientConfig *tls.Config
@ -504,44 +515,56 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
}
func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
conn, err := t.dial("tcp", cm.addr())
if err != nil {
if cm.proxyURL != nil {
err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
}
return nil, err
}
pa := cm.proxyAuth()
pconn := &persistConn{
t: t,
cacheKey: cm.key(),
conn: conn,
reqch: make(chan requestAndChan, 1),
writech: make(chan writeRequest, 1),
closech: make(chan struct{}),
writeErrCh: make(chan error, 1),
}
tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
if tlsDial {
var err error
pconn.conn, err = t.DialTLS("tcp", cm.addr())
if err != nil {
return nil, err
}
if tc, ok := pconn.conn.(*tls.Conn); ok {
cs := tc.ConnectionState()
pconn.tlsState = &cs
}
} else {
conn, err := t.dial("tcp", cm.addr())
if err != nil {
if cm.proxyURL != nil {
err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err)
}
return nil, err
}
pconn.conn = conn
}
// Proxy setup.
switch {
case cm.proxyURL == nil:
// Do nothing.
// Do nothing. Not using a proxy.
case cm.targetScheme == "http":
pconn.isProxy = true
if pa != "" {
if pa := cm.proxyAuth(); pa != "" {
pconn.mutateHeaderFunc = func(h Header) {
h.Set("Proxy-Authorization", pa)
}
}
case cm.targetScheme == "https":
conn := pconn.conn
connectReq := &Request{
Method: "CONNECT",
URL: &url.URL{Opaque: cm.targetAddr},
Host: cm.targetAddr,
Header: make(Header),
}
if pa != "" {
if pa := cm.proxyAuth(); pa != "" {
connectReq.Header.Set("Proxy-Authorization", pa)
}
connectReq.Write(conn)
@ -562,7 +585,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
}
}
if cm.targetScheme == "https" {
if cm.targetScheme == "https" && !tlsDial {
// Initiate TLS and check remote host name against certificate.
cfg := t.TLSClientConfig
if cfg == nil || cfg.ServerName == "" {
@ -575,7 +598,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
cfg = &clone
}
}
plainConn := conn
plainConn := pconn.conn
tlsConn := tls.Client(plainConn, cfg)
errc := make(chan error, 2)
var timer *time.Timer // for canceling TLS handshake

View File

@ -2096,6 +2096,46 @@ func TestTransportClosesBodyOnError(t *testing.T) {
}
}
func TestTransportDialTLS(t *testing.T) {
var mu sync.Mutex // guards following
var gotReq, didDial bool
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
mu.Lock()
gotReq = true
mu.Unlock()
}))
defer ts.Close()
tr := &Transport{
DialTLS: func(netw, addr string) (net.Conn, error) {
mu.Lock()
didDial = true
mu.Unlock()
c, err := tls.Dial(netw, addr, &tls.Config{
InsecureSkipVerify: true,
})
if err != nil {
return nil, err
}
return c, c.Handshake()
},
}
defer tr.CloseIdleConnections()
client := &Client{Transport: tr}
res, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
mu.Lock()
if !gotReq {
t.Error("didn't get request")
}
if !didDial {
t.Error("didn't use dial hook")
}
}
func wantBody(res *http.Response, err error, want string) error {
if err != nil {
return err