diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index a98a4ccf3b..3e84f2e11d 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -97,6 +97,7 @@ func (c *rwTestConn) Close() error { } type testConn struct { + readMu sync.Mutex // for TestHandlerBodyClose readBuf bytes.Buffer writeBuf bytes.Buffer closec chan bool // if non-nil, send value to it on close @@ -104,6 +105,8 @@ type testConn struct { } func (c *testConn) Read(b []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() return c.readBuf.Read(b) } @@ -1246,15 +1249,21 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) { done := make(chan bool) + readBufLen := func() int { + conn.readMu.Lock() + defer conn.readMu.Unlock() + return conn.readBuf.Len() + } + ls := &oneConnListener{conn} go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { defer close(done) - if conn.readBuf.Len() < len(body)/2 { - t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len()) + if bufLen := readBufLen(); bufLen < len(body)/2 { + t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen) } rw.WriteHeader(200) rw.(Flusher).Flush() - if g, e := conn.readBuf.Len(), 0; g != e { + if g, e := readBufLen(), 0; g != e { t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e) } if c := rw.Header().Get("Connection"); c != "" { @@ -1430,15 +1439,21 @@ func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) { } conn.closec = make(chan bool, 1) + readBufLen := func() int { + conn.readMu.Lock() + defer conn.readMu.Unlock() + return conn.readBuf.Len() + } + ls := &oneConnListener{conn} var numReqs int var size0, size1 int go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { numReqs++ if numReqs == 1 { - size0 = conn.readBuf.Len() + size0 = readBufLen() req.Body.Close() - size1 = conn.readBuf.Len() + size1 = readBufLen() } })) <-conn.closec @@ -1538,7 +1553,9 @@ type slowTestConn struct { // over multiple calls to Read, time.Durations are slept, strings are read. script []interface{} closec chan bool - rd, wd time.Time // read, write deadline + + mu sync.Mutex // guards rd/wd + rd, wd time.Time // read, write deadline noopConn } @@ -1549,16 +1566,22 @@ func (c *slowTestConn) SetDeadline(t time.Time) error { } func (c *slowTestConn) SetReadDeadline(t time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() c.rd = t return nil } func (c *slowTestConn) SetWriteDeadline(t time.Time) error { + c.mu.Lock() + defer c.mu.Unlock() c.wd = t return nil } func (c *slowTestConn) Read(b []byte) (n int, err error) { + c.mu.Lock() + defer c.mu.Unlock() restart: if !c.rd.IsZero() && time.Now().After(c.rd) { return 0, syscall.ETIMEDOUT @@ -2330,6 +2353,49 @@ For: ts.Close() } +// Tests that a pipelined request causes the first request's Handler's CloseNotify +// channel to fire. Previously it deadlocked. +// +// Issue 13165 +func TestCloseNotifierPipelined(t *testing.T) { + defer afterTest(t) + gotReq := make(chan bool, 2) + sawClose := make(chan bool, 2) + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + gotReq <- true + cc := rw.(CloseNotifier).CloseNotify() + <-cc + sawClose <- true + })) + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error dialing: %v", err) + } + diec := make(chan bool, 2) + go func() { + const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n" + _, err = io.WriteString(conn, req+req) // two requests + if err != nil { + t.Fatal(err) + } + <-diec + conn.Close() + }() +For: + for { + select { + case <-gotReq: + diec <- true + case <-sawClose: + break For + case <-time.After(5 * time.Second): + ts.CloseClientConnections() + t.Fatal("timeout") + } + } + ts.Close() +} + func TestCloseNotifierChanLeak(t *testing.T) { defer afterTest(t) req := reqBytes("GET / HTTP/1.0\nHost: golang.org") @@ -2352,6 +2418,61 @@ func TestCloseNotifierChanLeak(t *testing.T) { } } +// Tests that we can use CloseNotifier in one request, and later call Hijack +// on a second request on the same connection. +// +// It also tests that the connReader stitches together its background +// 1-byte read for CloseNotifier when CloseNotifier doesn't fire with +// the rest of the second HTTP later. +// +// Issue 9763. +// HTTP/1-only test. (http2 doesn't have Hijack) +func TestHijackAfterCloseNotifier(t *testing.T) { + defer afterTest(t) + script := make(chan string, 2) + script <- "closenotify" + script <- "hijack" + close(script) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + plan := <-script + switch plan { + default: + panic("bogus plan; too many requests") + case "closenotify": + w.(CloseNotifier).CloseNotify() // discard result + w.Header().Set("X-Addr", r.RemoteAddr) + case "hijack": + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack in Handler: %v", err) + return + } + if _, ok := c.(*net.TCPConn); !ok { + // Verify it's not wrapped in some type. + // Not strictly a go1 compat issue, but in practice it probably is. + t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c) + } + fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr) + c.Close() + return + } + })) + defer ts.Close() + res1, err := Get(ts.URL) + if err != nil { + log.Fatal(err) + } + res2, err := Get(ts.URL) + if err != nil { + log.Fatal(err) + } + addr1 := res1.Header.Get("X-Addr") + addr2 := res2.Header.Get("X-Addr") + if addr1 == "" || addr1 != addr2 { + t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2) + } +} + func TestOptions(t *testing.T) { uric := make(chan string, 2) // only expect 1, but leave space for 2 mux := NewServeMux() @@ -2702,7 +2823,7 @@ func TestHTTP10ConnectionHeader(t *testing.T) { defer afterTest(t) mux := NewServeMux() - mux.Handle("/", HandlerFunc(func(resp ResponseWriter, req *Request) {})) + mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {})) ts := httptest.NewServer(mux) defer ts.Close() @@ -3248,10 +3369,7 @@ func (c *closeWriteTestConn) CloseWrite() error { func TestCloseWrite(t *testing.T) { var srv Server var testConn closeWriteTestConn - c, err := ExportServerNewConn(&srv, &testConn) - if err != nil { - t.Fatal(err) - } + c := ExportServerNewConn(&srv, &testConn) ExportCloseWriteAndWait(c) if !testConn.didCloseWrite { t.Error("didn't see CloseWrite call") diff --git a/src/net/http/server.go b/src/net/http/server.go index 35f41e734e..cd5f9cf34f 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -115,28 +115,76 @@ type Hijacker interface { // This mechanism can be used to cancel long operations on the server // if the client has disconnected before the response is ready. type CloseNotifier interface { - // CloseNotify returns a channel that receives a single value - // when the client connection has gone away. + // CloseNotify returns a channel that receives at most a + // single value (true) when the client connection has gone + // away. + // + // CloseNotify is undefined before Request.Body has been + // fully read. + // + // After the Handler has returned, there is no guarantee + // that the channel receives a value. + // + // If the protocol is HTTP/1.1 and CloseNotify is called while + // processing an idempotent request (such a GET) while + // HTTP/1.1 pipelining is in use, the arrival of a subsequent + // pipelined request will cause a value to be sent on the + // returned channel. In practice HTTP/1.1 pipelining is not + // enabled in browsers and not seen often in the wild. If this + // is a problem, use HTTP/2 or only use CloseNotify on methods + // such as POST. CloseNotify() <-chan bool } // A conn represents the server side of an HTTP connection. type conn struct { - remoteAddr string // network address of remote side - server *Server // the Server on which the connection arrived - rwc net.Conn // i/o connection - w io.Writer // checkConnErrorWriter's copy of wrc, not zeroed on Hijack - werr error // any errors writing to w - sr liveSwitchReader // where the LimitReader reads from; usually the rwc - lr *io.LimitedReader // io.LimitReader(sr) - buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc - tlsState *tls.ConnectionState // or nil when not using TLS - lastMethod string // method of previous request, or "" + // server is the server on which the connection arrived. + // Immutable; never nil. + server *Server - mu sync.Mutex // guards the following - clientGone bool // if client has disconnected mid-request - closeNotifyc chan bool // made lazily - hijackedv bool // connection has been hijacked by handler + // rwc is the underlying network connection. + // This is never wrapped by other types and is the value given out + // to CloseNotifier callers. It is usually of type *net.TCPConn or + // *tls.Conn. + rwc net.Conn + + // remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously + // inside the Listener's Accept goroutine, as some implementations block. + // It is populated immediately inside the (*conn).serve goroutine. + // This is the value of a Handler's (*Request).RemoteAddr. + remoteAddr string + + // tlsState is the TLS connection state when using TLS. + // nil means not TLS. + tlsState *tls.ConnectionState + + // werr is set to the first write error to rwc. + // It is set via checkConnErrorWriter{w}, where bufw writes. + werr error + + // r is bufr's read source. It's a wrapper around rwc that provides + // io.LimitedReader-style limiting (while reading request headers) + // and functionality to support CloseNotifier. See *connReader docs. + r *connReader + + // bufr reads from r. + // Users of bufr must hold mu. + bufr *bufio.Reader + + // bufw writes to checkConnErrorWriter{c}, which populates werr on error. + bufw *bufio.Writer + + // lastMethod is the method of the most recent request + // on this connection, if any. + lastMethod string + + // mu guards hijackedv, use of bufr, (*response).closeNotifyCh. + mu sync.Mutex + + // hijackedv is whether this connection has been hijacked + // by a Handler with the Hijacker interface. + // It is guarded by mu. + hijackedv bool } func (c *conn) hijacked() bool { @@ -145,83 +193,18 @@ func (c *conn) hijacked() bool { return c.hijackedv } -func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { - c.mu.Lock() - defer c.mu.Unlock() +// c.mu must be held. +func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if c.hijackedv { return nil, nil, ErrHijacked } - if c.closeNotifyc != nil { - return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier") - } c.hijackedv = true rwc = c.rwc - buf = c.buf - c.rwc = nil - c.buf = nil + buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc)) c.setState(rwc, StateHijacked) return } -func (c *conn) closeNotify() <-chan bool { - c.mu.Lock() - defer c.mu.Unlock() - if c.closeNotifyc == nil { - c.closeNotifyc = make(chan bool, 1) - if c.hijackedv { - // to obey the function signature, even though - // it'll never receive a value. - return c.closeNotifyc - } - pr, pw := io.Pipe() - - readSource := c.sr.r - c.sr.Lock() - c.sr.r = pr - c.sr.Unlock() - go func() { - bufp := copyBufPool.Get().(*[]byte) - defer copyBufPool.Put(bufp) - _, err := io.CopyBuffer(pw, readSource, *bufp) - if err == nil { - err = io.EOF - } - pw.CloseWithError(err) - c.noteClientGone() - }() - } - return c.closeNotifyc -} - -func (c *conn) noteClientGone() { - c.mu.Lock() - defer c.mu.Unlock() - if c.closeNotifyc != nil && !c.clientGone { - c.closeNotifyc <- true - } - c.clientGone = true -} - -// A switchWriter can have its Writer changed at runtime. -// It's not safe for concurrent Writes and switches. -type switchWriter struct { - io.Writer -} - -// A liveSwitchReader can have its Reader changed at runtime. It's -// safe for concurrent reads and switches, if its mutex is held. -type liveSwitchReader struct { - sync.Mutex - r io.Reader -} - -func (sr *liveSwitchReader) Read(p []byte) (n int, err error) { - sr.Lock() - r := sr.r - sr.Unlock() - return r.Read(p) -} - // This should be >= 512 bytes for DetectContentType, // but otherwise it's somewhat arbitrary. const bufferBeforeChunkingSize = 2048 @@ -268,15 +251,15 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) { return len(p), nil } if cw.chunking { - _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p)) + _, err = fmt.Fprintf(cw.res.conn.bufw, "%x\r\n", len(p)) if err != nil { cw.res.conn.rwc.Close() return } } - n, err = cw.res.conn.buf.Write(p) + n, err = cw.res.conn.bufw.Write(p) if cw.chunking && err == nil { - _, err = cw.res.conn.buf.Write(crlf) + _, err = cw.res.conn.bufw.Write(crlf) } if err != nil { cw.res.conn.rwc.Close() @@ -288,7 +271,7 @@ func (cw *chunkWriter) flush() { if !cw.wroteHeader { cw.writeHeader(nil) } - cw.res.conn.buf.Flush() + cw.res.conn.bufw.Flush() } func (cw *chunkWriter) close() { @@ -296,7 +279,7 @@ func (cw *chunkWriter) close() { cw.writeHeader(nil) } if cw.chunking { - bw := cw.res.conn.buf // conn's bufio writer + bw := cw.res.conn.bufw // conn's bufio writer // zero chunk to mark EOF bw.WriteString("0\r\n") if len(cw.res.trailers) > 0 { @@ -324,7 +307,6 @@ type response struct { w *bufio.Writer // buffers output in chunks to chunkWriter cw chunkWriter - sw *switchWriter // of the bufio.Writer, for return to putBufioWriter // handlerHeader is the Header that Handlers get access to, // which may be retained and mutated even after WriteHeader. @@ -363,6 +345,8 @@ type response struct { // Buffers for Date and Content-Length dateBuf [len(TimeFormat)]byte clenBuf [10]byte + + closeNotifyCh <-chan bool // guarded by conn.mu } // declareTrailer is called for each Trailer header when the @@ -462,28 +446,88 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) { return n, err } -// noLimit is an effective infinite upper bound for io.LimitedReader -const noLimit int64 = (1 << 63) - 1 - // debugServerConnections controls whether all server connections are wrapped // with a verbose logging wrapper. const debugServerConnections = false // Create new connection from rwc. -func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) { - c = new(conn) - c.server = srv - c.rwc = rwc - c.w = rwc +func (srv *Server) newConn(rwc net.Conn) *conn { + c := &conn{ + server: srv, + rwc: rwc, + } if debugServerConnections { c.rwc = newLoggingConn("server", c.rwc) } - c.sr.r = c.rwc - c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader) - br := newBufioReader(c.lr) - bw := newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) - c.buf = bufio.NewReadWriter(br, bw) - return c, nil + return c +} + +type readResult struct { + n int + err error + b byte // byte read, if n == 1 +} + +// connReader is the io.Reader wrapper used by *conn. It combines a +// selectively-activated io.LimitedReader (to bound request header +// read sizes) with support for selectively keeping an io.Reader.Read +// call blocked in a background goroutine to wait for activitiy and +// trigger a CloseNotifier channel. +type connReader struct { + r io.Reader + remain int64 // bytes remaining + + // ch is non-nil if a background read is in progress. + // It is guarded by conn.mu. + ch chan readResult +} + +func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } +func (cr *connReader) setInfiniteReadLimit() { cr.remain = 1<<63 - 1 } +func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } + +func (cr *connReader) Read(p []byte) (n int, err error) { + if cr.hitReadLimit() { + return 0, io.EOF + } + if len(p) == 0 { + return + } + if int64(len(p)) > cr.remain { + p = p[:cr.remain] + } + + // Is a background read (started by CloseNotifier) already in + // flight? If so, wait for it and use its result. + ch := cr.ch + if ch != nil { + cr.ch = nil + res := <-ch + if res.n == 1 { + p[0] = res.b + cr.remain -= 1 + } + return res.n, res.err + } + n, err = cr.r.Read(p) + cr.remain -= int64(n) + return +} + +func (cr *connReader) startBackgroundRead(onReadComplete func()) { + if cr.ch != nil { + // Background read already started. + return + } + cr.ch = make(chan readResult, 1) + go cr.closeNotifyAwaitActivityRead(cr.ch, onReadComplete) +} + +func (cr *connReader) closeNotifyAwaitActivityRead(ch chan<- readResult, onReadComplete func()) { + var buf [1]byte + n, err := cr.r.Read(buf[:1]) + onReadComplete() + ch <- readResult{n, err, buf[0]} } var ( @@ -556,7 +600,7 @@ func (srv *Server) maxHeaderBytes() int { return DefaultMaxHeaderBytes } -func (srv *Server) initialLimitedReaderSize() int64 { +func (srv *Server) initialReadLimitSize() int64 { return int64(srv.maxHeaderBytes()) + 4096 // bufio slop } @@ -575,8 +619,8 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { } if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() { ecr.resp.wroteContinue = true - ecr.resp.conn.buf.WriteString("HTTP/1.1 100 Continue\r\n\r\n") - ecr.resp.conn.buf.Flush() + ecr.resp.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n") + ecr.resp.conn.bufw.Flush() } n, err = ecr.readCloser.Read(p) if err == io.EOF { @@ -635,21 +679,23 @@ func (c *conn) readRequest() (w *response, err error) { }() } - c.lr.N = c.server.initialLimitedReaderSize() + c.r.setReadLimit(c.server.initialReadLimitSize()) + c.mu.Lock() // while using bufr if c.lastMethod == "POST" { // RFC 2616 section 4.1 tolerance for old buggy clients. - peek, _ := c.buf.Reader.Peek(4) // ReadRequest will get err below - c.buf.Reader.Discard(numLeadingCRorLF(peek)) + peek, _ := c.bufr.Peek(4) // ReadRequest will get err below + c.bufr.Discard(numLeadingCRorLF(peek)) } - var req *Request - if req, err = ReadRequest(c.buf.Reader); err != nil { - if c.lr.N == 0 { + req, err := ReadRequest(c.bufr) + c.mu.Unlock() + if err != nil { + if c.r.hitReadLimit() { return nil, errTooLarge } return nil, err } - c.lr.N = noLimit c.lastMethod = req.Method + c.r.setInfiniteReadLimit() req.RemoteAddr = c.remoteAddr req.TLS = c.tlsState @@ -768,7 +814,7 @@ func (h extraHeader) Write(w *bufio.Writer) { } // writeHeader finalizes the header sent to the client and writes it -// to cw.res.conn.buf. +// to cw.res.conn.bufw. // // p is not written by writeHeader, but is the first chunk of the body // that will be written. It is sniffed for a Content-Type if none is @@ -1009,10 +1055,10 @@ func (cw *chunkWriter) writeHeader(p []byte) { } } - w.conn.buf.WriteString(statusLine(w.req, code)) - cw.header.WriteSubset(w.conn.buf, excludeHeader) - setHeader.Write(w.conn.buf.Writer) - w.conn.buf.Write(crlf) + w.conn.bufw.WriteString(statusLine(w.req, code)) + cw.header.WriteSubset(w.conn.bufw, excludeHeader) + setHeader.Write(w.conn.bufw) + w.conn.bufw.Write(crlf) } // foreachHeaderElement splits v according to the "#rule" construction @@ -1166,7 +1212,7 @@ func (w *response) finishRequest() { w.w.Flush() putBufioWriter(w.w) w.cw.close() - w.conn.buf.Flush() + w.conn.bufw.Flush() // Close the body (regardless of w.closeAfterReply) so we can // re-use its bufio.Reader later safely. @@ -1219,28 +1265,26 @@ func (w *response) Flush() { } func (c *conn) finalFlush() { - if c.buf != nil { - c.buf.Flush() - + if c.bufr != nil { // Steal the bufio.Reader (~4KB worth of memory) and its associated // reader for a future connection. - putBufioReader(c.buf.Reader) + putBufioReader(c.bufr) + c.bufr = nil + } + if c.bufw != nil { + c.bufw.Flush() // Steal the bufio.Writer (~4KB worth of memory) and its associated // writer for a future connection. - putBufioWriter(c.buf.Writer) - - c.buf = nil + putBufioWriter(c.bufw) + c.bufw = nil } } // Close the connection. func (c *conn) close() { c.finalFlush() - if c.rwc != nil { - c.rwc.Close() - c.rwc = nil - } + c.rwc.Close() } // rstAvoidanceDelay is the amount of time we sleep after closing the @@ -1293,7 +1337,6 @@ func (c *conn) setState(nc net.Conn, state ConnState) { // Serve a new connection. func (c *conn) serve() { c.remoteAddr = c.rwc.RemoteAddr().String() - origConn := c.rwc // copy it before it's set nil on Close or Hijack defer func() { if err := recover(); err != nil { const size = 64 << 10 @@ -1303,7 +1346,7 @@ func (c *conn) serve() { } if !c.hijacked() { c.close() - c.setState(origConn, StateClosed) + c.setState(c.rwc, StateClosed) } }() @@ -1329,9 +1372,13 @@ func (c *conn) serve() { } } + c.r = &connReader{r: c.rwc} + c.bufr = newBufioReader(c.r) + c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) + for { w, err := c.readRequest() - if c.lr.N != c.server.initialLimitedReaderSize() { + if c.r.remain != c.server.initialReadLimitSize() { // If we read any bytes off the wire, we're active. c.setState(c.rwc, StateActive) } @@ -1344,14 +1391,16 @@ func (c *conn) serve() { // request. Undefined behavior. io.WriteString(c.rwc, "HTTP/1.1 431 Request Header Fields Too Large\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n431 Request Header Fields Too Large") c.closeWriteAndWait() - break - } else if err == io.EOF { - break // Don't reply - } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() { - break // Don't reply + return + } + if err == io.EOF { + return // don't reply + } + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + return // don't reply } io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request") - break + return } // Expect 100 Continue support @@ -1364,7 +1413,7 @@ func (c *conn) serve() { req.Header.Del("Expect") } else if req.Header.get("Expect") != "" { w.sendExpectationFailed() - break + return } // HTTP cannot have multiple simultaneous active requests.[*] @@ -1381,7 +1430,7 @@ func (c *conn) serve() { if w.requestBodyLimitHit || w.closedRequestBodyEarly() { c.closeWriteAndWait() } - break + return } c.setState(c.rwc, StateIdle) } @@ -1411,9 +1460,18 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if w.wroteHeader { w.cw.flush() } + + c := w.conn + c.mu.Lock() + defer c.mu.Unlock() + + if w.closeNotifyCh != nil { + return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier in same ServeHTTP call") + } + // Release the bufioWriter that writes to the chunk writer, it is not // used after a connection has been hijacked. - rwc, buf, err = w.conn.hijack() + rwc, buf, err = c.hijackLocked() if err == nil { putBufioWriter(w.w) w.w = nil @@ -1422,7 +1480,34 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { } func (w *response) CloseNotify() <-chan bool { - return w.conn.closeNotify() + c := w.conn + c.mu.Lock() + defer c.mu.Unlock() + + if w.closeNotifyCh != nil { + return w.closeNotifyCh + } + ch := make(chan bool, 1) + w.closeNotifyCh = ch + + if w.conn.hijackedv { + // CloseNotify is undefined after a hijack, but we have + // no place to return an error, so just return a channel, + // even though it'll never receive a value. + return ch + } + + var once sync.Once + notify := func() { once.Do(func() { ch <- true }) } + + if c.bufr.Buffered() > 0 { + // A pipelined request or unread request body data is available + // unread. Per the CloseNotifier docs, fire immediately. + notify() + } else { + c.r.startBackgroundRead(notify) + } + return ch } // The HandlerFunc type is an adapter to allow the use of @@ -1934,10 +2019,7 @@ func (srv *Server) Serve(l net.Listener) error { return e } tempDelay = 0 - c, err := srv.newConn(rw) - if err != nil { - continue - } + c := srv.newConn(rw) c.setState(c.rwc, StateNew) // before Serve can return go c.serve() } @@ -2336,7 +2418,7 @@ type checkConnErrorWriter struct { } func (w checkConnErrorWriter) Write(p []byte) (n int, err error) { - n, err = w.c.w.Write(p) // c.w == c.rwc, except after a hijack, when rwc is nil. + n, err = w.c.rwc.Write(p) if err != nil && w.c.werr == nil { w.c.werr = err }