diff --git a/src/net/http/server.go b/src/net/http/server.go index 55fd4ae22f..e9b0b4d9bd 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -865,6 +865,28 @@ func (srv *Server) initialReadLimitSize() int64 { return int64(srv.maxHeaderBytes()) + 4096 // bufio slop } +// tlsHandshakeTimeout returns the time limit permitted for the TLS +// handshake, or zero for unlimited. +// +// It returns the minimum of any positive ReadHeaderTimeout, +// ReadTimeout, or WriteTimeout. +func (srv *Server) tlsHandshakeTimeout() time.Duration { + var ret time.Duration + for _, v := range [...]time.Duration{ + srv.ReadHeaderTimeout, + srv.ReadTimeout, + srv.WriteTimeout, + } { + if v <= 0 { + continue + } + if ret == 0 || v < ret { + ret = v + } + } + return ret +} + // wrapper around io.ReadCloser which on first read, sends an // HTTP/1.1 100 Continue header type expectContinueReader struct { @@ -1816,11 +1838,11 @@ func (c *conn) serve(ctx context.Context) { }() if tlsConn, ok := c.rwc.(*tls.Conn); ok { - if d := c.server.ReadTimeout; d > 0 { - c.rwc.SetReadDeadline(time.Now().Add(d)) - } - if d := c.server.WriteTimeout; d > 0 { - c.rwc.SetWriteDeadline(time.Now().Add(d)) + tlsTO := c.server.tlsHandshakeTimeout() + if tlsTO > 0 { + dl := time.Now().Add(tlsTO) + c.rwc.SetReadDeadline(dl) + c.rwc.SetWriteDeadline(dl) } if err := tlsConn.HandshakeContext(ctx); err != nil { // If the handshake failed due to the client not speaking @@ -1834,6 +1856,11 @@ func (c *conn) serve(ctx context.Context) { c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err) return } + // Restore Conn-level deadlines. + if tlsTO > 0 { + c.rwc.SetReadDeadline(time.Time{}) + c.rwc.SetWriteDeadline(time.Time{}) + } c.tlsState = new(tls.ConnectionState) *c.tlsState = tlsConn.ConnectionState() if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) { diff --git a/src/net/http/server_test.go b/src/net/http/server_test.go index 0132f3ba5f..d17c5c1e7e 100644 --- a/src/net/http/server_test.go +++ b/src/net/http/server_test.go @@ -9,8 +9,61 @@ package http import ( "fmt" "testing" + "time" ) +func TestServerTLSHandshakeTimeout(t *testing.T) { + tests := []struct { + s *Server + want time.Duration + }{ + { + s: &Server{}, + want: 0, + }, + { + s: &Server{ + ReadTimeout: -1, + }, + want: 0, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + }, + want: 5 * time.Second, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: -1, + }, + want: 5 * time.Second, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + WriteTimeout: 4 * time.Second, + }, + want: 4 * time.Second, + }, + { + s: &Server{ + ReadTimeout: 5 * time.Second, + ReadHeaderTimeout: 2 * time.Second, + WriteTimeout: 4 * time.Second, + }, + want: 2 * time.Second, + }, + } + for i, tt := range tests { + got := tt.s.tlsHandshakeTimeout() + if got != tt.want { + t.Errorf("%d. got %v; want %v", i, got, tt.want) + } + } +} + func BenchmarkServerMatch(b *testing.B) { fn := func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "OK")