1
0
mirror of https://github.com/golang/go synced 2024-09-29 22:14:29 -06:00

net/http: also use Server.ReadHeaderTimeout for TLS handshake deadline

Fixes #48120

Change-Id: I72e89af8aaf3310e348d8ab639925ce0bf84204d
Reviewed-on: https://go-review.googlesource.com/c/go/+/355870
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
This commit is contained in:
Brad Fitzpatrick 2021-10-14 08:45:16 -07:00
parent 0400d536e4
commit b59467e036
2 changed files with 85 additions and 5 deletions

View File

@ -865,6 +865,28 @@ func (srv *Server) initialReadLimitSize() int64 {
return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
}
// tlsHandshakeTimeout returns the time limit permitted for the TLS
// handshake, or zero for unlimited.
//
// It returns the minimum of any positive ReadHeaderTimeout,
// ReadTimeout, or WriteTimeout.
func (srv *Server) tlsHandshakeTimeout() time.Duration {
var ret time.Duration
for _, v := range [...]time.Duration{
srv.ReadHeaderTimeout,
srv.ReadTimeout,
srv.WriteTimeout,
} {
if v <= 0 {
continue
}
if ret == 0 || v < ret {
ret = v
}
}
return ret
}
// wrapper around io.ReadCloser which on first read, sends an
// HTTP/1.1 100 Continue header
type expectContinueReader struct {
@ -1816,11 +1838,11 @@ func (c *conn) serve(ctx context.Context) {
}()
if tlsConn, ok := c.rwc.(*tls.Conn); ok {
if d := c.server.ReadTimeout; d > 0 {
c.rwc.SetReadDeadline(time.Now().Add(d))
}
if d := c.server.WriteTimeout; d > 0 {
c.rwc.SetWriteDeadline(time.Now().Add(d))
tlsTO := c.server.tlsHandshakeTimeout()
if tlsTO > 0 {
dl := time.Now().Add(tlsTO)
c.rwc.SetReadDeadline(dl)
c.rwc.SetWriteDeadline(dl)
}
if err := tlsConn.HandshakeContext(ctx); err != nil {
// If the handshake failed due to the client not speaking
@ -1834,6 +1856,11 @@ func (c *conn) serve(ctx context.Context) {
c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err)
return
}
// Restore Conn-level deadlines.
if tlsTO > 0 {
c.rwc.SetReadDeadline(time.Time{})
c.rwc.SetWriteDeadline(time.Time{})
}
c.tlsState = new(tls.ConnectionState)
*c.tlsState = tlsConn.ConnectionState()
if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) {

View File

@ -9,8 +9,61 @@ package http
import (
"fmt"
"testing"
"time"
)
func TestServerTLSHandshakeTimeout(t *testing.T) {
tests := []struct {
s *Server
want time.Duration
}{
{
s: &Server{},
want: 0,
},
{
s: &Server{
ReadTimeout: -1,
},
want: 0,
},
{
s: &Server{
ReadTimeout: 5 * time.Second,
},
want: 5 * time.Second,
},
{
s: &Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: -1,
},
want: 5 * time.Second,
},
{
s: &Server{
ReadTimeout: 5 * time.Second,
WriteTimeout: 4 * time.Second,
},
want: 4 * time.Second,
},
{
s: &Server{
ReadTimeout: 5 * time.Second,
ReadHeaderTimeout: 2 * time.Second,
WriteTimeout: 4 * time.Second,
},
want: 2 * time.Second,
},
}
for i, tt := range tests {
got := tt.s.tlsHandshakeTimeout()
if got != tt.want {
t.Errorf("%d. got %v; want %v", i, got, tt.want)
}
}
}
func BenchmarkServerMatch(b *testing.B) {
fn := func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "OK")