mirror of
https://github.com/golang/go
synced 2024-11-24 04:40:24 -07:00
net/http: fix data race when sharing request body between client and server
A server Handler (e.g. a proxy) can receive a Request, and then turn around and give a copy of that Request.Body out to the Transport. So then two goroutines own that Request.Body (the server and the http client), and both think they can close it on failure. Therefore, all incoming server requests bodies (always *http.body from transfer.go) need to be thread-safe. Fixes #6995 R=golang-codereviews, r CC=golang-codereviews https://golang.org/cl/46570043
This commit is contained in:
parent
39a396d2ba
commit
affab3f312
@ -14,6 +14,7 @@ import (
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
@ -406,8 +407,7 @@ func TestWriteResponse(t *testing.T) {
|
||||
t.Errorf("#%d: %v", i, err)
|
||||
continue
|
||||
}
|
||||
bout := bytes.NewBuffer(nil)
|
||||
err = resp.Write(bout)
|
||||
err = resp.Write(ioutil.Discard)
|
||||
if err != nil {
|
||||
t.Errorf("#%d: %v", i, err)
|
||||
continue
|
||||
@ -506,6 +506,9 @@ func TestReadResponseCloseInMiddle(t *testing.T) {
|
||||
rest, err := ioutil.ReadAll(bufr)
|
||||
checkErr(err, "ReadAll on remainder")
|
||||
if e, g := "Next Request Here", string(rest); e != g {
|
||||
g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string {
|
||||
return fmt.Sprintf("x(repeated x%d)", len(match))
|
||||
})
|
||||
fatalf("remainder = %q, expected %q", g, e)
|
||||
}
|
||||
}
|
||||
|
@ -2090,6 +2090,64 @@ func TestNoContentTypeOnNotModified(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// Issue 6995
|
||||
// A server Handler can receive a Request, and then turn around and
|
||||
// give a copy of that Request.Body out to the Transport (e.g. any
|
||||
// proxy). So then two people own that Request.Body (both the server
|
||||
// and the http client), and both think they can close it on failure.
|
||||
// Therefore, all incoming server requests Bodies need to be thread-safe.
|
||||
func TestTransportAndServerSharedBodyRace(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
|
||||
const bodySize = 1 << 20
|
||||
|
||||
unblockBackend := make(chan bool)
|
||||
backend := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||||
io.CopyN(rw, req.Body, bodySize/2)
|
||||
<-unblockBackend
|
||||
}))
|
||||
defer backend.Close()
|
||||
|
||||
backendRespc := make(chan *Response, 1)
|
||||
proxy := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||||
if req.RequestURI == "/foo" {
|
||||
rw.Write([]byte("bar"))
|
||||
return
|
||||
}
|
||||
req2, _ := NewRequest("POST", backend.URL, req.Body)
|
||||
req2.ContentLength = bodySize
|
||||
|
||||
bresp, err := DefaultClient.Do(req2)
|
||||
if err != nil {
|
||||
t.Errorf("Proxy outbound request: %v", err)
|
||||
return
|
||||
}
|
||||
_, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/4)
|
||||
if err != nil {
|
||||
t.Errorf("Proxy copy error: %v", err)
|
||||
return
|
||||
}
|
||||
backendRespc <- bresp // to close later
|
||||
|
||||
// Try to cause a race: Both the DefaultTransport and the proxy handler's Server
|
||||
// will try to read/close req.Body (aka req2.Body)
|
||||
DefaultTransport.(*Transport).CancelRequest(req2)
|
||||
rw.Write([]byte("OK"))
|
||||
}))
|
||||
defer proxy.Close()
|
||||
|
||||
req, _ := NewRequest("POST", proxy.URL, io.LimitReader(neverEnding('a'), bodySize))
|
||||
res, err := DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("Original request: %v", err)
|
||||
}
|
||||
|
||||
// Cleanup, so we don't leak goroutines.
|
||||
res.Body.Close()
|
||||
close(unblockBackend)
|
||||
(<-backendRespc).Body.Close()
|
||||
}
|
||||
|
||||
func TestResponseWriterWriteStringAllocs(t *testing.T) {
|
||||
ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
if r.URL.Path == "/s" {
|
||||
|
@ -14,6 +14,7 @@ import (
|
||||
"net/textproto"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// transferWriter inspects the fields of a user-supplied Request or Response,
|
||||
@ -331,17 +332,17 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
|
||||
if noBodyExpected(t.RequestMethod) {
|
||||
t.Body = eofReader
|
||||
} else {
|
||||
t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
|
||||
t.Body = &body{src: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
|
||||
}
|
||||
case realLength == 0:
|
||||
t.Body = eofReader
|
||||
case realLength > 0:
|
||||
t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close}
|
||||
t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close}
|
||||
default:
|
||||
// realLength < 0, i.e. "Content-Length" not mentioned in header
|
||||
if t.Close {
|
||||
// Close semantics (i.e. HTTP/1.0)
|
||||
t.Body = &body{Reader: r, closing: t.Close}
|
||||
t.Body = &body{src: r, closing: t.Close}
|
||||
} else {
|
||||
// Persistent connection (i.e. HTTP/1.1)
|
||||
t.Body = eofReader
|
||||
@ -514,11 +515,13 @@ func fixTrailer(header Header, te []string) (Header, error) {
|
||||
// Close ensures that the body has been fully read
|
||||
// and then reads the trailer if necessary.
|
||||
type body struct {
|
||||
io.Reader
|
||||
src io.Reader
|
||||
hdr interface{} // non-nil (Response or Request) value means read trailer
|
||||
r *bufio.Reader // underlying wire-format reader for the trailer
|
||||
closing bool // is the connection to be closed after reading body?
|
||||
closed bool
|
||||
|
||||
mu sync.Mutex // guards closed, and calls to Read and Close
|
||||
closed bool
|
||||
}
|
||||
|
||||
// ErrBodyReadAfterClose is returned when reading a Request or Response
|
||||
@ -528,10 +531,17 @@ type body struct {
|
||||
var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
|
||||
|
||||
func (b *body) Read(p []byte) (n int, err error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return 0, ErrBodyReadAfterClose
|
||||
}
|
||||
n, err = b.Reader.Read(p)
|
||||
return b.readLocked(p)
|
||||
}
|
||||
|
||||
// Must hold b.mu.
|
||||
func (b *body) readLocked(p []byte) (n int, err error) {
|
||||
n, err = b.src.Read(p)
|
||||
|
||||
if err == io.EOF {
|
||||
// Chunked case. Read the trailer.
|
||||
@ -543,7 +553,7 @@ func (b *body) Read(p []byte) (n int, err error) {
|
||||
} else {
|
||||
// If the server declared the Content-Length, our body is a LimitedReader
|
||||
// and we need to check whether this EOF arrived early.
|
||||
if lr, ok := b.Reader.(*io.LimitedReader); ok && lr.N > 0 {
|
||||
if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
}
|
||||
@ -618,6 +628,8 @@ func (b *body) readTrailer() error {
|
||||
}
|
||||
|
||||
func (b *body) Close() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
if b.closed {
|
||||
return nil
|
||||
}
|
||||
@ -629,12 +641,25 @@ func (b *body) Close() error {
|
||||
default:
|
||||
// Fully consume the body, which will also lead to us reading
|
||||
// the trailer headers after the body, if present.
|
||||
_, err = io.Copy(ioutil.Discard, b)
|
||||
_, err = io.Copy(ioutil.Discard, bodyLocked{b})
|
||||
}
|
||||
b.closed = true
|
||||
return err
|
||||
}
|
||||
|
||||
// bodyLocked is a io.Reader reading from a *body when its mutex is
|
||||
// already held.
|
||||
type bodyLocked struct {
|
||||
b *body
|
||||
}
|
||||
|
||||
func (bl bodyLocked) Read(p []byte) (n int, err error) {
|
||||
if bl.b.closed {
|
||||
return 0, ErrBodyReadAfterClose
|
||||
}
|
||||
return bl.b.readLocked(p)
|
||||
}
|
||||
|
||||
// parseContentLength trims whitespace from s and returns -1 if no value
|
||||
// is set, or the value if it's >= 0.
|
||||
func parseContentLength(cl string) (int64, error) {
|
||||
|
@ -12,9 +12,9 @@ import (
|
||||
|
||||
func TestBodyReadBadTrailer(t *testing.T) {
|
||||
b := &body{
|
||||
Reader: strings.NewReader("foobar"),
|
||||
hdr: true, // force reading the trailer
|
||||
r: bufio.NewReader(strings.NewReader("")),
|
||||
src: strings.NewReader("foobar"),
|
||||
hdr: true, // force reading the trailer
|
||||
r: bufio.NewReader(strings.NewReader("")),
|
||||
}
|
||||
buf := make([]byte, 7)
|
||||
n, err := b.Read(buf[:3])
|
||||
|
Loading…
Reference in New Issue
Block a user