1
0
mirror of https://github.com/golang/go synced 2024-10-03 06:21:21 -06:00

net/http: implement CloseNotifier

Fixes #2510

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6867050
This commit is contained in:
Brad Fitzpatrick 2012-12-05 19:25:43 -08:00
parent a034fc9855
commit 4fb78c3a16
2 changed files with 158 additions and 33 deletions

View File

@ -1252,6 +1252,42 @@ func TestContentLengthZero(t *testing.T) {
}
}
func TestCloseNotifier(t *testing.T) {
gotReq := make(chan bool, 1)
sawClose := make(chan bool, 1)
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
gotReq <- true
cc := rw.(CloseNotifier).CloseNotify()
<-cc
sawClose <- true
}))
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
if err != nil {
t.Fatalf("error dialing: %v", err)
}
diec := make(chan bool)
go func() {
_, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
if err != nil {
t.Fatal(err)
}
<-diec
conn.Close()
}()
For:
for {
select {
case <-gotReq:
diec <- true
case <-sawClose:
break For
case <-time.After(5 * time.Second):
t.Fatal("timeout")
}
}
ts.Close()
}
// goTimeout runs f, failing t if f takes more than ns to complete.
func goTimeout(t *testing.T, d time.Duration, f func()) {
ch := make(chan bool, 2)

View File

@ -93,16 +93,104 @@ type Hijacker interface {
Hijack() (net.Conn, *bufio.ReadWriter, error)
}
// The CloseNotifier interface is implemented by ResponseWriters which
// allow detecting when the underlying connection has gone away.
//
// This mechanism can be used to cancel long operations on the server
// if the client has disconnected before the response is ready.
type CloseNotifier interface {
// CloseNotify returns a channel that receives a single value
// when the client connection has gone away.
CloseNotify() <-chan bool
}
// A conn represents the server side of an HTTP connection.
type conn struct {
remoteAddr string // network address of remote side
server *Server // the Server on which the connection arrived
rwc net.Conn // i/o connection
lr *io.LimitedReader // io.LimitReader(rwc)
buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->rwc
hijacked bool // connection has been hijacked by handler
sr switchReader // where the LimitReader reads from; usually the rwc
lr *io.LimitedReader // io.LimitReader(sr)
buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc
tlsState *tls.ConnectionState // or nil when not using TLS
body []byte
mu sync.Mutex // guards the following
clientGone bool // if client has disconnected mid-request
closeNotifyc chan bool // made lazily
hijackedv bool // connection has been hijacked by handler
}
func (c *conn) hijacked() bool {
c.mu.Lock()
defer c.mu.Unlock()
return c.hijackedv
}
func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.hijackedv {
return nil, nil, ErrHijacked
}
if c.closeNotifyc != nil {
return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier")
}
c.hijackedv = true
rwc = c.rwc
buf = c.buf
c.rwc = nil
c.buf = nil
return
}
func (c *conn) closeNotify() <-chan bool {
c.mu.Lock()
defer c.mu.Unlock()
if c.closeNotifyc == nil {
c.closeNotifyc = make(chan bool)
if c.hijackedv {
// to obey the function signature, even though
// it'll never receive a value.
return c.closeNotifyc
}
pr, pw := io.Pipe()
readSource := c.sr.r
c.sr.Lock()
c.sr.r = pr
c.sr.Unlock()
go func() {
_, err := io.Copy(pw, readSource)
if err == nil {
err = io.EOF
}
pw.CloseWithError(err)
c.noteClientGone()
}()
}
return c.closeNotifyc
}
func (c *conn) noteClientGone() {
c.mu.Lock()
defer c.mu.Unlock()
if c.closeNotifyc != nil && !c.clientGone {
c.closeNotifyc <- true
}
c.clientGone = true
}
type switchReader struct {
sync.Mutex
r io.Reader
}
func (sr *switchReader) Read(p []byte) (n int, err error) {
sr.Lock()
r := sr.r
sr.Unlock()
return r.Read(p)
}
// A response represents the server side of an HTTP response.
@ -183,8 +271,9 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
if debugServerConnections {
c.rwc = newLoggingConn("server", c.rwc)
}
c.sr = switchReader{r: c.rwc}
c.body = make([]byte, sniffLen)
c.lr = io.LimitReader(c.rwc, noLimit).(*io.LimitedReader)
c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader)
br := bufio.NewReader(c.lr)
bw := bufio.NewWriter(c.rwc)
c.buf = bufio.NewReadWriter(br, bw)
@ -215,7 +304,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
if ecr.closed {
return 0, errors.New("http: Read after Close on request Body")
}
if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked {
if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() {
ecr.resp.wroteContinue = true
io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n")
ecr.resp.conn.buf.Flush()
@ -238,7 +327,7 @@ var errTooLarge = errors.New("http: request too large")
// Read next request from connection.
func (c *conn) readRequest() (w *response, err error) {
if c.hijacked {
if c.hijacked() {
return nil, ErrHijacked
}
c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
@ -279,7 +368,7 @@ func (w *response) Header() Header {
const maxPostHandlerReadBytes = 256 << 10
func (w *response) WriteHeader(code int) {
if w.conn.hijacked {
if w.conn.hijacked() {
log.Print("http: response.WriteHeader on hijacked connection")
return
}
@ -455,7 +544,7 @@ func (w *response) bodyAllowed() bool {
}
func (w *response) Write(data []byte) (n int, err error) {
if w.conn.hijacked {
if w.conn.hijacked() {
log.Print("http: response.Write on hijacked connection")
return 0, ErrHijacked
}
@ -673,21 +762,7 @@ func (c *conn) serve() {
}
req.Header.Del("Expect")
} else if req.Header.get("Expect") != "" {
// TODO(bradfitz): let ServeHTTP handlers handle
// requests with non-standard expectation[s]? Seems
// theoretical at best, and doesn't fit into the
// current ServeHTTP model anyway. We'd need to
// make the ResponseWriter an optional
// "ExpectReplier" interface or something.
//
// For now we'll just obey RFC 2616 14.20 which says
// "If a server receives a request containing an
// Expect field that includes an expectation-
// extension that it does not support, it MUST
// respond with a 417 (Expectation Failed) status."
w.Header().Set("Connection", "close")
w.WriteHeader(StatusExpectationFailed)
w.finishRequest()
w.sendExpectationFailed()
break
}
@ -702,7 +777,7 @@ func (c *conn) serve() {
// [*] Not strictly true: HTTP pipelining. We could let them all process
// in parallel even if their responses need to be serialized.
handler.ServeHTTP(w, w.req)
if c.hijacked {
if c.hijacked() {
return
}
w.finishRequest()
@ -716,18 +791,32 @@ func (c *conn) serve() {
c.close()
}
func (w *response) sendExpectationFailed() {
// TODO(bradfitz): let ServeHTTP handlers handle
// requests with non-standard expectation[s]? Seems
// theoretical at best, and doesn't fit into the
// current ServeHTTP model anyway. We'd need to
// make the ResponseWriter an optional
// "ExpectReplier" interface or something.
//
// For now we'll just obey RFC 2616 14.20 which says
// "If a server receives a request containing an
// Expect field that includes an expectation-
// extension that it does not support, it MUST
// respond with a 417 (Expectation Failed) status."
w.Header().Set("Connection", "close")
w.WriteHeader(StatusExpectationFailed)
w.finishRequest()
}
// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
// and a Hijacker.
func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
if w.conn.hijacked {
return nil, nil, ErrHijacked
}
w.conn.hijacked = true
rwc = w.conn.rwc
buf = w.conn.buf
w.conn.rwc = nil
w.conn.buf = nil
return
return w.conn.hijack()
}
func (w *response) CloseNotify() <-chan bool {
return w.conn.closeNotify()
}
// The HandlerFunc type is an adapter to allow the use of