diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index 725ba0b70a..bc99797b33 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -8,6 +8,8 @@ import ( "bytes" "io/ioutil" "net/http" + "strconv" + "strings" ) // ResponseRecorder is an implementation of http.ResponseWriter that @@ -162,6 +164,7 @@ func (rw *ResponseRecorder) Result() *http.Response { if rw.Body != nil { res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes())) } + res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) if trailers, ok := rw.snapHeader["Trailer"]; ok { res.Trailer = make(http.Header, len(trailers)) @@ -186,3 +189,20 @@ func (rw *ResponseRecorder) Result() *http.Response { } return res } + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +// +// This a modified version of same function found in net/http/transfer.go. This +// one just ignores an invalid header. +func parseContentLength(cl string) int64 { + cl = strings.TrimSpace(cl) + if cl == "" { + return -1 + } + n, err := strconv.ParseInt(cl, 10, 64) + if err != nil { + return -1 + } + return n +} diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index d4e7137913..ff9b9911a8 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -94,6 +94,14 @@ func TestRecorder(t *testing.T) { return nil } } + hasContentLength := func(length int64) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Result().ContentLength; got != length { + return fmt.Errorf("ContentLength = %d; want %d", got, length) + } + return nil + } + } tests := []struct { name string @@ -141,7 +149,7 @@ func TestRecorder(t *testing.T) { w.(http.Flusher).Flush() // also sends a 200 w.WriteHeader(201) }, - check(hasStatus(200), hasFlush(true)), + check(hasStatus(200), hasFlush(true), hasContentLength(-1)), }, { "Content-Type detection", @@ -244,6 +252,16 @@ func TestRecorder(t *testing.T) { hasNotHeaders("X-Bar"), ), }, + { + "setting Content-Length header", + func(w http.ResponseWriter, r *http.Request) { + body := "Some body" + contentLength := fmt.Sprintf("%d", len(body)) + w.Header().Set("Content-Length", contentLength) + io.WriteString(w, body) + }, + check(hasStatus(200), hasContents("Some body"), hasContentLength(9)), + }, } r, _ := http.NewRequest("GET", "http://foo.com/", nil) for _, tt := range tests {