diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index 5c70f0d27bb..04248d5f531 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -454,8 +454,19 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.Header.Set("User-Agent", "") } + var ( + roundTripMutex sync.Mutex + roundTripDone bool + ) trace := &httptrace.ClientTrace{ Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + roundTripMutex.Lock() + defer roundTripMutex.Unlock() + if roundTripDone { + // If RoundTrip has returned, don't try to further modify + // the ResponseWriter's header map. + return nil + } h := rw.Header() copyHeader(h, http.Header(header)) rw.WriteHeader(code) @@ -468,6 +479,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace)) res, err := transport.RoundTrip(outreq) + roundTripMutex.Lock() + roundTripDone = true + roundTripMutex.Unlock() if err != nil { p.getErrorHandler()(rw, outreq, err) return diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index dd3330b615d..1bd64e65ba9 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -1687,6 +1687,47 @@ func TestReverseProxyRewriteReplacesOut(t *testing.T) { } } +func Test1xxHeadersNotModifiedAfterRoundTrip(t *testing.T) { + // https://go.dev/issue/65123: We use httptrace.Got1xxResponse to capture 1xx responses + // and proxy them. httptrace handlers can execute after RoundTrip returns, in particular + // after experiencing connection errors. When this happens, we shouldn't modify the + // ResponseWriter headers after ReverseProxy.ServeHTTP returns. + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + for i := 0; i < 5; i++ { + w.WriteHeader(103) + } + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + + rw := &testResponseWriter{} + func() { + // Cancel the request (and cause RoundTrip to return) immediately upon + // seeing a 1xx response. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + cancel() + return nil + }, + }) + + req, _ := http.NewRequestWithContext(ctx, "GET", "http://go.dev/", nil) + proxyHandler.ServeHTTP(rw, req) + }() + // Trigger data race while iterating over response headers. + // When run with -race, this causes the condition in https://go.dev/issue/65123 often + // enough to detect reliably. + for _ = range rw.Header() { + } +} + func Test1xxResponses(t *testing.T) { backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := w.Header() @@ -1861,3 +1902,29 @@ func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool, } } } + +type testResponseWriter struct { + h http.Header + writeHeader func(int) + write func([]byte) (int, error) +} + +func (rw *testResponseWriter) Header() http.Header { + if rw.h == nil { + rw.h = make(http.Header) + } + return rw.h +} + +func (rw *testResponseWriter) WriteHeader(statusCode int) { + if rw.writeHeader != nil { + rw.writeHeader(statusCode) + } +} + +func (rw *testResponseWriter) Write(p []byte) (int, error) { + if rw.write != nil { + return rw.write(p) + } + return len(p), nil +}