mirror of
https://github.com/golang/go
synced 2024-11-18 14:04:45 -07:00
net/http/httptest: fill ContentLength in recorded Response
This change fills the ContentLength field in the http.Response returned by ResponseRecorder.Result. Fixes #16952. Change-Id: I9c49b1bf83e3719b5275b03a43aff5033156637d Reviewed-on: https://go-review.googlesource.com/28302 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
e69d63e807
commit
ea143c2990
@ -8,6 +8,8 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ResponseRecorder is an implementation of http.ResponseWriter that
|
// ResponseRecorder is an implementation of http.ResponseWriter that
|
||||||
@ -162,6 +164,7 @@ func (rw *ResponseRecorder) Result() *http.Response {
|
|||||||
if rw.Body != nil {
|
if rw.Body != nil {
|
||||||
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
|
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
|
||||||
}
|
}
|
||||||
|
res.ContentLength = parseContentLength(res.Header.Get("Content-Length"))
|
||||||
|
|
||||||
if trailers, ok := rw.snapHeader["Trailer"]; ok {
|
if trailers, ok := rw.snapHeader["Trailer"]; ok {
|
||||||
res.Trailer = make(http.Header, len(trailers))
|
res.Trailer = make(http.Header, len(trailers))
|
||||||
@ -186,3 +189,20 @@ func (rw *ResponseRecorder) Result() *http.Response {
|
|||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// parseContentLength trims whitespace from s and returns -1 if no value
|
||||||
|
// is set, or the value if it's >= 0.
|
||||||
|
//
|
||||||
|
// This a modified version of same function found in net/http/transfer.go. This
|
||||||
|
// one just ignores an invalid header.
|
||||||
|
func parseContentLength(cl string) int64 {
|
||||||
|
cl = strings.TrimSpace(cl)
|
||||||
|
if cl == "" {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
n, err := strconv.ParseInt(cl, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return -1
|
||||||
|
}
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
@ -94,6 +94,14 @@ func TestRecorder(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
hasContentLength := func(length int64) checkFunc {
|
||||||
|
return func(rec *ResponseRecorder) error {
|
||||||
|
if got := rec.Result().ContentLength; got != length {
|
||||||
|
return fmt.Errorf("ContentLength = %d; want %d", got, length)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
@ -141,7 +149,7 @@ func TestRecorder(t *testing.T) {
|
|||||||
w.(http.Flusher).Flush() // also sends a 200
|
w.(http.Flusher).Flush() // also sends a 200
|
||||||
w.WriteHeader(201)
|
w.WriteHeader(201)
|
||||||
},
|
},
|
||||||
check(hasStatus(200), hasFlush(true)),
|
check(hasStatus(200), hasFlush(true), hasContentLength(-1)),
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"Content-Type detection",
|
"Content-Type detection",
|
||||||
@ -244,6 +252,16 @@ func TestRecorder(t *testing.T) {
|
|||||||
hasNotHeaders("X-Bar"),
|
hasNotHeaders("X-Bar"),
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"setting Content-Length header",
|
||||||
|
func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body := "Some body"
|
||||||
|
contentLength := fmt.Sprintf("%d", len(body))
|
||||||
|
w.Header().Set("Content-Length", contentLength)
|
||||||
|
io.WriteString(w, body)
|
||||||
|
},
|
||||||
|
check(hasStatus(200), hasContents("Some body"), hasContentLength(9)),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
|
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
Loading…
Reference in New Issue
Block a user