1
0
mirror of https://github.com/golang/go synced 2024-11-24 06:20:02 -07:00

net/http: distinguish between timeouts and client hangups in TimeoutHandler

Fixes #48948

Change-Id: I411e3be99c7979ae289fd937388aae63d81adb59
GitHub-Last-Rev: 14abd7e4d7
GitHub-Pull-Request: golang/go#48993
Reviewed-on: https://go-review.googlesource.com/c/go/+/356009
Reviewed-by: Damien Neil <dneil@google.com>
Trust: Damien Neil <dneil@google.com>
Trust: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
This commit is contained in:
Charlie Getzen 2021-11-05 17:27:35 +00:00 committed by Damien Neil
parent 091948a55f
commit 4c7cafdd03
3 changed files with 71 additions and 19 deletions

View File

@ -88,12 +88,7 @@ func SetPendingDialHooks(before, after func()) {
func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
ctx, cancel := context.WithCancel(context.Background())
go func() {
<-ch
cancel()
}()
func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler {
return &timeoutHandler{
handler: handler,
testContext: ctx,

View File

@ -2274,6 +2274,18 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
}
}
// cancelableTimeoutContext overwrites the error message to DeadlineExceeded
type cancelableTimeoutContext struct {
context.Context
}
func (c cancelableTimeoutContext) Err() error {
if c.Context.Err() != nil {
return context.DeadlineExceeded
}
return nil
}
func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) }
func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) }
func testTimeoutHandler(t *testing.T, h2 bool) {
@ -2286,8 +2298,9 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
timeout := make(chan time.Time, 1) // write to this to force timeouts
cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout))
ctx, cancel := context.WithCancel(context.Background())
h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
cst := newClientServerTest(t, h2, h)
defer cst.close()
// Succeed without timing out:
@ -2308,7 +2321,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
}
// Times out:
timeout <- time.Time{}
cancel()
res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
@ -2429,8 +2443,9 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
timeout := make(chan time.Time, 1) // write to this to force timeouts
cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout))
ctx, cancel := context.WithCancel(context.Background())
h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
cst := newClientServerTest(t, h1Mode, h)
defer cst.close()
// Succeed without timing out:
@ -2451,7 +2466,8 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
}
// Times out:
timeout <- time.Time{}
cancel()
res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
@ -2501,6 +2517,41 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
}
}
func TestTimeoutHandlerContextCanceled(t *testing.T) {
setParallel(t)
defer afterTest(t)
sendHi := make(chan bool, 1)
writeErrors := make(chan error, 1)
sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Type", "text/plain")
<-sendHi
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Hour)
h := NewTestTimeoutHandler(sayHi, ctx)
cancel()
cst := newClientServerTest(t, h1Mode, h)
defer cst.close()
// Succeed without timing out:
sendHi <- true
res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
}
if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
t.Errorf("got res.StatusCode %d; expected %d", g, e)
}
body, _ := io.ReadAll(res.Body)
if g, e := string(body), ""; g != e {
t.Errorf("got body %q; expected %q", g, e)
}
if g, e := <-writeErrors, context.Canceled; g != e {
t.Errorf("got unexpected Write error on first request: %v", g)
}
}
// https://golang.org/issue/15948
func TestTimeoutHandlerEmptyResponse(t *testing.T) {
setParallel(t)

View File

@ -3391,9 +3391,15 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
case <-ctx.Done():
tw.mu.Lock()
defer tw.mu.Unlock()
w.WriteHeader(StatusServiceUnavailable)
io.WriteString(w, h.errorBody())
tw.timedOut = true
switch err := ctx.Err(); err {
case context.DeadlineExceeded:
w.WriteHeader(StatusServiceUnavailable)
io.WriteString(w, h.errorBody())
tw.err = ErrHandlerTimeout
default:
w.WriteHeader(StatusServiceUnavailable)
tw.err = err
}
}
}
@ -3404,7 +3410,7 @@ type timeoutWriter struct {
req *Request
mu sync.Mutex
timedOut bool
err error
wroteHeader bool
code int
}
@ -3424,8 +3430,8 @@ func (tw *timeoutWriter) Header() Header { return tw.h }
func (tw *timeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
if tw.timedOut {
return 0, ErrHandlerTimeout
if tw.err != nil {
return 0, tw.err
}
if !tw.wroteHeader {
tw.writeHeaderLocked(StatusOK)
@ -3437,7 +3443,7 @@ func (tw *timeoutWriter) writeHeaderLocked(code int) {
checkWriteHeaderCode(code)
switch {
case tw.timedOut:
case tw.err != nil:
return
case tw.wroteHeader:
if tw.req != nil {