diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index 69757bdca6..b656aa9731 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -82,6 +82,10 @@ func SetInstallConnClosedHook(f func()) { testHookPersistConnClosedGotRes = f } +func SetEnterRoundTripHook(f func()) { + testHookEnterRoundTrip = f +} + func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { f := func() <-chan time.Time { return ch diff --git a/src/net/http/transport.go b/src/net/http/transport.go index b754472be6..e31ae93e2a 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -475,6 +475,25 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) { } } +// replaceReqCanceler replaces an existing cancel function. If there is no cancel function +// for the request, we don't set the function and return false. +// Since CancelRequest will clear the canceler, we can use the return value to detect if +// the request was canceled since the last setReqCancel call. +func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool { + t.reqMu.Lock() + defer t.reqMu.Unlock() + _, ok := t.reqCanceler[r] + if !ok { + return false + } + if fn != nil { + t.reqCanceler[r] = fn + } else { + delete(t.reqCanceler, r) + } + return true +} + func (t *Transport) dial(network, addr string) (c net.Conn, err error) { if t.Dial != nil { return t.Dial(network, addr) @@ -491,6 +510,10 @@ var prePendingDial, postPendingDial func() // is ready to write requests to. func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) { if pc := t.getIdleConn(cm); pc != nil { + // set request canceler to some non-nil function so we + // can detect whether it was cleared between now and when + // we enter roundTrip + t.setReqCanceler(req, func() {}) return pc, nil } @@ -1063,10 +1086,20 @@ var errTimeout error = &httpError{err: "net/http: timeout awaiting response head var errClosed error = &httpError{err: "net/http: transport closed before response was received"} var errRequestCanceled = errors.New("net/http: request canceled") -var testHookPersistConnClosedGotRes func() // nil except for tests +// nil except for tests +var ( + testHookPersistConnClosedGotRes func() + testHookEnterRoundTrip func() +) func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { - pc.t.setReqCanceler(req.Request, pc.cancelRequest) + if hook := testHookEnterRoundTrip; hook != nil { + hook() + } + if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) { + pc.t.putIdleConn(pc) + return nil, errRequestCanceled + } pc.lk.Lock() pc.numExpectedResponses++ headerFn := pc.mutateHeaderFunc diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 2d52f17721..d20ba13208 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -2400,6 +2400,32 @@ func TestTransportResponseCancelRace(t *testing.T) { res.Body.Close() } +func TestTransportDialCancelRace(t *testing.T) { + defer afterTest(t) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + SetEnterRoundTripHook(func() { + tr.CancelRequest(req) + }) + defer SetEnterRoundTripHook(nil) + res, err := tr.RoundTrip(req) + if err != ExportErrRequestCanceled { + t.Errorf("expected canceled request error; got %v", err) + if err == nil { + res.Body.Close() + } + } +} + func wantBody(res *http.Response, err error, want string) error { if err != nil { return err