mirror of
https://github.com/golang/go
synced 2024-11-18 14:54:40 -07:00
net/http: fix race between dialing and canceling
In the brief window between getConn and persistConn.roundTrip, a cancel could end up going missing. Fix by making it possible to inspect if a cancel function was cleared and checking if we were canceled before entering roundTrip. Fixes #10511 Change-Id: If6513e63fbc2edb703e36d6356ccc95a1dc33144 Reviewed-on: https://go-review.googlesource.com/9181 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
5fa2d9915f
commit
723f86537c
@ -82,6 +82,10 @@ func SetInstallConnClosedHook(f func()) {
|
||||
testHookPersistConnClosedGotRes = f
|
||||
}
|
||||
|
||||
func SetEnterRoundTripHook(f func()) {
|
||||
testHookEnterRoundTrip = f
|
||||
}
|
||||
|
||||
func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
|
||||
f := func() <-chan time.Time {
|
||||
return ch
|
||||
|
@ -475,6 +475,25 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) {
|
||||
}
|
||||
}
|
||||
|
||||
// replaceReqCanceler replaces an existing cancel function. If there is no cancel function
|
||||
// for the request, we don't set the function and return false.
|
||||
// Since CancelRequest will clear the canceler, we can use the return value to detect if
|
||||
// the request was canceled since the last setReqCancel call.
|
||||
func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
|
||||
t.reqMu.Lock()
|
||||
defer t.reqMu.Unlock()
|
||||
_, ok := t.reqCanceler[r]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if fn != nil {
|
||||
t.reqCanceler[r] = fn
|
||||
} else {
|
||||
delete(t.reqCanceler, r)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
|
||||
if t.Dial != nil {
|
||||
return t.Dial(network, addr)
|
||||
@ -491,6 +510,10 @@ var prePendingDial, postPendingDial func()
|
||||
// is ready to write requests to.
|
||||
func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
|
||||
if pc := t.getIdleConn(cm); pc != nil {
|
||||
// set request canceler to some non-nil function so we
|
||||
// can detect whether it was cleared between now and when
|
||||
// we enter roundTrip
|
||||
t.setReqCanceler(req, func() {})
|
||||
return pc, nil
|
||||
}
|
||||
|
||||
@ -1063,10 +1086,20 @@ var errTimeout error = &httpError{err: "net/http: timeout awaiting response head
|
||||
var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
|
||||
var errRequestCanceled = errors.New("net/http: request canceled")
|
||||
|
||||
var testHookPersistConnClosedGotRes func() // nil except for tests
|
||||
// nil except for tests
|
||||
var (
|
||||
testHookPersistConnClosedGotRes func()
|
||||
testHookEnterRoundTrip func()
|
||||
)
|
||||
|
||||
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
|
||||
pc.t.setReqCanceler(req.Request, pc.cancelRequest)
|
||||
if hook := testHookEnterRoundTrip; hook != nil {
|
||||
hook()
|
||||
}
|
||||
if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
|
||||
pc.t.putIdleConn(pc)
|
||||
return nil, errRequestCanceled
|
||||
}
|
||||
pc.lk.Lock()
|
||||
pc.numExpectedResponses++
|
||||
headerFn := pc.mutateHeaderFunc
|
||||
|
@ -2400,6 +2400,32 @@ func TestTransportResponseCancelRace(t *testing.T) {
|
||||
res.Body.Close()
|
||||
}
|
||||
|
||||
func TestTransportDialCancelRace(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
|
||||
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
|
||||
defer ts.Close()
|
||||
|
||||
tr := &Transport{}
|
||||
defer tr.CloseIdleConnections()
|
||||
|
||||
req, err := NewRequest("GET", ts.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
SetEnterRoundTripHook(func() {
|
||||
tr.CancelRequest(req)
|
||||
})
|
||||
defer SetEnterRoundTripHook(nil)
|
||||
res, err := tr.RoundTrip(req)
|
||||
if err != ExportErrRequestCanceled {
|
||||
t.Errorf("expected canceled request error; got %v", err)
|
||||
if err == nil {
|
||||
res.Body.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func wantBody(res *http.Response, err error, want string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
|
Loading…
Reference in New Issue
Block a user