mirror of
https://github.com/golang/go
synced 2024-11-17 07:54:41 -07:00
net/http: make Transport.MaxConnsPerHost work for HTTP/2
Treat HTTP/2 connections as an ongoing persistent connection. When we are told there is no cached connections, cleanup the associated connection and host connection count. Fixes #27753 Change-Id: I6b7bd915fc7819617cb5d3b35e46e225c75eda29 Reviewed-on: https://go-review.googlesource.com/c/go/+/140357 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
c706d42289
commit
43b9fcf6fe
@ -506,8 +506,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
|
||||
var resp *Response
|
||||
if pconn.alt != nil {
|
||||
// HTTP/2 path.
|
||||
t.decHostConnCount(cm.key()) // don't count cached http2 conns toward conns per host
|
||||
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
|
||||
t.putOrCloseIdleConn(pconn)
|
||||
t.setReqCanceler(req, nil) // not cancelable with CancelRequest
|
||||
resp, err = pconn.alt.RoundTrip(req)
|
||||
} else {
|
||||
resp, err = pconn.roundTrip(treq)
|
||||
@ -515,7 +515,10 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
if !pconn.shouldRetryRequest(req, err) {
|
||||
if http2isNoCachedConnError(err) {
|
||||
t.removeIdleConn(pconn)
|
||||
t.decHostConnCount(cm.key()) // clean up the persistent connection
|
||||
} else if !pconn.shouldRetryRequest(req, err) {
|
||||
// Issue 16465: return underlying net.Conn.Read error from peek,
|
||||
// as we've historically done.
|
||||
if e, ok := err.(transportReadFromServerError); ok {
|
||||
@ -778,9 +781,6 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
|
||||
if pconn.isBroken() {
|
||||
return errConnBroken
|
||||
}
|
||||
if pconn.alt != nil {
|
||||
return errNotCachingH2Conn
|
||||
}
|
||||
pconn.markReused()
|
||||
key := pconn.cacheKey
|
||||
|
||||
@ -829,7 +829,10 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
|
||||
if pconn.idleTimer != nil {
|
||||
pconn.idleTimer.Reset(t.IdleConnTimeout)
|
||||
} else {
|
||||
pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
|
||||
// idleTimer does not apply to HTTP/2
|
||||
if pconn.alt == nil {
|
||||
pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle)
|
||||
}
|
||||
}
|
||||
}
|
||||
pconn.idleAt = time.Now()
|
||||
@ -1377,7 +1380,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
|
||||
|
||||
if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
|
||||
if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok {
|
||||
return &persistConn{alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
|
||||
return &persistConn{cacheKey: pconn.cacheKey, alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -588,6 +588,107 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) {
|
||||
<-reqComplete
|
||||
}
|
||||
|
||||
func TestTransportMaxConnsPerHost(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
if runtime.GOOS == "js" {
|
||||
t.Skipf("skipping test on js/wasm")
|
||||
}
|
||||
h := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
_, err := w.Write([]byte("foo"))
|
||||
if err != nil {
|
||||
t.Fatalf("Write: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
testMaxConns := func(scheme string, ts *httptest.Server) {
|
||||
defer ts.Close()
|
||||
|
||||
c := ts.Client()
|
||||
tr := c.Transport.(*Transport)
|
||||
tr.MaxConnsPerHost = 1
|
||||
if err := ExportHttp2ConfigureTransport(tr); err != nil {
|
||||
t.Fatalf("ExportHttp2ConfigureTransport: %v", err)
|
||||
}
|
||||
|
||||
connCh := make(chan net.Conn, 1)
|
||||
var dialCnt, gotConnCnt, tlsHandshakeCnt int32
|
||||
tr.Dial = func(network, addr string) (net.Conn, error) {
|
||||
atomic.AddInt32(&dialCnt, 1)
|
||||
c, err := net.Dial(network, addr)
|
||||
connCh <- c
|
||||
return c, err
|
||||
}
|
||||
|
||||
doReq := func() {
|
||||
trace := &httptrace.ClientTrace{
|
||||
GotConn: func(connInfo httptrace.GotConnInfo) {
|
||||
if !connInfo.Reused {
|
||||
atomic.AddInt32(&gotConnCnt, 1)
|
||||
}
|
||||
},
|
||||
TLSHandshakeStart: func() {
|
||||
atomic.AddInt32(&tlsHandshakeCnt, 1)
|
||||
},
|
||||
}
|
||||
req, _ := NewRequest("GET", ts.URL, nil)
|
||||
req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
|
||||
|
||||
resp, err := c.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
_, err = ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read body failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 10; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
doReq()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
expected := int32(tr.MaxConnsPerHost)
|
||||
if dialCnt != expected {
|
||||
t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
|
||||
}
|
||||
if gotConnCnt != expected {
|
||||
t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
|
||||
}
|
||||
if ts.TLS != nil && tlsHandshakeCnt != expected {
|
||||
t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
|
||||
}
|
||||
|
||||
(<-connCh).Close()
|
||||
|
||||
doReq()
|
||||
expected++
|
||||
if dialCnt != expected {
|
||||
t.Errorf("Too many dials (%s): %d", scheme, dialCnt)
|
||||
}
|
||||
if gotConnCnt != expected {
|
||||
t.Errorf("Too many get connections (%s): %d", scheme, gotConnCnt)
|
||||
}
|
||||
if ts.TLS != nil && tlsHandshakeCnt != expected {
|
||||
t.Errorf("Too many tls handshakes (%s): %d", scheme, tlsHandshakeCnt)
|
||||
}
|
||||
}
|
||||
|
||||
testMaxConns("http", httptest.NewServer(h))
|
||||
testMaxConns("https", httptest.NewTLSServer(h))
|
||||
|
||||
ts := httptest.NewUnstartedServer(h)
|
||||
ts.TLS = &tls.Config{NextProtos: []string{"h2"}}
|
||||
ts.StartTLS()
|
||||
testMaxConns("http2", ts)
|
||||
}
|
||||
|
||||
func TestTransportRemovesDeadIdleConnections(t *testing.T) {
|
||||
setParallel(t)
|
||||
defer afterTest(t)
|
||||
|
Loading…
Reference in New Issue
Block a user