diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 1f215bd8439..e71c5365e14 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -2967,15 +2967,36 @@ func (b neverEnding) Read(p []byte) (n int, err error) { return len(p), nil } -type countReader struct { - r io.Reader - n *int64 +type bodyLimitReader struct { + mu sync.Mutex + count int + limit int + closed chan struct{} } -func (cr countReader) Read(p []byte) (n int, err error) { - n, err = cr.r.Read(p) - atomic.AddInt64(cr.n, int64(n)) - return +func (r *bodyLimitReader) Read(p []byte) (int, error) { + r.mu.Lock() + defer r.mu.Unlock() + select { + case <-r.closed: + return 0, errors.New("closed") + default: + } + if r.count > r.limit { + return 0, errors.New("at limit") + } + r.count += len(p) + for i := range p { + p[i] = 'a' + } + return len(p), nil +} + +func (r *bodyLimitReader) Close() error { + r.mu.Lock() + defer r.mu.Unlock() + close(r.closed) + return nil } func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) } @@ -2999,8 +3020,11 @@ func testRequestBodyLimit(t *testing.T, mode testMode) { } })) - nWritten := new(int64) - req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) + body := &bodyLimitReader{ + closed: make(chan struct{}), + limit: limit * 200, + } + req, _ := NewRequest("POST", cst.ts.URL, body) // Send the POST, but don't care it succeeds or not. The // remote side is going to reply and then close the TCP @@ -3015,10 +3039,13 @@ func testRequestBodyLimit(t *testing.T, mode testMode) { if err == nil { resp.Body.Close() } + // Wait for the Transport to finish writing the request body. + // It will close the body when done. + <-body.closed - if atomic.LoadInt64(nWritten) > limit*100 { + if body.count > limit*100 { t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d", - limit, nWritten) + limit, body.count) } }