diff --git a/src/net/http/transport.go b/src/net/http/transport.go index c07352b018e..d30eb79508a 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -2248,7 +2248,7 @@ func (pc *persistConn) readLoop() { } case <-rc.req.Cancel: alive = false - pc.t.CancelRequest(rc.req) + pc.t.cancelRequest(rc.cancelKey, errRequestCanceled) case <-rc.req.Context().Done(): alive = false pc.t.cancelRequest(rc.cancelKey, rc.req.Context().Err()) diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 028fecc9611..bcc26aa58e0 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -2440,6 +2440,7 @@ func testTransportCancelRequest(t *testing.T, mode testMode) { if d > 0 { t.Logf("pending requests = %d after %v (want 0)", n, d) } + return false } return true }) @@ -2599,6 +2600,65 @@ func testCancelRequestWithChannel(t *testing.T, mode testMode) { if d > 0 { t.Logf("pending requests = %d after %v (want 0)", n, d) } + return false + } + return true + }) +} + +// Issue 51354 +func TestCancelRequestWithBodyWithChannel(t *testing.T) { + run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode}) +} +func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { + if testing.Short() { + t.Skip("skipping test in -short mode") + } + + const msg = "Hello" + unblockc := make(chan struct{}) + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, msg) + w.(Flusher).Flush() // send headers and some body + <-unblockc + })).ts + defer close(unblockc) + + c := ts.Client() + tr := c.Transport.(*Transport) + + req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody")) + cancel := make(chan struct{}) + req.Cancel = cancel + + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + body := make([]byte, len(msg)) + n, _ := io.ReadFull(res.Body, body) + if n != len(body) || !bytes.Equal(body, []byte(msg)) { + t.Errorf("Body = %q; want %q", body[:n], msg) + } + close(cancel) + + tail, err := io.ReadAll(res.Body) + res.Body.Close() + if err != ExportErrRequestCanceled { + t.Errorf("Body.Read error = %v; want errRequestCanceled", err) + } else if len(tail) > 0 { + t.Errorf("Spurious bytes from Body.Read: %q", tail) + } + + // Verify no outstanding requests after readLoop/writeLoop + // goroutines shut down. + waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { + n := tr.NumPendingRequestsForTesting() + if n > 0 { + if d > 0 { + t.Logf("pending requests = %d after %v (want 0)", n, d) + } + return false } return true })