diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index 6f0a2418b3..1dddaa95a7 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -55,10 +55,23 @@ type ReverseProxy struct { // copying HTTP response bodies. BufferPool BufferPool - // ModifyResponse is an optional function that - // modifies the Response from the backend. - // If it returns an error, the proxy returns a StatusBadGateway error. + // ModifyResponse is an optional function that modifies the + // Response from the backend. It is called if the backend + // returns a response at all, with any HTTP status code. + // If the backend is unreachable, the optional ErrorHandler is + // called without any call to ModifyResponse. + // + // If ModifyResponse returns an error, ErrorHandler is called + // with its error value. If ErrorHandler is nil, its default + // implementation is used. ModifyResponse func(*http.Response) error + + // ErrorHandler is an optional function that handles errors + // reaching the backend or errors from ModifyResponse. + // + // If nil, the default is to log the provided error and return + // a 502 Status Bad Gateway response. + ErrorHandler func(http.ResponseWriter, *http.Request, error) } // A BufferPool is an interface for getting and returning temporary @@ -141,6 +154,18 @@ var hopHeaders = []string{ "Upgrade", } +func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) +} + +func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) { + if p.ErrorHandler != nil { + return p.ErrorHandler + } + return p.defaultErrorHandler +} + func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { transport := p.Transport if transport == nil { @@ -206,8 +231,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { res, err := transport.RoundTrip(outreq) if err != nil { - p.logf("http: proxy error: %v", err) - rw.WriteHeader(http.StatusBadGateway) + p.getErrorHandler()(rw, outreq, err) return } @@ -219,9 +243,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if p.ModifyResponse != nil { if err := p.ModifyResponse(res); err != nil { - p.logf("http: proxy error: %v", err) - rw.WriteHeader(http.StatusBadGateway) res.Body.Close() + p.getErrorHandler()(rw, outreq, err) return } } diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index 2a12e753b5..2f75b4e34e 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -637,6 +637,93 @@ func TestReverseProxyModifyResponse(t *testing.T) { } } +type failingRoundTripper struct{} + +func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return nil, errors.New("some error") +} + +type staticResponseRoundTripper struct{ res *http.Response } + +func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return rt.res, nil +} + +func TestReverseProxyErrorHandler(t *testing.T) { + tests := []struct { + name string + wantCode int + errorHandler func(http.ResponseWriter, *http.Request, error) + transport http.RoundTripper // defaults to failingRoundTripper + modifyResponse func(*http.Response) error + }{ + { + name: "default", + wantCode: http.StatusBadGateway, + }, + { + name: "errorhandler", + wantCode: http.StatusTeapot, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + }, + { + name: "modifyresponse_noerr", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return nil + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: 346, + }, + { + name: "modifyresponse_err", + transport: staticResponseRoundTripper{ + &http.Response{StatusCode: 345, Body: http.NoBody}, + }, + modifyResponse: func(res *http.Response) error { + res.StatusCode++ + return errors.New("some error to trigger errorHandler") + }, + errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) }, + wantCode: http.StatusTeapot, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + target := &url.URL{ + Scheme: "http", + Host: "dummy.tld", + Path: "/", + } + rproxy := NewSingleHostReverseProxy(target) + rproxy.Transport = tt.transport + rproxy.ModifyResponse = tt.modifyResponse + if rproxy.Transport == nil { + rproxy.Transport = failingRoundTripper{} + } + rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + if tt.errorHandler != nil { + rproxy.ErrorHandler = tt.errorHandler + } + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL + "/test") + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + resp.Body.Close() + }) + } +} + // Issue 16659: log errors from short read func TestReverseProxy_CopyBuffer(t *testing.T) { backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {