mirror of
https://github.com/golang/go
synced 2024-11-23 19:40:08 -07:00
net/http/httptest: change Server to use http.Server.ConnState for accounting
With this CL, httptest.Server now uses connection-level accounting of outstanding requests instead of ServeHTTP-level accounting. This is more robust and results in a non-racy shutdown. This is much easier now that net/http.Server has the ConnState hook. Fixes #12789 Fixes #12781 Change-Id: I098cf334a6494316acb66cd07df90766df41764b Reviewed-on: https://go-review.googlesource.com/15151 Reviewed-by: Andrew Gerrand <adg@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
684218e135
commit
a3156aaa12
@ -7,13 +7,17 @@
|
||||
package httptest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A Server is an HTTP server listening on a system-chosen port on the
|
||||
@ -34,24 +38,10 @@ type Server struct {
|
||||
// wg counts the number of outstanding HTTP requests on this server.
|
||||
// Close blocks until all requests are finished.
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// historyListener keeps track of all connections that it's ever
|
||||
// accepted.
|
||||
type historyListener struct {
|
||||
net.Listener
|
||||
sync.Mutex // protects history
|
||||
history []net.Conn
|
||||
}
|
||||
|
||||
func (hs *historyListener) Accept() (c net.Conn, err error) {
|
||||
c, err = hs.Listener.Accept()
|
||||
if err == nil {
|
||||
hs.Lock()
|
||||
hs.history = append(hs.history, c)
|
||||
hs.Unlock()
|
||||
}
|
||||
return
|
||||
mu sync.Mutex // guards closed and conns
|
||||
closed bool
|
||||
conns map[net.Conn]http.ConnState // except terminal states
|
||||
}
|
||||
|
||||
func newLocalListener() net.Listener {
|
||||
@ -103,10 +93,9 @@ func (s *Server) Start() {
|
||||
if s.URL != "" {
|
||||
panic("Server already started")
|
||||
}
|
||||
s.Listener = &historyListener{Listener: s.Listener}
|
||||
s.URL = "http://" + s.Listener.Addr().String()
|
||||
s.wrapHandler()
|
||||
go s.Config.Serve(s.Listener)
|
||||
s.wrap()
|
||||
s.goServe()
|
||||
if *serve != "" {
|
||||
fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
|
||||
select {}
|
||||
@ -134,23 +123,10 @@ func (s *Server) StartTLS() {
|
||||
if len(s.TLS.Certificates) == 0 {
|
||||
s.TLS.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
tlsListener := tls.NewListener(s.Listener, s.TLS)
|
||||
|
||||
s.Listener = &historyListener{Listener: tlsListener}
|
||||
s.Listener = tls.NewListener(s.Listener, s.TLS)
|
||||
s.URL = "https://" + s.Listener.Addr().String()
|
||||
s.wrapHandler()
|
||||
go s.Config.Serve(s.Listener)
|
||||
}
|
||||
|
||||
func (s *Server) wrapHandler() {
|
||||
h := s.Config.Handler
|
||||
if h == nil {
|
||||
h = http.DefaultServeMux
|
||||
}
|
||||
s.Config.Handler = &waitGroupHandler{
|
||||
s: s,
|
||||
h: h,
|
||||
}
|
||||
s.wrap()
|
||||
s.goServe()
|
||||
}
|
||||
|
||||
// NewTLSServer starts and returns a new Server using TLS.
|
||||
@ -161,43 +137,139 @@ func NewTLSServer(handler http.Handler) *Server {
|
||||
return ts
|
||||
}
|
||||
|
||||
type closeIdleTransport interface {
|
||||
CloseIdleConnections()
|
||||
}
|
||||
|
||||
// Close shuts down the server and blocks until all outstanding
|
||||
// requests on this server have completed.
|
||||
func (s *Server) Close() {
|
||||
s.Listener.Close()
|
||||
s.wg.Wait()
|
||||
s.CloseClientConnections()
|
||||
if t, ok := http.DefaultTransport.(*http.Transport); ok {
|
||||
s.mu.Lock()
|
||||
if !s.closed {
|
||||
s.closed = true
|
||||
s.Listener.Close()
|
||||
s.Config.SetKeepAlivesEnabled(false)
|
||||
for c, st := range s.conns {
|
||||
if st == http.StateIdle {
|
||||
s.closeConn(c)
|
||||
}
|
||||
}
|
||||
// If this server doesn't shut down in 5 seconds, tell the user why.
|
||||
t := time.AfterFunc(5*time.Second, s.logCloseHangDebugInfo)
|
||||
defer t.Stop()
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
// Not part of httptest.Server's correctness, but assume most
|
||||
// users of httptest.Server will be using the standard
|
||||
// transport, so help them out and close any idle connections for them.
|
||||
if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
|
||||
t.CloseIdleConnections()
|
||||
}
|
||||
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
// CloseClientConnections closes any currently open HTTP connections
|
||||
// to the test Server.
|
||||
func (s *Server) logCloseHangDebugInfo() {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
|
||||
for c, st := range s.conns {
|
||||
fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
|
||||
}
|
||||
log.Print(buf.String())
|
||||
}
|
||||
|
||||
// CloseClientConnections closes any open HTTP connections to the test Server.
|
||||
func (s *Server) CloseClientConnections() {
|
||||
hl, ok := s.Listener.(*historyListener)
|
||||
if !ok {
|
||||
return
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for c := range s.conns {
|
||||
s.closeConn(c)
|
||||
}
|
||||
hl.Lock()
|
||||
for _, conn := range hl.history {
|
||||
conn.Close()
|
||||
}
|
||||
hl.Unlock()
|
||||
}
|
||||
|
||||
// waitGroupHandler wraps a handler, incrementing and decrementing a
|
||||
// sync.WaitGroup on each request, to enable Server.Close to block
|
||||
// until outstanding requests are finished.
|
||||
type waitGroupHandler struct {
|
||||
s *Server
|
||||
h http.Handler // non-nil
|
||||
func (s *Server) goServe() {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
s.Config.Serve(s.Listener)
|
||||
}()
|
||||
}
|
||||
|
||||
func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
h.s.wg.Add(1)
|
||||
defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
|
||||
h.h.ServeHTTP(w, r)
|
||||
// wrap installs the connection state-tracking hook to know which
|
||||
// connections are idle.
|
||||
func (s *Server) wrap() {
|
||||
oldHook := s.Config.ConnState
|
||||
s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
switch cs {
|
||||
case http.StateNew:
|
||||
s.wg.Add(1)
|
||||
if _, exists := s.conns[c]; exists {
|
||||
panic("invalid state transition")
|
||||
}
|
||||
if s.conns == nil {
|
||||
s.conns = make(map[net.Conn]http.ConnState)
|
||||
}
|
||||
s.conns[c] = cs
|
||||
if s.closed {
|
||||
// Probably just a socket-late-binding dial from
|
||||
// the default transport that lost the race (and
|
||||
// thus this connection is now idle and will
|
||||
// never be used).
|
||||
s.closeConn(c)
|
||||
}
|
||||
case http.StateActive:
|
||||
if oldState, ok := s.conns[c]; ok {
|
||||
if oldState != http.StateNew && oldState != http.StateIdle {
|
||||
panic("invalid state transition")
|
||||
}
|
||||
s.conns[c] = cs
|
||||
}
|
||||
case http.StateIdle:
|
||||
if oldState, ok := s.conns[c]; ok {
|
||||
if oldState != http.StateActive {
|
||||
panic("invalid state transition")
|
||||
}
|
||||
s.conns[c] = cs
|
||||
}
|
||||
if s.closed {
|
||||
s.closeConn(c)
|
||||
}
|
||||
case http.StateHijacked, http.StateClosed:
|
||||
s.forgetConn(c)
|
||||
}
|
||||
if oldHook != nil {
|
||||
oldHook(c, cs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeConn closes c. Except on plan9, which is special. See comment below.
|
||||
// s.mu must be held.
|
||||
func (s *Server) closeConn(c net.Conn) {
|
||||
if runtime.GOOS == "plan9" {
|
||||
// Go's Plan 9 net package isn't great at unblocking reads when
|
||||
// their underlying TCP connections are closed. Don't trust
|
||||
// that that the ConnState state machine will get to
|
||||
// StateClosed. Instead, just go there directly. Plan 9 may leak
|
||||
// resources if the syscall doesn't end up returning. Oh well.
|
||||
s.forgetConn(c)
|
||||
}
|
||||
go c.Close()
|
||||
}
|
||||
|
||||
// forgetConn removes c from the set of tracked conns and decrements it from the
|
||||
// waitgroup, unless it was previously removed.
|
||||
// s.mu must be held.
|
||||
func (s *Server) forgetConn(c net.Conn) {
|
||||
if _, ok := s.conns[c]; ok {
|
||||
delete(s.conns, c)
|
||||
s.wg.Done()
|
||||
}
|
||||
}
|
||||
|
||||
// localhostCert is a PEM-encoded TLS cert with SAN IPs
|
||||
|
@ -27,3 +27,30 @@ func TestServer(t *testing.T) {
|
||||
t.Errorf("got %q, want hello", string(got))
|
||||
}
|
||||
}
|
||||
|
||||
// Issue 12781
|
||||
func TestGetAfterClose(t *testing.T) {
|
||||
ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hello"))
|
||||
}))
|
||||
|
||||
res, err := http.Get(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(got) != "hello" {
|
||||
t.Fatalf("got %q, want hello", string(got))
|
||||
}
|
||||
|
||||
ts.Close()
|
||||
|
||||
res, err = http.Get(ts.URL)
|
||||
if err == nil {
|
||||
body, _ := ioutil.ReadAll(res.Body)
|
||||
t.Fatalf("Unexected response after close: %v, %v, %s", res.Status, res.Header, body)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user