diff --git a/src/net/http/request.go b/src/net/http/request.go index 1bde114909..45507d23d1 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -885,68 +885,56 @@ func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { } type maxBytesReader struct { - w ResponseWriter - r io.ReadCloser // underlying reader - n int64 // max bytes remaining - stopped bool - sawEOF bool + w ResponseWriter + r io.ReadCloser // underlying reader + n int64 // max bytes remaining + err error // sticky error } func (l *maxBytesReader) tooLarge() (n int, err error) { - if !l.stopped { - l.stopped = true - - // The server code and client code both use - // maxBytesReader. This "requestTooLarge" check is - // only used by the server code. To prevent binaries - // which only using the HTTP Client code (such as - // cmd/go) from also linking in the HTTP server, don't - // use a static type assertion to the server - // "*response" type. Check this interface instead: - type requestTooLarger interface { - requestTooLarge() - } - if res, ok := l.w.(requestTooLarger); ok { - res.requestTooLarge() - } - } - return 0, errors.New("http: request body too large") + l.err = errors.New("http: request body too large") + return 0, l.err } func (l *maxBytesReader) Read(p []byte) (n int, err error) { - toRead := l.n - if l.n == 0 { - if l.sawEOF { - return l.tooLarge() - } - // The underlying io.Reader may not return (0, io.EOF) - // at EOF if the requested size is 0, so read 1 byte - // instead. The io.Reader docs are a bit ambiguous - // about the return value of Read when 0 bytes are - // requested, and {bytes,strings}.Reader gets it wrong - // too (it returns (0, nil) even at EOF). - toRead = 1 + if l.err != nil { + return 0, l.err } - if int64(len(p)) > toRead { - p = p[:toRead] + if len(p) == 0 { + return 0, nil + } + // If they asked for a 32KB byte read but only 5 bytes are + // remaining, no need to read 32KB. 6 bytes will answer the + // question of the whether we hit the limit or go past it. + if int64(len(p)) > l.n+1 { + p = p[:l.n+1] } n, err = l.r.Read(p) - if err == io.EOF { - l.sawEOF = true + + if int64(n) <= l.n { + l.n -= int64(n) + l.err = err + return n, err } - if l.n == 0 { - // If we had zero bytes to read remaining (but hadn't seen EOF) - // and we get a byte here, that means we went over our limit. - if n > 0 { - return l.tooLarge() - } - return 0, err + + n = int(l.n) + l.n = 0 + + // The server code and client code both use + // maxBytesReader. This "requestTooLarge" check is + // only used by the server code. To prevent binaries + // which only using the HTTP Client code (such as + // cmd/go) from also linking in the HTTP server, don't + // use a static type assertion to the server + // "*response" type. Check this interface instead: + type requestTooLarger interface { + requestTooLarge() } - l.n -= int64(n) - if l.n < 0 { - l.n = 0 + if res, ok := l.w.(requestTooLarger); ok { + res.requestTooLarge() } - return + l.err = errors.New("http: request body too large") + return n, l.err } func (l *maxBytesReader) Close() error { diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go index 82c7af3cda..a4c88c0291 100644 --- a/src/net/http/request_test.go +++ b/src/net/http/request_test.go @@ -679,6 +679,46 @@ func TestIssue10884_MaxBytesEOF(t *testing.T) { } } +// Issue 14981: MaxBytesReader's return error wasn't sticky. It +// doesn't technically need to be, but people expected it to be. +func TestMaxBytesReaderStickyError(t *testing.T) { + isSticky := func(r io.Reader) error { + var log bytes.Buffer + buf := make([]byte, 1000) + var firstErr error + for { + n, err := r.Read(buf) + fmt.Fprintf(&log, "Read(%d) = %d, %v\n", len(buf), n, err) + if err == nil { + continue + } + if firstErr == nil { + firstErr = err + continue + } + if !reflect.DeepEqual(err, firstErr) { + return fmt.Errorf("non-sticky error. got log:\n%s", log.Bytes()) + } + t.Logf("Got log: %s", log.Bytes()) + return nil + } + } + tests := [...]struct { + readable int + limit int64 + }{ + 0: {99, 100}, + 1: {100, 100}, + 2: {101, 100}, + } + for i, tt := range tests { + rc := MaxBytesReader(nil, ioutil.NopCloser(bytes.NewReader(make([]byte, tt.readable))), tt.limit) + if err := isSticky(rc); err != nil { + t.Errorf("%d. error: %v", i, err) + } + } +} + func testMissingFile(t *testing.T, req *Request) { f, fh, err := req.FormFile("missing") if f != nil {