diff --git a/src/net/http/transport.go b/src/net/http/transport.go index f7a7092ef7..0d4332c344 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -2333,7 +2333,7 @@ func (pc *persistConn) readLoop() { } case <-rc.treq.ctx.Done(): alive = false - pc.cancelRequest(errRequestCanceled) + pc.cancelRequest(context.Cause(rc.treq.ctx)) case <-pc.closech: alive = false } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index fa147e164e..25876e8d16 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -2507,17 +2507,103 @@ func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) { } } -func TestTransportCancelRequest(t *testing.T) { - run(t, testTransportCancelRequest, []testMode{http1Mode}) +// A cancelTest is a test of request cancellation. +type cancelTest struct { + mode testMode + newReq func(req *Request) *Request // prepare the request to cancel + cancel func(tr *Transport, req *Request) // cancel the request + checkErr func(when string, err error) // verify the expected error } -func testTransportCancelRequest(t *testing.T, mode testMode) { + +// runCancelTestTransport uses Transport.CancelRequest. +func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) { + t.Run("TransportCancel", func(t *testing.T) { + f(t, cancelTest{ + mode: mode, + newReq: func(req *Request) *Request { + return req + }, + cancel: func(tr *Transport, req *Request) { + tr.CancelRequest(req) + }, + checkErr: func(when string, err error) { + if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) { + t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err) + } + }, + }) + }) +} + +// runCancelTestChannel uses Request.Cancel. +func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) { + var cancelOnce sync.Once + cancelc := make(chan struct{}) + f(t, cancelTest{ + mode: mode, + newReq: func(req *Request) *Request { + req.Cancel = cancelc + return req + }, + cancel: func(tr *Transport, req *Request) { + cancelOnce.Do(func() { + close(cancelc) + }) + }, + checkErr: func(when string, err error) { + if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) { + t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err) + } + }, + }) +} + +// runCancelTestContext uses a request context. +func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) { + ctx, cancel := context.WithCancel(context.Background()) + f(t, cancelTest{ + mode: mode, + newReq: func(req *Request) *Request { + return req.WithContext(ctx) + }, + cancel: func(tr *Transport, req *Request) { + cancel() + }, + checkErr: func(when string, err error) { + if !errors.Is(err, context.Canceled) { + t.Errorf("%v error = %v, want context.Canceled", when, err) + } + }, + }) +} + +func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) { + run(t, func(t *testing.T, mode testMode) { + if mode == http1Mode { + t.Run("TransportCancel", func(t *testing.T) { + runCancelTestTransport(t, mode, f) + }) + } + t.Run("RequestCancel", func(t *testing.T) { + runCancelTestChannel(t, mode, f) + }) + t.Run("ContextCancel", func(t *testing.T) { + runCancelTestContext(t, mode, f) + }) + }, opts...) +} + +func TestTransportCancelRequest(t *testing.T) { + runCancelTest(t, testTransportCancelRequest) +} +func testTransportCancelRequest(t *testing.T, test cancelTest) { if testing.Short() { t.Skip("skipping test in -short mode") } const msg = "Hello" unblockc := make(chan bool) - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() // send headers and some body <-unblockc @@ -2528,6 +2614,7 @@ func testTransportCancelRequest(t *testing.T, mode testMode) { tr := c.Transport.(*Transport) req, _ := NewRequest("GET", ts.URL, nil) + req = test.newReq(req) res, err := c.Do(req) if err != nil { t.Fatal(err) @@ -2537,13 +2624,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) { if n != len(body) || !bytes.Equal(body, []byte(msg)) { t.Errorf("Body = %q; want %q", body[:n], msg) } - tr.CancelRequest(req) + test.cancel(tr, req) tail, err := io.ReadAll(res.Body) res.Body.Close() - if err != ExportErrRequestCanceled { - t.Errorf("Body.Read error = %v; want errRequestCanceled", err) - } else if len(tail) > 0 { + test.checkErr("Body.Read", err) + if len(tail) > 0 { t.Errorf("Spurious bytes from Body.Read: %q", tail) } @@ -2561,12 +2647,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) { }) } -func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) { +func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc })).ts defer close(unblockc) @@ -2576,6 +2662,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) donec := make(chan bool) req, _ := NewRequest("GET", ts.URL, body) + req = test.newReq(req) go func() { defer close(donec) c.Do(req) @@ -2583,7 +2670,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) unblockc <- true waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { - tr.CancelRequest(req) + test.cancel(tr, req) select { case <-donec: return true @@ -2597,18 +2684,21 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) } func TestTransportCancelRequestInDo(t *testing.T) { - run(t, func(t *testing.T, mode testMode) { - testTransportCancelRequestInDo(t, mode, nil) - }, []testMode{http1Mode}) + runCancelTest(t, func(t *testing.T, test cancelTest) { + testTransportCancelRequestInDo(t, test, nil) + }) } func TestTransportCancelRequestWithBodyInDo(t *testing.T) { - run(t, func(t *testing.T, mode testMode) { - testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0})) - }, []testMode{http1Mode}) + runCancelTest(t, func(t *testing.T, test cancelTest) { + testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0})) + }) } func TestTransportCancelRequestInDial(t *testing.T) { + runCancelTest(t, testTransportCancelRequestInDial) +} +func testTransportCancelRequestInDial(t *testing.T, test cancelTest) { defer afterTest(t) if testing.Short() { t.Skip("skipping test in -short mode") @@ -2633,17 +2723,19 @@ func TestTransportCancelRequestInDial(t *testing.T) { cl := &Client{Transport: tr} gotres := make(chan bool) req, _ := NewRequest("GET", "http://something.no-network.tld/", nil) + req = test.newReq(req) go func() { _, err := cl.Do(req) - eventLog.Printf("Get = %v", err) + eventLog.Printf("Get error = %v", err != nil) + test.checkErr("Get", err) gotres <- true }() inDial <- true eventLog.Printf("canceling") - tr.CancelRequest(req) - tr.CancelRequest(req) // used to panic on second call + test.cancel(tr, req) + test.cancel(tr, req) // used to panic on second call to Transport.Cancel if d, ok := t.Deadline(); ok { // When the test's deadline is about to expire, log the pending events for @@ -2659,80 +2751,25 @@ func TestTransportCancelRequestInDial(t *testing.T) { got := logbuf.String() want := `dial: blocking canceling -Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection +Get error = true ` if got != want { t.Errorf("Got events:\n%s\nWant:\n%s", got, want) } } -func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) } -func testCancelRequestWithChannel(t *testing.T, mode testMode) { - if testing.Short() { - t.Skip("skipping test in -short mode") - } - - const msg = "Hello" - unblockc := make(chan struct{}) - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { - io.WriteString(w, msg) - w.(Flusher).Flush() // send headers and some body - <-unblockc - })).ts - defer close(unblockc) - - c := ts.Client() - tr := c.Transport.(*Transport) - - req, _ := NewRequest("GET", ts.URL, nil) - cancel := make(chan struct{}) - req.Cancel = cancel - - res, err := c.Do(req) - if err != nil { - t.Fatal(err) - } - body := make([]byte, len(msg)) - n, _ := io.ReadFull(res.Body, body) - if n != len(body) || !bytes.Equal(body, []byte(msg)) { - t.Errorf("Body = %q; want %q", body[:n], msg) - } - close(cancel) - - tail, err := io.ReadAll(res.Body) - res.Body.Close() - if err != ExportErrRequestCanceled { - t.Errorf("Body.Read error = %v; want errRequestCanceled", err) - } else if len(tail) > 0 { - t.Errorf("Spurious bytes from Body.Read: %q", tail) - } - - // Verify no outstanding requests after readLoop/writeLoop - // goroutines shut down. - waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool { - n := tr.NumPendingRequestsForTesting() - if n > 0 { - if d > 0 { - t.Logf("pending requests = %d after %v (want 0)", n, d) - } - return false - } - return true - }) -} - // Issue 51354 -func TestCancelRequestWithBodyWithChannel(t *testing.T) { - run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode}) +func TestTransportCancelRequestWithBody(t *testing.T) { + runCancelTest(t, testTransportCancelRequestWithBody) } -func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { +func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) { if testing.Short() { t.Skip("skipping test in -short mode") } const msg = "Hello" unblockc := make(chan struct{}) - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() // send headers and some body <-unblockc @@ -2743,8 +2780,7 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { tr := c.Transport.(*Transport) req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody")) - cancel := make(chan struct{}) - req.Cancel = cancel + req = test.newReq(req) res, err := c.Do(req) if err != nil { @@ -2755,13 +2791,12 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { if n != len(body) || !bytes.Equal(body, []byte(msg)) { t.Errorf("Body = %q; want %q", body[:n], msg) } - close(cancel) + test.cancel(tr, req) tail, err := io.ReadAll(res.Body) res.Body.Close() - if err != ExportErrRequestCanceled { - t.Errorf("Body.Read error = %v; want errRequestCanceled", err) - } else if len(tail) > 0 { + test.checkErr("Body.Read", err) + if len(tail) > 0 { t.Errorf("Spurious bytes from Body.Read: %q", tail) } @@ -2779,53 +2814,39 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) { }) } -func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { +func TestTransportCancelRequestBeforeDo(t *testing.T) { + // We can't cancel a request that hasn't started using Transport.CancelRequest. run(t, func(t *testing.T, mode testMode) { - testCancelRequestWithChannelBeforeDo(t, mode, false) + t.Run("RequestCancel", func(t *testing.T) { + runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo) + }) + t.Run("ContextCancel", func(t *testing.T) { + runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo) + }) }) } -func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { - run(t, func(t *testing.T, mode testMode) { - testCancelRequestWithChannelBeforeDo(t, mode, true) - }) -} -func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) { +func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) { unblockc := make(chan bool) - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc - })).ts + })) defer close(unblockc) - c := ts.Client() + c := cst.ts.Client() - req, _ := NewRequest("GET", ts.URL, nil) - if withCtx { - ctx, cancel := context.WithCancel(context.Background()) - cancel() - req = req.WithContext(ctx) - } else { - ch := make(chan struct{}) - req.Cancel = ch - close(ch) - } + req, _ := NewRequest("GET", cst.ts.URL, nil) + req = test.newReq(req) + test.cancel(cst.tr, req) _, err := c.Do(req) - if ue, ok := err.(*url.Error); ok { - err = ue.Err - } - if withCtx { - if err != context.Canceled { - t.Errorf("Do error = %v; want %v", err, context.Canceled) - } - } else { - if err == nil || !strings.Contains(err.Error(), "canceled") { - t.Errorf("Do error = %v; want cancellation", err) - } - } + test.checkErr("Do", err) } // Issue 11020. The returned error message should be errRequestCanceled -func TestTransportCancelBeforeResponseHeaders(t *testing.T) { +func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) { + runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode}) +} +func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) { defer afterTest(t) serverConnCh := make(chan net.Conn, 1) @@ -2839,6 +2860,7 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) { defer tr.CloseIdleConnections() errc := make(chan error, 1) req, _ := NewRequest("GET", "http://example.com/", nil) + req = test.newReq(req) go func() { _, err := tr.RoundTrip(req) errc <- err @@ -2854,15 +2876,13 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) { } defer sc.Close() - tr.CancelRequest(req) + test.cancel(tr, req) err := <-errc if err == nil { t.Fatalf("unexpected success from RoundTrip") } - if err != ExportErrRequestCanceled { - t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err) - } + test.checkErr("RoundTrip", err) } // golang.org/issue/3672 -- Client can't close HTTP stream