1
0
mirror of https://github.com/golang/go synced 2024-11-19 03:54:42 -07:00

net/http: limit Transport's reading of response header bytes from servers

The default is 10MB, like http2, but can be configured with a new
field http.Transport.MaxResponseHeaderBytes.

Fixes #9115

Change-Id: I01808ac631ce4794ef2b0dfc391ed51cf951ceb1
Reviewed-on: https://go-review.googlesource.com/21329
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Andrew Gerrand <adg@golang.org>
This commit is contained in:
Brad Fitzpatrick 2016-03-31 00:06:27 -07:00
parent 7a4211bc1f
commit 36feb1a00a
4 changed files with 116 additions and 29 deletions

12
src/net/http/http.go Normal file
View File

@ -0,0 +1,12 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package http
// maxInt64 is the effective "infinite" value for the Server and
// Transport's byte-limiting readers.
const maxInt64 = 1<<63 - 1
// TODO(bradfitz): move common stuff here. The other files have accumulated
// generic http stuff in random places.

View File

@ -497,7 +497,7 @@ type connReader struct {
}
func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
func (cr *connReader) setInfiniteReadLimit() { cr.remain = 1<<63 - 1 }
func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 }
func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
func (cr *connReader) Read(p []byte) (n int, err error) {

View File

@ -146,6 +146,13 @@ type Transport struct {
// If TLSNextProto is nil, HTTP/2 support is enabled automatically.
TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper
// MaxResponseHeaderBytes specifies a limit on how many
// response bytes are allowed in the server's response
// header.
//
// Zero means to use a default limit.
MaxResponseHeaderBytes int64
// nextProtoOnce guards initialization of TLSNextProto and
// h2transport (via onceSetNextProtoDefaults)
nextProtoOnce sync.Once
@ -188,8 +195,23 @@ func (t *Transport) onceSetNextProtoDefaults() {
t2, err := http2configureTransport(t)
if err != nil {
log.Printf("Error enabling Transport HTTP/2 support: %v", err)
} else {
t.h2transport = t2
return
}
t.h2transport = t2
// Auto-configure the http2.Transport's MaxHeaderListSize from
// the http.Transport's MaxResponseHeaderBytes. They don't
// exactly mean the same thing, but they're close.
//
// TODO: also add this to x/net/http2.Configure Transport, behind
// a +build go1.7 build tag:
if limit1 := t.MaxResponseHeaderBytes; limit1 != 0 && t2.MaxHeaderListSize == 0 {
const h2max = 1<<32 - 1
if limit1 >= h2max {
t2.MaxHeaderListSize = h2max
} else {
t2.MaxHeaderListSize = uint32(limit1)
}
}
}
@ -351,7 +373,8 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
// resent on a new connection. The non-nil input error is the error from
// roundTrip, which might be wrapped in a beforeRespHeaderError error.
//
// The return value is err or the unwrapped error inside a
// The return value is either nil to retry the request, the provided
// err unmodified, or the unwrapped error inside a
// beforeRespHeaderError.
func checkTransportResend(err error, req *Request, pconn *persistConn) error {
brhErr, ok := err.(beforeRespHeaderError)
@ -864,7 +887,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
}
}
pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF})
pconn.br = bufio.NewReader(pconn)
pconn.bw = bufio.NewWriter(pconn.conn)
go pconn.readLoop()
go pconn.writeLoop()
@ -998,17 +1021,18 @@ type persistConn struct {
// If it's non-nil, the rest of the fields are unused.
alt RoundTripper
t *Transport
cacheKey connectMethodKey
conn net.Conn
tlsState *tls.ConnectionState
br *bufio.Reader // from conn
sawEOF bool // whether we've seen EOF from conn; owned by readLoop
bw *bufio.Writer // to conn
reqch chan requestAndChan // written by roundTrip; read by readLoop
writech chan writeRequest // written by roundTrip; read by writeLoop
closech chan struct{} // closed when conn closed
isProxy bool
t *Transport
cacheKey connectMethodKey
conn net.Conn
tlsState *tls.ConnectionState
br *bufio.Reader // from conn
bw *bufio.Writer // to conn
reqch chan requestAndChan // written by roundTrip; read by readLoop
writech chan writeRequest // written by roundTrip; read by writeLoop
closech chan struct{} // closed when conn closed
isProxy bool
sawEOF bool // whether we've seen EOF from conn; owned by readLoop
readLimit int64 // bytes allowed to be read; owned by readLoop
// writeErrCh passes the request write error (usually nil)
// from the writeLoop goroutine to the readLoop which passes
// it off to the res.Body reader, which then uses it to decide
@ -1027,6 +1051,28 @@ type persistConn struct {
mutateHeaderFunc func(Header)
}
func (pc *persistConn) maxHeaderResponseSize() int64 {
if v := pc.t.MaxResponseHeaderBytes; v != 0 {
return v
}
return 10 << 20 // conservative default; same as http2
}
func (pc *persistConn) Read(p []byte) (n int, err error) {
if pc.readLimit <= 0 {
return 0, fmt.Errorf("read limit of %d bytes exhausted", pc.maxHeaderResponseSize())
}
if int64(len(p)) > pc.readLimit {
p = p[:pc.readLimit]
}
n, err = pc.conn.Read(p)
if err == io.EOF {
pc.sawEOF = true
}
pc.readLimit -= int64(n)
return
}
// isBroken reports whether this connection is in a known broken state.
func (pc *persistConn) isBroken() bool {
pc.mu.Lock()
@ -1082,6 +1128,7 @@ func (pc *persistConn) readLoop() {
alive := true
for alive {
pc.readLimit = pc.maxHeaderResponseSize()
_, err := pc.br.Peek(1)
if err != nil {
err = beforeRespHeaderError{err}
@ -1103,6 +1150,9 @@ func (pc *persistConn) readLoop() {
}
if err != nil {
if pc.readLimit <= 0 {
err = fmt.Errorf("net/http: server response headers exceeded %d bytes; aborted", pc.maxHeaderResponseSize())
}
// If we won't be able to retry this request later (from the
// roundTrip goroutine), mark it as done now.
// BEFORE the send on rc.ch, as the client might re-use the
@ -1120,6 +1170,7 @@ func (pc *persistConn) readLoop() {
}
return
}
pc.readLimit = maxInt64 // effictively no limit for response bodies
pc.mu.Lock()
pc.numExpectedResponses--
@ -1251,6 +1302,7 @@ func (pc *persistConn) readResponse(rc requestAndChan) (resp *Response, err erro
}
}
if resp.StatusCode == 100 {
pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
resp, err = ReadResponse(pc.br, rc.req)
if err != nil {
return
@ -1706,19 +1758,6 @@ func (tlsHandshakeTimeoutError) Timeout() bool { return true }
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
type noteEOFReader struct {
r io.Reader
sawEOF *bool
}
func (nr noteEOFReader) Read(p []byte) (n int, err error) {
n, err = nr.r.Read(p)
if err == io.EOF {
*nr.sawEOF = true
}
return
}
// fakeLocker is a sync.Locker which does nothing. It's used to guard
// test-only fields when not under test, to avoid runtime atomic
// overhead.

View File

@ -3090,6 +3090,42 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) {
}
}
func TestTransportResponseHeaderLength(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.URL.Path == "/long" {
w.Header().Set("Long", strings.Repeat("a", 1<<20))
}
}))
defer ts.Close()
tr := &Transport{
MaxResponseHeaderBytes: 512 << 10,
}
defer tr.CloseIdleConnections()
c := &Client{Transport: tr}
if res, err := c.Get(ts.URL); err != nil {
t.Fatal(err)
} else {
res.Body.Close()
}
res, err := c.Get(ts.URL + "/long")
if err == nil {
defer res.Body.Close()
var n int64
for k, vv := range res.Header {
for _, v := range vv {
n += int64(len(k)) + int64(len(v))
}
}
t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
}
if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
t.Errorf("got error: %v; want %q", err, want)
}
}
var errFakeRoundTrip = errors.New("fake roundtrip")
type funcRoundTripper func()