1
0
mirror of https://github.com/golang/go synced 2024-11-23 19:40:08 -07:00

net/http: fix race in TimeoutHandler

New implementation of TimeoutHandler: buffer everything to memory.

All or nothing: either the handler finishes completely within the
timeout (in which case the wrapper writes it all), or it misses the
timeout and none of it gets written, in which case handler wrapper can
reliably print the error response without fear that some of the
wrapped Handler's code already wrote to the output.

Now the goroutine running the wrapped Handler has its own write buffer
and Header copy.

Document the limitations.

Fixes #9162

Change-Id: Ia058c1d62cefd11843e7a2fc1ae1609d75de2441
Reviewed-on: https://go-review.googlesource.com/17752
Reviewed-by: David Symonds <dsymonds@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Brad Fitzpatrick 2015-12-14 01:04:07 +00:00
parent 24a7955c74
commit 0478f7b9d6
3 changed files with 111 additions and 22 deletions

View File

@ -58,10 +58,11 @@ func SetPendingDialHooks(before, after func()) {
func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn } func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
f := func() <-chan time.Time { return &timeoutHandler{
return ch handler: handler,
timeout: func() <-chan time.Time { return ch },
// (no body and nil cancelTimer)
} }
return &timeoutHandler{handler, f, ""}
} }
func ResetCachedEnvironment() { func ResetCachedEnvironment() {

View File

@ -1785,6 +1785,60 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) {
wg.Wait() wg.Wait()
} }
// Issue 9162
func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
defer afterTest(t)
sendHi := make(chan bool, 1)
writeErrors := make(chan error, 1)
sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Type", "text/plain")
<-sendHi
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
timeout := make(chan time.Time, 1) // write to this to force timeouts
cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout))
defer cst.close()
// Succeed without timing out:
sendHi <- true
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
}
if g, e := res.StatusCode, StatusOK; g != e {
t.Errorf("got res.StatusCode %d; expected %d", g, e)
}
body, _ := ioutil.ReadAll(res.Body)
if g, e := string(body), "hi"; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
if g := <-writeErrors; g != nil {
t.Errorf("got unexpected Write error on first request: %v", g)
}
// Times out:
timeout <- time.Time{}
res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
}
if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
t.Errorf("got res.StatusCode %d; expected %d", g, e)
}
body, _ = ioutil.ReadAll(res.Body)
if !strings.Contains(string(body), "<title>Timeout</title>") {
t.Errorf("expected timeout body; got %q", string(body))
}
// Now make the previously-timed out handler speak again,
// which verifies the panic is handled:
sendHi <- true
if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
t.Errorf("expected Write error of %v; got %v", e, g)
}
}
// Verifies we don't path.Clean() on the wrong parts in redirects. // Verifies we don't path.Clean() on the wrong parts in redirects.
func TestRedirectMunging(t *testing.T) { func TestRedirectMunging(t *testing.T) {
req, _ := NewRequest("GET", "http://example.com/", nil) req, _ := NewRequest("GET", "http://example.com/", nil)

View File

@ -8,6 +8,7 @@ package http
import ( import (
"bufio" "bufio"
"bytes"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt" "fmt"
@ -2101,11 +2102,20 @@ func (srv *Server) onceSetNextProtoDefaults() {
// (If msg is empty, a suitable default message will be sent.) // (If msg is empty, a suitable default message will be sent.)
// After such a timeout, writes by h to its ResponseWriter will return // After such a timeout, writes by h to its ResponseWriter will return
// ErrHandlerTimeout. // ErrHandlerTimeout.
//
// TimeoutHandler buffers all Handler writes to memory and does not
// support the Hijacker or Flusher interfaces.
func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler { func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler {
f := func() <-chan time.Time { t := time.NewTimer(dt)
return time.After(dt) return &timeoutHandler{
handler: h,
body: msg,
// Effectively storing a *time.Timer, but decomposed
// for testing:
timeout: func() <-chan time.Time { return t.C },
cancelTimer: t.Stop,
} }
return &timeoutHandler{h, f, msg}
} }
// ErrHandlerTimeout is returned on ResponseWriter Write calls // ErrHandlerTimeout is returned on ResponseWriter Write calls
@ -2114,8 +2124,13 @@ var ErrHandlerTimeout = errors.New("http: Handler timeout")
type timeoutHandler struct { type timeoutHandler struct {
handler Handler handler Handler
timeout func() <-chan time.Time // returns channel producing a timeout
body string body string
// timeout returns the channel of a *time.Timer and
// cancelTimer cancels it. They're stored separately for
// testing purposes.
timeout func() <-chan time.Time // returns channel producing a timeout
cancelTimer func() bool // optional
} }
func (h *timeoutHandler) errorBody() string { func (h *timeoutHandler) errorBody() string {
@ -2126,46 +2141,61 @@ func (h *timeoutHandler) errorBody() string {
} }
func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
done := make(chan bool, 1) done := make(chan struct{})
tw := &timeoutWriter{w: w} tw := &timeoutWriter{
w: w,
h: make(Header),
}
go func() { go func() {
h.handler.ServeHTTP(tw, r) h.handler.ServeHTTP(tw, r)
done <- true close(done)
}() }()
select { select {
case <-done: case <-done:
return tw.mu.Lock()
defer tw.mu.Unlock()
dst := w.Header()
for k, vv := range tw.h {
dst[k] = vv
}
w.WriteHeader(tw.code)
w.Write(tw.wbuf.Bytes())
if h.cancelTimer != nil {
h.cancelTimer()
}
case <-h.timeout(): case <-h.timeout():
tw.mu.Lock() tw.mu.Lock()
defer tw.mu.Unlock() defer tw.mu.Unlock()
if !tw.wroteHeader { w.WriteHeader(StatusServiceUnavailable)
tw.w.WriteHeader(StatusServiceUnavailable) io.WriteString(w, h.errorBody())
tw.w.Write([]byte(h.errorBody()))
}
tw.timedOut = true tw.timedOut = true
return
} }
} }
type timeoutWriter struct { type timeoutWriter struct {
w ResponseWriter w ResponseWriter
h Header
wbuf bytes.Buffer
mu sync.Mutex mu sync.Mutex
timedOut bool timedOut bool
wroteHeader bool wroteHeader bool
code int
} }
func (tw *timeoutWriter) Header() Header { func (tw *timeoutWriter) Header() Header { return tw.h }
return tw.w.Header()
}
func (tw *timeoutWriter) Write(p []byte) (int, error) { func (tw *timeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock() tw.mu.Lock()
defer tw.mu.Unlock() defer tw.mu.Unlock()
tw.wroteHeader = true // implicitly at least
if tw.timedOut { if tw.timedOut {
return 0, ErrHandlerTimeout return 0, ErrHandlerTimeout
} }
return tw.w.Write(p) if !tw.wroteHeader {
tw.writeHeader(StatusOK)
}
return tw.wbuf.Write(p)
} }
func (tw *timeoutWriter) WriteHeader(code int) { func (tw *timeoutWriter) WriteHeader(code int) {
@ -2174,8 +2204,12 @@ func (tw *timeoutWriter) WriteHeader(code int) {
if tw.timedOut || tw.wroteHeader { if tw.timedOut || tw.wroteHeader {
return return
} }
tw.writeHeader(code)
}
func (tw *timeoutWriter) writeHeader(code int) {
tw.wroteHeader = true tw.wroteHeader = true
tw.w.WriteHeader(code) tw.code = code
} }
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted