diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index 7c7b5d56675..7cdb51b05b0 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -133,9 +133,11 @@ func (t *Transport) IdleConnStrsForTesting_h2() []string { return ret } -func (t *Transport) IdleConnCountForTesting(cacheKey string) int { +func (t *Transport) IdleConnCountForTesting(scheme, addr string) int { t.idleMu.Lock() defer t.idleMu.Unlock() + key := connectMethodKey{"", scheme, addr} + cacheKey := key.String() for k, conns := range t.idleConn { if k.String() == cacheKey { return len(conns) @@ -160,13 +162,19 @@ func (t *Transport) RequestIdleConnChForTesting() { t.getIdleConnCh(connectMethod{nil, "http", "example.com"}) } -func (t *Transport) PutIdleTestConn() bool { +func (t *Transport) PutIdleTestConn(scheme, addr string) bool { c, _ := net.Pipe() + key := connectMethodKey{"", scheme, addr} + select { + case <-t.incHostConnCount(key): + default: + return false + } return t.tryPutIdleConn(&persistConn{ t: t, conn: c, // dummy closech: make(chan struct{}), // so it can be closed - cacheKey: connectMethodKey{"", "http", "example.com"}, + cacheKey: key, }) == nil } diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 59bffd0ae89..182390cf017 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -55,6 +55,15 @@ var DefaultTransport RoundTripper = &Transport{ // MaxIdleConnsPerHost. const DefaultMaxIdleConnsPerHost = 2 +// connsPerHostClosedCh is a closed channel used by MaxConnsPerHost +// for the property that receives from a closed channel return the +// zero value. +var connsPerHostClosedCh = make(chan struct{}) + +func init() { + close(connsPerHostClosedCh) +} + // Transport is an implementation of RoundTripper that supports HTTP, // HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT). // @@ -103,6 +112,10 @@ type Transport struct { altMu sync.Mutex // guards changing altProto only altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme + connCountMu sync.Mutex + connPerHostCount map[connectMethodKey]int + connPerHostAvailable map[connectMethodKey]chan struct{} + // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the // request is aborted with the provided error. @@ -183,6 +196,18 @@ type Transport struct { // DefaultMaxIdleConnsPerHost is used. MaxIdleConnsPerHost int + // MaxConnsPerHost optionally limits the total number of + // connections per host, including connections in the dialing, + // active, and idle states. On limit violation, dials will block. + // + // Zero means no limit. + // + // For HTTP/2, this currently only controls the number of new + // connections being created at a time, instead of the total + // number. In practice, hosts using HTTP/2 only have about one + // idle connection, though. + MaxConnsPerHost int + // IdleConnTimeout is the maximum amount of time an idle // (keep-alive) connection will remain idle before closing // itself. @@ -231,8 +256,6 @@ type Transport struct { // h2transport (via onceSetNextProtoDefaults) nextProtoOnce sync.Once h2transport *http2Transport // non-nil if http2 wired up - - // TODO: tunable on max per-host TCP dials in flight (Issue 13957) } // onceSetNextProtoDefaults initializes TLSNextProto. @@ -409,7 +432,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { var resp *Response if pconn.alt != nil { // HTTP/2 path. - t.setReqCanceler(req, nil) // not cancelable with CancelRequest + t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host + t.setReqCanceler(req, nil) // not cancelable with CancelRequest resp, err = pconn.alt.RoundTrip(req) } else { resp, err = pconn.roundTrip(treq) @@ -908,6 +932,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC err error } dialc := make(chan dialRes) + cmKey := cm.key() // Copy these hooks so we don't race on the postPendingDial in // the goroutine we launch. Issue 11136. @@ -919,6 +944,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC go func() { if v := <-dialc; v.err == nil { t.putOrCloseIdleConn(v.pc) + } else { + t.decHostConnCount(cmKey) } testHookPostPendingDial() }() @@ -927,6 +954,27 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC cancelc := make(chan error, 1) t.setReqCanceler(req, func(err error) { cancelc <- err }) + if t.MaxConnsPerHost > 0 { + select { + case <-t.incHostConnCount(cmKey): + // count below conn per host limit; proceed + case pc := <-t.getIdleConnCh(cm): + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{Conn: pc.conn, Reused: pc.isReused()}) + } + return pc, nil + case <-req.Cancel: + return nil, errRequestCanceledConn + case <-req.Context().Done(): + return nil, req.Context().Err() + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err + } + } + go func() { pc, err := t.dialConn(ctx, cm) dialc <- dialRes{pc, err} @@ -944,6 +992,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC } // Our dial failed. See why to return a nicer error // value. + t.decHostConnCount(cmKey) select { case <-req.Cancel: // It was an error due to cancelation, so prioritize that @@ -987,6 +1036,83 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC } } +// incHostConnCount increments the count of connections for a +// given host. It returns an already-closed channel if the count +// is not at its limit; otherwise it returns a channel which is +// notified when the count is below the limit. +func (t *Transport) incHostConnCount(cmKey connectMethodKey) <-chan struct{} { + if t.MaxConnsPerHost <= 0 { + return connsPerHostClosedCh + } + t.connCountMu.Lock() + defer t.connCountMu.Unlock() + if t.connPerHostCount[cmKey] == t.MaxConnsPerHost { + if t.connPerHostAvailable == nil { + t.connPerHostAvailable = make(map[connectMethodKey]chan struct{}) + } + ch, ok := t.connPerHostAvailable[cmKey] + if !ok { + ch = make(chan struct{}) + t.connPerHostAvailable[cmKey] = ch + } + return ch + } + if t.connPerHostCount == nil { + t.connPerHostCount = make(map[connectMethodKey]int) + } + t.connPerHostCount[cmKey]++ + // return a closed channel to avoid race: if decHostConnCount is called + // after incHostConnCount and during the nil check, decHostConnCount + // will delete the channel since it's not being listened on yet. + return connsPerHostClosedCh +} + +// decHostConnCount decrements the count of connections +// for a given host. +// See Transport.MaxConnsPerHost. +func (t *Transport) decHostConnCount(cmKey connectMethodKey) { + if t.MaxConnsPerHost <= 0 { + return + } + t.connCountMu.Lock() + defer t.connCountMu.Unlock() + t.connPerHostCount[cmKey]-- + select { + case t.connPerHostAvailable[cmKey] <- struct{}{}: + default: + // close channel before deleting avoids getConn waiting forever in + // case getConn has reference to channel but hasn't started waiting. + // This could lead to more than MaxConnsPerHost in the unlikely case + // that > 1 go routine has fetched the channel but none started waiting. + if t.connPerHostAvailable[cmKey] != nil { + close(t.connPerHostAvailable[cmKey]) + } + delete(t.connPerHostAvailable, cmKey) + } + if t.connPerHostCount[cmKey] == 0 { + delete(t.connPerHostCount, cmKey) + } +} + +// connCloseListener wraps a connection, the transport that dialed it +// and the connected-to host key so the host connection count can be +// transparently decremented by whatever closes the embedded connection. +type connCloseListener struct { + net.Conn + t *Transport + cmKey connectMethodKey + didClose int32 +} + +func (c *connCloseListener) Close() error { + if atomic.AddInt32(&c.didClose, 1) != 1 { + return nil + } + err := c.Conn.Close() + c.t.decHostConnCount(c.cmKey) + return err +} + // The connect method and the transport can both specify a TLS // Host name. The transport's name takes precedence if present. func chooseTLSHost(cm connectMethod, t *Transport) string { @@ -1184,6 +1310,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon } } + if t.MaxConnsPerHost > 0 { + pconn.conn = &connCloseListener{Conn: pconn.conn, t: t, cmKey: pconn.cacheKey} + } pconn.br = bufio.NewReader(pconn) pconn.bw = bufio.NewWriter(persistConnWriter{pconn}) go pconn.readLoop() diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 87361e81ca5..5145da0ae07 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -446,27 +446,95 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { if e, g := 1, len(keys); e != g { t.Fatalf("after first response, expected %d idle conn cache keys; got %d", e, g) } - cacheKey := "|http|" + ts.Listener.Addr().String() + addr := ts.Listener.Addr().String() + cacheKey := "|http|" + addr if keys[0] != cacheKey { t.Fatalf("Expected idle cache key %q; got %q", cacheKey, keys[0]) } - if e, g := 1, tr.IdleConnCountForTesting(cacheKey); e != g { + if e, g := 1, tr.IdleConnCountForTesting("http", addr); e != g { t.Errorf("after first response, expected %d idle conns; got %d", e, g) } resch <- "res2" <-donech - if g, w := tr.IdleConnCountForTesting(cacheKey), 2; g != w { + if g, w := tr.IdleConnCountForTesting("http", addr), 2; g != w { t.Errorf("after second response, idle conns = %d; want %d", g, w) } resch <- "res3" <-donech - if g, w := tr.IdleConnCountForTesting(cacheKey), maxIdleConnsPerHost; g != w { + if g, w := tr.IdleConnCountForTesting("http", addr), maxIdleConnsPerHost; g != w { t.Errorf("after third response, idle conns = %d; want %d", g, w) } } +func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + _, err := w.Write([]byte("foo")) + if err != nil { + t.Fatalf("Write: %v", err) + } + })) + defer ts.Close() + c := ts.Client() + tr := c.Transport.(*Transport) + dialStarted := make(chan struct{}) + stallDial := make(chan struct{}) + tr.Dial = func(network, addr string) (net.Conn, error) { + dialStarted <- struct{}{} + <-stallDial + return net.Dial(network, addr) + } + + tr.DisableKeepAlives = true + tr.MaxConnsPerHost = 1 + + preDial := make(chan struct{}) + reqComplete := make(chan struct{}) + doReq := func(reqId string) { + req, _ := NewRequest("GET", ts.URL, nil) + trace := &httptrace.ClientTrace{ + GetConn: func(hostPort string) { + preDial <- struct{}{} + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + resp, err := tr.RoundTrip(req) + if err != nil { + t.Errorf("unexpected error for request %s: %v", reqId, err) + } + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("unexpected error for request %s: %v", reqId, err) + } + reqComplete <- struct{}{} + } + // get req1 to dial-in-progress + go doReq("req1") + <-preDial + <-dialStarted + + // get req2 to waiting on conns per host to go down below max + go doReq("req2") + <-preDial + select { + case <-dialStarted: + t.Error("req2 dial started while req1 dial in progress") + return + default: + } + + // let req1 complete + stallDial <- struct{}{} + <-reqComplete + + // let req2 complete + <-dialStarted + stallDial <- struct{}{} + <-reqComplete +} + func TestTransportRemovesDeadIdleConnections(t *testing.T) { setParallel(t) defer afterTest(t) @@ -3118,7 +3186,7 @@ func TestRoundTripReturnsProxyError(t *testing.T) { func TestTransportCloseIdleConnsThenReturn(t *testing.T) { tr := &Transport{} wantIdle := func(when string, n int) bool { - got := tr.IdleConnCountForTesting("|http|example.com") // key used by PutIdleTestConn + got := tr.IdleConnCountForTesting("http", "example.com") // key used by PutIdleTestConn if got == n { return true } @@ -3126,10 +3194,10 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) { return false } wantIdle("start", 0) - if !tr.PutIdleTestConn() { + if !tr.PutIdleTestConn("http", "example.com") { t.Fatal("put failed") } - if !tr.PutIdleTestConn() { + if !tr.PutIdleTestConn("http", "example.com") { t.Fatal("second put failed") } wantIdle("after put", 2) @@ -3138,7 +3206,7 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) { t.Error("should be idle after CloseIdleConnections") } wantIdle("after close idle", 0) - if tr.PutIdleTestConn() { + if tr.PutIdleTestConn("http", "example.com") { t.Fatal("put didn't fail") } wantIdle("after second put", 0) @@ -3147,7 +3215,7 @@ func TestTransportCloseIdleConnsThenReturn(t *testing.T) { if tr.IsIdleForTesting() { t.Error("shouldn't be idle after RequestIdleConnChForTesting") } - if !tr.PutIdleTestConn() { + if !tr.PutIdleTestConn("http", "example.com") { t.Fatal("after re-activation") } wantIdle("after final put", 1)