From 99fb19194c03c618c0d8faa87b91ba419ae28ee3 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Tue, 1 Dec 2015 19:07:41 +0000 Subject: [PATCH] net/http: rework CloseNotifier implementation, clarify expectations in docs CloseNotifier wasn't well specified previously. This CL simplifies its implementation, clarifies the public documentation on CloseNotifier, clarifies internal documentation on conn, and fixes two CloseNotifier bugs in the process. The main change, though, is tightening the rules and expectations for using CloseNotifier: * the caller must consume the Request.Body first (old rule, unwritten) * the received value is the "true" value (old rule, unwritten) * no promises for channel sends after Handler returns (old rule, unwritten) * a subsequent pipelined request fires the CloseNotifier (new behavior; previously it never fired and thus effectively deadlocked as in #13165) * advise that it should only be used without HTTP/1.1 pipelining (use HTTP/2 or non-idempotent browsers). Not that browsers actually use pipelining. The main implementation change is that each Handler now gets its own CloseNotifier channel value, rather than sharing one between the whole conn. This means Handlers can't affect subsequent requests. This is how HTTP/2's Server works too. The old docs never clarified a behavior either way. The other side effect of each request getting its own CloseNotifier channel is that one handler can't "poison" the underlying conn preventing subsequent requests on the same connection from using CloseNotifier (this is #9763). In the old implementation, once any request on a connection used ClosedNotifier, the conn's underlying bufio.Reader source was switched from the TCPConn to the read side of the pipe being fed by a never-ending copy. Since it was impossible to abort that never-ending copy, we could never get back to a fresh state where it was possible to return the underlying TCPConn to callers of Hijack. Now, instead of a never-ending Copy, the background goroutine doing a Read from the TCPConn (or *tls.Conn) only reads a single byte. That single byte can be in the request body, a socket timeout error, io.EOF error, or the first byte of the second body. In any case, the new *connReader type stitches sync and async reads together like an io.MultiReader. To clarify the flow of Read data and combat the complexity of too many wrapper Reader types, the *connReader absorbs the io.LimitReader previously used for bounding request header reads. The liveSwitchReader type is removed. (an unused switchWriter type is also removed) Many fields on *conn are also documented more fully. Fixes #9763 (CloseNotify + Hijack together) Fixes #13165 (deadlock with CloseNotify + pipelined requests) Change-Id: I40abc0a1992d05b294d627d1838c33cbccb9dd65 Reviewed-on: https://go-review.googlesource.com/17750 Reviewed-by: Russ Cox Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- src/net/http/serve_test.go | 140 ++++++++++++-- src/net/http/server.go | 380 ++++++++++++++++++++++--------------- 2 files changed, 360 insertions(+), 160 deletions(-) diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index a98a4ccf3b..3e84f2e11d 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -97,6 +97,7 @@ func (c *rwTestConn) Close() error { } type testConn struct { + readMu sync.Mutex // for TestHandlerBodyClose readBuf bytes.Buffer writeBuf bytes.Buffer closec chan bool // if non-nil, send value to it on close @@ -104,6 +105,8 @@ type testConn struct { } func (c *testConn) Read(b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() return c.readBuf.Read(b) } @@ -1246,15 +1249,21 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) { done := make(chan bool) + readBufLen := func() int { + conn.readMu.Lock() + defer conn.readMu.Unlock() + return conn.readBuf.Len() + } + ls := &oneConnListener{conn} go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { defer close(done) - if conn.readBuf.Len() < len(body)/2 { - t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len()) + if bufLen := readBufLen(); bufLen < len(body)/2 { + t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen) } rw.WriteHeader(200) rw.(Flusher).Flush() - if g, e := conn.readBuf.Len(), 0; g != e { + if g, e := readBufLen(), 0; g != e { t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e) } if c := rw.Header().Get("Connection"); c != "" { @@ -1430,15 +1439,21 @@ func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) { } conn.closec = make(chan bool, 1) + readBufLen := func() int { + conn.readMu.Lock() + defer conn.readMu.Unlock() + return conn.readBuf.Len() + } + ls := &oneConnListener{conn} var numReqs int var size0, size1 int go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { numReqs++ if numReqs == 1 { - size0 = conn.readBuf.Len() + size0 = readBufLen() req.Body.Close() - size1 = conn.readBuf.Len() + size1 = readBufLen() } })) <-conn.closec @@ -1538,7 +1553,9 @@ type slowTestConn struct { // over multiple calls to Read, time.Durations are slept, strings are read. script []interface{} closec chan bool - rd, wd time.Time // read, write deadline + + mu sync.Mutex // guards rd/wd + rd, wd time.Time // read, write deadline noopConn } @@ -1549,16 +1566,22 @@ func (c *slowTestConn) SetDeadline(t time.Time) error { } func (c *slowTestConn) SetReadDeadline(t time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() c.rd = t return nil } func (c *slowTestConn) SetWriteDeadline(t time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() c.wd = t return nil } func (c *slowTestConn) Read(b []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() restart: if !c.rd.IsZero() && time.Now().After(c.rd) { return 0, syscall.ETIMEDOUT @@ -2330,6 +2353,49 @@ For: ts.Close() } +// Tests that a pipelined request causes the first request's Handler's CloseNotify +// channel to fire. Previously it deadlocked. +// +// Issue 13165 +func TestCloseNotifierPipelined(t *testing.T) { + defer afterTest(t) + gotReq := make(chan bool, 2) + sawClose := make(chan bool, 2) + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + gotReq <- true + cc := rw.(CloseNotifier).CloseNotify() + <-cc + sawClose <- true + })) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + diec := make(chan bool, 2) + go func() { + const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n" + _, err = io.WriteString(conn, req+req) // two requests + if err != nil { + t.Fatal(err) + } + <-diec + conn.Close() + }() +For: + for { + select { + case <-gotReq: + diec <- true + case <-sawClose: + break For + case <-time.After(5 * time.Second): + ts.CloseClientConnections() + t.Fatal("timeout") + } + } + ts.Close() +} + func TestCloseNotifierChanLeak(t *testing.T) { defer afterTest(t) req := reqBytes("GET / HTTP/1.0\nHost: golang.org") @@ -2352,6 +2418,61 @@ func TestCloseNotifierChanLeak(t *testing.T) { } } +// Tests that we can use CloseNotifier in one request, and later call Hijack +// on a second request on the same connection. +// +// It also tests that the connReader stitches together its background +// 1-byte read for CloseNotifier when CloseNotifier doesn't fire with +// the rest of the second HTTP later. +// +// Issue 9763. +// HTTP/1-only test. (http2 doesn't have Hijack) +func TestHijackAfterCloseNotifier(t *testing.T) { + defer afterTest(t) + script := make(chan string, 2) + script <- "closenotify" + script <- "hijack" + close(script) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + plan := <-script + switch plan { + default: + panic("bogus plan; too many requests") + case "closenotify": + w.(CloseNotifier).CloseNotify() // discard result + w.Header().Set("X-Addr", r.RemoteAddr) + case "hijack": + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack in Handler: %v", err) + return + } + if _, ok := c.(*net.TCPConn); !ok { + // Verify it's not wrapped in some type. + // Not strictly a go1 compat issue, but in practice it probably is. + t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c) + } + fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr) + c.Close() + return + } + })) + defer ts.Close() + res1, err := Get(ts.URL) + if err != nil { + log.Fatal(err) + } + res2, err := Get(ts.URL) + if err != nil { + log.Fatal(err) + } + addr1 := res1.Header.Get("X-Addr") + addr2 := res2.Header.Get("X-Addr") + if addr1 == "" || addr1 != addr2 { + t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2) + } +} + func TestOptions(t *testing.T) { uric := make(chan string, 2) // only expect 1, but leave space for 2 mux := NewServeMux() @@ -2702,7 +2823,7 @@ func TestHTTP10ConnectionHeader(t *testing.T) { defer afterTest(t) mux := NewServeMux() - mux.Handle("/", HandlerFunc(func(resp ResponseWriter, req *Request) {})) + mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {})) ts := httptest.NewServer(mux) defer ts.Close() @@ -3248,10 +3369,7 @@ func (c *closeWriteTestConn) CloseWrite() error { func TestCloseWrite(t *testing.T) { var srv Server var testConn closeWriteTestConn - c, err := ExportServerNewConn(&srv, &testConn) - if err != nil { - t.Fatal(err) - } + c := ExportServerNewConn(&srv, &testConn) ExportCloseWriteAndWait(c) if !testConn.didCloseWrite { t.Error("didn't see CloseWrite call") diff --git a/src/net/http/server.go b/src/net/http/server.go index 35f41e734e..cd5f9cf34f 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -115,28 +115,76 @@ type Hijacker interface { // This mechanism can be used to cancel long operations on the server // if the client has disconnected before the response is ready. type CloseNotifier interface { - // CloseNotify returns a channel that receives a single value - // when the client connection has gone away. + // CloseNotify returns a channel that receives at most a + // single value (true) when the client connection has gone + // away. + // + // CloseNotify is undefined before Request.Body has been + // fully read. + // + // After the Handler has returned, there is no guarantee + // that the channel receives a value. + // + // If the protocol is HTTP/1.1 and CloseNotify is called while + // processing an idempotent request (such a GET) while + // HTTP/1.1 pipelining is in use, the arrival of a subsequent + // pipelined request will cause a value to be sent on the + // returned channel. In practice HTTP/1.1 pipelining is not + // enabled in browsers and not seen often in the wild. If this + // is a problem, use HTTP/2 or only use CloseNotify on methods + // such as POST. CloseNotify() <-chan bool } // A conn represents the server side of an HTTP connection. type conn struct { - remoteAddr string // network address of remote side - server *Server // the Server on which the connection arrived - rwc net.Conn // i/o connection - w io.Writer // checkConnErrorWriter's copy of wrc, not zeroed on Hijack - werr error // any errors writing to w - sr liveSwitchReader // where the LimitReader reads from; usually the rwc - lr *io.LimitedReader // io.LimitReader(sr) - buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc - tlsState *tls.ConnectionState // or nil when not using TLS - lastMethod string // method of previous request, or "" + // server is the server on which the connection arrived. + // Immutable; never nil. + server *Server - mu sync.Mutex // guards the following - clientGone bool // if client has disconnected mid-request - closeNotifyc chan bool // made lazily - hijackedv bool // connection has been hijacked by handler + // rwc is the underlying network connection. + // This is never wrapped by other types and is the value given out + // to CloseNotifier callers. It is usually of type *net.TCPConn or + // *tls.Conn. + rwc net.Conn + + // remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously + // inside the Listener's Accept goroutine, as some implementations block. + // It is populated immediately inside the (*conn).serve goroutine. + // This is the value of a Handler's (*Request).RemoteAddr. + remoteAddr string + + // tlsState is the TLS connection state when using TLS. + // nil means not TLS. + tlsState *tls.ConnectionState + + // werr is set to the first write error to rwc. + // It is set via checkConnErrorWriter{w}, where bufw writes. + werr error + + // r is bufr's read source. It's a wrapper around rwc that provides + // io.LimitedReader-style limiting (while reading request headers) + // and functionality to support CloseNotifier. See *connReader docs. + r *connReader + + // bufr reads from r. + // Users of bufr must hold mu. + bufr *bufio.Reader + + // bufw writes to checkConnErrorWriter{c}, which populates werr on error. + bufw *bufio.Writer + + // lastMethod is the method of the most recent request + // on this connection, if any. + lastMethod string + + // mu guards hijackedv, use of bufr, (*response).closeNotifyCh. + mu sync.Mutex + + // hijackedv is whether this connection has been hijacked + // by a Handler with the Hijacker interface. + // It is guarded by mu. + hijackedv bool } func (c *conn) hijacked() bool { @@ -145,83 +193,18 @@ func (c *conn) hijacked() bool { return c.hijackedv } -func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { - c.mu.Lock() - defer c.mu.Unlock() +// c.mu must be held. +func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if c.hijackedv { return nil, nil, ErrHijacked } - if c.closeNotifyc != nil { - return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier") - } c.hijackedv = true rwc = c.rwc - buf = c.buf - c.rwc = nil - c.buf = nil + buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc)) c.setState(rwc, StateHijacked) return } -func (c *conn) closeNotify() <-chan bool { - c.mu.Lock() - defer c.mu.Unlock() - if c.closeNotifyc == nil { - c.closeNotifyc = make(chan bool, 1) - if c.hijackedv { - // to obey the function signature, even though - // it'll never receive a value. - return c.closeNotifyc - } - pr, pw := io.Pipe() - - readSource := c.sr.r - c.sr.Lock() - c.sr.r = pr - c.sr.Unlock() - go func() { - bufp := copyBufPool.Get().(*[]byte) - defer copyBufPool.Put(bufp) - _, err := io.CopyBuffer(pw, readSource, *bufp) - if err == nil { - err = io.EOF - } - pw.CloseWithError(err) - c.noteClientGone() - }() - } - return c.closeNotifyc -} - -func (c *conn) noteClientGone() { - c.mu.Lock() - defer c.mu.Unlock() - if c.closeNotifyc != nil && !c.clientGone { - c.closeNotifyc <- true - } - c.clientGone = true -} - -// A switchWriter can have its Writer changed at runtime. -// It's not safe for concurrent Writes and switches. -type switchWriter struct { - io.Writer -} - -// A liveSwitchReader can have its Reader changed at runtime. It's -// safe for concurrent reads and switches, if its mutex is held. -type liveSwitchReader struct { - sync.Mutex - r io.Reader -} - -func (sr *liveSwitchReader) Read(p []byte) (n int, err error) { - sr.Lock() - r := sr.r - sr.Unlock() - return r.Read(p) -} - // This should be >= 512 bytes for DetectContentType, // but otherwise it's somewhat arbitrary. const bufferBeforeChunkingSize = 2048 @@ -268,15 +251,15 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { return len(p), nil } if cw.chunking { - _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) + _, err = fmt.Fprintf(cw.res.conn.bufw, "%x\r\n", len(p)) if err != nil { cw.res.conn.rwc.Close() return } } - n, err = cw.res.conn.buf.Write(p) + n, err = cw.res.conn.bufw.Write(p) if cw.chunking && err == nil { - _, err = cw.res.conn.buf.Write(crlf) + _, err = cw.res.conn.bufw.Write(crlf) } if err != nil { cw.res.conn.rwc.Close() @@ -288,7 +271,7 @@ func (cw *chunkWriter) flush() { if !cw.wroteHeader { cw.writeHeader(nil) } - cw.res.conn.buf.Flush() + cw.res.conn.bufw.Flush() } func (cw *chunkWriter) close() { @@ -296,7 +279,7 @@ func (cw *chunkWriter) close() { cw.writeHeader(nil) } if cw.chunking { - bw := cw.res.conn.buf // conn's bufio writer + bw := cw.res.conn.bufw // conn's bufio writer // zero chunk to mark EOF bw.WriteString("0\r\n") if len(cw.res.trailers) > 0 { @@ -324,7 +307,6 @@ type response struct { w *bufio.Writer // buffers output in chunks to chunkWriter cw chunkWriter - sw *switchWriter // of the bufio.Writer, for return to putBufioWriter // handlerHeader is the Header that Handlers get access to, // which may be retained and mutated even after WriteHeader. @@ -363,6 +345,8 @@ type response struct { // Buffers for Date and Content-Length dateBuf [len(TimeFormat)]byte clenBuf [10]byte + + closeNotifyCh <-chan bool // guarded by conn.mu } // declareTrailer is called for each Trailer header when the @@ -462,28 +446,88 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) { return n, err } -// noLimit is an effective infinite upper bound for io.LimitedReader -const noLimit int64 = (1 << 63) - 1 - // debugServerConnections controls whether all server connections are wrapped // with a verbose logging wrapper. const debugServerConnections = false // Create new connection from rwc. -func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { - c = new(conn) - c.server = srv - c.rwc = rwc - c.w = rwc +func (srv *Server) newConn(rwc net.Conn) *conn { + c := &conn{ + server: srv, + rwc: rwc, + } if debugServerConnections { c.rwc = newLoggingConn("server", c.rwc) } - c.sr.r = c.rwc - c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) - br := newBufioReader(c.lr) - bw := newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) - c.buf = bufio.NewReadWriter(br, bw) - return c, nil + return c +} + +type readResult struct { + n int + err error + b byte // byte read, if n == 1 +} + +// connReader is the io.Reader wrapper used by *conn. It combines a +// selectively-activated io.LimitedReader (to bound request header +// read sizes) with support for selectively keeping an io.Reader.Read +// call blocked in a background goroutine to wait for activitiy and +// trigger a CloseNotifier channel. +type connReader struct { + r io.Reader + remain int64 // bytes remaining + + // ch is non-nil if a background read is in progress. + // It is guarded by conn.mu. + ch chan readResult +} + +func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } +func (cr *connReader) setInfiniteReadLimit() { cr.remain = 1<<63 - 1 } +func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } + +func (cr *connReader) Read(p []byte) (n int, err error) { + if cr.hitReadLimit() { + return 0, io.EOF + } + if len(p) == 0 { + return + } + if int64(len(p)) > cr.remain { + p = p[:cr.remain] + } + + // Is a background read (started by CloseNotifier) already in + // flight? If so, wait for it and use its result. + ch := cr.ch + if ch != nil { + cr.ch = nil + res := <-ch + if res.n == 1 { + p[0] = res.b + cr.remain -= 1 + } + return res.n, res.err + } + n, err = cr.r.Read(p) + cr.remain -= int64(n) + return +} + +func (cr *connReader) startBackgroundRead(onReadComplete func()) { + if cr.ch != nil { + // Background read already started. + return + } + cr.ch = make(chan readResult, 1) + go cr.closeNotifyAwaitActivityRead(cr.ch, onReadComplete) +} + +func (cr *connReader) closeNotifyAwaitActivityRead(ch chan<- readResult, onReadComplete func()) { + var buf [1]byte + n, err := cr.r.Read(buf[:1]) + onReadComplete() + ch <- readResult{n, err, buf[0]} } var ( @@ -556,7 +600,7 @@ func (srv *Server) maxHeaderBytes() int { return DefaultMaxHeaderBytes } -func (srv *Server) initialLimitedReaderSize() int64 { +func (srv *Server) initialReadLimitSize() int64 { return int64(srv.maxHeaderBytes()) + 4096 // bufio slop } @@ -575,8 +619,8 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { } if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() { ecr.resp.wroteContinue = true - ecr.resp.conn.buf.WriteString("HTTP/1.1 100 Continue\r\n\r\n") - ecr.resp.conn.buf.Flush() + ecr.resp.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") + ecr.resp.conn.bufw.Flush() } n, err = ecr.readCloser.Read(p) if err == io.EOF { @@ -635,21 +679,23 @@ func (c *conn) readRequest() (w *response, err error) { }() } - c.lr.N = c.server.initialLimitedReaderSize() + c.r.setReadLimit(c.server.initialReadLimitSize()) + c.mu.Lock() // while using bufr if c.lastMethod == "POST" { // RFC 2616 section 4.1 tolerance for old buggy clients. - peek, _ := c.buf.Reader.Peek(4) // ReadRequest will get err below - c.buf.Reader.Discard(numLeadingCRorLF(peek)) + peek, _ := c.bufr.Peek(4) // ReadRequest will get err below + c.bufr.Discard(numLeadingCRorLF(peek)) } - var req *Request - if req, err = ReadRequest(c.buf.Reader); err != nil { - if c.lr.N == 0 { + req, err := ReadRequest(c.bufr) + c.mu.Unlock() + if err != nil { + if c.r.hitReadLimit() { return nil, errTooLarge } return nil, err } - c.lr.N = noLimit c.lastMethod = req.Method + c.r.setInfiniteReadLimit() req.RemoteAddr = c.remoteAddr req.TLS = c.tlsState @@ -768,7 +814,7 @@ func (h extraHeader) Write(w *bufio.Writer) { } // writeHeader finalizes the header sent to the client and writes it -// to cw.res.conn.buf. +// to cw.res.conn.bufw. // // p is not written by writeHeader, but is the first chunk of the body // that will be written. It is sniffed for a Content-Type if none is @@ -1009,10 +1055,10 @@ func (cw *chunkWriter) writeHeader(p []byte) { } } - w.conn.buf.WriteString(statusLine(w.req, code)) - cw.header.WriteSubset(w.conn.buf, excludeHeader) - setHeader.Write(w.conn.buf.Writer) - w.conn.buf.Write(crlf) + w.conn.bufw.WriteString(statusLine(w.req, code)) + cw.header.WriteSubset(w.conn.bufw, excludeHeader) + setHeader.Write(w.conn.bufw) + w.conn.bufw.Write(crlf) } // foreachHeaderElement splits v according to the "#rule" construction @@ -1166,7 +1212,7 @@ func (w *response) finishRequest() { w.w.Flush() putBufioWriter(w.w) w.cw.close() - w.conn.buf.Flush() + w.conn.bufw.Flush() // Close the body (regardless of w.closeAfterReply) so we can // re-use its bufio.Reader later safely. @@ -1219,28 +1265,26 @@ func (w *response) Flush() { } func (c *conn) finalFlush() { - if c.buf != nil { - c.buf.Flush() - + if c.bufr != nil { // Steal the bufio.Reader (~4KB worth of memory) and its associated // reader for a future connection. - putBufioReader(c.buf.Reader) + putBufioReader(c.bufr) + c.bufr = nil + } + if c.bufw != nil { + c.bufw.Flush() // Steal the bufio.Writer (~4KB worth of memory) and its associated // writer for a future connection. - putBufioWriter(c.buf.Writer) - - c.buf = nil + putBufioWriter(c.bufw) + c.bufw = nil } } // Close the connection. func (c *conn) close() { c.finalFlush() - if c.rwc != nil { - c.rwc.Close() - c.rwc = nil - } + c.rwc.Close() } // rstAvoidanceDelay is the amount of time we sleep after closing the @@ -1293,7 +1337,6 @@ func (c *conn) setState(nc net.Conn, state ConnState) { // Serve a new connection. func (c *conn) serve() { c.remoteAddr = c.rwc.RemoteAddr().String() - origConn := c.rwc // copy it before it's set nil on Close or Hijack defer func() { if err := recover(); err != nil { const size = 64 << 10 @@ -1303,7 +1346,7 @@ func (c *conn) serve() { } if !c.hijacked() { c.close() - c.setState(origConn, StateClosed) + c.setState(c.rwc, StateClosed) } }() @@ -1329,9 +1372,13 @@ func (c *conn) serve() { } } + c.r = &connReader{r: c.rwc} + c.bufr = newBufioReader(c.r) + c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) + for { w, err := c.readRequest() - if c.lr.N != c.server.initialLimitedReaderSize() { + if c.r.remain != c.server.initialReadLimitSize() { // If we read any bytes off the wire, we're active. c.setState(c.rwc, StateActive) } @@ -1344,14 +1391,16 @@ func (c *conn) serve() { // request. Undefined behavior. io.WriteString(c.rwc, "HTTP/1.1 431 Request Header Fields Too Large\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n431 Request Header Fields Too Large") c.closeWriteAndWait() - break - } else if err == io.EOF { - break // Don't reply - } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() { - break // Don't reply + return + } + if err == io.EOF { + return // don't reply + } + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + return // don't reply } io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request") - break + return } // Expect 100 Continue support @@ -1364,7 +1413,7 @@ func (c *conn) serve() { req.Header.Del("Expect") } else if req.Header.get("Expect") != "" { w.sendExpectationFailed() - break + return } // HTTP cannot have multiple simultaneous active requests.[*] @@ -1381,7 +1430,7 @@ func (c *conn) serve() { if w.requestBodyLimitHit || w.closedRequestBodyEarly() { c.closeWriteAndWait() } - break + return } c.setState(c.rwc, StateIdle) } @@ -1411,9 +1460,18 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if w.wroteHeader { w.cw.flush() } + + c := w.conn + c.mu.Lock() + defer c.mu.Unlock() + + if w.closeNotifyCh != nil { + return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier in same ServeHTTP call") + } + // Release the bufioWriter that writes to the chunk writer, it is not // used after a connection has been hijacked. - rwc, buf, err = w.conn.hijack() + rwc, buf, err = c.hijackLocked() if err == nil { putBufioWriter(w.w) w.w = nil @@ -1422,7 +1480,34 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { } func (w *response) CloseNotify() <-chan bool { - return w.conn.closeNotify() + c := w.conn + c.mu.Lock() + defer c.mu.Unlock() + + if w.closeNotifyCh != nil { + return w.closeNotifyCh + } + ch := make(chan bool, 1) + w.closeNotifyCh = ch + + if w.conn.hijackedv { + // CloseNotify is undefined after a hijack, but we have + // no place to return an error, so just return a channel, + // even though it'll never receive a value. + return ch + } + + var once sync.Once + notify := func() { once.Do(func() { ch <- true }) } + + if c.bufr.Buffered() > 0 { + // A pipelined request or unread request body data is available + // unread. Per the CloseNotifier docs, fire immediately. + notify() + } else { + c.r.startBackgroundRead(notify) + } + return ch } // The HandlerFunc type is an adapter to allow the use of @@ -1934,10 +2019,7 @@ func (srv *Server) Serve(l net.Listener) error { return e } tempDelay = 0 - c, err := srv.newConn(rw) - if err != nil { - continue - } + c := srv.newConn(rw) c.setState(c.rwc, StateNew) // before Serve can return go c.serve() } @@ -2336,7 +2418,7 @@ type checkConnErrorWriter struct { } func (w checkConnErrorWriter) Write(p []byte) (n int, err error) { - n, err = w.c.w.Write(p) // c.w == c.rwc, except after a hijack, when rwc is nil. + n, err = w.c.rwc.Write(p) if err != nil && w.c.werr == nil { w.c.werr = err }