mirror of
https://github.com/golang/go
synced 2024-11-20 04:54:44 -07:00
net/http: implement CloseNotifier
Fixes #2510 R=golang-dev, rsc CC=golang-dev https://golang.org/cl/6867050
This commit is contained in:
parent
a034fc9855
commit
4fb78c3a16
@ -1252,6 +1252,42 @@ func TestContentLengthZero(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCloseNotifier(t *testing.T) {
|
||||
gotReq := make(chan bool, 1)
|
||||
sawClose := make(chan bool, 1)
|
||||
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||||
gotReq <- true
|
||||
cc := rw.(CloseNotifier).CloseNotify()
|
||||
<-cc
|
||||
sawClose <- true
|
||||
}))
|
||||
conn, err := net.Dial("tcp", ts.Listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("error dialing: %v", err)
|
||||
}
|
||||
diec := make(chan bool)
|
||||
go func() {
|
||||
_, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
<-diec
|
||||
conn.Close()
|
||||
}()
|
||||
For:
|
||||
for {
|
||||
select {
|
||||
case <-gotReq:
|
||||
diec <- true
|
||||
case <-sawClose:
|
||||
break For
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
}
|
||||
ts.Close()
|
||||
}
|
||||
|
||||
// goTimeout runs f, failing t if f takes more than ns to complete.
|
||||
func goTimeout(t *testing.T, d time.Duration, f func()) {
|
||||
ch := make(chan bool, 2)
|
||||
|
@ -93,16 +93,104 @@ type Hijacker interface {
|
||||
Hijack() (net.Conn, *bufio.ReadWriter, error)
|
||||
}
|
||||
|
||||
// The CloseNotifier interface is implemented by ResponseWriters which
|
||||
// allow detecting when the underlying connection has gone away.
|
||||
//
|
||||
// This mechanism can be used to cancel long operations on the server
|
||||
// if the client has disconnected before the response is ready.
|
||||
type CloseNotifier interface {
|
||||
// CloseNotify returns a channel that receives a single value
|
||||
// when the client connection has gone away.
|
||||
CloseNotify() <-chan bool
|
||||
}
|
||||
|
||||
// A conn represents the server side of an HTTP connection.
|
||||
type conn struct {
|
||||
remoteAddr string // network address of remote side
|
||||
server *Server // the Server on which the connection arrived
|
||||
rwc net.Conn // i/o connection
|
||||
lr *io.LimitedReader // io.LimitReader(rwc)
|
||||
buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->rwc
|
||||
hijacked bool // connection has been hijacked by handler
|
||||
sr switchReader // where the LimitReader reads from; usually the rwc
|
||||
lr *io.LimitedReader // io.LimitReader(sr)
|
||||
buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc
|
||||
tlsState *tls.ConnectionState // or nil when not using TLS
|
||||
body []byte
|
||||
|
||||
mu sync.Mutex // guards the following
|
||||
clientGone bool // if client has disconnected mid-request
|
||||
closeNotifyc chan bool // made lazily
|
||||
hijackedv bool // connection has been hijacked by handler
|
||||
}
|
||||
|
||||
func (c *conn) hijacked() bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
return c.hijackedv
|
||||
}
|
||||
|
||||
func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.hijackedv {
|
||||
return nil, nil, ErrHijacked
|
||||
}
|
||||
if c.closeNotifyc != nil {
|
||||
return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier")
|
||||
}
|
||||
c.hijackedv = true
|
||||
rwc = c.rwc
|
||||
buf = c.buf
|
||||
c.rwc = nil
|
||||
c.buf = nil
|
||||
return
|
||||
}
|
||||
|
||||
func (c *conn) closeNotify() <-chan bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closeNotifyc == nil {
|
||||
c.closeNotifyc = make(chan bool)
|
||||
if c.hijackedv {
|
||||
// to obey the function signature, even though
|
||||
// it'll never receive a value.
|
||||
return c.closeNotifyc
|
||||
}
|
||||
pr, pw := io.Pipe()
|
||||
|
||||
readSource := c.sr.r
|
||||
c.sr.Lock()
|
||||
c.sr.r = pr
|
||||
c.sr.Unlock()
|
||||
go func() {
|
||||
_, err := io.Copy(pw, readSource)
|
||||
if err == nil {
|
||||
err = io.EOF
|
||||
}
|
||||
pw.CloseWithError(err)
|
||||
c.noteClientGone()
|
||||
}()
|
||||
}
|
||||
return c.closeNotifyc
|
||||
}
|
||||
|
||||
func (c *conn) noteClientGone() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closeNotifyc != nil && !c.clientGone {
|
||||
c.closeNotifyc <- true
|
||||
}
|
||||
c.clientGone = true
|
||||
}
|
||||
|
||||
type switchReader struct {
|
||||
sync.Mutex
|
||||
r io.Reader
|
||||
}
|
||||
|
||||
func (sr *switchReader) Read(p []byte) (n int, err error) {
|
||||
sr.Lock()
|
||||
r := sr.r
|
||||
sr.Unlock()
|
||||
return r.Read(p)
|
||||
}
|
||||
|
||||
// A response represents the server side of an HTTP response.
|
||||
@ -183,8 +271,9 @@ func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
|
||||
if debugServerConnections {
|
||||
c.rwc = newLoggingConn("server", c.rwc)
|
||||
}
|
||||
c.sr = switchReader{r: c.rwc}
|
||||
c.body = make([]byte, sniffLen)
|
||||
c.lr = io.LimitReader(c.rwc, noLimit).(*io.LimitedReader)
|
||||
c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader)
|
||||
br := bufio.NewReader(c.lr)
|
||||
bw := bufio.NewWriter(c.rwc)
|
||||
c.buf = bufio.NewReadWriter(br, bw)
|
||||
@ -215,7 +304,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
|
||||
if ecr.closed {
|
||||
return 0, errors.New("http: Read after Close on request Body")
|
||||
}
|
||||
if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked {
|
||||
if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() {
|
||||
ecr.resp.wroteContinue = true
|
||||
io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n")
|
||||
ecr.resp.conn.buf.Flush()
|
||||
@ -238,7 +327,7 @@ var errTooLarge = errors.New("http: request too large")
|
||||
|
||||
// Read next request from connection.
|
||||
func (c *conn) readRequest() (w *response, err error) {
|
||||
if c.hijacked {
|
||||
if c.hijacked() {
|
||||
return nil, ErrHijacked
|
||||
}
|
||||
c.lr.N = int64(c.server.maxHeaderBytes()) + 4096 /* bufio slop */
|
||||
@ -279,7 +368,7 @@ func (w *response) Header() Header {
|
||||
const maxPostHandlerReadBytes = 256 << 10
|
||||
|
||||
func (w *response) WriteHeader(code int) {
|
||||
if w.conn.hijacked {
|
||||
if w.conn.hijacked() {
|
||||
log.Print("http: response.WriteHeader on hijacked connection")
|
||||
return
|
||||
}
|
||||
@ -455,7 +544,7 @@ func (w *response) bodyAllowed() bool {
|
||||
}
|
||||
|
||||
func (w *response) Write(data []byte) (n int, err error) {
|
||||
if w.conn.hijacked {
|
||||
if w.conn.hijacked() {
|
||||
log.Print("http: response.Write on hijacked connection")
|
||||
return 0, ErrHijacked
|
||||
}
|
||||
@ -673,6 +762,36 @@ func (c *conn) serve() {
|
||||
}
|
||||
req.Header.Del("Expect")
|
||||
} else if req.Header.get("Expect") != "" {
|
||||
w.sendExpectationFailed()
|
||||
break
|
||||
}
|
||||
|
||||
handler := c.server.Handler
|
||||
if handler == nil {
|
||||
handler = DefaultServeMux
|
||||
}
|
||||
|
||||
// HTTP cannot have multiple simultaneous active requests.[*]
|
||||
// Until the server replies to this request, it can't read another,
|
||||
// so we might as well run the handler in this goroutine.
|
||||
// [*] Not strictly true: HTTP pipelining. We could let them all process
|
||||
// in parallel even if their responses need to be serialized.
|
||||
handler.ServeHTTP(w, w.req)
|
||||
if c.hijacked() {
|
||||
return
|
||||
}
|
||||
w.finishRequest()
|
||||
if w.closeAfterReply {
|
||||
if w.requestBodyLimitHit {
|
||||
c.closeWriteAndWait()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
c.close()
|
||||
}
|
||||
|
||||
func (w *response) sendExpectationFailed() {
|
||||
// TODO(bradfitz): let ServeHTTP handlers handle
|
||||
// requests with non-standard expectation[s]? Seems
|
||||
// theoretical at best, and doesn't fit into the
|
||||
@ -688,46 +807,16 @@ func (c *conn) serve() {
|
||||
w.Header().Set("Connection", "close")
|
||||
w.WriteHeader(StatusExpectationFailed)
|
||||
w.finishRequest()
|
||||
break
|
||||
}
|
||||
|
||||
handler := c.server.Handler
|
||||
if handler == nil {
|
||||
handler = DefaultServeMux
|
||||
}
|
||||
|
||||
// HTTP cannot have multiple simultaneous active requests.[*]
|
||||
// Until the server replies to this request, it can't read another,
|
||||
// so we might as well run the handler in this goroutine.
|
||||
// [*] Not strictly true: HTTP pipelining. We could let them all process
|
||||
// in parallel even if their responses need to be serialized.
|
||||
handler.ServeHTTP(w, w.req)
|
||||
if c.hijacked {
|
||||
return
|
||||
}
|
||||
w.finishRequest()
|
||||
if w.closeAfterReply {
|
||||
if w.requestBodyLimitHit {
|
||||
c.closeWriteAndWait()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
c.close()
|
||||
}
|
||||
|
||||
// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
|
||||
// and a Hijacker.
|
||||
func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
|
||||
if w.conn.hijacked {
|
||||
return nil, nil, ErrHijacked
|
||||
return w.conn.hijack()
|
||||
}
|
||||
w.conn.hijacked = true
|
||||
rwc = w.conn.rwc
|
||||
buf = w.conn.buf
|
||||
w.conn.rwc = nil
|
||||
w.conn.buf = nil
|
||||
return
|
||||
|
||||
func (w *response) CloseNotify() <-chan bool {
|
||||
return w.conn.closeNotify()
|
||||
}
|
||||
|
||||
// The HandlerFunc type is an adapter to allow the use of
|
||||
|
Loading…
Reference in New Issue
Block a user