mirror of
https://github.com/golang/go
synced 2024-11-19 05:14:50 -07:00
net/http: fix TimeoutHandler data races; hold lock longer
The existing lock needed to be held longer. If a timeout occured while writing (but after the guarded timeout check), the writes would clobber a future connection's buffer. Also remove a harmless warning by making Write also set the flag that headers were sent (implicitly), so we don't try to write headers later (a no-op + warning) on timeout after we've started writing. Fixes #8414 Fixes #8209 LGTM=ruiu, adg R=adg, ruiu CC=golang-codereviews https://golang.org/cl/123610043
This commit is contained in:
parent
339a24da66
commit
e38fa91648
@ -15,6 +15,7 @@ import (
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"net"
|
||||
. "net/http"
|
||||
"net/http/httptest"
|
||||
@ -1188,6 +1189,82 @@ func TestTimeoutHandler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// See issues 8209 and 8414.
|
||||
func TestTimeoutHandlerRace(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
|
||||
delayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
ms, _ := strconv.Atoi(r.URL.Path[1:])
|
||||
if ms == 0 {
|
||||
ms = 1
|
||||
}
|
||||
for i := 0; i < ms; i++ {
|
||||
w.Write([]byte("hi"))
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
})
|
||||
|
||||
ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, ""))
|
||||
defer ts.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
gate := make(chan bool, 10)
|
||||
n := 50
|
||||
if testing.Short() {
|
||||
n = 10
|
||||
gate = make(chan bool, 3)
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
gate <- true
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() { <-gate }()
|
||||
res, err := Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50)))
|
||||
if err == nil {
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
res.Body.Close()
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// See issues 8209 and 8414.
|
||||
func TestTimeoutHandlerRaceHeader(t *testing.T) {
|
||||
defer afterTest(t)
|
||||
|
||||
delay204 := HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
w.WriteHeader(204)
|
||||
})
|
||||
|
||||
ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, ""))
|
||||
defer ts.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
gate := make(chan bool, 50)
|
||||
n := 500
|
||||
if testing.Short() {
|
||||
n = 10
|
||||
}
|
||||
for i := 0; i < n; i++ {
|
||||
gate <- true
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() { <-gate }()
|
||||
res, err := Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer res.Body.Close()
|
||||
io.Copy(ioutil.Discard, res.Body)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// Verifies we don't path.Clean() on the wrong parts in redirects.
|
||||
func TestRedirectMunging(t *testing.T) {
|
||||
req, _ := NewRequest("GET", "http://example.com/", nil)
|
||||
|
@ -1916,9 +1916,9 @@ func (tw *timeoutWriter) Header() Header {
|
||||
|
||||
func (tw *timeoutWriter) Write(p []byte) (int, error) {
|
||||
tw.mu.Lock()
|
||||
timedOut := tw.timedOut
|
||||
tw.mu.Unlock()
|
||||
if timedOut {
|
||||
defer tw.mu.Unlock()
|
||||
tw.wroteHeader = true // implicitly at least
|
||||
if tw.timedOut {
|
||||
return 0, ErrHandlerTimeout
|
||||
}
|
||||
return tw.w.Write(p)
|
||||
@ -1926,12 +1926,11 @@ func (tw *timeoutWriter) Write(p []byte) (int, error) {
|
||||
|
||||
func (tw *timeoutWriter) WriteHeader(code int) {
|
||||
tw.mu.Lock()
|
||||
defer tw.mu.Unlock()
|
||||
if tw.timedOut || tw.wroteHeader {
|
||||
tw.mu.Unlock()
|
||||
return
|
||||
}
|
||||
tw.wroteHeader = true
|
||||
tw.mu.Unlock()
|
||||
tw.w.WriteHeader(code)
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user