diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go index 4f85ff55d8..1c0c0f6987 100644 --- a/src/net/http/httptest/server.go +++ b/src/net/http/httptest/server.go @@ -317,21 +317,17 @@ func (s *Server) wrap() { s.mu.Lock() defer s.mu.Unlock() - // Keep Close from returning until the user's ConnState hook - // (if any) finishes. Without this, the call to forgetConn - // below might send the count to 0 before we run the hook. - s.wg.Add(1) - defer s.wg.Done() - switch cs { case http.StateNew: - s.wg.Add(1) if _, exists := s.conns[c]; exists { panic("invalid state transition") } if s.conns == nil { s.conns = make(map[net.Conn]http.ConnState) } + // Add c to the set of tracked conns and increment it to the + // waitgroup. + s.wg.Add(1) s.conns[c] = cs if s.closed { // Probably just a socket-late-binding dial from @@ -358,7 +354,14 @@ func (s *Server) wrap() { s.closeConn(c) } case http.StateHijacked, http.StateClosed: - s.forgetConn(c) + // Remove c from the set of tracked conns and decrement it from the + // waitgroup, unless it was previously removed. + if _, ok := s.conns[c]; ok { + delete(s.conns, c) + // Keep Close from returning until the user's ConnState hook + // (if any) finishes. + defer s.wg.Done() + } } if oldHook != nil { oldHook(c, cs) @@ -378,13 +381,3 @@ func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) { done <- struct{}{} } } - -// forgetConn removes c from the set of tracked conns and decrements it from the -// waitgroup, unless it was previously removed. -// s.mu must be held. -func (s *Server) forgetConn(c net.Conn) { - if _, ok := s.conns[c]; ok { - delete(s.conns, c) - s.wg.Done() - } -} diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go index 39568b358c..5313f65456 100644 --- a/src/net/http/httptest/server_test.go +++ b/src/net/http/httptest/server_test.go @@ -9,6 +9,7 @@ import ( "io" "net" "net/http" + "sync" "testing" ) @@ -203,6 +204,59 @@ func TestServerZeroValueClose(t *testing.T) { ts.Close() // tests that it doesn't panic } +// Issue 51799: test hijacking a connection and then closing it +// concurrently with closing the server. +func TestCloseHijackedConnection(t *testing.T) { + hijacked := make(chan net.Conn) + ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer close(hijacked) + hj, ok := w.(http.Hijacker) + if !ok { + t.Fatal("failed to hijack") + } + c, _, err := hj.Hijack() + if err != nil { + t.Fatal(err) + } + hijacked <- c + })) + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + req, err := http.NewRequest("GET", ts.URL, nil) + if err != nil { + t.Log(err) + } + // Use a client not associated with the Server. + var c http.Client + resp, err := c.Do(req) + if err != nil { + t.Log(err) + return + } + resp.Body.Close() + }() + + wg.Add(1) + conn := <-hijacked + go func(conn net.Conn) { + defer wg.Done() + // Close the connection and then inform the Server that + // we closed it. + conn.Close() + ts.Config.ConnState(conn, http.StateClosed) + }(conn) + + wg.Add(1) + go func() { + defer wg.Done() + ts.Close() + }() + wg.Wait() +} + func TestTLSServerWithHTTP2(t *testing.T) { modes := []struct { name string