mirror of
https://github.com/golang/go
synced 2024-11-22 08:34:40 -07:00
net/http/httputil: avoid ReverseProxy data race on 1xx response and error
ReverseProxy uses a httptrace.ClientTrace.Got1xxResponse trace hook to capture 1xx response headers for proxying. This hook can be called asynchrnously after RoundTrip returns. (This should only happen when RoundTrip has failed for some reason.) Add synchronization so we don't attempt to modifying the ResponseWriter headers map from the hook after another goroutine has begun making use of it. Fixes #65123 Change-Id: I8b7ecb1a140f7ba7e37b9d27b8a20bca41a118b1 Reviewed-on: https://go-review.googlesource.com/c/go/+/567216 LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com> Reviewed-by: Jonathan Amsterdam <jba@google.com> Auto-Submit: Damien Neil <dneil@google.com>
This commit is contained in:
parent
d3e827d371
commit
960654be0c
@ -454,8 +454,19 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
outreq.Header.Set("User-Agent", "")
|
||||
}
|
||||
|
||||
var (
|
||||
roundTripMutex sync.Mutex
|
||||
roundTripDone bool
|
||||
)
|
||||
trace := &httptrace.ClientTrace{
|
||||
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||||
roundTripMutex.Lock()
|
||||
defer roundTripMutex.Unlock()
|
||||
if roundTripDone {
|
||||
// If RoundTrip has returned, don't try to further modify
|
||||
// the ResponseWriter's header map.
|
||||
return nil
|
||||
}
|
||||
h := rw.Header()
|
||||
copyHeader(h, http.Header(header))
|
||||
rw.WriteHeader(code)
|
||||
@ -468,6 +479,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||
outreq = outreq.WithContext(httptrace.WithClientTrace(outreq.Context(), trace))
|
||||
|
||||
res, err := transport.RoundTrip(outreq)
|
||||
roundTripMutex.Lock()
|
||||
roundTripDone = true
|
||||
roundTripMutex.Unlock()
|
||||
if err != nil {
|
||||
p.getErrorHandler()(rw, outreq, err)
|
||||
return
|
||||
|
@ -1687,6 +1687,47 @@ func TestReverseProxyRewriteReplacesOut(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test1xxHeadersNotModifiedAfterRoundTrip(t *testing.T) {
|
||||
// https://go.dev/issue/65123: We use httptrace.Got1xxResponse to capture 1xx responses
|
||||
// and proxy them. httptrace handlers can execute after RoundTrip returns, in particular
|
||||
// after experiencing connection errors. When this happens, we shouldn't modify the
|
||||
// ResponseWriter headers after ReverseProxy.ServeHTTP returns.
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
for i := 0; i < 5; i++ {
|
||||
w.WriteHeader(103)
|
||||
}
|
||||
}))
|
||||
defer backend.Close()
|
||||
backendURL, err := url.Parse(backend.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
proxyHandler := NewSingleHostReverseProxy(backendURL)
|
||||
proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests
|
||||
|
||||
rw := &testResponseWriter{}
|
||||
func() {
|
||||
// Cancel the request (and cause RoundTrip to return) immediately upon
|
||||
// seeing a 1xx response.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
|
||||
Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
|
||||
cancel()
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
req, _ := http.NewRequestWithContext(ctx, "GET", "http://go.dev/", nil)
|
||||
proxyHandler.ServeHTTP(rw, req)
|
||||
}()
|
||||
// Trigger data race while iterating over response headers.
|
||||
// When run with -race, this causes the condition in https://go.dev/issue/65123 often
|
||||
// enough to detect reliably.
|
||||
for _ = range rw.Header() {
|
||||
}
|
||||
}
|
||||
|
||||
func Test1xxResponses(t *testing.T) {
|
||||
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
h := w.Header()
|
||||
@ -1861,3 +1902,29 @@ func testReverseProxyQueryParameterSmuggling(t *testing.T, wantCleanQuery bool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type testResponseWriter struct {
|
||||
h http.Header
|
||||
writeHeader func(int)
|
||||
write func([]byte) (int, error)
|
||||
}
|
||||
|
||||
func (rw *testResponseWriter) Header() http.Header {
|
||||
if rw.h == nil {
|
||||
rw.h = make(http.Header)
|
||||
}
|
||||
return rw.h
|
||||
}
|
||||
|
||||
func (rw *testResponseWriter) WriteHeader(statusCode int) {
|
||||
if rw.writeHeader != nil {
|
||||
rw.writeHeader(statusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func (rw *testResponseWriter) Write(p []byte) (int, error) {
|
||||
if rw.write != nil {
|
||||
return rw.write(p)
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user