diff --git a/src/pkg/net/http/serve_test.go b/src/pkg/net/http/serve_test.go index 36832140b40..21cd67f9dca 100644 --- a/src/pkg/net/http/serve_test.go +++ b/src/pkg/net/http/serve_test.go @@ -2270,6 +2270,12 @@ func TestServerConnState(t *testing.T) { c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello.")) c.Close() }, + "/hijack-panic": func(w ResponseWriter, r *Request) { + c, _, _ := w.(Hijacker).Hijack() + c.Write([]byte("HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello.")) + c.Close() + panic("intentional panic") + }, } ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { handler[r.URL.Path](w, r) @@ -2280,6 +2286,7 @@ func TestServerConnState(t *testing.T) { var stateLog = map[int][]ConnState{} var connID = map[net.Conn]int{} + ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) ts.Config.ConnState = func(c net.Conn, state ConnState) { if c == nil { t.Error("nil conn seen in state %s", state) @@ -2303,11 +2310,56 @@ func TestServerConnState(t *testing.T) { mustGet(t, ts.URL+"/", "Connection", "close") mustGet(t, ts.URL+"/hijack") + mustGet(t, ts.URL+"/hijack-panic") + + // New->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + c.Close() + } + + // New->Active->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil { + t.Fatal(err) + } + c.Close() + } + + // New->Idle->Closed + { + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + if _, err := io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n"); err != nil { + t.Fatal(err) + } + res, err := ReadResponse(bufio.NewReader(c), nil) + if err != nil { + t.Fatal(err) + } + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatal(err) + } + c.Close() + } want := map[int][]ConnState{ 1: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed}, 2: []ConnState{StateNew, StateActive, StateIdle, StateActive, StateClosed}, 3: []ConnState{StateNew, StateActive, StateHijacked}, + 4: []ConnState{StateNew, StateActive, StateHijacked}, + 5: []ConnState{StateNew, StateClosed}, + 6: []ConnState{StateNew, StateActive, StateClosed}, + 7: []ConnState{StateNew, StateActive, StateIdle, StateClosed}, } logString := func(m map[int][]ConnState) string { var b bytes.Buffer @@ -2316,6 +2368,7 @@ func TestServerConnState(t *testing.T) { for _, s := range l { fmt.Fprintf(&b, "%s ", s) } + b.WriteString("\n") } return b.String() } diff --git a/src/pkg/net/http/server.go b/src/pkg/net/http/server.go index ffe5838a063..273d5964f14 100644 --- a/src/pkg/net/http/server.go +++ b/src/pkg/net/http/server.go @@ -139,6 +139,7 @@ func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { buf = c.buf c.rwc = nil c.buf = nil + c.setState(rwc, StateHijacked) return } @@ -497,6 +498,10 @@ func (srv *Server) maxHeaderBytes() int { return DefaultMaxHeaderBytes } +func (srv *Server) initialLimitedReaderSize() int64 { + return int64(srv.maxHeaderBytes()) + 4096 // bufio slop +} + // wrapper around io.ReaderCloser which on first read, sends an // HTTP/1.1 100 Continue header type expectContinueReader struct { @@ -567,7 +572,7 @@ func (c *conn) readRequest() (w *response, err error) { }() } - c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */ + c.lr.N = c.server.initialLimitedReaderSize() var req *Request if req, err = ReadRequest(c.buf.Reader); err != nil { if c.lr.N == 0 { @@ -1127,10 +1132,10 @@ func (c *conn) serve() { for { w, err := c.readRequest() - // TODO(bradfitz): could push this StateActive - // earlier, but in practice header will be all in one - // packet/Read: - c.setState(c.rwc, StateActive) + if c.lr.N != c.server.initialLimitedReaderSize() { + // If we read any bytes off the wire, we're active. + c.setState(c.rwc, StateActive) + } if err != nil { if err == errTooLarge { // Their HTTP client may or may not be @@ -1176,7 +1181,6 @@ func (c *conn) serve() { // in parallel even if their responses need to be serialized. serverHandler{c.server}.ServeHTTP(w, w.req) if c.hijacked() { - c.setState(origConn, StateHijacked) return } w.finishRequest()