diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index 21b10355a9..22047c5826 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -1974,6 +1974,27 @@ func http2summarizeFrame(f http2Frame) string { return buf.String() } +type http2contextContext interface { + context.Context +} + +func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx http2contextContext, cancel func()) { + ctx, cancel = context.WithCancel(context.Background()) + ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr()) + if hs := opts.baseConfig(); hs != nil { + ctx = context.WithValue(ctx, ServerContextKey, hs) + } + return +} + +func http2contextWithCancel(ctx http2contextContext) (_ http2contextContext, cancel func()) { + return context.WithCancel(ctx) +} + +func http2requestWithContext(req *Request, ctx http2contextContext) *Request { + return req.WithContext(ctx) +} + type http2clientTrace httptrace.ClientTrace func http2reqContext(r *Request) context.Context { return r.Context() } @@ -2994,10 +3015,14 @@ func (o *http2ServeConnOpts) handler() Handler { // // The opts parameter is optional. If nil, default values are used. func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { + baseCtx, cancel := http2serverConnBaseContext(c, opts) + defer cancel() + sc := &http2serverConn{ srv: s, hs: opts.baseConfig(), conn: c, + baseCtx: baseCtx, remoteAddrStr: c.RemoteAddr().String(), bw: http2newBufferedWriter(c), handler: opts.handler(), @@ -3016,6 +3041,7 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { serveG: http2newGoroutineLock(), pushEnabled: true, } + sc.flow.add(http2initialWindowSize) sc.inflow.add(http2initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) @@ -3088,6 +3114,7 @@ type http2serverConn struct { conn net.Conn bw *http2bufferedWriter // writing to conn handler Handler + baseCtx http2contextContext framer *http2Framer doneServing chan struct{} // closed when serverConn.serve ends readFrameCh chan http2readFrameResult // written by serverConn.readFrames @@ -3151,10 +3178,12 @@ func (sc *http2serverConn) maxHeaderListSize() uint32 { // responseWriter's state field. type http2stream struct { // immutable: - sc *http2serverConn - id uint32 - body *http2pipe // non-nil if expecting DATA frames - cw http2closeWaiter // closed wait stream transitions to closed state + sc *http2serverConn + id uint32 + body *http2pipe // non-nil if expecting DATA frames + cw http2closeWaiter // closed wait stream transitions to closed state + ctx http2contextContext + cancelCtx func() // owned by serverConn's serve loop: bodyBytes int64 // body bytes seen so far @@ -3818,6 +3847,7 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { } if st != nil { st.gotReset = true + st.cancelCtx() sc.closeStream(st, http2StreamError{f.StreamID, f.ErrCode}) } return nil @@ -3997,10 +4027,13 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { } sc.maxStreamID = id + ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) st = &http2stream{ - sc: sc, - id: id, - state: http2stateOpen, + sc: sc, + id: id, + state: http2stateOpen, + ctx: ctx, + cancelCtx: cancelCtx, } if f.StreamEnded() { st.state = http2stateHalfClosedRemote @@ -4208,6 +4241,7 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead Body: body, Trailer: trailer, } + req = http2requestWithContext(req, st.ctx) if bodyOpen { buf := make([]byte, http2initialWindowSize) @@ -4250,6 +4284,7 @@ func (sc *http2serverConn) getRequestBodyBuf() []byte { func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) { didPanic := true defer func() { + rw.rws.stream.cancelCtx() if didPanic { e := recover() // Same as net/http: diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index e398c92638..c32ff29902 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -4064,10 +4064,16 @@ func TestServerValidatesHeaders(t *testing.T) { } } -func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) { +func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) { + testServerRequestContextCancel_ServeHTTPDone(t, h1Mode) +} +func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) { + testServerRequestContextCancel_ServeHTTPDone(t, h2Mode) +} +func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { defer afterTest(t) ctxc := make(chan context.Context, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() select { case <-ctx.Done(): @@ -4076,8 +4082,8 @@ func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) { } ctxc <- ctx })) - defer ts.Close() - res, err := Get(ts.URL) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -4130,9 +4136,15 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) { } } -func TestServerContext_ServerContextKey(t *testing.T) { +func TestServerContext_ServerContextKey_h1(t *testing.T) { + testServerContext_ServerContextKey(t, h1Mode) +} +func TestServerContext_ServerContextKey_h2(t *testing.T) { + testServerContext_ServerContextKey(t, h2Mode) +} +func testServerContext_ServerContextKey(t *testing.T, h2 bool) { defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() got := ctx.Value(ServerContextKey) if _, ok := got.(*Server); !ok { @@ -4140,12 +4152,14 @@ func TestServerContext_ServerContextKey(t *testing.T) { } got = ctx.Value(LocalAddrContextKey) - if _, ok := got.(net.Addr); !ok { + if addr, ok := got.(net.Addr); !ok { t.Errorf("local addr value = %T; want net.Addr", got) + } else if fmt.Sprint(addr) != r.Host { + t.Errorf("local addr = %v; want %v", addr, r.Host) } })) - defer ts.Close() - res, err := Get(ts.URL) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) }