1
0
mirror of https://github.com/golang/go synced 2024-11-13 17:50:23 -07:00

net/http: make the MaxBytesReader.Read error sticky

Fixes #14981

Change-Id: I39b906d119ca96815801a0fbef2dbe524a3246ff
Reviewed-on: https://go-review.googlesource.com/23009
Reviewed-by: Andrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Brad Fitzpatrick 2016-05-10 15:09:23 -07:00
parent 20e362dae7
commit 4d8031cf3c
2 changed files with 78 additions and 50 deletions

View File

@ -885,68 +885,56 @@ func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser {
}
type maxBytesReader struct {
w ResponseWriter
r io.ReadCloser // underlying reader
n int64 // max bytes remaining
stopped bool
sawEOF bool
w ResponseWriter
r io.ReadCloser // underlying reader
n int64 // max bytes remaining
err error // sticky error
}
func (l *maxBytesReader) tooLarge() (n int, err error) {
if !l.stopped {
l.stopped = true
// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
}
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
}
return 0, errors.New("http: request body too large")
l.err = errors.New("http: request body too large")
return 0, l.err
}
func (l *maxBytesReader) Read(p []byte) (n int, err error) {
toRead := l.n
if l.n == 0 {
if l.sawEOF {
return l.tooLarge()
}
// The underlying io.Reader may not return (0, io.EOF)
// at EOF if the requested size is 0, so read 1 byte
// instead. The io.Reader docs are a bit ambiguous
// about the return value of Read when 0 bytes are
// requested, and {bytes,strings}.Reader gets it wrong
// too (it returns (0, nil) even at EOF).
toRead = 1
if l.err != nil {
return 0, l.err
}
if int64(len(p)) > toRead {
p = p[:toRead]
if len(p) == 0 {
return 0, nil
}
// If they asked for a 32KB byte read but only 5 bytes are
// remaining, no need to read 32KB. 6 bytes will answer the
// question of the whether we hit the limit or go past it.
if int64(len(p)) > l.n+1 {
p = p[:l.n+1]
}
n, err = l.r.Read(p)
if err == io.EOF {
l.sawEOF = true
if int64(n) <= l.n {
l.n -= int64(n)
l.err = err
return n, err
}
if l.n == 0 {
// If we had zero bytes to read remaining (but hadn't seen EOF)
// and we get a byte here, that means we went over our limit.
if n > 0 {
return l.tooLarge()
}
return 0, err
n = int(l.n)
l.n = 0
// The server code and client code both use
// maxBytesReader. This "requestTooLarge" check is
// only used by the server code. To prevent binaries
// which only using the HTTP Client code (such as
// cmd/go) from also linking in the HTTP server, don't
// use a static type assertion to the server
// "*response" type. Check this interface instead:
type requestTooLarger interface {
requestTooLarge()
}
l.n -= int64(n)
if l.n < 0 {
l.n = 0
if res, ok := l.w.(requestTooLarger); ok {
res.requestTooLarge()
}
return
l.err = errors.New("http: request body too large")
return n, l.err
}
func (l *maxBytesReader) Close() error {

View File

@ -679,6 +679,46 @@ func TestIssue10884_MaxBytesEOF(t *testing.T) {
}
}
// Issue 14981: MaxBytesReader's return error wasn't sticky. It
// doesn't technically need to be, but people expected it to be.
func TestMaxBytesReaderStickyError(t *testing.T) {
isSticky := func(r io.Reader) error {
var log bytes.Buffer
buf := make([]byte, 1000)
var firstErr error
for {
n, err := r.Read(buf)
fmt.Fprintf(&log, "Read(%d) = %d, %v\n", len(buf), n, err)
if err == nil {
continue
}
if firstErr == nil {
firstErr = err
continue
}
if !reflect.DeepEqual(err, firstErr) {
return fmt.Errorf("non-sticky error. got log:\n%s", log.Bytes())
}
t.Logf("Got log: %s", log.Bytes())
return nil
}
}
tests := [...]struct {
readable int
limit int64
}{
0: {99, 100},
1: {100, 100},
2: {101, 100},
}
for i, tt := range tests {
rc := MaxBytesReader(nil, ioutil.NopCloser(bytes.NewReader(make([]byte, tt.readable))), tt.limit)
if err := isSticky(rc); err != nil {
t.Errorf("%d. error: %v", i, err)
}
}
}
func testMissingFile(t *testing.T, req *Request) {
f, fh, err := req.FormFile("missing")
if f != nil {