mirror of
https://github.com/golang/go
synced 2024-11-23 13:40:04 -07:00
net/http: add Transport.TLSHandshakeTimeout; set it by default
Update #3362 LGTM=agl R=agl CC=golang-codereviews https://golang.org/cl/68150045
This commit is contained in:
parent
abe53f8766
commit
fd4b4b56c4
@ -36,6 +36,7 @@ var DefaultTransport RoundTripper = &Transport{
|
||||
Timeout: 30 * time.Second,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
// DefaultMaxIdleConnsPerHost is the default value of Transport's
|
||||
@ -69,6 +70,10 @@ type Transport struct {
|
||||
// tls.Client. If nil, the default configuration is used.
|
||||
TLSClientConfig *tls.Config
|
||||
|
||||
// TLSHandshakeTimeout specifies the maximum amount of time waiting to
|
||||
// wait for a TLS handshake. Zero means no timeout.
|
||||
TLSHandshakeTimeout time.Duration
|
||||
|
||||
// DisableKeepAlives, if true, prevents re-use of TCP connections
|
||||
// between different HTTP requests.
|
||||
DisableKeepAlives bool
|
||||
@ -542,16 +547,33 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
|
||||
cfg = &clone
|
||||
}
|
||||
}
|
||||
conn = tls.Client(conn, cfg)
|
||||
if err = conn.(*tls.Conn).Handshake(); err != nil {
|
||||
plainConn := conn
|
||||
tlsConn := tls.Client(plainConn, cfg)
|
||||
errc := make(chan error, 2)
|
||||
var timer *time.Timer // for canceling TLS handshake
|
||||
if d := t.TLSHandshakeTimeout; d != 0 {
|
||||
timer = time.AfterFunc(d, func() {
|
||||
errc <- tlsHandshakeTimeoutError{}
|
||||
})
|
||||
}
|
||||
go func() {
|
||||
err := tlsConn.Handshake()
|
||||
if timer != nil {
|
||||
timer.Stop()
|
||||
}
|
||||
errc <- err
|
||||
}()
|
||||
if err := <-errc; err != nil {
|
||||
plainConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !cfg.InsecureSkipVerify {
|
||||
if err = conn.(*tls.Conn).VerifyHostname(cfg.ServerName); err != nil {
|
||||
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
|
||||
plainConn.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
pconn.conn = conn
|
||||
pconn.conn = tlsConn
|
||||
}
|
||||
|
||||
pconn.br = bufio.NewReader(pconn.conn)
|
||||
@ -1084,3 +1106,9 @@ type readerAndCloser struct {
|
||||
io.Reader
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type tlsHandshakeTimeoutError struct{}
|
||||
|
||||
func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
||||
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
|
||||
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
|
||||
|
@ -1722,6 +1722,73 @@ func TestTransportClosesRequestBody(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportTLSHandshakeTimeout(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
if testing.Short() {
|
||||
t.Skip("skipping in short mode")
|
||||
}
|
||||
ln := newLocalListener(t)
|
||||
defer ln.Close()
|
||||
testdonec := make(chan struct{})
|
||||
defer close(testdonec)
|
||||
|
||||
go func() {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
<-testdonec
|
||||
c.Close()
|
||||
}()
|
||||
|
||||
getdonec := make(chan struct{})
|
||||
go func() {
|
||||
defer close(getdonec)
|
||||
tr := &Transport{
|
||||
Dial: func(_, _ string) (net.Conn, error) {
|
||||
return net.Dial("tcp", ln.Addr().String())
|
||||
},
|
||||
TLSHandshakeTimeout: 250 * time.Millisecond,
|
||||
}
|
||||
cl := &Client{Transport: tr}
|
||||
_, err := cl.Get("https://dummy.tld/")
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
ue, ok := err.(*url.Error)
|
||||
if !ok {
|
||||
t.Fatalf("expected url.Error; got %#v", err)
|
||||
}
|
||||
ne, ok := ue.Err.(net.Error)
|
||||
if !ok {
|
||||
t.Fatalf("expected net.Error; got %#v", err)
|
||||
}
|
||||
if !ne.Timeout() {
|
||||
t.Error("expected timeout error; got %v", err)
|
||||
}
|
||||
if !strings.Contains(err.Error(), "handshake timeout") {
|
||||
t.Error("expected 'handshake timeout' in error; got %v", err)
|
||||
}
|
||||
}()
|
||||
select {
|
||||
case <-getdonec:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Error("test timeout; TLS handshake hung?")
|
||||
}
|
||||
}
|
||||
|
||||
func newLocalListener(t *testing.T) net.Listener {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
ln, err = net.Listen("tcp6", "[::1]:0")
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return ln
|
||||
}
|
||||
|
||||
type countCloseReader struct {
|
||||
n *int
|
||||
io.Reader
|
||||
|
Loading…
Reference in New Issue
Block a user