1
0
mirror of https://github.com/golang/go synced 2024-11-05 15:26:15 -07:00

net/http/httptest: fill ContentLength in recorded Response

This change fills the ContentLength field in the http.Response returned by
ResponseRecorder.Result.

Fixes #16952.

Change-Id: I9c49b1bf83e3719b5275b03a43aff5033156637d
Reviewed-on: https://go-review.googlesource.com/28302
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Thomas de Zeeuw 2016-09-01 14:54:08 +02:00 committed by Brad Fitzpatrick
parent e69d63e807
commit ea143c2990
2 changed files with 39 additions and 1 deletions

View File

@ -8,6 +8,8 @@ import (
"bytes" "bytes"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"strconv"
"strings"
) )
// ResponseRecorder is an implementation of http.ResponseWriter that // ResponseRecorder is an implementation of http.ResponseWriter that
@ -162,6 +164,7 @@ func (rw *ResponseRecorder) Result() *http.Response {
if rw.Body != nil { if rw.Body != nil {
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes())) res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
} }
res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
if trailers, ok := rw.snapHeader["Trailer"]; ok { if trailers, ok := rw.snapHeader["Trailer"]; ok {
res.Trailer = make(http.Header, len(trailers)) res.Trailer = make(http.Header, len(trailers))
@ -186,3 +189,20 @@ func (rw *ResponseRecorder) Result() *http.Response {
} }
return res return res
} }
// parseContentLength trims whitespace from s and returns -1 if no value
// is set, or the value if it's >= 0.
//
// This a modified version of same function found in net/http/transfer.go. This
// one just ignores an invalid header.
func parseContentLength(cl string) int64 {
cl = strings.TrimSpace(cl)
if cl == "" {
return -1
}
n, err := strconv.ParseInt(cl, 10, 64)
if err != nil {
return -1
}
return n
}

View File

@ -94,6 +94,14 @@ func TestRecorder(t *testing.T) {
return nil return nil
} }
} }
hasContentLength := func(length int64) checkFunc {
return func(rec *ResponseRecorder) error {
if got := rec.Result().ContentLength; got != length {
return fmt.Errorf("ContentLength = %d; want %d", got, length)
}
return nil
}
}
tests := []struct { tests := []struct {
name string name string
@ -141,7 +149,7 @@ func TestRecorder(t *testing.T) {
w.(http.Flusher).Flush() // also sends a 200 w.(http.Flusher).Flush() // also sends a 200
w.WriteHeader(201) w.WriteHeader(201)
}, },
check(hasStatus(200), hasFlush(true)), check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
}, },
{ {
"Content-Type detection", "Content-Type detection",
@ -244,6 +252,16 @@ func TestRecorder(t *testing.T) {
hasNotHeaders("X-Bar"), hasNotHeaders("X-Bar"),
), ),
}, },
{
"setting Content-Length header",
func(w http.ResponseWriter, r *http.Request) {
body := "Some body"
contentLength := fmt.Sprintf("%d", len(body))
w.Header().Set("Content-Length", contentLength)
io.WriteString(w, body)
},
check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
},
} }
r, _ := http.NewRequest("GET", "http://foo.com/", nil) r, _ := http.NewRequest("GET", "http://foo.com/", nil)
for _, tt := range tests { for _, tt := range tests {