mirror of
https://github.com/golang/go
synced 2024-11-23 11:00:08 -07:00
net/http: make Transport.CancelRequest work for requests blocked in a dial
Fixes #6951 LGTM=josharian R=golang-codereviews, josharian CC=golang-codereviews https://golang.org/cl/69280043
This commit is contained in:
parent
28cc8aa89e
commit
dc6bf295b9
@ -21,7 +21,7 @@ var ExportAppendTime = appendTime
|
||||
func (t *Transport) NumPendingRequestsForTesting() int {
|
||||
t.reqMu.Lock()
|
||||
defer t.reqMu.Unlock()
|
||||
return len(t.reqConn)
|
||||
return len(t.reqCanceler)
|
||||
}
|
||||
|
||||
func (t *Transport) IdleConnKeysForTesting() (keys []string) {
|
||||
|
@ -47,13 +47,13 @@ const DefaultMaxIdleConnsPerHost = 2
|
||||
// https, and http proxies (for either http or https with CONNECT).
|
||||
// Transport can also cache connections for future re-use.
|
||||
type Transport struct {
|
||||
idleMu sync.Mutex
|
||||
idleConn map[connectMethodKey][]*persistConn
|
||||
idleConnCh map[connectMethodKey]chan *persistConn
|
||||
reqMu sync.Mutex
|
||||
reqConn map[*Request]*persistConn
|
||||
altMu sync.RWMutex
|
||||
altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
|
||||
idleMu sync.Mutex
|
||||
idleConn map[connectMethodKey][]*persistConn
|
||||
idleConnCh map[connectMethodKey]chan *persistConn
|
||||
reqMu sync.Mutex
|
||||
reqCanceler map[*Request]func()
|
||||
altMu sync.RWMutex
|
||||
altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
|
||||
|
||||
// Proxy specifies a function to return a proxy for a given
|
||||
// Request. If the function returns a non-nil error, the
|
||||
@ -190,8 +190,9 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
|
||||
// host (for http or https), the http proxy, or the http proxy
|
||||
// pre-CONNECTed to https server. In any case, we'll be ready
|
||||
// to send it requests.
|
||||
pconn, err := t.getConn(cm)
|
||||
pconn, err := t.getConn(req, cm)
|
||||
if err != nil {
|
||||
t.setReqCanceler(req, nil)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -243,10 +244,10 @@ func (t *Transport) CloseIdleConnections() {
|
||||
// connection.
|
||||
func (t *Transport) CancelRequest(req *Request) {
|
||||
t.reqMu.Lock()
|
||||
pc := t.reqConn[req]
|
||||
cancel := t.reqCanceler[req]
|
||||
t.reqMu.Unlock()
|
||||
if pc != nil {
|
||||
pc.conn.Close()
|
||||
if cancel != nil {
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
@ -417,16 +418,16 @@ func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Transport) setReqConn(r *Request, pc *persistConn) {
|
||||
func (t *Transport) setReqCanceler(r *Request, fn func()) {
|
||||
t.reqMu.Lock()
|
||||
defer t.reqMu.Unlock()
|
||||
if t.reqConn == nil {
|
||||
t.reqConn = make(map[*Request]*persistConn)
|
||||
if t.reqCanceler == nil {
|
||||
t.reqCanceler = make(map[*Request]func())
|
||||
}
|
||||
if pc != nil {
|
||||
t.reqConn[r] = pc
|
||||
if fn != nil {
|
||||
t.reqCanceler[r] = fn
|
||||
} else {
|
||||
delete(t.reqConn, r)
|
||||
delete(t.reqCanceler, r)
|
||||
}
|
||||
}
|
||||
|
||||
@ -441,7 +442,7 @@ func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
|
||||
// specified in the connectMethod. This includes doing a proxy CONNECT
|
||||
// and/or setting up TLS. If this doesn't return an error, the persistConn
|
||||
// is ready to write requests to.
|
||||
func (t *Transport) getConn(cm connectMethod) (*persistConn, error) {
|
||||
func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
|
||||
if pc := t.getIdleConn(cm); pc != nil {
|
||||
return pc, nil
|
||||
}
|
||||
@ -451,6 +452,16 @@ func (t *Transport) getConn(cm connectMethod) (*persistConn, error) {
|
||||
err error
|
||||
}
|
||||
dialc := make(chan dialRes)
|
||||
|
||||
handlePendingDial := func() {
|
||||
if v := <-dialc; v.err == nil {
|
||||
t.putIdleConn(v.pc)
|
||||
}
|
||||
}
|
||||
|
||||
cancelc := make(chan struct{})
|
||||
t.setReqCanceler(req, func() { close(cancelc) })
|
||||
|
||||
go func() {
|
||||
pc, err := t.dialConn(cm)
|
||||
dialc <- dialRes{pc, err}
|
||||
@ -467,12 +478,11 @@ func (t *Transport) getConn(cm connectMethod) (*persistConn, error) {
|
||||
// else's dial that they didn't use.
|
||||
// But our dial is still going, so give it away
|
||||
// when it finishes:
|
||||
go func() {
|
||||
if v := <-dialc; v.err == nil {
|
||||
t.putIdleConn(v.pc)
|
||||
}
|
||||
}()
|
||||
go handlePendingDial()
|
||||
return pc, nil
|
||||
case <-cancelc:
|
||||
go handlePendingDial()
|
||||
return nil, errors.New("net/http: request canceled while waiting for connection")
|
||||
}
|
||||
}
|
||||
|
||||
@ -732,6 +742,10 @@ func (pc *persistConn) isBroken() bool {
|
||||
return b
|
||||
}
|
||||
|
||||
func (pc *persistConn) cancelRequest() {
|
||||
pc.conn.Close()
|
||||
}
|
||||
|
||||
var remoteSideClosedFunc func(error) bool // or nil to use default
|
||||
|
||||
func remoteSideClosed(err error) bool {
|
||||
@ -843,7 +857,7 @@ func (pc *persistConn) readLoop() {
|
||||
alive = <-waitForBodyRead
|
||||
}
|
||||
|
||||
pc.t.setReqConn(rc.req, nil)
|
||||
pc.t.setReqCanceler(rc.req, nil)
|
||||
|
||||
if !alive {
|
||||
pc.close()
|
||||
@ -910,7 +924,7 @@ var errTimeout error = &httpError{err: "net/http: timeout awaiting response head
|
||||
var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
|
||||
|
||||
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
|
||||
pc.t.setReqConn(req.Request, pc)
|
||||
pc.t.setReqCanceler(req.Request, pc.cancelRequest)
|
||||
pc.lk.Lock()
|
||||
pc.numExpectedResponses++
|
||||
headerFn := pc.mutateHeaderFunc
|
||||
@ -995,7 +1009,7 @@ WaitResponse:
|
||||
pc.lk.Unlock()
|
||||
|
||||
if re.err != nil {
|
||||
pc.t.setReqConn(req.Request, nil)
|
||||
pc.t.setReqCanceler(req.Request, nil)
|
||||
}
|
||||
return re.res, re.err
|
||||
}
|
||||
|
@ -11,9 +11,11 @@ import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
. "net/http"
|
||||
@ -1321,6 +1323,61 @@ func TestTransportCancelRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransportCancelRequestInDial(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
if testing.Short() {
|
||||
t.Skip("skipping test in -short mode")
|
||||
}
|
||||
var logbuf bytes.Buffer
|
||||
eventLog := log.New(&logbuf, "", 0)
|
||||
|
||||
unblockDial := make(chan bool)
|
||||
defer close(unblockDial)
|
||||
|
||||
inDial := make(chan bool)
|
||||
tr := &Transport{
|
||||
Dial: func(network, addr string) (net.Conn, error) {
|
||||
eventLog.Println("dial: blocking")
|
||||
inDial <- true
|
||||
<-unblockDial
|
||||
return nil, errors.New("nope")
|
||||
},
|
||||
}
|
||||
cl := &Client{Transport: tr}
|
||||
gotres := make(chan bool)
|
||||
req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
|
||||
go func() {
|
||||
_, err := cl.Do(req)
|
||||
eventLog.Printf("Get = %v", err)
|
||||
gotres <- true
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-inDial:
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout; never saw blocking dial")
|
||||
}
|
||||
|
||||
eventLog.Printf("canceling")
|
||||
tr.CancelRequest(req)
|
||||
|
||||
select {
|
||||
case <-gotres:
|
||||
case <-time.After(5 * time.Second):
|
||||
panic("hang. events are: " + logbuf.String())
|
||||
t.Fatal("timeout; cancel didn't work?")
|
||||
}
|
||||
|
||||
got := logbuf.String()
|
||||
want := `dial: blocking
|
||||
canceling
|
||||
Get = Get http://something.no-network.tld/: net/http: request canceled while waiting for connection
|
||||
`
|
||||
if got != want {
|
||||
t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
// golang.org/issue/3672 -- Client can't close HTTP stream
|
||||
// Calling Close on a Response.Body used to just read until EOF.
|
||||
// Now it actually closes the TCP connection.
|
||||
|
Loading…
Reference in New Issue
Block a user