mirror of
https://github.com/golang/go
synced 2024-11-07 17:56:21 -07:00
http/http/httputil: add ReverseProxy.ErrorHandler
This permits specifying an ErrorHandler to customize the RoundTrip error handling if the backend fails to return a response. Fixes #22700 Fixes #21255 Change-Id: I8879f0956e2472a07f584660afa10105ef23bf11 Reviewed-on: https://go-review.googlesource.com/77410 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
86a0e67a03
commit
5201b1ad22
@ -55,10 +55,23 @@ type ReverseProxy struct {
|
|||||||
// copying HTTP response bodies.
|
// copying HTTP response bodies.
|
||||||
BufferPool BufferPool
|
BufferPool BufferPool
|
||||||
|
|
||||||
// ModifyResponse is an optional function that
|
// ModifyResponse is an optional function that modifies the
|
||||||
// modifies the Response from the backend.
|
// Response from the backend. It is called if the backend
|
||||||
// If it returns an error, the proxy returns a StatusBadGateway error.
|
// returns a response at all, with any HTTP status code.
|
||||||
|
// If the backend is unreachable, the optional ErrorHandler is
|
||||||
|
// called without any call to ModifyResponse.
|
||||||
|
//
|
||||||
|
// If ModifyResponse returns an error, ErrorHandler is called
|
||||||
|
// with its error value. If ErrorHandler is nil, its default
|
||||||
|
// implementation is used.
|
||||||
ModifyResponse func(*http.Response) error
|
ModifyResponse func(*http.Response) error
|
||||||
|
|
||||||
|
// ErrorHandler is an optional function that handles errors
|
||||||
|
// reaching the backend or errors from ModifyResponse.
|
||||||
|
//
|
||||||
|
// If nil, the default is to log the provided error and return
|
||||||
|
// a 502 Status Bad Gateway response.
|
||||||
|
ErrorHandler func(http.ResponseWriter, *http.Request, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// A BufferPool is an interface for getting and returning temporary
|
// A BufferPool is an interface for getting and returning temporary
|
||||||
@ -141,6 +154,18 @@ var hopHeaders = []string{
|
|||||||
"Upgrade",
|
"Upgrade",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *ReverseProxy) defaultErrorHandler(rw http.ResponseWriter, req *http.Request, err error) {
|
||||||
|
p.logf("http: proxy error: %v", err)
|
||||||
|
rw.WriteHeader(http.StatusBadGateway)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *ReverseProxy) getErrorHandler() func(http.ResponseWriter, *http.Request, error) {
|
||||||
|
if p.ErrorHandler != nil {
|
||||||
|
return p.ErrorHandler
|
||||||
|
}
|
||||||
|
return p.defaultErrorHandler
|
||||||
|
}
|
||||||
|
|
||||||
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
||||||
transport := p.Transport
|
transport := p.Transport
|
||||||
if transport == nil {
|
if transport == nil {
|
||||||
@ -206,8 +231,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
res, err := transport.RoundTrip(outreq)
|
res, err := transport.RoundTrip(outreq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
p.logf("http: proxy error: %v", err)
|
p.getErrorHandler()(rw, outreq, err)
|
||||||
rw.WriteHeader(http.StatusBadGateway)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,9 +243,8 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
|
|||||||
|
|
||||||
if p.ModifyResponse != nil {
|
if p.ModifyResponse != nil {
|
||||||
if err := p.ModifyResponse(res); err != nil {
|
if err := p.ModifyResponse(res); err != nil {
|
||||||
p.logf("http: proxy error: %v", err)
|
|
||||||
rw.WriteHeader(http.StatusBadGateway)
|
|
||||||
res.Body.Close()
|
res.Body.Close()
|
||||||
|
p.getErrorHandler()(rw, outreq, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -637,6 +637,93 @@ func TestReverseProxyModifyResponse(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type failingRoundTripper struct{}
|
||||||
|
|
||||||
|
func (failingRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
|
||||||
|
return nil, errors.New("some error")
|
||||||
|
}
|
||||||
|
|
||||||
|
type staticResponseRoundTripper struct{ res *http.Response }
|
||||||
|
|
||||||
|
func (rt staticResponseRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
|
||||||
|
return rt.res, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestReverseProxyErrorHandler(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
wantCode int
|
||||||
|
errorHandler func(http.ResponseWriter, *http.Request, error)
|
||||||
|
transport http.RoundTripper // defaults to failingRoundTripper
|
||||||
|
modifyResponse func(*http.Response) error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "default",
|
||||||
|
wantCode: http.StatusBadGateway,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "errorhandler",
|
||||||
|
wantCode: http.StatusTeapot,
|
||||||
|
errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modifyresponse_noerr",
|
||||||
|
transport: staticResponseRoundTripper{
|
||||||
|
&http.Response{StatusCode: 345, Body: http.NoBody},
|
||||||
|
},
|
||||||
|
modifyResponse: func(res *http.Response) error {
|
||||||
|
res.StatusCode++
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
|
||||||
|
wantCode: 346,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "modifyresponse_err",
|
||||||
|
transport: staticResponseRoundTripper{
|
||||||
|
&http.Response{StatusCode: 345, Body: http.NoBody},
|
||||||
|
},
|
||||||
|
modifyResponse: func(res *http.Response) error {
|
||||||
|
res.StatusCode++
|
||||||
|
return errors.New("some error to trigger errorHandler")
|
||||||
|
},
|
||||||
|
errorHandler: func(rw http.ResponseWriter, req *http.Request, err error) { rw.WriteHeader(http.StatusTeapot) },
|
||||||
|
wantCode: http.StatusTeapot,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
target := &url.URL{
|
||||||
|
Scheme: "http",
|
||||||
|
Host: "dummy.tld",
|
||||||
|
Path: "/",
|
||||||
|
}
|
||||||
|
rproxy := NewSingleHostReverseProxy(target)
|
||||||
|
rproxy.Transport = tt.transport
|
||||||
|
rproxy.ModifyResponse = tt.modifyResponse
|
||||||
|
if rproxy.Transport == nil {
|
||||||
|
rproxy.Transport = failingRoundTripper{}
|
||||||
|
}
|
||||||
|
rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests
|
||||||
|
if tt.errorHandler != nil {
|
||||||
|
rproxy.ErrorHandler = tt.errorHandler
|
||||||
|
}
|
||||||
|
frontendProxy := httptest.NewServer(rproxy)
|
||||||
|
defer frontendProxy.Close()
|
||||||
|
|
||||||
|
resp, err := http.Get(frontendProxy.URL + "/test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to reach proxy: %v", err)
|
||||||
|
}
|
||||||
|
if g, e := resp.StatusCode, tt.wantCode; g != e {
|
||||||
|
t.Errorf("got res.StatusCode %d; expected %d", g, e)
|
||||||
|
}
|
||||||
|
resp.Body.Close()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Issue 16659: log errors from short read
|
// Issue 16659: log errors from short read
|
||||||
func TestReverseProxy_CopyBuffer(t *testing.T) {
|
func TestReverseProxy_CopyBuffer(t *testing.T) {
|
||||||
backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
Loading…
Reference in New Issue
Block a user