From a524b8725374e4ebbb7fe3da85f407ee24141d51 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Tue, 14 May 2024 09:55:11 -0700 Subject: [PATCH] net/http: avoid panic when writing 100-continue after handler done When a request contains an "Expect: 100-continue" header, the first read from the request body causes the server to write a 100-continue status. This write caused a panic when performed after the server handler has exited. Disable the write when cleaning up after a handler exits. This also fixes a bug where an implicit 100-continue could be sent after a call to WriteHeader has sent a non-1xx header. This change drops tracking of whether we've written a 100-continue or not in response.wroteContinue. This tracking was used to determine whether we should consume the remaining request body in chunkWriter.writeHeader, but the discard-the-body path was only taken when the body was already consumed. (If the body is not consumed, we set closeAfterReply, and we don't consume the remaining body when closeAfterReply is set. If the body is consumed, then we may attempt to discard the remaining body, but there is obviously no body remaining.) Fixes #53808 Change-Id: I3542df26ad6cdfe93b50a45ae2d6e7ef031e46fa Reviewed-on: https://go-review.googlesource.com/c/go/+/585395 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- src/net/http/serve_test.go | 76 ++++++++++++++++++++++++++++++++++++++ src/net/http/server.go | 50 ++++++++++++++----------- 2 files changed, 105 insertions(+), 21 deletions(-) diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index e21af8b159..f454dcdbed 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -7166,3 +7166,79 @@ func TestError(t *testing.T) { t.Errorf("X-Content-Type-Options: %q, want %q", v, "nosniff") } } + +func TestServerReadAfterWriteHeader100Continue(t *testing.T) { + run(t, testServerReadAfterWriteHeader100Continue) +} +func testServerReadAfterWriteHeader100Continue(t *testing.T, mode testMode) { + body := []byte("body") + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(200) + NewResponseController(w).Flush() + io.ReadAll(r.Body) + w.Write(body) + })) + + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body")) + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err != nil { + t.Fatalf("Get(%q) = %v", cst.ts.URL, err) + } + defer res.Body.Close() + got, err := io.ReadAll(res.Body) + if err != nil { + t.Fatalf("io.ReadAll(res.Body) = %v", err) + } + if !bytes.Equal(got, body) { + t.Fatalf("response body = %q, want %q", got, body) + } +} + +func TestServerReadAfterHandlerDone100Continue(t *testing.T) { + run(t, testServerReadAfterHandlerDone100Continue) +} +func testServerReadAfterHandlerDone100Continue(t *testing.T, mode testMode) { + readyc := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + go func() { + <-readyc + io.ReadAll(r.Body) + <-readyc + }() + })) + + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body")) + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err != nil { + t.Fatalf("Get(%q) = %v", cst.ts.URL, err) + } + res.Body.Close() + readyc <- struct{}{} // server starts reading from the request body + readyc <- struct{}{} // server finishes reading from the request body +} + +func TestServerReadAfterHandlerAbort100Continue(t *testing.T) { + run(t, testServerReadAfterHandlerAbort100Continue) +} +func testServerReadAfterHandlerAbort100Continue(t *testing.T, mode testMode) { + readyc := make(chan struct{}) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + go func() { + <-readyc + io.ReadAll(r.Body) + <-readyc + }() + panic(ErrAbortHandler) + })) + + req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader("body")) + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err == nil { + res.Body.Close() + } + readyc <- struct{}{} // server starts reading from the request body + readyc <- struct{}{} // server finishes reading from the request body +} diff --git a/src/net/http/server.go b/src/net/http/server.go index a50b20b7da..b76c869567 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -425,7 +425,6 @@ type response struct { reqBody io.ReadCloser cancelCtx context.CancelFunc // when ServeHTTP exits wroteHeader bool // a non-1xx header has been (logically) written - wroteContinue bool // 100 Continue response was written wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" wantsClose bool // HTTP request has Connection "close" @@ -436,8 +435,8 @@ type response struct { // These two fields together synchronize the body reader (the // expectContinueReader, which wants to write 100 Continue) // against the main writer. - canWriteContinue atomic.Bool writeContinueMu sync.Mutex + canWriteContinue atomic.Bool w *bufio.Writer // buffers output in chunks to chunkWriter cw chunkWriter @@ -565,6 +564,14 @@ func (w *response) requestTooLarge() { } } +// disableWriteContinue stops Request.Body.Read from sending an automatic 100-Continue. +// If a 100-Continue is being written, it waits for it to complete before continuing. +func (w *response) disableWriteContinue() { + w.writeContinueMu.Lock() + w.canWriteContinue.Store(false) + w.writeContinueMu.Unlock() +} + // writerOnly hides an io.Writer value's optional ReadFrom method // from io.Copy. type writerOnly struct { @@ -917,8 +924,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { return 0, ErrBodyReadAfterClose } w := ecr.resp - if !w.wroteContinue && w.canWriteContinue.Load() && !w.conn.hijacked() { - w.wroteContinue = true + if w.canWriteContinue.Load() { w.writeContinueMu.Lock() if w.canWriteContinue.Load() { w.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") @@ -1159,18 +1165,17 @@ func (w *response) WriteHeader(code int) { } checkWriteHeaderCode(code) + if code < 101 || code > 199 { + // Sending a 100 Continue or any non-1xx header disables the + // automatically-sent 100 Continue from Request.Body.Read. + w.disableWriteContinue() + } + // Handle informational headers. // // We shouldn't send any further headers after 101 Switching Protocols, // so it takes the non-informational path. if code >= 100 && code <= 199 && code != StatusSwitchingProtocols { - // Prevent a potential race with an automatically-sent 100 Continue triggered by Request.Body.Read() - if code == 100 && w.canWriteContinue.Load() { - w.writeContinueMu.Lock() - w.canWriteContinue.Store(false) - w.writeContinueMu.Unlock() - } - writeStatusLine(w.conn.bufw, w.req.ProtoAtLeast(1, 1), code, w.statusBuf[:]) // Per RFC 8297 we must not clear the current header map @@ -1378,14 +1383,20 @@ func (cw *chunkWriter) writeHeader(p []byte) { // // If full duplex mode has been enabled with ResponseController.EnableFullDuplex, // then leave the request body alone. + // + // We don't take this path when w.closeAfterReply is set. + // We may not need to consume the request to get ready for the next one + // (since we're closing the conn), but a client which sends a full request + // before reading a response may deadlock in this case. + // This behavior has been present since CL 5268043 (2011), however, + // so it doesn't seem to be causing problems. if w.req.ContentLength != 0 && !w.closeAfterReply && !w.fullDuplex { var discard, tooBig bool switch bdy := w.req.Body.(type) { case *expectContinueReader: - if bdy.resp.wroteContinue { - discard = true - } + // We only get here if we have already fully consumed the request body + // (see above). case *body: bdy.mu.Lock() switch { @@ -1626,13 +1637,8 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er } if w.canWriteContinue.Load() { - // Body reader wants to write 100 Continue but hasn't yet. - // Tell it not to. The store must be done while holding the lock - // because the lock makes sure that there is not an active write - // this very moment. - w.writeContinueMu.Lock() - w.canWriteContinue.Store(false) - w.writeContinueMu.Unlock() + // Body reader wants to write 100 Continue but hasn't yet. Tell it not to. + w.disableWriteContinue() } if !w.wroteHeader { @@ -1900,6 +1906,7 @@ func (c *conn) serve(ctx context.Context) { } if inFlightResponse != nil { inFlightResponse.cancelCtx() + inFlightResponse.disableWriteContinue() } if !c.hijacked() { if inFlightResponse != nil { @@ -2106,6 +2113,7 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if w.handlerDone.Load() { panic("net/http: Hijack called after ServeHTTP finished") } + w.disableWriteContinue() if w.wroteHeader { w.cw.flush() }