diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 1d541a8e46f..80fcc8c4071 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -5232,7 +5232,8 @@ func testServerShutdown(t *testing.T, h2 bool) { defer afterTest(t) var doShutdown func() // set later var shutdownRes = make(chan error, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + var gotOnShutdown = make(chan struct{}, 1) + handler := HandlerFunc(func(w ResponseWriter, r *Request) { go doShutdown() // Shutdown is graceful, so it should not interrupt // this in-flight response. Add a tiny sleep here to @@ -5240,7 +5241,10 @@ func testServerShutdown(t *testing.T, h2 bool) { // bugs. time.Sleep(20 * time.Millisecond) io.WriteString(w, r.RemoteAddr) - })) + }) + cst := newClientServerTest(t, h2, handler, func(srv *httptest.Server) { + srv.Config.RegisterOnShutdown(func() { gotOnShutdown <- struct{}{} }) + }) defer cst.close() doShutdown = func() { @@ -5251,6 +5255,11 @@ func testServerShutdown(t *testing.T, h2 bool) { if err := <-shutdownRes; err != nil { t.Fatalf("Shutdown: %v", err) } + select { + case <-gotOnShutdown: + case <-time.After(5 * time.Second): + t.Errorf("onShutdown callback not called, RegisterOnShutdown broken?") + } res, err := cst.c.Get(cst.ts.URL) if err == nil { diff --git a/src/net/http/server.go b/src/net/http/server.go index 71f46a74f97..b60bd2481e4 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -2395,6 +2395,7 @@ type Server struct { listeners map[net.Listener]struct{} activeConn map[*conn]struct{} doneChan chan struct{} + onShutdown []func() } func (s *Server) getDoneChan() <-chan struct{} { @@ -2475,6 +2476,9 @@ func (srv *Server) Shutdown(ctx context.Context) error { srv.mu.Lock() lnerr := srv.closeListenersLocked() srv.closeDoneChanLocked() + for _, f := range srv.onShutdown { + go f() + } srv.mu.Unlock() ticker := time.NewTicker(shutdownPollInterval) @@ -2491,6 +2495,17 @@ func (srv *Server) Shutdown(ctx context.Context) error { } } +// RegisterOnShutdown registers a function to call on Shutdown. +// This can be used to gracefully shutdown connections that have +// undergone NPN/ALPN protocol upgrade or that have been hijacked. +// This function should start protocol-specific graceful shutdown, +// but should not wait for shutdown to complete. +func (srv *Server) RegisterOnShutdown(f func()) { + srv.mu.Lock() + srv.onShutdown = append(srv.onShutdown, f) + srv.mu.Unlock() +} + // closeIdleConns closes all idle connections and reports whether the // server is quiescent. func (s *Server) closeIdleConns() bool {