From 747e1961e95c2eb3df62e045b90b111c2ceea337 Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Mon, 3 Oct 2022 16:07:48 -0700 Subject: [PATCH] net/http: refactor tests to run most in HTTP/1 and HTTP/2 modes Replace the ad-hoc approach to running tests in HTTP/1 and HTTP/2 modes with a 'run' function that executes a test in various modes. By default, these modes are HTTP/1 and HTTP/2, but tests can opt-in to HTTPS/1 as well. The 'run' function also takes care of post-test cleanup (running the afterTest function). The 'run' function runs tests in parallel by default. Tests which can't run in parallel (generally because they use global test hooks) pass a testNotParallel option to disable parallelism. Update clientServerTest to use t.Cleanup to clean up after itself, rather than leaving this up to tests to handle. Drop an unnecessary mutex in SetReadLoopBeforeNextReadHook. Test hooks can't be set in parallel, and we want the race detector to notify us if two simultaneous tests try to set a hook. Fixes #56032 Change-Id: I16be64913c426fc93d84abc6ad85dbd3bc191224 Reviewed-on: https://go-review.googlesource.com/c/go/+/438137 TryBot-Result: Gopher Robot Run-TryBot: Damien Neil Reviewed-by: Brad Fitzpatrick Reviewed-by: David Chase --- src/net/http/client_test.go | 337 ++++---- src/net/http/clientserver_test.go | 441 ++++++----- src/net/http/export_test.go | 4 +- src/net/http/fs_test.go | 201 +++-- src/net/http/request_test.go | 60 +- src/net/http/serve_test.go | 1186 ++++++++++++----------------- src/net/http/sniff_test.go | 39 +- src/net/http/transport_test.go | 1054 +++++++++++++------------ 8 files changed, 1503 insertions(+), 1819 deletions(-) diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index 44b532ae1f3..8b53c41687a 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -67,11 +67,9 @@ func (w chanWriter) Write(p []byte) (n int, err error) { return len(p), nil } -func TestClient(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(robotsTxtHandler) - defer ts.Close() +func TestClient(t *testing.T) { run(t, testClient) } +func testClient(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, robotsTxtHandler).ts c := ts.Client() r, err := c.Get(ts.URL) @@ -87,14 +85,9 @@ func TestClient(t *testing.T) { } } -func TestClientHead_h1(t *testing.T) { testClientHead(t, h1Mode) } -func TestClientHead_h2(t *testing.T) { testClientHead(t, h2Mode) } - -func testClientHead(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, robotsTxtHandler) - defer cst.close() - +func TestClientHead(t *testing.T) { run(t, testClientHead) } +func testClientHead(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, robotsTxtHandler) r, err := cst.c.Head(cst.ts.URL) if err != nil { t.Fatal(err) @@ -200,11 +193,10 @@ func TestPostFormRequestFormat(t *testing.T) { } } -func TestClientRedirects(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestClientRedirects(t *testing.T) { run(t, testClientRedirects) } +func testClientRedirects(t *testing.T, mode testMode) { var ts *httptest.Server - ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { n, _ := strconv.Atoi(r.FormValue("n")) // Test Referer header. (7 is arbitrary position to test at) if n == 7 { @@ -217,8 +209,7 @@ func TestClientRedirects(t *testing.T) { return } fmt.Fprintf(w, "n=%d", n) - })) - defer ts.Close() + })).ts c := ts.Client() _, err := c.Get(ts.URL) @@ -299,13 +290,11 @@ func TestClientRedirects(t *testing.T) { } // Tests that Client redirects' contexts are derived from the original request's context. -func TestClientRedirectContext(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientRedirectsContext(t *testing.T) { run(t, testClientRedirectsContext) } +func testClientRedirectsContext(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Redirect(w, r, "/", StatusTemporaryRedirect) - })) - defer ts.Close() + })).ts ctx, cancel := context.WithCancel(context.Background()) c := ts.Client() @@ -373,7 +362,9 @@ func TestPostRedirects(t *testing.T) { `POST /?code=404 "c404"`, } want := strings.Join(wantSegments, "\n") - testRedirectsByMethod(t, "POST", postRedirectTests, want) + run(t, func(t *testing.T, mode testMode) { + testRedirectsByMethod(t, mode, "POST", postRedirectTests, want) + }) } func TestDeleteRedirects(t *testing.T) { @@ -410,17 +401,18 @@ func TestDeleteRedirects(t *testing.T) { `DELETE /?code=404 "c404"`, } want := strings.Join(wantSegments, "\n") - testRedirectsByMethod(t, "DELETE", deleteRedirectTests, want) + run(t, func(t *testing.T, mode testMode) { + testRedirectsByMethod(t, mode, "DELETE", deleteRedirectTests, want) + }) } -func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, want string) { - defer afterTest(t) +func testRedirectsByMethod(t *testing.T, mode testMode, method string, table []redirectTest, want string) { var log struct { sync.Mutex bytes.Buffer } var ts *httptest.Server - ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { log.Lock() slurp, _ := io.ReadAll(r.Body) fmt.Fprintf(&log.Buffer, "%s %s %q", r.Method, r.RequestURI, slurp) @@ -445,8 +437,7 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa } w.WriteHeader(code) } - })) - defer ts.Close() + })).ts c := ts.Client() for _, tt := range table { @@ -491,12 +482,11 @@ func removeCommonLines(a, b string) (asuffix, bsuffix string, commonLines int) { } } -func TestClientRedirectUseResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestClientRedirectUseResponse(t *testing.T) { run(t, testClientRedirectUseResponse) } +func testClientRedirectUseResponse(t *testing.T, mode testMode) { const body = "Hello, world." var ts *httptest.Server - ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts = newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if strings.Contains(r.URL.Path, "/other") { io.WriteString(w, "wrong body") } else { @@ -504,8 +494,7 @@ func TestClientRedirectUseResponse(t *testing.T) { w.WriteHeader(StatusFound) io.WriteString(w, body) } - })) - defer ts.Close() + })).ts c := ts.Client() c.CheckRedirect = func(req *Request, via []*Request) error { @@ -533,18 +522,16 @@ func TestClientRedirectUseResponse(t *testing.T) { // Issues 17773 and 49281: don't follow a 3xx if the response doesn't // have a Location header. -func TestClientRedirectNoLocation(t *testing.T) { +func TestClientRedirectNoLocation(t *testing.T) { run(t, testClientRedirectNoLocation) } +func testClientRedirectNoLocation(t *testing.T, mode testMode) { for _, code := range []int{301, 308} { t.Run(fmt.Sprint(code), func(t *testing.T) { setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Foo", "Bar") w.WriteHeader(code) })) - defer ts.Close() - c := ts.Client() - res, err := c.Get(ts.URL) + res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } @@ -560,15 +547,13 @@ func TestClientRedirectNoLocation(t *testing.T) { } // Don't follow a 307/308 if we can't resent the request body. -func TestClientRedirect308NoGetBody(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestClientRedirect308NoGetBody(t *testing.T) { run(t, testClientRedirect308NoGetBody) } +func testClientRedirect308NoGetBody(t *testing.T, mode testMode) { const fakeURL = "https://localhost:1234/" // won't be hit - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Location", fakeURL) w.WriteHeader(308) - })) - defer ts.Close() + })).ts req, err := NewRequest("POST", ts.URL, strings.NewReader("some body")) if err != nil { t.Fatal(err) @@ -659,12 +644,10 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { return j.perURL[u.Host] } -func TestRedirectCookiesJar(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestRedirectCookiesJar(t *testing.T) { run(t, testRedirectCookiesJar) } +func testRedirectCookiesJar(t *testing.T, mode testMode) { var ts *httptest.Server - ts = httptest.NewServer(echoCookiesRedirectHandler) - defer ts.Close() + ts = newClientServerTest(t, mode, echoCookiesRedirectHandler).ts c := ts.Client() c.Jar = new(TestJar) u, _ := url.Parse(ts.URL) @@ -696,9 +679,9 @@ func matchReturnedCookies(t *testing.T, expected, given []*Cookie) { } } -func TestJarCalls(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestJarCalls(t *testing.T) { run(t, testJarCalls, []testMode{http1Mode}) } +func testJarCalls(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { pathSuffix := r.RequestURI[1:] if r.RequestURI == "/nosetcookie" { return // don't set cookies for this path @@ -707,8 +690,7 @@ func TestJarCalls(t *testing.T) { if r.RequestURI == "/" { Redirect(w, r, "http://secondhost.fake/secondpath", 302) } - })) - defer ts.Close() + })).ts jar := new(RecordingJar) c := ts.Client() c.Jar = jar @@ -757,20 +739,16 @@ func (j *RecordingJar) logf(format string, args ...any) { fmt.Fprintf(&j.log, format, args...) } -func TestStreamingGet_h1(t *testing.T) { testStreamingGet(t, h1Mode) } -func TestStreamingGet_h2(t *testing.T) { testStreamingGet(t, h2Mode) } - -func testStreamingGet(t *testing.T, h2 bool) { - defer afterTest(t) +func TestStreamingGet(t *testing.T) { run(t, testStreamingGet) } +func testStreamingGet(t *testing.T, mode testMode) { say := make(chan string) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() for str := range say { w.Write([]byte(str)) w.(Flusher).Flush() } })) - defer cst.close() c := cst.c res, err := c.Get(cst.ts.URL) @@ -811,11 +789,10 @@ func (c *writeCountingConn) Write(p []byte) (int, error) { // TestClientWrites verifies that client requests are buffered and we // don't send a TCP packet per line of the http request + body. -func TestClientWrites(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - })) - defer ts.Close() +func TestClientWrites(t *testing.T) { run(t, testClientWrites, []testMode{http1Mode}) } +func testClientWrites(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + })).ts writes := 0 dialer := func(netz string, addr string) (net.Conn, error) { @@ -847,11 +824,12 @@ func TestClientWrites(t *testing.T) { } func TestClientInsecureTransport(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testClientInsecureTransport, []testMode{https1Mode, http2Mode}) +} +func testClientInsecureTransport(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) - })) + })).ts errc := make(chanWriter, 10) // but only expecting 1 ts.Config.ErrorLog = log.New(errc, "", 0) defer ts.Close() @@ -898,15 +876,15 @@ func TestClientErrorWithRequestURI(t *testing.T) { } func TestClientWithCorrectTLSServerName(t *testing.T) { - defer afterTest(t) - + run(t, testClientWithCorrectTLSServerName, []testMode{https1Mode, http2Mode}) +} +func testClientWithCorrectTLSServerName(t *testing.T, mode testMode) { const serverName = "example.com" - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS.ServerName != serverName { t.Errorf("expected client to set ServerName %q, got: %q", serverName, r.TLS.ServerName) } - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).TLSClientConfig.ServerName = serverName @@ -916,9 +894,10 @@ func TestClientWithCorrectTLSServerName(t *testing.T) { } func TestClientWithIncorrectTLSServerName(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + run(t, testClientWithIncorrectTLSServerName, []testMode{https1Mode, http2Mode}) +} +func testClientWithIncorrectTLSServerName(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts errc := make(chanWriter, 10) // but only expecting 1 ts.Config.ErrorLog = log.New(errc, "", 0) @@ -951,11 +930,12 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { // // The httptest.Server has a cert with "example.com" as its name. func TestTransportUsesTLSConfigServerName(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportUsesTLSConfigServerName, []testMode{https1Mode, http2Mode}) +} +func testTransportUsesTLSConfigServerName(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -971,11 +951,12 @@ func TestTransportUsesTLSConfigServerName(t *testing.T) { } func TestResponseSetsTLSConnectionState(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testResponseSetsTLSConnectionState, []testMode{https1Mode}) +} +func testResponseSetsTLSConnectionState(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1001,10 +982,11 @@ func TestResponseSetsTLSConnectionState(t *testing.T) { // to determine that the server is speaking HTTP. // See golang.org/issue/11111. func TestHTTPSClientDetectsHTTPServer(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + run(t, testHTTPSClientDetectsHTTPServer, []testMode{http1Mode}) +} +func testHTTPSClientDetectsHTTPServer(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts ts.Config.ErrorLog = quietLog - defer ts.Close() _, err := Get(strings.Replace(ts.URL, "http", "https", 1)) if got := err.Error(); !strings.Contains(got, "HTTP response to HTTPS client") { @@ -1013,22 +995,13 @@ func TestHTTPSClientDetectsHTTPServer(t *testing.T) { } // Verify Response.ContentLength is populated. https://golang.org/issue/4126 -func TestClientHeadContentLength_h1(t *testing.T) { - testClientHeadContentLength(t, h1Mode) -} - -func TestClientHeadContentLength_h2(t *testing.T) { - testClientHeadContentLength(t, h2Mode) -} - -func testClientHeadContentLength(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientHeadContentLength(t *testing.T) { run(t, testClientHeadContentLength) } +func testClientHeadContentLength(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if v := r.FormValue("cl"); v != "" { w.Header().Set("Content-Length", v) } })) - defer cst.close() tests := []struct { suffix string want int64 @@ -1056,11 +1029,10 @@ func testClientHeadContentLength(t *testing.T, h2 bool) { } } -func TestEmptyPasswordAuth(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestEmptyPasswordAuth(t *testing.T) { run(t, testEmptyPasswordAuth) } +func testEmptyPasswordAuth(t *testing.T, mode testMode) { gopher := "gopher" - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { auth := r.Header.Get("Authorization") if strings.HasPrefix(auth, "Basic ") { encoded := auth[6:] @@ -1076,7 +1048,7 @@ func TestEmptyPasswordAuth(t *testing.T) { } else { t.Errorf("Invalid auth %q", auth) } - })) + })).ts defer ts.Close() req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -1205,19 +1177,14 @@ func TestStripPasswordFromError(t *testing.T) { } } -func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) } -func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) } - -func testClientTimeout(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestClientTimeout(t *testing.T) { run(t, testClientTimeout) } +func testClientTimeout(t *testing.T, mode testMode) { var ( mu sync.Mutex nonce string // a unique per-request string sawSlowNonce bool // true if the handler saw /slow?nonce= ) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _ = r.ParseForm() if r.URL.Path == "/" { Redirect(w, r, "/slow?nonce="+r.Form.Get("nonce"), StatusFound) @@ -1238,7 +1205,6 @@ func testClientTimeout(t *testing.T, h2 bool) { return } })) - defer cst.close() // Try to trigger a timeout after reading part of the response body. // The initial timeout is emprically usually long enough on a decently fast @@ -1308,18 +1274,13 @@ func testClientTimeout(t *testing.T, h2 bool) { } } -func TestClientTimeout_Headers_h1(t *testing.T) { testClientTimeout_Headers(t, h1Mode) } -func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h2Mode) } - // Client.Timeout firing before getting to the body -func testClientTimeout_Headers(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestClientTimeout_Headers(t *testing.T) { run(t, testClientTimeout_Headers) } +func testClientTimeout_Headers(t *testing.T, mode testMode) { donec := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-donec }), optQuietLog) - defer cst.close() // Note that we use a channel send here and not a close. // The race detector doesn't know that we're waiting for a timeout // and thinks that the waitgroup inside httptest.Server is added to concurrently @@ -1355,18 +1316,15 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) { // Issue 16094: if Client.Timeout is set but not hit, a Timeout error shouldn't be // returned. -func TestClientTimeoutCancel(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestClientTimeoutCancel(t *testing.T) { run(t, testClientTimeoutCancel) } +func testClientTimeoutCancel(t *testing.T, mode testMode) { testDone := make(chan struct{}) ctx, cancel := context.WithCancel(context.Background()) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() <-testDone })) - defer cst.close() defer close(testDone) cst.c.Timeout = 1 * time.Hour @@ -1383,18 +1341,12 @@ func TestClientTimeoutCancel(t *testing.T) { } } -func TestClientTimeoutDoesNotExpire_h1(t *testing.T) { testClientTimeoutDoesNotExpire(t, h1Mode) } -func TestClientTimeoutDoesNotExpire_h2(t *testing.T) { testClientTimeoutDoesNotExpire(t, h2Mode) } - // Issue 49366: if Client.Timeout is set but not hit, no error should be returned. -func testClientTimeoutDoesNotExpire(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientTimeoutDoesNotExpire(t *testing.T) { run(t, testClientTimeoutDoesNotExpire) } +func testClientTimeoutDoesNotExpire(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("body")) })) - defer cst.close() cst.c.Timeout = 1 * time.Hour req, _ := NewRequest("GET", cst.ts.URL, nil) @@ -1410,19 +1362,15 @@ func testClientTimeoutDoesNotExpire(t *testing.T, h2 bool) { } } -func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) } -func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) } -func testClientRedirectEatsBody(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestClientRedirectEatsBody_h1(t *testing.T) { run(t, testClientRedirectEatsBody) } +func testClientRedirectEatsBody(t *testing.T, mode testMode) { saw := make(chan string, 2) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { saw <- r.RemoteAddr if r.URL.Path == "/" { Redirect(w, r, "/foo", StatusFound) // which includes a body } })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1522,13 +1470,14 @@ func TestClientRedirectResponseWithoutRequest(t *testing.T) { } // Issue 4800: copy (some) headers when Client follows a redirect. -func TestClientCopyHeadersOnRedirect(t *testing.T) { +func TestClientCopyHeadersOnRedirect(t *testing.T) { run(t, testClientCopyHeadersOnRedirect) } +func testClientCopyHeadersOnRedirect(t *testing.T, mode testMode) { const ( ua = "some-agent/1.2" xfoo = "foo-val" ) var ts2URL string - ts1 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts1 := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { want := Header{ "User-Agent": []string{ua}, "X-Foo": []string{xfoo}, @@ -1543,12 +1492,10 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) { } else { w.Header().Set("Result", "ok") } - })) - defer ts1.Close() - ts2 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + })).ts + ts2 := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Redirect(w, r, ts1.URL, StatusFound) - })) - defer ts2.Close() + })).ts ts2URL = ts2.URL c := ts1.Client() @@ -1583,22 +1530,24 @@ func TestClientCopyHeadersOnRedirect(t *testing.T) { } // Issue 22233: copy host when Client follows a relative redirect. -func TestClientCopyHostOnRedirect(t *testing.T) { +func TestClientCopyHostOnRedirect(t *testing.T) { run(t, testClientCopyHostOnRedirect) } +func testClientCopyHostOnRedirect(t *testing.T, mode testMode) { // Virtual hostname: should not receive any request. - virtual := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + virtual := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Errorf("Virtual host received request %v", r.URL) w.WriteHeader(403) io.WriteString(w, "should not see this response") - })) + })).ts defer virtual.Close() virtualHost := strings.TrimPrefix(virtual.URL, "http://") + virtualHost = strings.TrimPrefix(virtualHost, "https://") t.Logf("Virtual host is %v", virtualHost) // Actual hostname: should not receive any request. const wantBody = "response body" var tsURL string var tsHost string - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { switch r.URL.Path { case "/": // Relative redirect. @@ -1630,10 +1579,10 @@ func TestClientCopyHostOnRedirect(t *testing.T) { t.Errorf("Serving unexpected path %q", r.URL.Path) w.WriteHeader(404) } - })) - defer ts.Close() + })).ts tsURL = ts.URL tsHost = strings.TrimPrefix(ts.URL, "http://") + tsHost = strings.TrimPrefix(tsHost, "https://") t.Logf("Server host is %v", tsHost) c := ts.Client() @@ -1653,7 +1602,8 @@ func TestClientCopyHostOnRedirect(t *testing.T) { } // Issue 17494: cookies should be altered when Client follows redirects. -func TestClientAltersCookiesOnRedirect(t *testing.T) { +func TestClientAltersCookiesOnRedirect(t *testing.T) { run(t, testClientAltersCookiesOnRedirect) } +func testClientAltersCookiesOnRedirect(t *testing.T, mode testMode) { cookieMap := func(cs []*Cookie) map[string][]string { m := make(map[string][]string) for _, c := range cs { @@ -1662,7 +1612,7 @@ func TestClientAltersCookiesOnRedirect(t *testing.T) { return m } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var want map[string][]string got := cookieMap(r.Cookies()) @@ -1717,8 +1667,7 @@ func TestClientAltersCookiesOnRedirect(t *testing.T) { if !reflect.DeepEqual(got, want) { t.Errorf("redirect %s, Cookie = %v, want %v", c.Value, got, want) } - })) - defer ts.Close() + })).ts jar, _ := cookiejar.New(nil) c := ts.Client() @@ -1790,10 +1739,8 @@ func TestShouldCopyHeaderOnRedirect(t *testing.T) { } } -func TestClientRedirectTypes(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestClientRedirectTypes(t *testing.T) { run(t, testClientRedirectTypes) } +func testClientRedirectTypes(t *testing.T, mode testMode) { tests := [...]struct { method string serverStatus int @@ -1838,11 +1785,10 @@ func TestClientRedirectTypes(t *testing.T) { handlerc := make(chan HandlerFunc, 1) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { h := <-handlerc h(rw, req) - })) - defer ts.Close() + })).ts c := ts.Client() for i, tt := range tests { @@ -1898,18 +1844,16 @@ func (b issue18239Body) Close() error { // Issue 18239: make sure the Transport doesn't retry requests with bodies // if Request.GetBody is not defined. -func TestTransportBodyReadError(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportBodyReadError(t *testing.T) { run(t, testTransportBodyReadError) } +func testTransportBodyReadError(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/ping" { return } buf := make([]byte, 1) n, err := r.Body.Read(buf) w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err)) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1993,22 +1937,13 @@ func TestClientPropagatesTimeoutToContext(t *testing.T) { c.Get("https://example.tld/") } -func TestClientDoCanceledVsTimeout_h1(t *testing.T) { - testClientDoCanceledVsTimeout(t, h1Mode) -} - -func TestClientDoCanceledVsTimeout_h2(t *testing.T) { - testClientDoCanceledVsTimeout(t, h2Mode) -} - // Issue 33545: lock-in the behavior promised by Client.Do's // docs about request cancellation vs timing out. -func testClientDoCanceledVsTimeout(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientDoCanceledVsTimeout(t *testing.T) { run(t, testClientDoCanceledVsTimeout) } +func testClientDoCanceledVsTimeout(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello, World!")) })) - defer cst.close() cases := []string{"timeout", "canceled"} @@ -2084,13 +2019,11 @@ func TestClientPopulatesNilResponseBody(t *testing.T) { } // Issue 40382: Client calls Close multiple times on Request.Body. -func TestClientCallsCloseOnlyOnce(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestClientCallsCloseOnlyOnce(t *testing.T) { run(t, testClientCallsCloseOnlyOnce) } +func testClientCallsCloseOnlyOnce(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNoContent) })) - defer cst.close() // Issue occurred non-deterministically: needed to occur after a successful // write (into TCP buffer) but before end of body. @@ -2140,17 +2073,15 @@ func (b *issue40382Body) Close() error { return nil } -func TestProbeZeroLengthBody(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestProbeZeroLengthBody(t *testing.T) { run(t, testProbeZeroLengthBody) } +func testProbeZeroLengthBody(t *testing.T, mode testMode) { reqc := make(chan struct{}) - cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(reqc) if _, err := io.Copy(w, r.Body); err != nil { t.Errorf("error copying request body: %v", err) } })) - defer cst.close() bodyr, bodyw := io.Pipe() var gotBody string diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index b472ca4b786..87e34cef855 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -35,8 +35,65 @@ import ( "time" ) +type testMode string + +const ( + http1Mode = testMode("h1") // HTTP/1.1 + https1Mode = testMode("https1") // HTTPS/1.1 + http2Mode = testMode("h2") // HTTP/2 +) + +type testNotParallelOpt struct{} + +var ( + testNotParallel = testNotParallelOpt{} +) + +type TBRun[T any] interface { + testing.TB + Run(string, func(T)) bool +} + +// run runs a client/server test in a variety of test configurations. +// +// Tests execute in HTTP/1.1 and HTTP/2 modes by default. +// To run in a different set of configurations, pass a []testMode option. +// +// Tests call t.Parallel() by default. +// To disable parallel execution, pass the testNotParallel option. +func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) { + t.Helper() + modes := []testMode{http1Mode, http2Mode} + parallel := true + for _, opt := range opts { + switch opt := opt.(type) { + case []testMode: + modes = opt + case testNotParallelOpt: + parallel = false + default: + t.Fatalf("unknown option type %T", opt) + } + } + if t, ok := any(t).(*testing.T); ok && parallel { + setParallel(t) + } + for _, mode := range modes { + t.Run(string(mode), func(t T) { + t.Helper() + if t, ok := any(t).(*testing.T); ok && parallel { + setParallel(t) + } + t.Cleanup(func() { + afterTest(t) + }) + f(t, mode) + }) + } +} + type clientServerTest struct { - t *testing.T + t testing.TB h2 bool h Handler ts *httptest.Server @@ -69,11 +126,6 @@ func (t *clientServerTest) scheme() string { return "http" } -const ( - h1Mode = false - h2Mode = true -) - var optQuietLog = func(ts *httptest.Server) { ts.Config.ErrorLog = quietLog } @@ -84,23 +136,33 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) { } } -func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientServerTest { - if h2 { +// newClientServerTest creates and starts an httptest.Server. +// +// The mode parameter selects the implementation to test: +// HTTP/1, HTTP/2, etc. Tests using newClientServerTest should use +// the 'run' function, which will start a subtests for each tested mode. +// +// The vararg opts parameter can include functions to configure the +// test server or transport. +// +// func(*httptest.Server) // run before starting the server +// func(*http.Transport) +func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest { + if mode == http2Mode { CondSkipHTTP2(t) } cst := &clientServerTest{ t: t, - h2: h2, + h2: mode == http2Mode, h: h, - tr: &Transport{}, } - cst.c = &Client{Transport: cst.tr} cst.ts = httptest.NewUnstartedServer(h) + var transportFuncs []func(*Transport) for _, opt := range opts { switch opt := opt.(type) { case func(*Transport): - opt(cst.tr) + transportFuncs = append(transportFuncs, opt) case func(*httptest.Server): opt(cst.ts) default: @@ -108,60 +170,84 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...any) *clientS } } - if !h2 { + switch mode { + case http1Mode: cst.ts.Start() - return cst + case https1Mode: + cst.ts.StartTLS() + case http2Mode: + ExportHttp2ConfigureServer(cst.ts.Config, nil) + cst.ts.TLS = cst.ts.Config.TLSConfig + cst.ts.StartTLS() + default: + t.Fatalf("unknown test mode %v", mode) } - ExportHttp2ConfigureServer(cst.ts.Config, nil) - cst.ts.TLS = cst.ts.Config.TLSConfig - cst.ts.StartTLS() - - cst.tr.TLSClientConfig = &tls.Config{ - InsecureSkipVerify: true, + cst.c = cst.ts.Client() + cst.tr = cst.c.Transport.(*Transport) + if mode == http2Mode { + if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { + t.Fatal(err) + } } - if err := ExportHttp2ConfigureTransport(cst.tr); err != nil { - t.Fatal(err) + for _, f := range transportFuncs { + f(cst.tr) } + t.Cleanup(func() { + cst.close() + }) return cst } // Testing the newClientServerTest helper itself. func TestNewClientServerTest(t *testing.T) { + run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testNewClientServerTest(t *testing.T, mode testMode) { var got struct { sync.Mutex - log []string + proto string + hasTLS bool } h := HandlerFunc(func(w ResponseWriter, r *Request) { got.Lock() defer got.Unlock() - got.log = append(got.log, r.Proto) + got.proto = r.Proto + got.hasTLS = r.TLS != nil }) - for _, v := range [2]bool{false, true} { - cst := newClientServerTest(t, v, h) - if _, err := cst.c.Head(cst.ts.URL); err != nil { - t.Fatal(err) - } - cst.close() + cst := newClientServerTest(t, mode, h) + if _, err := cst.c.Head(cst.ts.URL); err != nil { + t.Fatal(err) } - got.Lock() // no need to unlock - if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) { - t.Errorf("got %q; want %q", got.log, want) + var wantProto string + var wantTLS bool + switch mode { + case http1Mode: + wantProto = "HTTP/1.1" + wantTLS = false + case https1Mode: + wantProto = "HTTP/1.1" + wantTLS = true + case http2Mode: + wantProto = "HTTP/2.0" + wantTLS = true + } + if got.proto != wantProto { + t.Errorf("req.Proto = %q, want %q", got.proto, wantProto) + } + if got.hasTLS != wantTLS { + t.Errorf("req.TLS set: %v, want %v", got.hasTLS, wantTLS) } } -func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) } -func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) } - -func testChunkedResponseHeaders(t *testing.T, h2 bool) { - defer afterTest(t) +func TestChunkedResponseHeaders(t *testing.T) { run(t, testChunkedResponseHeaders) } +func testChunkedResponseHeaders(t *testing.T, mode testMode) { log.SetOutput(io.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted w.(Flusher).Flush() fmt.Fprintf(w, "I am a chunked response.") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -172,7 +258,7 @@ func testChunkedResponseHeaders(t *testing.T, h2 bool) { t.Errorf("expected ContentLength of %d; got %d", e, g) } wantTE := []string{"chunked"} - if h2 { + if mode == http2Mode { wantTE = nil } if !reflect.DeepEqual(res.TransferEncoding, wantTE) { @@ -204,9 +290,9 @@ func (tt h12Compare) reqFunc() reqFunc { func (tt h12Compare) run(t *testing.T) { setParallel(t) - cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...) + cst1 := newClientServerTest(t, http1Mode, HandlerFunc(tt.Handler), tt.Opts...) defer cst1.close() - cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...) + cst2 := newClientServerTest(t, http2Mode, HandlerFunc(tt.Handler), tt.Opts...) defer cst2.close() res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL) @@ -459,12 +545,9 @@ func TestH12_AutoGzip_Disabled(t *testing.T) { // Test304Responses verifies that 304s don't declare that they're // chunking in their response headers and aren't allowed to produce // output. -func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) } -func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) } - -func test304Responses(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func Test304Responses(t *testing.T) { run(t, test304Responses) } +func test304Responses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNotModified) _, err := w.Write([]byte("illegal body")) if err != ErrBodyNotAllowed { @@ -528,20 +611,17 @@ func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int6 // Tests that closing the Request.Cancel channel also while still // reading the response body. Issue 13159. -func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) } -func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) } -func testCancelRequestMidBody(t *testing.T, h2 bool) { - defer afterTest(t) +func TestCancelRequestMidBody(t *testing.T) { run(t, testCancelRequestMidBody) } +func testCancelRequestMidBody(t *testing.T, mode testMode) { unblock := make(chan bool) didFlush := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, "Hello") w.(Flusher).Flush() didFlush <- true <-unblock io.WriteString(w, ", world.") })) - defer cst.close() defer close(unblock) req, _ := NewRequest("GET", cst.ts.URL, nil) @@ -577,12 +657,9 @@ func testCancelRequestMidBody(t *testing.T, h2 bool) { } // Tests that clients can send trailers to a server and that the server can read them. -func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) } -func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) } - -func testTrailersClientToServer(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTrailersClientToServer(t *testing.T) { run(t, testTrailersClientToServer) } +func testTrailersClientToServer(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var decl []string for k := range r.Trailer { decl = append(decl, k) @@ -605,7 +682,6 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { r.Trailer.Get("Client-Trailer-B")) } })) - defer cst.close() var req *Request req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader( @@ -632,15 +708,20 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { } // Tests that servers send trailers to a client and that the client can read them. -func TestTrailersServerToClient_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, false) } -func TestTrailersServerToClient_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, false) } -func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) } -func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) } +func TestTrailersServerToClient(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTrailersServerToClient(t, mode, false) + }) +} +func TestTrailersServerToClientFlush(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTrailersServerToClient(t, mode, true) + }) +} -func testTrailersServerToClient(t *testing.T, h2, flush bool) { - defer afterTest(t) +func testTrailersServerToClient(t *testing.T, mode testMode, flush bool) { const body = "Some body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B") w.Header().Add("Trailer", "Server-Trailer-C") @@ -657,7 +738,6 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { w.Header().Set("Server-Trailer-C", "valuec") // skipping B w.Header().Set("Server-Trailer-NotDeclared", "should be omitted") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -668,7 +748,7 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { "Content-Type": {"text/plain; charset=utf-8"}, } wantLen := -1 - if h2 && !flush { + if mode == http2Mode && !flush { // In HTTP/1.1, any use of trailers forces HTTP/1.1 // chunking and a flush at the first write. That's // unnecessary with HTTP/2's framing, so the server @@ -708,16 +788,12 @@ func testTrailersServerToClient(t *testing.T, h2, flush bool) { } // Don't allow a Body.Read after Body.Close. Issue 13648. -func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) } -func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) } - -func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { - defer afterTest(t) +func TestResponseBodyReadAfterClose(t *testing.T) { run(t, testResponseBodyReadAfterClose) } +func testResponseBodyReadAfterClose(t *testing.T, mode testMode) { const body = "Some body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, body) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -729,13 +805,11 @@ func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { } } -func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) } -func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) } -func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { - defer afterTest(t) +func TestConcurrentReadWriteReqBody(t *testing.T) { run(t, testConcurrentReadWriteReqBody) } +func testConcurrentReadWriteReqBody(t *testing.T, mode testMode) { const reqBody = "some request body" const resBody = "some response body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { var wg sync.WaitGroup wg.Add(2) didRead := make(chan bool, 1) @@ -754,7 +828,7 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { // Write in another goroutine. go func() { defer wg.Done() - if !h2 { + if mode != http2Mode { // our HTTP/1 implementation intentionally // doesn't permit writes during read (mostly // due to it being undefined); if that is ever @@ -765,7 +839,6 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { }() wg.Wait() })) - defer cst.close() req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody)) req.Header.Add("Expect", "100-continue") // just to complicate things res, err := cst.c.Do(req) @@ -782,15 +855,12 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { } } -func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) } -func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) } -func testConnectRequest(t *testing.T, h2 bool) { - defer afterTest(t) +func TestConnectRequest(t *testing.T) { run(t, testConnectRequest) } +func testConnectRequest(t *testing.T, mode testMode) { gotc := make(chan *Request, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotc <- r })) - defer cst.close() u, err := url.Parse(cst.ts.URL) if err != nil { @@ -840,17 +910,14 @@ func testConnectRequest(t *testing.T, h2 bool) { } } -func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) } -func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) } -func testTransportUserAgent(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportUserAgent(t *testing.T) { run(t, testTransportUserAgent) } +func testTransportUserAgent(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%q", r.Header["User-Agent"]) })) - defer cst.close() either := func(a, b string) string { - if h2 { + if mode == http2Mode { return b } return a @@ -901,19 +968,22 @@ func testTransportUserAgent(t *testing.T, h2 bool) { } } -func TestStarRequestFoo_h1(t *testing.T) { testStarRequest(t, "FOO", h1Mode) } -func TestStarRequestFoo_h2(t *testing.T) { testStarRequest(t, "FOO", h2Mode) } -func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) } -func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) } -func testStarRequest(t *testing.T, method string, h2 bool) { - defer afterTest(t) +func TestStarRequestMethod(t *testing.T) { + for _, method := range []string{"FOO", "OPTIONS"} { + t.Run(method, func(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testStarRequest(t, method, mode) + }) + }) + } +} +func testStarRequest(t *testing.T, method string, mode testMode) { gotc := make(chan *Request, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("foo", "bar") gotc <- r w.(Flusher).Flush() })) - defer cst.close() u, err := url.Parse(cst.ts.URL) if err != nil { @@ -972,9 +1042,10 @@ func testStarRequest(t *testing.T, method string, h2 bool) { // Issue 13957 func TestTransportDiscardsUnneededConns(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportDiscardsUnneededConns, []testMode{http2Mode}) +} +func testTransportDiscardsUnneededConns(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) })) defer cst.close() @@ -1058,20 +1129,19 @@ func TestTransportDiscardsUnneededConns(t *testing.T) { } // tests that Transport doesn't retain a pointer to the provided request. -func TestTransportGCRequest_Body_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, true) } -func TestTransportGCRequest_Body_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, true) } -func TestTransportGCRequest_NoBody_h1(t *testing.T) { testTransportGCRequest(t, h1Mode, false) } -func TestTransportGCRequest_NoBody_h2(t *testing.T) { testTransportGCRequest(t, h2Mode, false) } -func testTransportGCRequest(t *testing.T, h2, body bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportGCRequest(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + t.Run("Body", func(t *testing.T) { testTransportGCRequest(t, mode, true) }) + t.Run("NoBody", func(t *testing.T) { testTransportGCRequest(t, mode, false) }) + }) +} +func testTransportGCRequest(t *testing.T, mode testMode, body bool) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.ReadAll(r.Body) if body { io.WriteString(w, "Hello.") } })) - defer cst.close() didGC := make(chan struct{}) (func() { @@ -1103,19 +1173,11 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { } } -func TestTransportRejectsInvalidHeaders_h1(t *testing.T) { - testTransportRejectsInvalidHeaders(t, h1Mode) -} -func TestTransportRejectsInvalidHeaders_h2(t *testing.T) { - testTransportRejectsInvalidHeaders(t, h2Mode) -} -func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportRejectsInvalidHeaders(t *testing.T) { run(t, testTransportRejectsInvalidHeaders) } +func testTransportRejectsInvalidHeaders(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Handler saw headers: %q", r.Header) }), optQuietLog) - defer cst.close() cst.tr.DisableKeepAlives = true tests := []struct { @@ -1161,27 +1223,22 @@ func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { } } -func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, "boom") } -func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, "boom") } -func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) } -func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) } -func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) { - testInterruptWithPanic(t, h1Mode, ErrAbortHandler) +func TestInterruptWithPanic(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + t.Run("boom", func(t *testing.T) { testInterruptWithPanic(t, mode, "boom") }) + t.Run("nil", func(t *testing.T) { testInterruptWithPanic(t, mode, nil) }) + t.Run("ErrAbortHandler", func(t *testing.T) { testInterruptWithPanic(t, mode, ErrAbortHandler) }) + }) } -func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) { - testInterruptWithPanic(t, h2Mode, ErrAbortHandler) -} -func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { - setParallel(t) +func testInterruptWithPanic(t *testing.T, mode testMode, panicValue any) { const msg = "hello" - defer afterTest(t) testDone := make(chan struct{}) defer close(testDone) var errorLog lockedBytesBuffer gotHeaders := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() @@ -1193,7 +1250,6 @@ func testInterruptWithPanic(t *testing.T, h2 bool, panicValue any) { }), func(ts *httptest.Server) { ts.Config.ErrorLog = log.New(&errorLog, "", 0) }) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1274,15 +1330,11 @@ func TestH12_AutoGzipWithDumpResponse(t *testing.T) { } // Issue 14607 -func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) } -func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) } -func testCloseIdleConnections(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestCloseIdleConnections(t *testing.T) { run(t, testCloseIdleConnections) } +func testCloseIdleConnections(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) })) - defer cst.close() get := func() string { res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1320,15 +1372,11 @@ func (r testErrorReader) Read(p []byte) (n int, err error) { return 0, io.EOF } -func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) } -func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) } - -func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestNoSniffExpectRequestBody(t *testing.T) { run(t, testNoSniffExpectRequestBody) } +func testNoSniffExpectRequestBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusUnauthorized) })) - defer cst.close() // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. cst.tr.ExpectContinueTimeout = 10 * time.Second @@ -1349,18 +1397,15 @@ func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { } } -func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) } -func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) } -func testServerUndeclaredTrailers(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerUndeclaredTrailers(t *testing.T) { run(t, testServerUndeclaredTrailers) } +func testServerUndeclaredTrailers(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Foo", "Bar") w.Header().Set("Trailer:Foo", "Baz") w.(Flusher).Flush() w.Header().Add("Trailer:Foo", "Baz2") w.Header().Set("Trailer:Bar", "Quux") })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1381,8 +1426,10 @@ func testServerUndeclaredTrailers(t *testing.T, h2 bool) { } func TestBadResponseAfterReadingBody(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testBadResponseAfterReadingBody, []testMode{http1Mode}) +} +func testBadResponseAfterReadingBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := io.Copy(io.Discard, r.Body) if err != nil { t.Fatal(err) @@ -1394,7 +1441,6 @@ func TestBadResponseAfterReadingBody(t *testing.T) { defer c.Close() fmt.Fprintln(c, "some bogus crap") })) - defer cst.close() closes := 0 res, err := cst.c.Post(cst.ts.URL, "text/plain", countCloseReader{&closes, strings.NewReader("hello")}) @@ -1407,12 +1453,10 @@ func TestBadResponseAfterReadingBody(t *testing.T) { } } -func TestWriteHeader0_h1(t *testing.T) { testWriteHeader0(t, h1Mode) } -func TestWriteHeader0_h2(t *testing.T) { testWriteHeader0(t, h2Mode) } -func testWriteHeader0(t *testing.T, h2 bool) { - defer afterTest(t) +func TestWriteHeader0(t *testing.T) { run(t, testWriteHeader0) } +func testWriteHeader0(t *testing.T, mode testMode) { gotpanic := make(chan bool, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(gotpanic) defer func() { if e := recover(); e != nil { @@ -1431,7 +1475,6 @@ func testWriteHeader0(t *testing.T, h2 bool) { }() w.WriteHeader(0) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1446,15 +1489,17 @@ func testWriteHeader0(t *testing.T, h2 bool) { // Issue 23010: don't be super strict checking WriteHeader's code if // it's not even valid to call WriteHeader then anyway. -func TestWriteHeaderNoCodeCheck_h1(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, false) } -func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { testWriteHeaderAfterWrite(t, h1Mode, true) } -func TestWriteHeaderNoCodeCheck_h2(t *testing.T) { testWriteHeaderAfterWrite(t, h2Mode, false) } -func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { - setParallel(t) - defer afterTest(t) - +func TestWriteHeaderNoCodeCheck(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testWriteHeaderAfterWrite(t, mode, false) + }) +} +func TestWriteHeaderNoCodeCheck_h1hijack(t *testing.T) { + testWriteHeaderAfterWrite(t, http1Mode, true) +} +func testWriteHeaderAfterWrite(t *testing.T, mode testMode, hijack bool) { var errorLog lockedBytesBuffer - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if hijack { conn, _, _ := w.(Hijacker).Hijack() defer conn.Close() @@ -1470,7 +1515,6 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { }), func(ts *httptest.Server) { ts.Config.ErrorLog = log.New(&errorLog, "", 0) }) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -1485,7 +1529,7 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { } // Also check the stderr output: - if h2 { + if mode == http2Mode { // TODO: also emit this log message for HTTP/2? // We historically haven't, so don't check. return @@ -1501,14 +1545,14 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { } func TestBidiStreamReverseProxy(t *testing.T) { - setParallel(t) - defer afterTest(t) - backend := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testBidiStreamReverseProxy, []testMode{http2Mode}) +} +func testBidiStreamReverseProxy(t *testing.T, mode testMode) { + backend := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if _, err := io.Copy(w, r.Body); err != nil { log.Printf("bidi backend copy: %v", err) } })) - defer backend.close() backURL, err := url.Parse(backend.ts.URL) if err != nil { @@ -1516,10 +1560,9 @@ func TestBidiStreamReverseProxy(t *testing.T) { } rp := httputil.NewSingleHostReverseProxy(backURL) rp.Transport = backend.tr - proxy := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + proxy := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { rp.ServeHTTP(w, r) })) - defer proxy.close() bodyRes := make(chan any, 1) // error or hash.Hash pr, pw := io.Pipe() @@ -1586,15 +1629,10 @@ func TestH12_WebSocketUpgrade(t *testing.T) { }.run(t) } -func TestIdentityTransferEncoding_h1(t *testing.T) { testIdentityTransferEncoding(t, h1Mode) } -func TestIdentityTransferEncoding_h2(t *testing.T) { testIdentityTransferEncoding(t, h2Mode) } - -func testIdentityTransferEncoding(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestIdentityTransferEncoding(t *testing.T) { run(t, testIdentityTransferEncoding) } +func testIdentityTransferEncoding(t *testing.T, mode testMode) { const body = "body" - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotBody, _ := io.ReadAll(r.Body) if got, want := string(gotBody), body; got != want { t.Errorf("got request body = %q; want %q", got, want) @@ -1604,7 +1642,6 @@ func testIdentityTransferEncoding(t *testing.T, h2 bool) { w.(Flusher).Flush() io.WriteString(w, body) })) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, strings.NewReader(body)) res, err := cst.c.Do(req) if err != nil { @@ -1620,14 +1657,11 @@ func testIdentityTransferEncoding(t *testing.T, h2 bool) { } } -func TestEarlyHintsRequest_h1(t *testing.T) { testEarlyHintsRequest(t, h1Mode) } -func TestEarlyHintsRequest_h2(t *testing.T) { testEarlyHintsRequest(t, h2Mode) } -func testEarlyHintsRequest(t *testing.T, h2 bool) { - defer afterTest(t) - +func TestEarlyHintsRequest(t *testing.T) { run(t, testEarlyHintsRequest) } +func testEarlyHintsRequest(t *testing.T, mode testMode) { var wg sync.WaitGroup wg.Add(1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { h := w.Header() h.Add("Content-Length", "123") // must be ignored @@ -1642,7 +1676,6 @@ func testEarlyHintsRequest(t *testing.T, h2 bool) { w.Write([]byte("Hello")) })) - defer cst.close() checkLinkHeaders := func(t *testing.T, expected, got []string) { t.Helper() diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index 205ca83f402..fb5ab9396aa 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -60,7 +60,7 @@ func init() { } } -func CondSkipHTTP2(t *testing.T) { +func CondSkipHTTP2(t testing.TB) { if omitBundledHTTP2 { t.Skip("skipping HTTP/2 test when nethttpomithttp2 build tag in use") } @@ -72,8 +72,6 @@ var ( ) func SetReadLoopBeforeNextReadHook(f func()) { - testHookMu.Lock() - defer testHookMu.Unlock() unnilTestHook(&f) testHookReadLoopBeforeNextRead = f } diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go index 71fc064367d..47526152b30 100644 --- a/src/net/http/fs_test.go +++ b/src/net/http/fs_test.go @@ -68,13 +68,11 @@ var ServeFileRangeTests = []struct { {r: "bytes=100-1000", code: StatusRequestedRangeNotSatisfiable}, } -func TestServeFile(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFile(t *testing.T) { run(t, testServeFile) } +func testServeFile(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") - })) - defer ts.Close() + })).ts c := ts.Client() var err error @@ -228,13 +226,12 @@ var fsRedirectTestData = []struct { {"/test/testdata/file/", "/test/testdata/file"}, } -func TestFSRedirect(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(StripPrefix("/test", FileServer(Dir(".")))) - defer ts.Close() +func TestFSRedirect(t *testing.T) { run(t, testFSRedirect) } +func testFSRedirect(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, StripPrefix("/test", FileServer(Dir(".")))).ts for _, data := range fsRedirectTestData { - res, err := Get(ts.URL + data.original) + res, err := ts.Client().Get(ts.URL + data.original) if err != nil { t.Fatal(err) } @@ -278,8 +275,8 @@ func TestFileServerCleans(t *testing.T) { } } -func TestFileServerEscapesNames(t *testing.T) { - defer afterTest(t) +func TestFileServerEscapesNames(t *testing.T) { run(t, testFileServerEscapesNames) } +func testFileServerEscapesNames(t *testing.T, mode testMode) { const dirListPrefix = "
\n"
 	const dirListSuffix = "\n
\n" tests := []struct { @@ -304,11 +301,10 @@ func TestFileServerEscapesNames(t *testing.T) { fs[fmt.Sprintf("/%d/%s", i, test.name)] = testFile } - ts := httptest.NewServer(FileServer(&fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(&fs)).ts for i, test := range tests { url := fmt.Sprintf("%s/%d", ts.URL, i) - res, err := Get(url) + res, err := ts.Client().Get(url) if err != nil { t.Fatalf("test %q: Get: %v", test.name, err) } @@ -327,8 +323,8 @@ func TestFileServerEscapesNames(t *testing.T) { } } -func TestFileServerSortsNames(t *testing.T) { - defer afterTest(t) +func TestFileServerSortsNames(t *testing.T) { run(t, testFileServerSortsNames) } +func testFileServerSortsNames(t *testing.T, mode testMode) { const contents = "I am a fake file" dirMod := time.Unix(123, 0).UTC() fileMod := time.Unix(1000000000, 0).UTC() @@ -351,10 +347,9 @@ func TestFileServerSortsNames(t *testing.T) { }, } - ts := httptest.NewServer(FileServer(&fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(&fs)).ts - res, err := Get(ts.URL) + res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatalf("Get: %v", err) } @@ -377,16 +372,15 @@ func mustRemoveAll(dir string) { } } -func TestFileServerImplicitLeadingSlash(t *testing.T) { - defer afterTest(t) +func TestFileServerImplicitLeadingSlash(t *testing.T) { run(t, testFileServerImplicitLeadingSlash) } +func testFileServerImplicitLeadingSlash(t *testing.T, mode testMode) { tempDir := t.TempDir() if err := os.WriteFile(filepath.Join(tempDir, "foo.txt"), []byte("Hello world"), 0644); err != nil { t.Fatalf("WriteFile: %v", err) } - ts := httptest.NewServer(StripPrefix("/bar/", FileServer(Dir(tempDir)))) - defer ts.Close() + ts := newClientServerTest(t, mode, StripPrefix("/bar/", FileServer(Dir(tempDir)))).ts get := func(suffix string) string { - res, err := Get(ts.URL + suffix) + res, err := ts.Client().Get(ts.URL + suffix) if err != nil { t.Fatalf("Get %s: %v", suffix, err) } @@ -405,11 +399,10 @@ func TestFileServerImplicitLeadingSlash(t *testing.T) { } } -func TestFileServerMethodOptions(t *testing.T) { - defer afterTest(t) +func TestFileServerMethodOptions(t *testing.T) { run(t, testFileServerMethodOptions) } +func testFileServerMethodOptions(t *testing.T, mode testMode) { const want = "GET, HEAD, OPTIONS" - ts := httptest.NewServer(FileServer(Dir("."))) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(Dir("."))).ts tests := []struct { method string @@ -496,10 +489,10 @@ func TestEmptyDirOpenCWD(t *testing.T) { test(Dir("./")) } -func TestServeFileContentType(t *testing.T) { - defer afterTest(t) +func TestServeFileContentType(t *testing.T) { run(t, testServeFileContentType) } +func testServeFileContentType(t *testing.T, mode testMode) { const ctype = "icecream/chocolate" - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { switch r.FormValue("override") { case "1": w.Header().Set("Content-Type", ctype) @@ -508,10 +501,9 @@ func TestServeFileContentType(t *testing.T) { w.Header()["Content-Type"] = []string{} } ServeFile(w, r, "testdata/file") - })) - defer ts.Close() + })).ts get := func(override string, want []string) { - resp, err := Get(ts.URL + "?override=" + override) + resp, err := ts.Client().Get(ts.URL + "?override=" + override) if err != nil { t.Fatal(err) } @@ -525,13 +517,12 @@ func TestServeFileContentType(t *testing.T) { get("2", nil) } -func TestServeFileMimeType(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileMimeType(t *testing.T) { run(t, testServeFileMimeType) } +func testServeFileMimeType(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/style.css") - })) - defer ts.Close() - resp, err := Get(ts.URL) + })).ts + resp, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -542,13 +533,12 @@ func TestServeFileMimeType(t *testing.T) { } } -func TestServeFileFromCWD(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileFromCWD(t *testing.T) { run(t, testServeFileFromCWD) } +func testServeFileFromCWD(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "fs_test.go") - })) - defer ts.Close() - r, err := Get(ts.URL) + })).ts + r, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -559,14 +549,13 @@ func TestServeFileFromCWD(t *testing.T) { } // Issue 13996 -func TestServeDirWithoutTrailingSlash(t *testing.T) { +func TestServeDirWithoutTrailingSlash(t *testing.T) { run(t, testServeDirWithoutTrailingSlash) } +func testServeDirWithoutTrailingSlash(t *testing.T, mode testMode) { e := "/testdata/" - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, ".") - })) - defer ts.Close() - r, err := Get(ts.URL + "/testdata") + })).ts + r, err := ts.Client().Get(ts.URL + "/testdata") if err != nil { t.Fatal(err) } @@ -578,11 +567,9 @@ func TestServeDirWithoutTrailingSlash(t *testing.T) { // Tests that ServeFile doesn't add a Content-Length if a Content-Encoding is // specified. -func TestServeFileWithContentEncoding_h1(t *testing.T) { testServeFileWithContentEncoding(t, h1Mode) } -func TestServeFileWithContentEncoding_h2(t *testing.T) { testServeFileWithContentEncoding(t, h2Mode) } -func testServeFileWithContentEncoding(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileWithContentEncoding(t *testing.T) { run(t, testServeFileWithContentEncoding) } +func testServeFileWithContentEncoding(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "foo") ServeFile(w, r, "testdata/file") @@ -595,7 +582,6 @@ func testServeFileWithContentEncoding(t *testing.T, h2 bool) { // Content-Length and test ServeFile only, flush here. w.(Flusher).Flush() })) - defer cst.close() resp, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -608,11 +594,9 @@ func testServeFileWithContentEncoding(t *testing.T, h2 bool) { // Tests that ServeFile does not generate representation metadata when // file has not been modified, as per RFC 7232 section 4.1. -func TestServeFileNotModified_h1(t *testing.T) { testServeFileNotModified(t, h1Mode) } -func TestServeFileNotModified_h2(t *testing.T) { testServeFileNotModified(t, h2Mode) } -func testServeFileNotModified(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServeFileNotModified(t *testing.T) { run(t, testServeFileNotModified) } +func testServeFileNotModified(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Encoding", "foo") w.Header().Set("Etag", `"123"`) @@ -627,7 +611,6 @@ func testServeFileNotModified(t *testing.T, h2 bool) { // Content-Length and test ServeFile only, flush here. w.(Flusher).Flush() })) - defer cst.close() req, err := NewRequest("GET", cst.ts.URL, nil) if err != nil { t.Fatal(err) @@ -660,9 +643,8 @@ func testServeFileNotModified(t *testing.T, h2 bool) { } } -func TestServeIndexHtml(t *testing.T) { - defer afterTest(t) - +func TestServeIndexHtml(t *testing.T) { run(t, testServeIndexHtml) } +func testServeIndexHtml(t *testing.T, mode testMode) { for i := 0; i < 2; i++ { var h Handler var name string @@ -676,11 +658,10 @@ func TestServeIndexHtml(t *testing.T) { } t.Run(name, func(t *testing.T) { const want = "index.html says hello\n" - ts := httptest.NewServer(h) - defer ts.Close() + ts := newClientServerTest(t, mode, h).ts for _, path := range []string{"/testdata/", "/testdata/index.html"} { - res, err := Get(ts.URL + path) + res, err := ts.Client().Get(ts.URL + path) if err != nil { t.Fatal(err) } @@ -697,14 +678,14 @@ func TestServeIndexHtml(t *testing.T) { } } -func TestServeIndexHtmlFS(t *testing.T) { - defer afterTest(t) +func TestServeIndexHtmlFS(t *testing.T) { run(t, testServeIndexHtmlFS) } +func testServeIndexHtmlFS(t *testing.T, mode testMode) { const want = "index.html says hello\n" - ts := httptest.NewServer(FileServer(Dir("."))) + ts := newClientServerTest(t, mode, FileServer(Dir("."))).ts defer ts.Close() for _, path := range []string{"/testdata/", "/testdata/index.html"} { - res, err := Get(ts.URL + path) + res, err := ts.Client().Get(ts.URL + path) if err != nil { t.Fatal(err) } @@ -719,10 +700,9 @@ func TestServeIndexHtmlFS(t *testing.T) { } } -func TestFileServerZeroByte(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(FileServer(Dir("."))) - defer ts.Close() +func TestFileServerZeroByte(t *testing.T) { run(t, testFileServerZeroByte) } +func testFileServerZeroByte(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, FileServer(Dir("."))).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -809,8 +789,8 @@ func (fsys fakeFS) Open(name string) (File, error) { return &fakeFile{ReadSeeker: strings.NewReader(f.contents), fi: f, path: name}, nil } -func TestDirectoryIfNotModified(t *testing.T) { - defer afterTest(t) +func TestDirectoryIfNotModified(t *testing.T) { run(t, testDirectoryIfNotModified) } +func testDirectoryIfNotModified(t *testing.T, mode testMode) { const indexContents = "I am a fake index.html file" fileMod := time.Unix(1000000000, 0).UTC() fileModStr := fileMod.Format(TimeFormat) @@ -829,10 +809,9 @@ func TestDirectoryIfNotModified(t *testing.T) { "/index.html": indexFile, } - ts := httptest.NewServer(FileServer(fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(fs)).ts - res, err := Get(ts.URL) + res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -884,8 +863,8 @@ func mustStat(t *testing.T, fileName string) fs.FileInfo { return fi } -func TestServeContent(t *testing.T) { - defer afterTest(t) +func TestServeContent(t *testing.T) { run(t, testServeContent) } +func testServeContent(t *testing.T, mode testMode) { type serveParam struct { name string modtime time.Time @@ -894,7 +873,7 @@ func TestServeContent(t *testing.T) { etag string } servec := make(chan serveParam, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { p := <-servec if p.etag != "" { w.Header().Set("ETag", p.etag) @@ -903,8 +882,7 @@ func TestServeContent(t *testing.T) { w.Header().Set("Content-Type", p.contentType) } ServeContent(w, r, p.name, p.modtime, p.content) - })) - defer ts.Close() + })).ts type testCase struct { // One of file or content must be set: @@ -1213,8 +1191,8 @@ type issue12991File struct{ File } func (issue12991File) Stat() (fs.FileInfo, error) { return nil, fs.ErrPermission } func (issue12991File) Close() error { return nil } -func TestServeContentErrorMessages(t *testing.T) { - defer afterTest(t) +func TestServeContentErrorMessages(t *testing.T) { run(t, testServeContentErrorMessages) } +func testServeContentErrorMessages(t *testing.T, mode testMode) { fs := fakeFS{ "/500": &fakeFileInfo{ err: errors.New("random error"), @@ -1223,8 +1201,7 @@ func TestServeContentErrorMessages(t *testing.T) { err: &fs.PathError{Err: fs.ErrPermission}, }, } - ts := httptest.NewServer(FileServer(fs)) - defer ts.Close() + ts := newClientServerTest(t, mode, FileServer(fs)).ts c := ts.Client() for _, code := range []int{403, 404, 500} { res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, code)) @@ -1342,20 +1319,20 @@ func TestLinuxSendfileChild(*testing.T) { // Issues 18984, 49552: tests that requests for paths beyond files return not-found errors func TestFileServerNotDirError(t *testing.T) { - defer afterTest(t) - t.Run("Dir", func(t *testing.T) { - testFileServerNotDirError(t, func(path string) FileSystem { return Dir(path) }) - }) - t.Run("FS", func(t *testing.T) { - testFileServerNotDirError(t, func(path string) FileSystem { return FS(os.DirFS(path)) }) + run(t, func(t *testing.T, mode testMode) { + t.Run("Dir", func(t *testing.T) { + testFileServerNotDirError(t, mode, func(path string) FileSystem { return Dir(path) }) + }) + t.Run("FS", func(t *testing.T) { + testFileServerNotDirError(t, mode, func(path string) FileSystem { return FS(os.DirFS(path)) }) + }) }) } -func testFileServerNotDirError(t *testing.T, newfs func(string) FileSystem) { - ts := httptest.NewServer(FileServer(newfs("testdata"))) - defer ts.Close() +func testFileServerNotDirError(t *testing.T, mode testMode, newfs func(string) FileSystem) { + ts := newClientServerTest(t, mode, FileServer(newfs("testdata"))).ts - res, err := Get(ts.URL + "/index.html/not-a-file") + res, err := ts.Client().Get(ts.URL + "/index.html/not-a-file") if err != nil { t.Fatal(err) } @@ -1459,19 +1436,11 @@ func Test_scanETag(t *testing.T) { // Issue 40940: Ensure that we only accept non-negative suffix-lengths // in "Range": "bytes=-N", and should reject "bytes=--2". -func TestServeFileRejectsInvalidSuffixLengths_h1(t *testing.T) { - testServeFileRejectsInvalidSuffixLengths(t, h1Mode) +func TestServeFileRejectsInvalidSuffixLengths(t *testing.T) { + run(t, testServeFileRejectsInvalidSuffixLengths, []testMode{http1Mode, https1Mode, http2Mode}) } -func TestServeFileRejectsInvalidSuffixLengths_h2(t *testing.T) { - testServeFileRejectsInvalidSuffixLengths(t, h2Mode) -} - -func testServeFileRejectsInvalidSuffixLengths(t *testing.T, h2 bool) { - defer afterTest(t) - cst := httptest.NewUnstartedServer(FileServer(Dir("testdata"))) - cst.EnableHTTP2 = h2 - cst.StartTLS() - defer cst.Close() +func testServeFileRejectsInvalidSuffixLengths(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, FileServer(Dir("testdata"))).ts tests := []struct { r string diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go index 27e9eb30eeb..686a8699fb0 100644 --- a/src/net/http/request_test.go +++ b/src/net/http/request_test.go @@ -15,7 +15,6 @@ import ( "math" "mime/multipart" . "net/http" - "net/http/httptest" "net/url" "os" "reflect" @@ -289,10 +288,11 @@ Content-Type: text/plain // the payload size and the internal leeway buffer size of 10MiB overflows, that we // correctly return an error. func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) { - defer afterTest(t) - + run(t, testMaxInt64ForMultipartFormMaxMemoryOverflow) +} +func testMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T, mode testMode) { payloadSize := 1 << 10 - cst := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { // The combination of: // MaxInt64 + payloadSize + (internal spare of 10MiB) // triggers the overflow. See issue https://golang.org/issue/40430/ @@ -300,8 +300,7 @@ func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) { Error(rw, err.Error(), StatusBadRequest) return } - })) - defer cst.Close() + })).ts fBuf := new(bytes.Buffer) mw := multipart.NewWriter(fBuf) mf, err := mw.CreateFormFile("file", "myfile.txt") @@ -329,11 +328,9 @@ func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) { } } -func TestRedirect_h1(t *testing.T) { testRedirect(t, h1Mode) } -func TestRedirect_h2(t *testing.T) { testRedirect(t, h2Mode) } -func testRedirect(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestRequestRedirect(t *testing.T) { run(t, testRequestRedirect) } +func testRequestRedirect(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { switch r.URL.Path { case "/": w.Header().Set("Location", "/foo/") @@ -344,7 +341,6 @@ func testRedirect(t *testing.T, h2 bool) { w.WriteHeader(StatusBadRequest) } })) - defer cst.close() var end = regexp.MustCompile("/foo/$") r, err := cst.c.Get(cst.ts.URL) @@ -1035,19 +1031,10 @@ func TestRequestCloneTransferEncoding(t *testing.T) { } } -func TestNoPanicOnRoundTripWithBasicAuth_h1(t *testing.T) { - testNoPanicWithBasicAuth(t, h1Mode) -} - -func TestNoPanicOnRoundTripWithBasicAuth_h2(t *testing.T) { - testNoPanicWithBasicAuth(t, h2Mode) -} - // Issue 34878: verify we don't panic when including basic auth (Go 1.13 regression) -func testNoPanicWithBasicAuth(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer cst.close() +func TestNoPanicOnRoundTripWithBasicAuth(t *testing.T) { run(t, testNoPanicWithBasicAuth) } +func testNoPanicWithBasicAuth(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})) u, err := url.Parse(cst.ts.URL) if err != nil { @@ -1328,11 +1315,6 @@ Host: localhost:8080 `) } -const ( - withTLS = true - noTLS = false -) - func BenchmarkFileAndServer_1KB(b *testing.B) { benchmarkFileAndServer(b, 1<<10) } @@ -1360,16 +1342,12 @@ func benchmarkFileAndServer(b *testing.B, n int64) { b.Fatalf("Failed to copy %d bytes: %v", n, err) } - b.Run("NoTLS", func(b *testing.B) { - runFileAndServerBenchmarks(b, noTLS, f, n) - }) - - b.Run("TLS", func(b *testing.B) { - runFileAndServerBenchmarks(b, withTLS, f, n) - }) + run(b, func(b *testing.B, mode testMode) { + runFileAndServerBenchmarks(b, mode, f, n) + }, []testMode{http1Mode, https1Mode, http2Mode}) } -func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int64) { +func runFileAndServerBenchmarks(b *testing.B, mode testMode, f *os.File, n int64) { handler := HandlerFunc(func(rw ResponseWriter, req *Request) { defer req.Body.Close() nc, err := io.Copy(io.Discard, req.Body) @@ -1382,14 +1360,8 @@ func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int6 } }) - var cst *httptest.Server - if tlsOption == withTLS { - cst = httptest.NewTLSServer(handler) - } else { - cst = httptest.NewServer(handler) - } + cst := newClientServerTest(b, mode, handler).ts - defer cst.Close() b.ResetTimer() for i := 0; i < b.N; i++ { // Perform some setup. diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 4fadc56c9eb..a93f6eff1bb 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -246,15 +246,13 @@ var vtests = []struct { {"http://someHost.com/someDir", "/someDir/"}, } -func TestHostHandlers(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestHostHandlers(t *testing.T) { run(t, testHostHandlers, []testMode{http1Mode}) } +func testHostHandlers(t *testing.T, mode testMode) { mux := NewServeMux() for _, h := range handlers { mux.Handle(h.pattern, stringHandler(h.msg)) } - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -487,9 +485,9 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { // properly sets the query string in the redirect URL. // See Issue 17841. func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { - setParallel(t) - defer afterTest(t) - + run(t, testServeWithSlashRedirectKeepsQueryString, []testMode{http1Mode}) +} +func testServeWithSlashRedirectKeepsQueryString(t *testing.T, mode testMode) { writeBackQuery := func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.URL.RawQuery) } @@ -502,8 +500,7 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { fmt.Fprintf(w, "%s:bar", r.URL.RawQuery) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts tests := [...]struct { path string @@ -546,7 +543,6 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { setParallel(t) - defer afterTest(t) mux := NewServeMux() mux.Handle("example.com/pkg/foo/", stringHandler("example.com/pkg/foo/")) @@ -578,9 +574,6 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { {"CONNECT", "http://example.com:3000/pkg/connect", 301, "/pkg/connect/", ""}, } - ts := httptest.NewServer(mux) - defer ts.Close() - for i, tt := range tests { req, _ := NewRequest(tt.method, tt.url, nil) w := httptest.NewRecorder() @@ -602,13 +595,10 @@ func TestServeWithSlashRedirectForHostPatterns(t *testing.T) { } } -func TestShouldRedirectConcurrency(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestShouldRedirectConcurrency(t *testing.T) { run(t, testShouldRedirectConcurrency) } +func testShouldRedirectConcurrency(t *testing.T, mode testMode) { mux := NewServeMux() - ts := httptest.NewServer(mux) - defer ts.Close() + newClientServerTest(t, mode, mux) mux.HandleFunc("/", func(w ResponseWriter, r *Request) {}) } @@ -656,13 +646,12 @@ func benchmarkServeMux(b *testing.B, runHandler bool) { } } -func TestServerTimeouts(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) } +func testServerTimeouts(t *testing.T, mode testMode) { // Try three times, with increasing timeouts. tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second} for i, timeout := range tries { - err := testServerTimeouts(timeout) + err := testServerTimeoutsWithTimeout(t, timeout, mode) if err == nil { return } @@ -674,16 +663,15 @@ func TestServerTimeouts(t *testing.T) { t.Fatal("all attempts failed") } -func testServerTimeouts(timeout time.Duration) error { +func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ fmt.Fprintf(res, "req=%d", reqNum) - })) - ts.Config.ReadTimeout = timeout - ts.Config.WriteTimeout = timeout - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadTimeout = timeout + ts.Config.WriteTimeout = timeout + }).ts // Hit the HTTP server successfully. c := ts.Client() @@ -749,22 +737,20 @@ func testServerTimeouts(timeout time.Duration) error { } // Test that the HTTP/2 server handles Server.WriteTimeout (Issue 18437) -func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { +func TestWriteDeadlineExtendedOnNewRequest(t *testing.T) { + run(t, testWriteDeadlineExtendedOnNewRequest) +} +func testWriteDeadlineExtendedOnNewRequest(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {})) - ts.Config.WriteTimeout = 250 * time.Millisecond - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) {}), + func(ts *httptest.Server) { + ts.Config.WriteTimeout = 250 * time.Millisecond + }, + ).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - t.Fatal(err) - } for i := 1; i <= 3; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -785,9 +771,6 @@ func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { t.Fatalf("http2 Get #%d: %v", i, err) } r.Body.Close() - if r.ProtoMajor != 2 { - t.Fatalf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } time.Sleep(ts.Config.WriteTimeout / 2) } } @@ -810,33 +793,31 @@ func tryTimeouts(t *testing.T, testFunc func(timeout time.Duration) error) { } // Test that the HTTP/2 server RSTs stream on slow write. -func TestHTTP2WriteDeadlineEnforcedPerStream(t *testing.T) { +func TestWriteDeadlineEnforcedPerStream(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } setParallel(t) - defer afterTest(t) - tryTimeouts(t, testHTTP2WriteDeadlineEnforcedPerStream) + run(t, func(t *testing.T, mode testMode) { + tryTimeouts(t, func(timeout time.Duration) error { + return testWriteDeadlineEnforcedPerStream(t, mode, timeout) + }) + }) } -func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { +func testWriteDeadlineEnforcedPerStream(t *testing.T, mode testMode, timeout time.Duration) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ if reqNum == 1 { return // first request succeeds } time.Sleep(timeout) // second request times out - })) - ts.Config.WriteTimeout = timeout / 2 - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.WriteTimeout = timeout / 2 + }).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) - } req, err := NewRequest("GET", ts.URL, nil) if err != nil { @@ -844,12 +825,9 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { } r, err := c.Do(req) if err != nil { - return fmt.Errorf("http2 Get #1: %v", err) + return fmt.Errorf("Get #1: %v", err) } r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } req, err = NewRequest("GET", ts.URL, nil) if err != nil { @@ -858,45 +836,42 @@ func testHTTP2WriteDeadlineEnforcedPerStream(timeout time.Duration) error { r, err = c.Do(req) if err == nil { r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } - return fmt.Errorf("http2 Get #2 expected error, got nil") + return fmt.Errorf("Get #2 expected error, got nil") } - expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3 - if !strings.Contains(err.Error(), expected) { - return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err) + if mode == http2Mode { + expected := "stream ID 3; INTERNAL_ERROR" // client IDs are odd, second stream should be 3 + if !strings.Contains(err.Error(), expected) { + return fmt.Errorf("http2 Get #2: expected error to contain %q, got %q", expected, err) + } } return nil } // Test that the HTTP/2 server does not send RST when WriteDeadline not set. -func TestHTTP2NoWriteDeadline(t *testing.T) { +func TestNoWriteDeadline(t *testing.T) { if testing.Short() { t.Skip("skipping in short mode") } setParallel(t) defer afterTest(t) - tryTimeouts(t, testHTTP2NoWriteDeadline) + run(t, func(t *testing.T, mode testMode) { + tryTimeouts(t, func(timeout time.Duration) error { + return testNoWriteDeadline(t, mode, timeout) + }) + }) } -func testHTTP2NoWriteDeadline(timeout time.Duration) error { +func testNoWriteDeadline(t *testing.T, mode testMode, timeout time.Duration) error { reqNum := 0 - ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(res ResponseWriter, req *Request) { reqNum++ if reqNum == 1 { return // first request succeeds } time.Sleep(timeout) // second request timesout - })) - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - defer ts.Close() + })).ts c := ts.Client() - if err := ExportHttp2ConfigureTransport(c.Transport.(*Transport)); err != nil { - return fmt.Errorf("ExportHttp2ConfigureTransport: %v", err) - } for i := 0; i < 2; i++ { req, err := NewRequest("GET", ts.URL, nil) @@ -905,12 +880,9 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error { } r, err := c.Do(req) if err != nil { - return fmt.Errorf("http2 Get #%d: %v", i, err) + return fmt.Errorf("Get #%d: %v", i, err) } r.Body.Close() - if r.ProtoMajor != 2 { - return fmt.Errorf("http2 Get expected HTTP/2.0, got %q", r.Proto) - } } return nil } @@ -918,15 +890,14 @@ func testHTTP2NoWriteDeadline(timeout time.Duration) error { // golang.org/issue/4741 -- setting only a write timeout that triggers // shouldn't cause a handler to block forever on reads (next HTTP // request) that will never happen. -func TestOnlyWriteTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestOnlyWriteTimeout(t *testing.T) { run(t, testOnlyWriteTimeout, []testMode{http1Mode}) } +func testOnlyWriteTimeout(t *testing.T, mode testMode) { var ( mu sync.RWMutex conn net.Conn ) var afterTimeoutErrc = make(chan error, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { buf := make([]byte, 512<<10) _, err := w.Write(buf) if err != nil { @@ -942,10 +913,9 @@ func TestOnlyWriteTimeout(t *testing.T) { conn.SetWriteDeadline(time.Now().Add(-30 * time.Second)) _, err = w.Write(buf) afterTimeoutErrc <- err - })) - ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn} - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Listener = trackLastConnListener{ts.Listener, &mu, &conn} + }).ts c := ts.Client() @@ -992,9 +962,12 @@ func (l trackLastConnListener) Accept() (c net.Conn, err error) { } // TestIdentityResponse verifies that a handler can unset -func TestIdentityResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestIdentityResponse(t *testing.T) { run(t, testIdentityResponse) } +func testIdentityResponse(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56019") + } + handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") rw.Header().Set("Transfer-Encoding", req.FormValue("te")) @@ -1012,9 +985,7 @@ func TestIdentityResponse(t *testing.T) { } }) - ts := httptest.NewServer(handler) - defer ts.Close() - + ts := newClientServerTest(t, mode, handler).ts c := ts.Client() // Note: this relies on the assumption (which is true) that @@ -1048,6 +1019,10 @@ func TestIdentityResponse(t *testing.T) { } res.Body.Close() + if mode != http1Mode { + return + } + // Verify that the connection is closed when the declared Content-Length // is larger than what the handler wrote. conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -1070,9 +1045,7 @@ func TestIdentityResponse(t *testing.T) { func testTCPConnectionCloses(t *testing.T, req string, h Handler) { setParallel(t) - defer afterTest(t) - s := httptest.NewServer(h) - defer s.Close() + s := newClientServerTest(t, http1Mode, h).ts conn, err := net.Dial("tcp", s.Listener.Addr().String()) if err != nil { @@ -1114,9 +1087,7 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) { setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(handler) - defer ts.Close() + ts := newClientServerTest(t, http1Mode, handler).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -1192,14 +1163,12 @@ func TestHTTP10KeepAlive304Response(t *testing.T) { } // Issue 15703 -func TestKeepAliveFinalChunkWithEOF(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, false /* h1 */, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestKeepAliveFinalChunkWithEOF(t *testing.T) { run(t, testKeepAliveFinalChunkWithEOF) } +func testKeepAliveFinalChunkWithEOF(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() // force chunked encoding w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}")) })) - defer cst.close() type data struct { Addr string } @@ -1222,16 +1191,11 @@ func TestKeepAliveFinalChunkWithEOF(t *testing.T) { } } -func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) } -func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) } - -func testSetsRemoteAddr(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestSetsRemoteAddr(t *testing.T) { run(t, testSetsRemoteAddr) } +func testSetsRemoteAddr(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -1276,17 +1240,18 @@ func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr { // Issue 12943 func TestServerAllowsBlockingRemoteAddr(t *testing.T) { - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "RA:%s", r.RemoteAddr) - })) + run(t, testServerAllowsBlockingRemoteAddr, []testMode{http1Mode}) +} +func testServerAllowsBlockingRemoteAddr(t *testing.T, mode testMode) { conns := make(chan net.Conn) - ts.Listener = &blockingRemoteAddrListener{ - Listener: ts.Listener, - conns: conns, - } - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "RA:%s", r.RemoteAddr) + }), func(ts *httptest.Server) { + ts.Listener = &blockingRemoteAddrListener{ + Listener: ts.Listener, + conns: conns, + } + }).ts c := ts.Client() c.Timeout = time.Second @@ -1351,13 +1316,9 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { // TestHeadResponses verifies that all MIME type sniffing and Content-Length // counting of GET requests also happens on HEAD requests. -func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) } -func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) } - -func testHeadResponses(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestHeadResponses(t *testing.T) { run(t, testHeadResponses) } +func testHeadResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("")) if err != nil { t.Errorf("ResponseWriter.Write: %v", err) @@ -1369,7 +1330,6 @@ func testHeadResponses(t *testing.T, h2 bool) { t.Errorf("Copy(ResponseWriter, ...): %v", err) } })) - defer cst.close() res, err := cst.c.Head(cst.ts.URL) if err != nil { t.Error(err) @@ -1393,14 +1353,16 @@ func testHeadResponses(t *testing.T, h2 bool) { } func TestTLSHandshakeTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode}) +} +func testTLSHandshakeTimeout(t *testing.T, mode testMode) { errc := make(chanWriter, 10) // but only expecting 1 - ts.Config.ReadTimeout = 250 * time.Millisecond - ts.Config.ErrorLog = log.New(errc, "", 0) - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), + func(ts *httptest.Server) { + ts.Config.ReadTimeout = 250 * time.Millisecond + ts.Config.ErrorLog = log.New(errc, "", 0) + }, + ).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -1423,19 +1385,18 @@ func TestTLSHandshakeTimeout(t *testing.T) { } } -func TestTLSServer(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTLSServer(t *testing.T) { run(t, testTLSServer, []testMode{https1Mode, http2Mode}) } +func testTLSServer(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { w.Header().Set("X-TLS-Set", "true") if r.TLS.HandshakeComplete { w.Header().Set("X-TLS-HandshakeComplete", "true") } } - })) - ts.Config.ErrorLog = log.New(io.Discard, "", 0) - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(io.Discard, "", 0) + }).ts // Connect an idle TCP connection to this server before we run // our real tests. This idle connection used to block forever @@ -1528,14 +1489,15 @@ func TestServeTLS(t *testing.T) { // Test that the HTTPS server nicely rejects plaintext HTTP/1.x requests. func TestTLSServerRejectHTTPRequests(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTLSServerRejectHTTPRequests, []testMode{https1Mode, http2Mode}) +} +func testTLSServerRejectHTTPRequests(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Error("unexpected HTTPS request") - })) - var errBuf bytes.Buffer - ts.Config.ErrorLog = log.New(&errBuf, "", 0) - defer ts.Close() + }), func(ts *httptest.Server) { + var errBuf bytes.Buffer + ts.Config.ErrorLog = log.New(&errBuf, "", 0) + }).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -1727,11 +1689,9 @@ var serverExpectTests = []serverExpectTest{ // Tests that the server responds to the "Expect" request header // correctly. -// http2 test: TestServer_Response_Automatic100Continue -func TestServerExpect(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerExpect(t *testing.T) { run(t, testServerExpect, []testMode{http1Mode}) } +func testServerExpect(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST // requests that would read from r.Body, which we only // conditionally want to do. @@ -1741,8 +1701,7 @@ func TestServerExpect(t *testing.T) { } else { w.WriteHeader(StatusUnauthorized) } - })) - defer ts.Close() + })).ts runTest := func(test serverExpectTest) { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -2287,11 +2246,8 @@ func (c cancelableTimeoutContext) Err() error { return nil } -func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } -func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } -func testTimeoutHandler(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandler(t *testing.T) { run(t, testTimeoutHandler) } +func testTimeoutHandler(t *testing.T, mode testMode) { sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2301,8 +2257,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) { }) ctx, cancel := context.WithCancel(context.Background()) h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) - cst := newClientServerTest(t, h2, h) - defer cst.close() + cst := newClientServerTest(t, mode, h) // Succeed without timing out: sendHi <- true @@ -2348,10 +2303,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) { } // See issues 8209 and 8414. -func TestTimeoutHandlerRace(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTimeoutHandlerRace(t *testing.T) { run(t, testTimeoutHandlerRace) } +func testTimeoutHandlerRace(t *testing.T, mode testMode) { delayHi := HandlerFunc(func(w ResponseWriter, r *Request) { ms, _ := strconv.Atoi(r.URL.Path[1:]) if ms == 0 { @@ -2363,8 +2316,7 @@ func TestTimeoutHandlerRace(t *testing.T) { } }) - ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(delayHi, 20*time.Millisecond, "")).ts c := ts.Client() @@ -2393,16 +2345,13 @@ func TestTimeoutHandlerRace(t *testing.T) { // See issues 8209 and 8414. // Both issues involved panics in the implementation of TimeoutHandler. -func TestTimeoutHandlerRaceHeader(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTimeoutHandlerRaceHeader(t *testing.T) { run(t, testTimeoutHandlerRaceHeader) } +func testTimeoutHandlerRaceHeader(t *testing.T, mode testMode) { delay204 := HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(204) }) - ts := httptest.NewServer(TimeoutHandler(delay204, time.Nanosecond, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(delay204, time.Nanosecond, "")).ts var wg sync.WaitGroup gate := make(chan bool, 50) @@ -2433,9 +2382,8 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { } // Issue 9162 -func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { run(t, testTimeoutHandlerRaceHeaderTimeout) } +func testTimeoutHandlerRaceHeaderTimeout(t *testing.T, mode testMode) { sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2446,8 +2394,7 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { }) ctx, cancel := context.WithCancel(context.Background()) h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx}) - cst := newClientServerTest(t, h1Mode, h) - defer cst.close() + cst := newClientServerTest(t, mode, h) // Succeed without timing out: sendHi <- true @@ -2491,15 +2438,17 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { // Issue 14568. func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { + run(t, testTimeoutHandlerStartTimerWhenServing) +} +func testTimeoutHandlerStartTimerWhenServing(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping sleeping test in -short mode") } - defer afterTest(t) var handler HandlerFunc = func(w ResponseWriter, _ *Request) { w.WriteHeader(StatusNoContent) } timeout := 300 * time.Millisecond - ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) + ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts defer ts.Close() c := ts.Client() @@ -2518,9 +2467,8 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { } } -func TestTimeoutHandlerContextCanceled(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerContextCanceled(t *testing.T) { run(t, testTimeoutHandlerContextCanceled) } +func testTimeoutHandlerContextCanceled(t *testing.T, mode testMode) { writeErrors := make(chan error, 1) sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Type", "text/plain") @@ -2540,7 +2488,7 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() h := NewTestTimeoutHandler(sayHi, ctx) - cst := newClientServerTest(t, h1Mode, h) + cst := newClientServerTest(t, mode, h) defer cst.close() res, err := cst.c.Get(cst.ts.URL) @@ -2560,15 +2508,13 @@ func TestTimeoutHandlerContextCanceled(t *testing.T) { } // https://golang.org/issue/15948 -func TestTimeoutHandlerEmptyResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTimeoutHandlerEmptyResponse(t *testing.T) { run(t, testTimeoutHandlerEmptyResponse) } +func testTimeoutHandlerEmptyResponse(t *testing.T, mode testMode) { var handler HandlerFunc = func(w ResponseWriter, _ *Request) { // No response. } timeout := 300 * time.Millisecond - ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) - defer ts.Close() + ts := newClientServerTest(t, mode, TimeoutHandler(handler, timeout, "")).ts c := ts.Client() @@ -2587,7 +2533,9 @@ func TestTimeoutHandlerPanicRecovery(t *testing.T) { wrapper := func(h Handler) Handler { return TimeoutHandler(h, time.Second, "") } - testHandlerPanic(t, false, false, wrapper, "intentional death for testing") + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, wrapper, "intentional death for testing") + }, testNotParallel) } func TestRedirectBadPath(t *testing.T) { @@ -2705,17 +2653,10 @@ func TestRedirectContentTypeAndBody(t *testing.T) { // connection immediately. But when it re-uses the connection, it typically closes // the previous request's body, which is not optimal for zero-lengthed bodies, // as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF. -func TestZeroLengthPostAndResponse_h1(t *testing.T) { - testZeroLengthPostAndResponse(t, h1Mode) -} -func TestZeroLengthPostAndResponse_h2(t *testing.T) { - testZeroLengthPostAndResponse(t, h2Mode) -} +func TestZeroLengthPostAndResponse(t *testing.T) { run(t, testZeroLengthPostAndResponse) } -func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { +func testZeroLengthPostAndResponse(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("handler ReadAll: %v", err) @@ -2725,7 +2666,6 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } rw.Header().Set("Content-Length", "0") })) - defer cst.close() req, err := NewRequest("POST", cst.ts.URL, strings.NewReader("")) if err != nil { @@ -2752,23 +2692,26 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } } -func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil, nil) } -func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil, nil) } - -func TestHandlerPanic_h1(t *testing.T) { - testHandlerPanic(t, false, h1Mode, nil, "intentional death for testing") +func TestHandlerPanicNil(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, nil, nil) + }, testNotParallel) } -func TestHandlerPanic_h2(t *testing.T) { - testHandlerPanic(t, false, h2Mode, nil, "intentional death for testing") + +func TestHandlerPanic(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, false, mode, nil, "intentional death for testing") + }, testNotParallel) } func TestHandlerPanicWithHijack(t *testing.T) { // Only testing HTTP/1, and our http2 server doesn't support hijacking. - testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing") + run(t, func(t *testing.T, mode testMode) { + testHandlerPanic(t, true, mode, nil, "intentional death for testing") + }, testNotParallel, []testMode{http1Mode}) } -func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) { - defer afterTest(t) +func testHandlerPanic(t *testing.T, withHijack bool, mode testMode, wrapper func(Handler) Handler, panicValue any) { // Unlike the other tests that set the log output to io.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: @@ -2803,8 +2746,7 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) H if wrapper != nil { handler = wrapper(handler) } - cst := newClientServerTest(t, h2, handler) - defer cst.close() + cst := newClientServerTest(t, mode, handler) // Do a blocking read on the log output pipe so its logging // doesn't bleed into the next test. But wait only 5 seconds @@ -2847,9 +2789,11 @@ func (w terrorWriter) Write(p []byte) (int, error) { // Issue 16456: allow writing 0 bytes on hijacked conn to test hijack // without any log spam. func TestServerWriteHijackZeroBytes(t *testing.T) { - defer afterTest(t) + run(t, testServerWriteHijackZeroBytes, []testMode{http1Mode}) +} +func testServerWriteHijackZeroBytes(t *testing.T, mode testMode) { done := make(chan struct{}) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) w.(Flusher).Flush() conn, _, err := w.(Hijacker).Hijack() @@ -2862,10 +2806,9 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { if err != ErrHijacked { t.Errorf("Write error = %v; want ErrHijacked", err) } - })) - ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) + }).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -2880,19 +2823,23 @@ func TestServerWriteHijackZeroBytes(t *testing.T) { } } -func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") } -func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") } -func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") } -func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") } +func TestServerNoDate(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testServerNoHeader(t, mode, "Date") + }) +} -func testServerNoHeader(t *testing.T, h2 bool, header string) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerContentType(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testServerNoHeader(t, mode, "Content-Type") + }) +} + +func testServerNoHeader(t *testing.T, mode testMode, header string) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()[header] = nil io.WriteString(w, "foo") // non-empty })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -2903,15 +2850,13 @@ func testServerNoHeader(t *testing.T, h2 bool, header string) { } } -func TestStripPrefix(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestStripPrefix(t *testing.T) { run(t, testStripPrefix) } +func testStripPrefix(t *testing.T, mode testMode) { h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) w.Header().Set("X-RawPath", r.URL.RawPath) }) - ts := httptest.NewServer(StripPrefix("/foo/bar", h)) - defer ts.Close() + ts := newClientServerTest(t, mode, StripPrefix("/foo/bar", h)).ts c := ts.Client() @@ -2961,15 +2906,11 @@ func TestStripPrefixNotModifyRequest(t *testing.T) { } } -func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) } -func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) } -func testRequestLimit(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestRequestLimit(t *testing.T) { run(t, testRequestLimit) } +func testRequestLimit(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") }), optQuietLog) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, nil) var bytesPerHeader = len("header12345: val12345\r\n") for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ { @@ -2979,7 +2920,7 @@ func testRequestLimit(t *testing.T, h2 bool) { if res != nil { defer res.Body.Close() } - if h2 { + if mode == http2Mode { // In HTTP/2, the result depends on a race. If the client has received the // server's SETTINGS before RoundTrip starts sending the request, then RoundTrip // will fail with an error. Otherwise, the client should receive a 431 from the @@ -3021,13 +2962,10 @@ func (cr countReader) Read(p []byte) (n int, err error) { return } -func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) } -func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) } -func testRequestBodyLimit(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestRequestBodyLimit(t *testing.T) { run(t, testRequestBodyLimit) } +func testRequestBodyLimit(t *testing.T, mode testMode) { const limit = 1 << 20 - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) n, err := io.Copy(io.Discard, r.Body) if err == nil { @@ -3044,7 +2982,6 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { t.Errorf("MaxBytesError.Limit = %d, want %d", mbErr.Limit, limit) } })) - defer cst.close() nWritten := new(int64) req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) @@ -3068,13 +3005,12 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { // TestClientWriteShutdown tests that if the client shuts down the write // side of their TCP connection, the server doesn't send a 400 Bad Request. -func TestClientWriteShutdown(t *testing.T) { +func TestClientWriteShutdown(t *testing.T) { run(t, testClientWriteShutdown) } +func testClientWriteShutdown(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/17906") } - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -3119,12 +3055,12 @@ func TestServerBufferedChunking(t *testing.T) { // closing the TCP connection, causing the client to get a RST. // See https://golang.org/issue/3595 func TestServerGracefulClose(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testServerGracefulClose, []testMode{http1Mode}) +} +func testServerGracefulClose(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, "bye", StatusUnauthorized) - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3162,11 +3098,9 @@ func TestServerGracefulClose(t *testing.T) { <-writeErr } -func TestCaseSensitiveMethod_h1(t *testing.T) { testCaseSensitiveMethod(t, h1Mode) } -func TestCaseSensitiveMethod_h2(t *testing.T) { testCaseSensitiveMethod(t, h2Mode) } -func testCaseSensitiveMethod(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) } +func testCaseSensitiveMethod(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "get" { t.Errorf(`Got method %q; want "get"`, r.Method) } @@ -3187,8 +3121,10 @@ func testCaseSensitiveMethod(t *testing.T, h2 bool) { // response, the net/http package adds a "Content-Length: 0" response // header. func TestContentLengthZero(t *testing.T) { - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {})) - defer ts.Close() + run(t, testContentLengthZero, []testMode{http1Mode}) +} +func testContentLengthZero(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {})).ts for _, version := range []string{"HTTP/1.0", "HTTP/1.1"} { conn, err := net.Dial("tcp", ts.Listener.Addr().String()) @@ -3215,15 +3151,17 @@ func TestContentLengthZero(t *testing.T) { } func TestCloseNotifier(t *testing.T) { - defer afterTest(t) + run(t, testCloseNotifier, []testMode{http1Mode}) +} +func testCloseNotifier(t *testing.T, mode testMode) { gotReq := make(chan bool, 1) sawClose := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gotReq <- true cc := rw.(CloseNotifier).CloseNotify() <-cc sawClose <- true - })) + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) @@ -3257,11 +3195,12 @@ For: // // Issue 13165 (where it used to deadlock), but behavior changed in Issue 23921. func TestCloseNotifierPipelined(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testCloseNotifierPipelined, []testMode{http1Mode}) +} +func testCloseNotifierPipelined(t *testing.T, mode testMode) { gotReq := make(chan bool, 2) sawClose := make(chan bool, 2) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gotReq <- true cc := rw.(CloseNotifier).CloseNotify() select { @@ -3270,8 +3209,7 @@ func TestCloseNotifierPipelined(t *testing.T) { case <-time.After(100 * time.Millisecond): } sawClose <- true - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("error dialing: %v", err) @@ -3341,12 +3279,14 @@ func TestCloseNotifierChanLeak(t *testing.T) { // Issue 9763. // HTTP/1-only test. (http2 doesn't have Hijack) func TestHijackAfterCloseNotifier(t *testing.T) { - defer afterTest(t) + run(t, testHijackAfterCloseNotifier, []testMode{http1Mode}) +} +func testHijackAfterCloseNotifier(t *testing.T, mode testMode) { script := make(chan string, 2) script <- "closenotify" script <- "hijack" close(script) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { plan := <-script switch plan { default: @@ -3369,13 +3309,12 @@ func TestHijackAfterCloseNotifier(t *testing.T) { c.Close() return } - })) - defer ts.Close() - res1, err := Get(ts.URL) + })).ts + res1, err := ts.Client().Get(ts.URL) if err != nil { log.Fatal(err) } - res2, err := Get(ts.URL) + res2, err := ts.Client().Get(ts.URL) if err != nil { log.Fatal(err) } @@ -3387,12 +3326,13 @@ func TestHijackAfterCloseNotifier(t *testing.T) { } func TestHijackBeforeRequestBodyRead(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testHijackBeforeRequestBodyRead, []testMode{http1Mode}) +} +func testHijackBeforeRequestBodyRead(t *testing.T, mode testMode) { var requestBody = bytes.Repeat([]byte("a"), 1<<20) bodyOkay := make(chan bool, 1) gotCloseNotify := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(bodyOkay) // caller will read false if nothing else reqBody := r.Body @@ -3419,8 +3359,7 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { case <-time.After(5 * time.Second): gotCloseNotify <- false } - })) - defer ts.Close() + })).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3440,14 +3379,14 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { } } -func TestOptions(t *testing.T) { +func TestOptions(t *testing.T) { run(t, testOptions, []testMode{http1Mode}) } +func testOptions(t *testing.T, mode testMode) { uric := make(chan string, 2) // only expect 1, but leave space for 2 mux := NewServeMux() mux.HandleFunc("/", func(w ResponseWriter, r *Request) { uric <- r.RequestURI }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3492,15 +3431,15 @@ func TestOptions(t *testing.T) { } } -func TestOptionsHandler(t *testing.T) { +func TestOptionsHandler(t *testing.T) { run(t, testOptionsHandler, []testMode{http1Mode}) } +func testOptionsHandler(t *testing.T, mode testMode) { rc := make(chan *Request, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { rc <- r - })) - ts.Config.DisableGeneralOptionsHandler = true - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.DisableGeneralOptionsHandler = true + }).ts conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -3804,12 +3743,12 @@ func TestDoubleHijack(t *testing.T) { // optimization and is pointless if dealing with a // badly behaved client. func TestHTTP10ConnectionHeader(t *testing.T) { - defer afterTest(t) - + run(t, testHTTP10ConnectionHeader, []testMode{http1Mode}) +} +func testHTTP10ConnectionHeader(t *testing.T, mode testMode) { mux := NewServeMux() mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {})) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts // net/http uses HTTP/1.1 for requests, so write requests manually tests := []struct { @@ -3856,14 +3795,11 @@ func TestHTTP10ConnectionHeader(t *testing.T) { } // See golang.org/issue/5660 -func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) } -func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) } -func testServerReaderFromOrder(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerReaderFromOrder(t *testing.T) { run(t, testServerReaderFromOrder) } +func testServerReaderFromOrder(t *testing.T, mode testMode) { pr, pw := io.Pipe() const size = 3 << 20 - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path done := make(chan bool) go func() { @@ -3883,7 +3819,6 @@ func testServerReaderFromOrder(t *testing.T, h2 bool) { pw.Close() <-done })) - defer cst.close() req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size)) if err != nil { @@ -3957,16 +3892,10 @@ func TestContentTypeOkayOn204(t *testing.T) { // proxy). So then two people own that Request.Body (both the server // and the http client), and both think they can close it on failure. // Therefore, all incoming server requests Bodies need to be thread-safe. -func TestTransportAndServerSharedBodyRace_h1(t *testing.T) { - testTransportAndServerSharedBodyRace(t, h1Mode) +func TestTransportAndServerSharedBodyRace(t *testing.T) { + run(t, testTransportAndServerSharedBodyRace) } -func TestTransportAndServerSharedBodyRace_h2(t *testing.T) { - testTransportAndServerSharedBodyRace(t, h2Mode) -} -func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) { const bodySize = 1 << 20 // errorf is like t.Errorf, but also writes to println. When @@ -3980,7 +3909,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { } unblockBackend := make(chan bool) - backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { gone := rw.(CloseNotifier).CloseNotify() didCopy := make(chan any) go func() { @@ -4007,7 +3936,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { backendRespc := make(chan *Response, 1) var proxy *clientServerTest - proxy = newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { req2, _ := NewRequest("POST", backend.ts.URL, req.Body) req2.ContentLength = bodySize cancel := make(chan struct{}) @@ -4027,7 +3956,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // Try to cause a race: Both the Transport and the proxy handler's Server // will try to read/close req.Body (aka req2.Body) - if h2 { + if mode == http2Mode { close(cancel) } else { proxy.c.Transport.(*Transport).CancelRequest(req2) @@ -4071,22 +4000,23 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { // cause the Handler goroutine's Request.Body.Close to block. // See issue 7121. func TestRequestBodyCloseDoesntBlock(t *testing.T) { + run(t, testRequestBodyCloseDoesntBlock, []testMode{http1Mode}) +} +func testRequestBodyCloseDoesntBlock(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in -short mode") } - defer afterTest(t) readErrCh := make(chan error, 1) errCh := make(chan error, 2) - server := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + server := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { go func(body io.Reader) { _, err := body.Read(make([]byte, 100)) readErrCh <- err }(req.Body) time.Sleep(500 * time.Millisecond) - })) - defer server.Close() + })).ts closeConn := make(chan bool) defer close(closeConn) @@ -4149,9 +4079,8 @@ func TestAppendTime(t *testing.T) { } } -func TestServerConnState(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerConnState(t *testing.T) { run(t, testServerConnState, []testMode{http1Mode}) } +func testServerConnState(t *testing.T, mode testMode) { handler := map[string]func(w ResponseWriter, r *Request){ "/": func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello.") @@ -4217,37 +4146,36 @@ func TestServerConnState(t *testing.T) { // next call to wantLog. } - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { handler[r.URL.Path](w, r) - })) + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(io.Discard, "", 0) + ts.Config.ConnState = func(c net.Conn, state ConnState) { + if c == nil { + t.Errorf("nil conn seen in state %s", state) + return + } + sl := <-activeLog + if sl.active == nil && state == StateNew { + sl.active = c + } else if sl.active != c { + t.Errorf("unexpected conn in state %s", state) + activeLog <- sl + return + } + sl.got = append(sl.got, state) + if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) { + close(sl.complete) + sl.complete = nil + } + activeLog <- sl + } + }).ts defer func() { activeLog <- &stateLog{} // If the test failed, allow any remaining ConnState callbacks to complete. ts.Close() }() - ts.Config.ErrorLog = log.New(io.Discard, "", 0) - ts.Config.ConnState = func(c net.Conn, state ConnState) { - if c == nil { - t.Errorf("nil conn seen in state %s", state) - return - } - sl := <-activeLog - if sl.active == nil && state == StateNew { - sl.active = c - } else if sl.active != c { - t.Errorf("unexpected conn in state %s", state) - activeLog <- sl - return - } - sl.got = append(sl.got, state) - if sl.complete != nil && (len(sl.got) >= len(sl.want) || !reflect.DeepEqual(sl.got, sl.want[:len(sl.got)])) { - close(sl.complete) - sl.complete = nil - } - activeLog <- sl - } - - ts.Start() c := ts.Client() mustGet := func(url string, headers ...string) { @@ -4329,13 +4257,15 @@ func TestServerConnState(t *testing.T) { }, StateNew, StateActive, StateIdle, StateClosed) } -func TestServerKeepAlivesEnabled(t *testing.T) { - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - ts.Config.SetKeepAlivesEnabled(false) - ts.Start() - defer ts.Close() - res, err := Get(ts.URL) +func TestServerKeepAlivesEnabledResultClose(t *testing.T) { + run(t, testServerKeepAlivesEnabledResultClose, []testMode{http1Mode}) +} +func testServerKeepAlivesEnabledResultClose(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + }), func(ts *httptest.Server) { + ts.Config.SetKeepAlivesEnabled(false) + }).ts + res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) } @@ -4346,16 +4276,12 @@ func TestServerKeepAlivesEnabled(t *testing.T) { } // golang.org/issue/7856 -func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) } -func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) } -func testServerEmptyBodyRace(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerEmptyBodyRace(t *testing.T) { run(t, testServerEmptyBodyRace) } +func testServerEmptyBodyRace(t *testing.T, mode testMode) { var n int32 - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { atomic.AddInt32(&n, 1) }), optQuietLog) - defer cst.close() var wg sync.WaitGroup const reqs = 20 for i := 0; i < reqs; i++ { @@ -4436,9 +4362,9 @@ func TestCloseWrite(t *testing.T) { // fixed. // // So add an explicit test for this. -func TestServerFlushAndHijack(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerFlushAndHijack(t *testing.T) { run(t, testServerFlushAndHijack, []testMode{http1Mode}) } +func testServerFlushAndHijack(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, "Hello, ") w.(Flusher).Flush() conn, buf, _ := w.(Hijacker).Hijack() @@ -4449,8 +4375,7 @@ func TestServerFlushAndHijack(t *testing.T) { if err := conn.Close(); err != nil { t.Error(err) } - })) - defer ts.Close() + })).ts res, err := Get(ts.URL) if err != nil { t.Fatal(err) @@ -4472,20 +4397,21 @@ func TestServerFlushAndHijack(t *testing.T) { // To test, verify we don't timeout or see fewer unique client // addresses (== unique connections) than requests. func TestServerKeepAliveAfterWriteError(t *testing.T) { + run(t, testServerKeepAliveAfterWriteError, []testMode{http1Mode}) +} +func testServerKeepAliveAfterWriteError(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in -short mode") } - defer afterTest(t) const numReq = 3 addrc := make(chan string, numReq) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addrc <- r.RemoteAddr time.Sleep(500 * time.Millisecond) w.(Flusher).Flush() - })) - ts.Config.WriteTimeout = 250 * time.Millisecond - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.WriteTimeout = 250 * time.Millisecond + }).ts errc := make(chan error, numReq) go func() { @@ -4529,12 +4455,13 @@ func TestServerKeepAliveAfterWriteError(t *testing.T) { // Issue 9987: shouldn't add automatic Content-Length (or // Content-Type) if a Transfer-Encoding was set by the handler. func TestNoContentLengthIfTransferEncoding(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testNoContentLengthIfTransferEncoding, []testMode{http1Mode}) +} +func testNoContentLengthIfTransferEncoding(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Transfer-Encoding", "foo") io.WriteString(w, "") - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) @@ -4682,15 +4609,12 @@ func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) { } } -func TestHandlerSetsBodyNil_h1(t *testing.T) { testHandlerSetsBodyNil(t, h1Mode) } -func TestHandlerSetsBodyNil_h2(t *testing.T) { testHandlerSetsBodyNil(t, h2Mode) } -func testHandlerSetsBodyNil(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestHandlerSetsBodyNil(t *testing.T) { run(t, testHandlerSetsBodyNil) } +func testHandlerSetsBodyNil(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = nil fmt.Fprintf(w, "%v", r.RemoteAddr) })) - defer cst.close() get := func() string { res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -4780,9 +4704,11 @@ func TestServerValidatesHostHeader(t *testing.T) { } func TestServerHandlersCanHandleH2PRI(t *testing.T) { + run(t, testServerHandlersCanHandleH2PRI, []testMode{http1Mode}) +} +func testServerHandlersCanHandleH2PRI(t *testing.T, mode testMode) { const upgradeResponse = "upgrade here" - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, br, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -4804,8 +4730,7 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) { return } io.WriteString(conn, upgradeResponse) - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -4872,17 +4797,12 @@ func TestServerValidatesHeaders(t *testing.T) { } } -func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) { - testServerRequestContextCancel_ServeHTTPDone(t, h1Mode) +func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) { + run(t, testServerRequestContextCancel_ServeHTTPDone) } -func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) { - testServerRequestContextCancel_ServeHTTPDone(t, h2Mode) -} -func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, mode testMode) { ctxc := make(chan context.Context, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() select { case <-ctx.Done(): @@ -4891,7 +4811,6 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { } ctxc <- ctx })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -4910,16 +4829,16 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { // is always blocked in a Read call so it notices the EOF from the client. // See issues 15927 and 15224. func TestServerRequestContextCancel_ConnClose(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testServerRequestContextCancel_ConnClose, []testMode{http1Mode}) +} +func testServerRequestContextCancel_ConnClose(t *testing.T, mode testMode) { inHandler := make(chan struct{}) handlerDone := make(chan struct{}) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(inHandler) <-r.Context().Done() close(handlerDone) - })) - defer ts.Close() + })).ts c, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatal(err) @@ -4931,23 +4850,17 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) { <-handlerDone } -func TestServerContext_ServerContextKey_h1(t *testing.T) { - testServerContext_ServerContextKey(t, h1Mode) +func TestServerContext_ServerContextKey(t *testing.T) { + run(t, testServerContext_ServerContextKey) } -func TestServerContext_ServerContextKey_h2(t *testing.T) { - testServerContext_ServerContextKey(t, h2Mode) -} -func testServerContext_ServerContextKey(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func testServerContext_ServerContextKey(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() got := ctx.Value(ServerContextKey) if _, ok := got.(*Server); !ok { t.Errorf("context value = %T; want *http.Server", got) } })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -4955,20 +4868,14 @@ func testServerContext_ServerContextKey(t *testing.T, h2 bool) { res.Body.Close() } -func TestServerContext_LocalAddrContextKey_h1(t *testing.T) { - testServerContext_LocalAddrContextKey(t, h1Mode) +func TestServerContext_LocalAddrContextKey(t *testing.T) { + run(t, testServerContext_LocalAddrContextKey) } -func TestServerContext_LocalAddrContextKey_h2(t *testing.T) { - testServerContext_LocalAddrContextKey(t, h2Mode) -} -func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func testServerContext_LocalAddrContextKey(t *testing.T, mode testMode) { ch := make(chan any, 1) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { ch <- r.Context().Value(LocalAddrContextKey) })) - defer cst.close() if _, err := cst.c.Head(cst.ts.URL); err != nil { t.Fatal(err) } @@ -5021,16 +4928,19 @@ func TestHandlerSetTransferEncodingGzip(t *testing.T) { } func BenchmarkClientServer(b *testing.B) { + run(b, benchmarkClientServer, []testMode{http1Mode, https1Mode, http2Mode}) +} +func benchmarkClientServer(b *testing.B, mode testMode) { b.ReportAllocs() b.StopTimer() - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { fmt.Fprintf(rw, "Hello world.\n") - })) - defer ts.Close() + })).ts b.StartTimer() + c := ts.Client() for i := 0; i < b.N; i++ { - res, err := Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { b.Fatal("Get:", err) } @@ -5048,33 +4958,21 @@ func BenchmarkClientServer(b *testing.B) { b.StopTimer() } -func BenchmarkClientServerParallel4(b *testing.B) { - benchmarkClientServerParallel(b, 4, false) -} - -func BenchmarkClientServerParallel64(b *testing.B) { - benchmarkClientServerParallel(b, 64, false) -} - -func BenchmarkClientServerParallelTLS4(b *testing.B) { - benchmarkClientServerParallel(b, 4, true) -} - -func BenchmarkClientServerParallelTLS64(b *testing.B) { - benchmarkClientServerParallel(b, 64, true) -} - -func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { - b.ReportAllocs() - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { - fmt.Fprintf(rw, "Hello world.\n") - })) - if useTLS { - ts.StartTLS() - } else { - ts.Start() +func BenchmarkClientServerParallel(b *testing.B) { + for _, parallelism := range []int{4, 64} { + b.Run(fmt.Sprint(parallelism), func(b *testing.B) { + run(b, func(b *testing.B, mode testMode) { + benchmarkClientServerParallel(b, parallelism, mode) + }, []testMode{http1Mode, https1Mode, http2Mode}) + }) } - defer ts.Close() +} + +func benchmarkClientServerParallel(b *testing.B, parallelism int, mode testMode) { + b.ReportAllocs() + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { + fmt.Fprintf(rw, "Hello world.\n") + })).ts b.ResetTimer() b.SetParallelism(parallelism) b.RunParallel(func(pb *testing.PB) { @@ -5464,15 +5362,15 @@ Host: golang.org } } -func BenchmarkCloseNotifier(b *testing.B) { +func BenchmarkCloseNotifier(b *testing.B) { run(b, benchmarkCloseNotifier, []testMode{http1Mode}) } +func benchmarkCloseNotifier(b *testing.B, mode testMode) { b.ReportAllocs() b.StopTimer() sawClose := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(b, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { <-rw.(CloseNotifier).CloseNotify() sawClose <- true - })) - defer ts.Close() + })).ts tot := time.NewTimer(5 * time.Second) defer tot.Stop() b.StartTimer() @@ -5508,20 +5406,18 @@ func TestConcurrentServerServe(t *testing.T) { } } -func TestServerIdleTimeout(t *testing.T) { +func TestServerIdleTimeout(t *testing.T) { run(t, testServerIdleTimeout, []testMode{http1Mode}) } +func testServerIdleTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) io.WriteString(w, r.RemoteAddr) - })) - ts.Config.ReadHeaderTimeout = 1 * time.Second - ts.Config.IdleTimeout = 2 * time.Second - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = 1 * time.Second + ts.Config.IdleTimeout = 2 * time.Second + }).ts c := ts.Client() get := func() string { @@ -5576,12 +5472,12 @@ func get(t *testing.T, c *Client, url string) string { // Tests that calls to Server.SetKeepAlivesEnabled(false) closes any // currently-open connections. func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testServerSetKeepAlivesEnabledClosesConns, []testMode{http1Mode}) +} +func testServerSetKeepAlivesEnabledClosesConns(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -5620,16 +5516,8 @@ func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { } } -func TestServerShutdown_h1(t *testing.T) { - testServerShutdown(t, h1Mode) -} -func TestServerShutdown_h2(t *testing.T) { - testServerShutdown(t, h2Mode) -} - -func testServerShutdown(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) +func TestServerShutdown(t *testing.T) { run(t, testServerShutdown) } +func testServerShutdown(t *testing.T, mode testMode) { var doShutdown func() // set later var doStateCount func() var shutdownRes = make(chan error, 1) @@ -5645,10 +5533,9 @@ func testServerShutdown(t *testing.T, h2 bool) { time.Sleep(20 * time.Millisecond) io.WriteString(w, r.RemoteAddr) }) - cst := newClientServerTest(t, h2, handler, func(srv *httptest.Server) { + cst := newClientServerTest(t, mode, handler, func(srv *httptest.Server) { srv.Config.RegisterOnShutdown(func() { gotOnShutdown <- struct{}{} }) }) - defer cst.close() doShutdown = func() { shutdownRes <- cst.ts.Config.Shutdown(context.Background()) @@ -5678,24 +5565,22 @@ func testServerShutdown(t *testing.T, h2 bool) { } } -func TestServerShutdownStateNew(t *testing.T) { +func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) } +func testServerShutdownStateNew(t *testing.T, mode testMode) { if testing.Short() { t.Skip("test takes 5-6 seconds; skipping in short mode") } - setParallel(t) - defer afterTest(t) - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - // nothing. - })) var connAccepted sync.WaitGroup - ts.Config.ConnState = func(conn net.Conn, state ConnState) { - if state == StateNew { - connAccepted.Done() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + // nothing. + }), func(ts *httptest.Server) { + ts.Config.ConnState = func(conn net.Conn, state ConnState) { + if state == StateNew { + connAccepted.Done() + } } - } - ts.Start() - defer ts.Close() + }).ts // Start a connection but never write to it. connAccepted.Add(1) @@ -5757,16 +5642,14 @@ func TestServerCloseDeadlock(t *testing.T) { // Issue 17717: tests that Server.SetKeepAlivesEnabled is respected by // both HTTP/1 and HTTP/2. -func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) } -func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) } -func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { - if h2 { +func TestServerKeepAlivesEnabled(t *testing.T) { run(t, testServerKeepAlivesEnabled, testNotParallel) } +func testServerKeepAlivesEnabled(t *testing.T, mode testMode) { + if mode == http2Mode { restore := ExportSetH2GoawayTimeout(10 * time.Millisecond) defer restore() } // Not parallel: messes with global variable. (http2goAwayTimeout) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {})) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})) defer cst.close() srv := cst.ts.Config srv.SetKeepAlivesEnabled(false) @@ -5803,9 +5686,8 @@ func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { // Issue 18447: test that the Server's ReadTimeout is stopped while // the server's doing its 1-byte background read between requests, // waiting for the connection to maybe close. -func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { run(t, testServerCancelsReadTimeoutWhenIdle) } +func testServerCancelsReadTimeoutWhenIdle(t *testing.T, mode testMode) { runTimeSensitiveTest(t, []time.Duration{ 10 * time.Millisecond, 50 * time.Millisecond, @@ -5813,17 +5695,16 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { time.Second, 2 * time.Second, }, func(t *testing.T, timeout time.Duration) error { - ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { select { case <-time.After(2 * timeout): fmt.Fprint(w, "ok") case <-r.Context().Done(): fmt.Fprint(w, r.Context().Err()) } - })) - ts.Config.ReadTimeout = timeout - ts.Start() - defer ts.Close() + }), func(ts *httptest.Server) { + ts.Config.ReadTimeout = timeout + }).ts c := ts.Client() @@ -5847,8 +5728,9 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { // beginning of a request has been received, rather than including time the // connection spent idle. func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testServerCancelsReadHeaderTimeoutWhenIdle, []testMode{http1Mode}) +} +func testServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T, mode testMode) { runTimeSensitiveTest(t, []time.Duration{ 10 * time.Millisecond, 50 * time.Millisecond, @@ -5856,11 +5738,10 @@ func TestServerCancelsReadHeaderTimeoutWhenIdle(t *testing.T) { time.Second, 2 * time.Second, }, func(t *testing.T, timeout time.Duration) error { - ts := httptest.NewUnstartedServer(serve(200)) - ts.Config.ReadHeaderTimeout = timeout - ts.Config.IdleTimeout = 0 // disable idle timeout - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, serve(200), func(ts *httptest.Server) { + ts.Config.ReadHeaderTimeout = timeout + ts.Config.IdleTimeout = 0 // disable idle timeout + }).ts // rather than using an http.Client, create a single connection, so that // we can ensure this connection is not closed. @@ -5912,13 +5793,13 @@ func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t * // Issue 18535: test that the Server doesn't try to do a background // read if it's already done one. func TestServerDuplicateBackgroundRead(t *testing.T) { + run(t, testServerDuplicateBackgroundRead, []testMode{http1Mode}) +} +func testServerDuplicateBackgroundRead(t *testing.T, mode testMode) { if runtime.GOOS == "netbsd" && runtime.GOARCH == "arm" { testenv.SkipFlaky(t, 24826) } - setParallel(t) - defer afterTest(t) - goroutines := 5 requests := 2000 if testing.Short() { @@ -5926,8 +5807,7 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { requests = 100 } - hts := httptest.NewServer(HandlerFunc(NotFound)) - defer hts.Close() + hts := newClientServerTest(t, mode, HandlerFunc(NotFound)).ts reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n") @@ -5970,14 +5850,15 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { // bufio.Reader.Buffered(), without resorting to Reading it // (potentially blocking) to get at it. func TestServerHijackGetsBackgroundByte(t *testing.T) { + run(t, testServerHijackGetsBackgroundByte, []testMode{http1Mode}) +} +func testServerHijackGetsBackgroundByte(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/18657") } - setParallel(t) - defer afterTest(t) done := make(chan struct{}) inHandler := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) // Tell the client to send more data after the GET request. @@ -6000,8 +5881,7 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) { t.Error("context unexpectedly canceled") default: } - })) - defer ts.Close() + })).ts cn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -6030,14 +5910,15 @@ func TestServerHijackGetsBackgroundByte(t *testing.T) { // immediate 1MB of data to the server to fill up the server's 4KB // buffer. func TestServerHijackGetsBackgroundByte_big(t *testing.T) { + run(t, testServerHijackGetsBackgroundByte_big, []testMode{http1Mode}) +} +func testServerHijackGetsBackgroundByte_big(t *testing.T, mode testMode) { if runtime.GOOS == "plan9" { t.Skip("skipping test; see https://golang.org/issue/18657") } - setParallel(t) - defer afterTest(t) done := make(chan struct{}) const size = 8 << 10 - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { defer close(done) conn, buf, err := w.(Hijacker).Hijack() @@ -6061,8 +5942,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) { } else if !allX { t.Errorf("read %q; want %d 'x'", slurp, size) } - })) - defer ts.Close() + })).ts cn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { @@ -6198,73 +6078,27 @@ func TestStripPortFromHost(t *testing.T) { } } -func TestServerContexts(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestServerContexts(t *testing.T) { run(t, testServerContexts) } +func testServerContexts(t *testing.T, mode testMode) { type baseKey struct{} type connKey struct{} ch := make(chan context.Context, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { ch <- r.Context() - })) - ts.Config.BaseContext = func(ln net.Listener) context.Context { - if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { - t.Errorf("unexpected onceClose listener type %T", ln) + }), func(ts *httptest.Server) { + ts.Config.BaseContext = func(ln net.Listener) context.Context { + if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { + t.Errorf("unexpected onceClose listener type %T", ln) + } + return context.WithValue(context.Background(), baseKey{}, "base") } - return context.WithValue(context.Background(), baseKey{}, "base") - } - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got, want := ctx.Value(baseKey{}), "base"; got != want { + t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) + } + return context.WithValue(ctx, connKey{}, "conn") } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.Start() - defer ts.Close() - res, err := ts.Client().Get(ts.URL) - if err != nil { - t.Fatal(err) - } - res.Body.Close() - ctx := <-ch - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("base context key = %#v; want %q", got, want) - } - if got, want := ctx.Value(connKey{}), "conn"; got != want { - t.Errorf("conn context key = %#v; want %q", got, want) - } -} - -func TestServerContextsHTTP2(t *testing.T) { - setParallel(t) - defer afterTest(t) - type baseKey struct{} - type connKey struct{} - ch := make(chan context.Context, 1) - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { - if r.ProtoMajor != 2 { - t.Errorf("unexpected HTTP/1.x request") - } - ch <- r.Context() - })) - ts.Config.BaseContext = func(ln net.Listener) context.Context { - if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") { - t.Errorf("unexpected onceClose listener type %T", ln) - } - return context.WithValue(context.Background(), baseKey{}, "base") - } - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got, want := ctx.Value(baseKey{}), "base"; got != want { - t.Errorf("in ConnContext, base context key = %#v; want %q", got, want) - } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.TLS = &tls.Config{ - NextProtos: []string{"h2", "http/1.1"}, - } - ts.StartTLS() - defer ts.Close() - ts.Client().Transport.(*Transport).ForceAttemptHTTP2 = true + }).ts res, err := ts.Client().Get(ts.URL) if err != nil { t.Fatal(err) @@ -6281,20 +6115,20 @@ func TestServerContextsHTTP2(t *testing.T) { // Issue 35750: check ConnContext not modifying context for other connections func TestConnContextNotModifyingAllContexts(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testConnContextNotModifyingAllContexts) +} +func testConnContextNotModifyingAllContexts(t *testing.T, mode testMode) { type connKey struct{} - ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { rw.Header().Set("Connection", "close") - })) - ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { - if got := ctx.Value(connKey{}); got != nil { - t.Errorf("in ConnContext, unexpected context key = %#v", got) + }), func(ts *httptest.Server) { + ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context { + if got := ctx.Value(connKey{}); got != nil { + t.Errorf("in ConnContext, unexpected context key = %#v", got) + } + return context.WithValue(ctx, connKey{}, "conn") } - return context.WithValue(ctx, connKey{}, "conn") - } - ts.Start() - defer ts.Close() + }).ts var res *Response var err error @@ -6315,10 +6149,12 @@ func TestConnContextNotModifyingAllContexts(t *testing.T) { // Issue 30710: ensure that as per the spec, a server responds // with 501 Not Implemented for unsupported transfer-encodings. func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { - cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testUnsupportedTransferEncodingsReturn501, []testMode{http1Mode}) +} +func testUnsupportedTransferEncodingsReturn501(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello, World!")) - })) - defer cst.Close() + })).ts serverURL, err := url.Parse(cst.URL) if err != nil { @@ -6353,19 +6189,9 @@ func TestUnsupportedTransferEncodingsReturn501(t *testing.T) { } } -func TestContentEncodingNoSniffing_h1(t *testing.T) { - testContentEncodingNoSniffing(t, h1Mode) -} - -func TestContentEncodingNoSniffing_h2(t *testing.T) { - testContentEncodingNoSniffing(t, h2Mode) -} - // Issue 31753: don't sniff when Content-Encoding is set -func testContentEncodingNoSniffing(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - +func TestContentEncodingNoSniffing(t *testing.T) { run(t, testContentEncodingNoSniffing) } +func testContentEncodingNoSniffing(t *testing.T, mode testMode) { type setting struct { name string body []byte @@ -6428,13 +6254,12 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { for _, tt := range settings { t.Run(tt.name, func(t *testing.T) { - cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, r *Request) { if tt.contentEncoding != nil { rw.Header().Set("Content-Encoding", tt.contentEncoding.(string)) } rw.Write(tt.body) })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -6460,13 +6285,13 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) { // Issue 30803: ensure that TimeoutHandler logs spurious // WriteHeader calls, for consistency with other Handlers. func TestTimeoutHandlerSuperfluousLogs(t *testing.T) { + run(t, testTimeoutHandlerSuperfluousLogs, []testMode{http1Mode}) +} +func testTimeoutHandlerSuperfluousLogs(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - pc, curFile, _, _ := runtime.Caller(0) curFileBaseName := filepath.Base(curFile) testFuncName := runtime.FuncForPC(pc).Name() @@ -6520,7 +6345,7 @@ func TestTimeoutHandlerSuperfluousLogs(t *testing.T) { dur = 10 * time.Second } th := TimeoutHandler(sh, dur, timeoutMsg) - cst := newClientServerTest(t, h1Mode /* the test is protocol-agnostic */, th, optWithServerLog(srvLog)) + cst := newClientServerTest(t, mode, th, optWithServerLog(srvLog)) defer cst.close() res, err := cst.c.Get(cst.ts.URL) @@ -6590,15 +6415,16 @@ func BenchmarkResponseStatusLine(b *testing.B) { } }) } + func TestDisableKeepAliveUpgrade(t *testing.T) { + run(t, testDisableKeepAliveUpgrade, []testMode{http1Mode}) +} +func testDisableKeepAliveUpgrade(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - setParallel(t) - defer afterTest(t) - - s := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + s := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "Upgrade") w.Header().Set("Upgrade", "someProto") w.WriteHeader(StatusSwitchingProtocols) @@ -6611,10 +6437,9 @@ func TestDisableKeepAliveUpgrade(t *testing.T) { // Copy from the *bufio.ReadWriter, which may contain buffered data. // Copy to the net.Conn, to avoid buffering the output. io.Copy(c, buf) - })) - s.Config.SetKeepAlivesEnabled(false) - s.Start() - defer s.Close() + }), func(ts *httptest.Server) { + ts.Config.SetKeepAlivesEnabled(false) + }).ts cl := s.Client() cl.Transport.(*Transport).DisableKeepAlives = true @@ -6683,21 +6508,21 @@ func TestQuerySemicolon(t *testing.T) { {"?a=1;x=good;x=bad", "", "good", true}, } - for _, tt := range tests { - t.Run(tt.query+"/allow=false", func(t *testing.T) { - allowSemicolons := false - testQuerySemicolon(t, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning) - }) - t.Run(tt.query+"/allow=true", func(t *testing.T) { - allowSemicolons, expectWarning := true, false - testQuerySemicolon(t, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning) - }) - } + run(t, func(t *testing.T, mode testMode) { + for _, tt := range tests { + t.Run(tt.query+"/allow=false", func(t *testing.T) { + allowSemicolons := false + testQuerySemicolon(t, mode, tt.query, tt.xNoSemicolons, allowSemicolons, tt.warning) + }) + t.Run(tt.query+"/allow=true", func(t *testing.T) { + allowSemicolons, expectWarning := true, false + testQuerySemicolon(t, mode, tt.query, tt.xWithSemicolons, allowSemicolons, expectWarning) + }) + } + }) } -func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolons, expectWarning bool) { - setParallel(t) - +func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string, allowSemicolons, expectWarning bool) { writeBackX := func(w ResponseWriter, r *Request) { x := r.URL.Query().Get("x") if expectWarning { @@ -6720,11 +6545,10 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon h = AllowQuerySemicolons(h) } - ts := httptest.NewUnstartedServer(h) logBuf := &strings.Builder{} - ts.Config.ErrorLog = log.New(logBuf, "", 0) - ts.Start() - defer ts.Close() + ts := newClientServerTest(t, mode, h, func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(logBuf, "", 0) + }).ts req, _ := NewRequest("GET", ts.URL+query, nil) res, err := ts.Client().Do(req) @@ -6759,13 +6583,15 @@ func TestMaxBytesHandler(t *testing.T) { for _, requestSize := range []int64{100, 1_000, 1_000_000} { t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize), func(t *testing.T) { - testMaxBytesHandler(t, maxSize, requestSize) + run(t, func(t *testing.T, mode testMode) { + testMaxBytesHandler(t, mode, maxSize, requestSize) + }) }) } } } -func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { +func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) { var ( handlerN int64 handlerErr error @@ -6776,7 +6602,7 @@ func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) { io.Copy(w, &buf) }) - ts := httptest.NewServer(MaxBytesHandler(echo, maxSize)) + ts := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)).ts defer ts.Close() c := ts.Client() @@ -6843,13 +6669,12 @@ func TestProcessing(t *testing.T) { } } -func TestParseFormCleanup_h1(t *testing.T) { testParseFormCleanup(t, h1Mode) } -func TestParseFormCleanup_h2(t *testing.T) { - t.Skip("https://go.dev/issue/20253") - testParseFormCleanup(t, h2Mode) -} +func TestParseFormCleanup(t *testing.T) { run(t, testParseFormCleanup) } +func testParseFormCleanup(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/20253") + } -func testParseFormCleanup(t *testing.T, h2 bool) { const maxMemory = 1024 const key = "file" @@ -6858,9 +6683,7 @@ func testParseFormCleanup(t *testing.T, h2 bool) { t.Skip("https://go.dev/issue/25965") } - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { r.ParseMultipartForm(maxMemory) f, _, err := r.FormFile(key) if err != nil { @@ -6874,7 +6697,6 @@ func testParseFormCleanup(t *testing.T, h2 bool) { } w.Write([]byte(of.Name())) })) - defer cst.close() fBuf := new(bytes.Buffer) mw := multipart.NewWriter(fBuf) @@ -6911,33 +6733,23 @@ func testParseFormCleanup(t *testing.T, h2 bool) { func TestHeadBody(t *testing.T) { const identityMode = false const chunkedMode = true - t.Run("h1", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "HEAD") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "HEAD") }) - }) - t.Run("h2", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "HEAD") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "HEAD") }) + run(t, func(t *testing.T, mode testMode) { + t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "HEAD") }) + t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "HEAD") }) }) } func TestGetBody(t *testing.T) { const identityMode = false const chunkedMode = true - t.Run("h1", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h1Mode, identityMode, "GET") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h1Mode, chunkedMode, "GET") }) - }) - t.Run("h2", func(t *testing.T) { - t.Run("identity", func(t *testing.T) { testHeadBody(t, h2Mode, identityMode, "GET") }) - t.Run("chunked", func(t *testing.T) { testHeadBody(t, h2Mode, chunkedMode, "GET") }) + run(t, func(t *testing.T, mode testMode) { + t.Run("identity", func(t *testing.T) { testHeadBody(t, mode, identityMode, "GET") }) + t.Run("chunked", func(t *testing.T) { testHeadBody(t, mode, chunkedMode, "GET") }) }) } -func testHeadBody(t *testing.T, h2, chunked bool, method string) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func testHeadBody(t *testing.T, mode testMode, chunked bool, method string) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { b, err := io.ReadAll(r.Body) if err != nil { t.Errorf("server reading body: %v", err) diff --git a/src/net/http/sniff_test.go b/src/net/http/sniff_test.go index e91335729af..d6ef40905e6 100644 --- a/src/net/http/sniff_test.go +++ b/src/net/http/sniff_test.go @@ -88,13 +88,9 @@ func TestDetectContentType(t *testing.T) { } } -func TestServerContentType_h1(t *testing.T) { testServerContentType(t, h1Mode) } -func TestServerContentType_h2(t *testing.T) { testServerContentType(t, h2Mode) } - -func testServerContentType(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerContentTypeSniff(t *testing.T) { run(t, testServerContentTypeSniff) } +func testServerContentTypeSniff(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { i, _ := strconv.Atoi(r.FormValue("i")) tt := sniffTests[i] n, err := w.Write(tt.data) @@ -134,15 +130,12 @@ func testServerContentType(t *testing.T, h2 bool) { // Issue 5953: shouldn't sniff if the handler set a Content-Type header, // even if it's the empty string. -func TestServerIssue5953_h1(t *testing.T) { testServerIssue5953(t, h1Mode) } -func TestServerIssue5953_h2(t *testing.T) { testServerIssue5953(t, h2Mode) } -func testServerIssue5953(t *testing.T, h2 bool) { - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestServerIssue5953(t *testing.T) { run(t, testServerIssue5953) } +func testServerIssue5953(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()["Content-Type"] = []string{""} fmt.Fprintf(w, "hi") })) - defer cst.close() resp, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -173,11 +166,8 @@ func (b *byteAtATimeReader) Read(p []byte) (n int, err error) { return 1, nil } -func TestContentTypeWithVariousSources_h1(t *testing.T) { testContentTypeWithVariousSources(t, h1Mode) } -func TestContentTypeWithVariousSources_h2(t *testing.T) { testContentTypeWithVariousSources(t, h2Mode) } -func testContentTypeWithVariousSources(t *testing.T, h2 bool) { - defer afterTest(t) - +func TestContentTypeWithVariousSources(t *testing.T) { run(t, testContentTypeWithVariousSources) } +func testContentTypeWithVariousSources(t *testing.T, mode testMode) { const ( input = "\n\n\t\n" expected = "text/html; charset=utf-8" @@ -239,8 +229,7 @@ func testContentTypeWithVariousSources(t *testing.T, h2 bool) { }, }} { t.Run(test.name, func(t *testing.T) { - cst := newClientServerTest(t, h2, HandlerFunc(test.handler)) - defer cst.close() + cst := newClientServerTest(t, mode, HandlerFunc(test.handler)) resp, err := cst.c.Get(cst.ts.URL) if err != nil { @@ -265,12 +254,9 @@ func testContentTypeWithVariousSources(t *testing.T, h2 bool) { } } -func TestSniffWriteSize_h1(t *testing.T) { testSniffWriteSize(t, h1Mode) } -func TestSniffWriteSize_h2(t *testing.T) { testSniffWriteSize(t, h2Mode) } -func testSniffWriteSize(t *testing.T, h2 bool) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestSniffWriteSize(t *testing.T) { run(t, testSniffWriteSize) } +func testSniffWriteSize(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { size, _ := strconv.Atoi(r.FormValue("size")) written, err := io.WriteString(w, strings.Repeat("a", size)) if err != nil { @@ -281,7 +267,6 @@ func testSniffWriteSize(t *testing.T, h2 bool) { t.Errorf("write of %d bytes wrote %d bytes", size, written) } })) - defer cst.close() for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} { res, err := cst.c.Get(fmt.Sprintf("%s/?size=%d", cst.ts.URL, size)) if err != nil { diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 26293befb4d..8748cf6f7b6 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -135,12 +135,11 @@ func (tcs *testConnSet) check(t *testing.T) { } } -func TestReuseRequest(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestReuseRequest(t *testing.T) { run(t, testReuseRequest) } +func testReuseRequest(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("{}")) - })) - defer ts.Close() + })).ts c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) @@ -165,10 +164,9 @@ func TestReuseRequest(t *testing.T) { // Two subsequent requests and verify their response is the same. // The response from the server is our own IP:port -func TestTransportKeepAlives(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() +func TestTransportKeepAlives(t *testing.T) { run(t, testTransportKeepAlives, []testMode{http1Mode}) } +func testTransportKeepAlives(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() for _, disableKeepAlive := range []bool{false, true} { @@ -197,9 +195,10 @@ func TestTransportKeepAlives(t *testing.T) { } func TestTransportConnectionCloseOnResponse(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportConnectionCloseOnResponse) +} +func testTransportConnectionCloseOnResponse(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts connSet, testDial := makeTestDial(t) @@ -253,9 +252,10 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { // describes the source source connection it got (remote port number + // address of its net.Conn). func TestTransportConnectionCloseOnRequest(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportConnectionCloseOnRequest, []testMode{http1Mode}) +} +func testTransportConnectionCloseOnRequest(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts connSet, testDial := makeTestDial(t) @@ -317,9 +317,10 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { // send Connection: close. // HTTP/1-only (Connection: close doesn't exist in h2) func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportConnectionCloseOnRequestDisableKeepAlive, []testMode{http1Mode}) +} +func testTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() c.Transport.(*Transport).DisableKeepAlives = true @@ -337,6 +338,9 @@ func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) { // Test that Transport only sends one "Connection: close", regardless of // how "close" was indicated. func TestTransportRespectRequestWantsClose(t *testing.T) { + run(t, testTransportRespectRequestWantsClose, []testMode{http1Mode}) +} +func testTransportRespectRequestWantsClose(t *testing.T, mode testMode) { tests := []struct { disableKeepAlives bool close bool @@ -350,9 +354,7 @@ func TestTransportRespectRequestWantsClose(t *testing.T) { for _, tc := range tests { t.Run(fmt.Sprintf("DisableKeepAlive=%v,RequestClose=%v", tc.disableKeepAlives, tc.close), func(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() c.Transport.(*Transport).DisableKeepAlives = tc.disableKeepAlives @@ -387,9 +389,10 @@ func TestTransportRespectRequestWantsClose(t *testing.T) { } func TestTransportIdleCacheKeys(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportIdleCacheKeys, []testMode{http1Mode}) +} +func testTransportIdleCacheKeys(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -420,12 +423,12 @@ func TestTransportIdleCacheKeys(t *testing.T) { // Tests that the HTTP transport re-uses connections when a client // reads to the end of a response Body without closing it. -func TestTransportReadToEndReusesConn(t *testing.T) { - defer afterTest(t) +func TestTransportReadToEndReusesConn(t *testing.T) { run(t, testTransportReadToEndReusesConn) } +func testTransportReadToEndReusesConn(t *testing.T, mode testMode) { const msg = "foobar" var addrSeen map[string]int - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addrSeen[r.RemoteAddr]++ if r.URL.Path == "/chunked/" { w.WriteHeader(200) @@ -435,16 +438,13 @@ func TestTransportReadToEndReusesConn(t *testing.T) { w.WriteHeader(200) } w.Write([]byte(msg)) - })) - defer ts.Close() - - buf := make([]byte, len(msg)) + })).ts for pi, path := range []string{"/content-length/", "/chunked/"} { wantLen := []int{len(msg), -1}[pi] addrSeen = make(map[string]int) for i := 0; i < 3; i++ { - res, err := Get(ts.URL + path) + res, err := ts.Client().Get(ts.URL + path) if err != nil { t.Errorf("Get %s: %v", path, err) continue @@ -459,9 +459,9 @@ func TestTransportReadToEndReusesConn(t *testing.T) { if res.ContentLength != int64(wantLen) { t.Errorf("%s res.ContentLength = %d; want %d", path, res.ContentLength, wantLen) } - n, err := res.Body.Read(buf) - if n != len(msg) || err != io.EOF { - t.Errorf("%s Read = %v, %v; want %d, EOF", path, n, err, len(msg)) + got, err := io.ReadAll(res.Body) + if string(got) != msg || err != nil { + t.Errorf("%s ReadAll(Body) = %q, %v; want %q, nil", path, string(got), err, msg) } } if len(addrSeen) != 1 { @@ -471,13 +471,15 @@ func TestTransportReadToEndReusesConn(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { - defer afterTest(t) + run(t, testTransportMaxPerHostIdleConns, []testMode{http1Mode}) +} +func testTransportMaxPerHostIdleConns(t *testing.T, mode testMode) { stop := make(chan struct{}) // stop marks the exit of main Test goroutine defer close(stop) resch := make(chan string) gotReq := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotReq <- true var msg string select { @@ -490,8 +492,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { t.Errorf("Write: %v", err) return } - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -559,14 +560,15 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportMaxConnsPerHostIncludeDialInProgress) +} +func testTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("foo")) if err != nil { t.Fatalf("Write: %v", err) } - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) dialStarted := make(chan struct{}) @@ -626,7 +628,9 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { } func TestTransportMaxConnsPerHost(t *testing.T) { - defer afterTest(t) + run(t, testTransportMaxConnsPerHost, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testTransportMaxConnsPerHost(t *testing.T, mode testMode) { CondSkipHTTP2(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -636,115 +640,101 @@ func TestTransportMaxConnsPerHost(t *testing.T) { } }) - testMaxConns := func(scheme string, ts *httptest.Server) { - defer ts.Close() - - c := ts.Client() - tr := c.Transport.(*Transport) - tr.MaxConnsPerHost = 1 - if err := ExportHttp2ConfigureTransport(tr); err != nil { - t.Fatalf("ExportHttp2ConfigureTransport: %v", err) - } - - mu := sync.Mutex{} - var conns []net.Conn - var dialCnt, gotConnCnt, tlsHandshakeCnt int32 - tr.Dial = func(network, addr string) (net.Conn, error) { - atomic.AddInt32(&dialCnt, 1) - c, err := net.Dial(network, addr) - mu.Lock() - defer mu.Unlock() - conns = append(conns, c) - return c, err - } - - doReq := func() { - trace := &httptrace.ClientTrace{ - GotConn: func(connInfo httptrace.GotConnInfo) { - if !connInfo.Reused { - atomic.AddInt32(&gotConnCnt, 1) - } - }, - TLSHandshakeStart: func() { - atomic.AddInt32(&tlsHandshakeCnt, 1) - }, - } - req, _ := NewRequest("GET", ts.URL, nil) - req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) - - resp, err := c.Do(req) - if err != nil { - t.Fatalf("request failed: %v", err) - } - defer resp.Body.Close() - _, err = io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read body failed: %v", err) - } - } - - wg := sync.WaitGroup{} - for i := 0; i < 10; i++ { - wg.Add(1) - go func() { - defer wg.Done() - doReq() - }() - } - wg.Wait() - - expected := int32(tr.MaxConnsPerHost) - if dialCnt != expected { - t.Errorf("round 1: too many dials (%s): %d != %d", scheme, dialCnt, expected) - } - if gotConnCnt != expected { - t.Errorf("round 1: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) - } - if ts.TLS != nil && tlsHandshakeCnt != expected { - t.Errorf("round 1: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) - } - - if t.Failed() { - t.FailNow() - } + ts := newClientServerTest(t, mode, h).ts + c := ts.Client() + tr := c.Transport.(*Transport) + tr.MaxConnsPerHost = 1 + mu := sync.Mutex{} + var conns []net.Conn + var dialCnt, gotConnCnt, tlsHandshakeCnt int32 + tr.Dial = func(network, addr string) (net.Conn, error) { + atomic.AddInt32(&dialCnt, 1) + c, err := net.Dial(network, addr) mu.Lock() - for _, c := range conns { - c.Close() - } - conns = nil - mu.Unlock() - tr.CloseIdleConnections() + defer mu.Unlock() + conns = append(conns, c) + return c, err + } - doReq() - expected++ - if dialCnt != expected { - t.Errorf("round 2: too many dials (%s): %d", scheme, dialCnt) + doReq := func() { + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + if !connInfo.Reused { + atomic.AddInt32(&gotConnCnt, 1) + } + }, + TLSHandshakeStart: func() { + atomic.AddInt32(&tlsHandshakeCnt, 1) + }, } - if gotConnCnt != expected { - t.Errorf("round 2: too many get connections (%s): %d != %d", scheme, gotConnCnt, expected) + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + resp, err := c.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) } - if ts.TLS != nil && tlsHandshakeCnt != expected { - t.Errorf("round 2: too many tls handshakes (%s): %d != %d", scheme, tlsHandshakeCnt, expected) + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read body failed: %v", err) } } - testMaxConns("http", httptest.NewServer(h)) - testMaxConns("https", httptest.NewTLSServer(h)) + wg := sync.WaitGroup{} + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + doReq() + }() + } + wg.Wait() - ts := httptest.NewUnstartedServer(h) - ts.TLS = &tls.Config{NextProtos: []string{"h2"}} - ts.StartTLS() - testMaxConns("http2", ts) + expected := int32(tr.MaxConnsPerHost) + if dialCnt != expected { + t.Errorf("round 1: too many dials: %d != %d", dialCnt, expected) + } + if gotConnCnt != expected { + t.Errorf("round 1: too many get connections: %d != %d", gotConnCnt, expected) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("round 1: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) + } + + if t.Failed() { + t.FailNow() + } + + mu.Lock() + for _, c := range conns { + c.Close() + } + conns = nil + mu.Unlock() + tr.CloseIdleConnections() + + doReq() + expected++ + if dialCnt != expected { + t.Errorf("round 2: too many dials: %d", dialCnt) + } + if gotConnCnt != expected { + t.Errorf("round 2: too many get connections: %d != %d", gotConnCnt, expected) + } + if ts.TLS != nil && tlsHandshakeCnt != expected { + t.Errorf("round 2: too many tls handshakes: %d != %d", tlsHandshakeCnt, expected) + } } func TestTransportRemovesDeadIdleConnections(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportRemovesDeadIdleConnections, []testMode{http1Mode}) +} +func testTransportRemovesDeadIdleConnections(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -789,10 +779,10 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) { // Test that the Transport notices when a server hangs up on its // unexpectedly (a keep-alive connection is closed). func TestTransportServerClosingUnexpectedly(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(hostPortHandler) - defer ts.Close() + run(t, testTransportServerClosingUnexpectedly, []testMode{http1Mode}) +} +func testTransportServerClosingUnexpectedly(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, hostPortHandler).ts c := ts.Client() fetch := func(n, retries int) string { @@ -846,11 +836,13 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { // Test for https://golang.org/issue/2616 (appropriate issue number) // This fails pretty reliably with GOMAXPROCS=100 or something high. func TestStressSurpriseServerCloses(t *testing.T) { - defer afterTest(t) + run(t, testStressSurpriseServerCloses, []testMode{http1Mode}) +} +func testStressSurpriseServerCloses(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in short mode") } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "5") w.Header().Set("Content-Type", "text/plain") w.Write([]byte("Hello")) @@ -858,8 +850,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { conn, buf, _ := w.(Hijacker).Hijack() buf.Flush() conn.Close() - })) - defer ts.Close() + })).ts c := ts.Client() // Do a bunch of traffic from different goroutines. Send to activityc @@ -906,16 +897,15 @@ func TestStressSurpriseServerCloses(t *testing.T) { // TestTransportHeadResponses verifies that we deal with Content-Lengths // with no bodies properly -func TestTransportHeadResponses(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportHeadResponses(t *testing.T) { run(t, testTransportHeadResponses) } +func testTransportHeadResponses(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) } w.Header().Set("Content-Length", "123") w.WriteHeader(200) - })) - defer ts.Close() + })).ts c := ts.Client() for i := 0; i < 2; i++ { @@ -941,16 +931,17 @@ func TestTransportHeadResponses(t *testing.T) { // TestTransportHeadChunkedResponse verifies that we ignore chunked transfer-encoding // on responses to HEAD requests. func TestTransportHeadChunkedResponse(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportHeadChunkedResponse, []testMode{http1Mode}, testNotParallel) +} +func testTransportHeadChunkedResponse(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "HEAD" { panic("expected HEAD; got " + r.Method) } w.Header().Set("Transfer-Encoding", "chunked") // client should ignore w.Header().Set("x-client-ipport", r.RemoteAddr) w.WriteHeader(200) - })) - defer ts.Close() + })).ts c := ts.Client() // Ensure that we wait for the readLoop to complete before @@ -991,11 +982,10 @@ var roundTripTests = []struct { } // Test that the modification made to the Request by the RoundTripper is cleaned up -func TestRoundTripGzip(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestRoundTripGzip(t *testing.T) { run(t, testRoundTripGzip) } +func testRoundTripGzip(t *testing.T, mode testMode) { const responseBody = "test response body" - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") if expect := req.FormValue("expect_accept"); accept != expect { t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", @@ -1010,8 +1000,7 @@ func TestRoundTripGzip(t *testing.T) { rw.Header().Set("Content-Encoding", accept) rw.Write([]byte(responseBody)) } - })) - defer ts.Close() + })).ts tr := ts.Client().Transport.(*Transport) for i, test := range roundTripTests { @@ -1055,12 +1044,14 @@ func TestRoundTripGzip(t *testing.T) { } -func TestTransportGzip(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportGzip(t *testing.T) { run(t, testTransportGzip) } +func testTransportGzip(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56020") + } const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { if req.Method == "HEAD" { if g := req.Header.Get("Accept-Encoding"); g != "" { t.Errorf("HEAD request sent with Accept-Encoding of %q; want none", g) @@ -1087,8 +1078,7 @@ func TestTransportGzip(t *testing.T) { io.CopyN(gz, rand.Reader, nRandBytes) } gz.Close() - })) - defer ts.Close() + })).ts c := ts.Client() for _, chunked := range []string{"1", "0"} { @@ -1153,10 +1143,10 @@ func TestTransportGzip(t *testing.T) { // If a request has Expect:100-continue header, the request blocks sending body until the first response. // Premature consumption of the request body should not be occurred. func TestTransportExpect100Continue(t *testing.T) { - setParallel(t) - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + run(t, testTransportExpect100Continue, []testMode{http1Mode}) +} +func testTransportExpect100Continue(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { switch req.URL.Path { case "/100": // This endpoint implicitly responds 100 Continue and reads body. @@ -1194,8 +1184,7 @@ func TestTransportExpect100Continue(t *testing.T) { conn.Close() } - })) - defer ts.Close() + })).ts tests := []struct { path string @@ -1242,7 +1231,9 @@ func TestTransportExpect100Continue(t *testing.T) { } func TestSOCKS5Proxy(t *testing.T) { - defer afterTest(t) + run(t, testSOCKS5Proxy, []testMode{http1Mode, https1Mode, http2Mode}) +} +func testSOCKS5Proxy(t *testing.T, mode testMode) { ch := make(chan string, 1) l := newLocalListener(t) defer l.Close() @@ -1322,12 +1313,7 @@ func TestSOCKS5Proxy(t *testing.T) { }) for _, useTLS := range []bool{false, true} { t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) { - var ts *httptest.Server - if useTLS { - ts = httptest.NewTLSServer(h) - } else { - ts = httptest.NewServer(h) - } + ts := newClientServerTest(t, mode, h).ts go proxy(t) c := ts.Client() c.Transport.(*Transport).Proxy = ProxyURL(pu) @@ -1359,16 +1345,16 @@ func TestSOCKS5Proxy(t *testing.T) { func TestTransportProxy(t *testing.T) { defer afterTest(t) - testCases := []struct{ httpsSite, httpsProxy bool }{ - {false, false}, - {false, true}, - {true, false}, - {true, true}, + testCases := []struct{ siteMode, proxyMode testMode }{ + {http1Mode, http1Mode}, + {http1Mode, https1Mode}, + {https1Mode, http1Mode}, + {https1Mode, https1Mode}, } for _, testCase := range testCases { - httpsSite := testCase.httpsSite - httpsProxy := testCase.httpsProxy - t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) { + siteMode := testCase.siteMode + proxyMode := testCase.proxyMode + t.Run(fmt.Sprintf("site=%v/proxy=%v", siteMode, proxyMode), func(t *testing.T) { siteCh := make(chan *Request, 1) h1 := HandlerFunc(func(w ResponseWriter, r *Request) { siteCh <- r @@ -1414,18 +1400,8 @@ func TestTransportProxy(t *testing.T) { }() } }) - var ts *httptest.Server - if httpsSite { - ts = httptest.NewTLSServer(h1) - } else { - ts = httptest.NewServer(h1) - } - var proxy *httptest.Server - if httpsProxy { - proxy = httptest.NewTLSServer(h2) - } else { - proxy = httptest.NewServer(h2) - } + ts := newClientServerTest(t, siteMode, h1).ts + proxy := newClientServerTest(t, proxyMode, h2).ts pu, err := url.Parse(proxy.URL) if err != nil { @@ -1436,7 +1412,7 @@ func TestTransportProxy(t *testing.T) { // If only one server is HTTPS, c must be derived from that server in order // to ensure that it is configured to use the fake root CA from testcert.go. c := proxy.Client() - if httpsSite { + if siteMode == https1Mode { c = ts.Client() } @@ -1453,7 +1429,7 @@ func TestTransportProxy(t *testing.T) { c.Transport.(*Transport).CloseIdleConnections() ts.Close() proxy.Close() - if httpsSite { + if siteMode == https1Mode { // First message should be a CONNECT, asking for a socket to the real server, if got.Method != "CONNECT" { t.Errorf("Wrong method for secure proxying: %q", got.Method) @@ -1602,10 +1578,10 @@ func TestTransportDialPreservesNetOpProxyError(t *testing.T) { // (A bug caused dialConn to instead write the per-request Proxy-Authorization // header through to the shared Header instance, introducing a data race.) func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { - setParallel(t) - defer afterTest(t) - - proxy := httptest.NewTLSServer(NotFoundHandler()) + run(t, testTransportProxyDialDoesNotMutateProxyConnectHeader) +} +func testTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T, mode testMode) { + proxy := newClientServerTest(t, mode, NotFoundHandler()).ts defer proxy.Close() c := proxy.Client() @@ -1639,13 +1615,12 @@ func TestTransportProxyDialDoesNotMutateProxyConnectHeader(t *testing.T) { // client gets the same value back. This is more cute than anything, // but checks that we don't recurse forever, and checks that // Content-Encoding is removed. -func TestTransportGzipRecursive(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportGzipRecursive(t *testing.T) { run(t, testTransportGzipRecursive) } +func testTransportGzipRecursive(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write(rgz) - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -1667,13 +1642,12 @@ func TestTransportGzipRecursive(t *testing.T) { // golang.org/issue/7750: request fails when server replies with // a short gzip body -func TestTransportGzipShort(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportGzipShort(t *testing.T) { run(t, testTransportGzipShort) } +func testTransportGzipShort(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", "gzip") w.Write([]byte{0x1f, 0x8b}) - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -1703,19 +1677,20 @@ func waitNumGoroutine(nmax int) int { // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { + run(t, testTransportPersistConnLeak, testNotParallel) +} +func testTransportPersistConnLeak(t *testing.T, mode testMode) { // Not parallel: counts goroutines - defer afterTest(t) const numReq = 25 gotReqCh := make(chan bool, numReq) unblockCh := make(chan bool, numReq) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { gotReqCh <- true <-unblockCh w.Header().Set("Content-Length", "0") w.WriteHeader(204) - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1773,11 +1748,12 @@ func TestTransportPersistConnLeak(t *testing.T) { // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { + run(t, testTransportPersistConnLeakShortBody, testNotParallel) +} +func testTransportPersistConnLeakShortBody(t *testing.T, mode testMode) { // Not parallel: measures goroutines. - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - })) - defer ts.Close() + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -1851,9 +1827,10 @@ func (d *countingDialer) Read() (total, live int64) { } func TestTransportPersistConnLeakNeverIdle(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportPersistConnLeakNeverIdle, []testMode{http1Mode}) +} +func testTransportPersistConnLeakNeverIdle(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Close every connection so that it cannot be kept alive. conn, _, err := w.(Hijacker).Hijack() if err != nil { @@ -1861,8 +1838,7 @@ func TestTransportPersistConnLeakNeverIdle(t *testing.T) { return } conn.Close() - })) - defer ts.Close() + })).ts var d countingDialer c := ts.Client() @@ -1923,13 +1899,17 @@ func (cc *contextCounter) Read() (live int64) { } func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) { - defer afterTest(t) + run(t, testTransportPersistConnContextLeakMaxConnsPerHost) +} +func testTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("https://go.dev/issue/56021") + } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { runtime.Gosched() w.WriteHeader(StatusOK) - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).MaxConnsPerHost = 1 @@ -1979,16 +1959,15 @@ func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) { } // This used to crash; https://golang.org/issue/3266 -func TestTransportIdleConnCrash(t *testing.T) { - defer afterTest(t) +func TestTransportIdleConnCrash(t *testing.T) { run(t, testTransportIdleConnCrash) } +func testTransportIdleConnCrash(t *testing.T, mode testMode) { var tr *Transport unblockCh := make(chan bool, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockCh tr.CloseIdleConnections() - })) - defer ts.Close() + })).ts c := ts.Client() tr = c.Transport.(*Transport) @@ -2010,16 +1989,15 @@ func TestTransportIdleConnCrash(t *testing.T) { // before the response body has been read. This was a regression // which sadly lacked a triggering test. The large response body made // the old race easier to trigger. -func TestIssue3644(t *testing.T) { - defer afterTest(t) +func TestIssue3644(t *testing.T) { run(t, testIssue3644) } +func testIssue3644(t *testing.T, mode testMode) { const numFoos = 5000 - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") for i := 0; i < numFoos; i++ { w.Write([]byte("foo ")) } - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) if err != nil { @@ -2037,14 +2015,12 @@ func TestIssue3644(t *testing.T) { // Test that a client receives a server's reply, even if the server doesn't read // the entire request body. -func TestIssue3595(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestIssue3595(t *testing.T) { run(t, testIssue3595) } +func testIssue3595(t *testing.T, mode testMode) { const deniedMsg = "sorry, denied." - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, deniedMsg, StatusUnauthorized) - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a')) if err != nil { @@ -2062,12 +2038,11 @@ func TestIssue3595(t *testing.T) { // From https://golang.org/issue/4454 , // "client fails to handle requests with no body and chunked encoding" -func TestChunkedNoContent(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestChunkedNoContent(t *testing.T) { run(t, testChunkedNoContent) } +func testChunkedNoContent(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.WriteHeader(StatusNoContent) - })) - defer ts.Close() + })).ts c := ts.Client() for _, closeBody := range []bool{true, false} { @@ -2086,17 +2061,18 @@ func TestChunkedNoContent(t *testing.T) { } func TestTransportConcurrency(t *testing.T) { + run(t, testTransportConcurrency, testNotParallel, []testMode{http1Mode}) +} +func testTransportConcurrency(t *testing.T, mode testMode) { // Not parallel: uses global test hooks. - defer afterTest(t) maxProcs, numReqs := 16, 500 if testing.Short() { maxProcs, numReqs = 4, 50 } defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs)) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%v", r.FormValue("echo")) - })) - defer ts.Close() + })).ts var wg sync.WaitGroup wg.Add(numReqs) @@ -2147,16 +2123,14 @@ func TestTransportConcurrency(t *testing.T) { wg.Wait() } -func TestIssue4191_InfiniteGetTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestIssue4191_InfiniteGetTimeout(t *testing.T) { run(t, testIssue4191_InfiniteGetTimeout) } +func testIssue4191_InfiniteGetTimeout(t *testing.T, mode testMode) { const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { io.Copy(w, neverEnding('a')) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts timeout := 100 * time.Millisecond c := ts.Client() @@ -2206,8 +2180,9 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testIssue4191_InfiniteGetToPutTimeout, []testMode{http1Mode}) +} +func testIssue4191_InfiniteGetToPutTimeout(t *testing.T, mode testMode) { const debug = false mux := NewServeMux() mux.HandleFunc("/get", func(w ResponseWriter, r *Request) { @@ -2217,7 +2192,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { defer r.Body.Close() io.Copy(io.Discard, r.Body) }) - ts := httptest.NewServer(mux) + ts := newClientServerTest(t, mode, mux).ts timeout := 100 * time.Millisecond c := ts.Client() @@ -2270,9 +2245,8 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { ts.Close() } -func TestTransportResponseHeaderTimeout(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportResponseHeaderTimeout(t *testing.T) { run(t, testTransportResponseHeaderTimeout) } +func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping timeout test in -short mode") } @@ -2285,8 +2259,7 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { inHandler <- true time.Sleep(2 * time.Second) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts c := ts.Client() c.Transport.(*Transport).ResponseHeaderTimeout = 500 * time.Millisecond @@ -2342,18 +2315,18 @@ func TestTransportResponseHeaderTimeout(t *testing.T) { } func TestTransportCancelRequest(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportCancelRequest, []testMode{http1Mode}) +} +func testTransportCancelRequest(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello") w.(Flusher).Flush() // send headers and some body <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2395,17 +2368,14 @@ func TestTransportCancelRequest(t *testing.T) { } } -func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { - setParallel(t) - defer afterTest(t) +func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2432,11 +2402,15 @@ func testTransportCancelRequestInDo(t *testing.T, body io.Reader) { } func TestTransportCancelRequestInDo(t *testing.T) { - testTransportCancelRequestInDo(t, nil) + run(t, func(t *testing.T, mode testMode) { + testTransportCancelRequestInDo(t, mode, nil) + }, []testMode{http1Mode}) } func TestTransportCancelRequestWithBodyInDo(t *testing.T) { - testTransportCancelRequestInDo(t, bytes.NewBuffer([]byte{0})) + run(t, func(t *testing.T, mode testMode) { + testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0})) + }, []testMode{http1Mode}) } func TestTransportCancelRequestInDial(t *testing.T) { @@ -2497,19 +2471,17 @@ Get = Get "http://something.no-network.tld/": net/http: request canceled while w } } -func TestCancelRequestWithChannel(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) } +func testCancelRequestWithChannel(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping test in -short mode") } unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello") w.(Flusher).Flush() // send headers and some body <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2555,19 +2527,20 @@ func TestCancelRequestWithChannel(t *testing.T) { } func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { - testCancelRequestWithChannelBeforeDo(t, false) + run(t, func(t *testing.T, mode testMode) { + testCancelRequestWithChannelBeforeDo(t, mode, false) + }) } func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { - testCancelRequestWithChannelBeforeDo(t, true) + run(t, func(t *testing.T, mode testMode) { + testCancelRequestWithChannelBeforeDo(t, mode, true) + }) } -func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { - setParallel(t) - defer afterTest(t) +func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) { unblockc := make(chan bool) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { <-unblockc - })) - defer ts.Close() + })).ts defer close(unblockc) c := ts.Client() @@ -2642,11 +2615,11 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) { // golang.org/issue/3672 -- Client can't close HTTP stream // Calling Close on a Response.Body used to just read until EOF. // Now it actually closes the TCP connection. -func TestTransportCloseResponseBody(t *testing.T) { - defer afterTest(t) +func TestTransportCloseResponseBody(t *testing.T) { run(t, testTransportCloseResponseBody) } +func testTransportCloseResponseBody(t *testing.T, mode testMode) { writeErr := make(chan error, 1) msg := []byte("young\n") - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { for { _, err := w.Write(msg) if err != nil { @@ -2655,8 +2628,7 @@ func TestTransportCloseResponseBody(t *testing.T) { } w.(Flusher).Flush() } - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -2761,10 +2733,8 @@ func TestTransportEmptyMethod(t *testing.T) { } } -func TestTransportSocketLateBinding(t *testing.T) { - setParallel(t) - defer afterTest(t) - +func TestTransportSocketLateBinding(t *testing.T) { run(t, testTransportSocketLateBinding) } +func testTransportSocketLateBinding(t *testing.T, mode testMode) { mux := NewServeMux() fooGate := make(chan bool, 1) mux.HandleFunc("/foo", func(w ResponseWriter, r *Request) { @@ -2775,8 +2745,7 @@ func TestTransportSocketLateBinding(t *testing.T) { mux.HandleFunc("/bar", func(w ResponseWriter, r *Request) { w.Header().Set("bar-ipport", r.RemoteAddr) }) - ts := httptest.NewServer(mux) - defer ts.Close() + ts := newClientServerTest(t, mode, mux).ts dialGate := make(chan bool, 1) c := ts.Client() @@ -2920,15 +2889,15 @@ Content-Length: %d // Issue 17739: the HTTP client must ignore any unknown 1xx // informational responses before the actual response. func TestTransportIgnore1xxResponses(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportIgnore1xxResponses, []testMode{http1Mode}) +} +func testTransportIgnore1xxResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) buf.Flush() conn.Close() })) - defer cst.close() cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway var got strings.Builder @@ -2954,9 +2923,10 @@ func TestTransportIgnore1xxResponses(t *testing.T) { } func TestTransportLimits1xxResponses(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportLimits1xxResponses, []testMode{http1Mode}) +} +func testTransportLimits1xxResponses(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() for i := 0; i < 10; i++ { buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) @@ -2965,7 +2935,6 @@ func TestTransportLimits1xxResponses(t *testing.T) { buf.Flush() conn.Close() })) - defer cst.close() cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway res, err := cst.c.Get(cst.ts.URL) @@ -2982,16 +2951,16 @@ func TestTransportLimits1xxResponses(t *testing.T) { // Issue 26161: the HTTP client must treat 101 responses // as the final response. func TestTransportTreat101Terminal(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportTreat101Terminal, []testMode{http1Mode}) +} +func testTransportTreat101Terminal(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 101 Switching Protocols\r\n\r\n")) buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) buf.Flush() conn.Close() })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -3123,16 +3092,18 @@ func TestProxyFromEnvironmentLowerCase(t *testing.T) { } func TestIdleConnChannelLeak(t *testing.T) { + run(t, testIdleConnChannelLeak, []testMode{http1Mode}, testNotParallel) +} +func testIdleConnChannelLeak(t *testing.T, mode testMode) { // Not parallel: uses global test hooks. var mu sync.Mutex var n int - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() n++ mu.Unlock() - })) - defer ts.Close() + })).ts const nReqs = 5 didRead := make(chan bool, nReqs) @@ -3180,11 +3151,12 @@ func TestIdleConnChannelLeak(t *testing.T) { // body into a ReadCloser if it's a Closer, and that the Transport // then closes it. func TestTransportClosesRequestBody(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportClosesRequestBody, []testMode{http1Mode}) +} +func testTransportClosesRequestBody(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) - })) - defer ts.Close() + })).ts c := ts.Client() @@ -3261,10 +3233,11 @@ func TestTransportTLSHandshakeTimeout(t *testing.T) { // Trying to repro golang.org/issue/3514 func TestTLSServerClosesConnection(t *testing.T) { - defer afterTest(t) - + run(t, testTLSServerClosesConnection, []testMode{https1Mode}) +} +func testTLSServerClosesConnection(t *testing.T, mode testMode) { closedc := make(chan bool, 1) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if strings.Contains(r.URL.Path, "/keep-alive-then-die") { conn, _, _ := w.(Hijacker).Hijack() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) @@ -3273,8 +3246,7 @@ func TestTLSServerClosesConnection(t *testing.T) { return } fmt.Fprintf(w, "hello") - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) @@ -3345,8 +3317,9 @@ func (c byteFromChanReader) Read(p []byte) (n int, err error) { // questionable state. // golang.org/issue/7569 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportNoReuseAfterEarlyResponse, []testMode{http1Mode}) +} +func testTransportNoReuseAfterEarlyResponse(t *testing.T, mode testMode) { var sconn struct { sync.Mutex c net.Conn @@ -3365,7 +3338,7 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { } defer closeConn() - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method == "GET" { io.WriteString(w, "bar") return @@ -3376,8 +3349,7 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { sconn.Unlock() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive go io.Copy(io.Discard, conn) - })) - defer ts.Close() + })).ts c := ts.Client() const bodySize = 256 << 10 @@ -3410,9 +3382,9 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { // Tests that we don't leak Transport persistConn.readLoop goroutines // when a server hangs up immediately after saying it would keep-alive. -func TestTransportIssue10457(t *testing.T) { - defer afterTest(t) // used to fail in goroutine leak check - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportIssue10457(t *testing.T) { run(t, testTransportIssue10457, []testMode{http1Mode}) } +func testTransportIssue10457(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // Send a response with no body, keep-alive // (implicit), and then lie and immediately close the // connection. This forces the Transport's readLoop to @@ -3421,8 +3393,7 @@ func TestTransportIssue10457(t *testing.T) { conn, _, _ := w.(Hijacker).Hijack() conn.Write([]byte("HTTP/1.1 200 OK\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n")) // keep-alive conn.Close() - })) - defer ts.Close() + })).ts c := ts.Client() res, err := c.Get(ts.URL) @@ -3463,6 +3434,9 @@ func (c writerFuncConn) Write(p []byte) (n int, err error) { return c.write(p) } // This automatically prevents an infinite resend loop because we'll run out of // the cached keep-alive connections eventually. func TestRetryRequestsOnError(t *testing.T) { + run(t, testRetryRequestsOnError, testNotParallel, []testMode{http1Mode}) +} +func testRetryRequestsOnError(t *testing.T, mode testMode) { newRequest := func(method, urlStr string, body io.Reader) *Request { req, err := NewRequest(method, urlStr, body) if err != nil { @@ -3533,8 +3507,6 @@ func TestRetryRequestsOnError(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - defer afterTest(t) - var ( mu sync.Mutex logbuf strings.Builder @@ -3546,11 +3518,10 @@ func TestRetryRequestsOnError(t *testing.T) { logbuf.WriteByte('\n') } - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { logf("Handler") w.Header().Set("X-Status", "ok") - })) - defer ts.Close() + })).ts var writeNumAtomic int32 c := ts.Client() @@ -3620,15 +3591,13 @@ Handler } // Issue 6981 -func TestTransportClosesBodyOnError(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportClosesBodyOnError(t *testing.T) { run(t, testTransportClosesBodyOnError) } +func testTransportClosesBodyOnError(t *testing.T, mode testMode) { readBody := make(chan error, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := io.ReadAll(r.Body) readBody <- err - })) - defer ts.Close() + })).ts c := ts.Client() fakeErr := errors.New("fake error") didClose := make(chan bool, 1) @@ -3668,17 +3637,17 @@ func TestTransportClosesBodyOnError(t *testing.T) { } func TestTransportDialTLS(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportDialTLS, []testMode{https1Mode, http2Mode}) +} +func testTransportDialTLS(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq, didDial bool - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).DialTLS = func(netw, addr string) (net.Conn, error) { mu.Lock() @@ -3705,19 +3674,17 @@ func TestTransportDialTLS(t *testing.T) { } } -func TestTransportDialContext(t *testing.T) { - setParallel(t) - defer afterTest(t) +func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) } +func testTransportDialContext(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq bool var receivedContext context.Context - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() @@ -3746,18 +3713,18 @@ func TestTransportDialContext(t *testing.T) { } func TestTransportDialTLSContext(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode}) +} +func testTransportDialTLSContext(t *testing.T, mode testMode) { var mu sync.Mutex // guards following var gotReq bool var receivedContext context.Context - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() gotReq = true mu.Unlock() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() @@ -3879,6 +3846,9 @@ func TestTransportTraceGotConnH2IdleConns(t *testing.T) { } func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { + run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode}) +} +func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } @@ -3888,8 +3858,7 @@ func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { tr.MaxIdleConnsPerHost = 1 tr.IdleConnTimeout = 10 * time.Millisecond } - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) - defer cst.close() + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc) if _, err := cst.c.Get(cst.ts.URL); err != nil { t.Fatalf("got error: %s", err) @@ -3920,13 +3889,12 @@ func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) { // implicitly ask for gzip support. If they want that, they need to do it // on their own. // golang.org/issue/8923 -func TestTransportRangeAndGzip(t *testing.T) { - defer afterTest(t) +func TestTransportRangeAndGzip(t *testing.T) { run(t, testTransportRangeAndGzip) } +func testTransportRangeAndGzip(t *testing.T, mode testMode) { reqc := make(chan *Request, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { reqc <- r - })) - defer ts.Close() + })).ts c := ts.Client() req, _ := NewRequest("GET", ts.URL, nil) @@ -3951,15 +3919,13 @@ func TestTransportRangeAndGzip(t *testing.T) { } // Test for issue 10474 -func TestTransportResponseCancelRace(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportResponseCancelRace(t *testing.T) { run(t, testTransportResponseCancelRace) } +func testTransportResponseCancelRace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // important that this response has a body. var b [1024]byte w.Write(b[:]) - })) - defer ts.Close() + })).ts tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) @@ -3991,19 +3957,19 @@ func TestTransportResponseCancelRace(t *testing.T) { // Test for issue 19248: Content-Encoding's value is case insensitive. func TestTransportContentEncodingCaseInsensitive(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportContentEncodingCaseInsensitive) +} +func testTransportContentEncodingCaseInsensitive(t *testing.T, mode testMode) { for _, ce := range []string{"gzip", "GZIP"} { ce := ce t.Run(ce, func(t *testing.T) { const encodedString = "Hello Gopher" - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Encoding", ce) gz := gzip.NewWriter(w) gz.Write([]byte(encodedString)) gz.Close() - })) - defer ts.Close() + })).ts res, err := ts.Client().Get(ts.URL) if err != nil { @@ -4024,10 +3990,10 @@ func TestTransportContentEncodingCaseInsensitive(t *testing.T) { } func TestTransportDialCancelRace(t *testing.T) { - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + run(t, testTransportDialCancelRace, testNotParallel, []testMode{http1Mode}) +} +func testTransportDialCancelRace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts tr := ts.Client().Transport.(*Transport) req, err := NewRequest("GET", ts.URL, nil) @@ -4140,13 +4106,12 @@ func TestTransportFlushesBodyChunks(t *testing.T) { } // Issue 22088: flush Transport request headers if we're not sure the body won't block on read. -func TestTransportFlushesRequestHeader(t *testing.T) { - defer afterTest(t) +func TestTransportFlushesRequestHeader(t *testing.T) { run(t, testTransportFlushesRequestHeader) } +func testTransportFlushesRequestHeader(t *testing.T, mode testMode) { gotReq := make(chan struct{}) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { close(gotReq) })) - defer cst.close() pr, pw := io.Pipe() req, err := NewRequest("POST", cst.ts.URL, pr) @@ -4175,20 +4140,21 @@ func TestTransportFlushesRequestHeader(t *testing.T) { // Issue 11745. func TestTransportPrefersResponseOverWriteError(t *testing.T) { + run(t, testTransportPrefersResponseOverWriteError) +} +func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - defer afterTest(t) const contentLengthLimit = 1024 * 1024 // 1MB - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.ContentLength >= contentLengthLimit { w.WriteHeader(StatusBadRequest) r.Body.Close() return } w.WriteHeader(StatusOK) - })) - defer ts.Close() + })).ts c := ts.Client() fail := 0 @@ -4296,12 +4262,13 @@ func testTransportAutoHTTP(t *testing.T, tr *Transport, wantH2 bool) { // Plus it's nice to be consistent and not have timing-dependent // behavior. func TestTransportReuseConnEmptyResponseBody(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportReuseConnEmptyResponseBody) +} +func testTransportReuseConnEmptyResponseBody(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) // Empty response body. })) - defer cst.close() n := 100 if testing.Short() { n = 10 @@ -4406,27 +4373,28 @@ func TestNoCrashReturningTransportAltConn(t *testing.T) { } func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) { - testTransportReuseConnection_Gzip(t, true) + run(t, func(t *testing.T, mode testMode) { + testTransportReuseConnection_Gzip(t, mode, true) + }) } func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { - testTransportReuseConnection_Gzip(t, false) + run(t, func(t *testing.T, mode testMode) { + testTransportReuseConnection_Gzip(t, mode, false) + }) } // Make sure we re-use underlying TCP connection for gzipped responses too. -func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { - setParallel(t) - defer afterTest(t) +func testTransportReuseConnection_Gzip(t *testing.T, mode testMode, chunked bool) { addr := make(chan string, 2) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { addr <- r.RemoteAddr w.Header().Set("Content-Encoding", "gzip") if chunked { w.(Flusher).Flush() } w.Write(rgz) // arbitrary gzip response - })) - defer ts.Close() + })).ts c := ts.Client() trace := &httptrace.ClientTrace{ @@ -4459,15 +4427,16 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { } } -func TestTransportResponseHeaderLength(t *testing.T) { - setParallel(t) - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestTransportResponseHeaderLength(t *testing.T) { run(t, testTransportResponseHeaderLength) } +func testTransportResponseHeaderLength(t *testing.T, mode testMode) { + if mode == http2Mode { + t.Skip("HTTP/2 Transport doesn't support MaxResponseHeaderBytes") + } + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/long" { w.Header().Set("Long", strings.Repeat("a", 1<<20)) } - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).MaxResponseHeaderBytes = 512 << 10 @@ -4493,18 +4462,23 @@ func TestTransportResponseHeaderLength(t *testing.T) { } } -func TestTransportEventTrace(t *testing.T) { testTransportEventTrace(t, h1Mode, false) } -func TestTransportEventTrace_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, false) } +func TestTransportEventTrace(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTransportEventTrace(t, mode, false) + }) +} // test a non-nil httptrace.ClientTrace but with all hooks set to zero. -func TestTransportEventTrace_NoHooks(t *testing.T) { testTransportEventTrace(t, h1Mode, true) } -func TestTransportEventTrace_NoHooks_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, true) } +func TestTransportEventTrace_NoHooks(t *testing.T) { + run(t, func(t *testing.T, mode testMode) { + testTransportEventTrace(t, mode, true) + }) +} -func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { - defer afterTest(t) +func testTransportEventTrace(t *testing.T, mode testMode, noHooks bool) { const resBody = "some body" gotWroteReqEvent := make(chan struct{}, 500) - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method == "GET" { // Do nothing for the second request. return @@ -4520,7 +4494,11 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { } } io.WriteString(w, resBody) - })) + }), func(tr *Transport) { + if tr.TLSClientConfig != nil { + tr.TLSClientConfig.InsecureSkipVerify = true + } + }) defer cst.close() cst.tr.ExpectContinueTimeout = 1 * time.Second @@ -4579,7 +4557,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { gotWroteReqEvent <- struct{}{} }, } - if h2 { + if mode == http2Mode { trace.TLSHandshakeStart = func() { logf("tls handshake start") } trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) @@ -4636,7 +4614,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { wantOnceOrMore("connected to tcp " + addrStr + " = ") wantOnce("Reused:false WasIdle:false IdleTime:0s") wantOnce("first response byte") - if h2 { + if mode == http2Mode { wantOnce("tls handshake start") wantOnce("tls handshake done") } else { @@ -4684,6 +4662,9 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { } func TestTransportEventTraceTLSVerify(t *testing.T) { + run(t, testTransportEventTraceTLSVerify, []testMode{https1Mode, http2Mode}) +} +func testTransportEventTraceTLSVerify(t *testing.T, mode testMode) { var mu sync.Mutex var buf strings.Builder logf := func(format string, args ...any) { @@ -4693,14 +4674,14 @@ func TestTransportEventTraceTLSVerify(t *testing.T) { buf.WriteByte('\n') } - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Error("Unexpected request") - })) - defer ts.Close() - ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { - logf("%s", p) - return len(p), nil - }), "", 0) + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(funcWriter(func(p []byte) (int, error) { + logf("%s", p) + return len(p), nil + }), "", 0) + }).ts certpool := x509.NewCertPool() certpool.AddCert(ts.Certificate()) @@ -4834,9 +4815,10 @@ func TestTransportRejectsAlphaPort(t *testing.T) { // Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 // connections. The http2 test is done in TestTransportEventTrace_h2 func TestTLSHandshakeTrace(t *testing.T) { - defer afterTest(t) - ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) - defer ts.Close() + run(t, testTLSHandshakeTrace, []testMode{https1Mode, http2Mode}) +} +func testTLSHandshakeTrace(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts var mu sync.Mutex var start, done bool @@ -4879,11 +4861,12 @@ func TestTLSHandshakeTrace(t *testing.T) { } func TestTransportMaxIdleConns(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportMaxIdleConns, []testMode{http1Mode}) +} +func testTransportMaxIdleConns(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // No body for convenience. - })) - defer ts.Close() + })).ts c := ts.Client() tr := c.Transport.(*Transport) tr.MaxIdleConns = 4 @@ -4931,27 +4914,24 @@ func TestTransportMaxIdleConns(t *testing.T) { } } -func TestTransportIdleConnTimeout_h1(t *testing.T) { testTransportIdleConnTimeout(t, h1Mode) } -func TestTransportIdleConnTimeout_h2(t *testing.T) { testTransportIdleConnTimeout(t, h2Mode) } -func testTransportIdleConnTimeout(t *testing.T, h2 bool) { +func TestTransportIdleConnTimeout(t *testing.T) { run(t, testTransportIdleConnTimeout) } +func testTransportIdleConnTimeout(t *testing.T, mode testMode) { if testing.Short() { t.Skip("skipping in short mode") } - defer afterTest(t) const timeout = 1 * time.Second - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // No body for convenience. })) - defer cst.close() tr := cst.tr tr.IdleConnTimeout = timeout defer tr.CloseIdleConnections() c := &Client{Transport: tr} idleConns := func() []string { - if h2 { + if mode == http2Mode { return tr.IdleConnStrsForTesting_h2() } else { return tr.IdleConnStrsForTesting() @@ -5005,12 +4985,11 @@ func testTransportIdleConnTimeout(t *testing.T, h2 bool) { // real connection until after the RoundTrip saw the error. Then we // know the successful tls.Dial from DialTLS will need to go into the // idle pool. Then we give it a of time to explode. -func TestIdleConnH2Crash(t *testing.T) { - setParallel(t) - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { +func TestIdleConnH2Crash(t *testing.T) { run(t, testIdleConnH2Crash, []testMode{http2Mode}) } +func testIdleConnH2Crash(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { // nothing })) - defer cst.close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -5101,21 +5080,18 @@ func TestTransportReturnsPeekError(t *testing.T) { } // Issue 13835: international domain names should work -func TestTransportIDNA_h1(t *testing.T) { testTransportIDNA(t, h1Mode) } -func TestTransportIDNA_h2(t *testing.T) { testTransportIDNA(t, h2Mode) } -func testTransportIDNA(t *testing.T, h2 bool) { - defer afterTest(t) - +func TestTransportIDNA(t *testing.T) { run(t, testTransportIDNA) } +func testTransportIDNA(t *testing.T, mode testMode) { const uniDomain = "гофер.го" const punyDomain = "xn--c1ae0ajs.xn--c1aw" var port string - cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { want := punyDomain + ":" + port if r.Host != want { t.Errorf("Host header = %q; want %q", r.Host, want) } - if h2 { + if mode == http2Mode { if r.TLS == nil { t.Errorf("r.TLS == nil") } else if r.TLS.ServerName != punyDomain { @@ -5123,8 +5099,11 @@ func testTransportIDNA(t *testing.T, h2 bool) { } } w.Header().Set("Hit-Handler", "1") - })) - defer cst.close() + }), func(tr *Transport) { + if tr.TLSClientConfig != nil { + tr.TLSClientConfig.InsecureSkipVerify = true + } + }) ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String()) if err != nil { @@ -5172,9 +5151,11 @@ func testTransportIDNA(t *testing.T, h2 bool) { // Issue 13290: send User-Agent in proxy CONNECT func TestTransportProxyConnectHeader(t *testing.T) { - defer afterTest(t) + run(t, testTransportProxyConnectHeader, []testMode{http1Mode}) +} +func testTransportProxyConnectHeader(t *testing.T, mode testMode) { reqc := make(chan *Request, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("method = %q; want CONNECT", r.Method) } @@ -5185,8 +5166,7 @@ func TestTransportProxyConnectHeader(t *testing.T) { return } c.Close() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { @@ -5216,9 +5196,11 @@ func TestTransportProxyConnectHeader(t *testing.T) { } func TestTransportProxyGetConnectHeader(t *testing.T) { - defer afterTest(t) + run(t, testTransportProxyGetConnectHeader, []testMode{http1Mode}) +} +func testTransportProxyGetConnectHeader(t *testing.T, mode testMode) { reqc := make(chan *Request, 1) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("method = %q; want CONNECT", r.Method) } @@ -5229,8 +5211,7 @@ func TestTransportProxyGetConnectHeader(t *testing.T) { return } c.Close() - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { @@ -5417,14 +5398,15 @@ func doFetchCheckPanic(tr *Transport, req *Request) (res *Response, err error, p // Issue 22330: do not allow the response body to be read when the status code // forbids a response body. func TestNoBodyOnChunked304Response(t *testing.T) { - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testNoBodyOnChunked304Response, []testMode{http1Mode}) +} +func testNoBodyOnChunked304Response(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 304 NOT MODIFIED\r\nTransfer-Encoding: chunked\r\n\r\n0\r\n\r\n")) buf.Flush() conn.Close() })) - defer cst.close() // Our test server above is sending back bogus data after the // response (the "0\r\n\r\n" part), which causes the Transport @@ -5477,11 +5459,12 @@ func TestTransportCheckContextDoneEarly(t *testing.T) { // This is the test variant that times out before the server replies with // any response headers. func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testClientTimeoutKillsConn_BeforeHeaders, []testMode{http1Mode}) +} +func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) { inHandler := make(chan net.Conn, 1) handlerReadReturned := make(chan bool, 1) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -5494,7 +5477,6 @@ func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { } handlerReadReturned <- true })) - defer cst.close() const timeout = 50 * time.Millisecond cst.c.Timeout = timeout @@ -5529,11 +5511,12 @@ func TestClientTimeoutKillsConn_BeforeHeaders(t *testing.T) { // This is the test variant that has the server send response headers // first, and time out during the write of the response body. func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode}) +} +func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) { inHandler := make(chan net.Conn, 1) handlerResult := make(chan error, 1) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "100") w.(Flusher).Flush() conn, _, err := w.(Hijacker).Hijack() @@ -5555,7 +5538,6 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { } handlerResult <- nil })) - defer cst.close() // Set Timeout to something very long but non-zero to exercise // the codepaths that check for it. But rather than wait for it to fire @@ -5601,11 +5583,12 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { } func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { - setParallel(t) - defer afterTest(t) + run(t, testTransportResponseBodyWritableOnProtocolSwitch, []testMode{http1Mode}) +} +func testTransportResponseBodyWritableOnProtocolSwitch(t *testing.T, mode testMode) { done := make(chan struct{}) defer close(done) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -5618,7 +5601,6 @@ func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { fmt.Fprintf(conn, "%s\n", strings.ToUpper(bs.Text())) <-done })) - defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, nil) req.Header.Set("Upgrade", "foo") @@ -5651,10 +5633,10 @@ func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) { } } -func TestTransportCONNECTBidi(t *testing.T) { - defer afterTest(t) +func TestTransportCONNECTBidi(t *testing.T) { run(t, testTransportCONNECTBidi, []testMode{http1Mode}) } +func testTransportCONNECTBidi(t *testing.T, mode testMode) { const target = "backend:443" - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if r.Method != "CONNECT" { t.Errorf("unexpected method %q", r.Method) w.WriteHeader(500) @@ -5685,7 +5667,6 @@ func TestTransportCONNECTBidi(t *testing.T) { brw.Flush() } })) - defer cst.close() pr, pw := io.Pipe() defer pw.Close() req, err := NewRequest("CONNECT", cst.ts.URL, pr) @@ -5782,7 +5763,8 @@ func (c *testMockTCPConn) ReadFrom(r io.Reader) (int64, error) { return c.TCPConn.ReadFrom(r) } -func TestTransportRequestWriteRoundTrip(t *testing.T) { +func TestTransportRequestWriteRoundTrip(t *testing.T) { run(t, testTransportRequestWriteRoundTrip) } +func testTransportRequestWriteRoundTrip(t *testing.T, mode testMode) { nBytes := int64(1 << 10) newFileFunc := func() (r io.Reader, done func(), err error) { f, err := os.CreateTemp("", "net-http-newfilefunc") @@ -5876,7 +5858,7 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { cst := newClientServerTest( t, - h1Mode, + mode, HandlerFunc(func(w ResponseWriter, r *Request) { io.Copy(io.Discard, r.Body) r.Body.Close() @@ -5884,7 +5866,6 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { }), trFunc, ) - defer cst.close() req, err := NewRequest("PUT", cst.ts.URL, r) if err != nil { @@ -5901,11 +5882,15 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { t.Fatalf("status code = %d; want 200", resp.StatusCode) } - if !tConn.ReadFromCalled && tc.expectedReadFrom { + expectedReadFrom := tc.expectedReadFrom + if mode != http1Mode { + expectedReadFrom = false + } + if !tConn.ReadFromCalled && expectedReadFrom { t.Fatalf("did not call ReadFrom") } - if tConn.ReadFromCalled && !tc.expectedReadFrom { + if tConn.ReadFromCalled && !expectedReadFrom { t.Fatalf("ReadFrom was unexpectedly invoked") } }) @@ -5985,17 +5970,17 @@ func TestIs408(t *testing.T) { } } -func TestTransportIgnores408(t *testing.T) { +func TestTransportIgnores408(t *testing.T) { run(t, testTransportIgnores408, []testMode{http1Mode}) } +func testTransportIgnores408(t *testing.T, mode testMode) { // Not parallel. Relies on mutating the log package's global Output. defer log.SetOutput(log.Writer()) var logout strings.Builder log.SetOutput(&logout) - defer afterTest(t) const target = "backend:443" - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { nc, _, err := w.(Hijacker).Hijack() if err != nil { t.Error(err) @@ -6005,7 +5990,6 @@ func TestTransportIgnores408(t *testing.T) { nc.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok")) nc.Write([]byte("HTTP/1.1 408 bye\r\n")) // changing 408 to 409 makes test fail })) - defer cst.close() req, err := NewRequest("GET", cst.ts.URL, nil) if err != nil { t.Fatal(err) @@ -6039,9 +6023,10 @@ func TestTransportIgnores408(t *testing.T) { } func TestInvalidHeaderResponse(t *testing.T) { - setParallel(t) - defer afterTest(t) - cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testInvalidHeaderResponse, []testMode{http1Mode}) +} +func testInvalidHeaderResponse(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { conn, buf, _ := w.(Hijacker).Hijack() buf.Write([]byte("HTTP/1.1 200 OK\r\n" + "Date: Wed, 30 Aug 2017 19:09:27 GMT\r\n" + @@ -6051,7 +6036,6 @@ func TestInvalidHeaderResponse(t *testing.T) { buf.Flush() conn.Close() })) - defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) @@ -6078,10 +6062,12 @@ func (bc *bodyCloser) Read(b []byte) (n int, err error) { // Issue 35015: ensure that Transport closes the body on any error // with an invalid request, as promised by Client.Do docs. func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { - cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportClosesBodyOnInvalidRequests) +} +func testTransportClosesBodyOnInvalidRequests(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { t.Errorf("Should not have been invoked") - })) - defer cst.Close() + })).ts u, _ := url.Parse(cst.URL) @@ -6146,7 +6132,7 @@ func TestTransportClosesBodyOnInvalidRequests(t *testing.T) { var bc bodyCloser req := tt.req req.Body = &bc - _, err := DefaultClient.Do(tt.req) + _, err := cst.Client().Do(tt.req) if err == nil { t.Fatal("Expected an error") } @@ -6183,8 +6169,10 @@ func (w *breakableConn) Write(b []byte) (n int, err error) { // Issue 34978: don't cache a broken HTTP/2 connection func TestDontCacheBrokenHTTP2Conn(t *testing.T) { - cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog) - defer cst.close() + run(t, testDontCacheBrokenHTTP2Conn, []testMode{http2Mode}) +} +func testDontCacheBrokenHTTP2Conn(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), optQuietLog) var brokenState brokenState @@ -6246,7 +6234,9 @@ func TestDontCacheBrokenHTTP2Conn(t *testing.T) { // http.http2noCachedConnError is reported on multiple requests. There should // only be one decrement regardless of the number of failures. func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { - defer afterTest(t) + run(t, testTransportDecrementConnWhenIdleConnRemoved, []testMode{http2Mode}) +} +func testTransportDecrementConnWhenIdleConnRemoved(t *testing.T, mode testMode) { CondSkipHTTP2(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -6256,17 +6246,11 @@ func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { } }) - ts := httptest.NewUnstartedServer(h) - ts.EnableHTTP2 = true - ts.StartTLS() - defer ts.Close() + ts := newClientServerTest(t, mode, h).ts c := ts.Client() tr := c.Transport.(*Transport) tr.MaxConnsPerHost = 1 - if err := ExportHttp2ConfigureTransport(tr); err != nil { - t.Fatalf("ExportHttp2ConfigureTransport: %v", err) - } errCh := make(chan error, 300) doReq := func() { @@ -6335,14 +6319,13 @@ type roundTripFunc func(r *Request) (*Response, error) func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } // Issue 32441: body is not reset after ErrSkipAltProtocol -func TestIssue32441(t *testing.T) { - defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { +func TestIssue32441(t *testing.T) { run(t, testIssue32441, []testMode{http1Mode}) } +func testIssue32441(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { if n, _ := io.Copy(io.Discard, r.Body); n == 0 { t.Error("body length is zero") } - })) - defer ts.Close() + })).ts c := ts.Client() c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { // Draining body to trigger failure condition on actual request to server. @@ -6359,11 +6342,13 @@ func TestIssue32441(t *testing.T) { // Issue 39017. Ensure that HTTP/1 transports reject Content-Length headers // that contain a sign (eg. "+3"), per RFC 2616, Section 14.13. func TestTransportRejectsSignInContentLength(t *testing.T) { - cst := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + run(t, testTransportRejectsSignInContentLength, []testMode{http1Mode}) +} +func testTransportRejectsSignInContentLength(t *testing.T, mode testMode) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "+3") w.Write([]byte("abc")) - })) - defer cst.Close() + })).ts c := cst.Client() res, err := c.Get(cst.URL) @@ -6477,14 +6462,16 @@ func TestErrorWriteLoopRace(t *testing.T) { // Test that a new request which uses the connection of an active request // cannot cause it to be canceled as well. func TestCancelRequestWhenSharingConnection(t *testing.T) { + run(t, testCancelRequestWhenSharingConnection, []testMode{http1Mode}) +} +func testCancelRequestWhenSharingConnection(t *testing.T, mode testMode) { reqc := make(chan chan struct{}, 2) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, req *Request) { + ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, req *Request) { ch := make(chan struct{}, 1) reqc <- ch <-ch w.Header().Add("Content-Length", "0") - })) - defer ts.Close() + })).ts client := ts.Client() transport := client.Transport.(*Transport) @@ -6549,15 +6536,12 @@ func TestCancelRequestWhenSharingConnection(t *testing.T) { wg.Wait() } -func TestHandlerAbortRacesBodyRead(t *testing.T) { - setParallel(t) - defer afterTest(t) - - ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { +func TestHandlerAbortRacesBodyRead(t *testing.T) { run(t, testHandlerAbortRacesBodyRead) } +func testHandlerAbortRacesBodyRead(t *testing.T, mode testMode) { + ts := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) { go io.Copy(io.Discard, req.Body) panic(ErrAbortHandler) - })) - defer ts.Close() + })).ts var wg sync.WaitGroup for i := 0; i < 2; i++ {