1
0
mirror of https://github.com/golang/go synced 2024-11-23 11:00:08 -07:00

net/http: make Transport.CancelRequest work for requests blocked in a dial

Fixes #6951

LGTM=josharian
R=golang-codereviews, josharian
CC=golang-codereviews
https://golang.org/cl/69280043
This commit is contained in:
Brad Fitzpatrick 2014-02-27 13:32:40 -08:00
parent 28cc8aa89e
commit dc6bf295b9
3 changed files with 98 additions and 27 deletions

View File

@ -21,7 +21,7 @@ var ExportAppendTime = appendTime
func (t *Transport) NumPendingRequestsForTesting() int {
t.reqMu.Lock()
defer t.reqMu.Unlock()
return len(t.reqConn)
return len(t.reqCanceler)
}
func (t *Transport) IdleConnKeysForTesting() (keys []string) {

View File

@ -47,13 +47,13 @@ const DefaultMaxIdleConnsPerHost = 2
// https, and http proxies (for either http or https with CONNECT).
// Transport can also cache connections for future re-use.
type Transport struct {
idleMu sync.Mutex
idleConn map[connectMethodKey][]*persistConn
idleConnCh map[connectMethodKey]chan *persistConn
reqMu sync.Mutex
reqConn map[*Request]*persistConn
altMu sync.RWMutex
altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
idleMu sync.Mutex
idleConn map[connectMethodKey][]*persistConn
idleConnCh map[connectMethodKey]chan *persistConn
reqMu sync.Mutex
reqCanceler map[*Request]func()
altMu sync.RWMutex
altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
@ -190,8 +190,9 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
// host (for http or https), the http proxy, or the http proxy
// pre-CONNECTed to https server. In any case, we'll be ready
// to send it requests.
pconn, err := t.getConn(cm)
pconn, err := t.getConn(req, cm)
if err != nil {
t.setReqCanceler(req, nil)
return nil, err
}
@ -243,10 +244,10 @@ func (t *Transport) CloseIdleConnections() {
// connection.
func (t *Transport) CancelRequest(req *Request) {
t.reqMu.Lock()
pc := t.reqConn[req]
cancel := t.reqCanceler[req]
t.reqMu.Unlock()
if pc != nil {
pc.conn.Close()
if cancel != nil {
cancel()
}
}
@ -417,16 +418,16 @@ func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn) {
}
}
func (t *Transport) setReqConn(r *Request, pc *persistConn) {
func (t *Transport) setReqCanceler(r *Request, fn func()) {
t.reqMu.Lock()
defer t.reqMu.Unlock()
if t.reqConn == nil {
t.reqConn = make(map[*Request]*persistConn)
if t.reqCanceler == nil {
t.reqCanceler = make(map[*Request]func())
}
if pc != nil {
t.reqConn[r] = pc
if fn != nil {
t.reqCanceler[r] = fn
} else {
delete(t.reqConn, r)
delete(t.reqCanceler, r)
}
}
@ -441,7 +442,7 @@ func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
// specified in the connectMethod. This includes doing a proxy CONNECT
// and/or setting up TLS. If this doesn't return an error, the persistConn
// is ready to write requests to.
func (t *Transport) getConn(cm connectMethod) (*persistConn, error) {
func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
if pc := t.getIdleConn(cm); pc != nil {
return pc, nil
}
@ -451,6 +452,16 @@ func (t *Transport) getConn(cm connectMethod) (*persistConn, error) {
err error
}
dialc := make(chan dialRes)
handlePendingDial := func() {
if v := <-dialc; v.err == nil {
t.putIdleConn(v.pc)
}
}
cancelc := make(chan struct{})
t.setReqCanceler(req, func() { close(cancelc) })
go func() {
pc, err := t.dialConn(cm)
dialc <- dialRes{pc, err}
@ -467,12 +478,11 @@ func (t *Transport) getConn(cm connectMethod) (*persistConn, error) {
// else's dial that they didn't use.
// But our dial is still going, so give it away
// when it finishes:
go func() {
if v := <-dialc; v.err == nil {
t.putIdleConn(v.pc)
}
}()
go handlePendingDial()
return pc, nil
case <-cancelc:
go handlePendingDial()
return nil, errors.New("net/http: request canceled while waiting for connection")
}
}
@ -732,6 +742,10 @@ func (pc *persistConn) isBroken() bool {
return b
}
func (pc *persistConn) cancelRequest() {
pc.conn.Close()
}
var remoteSideClosedFunc func(error) bool // or nil to use default
func remoteSideClosed(err error) bool {
@ -843,7 +857,7 @@ func (pc *persistConn) readLoop() {
alive = <-waitForBodyRead
}
pc.t.setReqConn(rc.req, nil)
pc.t.setReqCanceler(rc.req, nil)
if !alive {
pc.close()
@ -910,7 +924,7 @@ var errTimeout error = &httpError{err: "net/http: timeout awaiting response head
var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
pc.t.setReqConn(req.Request, pc)
pc.t.setReqCanceler(req.Request, pc.cancelRequest)
pc.lk.Lock()
pc.numExpectedResponses++
headerFn := pc.mutateHeaderFunc
@ -995,7 +1009,7 @@ WaitResponse:
pc.lk.Unlock()
if re.err != nil {
pc.t.setReqConn(req.Request, nil)
pc.t.setReqCanceler(req.Request, nil)
}
return re.res, re.err
}

View File

@ -11,9 +11,11 @@ import (
"bytes"
"compress/gzip"
"crypto/rand"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
. "net/http"
@ -1321,6 +1323,61 @@ func TestTransportCancelRequest(t *testing.T) {
}
}
func TestTransportCancelRequestInDial(t *testing.T) {
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
}
var logbuf bytes.Buffer
eventLog := log.New(&logbuf, "", 0)
unblockDial := make(chan bool)
defer close(unblockDial)
inDial := make(chan bool)
tr := &Transport{
Dial: func(network, addr string) (net.Conn, error) {
eventLog.Println("dial: blocking")
inDial <- true
<-unblockDial
return nil, errors.New("nope")
},
}
cl := &Client{Transport: tr}
gotres := make(chan bool)
req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
go func() {
_, err := cl.Do(req)
eventLog.Printf("Get = %v", err)
gotres <- true
}()
select {
case <-inDial:
case <-time.After(5 * time.Second):
t.Fatal("timeout; never saw blocking dial")
}
eventLog.Printf("canceling")
tr.CancelRequest(req)
select {
case <-gotres:
case <-time.After(5 * time.Second):
panic("hang. events are: " + logbuf.String())
t.Fatal("timeout; cancel didn't work?")
}
got := logbuf.String()
want := `dial: blocking
canceling
Get = Get http://something.no-network.tld/: net/http: request canceled while waiting for connection
`
if got != want {
t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
}
}
// golang.org/issue/3672 -- Client can't close HTTP stream
// Calling Close on a Response.Body used to just read until EOF.
// Now it actually closes the TCP connection.