diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go index f8ba53288e..2adc06f39b 100644 --- a/src/go/build/deps_test.go +++ b/src/go/build/deps_test.go @@ -411,7 +411,7 @@ var pkgDeps = map[string][]string{ "net/http/cgi": {"L4", "NET", "OS", "crypto/tls", "net/http", "regexp"}, "net/http/cookiejar": {"L4", "NET", "net/http"}, "net/http/fcgi": {"L4", "NET", "OS", "net/http", "net/http/cgi"}, - "net/http/httptest": {"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal"}, + "net/http/httptest": {"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal", "crypto/x509"}, "net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal"}, "net/http/pprof": {"L4", "OS", "html/template", "net/http", "runtime/pprof", "runtime/trace"}, "net/rpc": {"L4", "NET", "encoding/gob", "html/template", "net/http"}, diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go index bd2c49642b..e3d392130e 100644 --- a/src/net/http/httptest/example_test.go +++ b/src/net/http/httptest/example_test.go @@ -54,3 +54,25 @@ func ExampleServer() { fmt.Printf("%s", greeting) // Output: Hello, client } + +func ExampleNewTLSServer() { + ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + client := ts.Client() + res, err := client.Get(ts.URL) + if err != nil { + log.Fatal(err) + } + + greeting, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + log.Fatal(err) + } + + fmt.Printf("%s", greeting) + // Output: Hello, client +} diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go index 711821433b..56ad18ee9b 100644 --- a/src/net/http/httptest/server.go +++ b/src/net/http/httptest/server.go @@ -9,6 +9,7 @@ package httptest import ( "bytes" "crypto/tls" + "crypto/x509" "flag" "fmt" "log" @@ -35,6 +36,9 @@ type Server struct { // before Start or StartTLS. Config *http.Server + // certificate is a parsed version of the TLS config certificate, if present. + certificate *x509.Certificate + // wg counts the number of outstanding HTTP requests on this server. // Close blocks until all requests are finished. wg sync.WaitGroup @@ -42,6 +46,10 @@ type Server struct { mu sync.Mutex // guards closed and conns closed bool conns map[net.Conn]http.ConnState // except terminal states + + // client is configured for use with the server. + // Its transport is automatically closed when Close is called. + client *http.Client } func newLocalListener() net.Listener { @@ -85,6 +93,7 @@ func NewUnstartedServer(handler http.Handler) *Server { return &Server{ Listener: newLocalListener(), Config: &http.Server{Handler: handler}, + client: &http.Client{}, } } @@ -124,6 +133,17 @@ func (s *Server) StartTLS() { if len(s.TLS.Certificates) == 0 { s.TLS.Certificates = []tls.Certificate{cert} } + s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0]) + if err != nil { + panic(fmt.Sprintf("httptest: NewTLSServer: %v", err)) + } + certpool := x509.NewCertPool() + certpool.AddCert(s.certificate) + s.client.Transport = &http.Transport{ + TLSClientConfig: &tls.Config{ + RootCAs: certpool, + }, + } s.Listener = tls.NewListener(s.Listener, s.TLS) s.URL = "https://" + s.Listener.Addr().String() s.wrap() @@ -186,6 +206,11 @@ func (s *Server) Close() { t.CloseIdleConnections() } + // Also close the client idle connections. + if t, ok := s.client.Transport.(closeIdleTransport); ok { + t.CloseIdleConnections() + } + s.wg.Wait() } @@ -228,6 +253,19 @@ func (s *Server) CloseClientConnections() { } } +// Certificate returns the certificate used by the server, or nil if +// the server doesn't use TLS. +func (s *Server) Certificate() *x509.Certificate { + return s.certificate +} + +// Client returns an HTTP client configured for making requests to the server. +// It is configured to trust the server's TLS test certificate and will +// close its idle connections on Server.Close. +func (s *Server) Client() *http.Client { + return s.client +} + func (s *Server) goServe() { s.wg.Add(1) go func() { diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go index d032c5983b..7d80fa15dd 100644 --- a/src/net/http/httptest/server_test.go +++ b/src/net/http/httptest/server_test.go @@ -22,6 +22,7 @@ func TestServer(t *testing.T) { t.Fatal(err) } got, err := ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { t.Fatal(err) } @@ -98,3 +99,25 @@ func TestServerCloseClientConnections(t *testing.T) { t.Fatalf("Unexpected response: %#v", res) } } + +// Tests that the Server.Client method works and returns an http.Client that can hit +// NewTLSServer without cert warnings. +func TestServerClient(t *testing.T) { + ts := NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("hello")) + })) + defer ts.Close() + client := ts.Client() + res, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + got, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(got) != "hello" { + t.Errorf("got %q, want hello", string(got)) + } +}