diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index 6efe191eb0b..485aff89790 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -603,6 +603,10 @@ func (pc *persistConn) readLoop() { // before we race and peek on the underlying bufio reader. if waitForBodyRead != nil { <-waitForBodyRead + } else if !alive { + // If waitForBodyRead is nil, and we're not alive, we + // must close the connection before we leave the loop. + pc.close() } } } diff --git a/src/pkg/net/http/transport_test.go b/src/pkg/net/http/transport_test.go index a9e401de58d..ebf4a8102d7 100644 --- a/src/pkg/net/http/transport_test.go +++ b/src/pkg/net/http/transport_test.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "io/ioutil" + "net" . "net/http" "net/http/httptest" "net/url" @@ -20,6 +21,7 @@ import ( "runtime" "strconv" "strings" + "sync" "testing" "time" ) @@ -35,6 +37,64 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte(r.RemoteAddr)) }) +type testCloseConn struct { + net.Conn + set *testConnSet +} + +func (conn *testCloseConn) Close() error { + conn.set.remove(conn) + return conn.Conn.Close() +} + +type testConnSet struct { + set map[net.Conn]bool + mutex sync.Mutex +} + +func (tcs *testConnSet) insert(c net.Conn) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + tcs.set[c] = true +} + +func (tcs *testConnSet) remove(c net.Conn) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + // just change to false, so we have a full set of opened connections + tcs.set[c] = false +} + +// some tests use this to manage raw tcp connections for later inspection +func makeTestDial() (*testConnSet, func(n, addr string) (net.Conn, error)) { + connSet := &testConnSet{ + set: make(map[net.Conn]bool), + } + dial := func(n, addr string) (net.Conn, error) { + c, err := net.Dial(n, addr) + if err != nil { + return nil, err + } + tc := &testCloseConn{c, connSet} + connSet.insert(tc) + return tc, nil + } + return connSet, dial +} + +func (tcs *testConnSet) countClosed() (closed, total int) { + tcs.mutex.Lock() + defer tcs.mutex.Unlock() + + total = len(tcs.set) + for _, open := range tcs.set { + if !open { + closed += 1 + } + } + return +} + // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port func TestTransportKeepAlives(t *testing.T) { @@ -72,8 +132,12 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { ts := httptest.NewServer(hostPortHandler) defer ts.Close() + connSet, testDial := makeTestDial() + for _, connectionClose := range []bool{false, true} { - tr := &Transport{} + tr := &Transport{ + Dial: testDial, + } c := &Client{Transport: tr} fetch := func(n int) string { @@ -107,6 +171,13 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", connectionClose, bodiesDiffer, body1, body2) } + + tr.CloseIdleConnections() + } + + closed, total := connSet.countClosed() + if closed < total { + t.Errorf("%d out of %d tcp connections were not closed", total-closed, total) } } @@ -114,8 +185,12 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { ts := httptest.NewServer(hostPortHandler) defer ts.Close() + connSet, testDial := makeTestDial() + for _, connectionClose := range []bool{false, true} { - tr := &Transport{} + tr := &Transport{ + Dial: testDial, + } c := &Client{Transport: tr} fetch := func(n int) string { @@ -149,6 +224,13 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q", connectionClose, bodiesDiffer, body1, body2) } + + tr.CloseIdleConnections() + } + + closed, total := connSet.countClosed() + if closed < total { + t.Errorf("%d out of %d tcp connections were not closed", total-closed, total) } }