1
0
mirror of https://github.com/golang/go synced 2024-11-12 04:30:22 -07:00

net/http/httptest: restore historic ResponseRecorder.HeaderMap behavior

In Go versions 1 up to and including Go 1.6,
ResponseRecorder.HeaderMap was both the map that handlers got access
to, and was the map tests checked their results against. That did not
mimic the behavior of the real HTTP server (Issue #8857), so HeaderMap
was changed to be a snapshot at the first write in
https://golang.org/cl/20047. But that broke cases where the Handler
never did a write (#15560), so revert the behavior.

Instead, introduce the ResponseWriter.Result method, returning an
*http.Response. It subsumes ResponseWriter.Trailers which was added
for Go 1.7 in CL 20047. Result().Header now contains the correct
answer, and HeaderMap is unchanged in behavior from previous Go
releases, so we don't break people's tests. People wanting the correct
behavior can use ResponseWriter.Result.

Fixes #15560
Updates #8857

Change-Id: I7ea9b56a6b843103784553d67f67847b5315b3d2
Reviewed-on: https://go-review.googlesource.com/23257
Reviewed-by: Damien Neil <dneil@google.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Brad Fitzpatrick 2016-05-19 18:05:10 +00:00
parent 3b50adbc4f
commit 0b80659832
2 changed files with 122 additions and 37 deletions

View File

@ -6,6 +6,7 @@ package httptest
import (
"bytes"
"io/ioutil"
"net/http"
)
@ -17,9 +18,8 @@ type ResponseRecorder struct {
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
Flushed bool
stagingMap http.Header // map that handlers manipulate to set headers
trailerMap http.Header // lazily filled when Trailers() is called
result *http.Response // cache of Result's return value
snapHeader http.Header // snapshot of HeaderMap at first Write
wroteHeader bool
}
@ -38,10 +38,10 @@ const DefaultRemoteAddr = "1.2.3.4"
// Header returns the response headers.
func (rw *ResponseRecorder) Header() http.Header {
m := rw.stagingMap
m := rw.HeaderMap
if m == nil {
m = make(http.Header)
rw.stagingMap = m
rw.HeaderMap = m
}
return m
}
@ -104,11 +104,17 @@ func (rw *ResponseRecorder) WriteHeader(code int) {
if rw.HeaderMap == nil {
rw.HeaderMap = make(http.Header)
}
for k, vv := range rw.stagingMap {
rw.snapHeader = cloneHeader(rw.HeaderMap)
}
func cloneHeader(h http.Header) http.Header {
h2 := make(http.Header, len(h))
for k, vv := range h {
vv2 := make([]string, len(vv))
copy(vv2, vv)
rw.HeaderMap[k] = vv2
h2[k] = vv2
}
return h2
}
// Flush sets rw.Flushed to true.
@ -119,32 +125,61 @@ func (rw *ResponseRecorder) Flush() {
rw.Flushed = true
}
// Trailers returns any trailers set by the handler. It must be called
// after the handler finished running.
func (rw *ResponseRecorder) Trailers() http.Header {
if rw.trailerMap != nil {
return rw.trailerMap
// Result returns the response generated by the handler.
//
// The returned Response will have at least its StatusCode,
// Header, Body, and optionally Trailer populated.
// More fields may be populated in the future, so callers should
// not DeepEqual the result in tests.
//
// The Response.Header is a snapshot of the headers at the time of the
// first write call, or at the time of this call, if the handler never
// did a write.
//
// Result must only be called after the handler has finished running.
func (rw *ResponseRecorder) Result() *http.Response {
if rw.result != nil {
return rw.result
}
trailers, ok := rw.HeaderMap["Trailer"]
if !ok {
rw.trailerMap = make(http.Header)
return rw.trailerMap
if rw.snapHeader == nil {
rw.snapHeader = cloneHeader(rw.HeaderMap)
}
rw.trailerMap = make(http.Header, len(trailers))
for _, k := range trailers {
switch k {
case "Transfer-Encoding", "Content-Length", "Trailer":
// Ignore since forbidden by RFC 2616 14.40.
continue
res := &http.Response{
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
StatusCode: rw.Code,
Header: rw.snapHeader,
}
rw.result = res
if res.StatusCode == 0 {
res.StatusCode = 200
}
res.Status = http.StatusText(res.StatusCode)
if rw.Body != nil {
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
}
if trailers, ok := rw.snapHeader["Trailer"]; ok {
res.Trailer = make(http.Header, len(trailers))
for _, k := range trailers {
// TODO: use http2.ValidTrailerHeader, but we can't
// get at it easily because it's bundled into net/http
// unexported. This is good enough for now:
switch k {
case "Transfer-Encoding", "Content-Length", "Trailer":
// Ignore since forbidden by RFC 2616 14.40.
continue
}
k = http.CanonicalHeaderKey(k)
vv, ok := rw.HeaderMap[k]
if !ok {
continue
}
vv2 := make([]string, len(vv))
copy(vv2, vv)
res.Trailer[k] = vv2
}
k = http.CanonicalHeaderKey(k)
vv, ok := rw.stagingMap[k]
if !ok {
continue
}
vv2 := make([]string, len(vv))
copy(vv2, vv)
rw.trailerMap[k] = vv2
}
return rw.trailerMap
return res
}

View File

@ -23,6 +23,14 @@ func TestRecorder(t *testing.T) {
return nil
}
}
hasResultStatus := func(wantCode int) checkFunc {
return func(rec *ResponseRecorder) error {
if rec.Result().StatusCode != wantCode {
return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
}
return nil
}
}
hasContents := func(want string) checkFunc {
return func(rec *ResponseRecorder) error {
if rec.Body.String() != want {
@ -39,10 +47,18 @@ func TestRecorder(t *testing.T) {
return nil
}
}
hasHeader := func(key, want string) checkFunc {
hasOldHeader := func(key, want string) checkFunc {
return func(rec *ResponseRecorder) error {
if got := rec.HeaderMap.Get(key); got != want {
return fmt.Errorf("header %s = %q; want %q", key, got, want)
return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
}
return nil
}
}
hasHeader := func(key, want string) checkFunc {
return func(rec *ResponseRecorder) error {
if got := rec.Result().Header.Get(key); got != want {
return fmt.Errorf("final header %s = %q; want %q", key, got, want)
}
return nil
}
@ -50,9 +66,9 @@ func TestRecorder(t *testing.T) {
hasNotHeaders := func(keys ...string) checkFunc {
return func(rec *ResponseRecorder) error {
for _, k := range keys {
_, ok := rec.HeaderMap[http.CanonicalHeaderKey(k)]
v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
if ok {
return fmt.Errorf("unexpected header %s", k)
return fmt.Errorf("unexpected header %s with value %q", k, v)
}
}
return nil
@ -60,7 +76,7 @@ func TestRecorder(t *testing.T) {
}
hasTrailer := func(key, want string) checkFunc {
return func(rec *ResponseRecorder) error {
if got := rec.Trailers().Get(key); got != want {
if got := rec.Result().Trailer.Get(key); got != want {
return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
}
return nil
@ -68,7 +84,7 @@ func TestRecorder(t *testing.T) {
}
hasNotTrailers := func(keys ...string) checkFunc {
return func(rec *ResponseRecorder) error {
trailers := rec.Trailers()
trailers := rec.Result().Trailer
for _, k := range keys {
_, ok := trailers[http.CanonicalHeaderKey(k)]
if ok {
@ -194,6 +210,40 @@ func TestRecorder(t *testing.T) {
hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
),
},
{
"Header set without any write", // Issue 15560
func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Foo", "1")
// Simulate somebody using
// new(ResponseRecorder) instead of
// using the constructor which sets
// this to 200
w.(*ResponseRecorder).Code = 0
},
check(
hasOldHeader("X-Foo", "1"),
hasStatus(0),
hasHeader("X-Foo", "1"),
hasResultStatus(200),
),
},
{
"HeaderMap vs FinalHeaders", // more for Issue 15560
func(w http.ResponseWriter, r *http.Request) {
h := w.Header()
h.Set("X-Foo", "1")
w.Write([]byte("hi"))
h.Set("X-Foo", "2")
h.Set("X-Bar", "2")
},
check(
hasOldHeader("X-Foo", "2"),
hasOldHeader("X-Bar", "2"),
hasHeader("X-Foo", "1"),
hasNotHeaders("X-Bar"),
),
},
}
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
for _, tt := range tests {