1
0
mirror of https://github.com/golang/go synced 2024-11-15 02:50:31 -07:00

net/http: return correct error when reading from a canceled request body

CL 546676 inadvertently changed the error returned when reading
from the body of a canceled request. Fix it.

Rework various request cancelation tests to exercise all three ways
of canceling a request.

Fixes #67439

Change-Id: I14ecaf8bff9452eca4a05df923d57d768127a90c
Cq-Include-Trybots: luci.golang.try:gotip-linux-amd64-longtest,gotip-linux-amd64-longtest-race
Reviewed-on: https://go-review.googlesource.com/c/go/+/586315
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Damien Neil 2024-05-16 16:26:04 -07:00 committed by Gopher Robot
parent 2c635b68fd
commit a61729b880
2 changed files with 145 additions and 125 deletions

View File

@ -2333,7 +2333,7 @@ func (pc *persistConn) readLoop() {
}
case <-rc.treq.ctx.Done():
alive = false
pc.cancelRequest(errRequestCanceled)
pc.cancelRequest(context.Cause(rc.treq.ctx))
case <-pc.closech:
alive = false
}

View File

@ -2507,17 +2507,103 @@ func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
}
}
func TestTransportCancelRequest(t *testing.T) {
run(t, testTransportCancelRequest, []testMode{http1Mode})
// A cancelTest is a test of request cancellation.
type cancelTest struct {
mode testMode
newReq func(req *Request) *Request // prepare the request to cancel
cancel func(tr *Transport, req *Request) // cancel the request
checkErr func(when string, err error) // verify the expected error
}
func testTransportCancelRequest(t *testing.T, mode testMode) {
// runCancelTestTransport uses Transport.CancelRequest.
func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
t.Run("TransportCancel", func(t *testing.T) {
f(t, cancelTest{
mode: mode,
newReq: func(req *Request) *Request {
return req
},
cancel: func(tr *Transport, req *Request) {
tr.CancelRequest(req)
},
checkErr: func(when string, err error) {
if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
}
},
})
})
}
// runCancelTestChannel uses Request.Cancel.
func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
var cancelOnce sync.Once
cancelc := make(chan struct{})
f(t, cancelTest{
mode: mode,
newReq: func(req *Request) *Request {
req.Cancel = cancelc
return req
},
cancel: func(tr *Transport, req *Request) {
cancelOnce.Do(func() {
close(cancelc)
})
},
checkErr: func(when string, err error) {
if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
}
},
})
}
// runCancelTestContext uses a request context.
func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
ctx, cancel := context.WithCancel(context.Background())
f(t, cancelTest{
mode: mode,
newReq: func(req *Request) *Request {
return req.WithContext(ctx)
},
cancel: func(tr *Transport, req *Request) {
cancel()
},
checkErr: func(when string, err error) {
if !errors.Is(err, context.Canceled) {
t.Errorf("%v error = %v, want context.Canceled", when, err)
}
},
})
}
func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
run(t, func(t *testing.T, mode testMode) {
if mode == http1Mode {
t.Run("TransportCancel", func(t *testing.T) {
runCancelTestTransport(t, mode, f)
})
}
t.Run("RequestCancel", func(t *testing.T) {
runCancelTestChannel(t, mode, f)
})
t.Run("ContextCancel", func(t *testing.T) {
runCancelTestContext(t, mode, f)
})
}, opts...)
}
func TestTransportCancelRequest(t *testing.T) {
runCancelTest(t, testTransportCancelRequest)
}
func testTransportCancelRequest(t *testing.T, test cancelTest) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
const msg = "Hello"
unblockc := make(chan bool)
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, msg)
w.(Flusher).Flush() // send headers and some body
<-unblockc
@ -2528,6 +2614,7 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
tr := c.Transport.(*Transport)
req, _ := NewRequest("GET", ts.URL, nil)
req = test.newReq(req)
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
@ -2537,13 +2624,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
if n != len(body) || !bytes.Equal(body, []byte(msg)) {
t.Errorf("Body = %q; want %q", body[:n], msg)
}
tr.CancelRequest(req)
test.cancel(tr, req)
tail, err := io.ReadAll(res.Body)
res.Body.Close()
if err != ExportErrRequestCanceled {
t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
} else if len(tail) > 0 {
test.checkErr("Body.Read", err)
if len(tail) > 0 {
t.Errorf("Spurious bytes from Body.Read: %q", tail)
}
@ -2561,12 +2647,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
})
}
func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) {
func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
unblockc := make(chan bool)
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
})).ts
defer close(unblockc)
@ -2576,6 +2662,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
donec := make(chan bool)
req, _ := NewRequest("GET", ts.URL, body)
req = test.newReq(req)
go func() {
defer close(donec)
c.Do(req)
@ -2583,7 +2670,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
unblockc <- true
waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
tr.CancelRequest(req)
test.cancel(tr, req)
select {
case <-donec:
return true
@ -2597,18 +2684,21 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
}
func TestTransportCancelRequestInDo(t *testing.T) {
run(t, func(t *testing.T, mode testMode) {
testTransportCancelRequestInDo(t, mode, nil)
}, []testMode{http1Mode})
runCancelTest(t, func(t *testing.T, test cancelTest) {
testTransportCancelRequestInDo(t, test, nil)
})
}
func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
run(t, func(t *testing.T, mode testMode) {
testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0}))
}, []testMode{http1Mode})
runCancelTest(t, func(t *testing.T, test cancelTest) {
testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
})
}
func TestTransportCancelRequestInDial(t *testing.T) {
runCancelTest(t, testTransportCancelRequestInDial)
}
func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
@ -2633,17 +2723,19 @@ func TestTransportCancelRequestInDial(t *testing.T) {
cl := &Client{Transport: tr}
gotres := make(chan bool)
req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
req = test.newReq(req)
go func() {
_, err := cl.Do(req)
eventLog.Printf("Get = %v", err)
eventLog.Printf("Get error = %v", err != nil)
test.checkErr("Get", err)
gotres <- true
}()
inDial <- true
eventLog.Printf("canceling")
tr.CancelRequest(req)
tr.CancelRequest(req) // used to panic on second call
test.cancel(tr, req)
test.cancel(tr, req) // used to panic on second call to Transport.Cancel
if d, ok := t.Deadline(); ok {
// When the test's deadline is about to expire, log the pending events for
@ -2659,80 +2751,25 @@ func TestTransportCancelRequestInDial(t *testing.T) {
got := logbuf.String()
want := `dial: blocking
canceling
Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
Get error = true
`
if got != want {
t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
}
}
func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) }
func testCancelRequestWithChannel(t *testing.T, mode testMode) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
const msg = "Hello"
unblockc := make(chan struct{})
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, msg)
w.(Flusher).Flush() // send headers and some body
<-unblockc
})).ts
defer close(unblockc)
c := ts.Client()
tr := c.Transport.(*Transport)
req, _ := NewRequest("GET", ts.URL, nil)
cancel := make(chan struct{})
req.Cancel = cancel
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
}
body := make([]byte, len(msg))
n, _ := io.ReadFull(res.Body, body)
if n != len(body) || !bytes.Equal(body, []byte(msg)) {
t.Errorf("Body = %q; want %q", body[:n], msg)
}
close(cancel)
tail, err := io.ReadAll(res.Body)
res.Body.Close()
if err != ExportErrRequestCanceled {
t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
} else if len(tail) > 0 {
t.Errorf("Spurious bytes from Body.Read: %q", tail)
}
// Verify no outstanding requests after readLoop/writeLoop
// goroutines shut down.
waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
n := tr.NumPendingRequestsForTesting()
if n > 0 {
if d > 0 {
t.Logf("pending requests = %d after %v (want 0)", n, d)
}
return false
}
return true
})
}
// Issue 51354
func TestCancelRequestWithBodyWithChannel(t *testing.T) {
run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode})
func TestTransportCancelRequestWithBody(t *testing.T) {
runCancelTest(t, testTransportCancelRequestWithBody)
}
func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
const msg = "Hello"
unblockc := make(chan struct{})
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, msg)
w.(Flusher).Flush() // send headers and some body
<-unblockc
@ -2743,8 +2780,7 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
tr := c.Transport.(*Transport)
req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
cancel := make(chan struct{})
req.Cancel = cancel
req = test.newReq(req)
res, err := c.Do(req)
if err != nil {
@ -2755,13 +2791,12 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
if n != len(body) || !bytes.Equal(body, []byte(msg)) {
t.Errorf("Body = %q; want %q", body[:n], msg)
}
close(cancel)
test.cancel(tr, req)
tail, err := io.ReadAll(res.Body)
res.Body.Close()
if err != ExportErrRequestCanceled {
t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
} else if len(tail) > 0 {
test.checkErr("Body.Read", err)
if len(tail) > 0 {
t.Errorf("Spurious bytes from Body.Read: %q", tail)
}
@ -2779,53 +2814,39 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
})
}
func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
func TestTransportCancelRequestBeforeDo(t *testing.T) {
// We can't cancel a request that hasn't started using Transport.CancelRequest.
run(t, func(t *testing.T, mode testMode) {
testCancelRequestWithChannelBeforeDo(t, mode, false)
t.Run("RequestCancel", func(t *testing.T) {
runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
})
t.Run("ContextCancel", func(t *testing.T) {
runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
})
})
}
func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
run(t, func(t *testing.T, mode testMode) {
testCancelRequestWithChannelBeforeDo(t, mode, true)
})
}
func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) {
func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
unblockc := make(chan bool)
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
})).ts
}))
defer close(unblockc)
c := ts.Client()
c := cst.ts.Client()
req, _ := NewRequest("GET", ts.URL, nil)
if withCtx {
ctx, cancel := context.WithCancel(context.Background())
cancel()
req = req.WithContext(ctx)
} else {
ch := make(chan struct{})
req.Cancel = ch
close(ch)
}
req, _ := NewRequest("GET", cst.ts.URL, nil)
req = test.newReq(req)
test.cancel(cst.tr, req)
_, err := c.Do(req)
if ue, ok := err.(*url.Error); ok {
err = ue.Err
}
if withCtx {
if err != context.Canceled {
t.Errorf("Do error = %v; want %v", err, context.Canceled)
}
} else {
if err == nil || !strings.Contains(err.Error(), "canceled") {
t.Errorf("Do error = %v; want cancellation", err)
}
}
test.checkErr("Do", err)
}
// Issue 11020. The returned error message should be errRequestCanceled
func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
}
func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
defer afterTest(t)
serverConnCh := make(chan net.Conn, 1)
@ -2839,6 +2860,7 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
defer tr.CloseIdleConnections()
errc := make(chan error, 1)
req, _ := NewRequest("GET", "http://example.com/", nil)
req = test.newReq(req)
go func() {
_, err := tr.RoundTrip(req)
errc <- err
@ -2854,15 +2876,13 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
}
defer sc.Close()
tr.CancelRequest(req)
test.cancel(tr, req)
err := <-errc
if err == nil {
t.Fatalf("unexpected success from RoundTrip")
}
if err != ExportErrRequestCanceled {
t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
}
test.checkErr("RoundTrip", err)
}
// golang.org/issue/3672 -- Client can't close HTTP stream