diff --git a/src/cmd/pprof/internal/fetch/fetch.go b/src/cmd/pprof/internal/fetch/fetch.go index 3ed16bb50d..45e02f2cd7 100644 --- a/src/cmd/pprof/internal/fetch/fetch.go +++ b/src/cmd/pprof/internal/fetch/fetch.go @@ -52,7 +52,8 @@ func FetchURL(source string, timeout time.Duration) (io.ReadCloser, error) { return nil, fmt.Errorf("http fetch: %v", err) } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("server response: %s", resp.Status) + defer resp.Body.Close() + return nil, statusCodeError(resp) } return resp.Body, nil @@ -64,13 +65,24 @@ func PostURL(source, post string) ([]byte, error) { if err != nil { return nil, fmt.Errorf("http post %s: %v", source, err) } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("server response: %s", resp.Status) - } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return nil, statusCodeError(resp) + } return ioutil.ReadAll(resp.Body) } +func statusCodeError(resp *http.Response) error { + if resp.Header.Get("X-Go-Pprof") != "" && strings.Contains(resp.Header.Get("Content-Type"), "text/plain") { + // error is from pprof endpoint + body, err := ioutil.ReadAll(resp.Body) + if err == nil { + return fmt.Errorf("server response: %s - %s", resp.Status, body) + } + } + return fmt.Errorf("server response: %s", resp.Status) +} + // httpGet is a wrapper around http.Get; it is defined as a variable // so it can be redefined during for testing. var httpGet = func(source string, timeout time.Duration) (*http.Response, error) { diff --git a/src/net/http/pprof/pprof.go b/src/net/http/pprof/pprof.go index 126e9eaaa7..6930df531b 100644 --- a/src/net/http/pprof/pprof.go +++ b/src/net/http/pprof/pprof.go @@ -90,6 +90,11 @@ func sleep(w http.ResponseWriter, d time.Duration) { } } +func durationExceedsWriteTimeout(r *http.Request, seconds float64) bool { + srv, ok := r.Context().Value(http.ServerContextKey).(*http.Server) + return ok && srv.WriteTimeout != 0 && seconds >= srv.WriteTimeout.Seconds() +} + // Profile responds with the pprof-formatted cpu profile. // The package initialization registers it as /debug/pprof/profile. func Profile(w http.ResponseWriter, r *http.Request) { @@ -98,6 +103,14 @@ func Profile(w http.ResponseWriter, r *http.Request) { sec = 30 } + if durationExceedsWriteTimeout(r, float64(sec)) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintln(w, "profile duration exceeds server's WriteTimeout") + return + } + // Set Content Type assuming StartCPUProfile will work, // because if it does it starts writing. w.Header().Set("Content-Type", "application/octet-stream") @@ -106,6 +119,7 @@ func Profile(w http.ResponseWriter, r *http.Request) { // Can change header back to text content // and send error code. w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") w.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err) return @@ -123,6 +137,14 @@ func Trace(w http.ResponseWriter, r *http.Request) { sec = 1 } + if durationExceedsWriteTimeout(r, sec) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") + w.WriteHeader(http.StatusBadRequest) + fmt.Fprintln(w, "profile duration exceeds server's WriteTimeout") + return + } + // Set Content Type assuming trace.Start will work, // because if it does it starts writing. w.Header().Set("Content-Type", "application/octet-stream") @@ -130,6 +152,7 @@ func Trace(w http.ResponseWriter, r *http.Request) { // trace.Start failed, so no writes yet. // Can change header back to text content and send error code. w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.Header().Set("X-Go-Pprof", "1") w.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(w, "Could not enable tracing: %s\n", err) return