mirror of
https://github.com/golang/go
synced 2024-09-25 11:20:13 -06:00
net/http: fix data race when sharing request body between client and server
A server Handler (e.g. a proxy) can receive a Request, and then turn around and give a copy of that Request.Body out to the Transport. So then two goroutines own that Request.Body (the server and the http client), and both think they can close it on failure. Therefore, all incoming server requests bodies (always *http.body from transfer.go) need to be thread-safe. Fixes #6995 R=golang-codereviews, r CC=golang-codereviews https://golang.org/cl/46570043
This commit is contained in:
parent
39a396d2ba
commit
affab3f312
@ -14,6 +14,7 @@ import (
|
|||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/url"
|
"net/url"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
@ -406,8 +407,7 @@ func TestWriteResponse(t *testing.T) {
|
|||||||
t.Errorf("#%d: %v", i, err)
|
t.Errorf("#%d: %v", i, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
bout := bytes.NewBuffer(nil)
|
err = resp.Write(ioutil.Discard)
|
||||||
err = resp.Write(bout)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("#%d: %v", i, err)
|
t.Errorf("#%d: %v", i, err)
|
||||||
continue
|
continue
|
||||||
@ -506,6 +506,9 @@ func TestReadResponseCloseInMiddle(t *testing.T) {
|
|||||||
rest, err := ioutil.ReadAll(bufr)
|
rest, err := ioutil.ReadAll(bufr)
|
||||||
checkErr(err, "ReadAll on remainder")
|
checkErr(err, "ReadAll on remainder")
|
||||||
if e, g := "Next Request Here", string(rest); e != g {
|
if e, g := "Next Request Here", string(rest); e != g {
|
||||||
|
g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string {
|
||||||
|
return fmt.Sprintf("x(repeated x%d)", len(match))
|
||||||
|
})
|
||||||
fatalf("remainder = %q, expected %q", g, e)
|
fatalf("remainder = %q, expected %q", g, e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2090,6 +2090,64 @@ func TestNoContentTypeOnNotModified(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Issue 6995
|
||||||
|
// A server Handler can receive a Request, and then turn around and
|
||||||
|
// give a copy of that Request.Body out to the Transport (e.g. any
|
||||||
|
// proxy). So then two people own that Request.Body (both the server
|
||||||
|
// and the http client), and both think they can close it on failure.
|
||||||
|
// Therefore, all incoming server requests Bodies need to be thread-safe.
|
||||||
|
func TestTransportAndServerSharedBodyRace(t *testing.T) {
|
||||||
|
defer afterTest(t)
|
||||||
|
|
||||||
|
const bodySize = 1 << 20
|
||||||
|
|
||||||
|
unblockBackend := make(chan bool)
|
||||||
|
backend := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||||||
|
io.CopyN(rw, req.Body, bodySize/2)
|
||||||
|
<-unblockBackend
|
||||||
|
}))
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
backendRespc := make(chan *Response, 1)
|
||||||
|
proxy := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||||||
|
if req.RequestURI == "/foo" {
|
||||||
|
rw.Write([]byte("bar"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
req2, _ := NewRequest("POST", backend.URL, req.Body)
|
||||||
|
req2.ContentLength = bodySize
|
||||||
|
|
||||||
|
bresp, err := DefaultClient.Do(req2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Proxy outbound request: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/4)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Proxy copy error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
backendRespc <- bresp // to close later
|
||||||
|
|
||||||
|
// Try to cause a race: Both the DefaultTransport and the proxy handler's Server
|
||||||
|
// will try to read/close req.Body (aka req2.Body)
|
||||||
|
DefaultTransport.(*Transport).CancelRequest(req2)
|
||||||
|
rw.Write([]byte("OK"))
|
||||||
|
}))
|
||||||
|
defer proxy.Close()
|
||||||
|
|
||||||
|
req, _ := NewRequest("POST", proxy.URL, io.LimitReader(neverEnding('a'), bodySize))
|
||||||
|
res, err := DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Original request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup, so we don't leak goroutines.
|
||||||
|
res.Body.Close()
|
||||||
|
close(unblockBackend)
|
||||||
|
(<-backendRespc).Body.Close()
|
||||||
|
}
|
||||||
|
|
||||||
func TestResponseWriterWriteStringAllocs(t *testing.T) {
|
func TestResponseWriterWriteStringAllocs(t *testing.T) {
|
||||||
ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
|
ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||||
if r.URL.Path == "/s" {
|
if r.URL.Path == "/s" {
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"net/textproto"
|
"net/textproto"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
)
|
)
|
||||||
|
|
||||||
// transferWriter inspects the fields of a user-supplied Request or Response,
|
// transferWriter inspects the fields of a user-supplied Request or Response,
|
||||||
@ -331,17 +332,17 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
|
|||||||
if noBodyExpected(t.RequestMethod) {
|
if noBodyExpected(t.RequestMethod) {
|
||||||
t.Body = eofReader
|
t.Body = eofReader
|
||||||
} else {
|
} else {
|
||||||
t.Body = &body{Reader: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
|
t.Body = &body{src: newChunkedReader(r), hdr: msg, r: r, closing: t.Close}
|
||||||
}
|
}
|
||||||
case realLength == 0:
|
case realLength == 0:
|
||||||
t.Body = eofReader
|
t.Body = eofReader
|
||||||
case realLength > 0:
|
case realLength > 0:
|
||||||
t.Body = &body{Reader: io.LimitReader(r, realLength), closing: t.Close}
|
t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close}
|
||||||
default:
|
default:
|
||||||
// realLength < 0, i.e. "Content-Length" not mentioned in header
|
// realLength < 0, i.e. "Content-Length" not mentioned in header
|
||||||
if t.Close {
|
if t.Close {
|
||||||
// Close semantics (i.e. HTTP/1.0)
|
// Close semantics (i.e. HTTP/1.0)
|
||||||
t.Body = &body{Reader: r, closing: t.Close}
|
t.Body = &body{src: r, closing: t.Close}
|
||||||
} else {
|
} else {
|
||||||
// Persistent connection (i.e. HTTP/1.1)
|
// Persistent connection (i.e. HTTP/1.1)
|
||||||
t.Body = eofReader
|
t.Body = eofReader
|
||||||
@ -514,10 +515,12 @@ func fixTrailer(header Header, te []string) (Header, error) {
|
|||||||
// Close ensures that the body has been fully read
|
// Close ensures that the body has been fully read
|
||||||
// and then reads the trailer if necessary.
|
// and then reads the trailer if necessary.
|
||||||
type body struct {
|
type body struct {
|
||||||
io.Reader
|
src io.Reader
|
||||||
hdr interface{} // non-nil (Response or Request) value means read trailer
|
hdr interface{} // non-nil (Response or Request) value means read trailer
|
||||||
r *bufio.Reader // underlying wire-format reader for the trailer
|
r *bufio.Reader // underlying wire-format reader for the trailer
|
||||||
closing bool // is the connection to be closed after reading body?
|
closing bool // is the connection to be closed after reading body?
|
||||||
|
|
||||||
|
mu sync.Mutex // guards closed, and calls to Read and Close
|
||||||
closed bool
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -528,10 +531,17 @@ type body struct {
|
|||||||
var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
|
var ErrBodyReadAfterClose = errors.New("http: invalid Read on closed Body")
|
||||||
|
|
||||||
func (b *body) Read(p []byte) (n int, err error) {
|
func (b *body) Read(p []byte) (n int, err error) {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
if b.closed {
|
if b.closed {
|
||||||
return 0, ErrBodyReadAfterClose
|
return 0, ErrBodyReadAfterClose
|
||||||
}
|
}
|
||||||
n, err = b.Reader.Read(p)
|
return b.readLocked(p)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must hold b.mu.
|
||||||
|
func (b *body) readLocked(p []byte) (n int, err error) {
|
||||||
|
n, err = b.src.Read(p)
|
||||||
|
|
||||||
if err == io.EOF {
|
if err == io.EOF {
|
||||||
// Chunked case. Read the trailer.
|
// Chunked case. Read the trailer.
|
||||||
@ -543,7 +553,7 @@ func (b *body) Read(p []byte) (n int, err error) {
|
|||||||
} else {
|
} else {
|
||||||
// If the server declared the Content-Length, our body is a LimitedReader
|
// If the server declared the Content-Length, our body is a LimitedReader
|
||||||
// and we need to check whether this EOF arrived early.
|
// and we need to check whether this EOF arrived early.
|
||||||
if lr, ok := b.Reader.(*io.LimitedReader); ok && lr.N > 0 {
|
if lr, ok := b.src.(*io.LimitedReader); ok && lr.N > 0 {
|
||||||
err = io.ErrUnexpectedEOF
|
err = io.ErrUnexpectedEOF
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -618,6 +628,8 @@ func (b *body) readTrailer() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *body) Close() error {
|
func (b *body) Close() error {
|
||||||
|
b.mu.Lock()
|
||||||
|
defer b.mu.Unlock()
|
||||||
if b.closed {
|
if b.closed {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -629,12 +641,25 @@ func (b *body) Close() error {
|
|||||||
default:
|
default:
|
||||||
// Fully consume the body, which will also lead to us reading
|
// Fully consume the body, which will also lead to us reading
|
||||||
// the trailer headers after the body, if present.
|
// the trailer headers after the body, if present.
|
||||||
_, err = io.Copy(ioutil.Discard, b)
|
_, err = io.Copy(ioutil.Discard, bodyLocked{b})
|
||||||
}
|
}
|
||||||
b.closed = true
|
b.closed = true
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// bodyLocked is a io.Reader reading from a *body when its mutex is
|
||||||
|
// already held.
|
||||||
|
type bodyLocked struct {
|
||||||
|
b *body
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bl bodyLocked) Read(p []byte) (n int, err error) {
|
||||||
|
if bl.b.closed {
|
||||||
|
return 0, ErrBodyReadAfterClose
|
||||||
|
}
|
||||||
|
return bl.b.readLocked(p)
|
||||||
|
}
|
||||||
|
|
||||||
// parseContentLength trims whitespace from s and returns -1 if no value
|
// parseContentLength trims whitespace from s and returns -1 if no value
|
||||||
// is set, or the value if it's >= 0.
|
// is set, or the value if it's >= 0.
|
||||||
func parseContentLength(cl string) (int64, error) {
|
func parseContentLength(cl string) (int64, error) {
|
||||||
|
@ -12,7 +12,7 @@ import (
|
|||||||
|
|
||||||
func TestBodyReadBadTrailer(t *testing.T) {
|
func TestBodyReadBadTrailer(t *testing.T) {
|
||||||
b := &body{
|
b := &body{
|
||||||
Reader: strings.NewReader("foobar"),
|
src: strings.NewReader("foobar"),
|
||||||
hdr: true, // force reading the trailer
|
hdr: true, // force reading the trailer
|
||||||
r: bufio.NewReader(strings.NewReader("")),
|
r: bufio.NewReader(strings.NewReader("")),
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user