diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index 76f23bcf9a9..f18dd886cc1 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -260,12 +260,40 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { if p.BufferPool != nil { buf = p.BufferPool.Get() } - io.CopyBuffer(dst, src, buf) + p.copyBuffer(dst, src, buf) if p.BufferPool != nil { p.BufferPool.Put(buf) } } +func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF { + p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + return written, rerr + } + } +} + func (p *ReverseProxy) logf(format string, args ...interface{}) { if p.ErrorLog != nil { p.ErrorLog.Printf(format, args...) diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index 8b5bd797a7e..b3270a1a63f 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -10,6 +10,7 @@ import ( "bufio" "bytes" "errors" + "fmt" "io" "io/ioutil" "log" @@ -581,3 +582,44 @@ func TestReverseProxy_NilBody(t *testing.T) { t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) } } + +// Issue 16659: log errors from short read +func TestReverseProxy_CopyBuffer(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.UnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + var proxyLog bytes.Buffer + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL) + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + defer resp.Body.Close() + + if _, err := ioutil.ReadAll(resp.Body); err == nil { + t.Fatalf("want non-nil error") + } + expected := []string{ + "EOF", + "read", + } + for _, phrase := range expected { + if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { + t.Errorf("expected log to contain phrase %q", phrase) + } + } +}