1
0
mirror of https://github.com/golang/go synced 2024-11-23 20:00:04 -07:00

net/http: fix data race due to writeLoop goroutine left running

Fix a data race for clients that mutate requests after receiving a
response error which is caused by the writeLoop goroutine left
running, this can be seen on canceled requests.

Fixes #37669

Change-Id: I0e0e4fd63266326b32587d8596456760bf848b13
Reviewed-on: https://go-review.googlesource.com/c/go/+/232799
Reviewed-by: Bryan C. Mills <bcmills@google.com>
Run-TryBot: Bryan C. Mills <bcmills@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Steven Hartland 2020-05-07 21:12:21 +00:00 committed by Bryan C. Mills
parent 91a52de527
commit 5e1e8c4c9f
2 changed files with 101 additions and 1 deletions

View File

@ -1963,6 +1963,15 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte
return nil
}
// Wait for the writeLoop goroutine to terminate to avoid data
// races on callers who mutate the request on failure.
//
// When resc in pc.roundTrip and hence rc.ch receives a responseAndError
// with a non-nil error it implies that the persistConn is either closed
// or closing. Waiting on pc.writeLoopDone is hence safe as all callers
// close closech which in turn ensures writeLoop returns.
<-pc.writeLoopDone
// If the request was canceled, that's better than network
// failures that were likely the result of tearing down the
// connection.
@ -1988,7 +1997,6 @@ func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritte
return err
}
if pc.isBroken() {
<-pc.writeLoopDone
if pc.nwrite == startBytesWritten {
return nothingWrittenError{err}
}

View File

@ -25,6 +25,7 @@ import (
"io"
"io/ioutil"
"log"
mrand "math/rand"
"net"
. "net/http"
"net/http/httptest"
@ -6284,3 +6285,94 @@ func TestTransportRejectsSignInContentLength(t *testing.T) {
t.Fatalf("Error mismatch\nGot: %q\nWanted substring: %q", got, want)
}
}
// dumpConn is a net.Conn which writes to Writer and reads from Reader
type dumpConn struct {
io.Writer
io.Reader
}
func (c *dumpConn) Close() error { return nil }
func (c *dumpConn) LocalAddr() net.Addr { return nil }
func (c *dumpConn) RemoteAddr() net.Addr { return nil }
func (c *dumpConn) SetDeadline(t time.Time) error { return nil }
func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil }
func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil }
// delegateReader is a reader that delegates to another reader,
// once it arrives on a channel.
type delegateReader struct {
c chan io.Reader
r io.Reader // nil until received from c
}
func (r *delegateReader) Read(p []byte) (int, error) {
if r.r == nil {
r.r = <-r.c
}
return r.r.Read(p)
}
func testTransportRace(req *Request) {
save := req.Body
pr, pw := io.Pipe()
defer pr.Close()
defer pw.Close()
dr := &delegateReader{c: make(chan io.Reader)}
t := &Transport{
Dial: func(net, addr string) (net.Conn, error) {
return &dumpConn{pw, dr}, nil
},
}
defer t.CloseIdleConnections()
quitReadCh := make(chan struct{})
// Wait for the request before replying with a dummy response:
go func() {
defer close(quitReadCh)
req, err := ReadRequest(bufio.NewReader(pr))
if err == nil {
// Ensure all the body is read; otherwise
// we'll get a partial dump.
io.Copy(ioutil.Discard, req.Body)
req.Body.Close()
}
select {
case dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n"):
case quitReadCh <- struct{}{}:
}
}()
t.RoundTrip(req)
// Ensure the reader returns before we reset req.Body to prevent
// a data race on req.Body.
pw.Close()
<-quitReadCh
req.Body = save
}
// Issue 37669
// Test that a cancellation doesn't result in a data race due to the writeLoop
// goroutine being left running, if the caller mutates the processed Request
// upon completion.
func TestErrorWriteLoopRace(t *testing.T) {
for i := 0; i < 1000; i++ {
ctx, cancel := context.WithCancel(context.Background())
r := bytes.NewBuffer(make([]byte, 10000))
delay := time.Duration(mrand.Intn(5)) * time.Millisecond
go func() {
time.Sleep(delay)
cancel()
}()
req, err := NewRequestWithContext(ctx, MethodPost, "http://example.com", r)
if err != nil {
t.Fatal(err)
}
testTransportRace(req)
}
}