diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index 0fe555af38..df2a670aee 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -60,13 +60,6 @@ func pedanticReadAll(r io.Reader) (b []byte, err error) { } } -type chanWriter chan string - -func (w chanWriter) Write(p []byte) (n int, err error) { - w <- string(p) - return len(p), nil -} - func TestClient(t *testing.T) { run(t, testClient) } func testClient(t *testing.T, mode testMode) { ts := newClientServerTest(t, mode, robotsTxtHandler).ts @@ -827,12 +820,12 @@ func TestClientInsecureTransport(t *testing.T) { run(t, testClientInsecureTransport, []testMode{https1Mode, http2Mode}) } func testClientInsecureTransport(t *testing.T, mode testMode) { - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) - })).ts - errc := make(chanWriter, 10) // but only expecting 1 - ts.Config.ErrorLog = log.New(errc, "", 0) - defer ts.Close() + })) + ts := cst.ts + errLog := new(strings.Builder) + ts.Config.ErrorLog = log.New(errLog, "", 0) // TODO(bradfitz): add tests for skipping hostname checks too? // would require a new cert for testing, and probably @@ -851,15 +844,10 @@ func testClientInsecureTransport(t *testing.T, mode testMode) { } } - select { - case v := <-errc: - if !strings.Contains(v, "TLS handshake error") { - t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) - } - case <-time.After(5 * time.Second): - t.Errorf("timeout waiting for logged error") + cst.close() + if !strings.Contains(errLog.String(), "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", errLog) } - } func TestClientErrorWithRequestURI(t *testing.T) { @@ -897,9 +885,10 @@ func TestClientWithIncorrectTLSServerName(t *testing.T) { run(t, testClientWithIncorrectTLSServerName, []testMode{https1Mode, http2Mode}) } func testClientWithIncorrectTLSServerName(t *testing.T, mode testMode) { - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})).ts - errc := make(chanWriter, 10) // but only expecting 1 - ts.Config.ErrorLog = log.New(errc, "", 0) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {})) + ts := cst.ts + errLog := new(strings.Builder) + ts.Config.ErrorLog = log.New(errLog, "", 0) c := ts.Client() c.Transport.(*Transport).TLSClientConfig.ServerName = "badserver" @@ -910,13 +899,10 @@ func testClientWithIncorrectTLSServerName(t *testing.T, mode testMode) { if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) } - select { - case v := <-errc: - if !strings.Contains(v, "TLS handshake error") { - t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", v) - } - case <-time.After(5 * time.Second): - t.Errorf("timeout waiting for logged error") + + cst.close() + if !strings.Contains(errLog.String(), "TLS handshake error") { + t.Errorf("expected an error log message containing 'TLS handshake error'; got %q", errLog) } } diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 00230020e7..0c76f1bcc4 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -1400,27 +1400,28 @@ func TestTLSHandshakeTimeout(t *testing.T) { run(t, testTLSHandshakeTimeout, []testMode{https1Mode, http2Mode}) } func testTLSHandshakeTimeout(t *testing.T, mode testMode) { - errc := make(chanWriter, 10) // but only expecting 1 - ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), + errLog := new(strings.Builder) + cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), func(ts *httptest.Server) { ts.Config.ReadTimeout = 250 * time.Millisecond - ts.Config.ErrorLog = log.New(errc, "", 0) + ts.Config.ErrorLog = log.New(errLog, "", 0) }, - ).ts + ) + ts := cst.ts + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) if err != nil { t.Fatalf("Dial: %v", err) } - defer conn.Close() - var buf [1]byte n, err := conn.Read(buf[:]) if err == nil || n != 0 { t.Errorf("Read = %d, %v; want an error and no bytes", n, err) } + conn.Close() - v := <-errc - if !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") { + cst.close() + if v := errLog.String(); !strings.Contains(v, "timeout") && !strings.Contains(v, "TLS handshake") { t.Errorf("expected a TLS handshake timeout error; got %q", v) } }