mirror of
https://github.com/golang/go
synced 2024-11-24 10:50:13 -07:00
net/http/httptest: fix race in Server.Close
When run with race detector the test fails without the fix.
Fixes #51799
Change-Id: I273adb6d3a2b1e0d606b9c27ab4c6a9aa4aa8064
GitHub-Last-Rev: a5ddd146a2
GitHub-Pull-Request: golang/go#51805
Reviewed-on: https://go-review.googlesource.com/c/go/+/393974
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
3fd8b8627f
commit
1d19cea740
@ -317,21 +317,17 @@ func (s *Server) wrap() {
|
|||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// Keep Close from returning until the user's ConnState hook
|
|
||||||
// (if any) finishes. Without this, the call to forgetConn
|
|
||||||
// below might send the count to 0 before we run the hook.
|
|
||||||
s.wg.Add(1)
|
|
||||||
defer s.wg.Done()
|
|
||||||
|
|
||||||
switch cs {
|
switch cs {
|
||||||
case http.StateNew:
|
case http.StateNew:
|
||||||
s.wg.Add(1)
|
|
||||||
if _, exists := s.conns[c]; exists {
|
if _, exists := s.conns[c]; exists {
|
||||||
panic("invalid state transition")
|
panic("invalid state transition")
|
||||||
}
|
}
|
||||||
if s.conns == nil {
|
if s.conns == nil {
|
||||||
s.conns = make(map[net.Conn]http.ConnState)
|
s.conns = make(map[net.Conn]http.ConnState)
|
||||||
}
|
}
|
||||||
|
// Add c to the set of tracked conns and increment it to the
|
||||||
|
// waitgroup.
|
||||||
|
s.wg.Add(1)
|
||||||
s.conns[c] = cs
|
s.conns[c] = cs
|
||||||
if s.closed {
|
if s.closed {
|
||||||
// Probably just a socket-late-binding dial from
|
// Probably just a socket-late-binding dial from
|
||||||
@ -358,7 +354,14 @@ func (s *Server) wrap() {
|
|||||||
s.closeConn(c)
|
s.closeConn(c)
|
||||||
}
|
}
|
||||||
case http.StateHijacked, http.StateClosed:
|
case http.StateHijacked, http.StateClosed:
|
||||||
s.forgetConn(c)
|
// Remove c from the set of tracked conns and decrement it from the
|
||||||
|
// waitgroup, unless it was previously removed.
|
||||||
|
if _, ok := s.conns[c]; ok {
|
||||||
|
delete(s.conns, c)
|
||||||
|
// Keep Close from returning until the user's ConnState hook
|
||||||
|
// (if any) finishes.
|
||||||
|
defer s.wg.Done()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if oldHook != nil {
|
if oldHook != nil {
|
||||||
oldHook(c, cs)
|
oldHook(c, cs)
|
||||||
@ -378,13 +381,3 @@ func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) {
|
|||||||
done <- struct{}{}
|
done <- struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// forgetConn removes c from the set of tracked conns and decrements it from the
|
|
||||||
// waitgroup, unless it was previously removed.
|
|
||||||
// s.mu must be held.
|
|
||||||
func (s *Server) forgetConn(c net.Conn) {
|
|
||||||
if _, ok := s.conns[c]; ok {
|
|
||||||
delete(s.conns, c)
|
|
||||||
s.wg.Done()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -203,6 +204,59 @@ func TestServerZeroValueClose(t *testing.T) {
|
|||||||
ts.Close() // tests that it doesn't panic
|
ts.Close() // tests that it doesn't panic
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Issue 51799: test hijacking a connection and then closing it
|
||||||
|
// concurrently with closing the server.
|
||||||
|
func TestCloseHijackedConnection(t *testing.T) {
|
||||||
|
hijacked := make(chan net.Conn)
|
||||||
|
ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
defer close(hijacked)
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
t.Fatal("failed to hijack")
|
||||||
|
}
|
||||||
|
c, _, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
hijacked <- c
|
||||||
|
}))
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
req, err := http.NewRequest("GET", ts.URL, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
}
|
||||||
|
// Use a client not associated with the Server.
|
||||||
|
var c http.Client
|
||||||
|
resp, err := c.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Log(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
conn := <-hijacked
|
||||||
|
go func(conn net.Conn) {
|
||||||
|
defer wg.Done()
|
||||||
|
// Close the connection and then inform the Server that
|
||||||
|
// we closed it.
|
||||||
|
conn.Close()
|
||||||
|
ts.Config.ConnState(conn, http.StateClosed)
|
||||||
|
}(conn)
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
ts.Close()
|
||||||
|
}()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func TestTLSServerWithHTTP2(t *testing.T) {
|
func TestTLSServerWithHTTP2(t *testing.T) {
|
||||||
modes := []struct {
|
modes := []struct {
|
||||||
name string
|
name string
|
||||||
|
Loading…
Reference in New Issue
Block a user