mirror of
https://github.com/golang/go
synced 2024-11-19 03:54:42 -07:00
net/http: limit Transport's reading of response header bytes from servers
The default is 10MB, like http2, but can be configured with a new field http.Transport.MaxResponseHeaderBytes. Fixes #9115 Change-Id: I01808ac631ce4794ef2b0dfc391ed51cf951ceb1 Reviewed-on: https://go-review.googlesource.com/21329 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Andrew Gerrand <adg@golang.org>
This commit is contained in:
parent
7a4211bc1f
commit
36feb1a00a
12
src/net/http/http.go
Normal file
12
src/net/http/http.go
Normal file
@ -0,0 +1,12 @@
|
||||
// Copyright 2016 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package http
|
||||
|
||||
// maxInt64 is the effective "infinite" value for the Server and
|
||||
// Transport's byte-limiting readers.
|
||||
const maxInt64 = 1<<63 - 1
|
||||
|
||||
// TODO(bradfitz): move common stuff here. The other files have accumulated
|
||||
// generic http stuff in random places.
|
@ -497,7 +497,7 @@ type connReader struct {
|
||||
}
|
||||
|
||||
func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
|
||||
func (cr *connReader) setInfiniteReadLimit() { cr.remain = 1<<63 - 1 }
|
||||
func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 }
|
||||
func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
|
||||
|
||||
func (cr *connReader) Read(p []byte) (n int, err error) {
|
||||
|
@ -146,6 +146,13 @@ type Transport struct {
|
||||
// If TLSNextProto is nil, HTTP/2 support is enabled automatically.
|
||||
TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper
|
||||
|
||||
// MaxResponseHeaderBytes specifies a limit on how many
|
||||
// response bytes are allowed in the server's response
|
||||
// header.
|
||||
//
|
||||
// Zero means to use a default limit.
|
||||
MaxResponseHeaderBytes int64
|
||||
|
||||
// nextProtoOnce guards initialization of TLSNextProto and
|
||||
// h2transport (via onceSetNextProtoDefaults)
|
||||
nextProtoOnce sync.Once
|
||||
@ -188,8 +195,23 @@ func (t *Transport) onceSetNextProtoDefaults() {
|
||||
t2, err := http2configureTransport(t)
|
||||
if err != nil {
|
||||
log.Printf("Error enabling Transport HTTP/2 support: %v", err)
|
||||
} else {
|
||||
t.h2transport = t2
|
||||
return
|
||||
}
|
||||
t.h2transport = t2
|
||||
|
||||
// Auto-configure the http2.Transport's MaxHeaderListSize from
|
||||
// the http.Transport's MaxResponseHeaderBytes. They don't
|
||||
// exactly mean the same thing, but they're close.
|
||||
//
|
||||
// TODO: also add this to x/net/http2.Configure Transport, behind
|
||||
// a +build go1.7 build tag:
|
||||
if limit1 := t.MaxResponseHeaderBytes; limit1 != 0 && t2.MaxHeaderListSize == 0 {
|
||||
const h2max = 1<<32 - 1
|
||||
if limit1 >= h2max {
|
||||
t2.MaxHeaderListSize = h2max
|
||||
} else {
|
||||
t2.MaxHeaderListSize = uint32(limit1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -351,7 +373,8 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
|
||||
// resent on a new connection. The non-nil input error is the error from
|
||||
// roundTrip, which might be wrapped in a beforeRespHeaderError error.
|
||||
//
|
||||
// The return value is err or the unwrapped error inside a
|
||||
// The return value is either nil to retry the request, the provided
|
||||
// err unmodified, or the unwrapped error inside a
|
||||
// beforeRespHeaderError.
|
||||
func checkTransportResend(err error, req *Request, pconn *persistConn) error {
|
||||
brhErr, ok := err.(beforeRespHeaderError)
|
||||
@ -864,7 +887,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
|
||||
}
|
||||
}
|
||||
|
||||
pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF})
|
||||
pconn.br = bufio.NewReader(pconn)
|
||||
pconn.bw = bufio.NewWriter(pconn.conn)
|
||||
go pconn.readLoop()
|
||||
go pconn.writeLoop()
|
||||
@ -998,17 +1021,18 @@ type persistConn struct {
|
||||
// If it's non-nil, the rest of the fields are unused.
|
||||
alt RoundTripper
|
||||
|
||||
t *Transport
|
||||
cacheKey connectMethodKey
|
||||
conn net.Conn
|
||||
tlsState *tls.ConnectionState
|
||||
br *bufio.Reader // from conn
|
||||
sawEOF bool // whether we've seen EOF from conn; owned by readLoop
|
||||
bw *bufio.Writer // to conn
|
||||
reqch chan requestAndChan // written by roundTrip; read by readLoop
|
||||
writech chan writeRequest // written by roundTrip; read by writeLoop
|
||||
closech chan struct{} // closed when conn closed
|
||||
isProxy bool
|
||||
t *Transport
|
||||
cacheKey connectMethodKey
|
||||
conn net.Conn
|
||||
tlsState *tls.ConnectionState
|
||||
br *bufio.Reader // from conn
|
||||
bw *bufio.Writer // to conn
|
||||
reqch chan requestAndChan // written by roundTrip; read by readLoop
|
||||
writech chan writeRequest // written by roundTrip; read by writeLoop
|
||||
closech chan struct{} // closed when conn closed
|
||||
isProxy bool
|
||||
sawEOF bool // whether we've seen EOF from conn; owned by readLoop
|
||||
readLimit int64 // bytes allowed to be read; owned by readLoop
|
||||
// writeErrCh passes the request write error (usually nil)
|
||||
// from the writeLoop goroutine to the readLoop which passes
|
||||
// it off to the res.Body reader, which then uses it to decide
|
||||
@ -1027,6 +1051,28 @@ type persistConn struct {
|
||||
mutateHeaderFunc func(Header)
|
||||
}
|
||||
|
||||
func (pc *persistConn) maxHeaderResponseSize() int64 {
|
||||
if v := pc.t.MaxResponseHeaderBytes; v != 0 {
|
||||
return v
|
||||
}
|
||||
return 10 << 20 // conservative default; same as http2
|
||||
}
|
||||
|
||||
func (pc *persistConn) Read(p []byte) (n int, err error) {
|
||||
if pc.readLimit <= 0 {
|
||||
return 0, fmt.Errorf("read limit of %d bytes exhausted", pc.maxHeaderResponseSize())
|
||||
}
|
||||
if int64(len(p)) > pc.readLimit {
|
||||
p = p[:pc.readLimit]
|
||||
}
|
||||
n, err = pc.conn.Read(p)
|
||||
if err == io.EOF {
|
||||
pc.sawEOF = true
|
||||
}
|
||||
pc.readLimit -= int64(n)
|
||||
return
|
||||
}
|
||||
|
||||
// isBroken reports whether this connection is in a known broken state.
|
||||
func (pc *persistConn) isBroken() bool {
|
||||
pc.mu.Lock()
|
||||
@ -1082,6 +1128,7 @@ func (pc *persistConn) readLoop() {
|
||||
|
||||
alive := true
|
||||
for alive {
|
||||
pc.readLimit = pc.maxHeaderResponseSize()
|
||||
_, err := pc.br.Peek(1)
|
||||
if err != nil {
|
||||
err = beforeRespHeaderError{err}
|
||||
@ -1103,6 +1150,9 @@ func (pc *persistConn) readLoop() {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if pc.readLimit <= 0 {
|
||||
err = fmt.Errorf("net/http: server response headers exceeded %d bytes; aborted", pc.maxHeaderResponseSize())
|
||||
}
|
||||
// If we won't be able to retry this request later (from the
|
||||
// roundTrip goroutine), mark it as done now.
|
||||
// BEFORE the send on rc.ch, as the client might re-use the
|
||||
@ -1120,6 +1170,7 @@ func (pc *persistConn) readLoop() {
|
||||
}
|
||||
return
|
||||
}
|
||||
pc.readLimit = maxInt64 // effictively no limit for response bodies
|
||||
|
||||
pc.mu.Lock()
|
||||
pc.numExpectedResponses--
|
||||
@ -1251,6 +1302,7 @@ func (pc *persistConn) readResponse(rc requestAndChan) (resp *Response, err erro
|
||||
}
|
||||
}
|
||||
if resp.StatusCode == 100 {
|
||||
pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
|
||||
resp, err = ReadResponse(pc.br, rc.req)
|
||||
if err != nil {
|
||||
return
|
||||
@ -1706,19 +1758,6 @@ func (tlsHandshakeTimeoutError) Timeout() bool { return true }
|
||||
func (tlsHandshakeTimeoutError) Temporary() bool { return true }
|
||||
func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" }
|
||||
|
||||
type noteEOFReader struct {
|
||||
r io.Reader
|
||||
sawEOF *bool
|
||||
}
|
||||
|
||||
func (nr noteEOFReader) Read(p []byte) (n int, err error) {
|
||||
n, err = nr.r.Read(p)
|
||||
if err == io.EOF {
|
||||
*nr.sawEOF = true
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// fakeLocker is a sync.Locker which does nothing. It's used to guard
|
||||
// test-only fields when not under test, to avoid runtime atomic
|
||||
// overhead.
|
||||
|
@ -3090,6 +3090,42 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportResponseHeaderLength(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
if r.URL.Path == "/long" {
|
||||
w.Header().Set("Long", strings.Repeat("a", 1<<20))
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
tr := &Transport{
|
||||
MaxResponseHeaderBytes: 512 << 10,
|
||||
}
|
||||
defer tr.CloseIdleConnections()
|
||||
c := &Client{Transport: tr}
|
||||
if res, err := c.Get(ts.URL); err != nil {
|
||||
t.Fatal(err)
|
||||
} else {
|
||||
res.Body.Close()
|
||||
}
|
||||
|
||||
res, err := c.Get(ts.URL + "/long")
|
||||
if err == nil {
|
||||
defer res.Body.Close()
|
||||
var n int64
|
||||
for k, vv := range res.Header {
|
||||
for _, v := range vv {
|
||||
n += int64(len(k)) + int64(len(v))
|
||||
}
|
||||
}
|
||||
t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
|
||||
}
|
||||
if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
|
||||
t.Errorf("got error: %v; want %q", err, want)
|
||||
}
|
||||
}
|
||||
|
||||
var errFakeRoundTrip = errors.New("fake roundtrip")
|
||||
|
||||
type funcRoundTripper func()
|
||||
|
Loading…
Reference in New Issue
Block a user