1
0
mirror of https://github.com/golang/go synced 2024-11-05 15:26:15 -07:00

net/http/httptest: fill ContentLength in recorded Response

This change fills the ContentLength field in the http.Response returned by
ResponseRecorder.Result.

Fixes #16952.

Change-Id: I9c49b1bf83e3719b5275b03a43aff5033156637d
Reviewed-on: https://go-review.googlesource.com/28302
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Thomas de Zeeuw 2016-09-01 14:54:08 +02:00 committed by Brad Fitzpatrick
parent e69d63e807
commit ea143c2990
2 changed files with 39 additions and 1 deletions

View File

@ -8,6 +8,8 @@ import (
"bytes"
"io/ioutil"
"net/http"
"strconv"
"strings"
)
// ResponseRecorder is an implementation of http.ResponseWriter that
@ -162,6 +164,7 @@ func (rw *ResponseRecorder) Result() *http.Response {
if rw.Body != nil {
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
}
res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
if trailers, ok := rw.snapHeader["Trailer"]; ok {
res.Trailer = make(http.Header, len(trailers))
@ -186,3 +189,20 @@ func (rw *ResponseRecorder) Result() *http.Response {
}
return res
}
// parseContentLength trims whitespace from s and returns -1 if no value
// is set, or the value if it's >= 0.
//
// This a modified version of same function found in net/http/transfer.go. This
// one just ignores an invalid header.
func parseContentLength(cl string) int64 {
cl = strings.TrimSpace(cl)
if cl == "" {
return -1
}
n, err := strconv.ParseInt(cl, 10, 64)
if err != nil {
return -1
}
return n
}

View File

@ -94,6 +94,14 @@ func TestRecorder(t *testing.T) {
return nil
}
}
hasContentLength := func(length int64) checkFunc {
return func(rec *ResponseRecorder) error {
if got := rec.Result().ContentLength; got != length {
return fmt.Errorf("ContentLength = %d; want %d", got, length)
}
return nil
}
}
tests := []struct {
name string
@ -141,7 +149,7 @@ func TestRecorder(t *testing.T) {
w.(http.Flusher).Flush() // also sends a 200
w.WriteHeader(201)
},
check(hasStatus(200), hasFlush(true)),
check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
},
{
"Content-Type detection",
@ -244,6 +252,16 @@ func TestRecorder(t *testing.T) {
hasNotHeaders("X-Bar"),
),
},
{
"setting Content-Length header",
func(w http.ResponseWriter, r *http.Request) {
body := "Some body"
contentLength := fmt.Sprintf("%d", len(body))
w.Header().Set("Content-Length", contentLength)
io.WriteString(w, body)
},
check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
},
}
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
for _, tt := range tests {