mirror of
https://github.com/golang/go
synced 2024-11-12 04:30:22 -07:00
net/http/httptest: restore historic ResponseRecorder.HeaderMap behavior
In Go versions 1 up to and including Go 1.6, ResponseRecorder.HeaderMap was both the map that handlers got access to, and was the map tests checked their results against. That did not mimic the behavior of the real HTTP server (Issue #8857), so HeaderMap was changed to be a snapshot at the first write in https://golang.org/cl/20047. But that broke cases where the Handler never did a write (#15560), so revert the behavior. Instead, introduce the ResponseWriter.Result method, returning an *http.Response. It subsumes ResponseWriter.Trailers which was added for Go 1.7 in CL 20047. Result().Header now contains the correct answer, and HeaderMap is unchanged in behavior from previous Go releases, so we don't break people's tests. People wanting the correct behavior can use ResponseWriter.Result. Fixes #15560 Updates #8857 Change-Id: I7ea9b56a6b843103784553d67f67847b5315b3d2 Reviewed-on: https://go-review.googlesource.com/23257 Reviewed-by: Damien Neil <dneil@google.com> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
3b50adbc4f
commit
0b80659832
@ -6,6 +6,7 @@ package httptest
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
@ -17,9 +18,8 @@ type ResponseRecorder struct {
|
||||
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
|
||||
Flushed bool
|
||||
|
||||
stagingMap http.Header // map that handlers manipulate to set headers
|
||||
trailerMap http.Header // lazily filled when Trailers() is called
|
||||
|
||||
result *http.Response // cache of Result's return value
|
||||
snapHeader http.Header // snapshot of HeaderMap at first Write
|
||||
wroteHeader bool
|
||||
}
|
||||
|
||||
@ -38,10 +38,10 @@ const DefaultRemoteAddr = "1.2.3.4"
|
||||
|
||||
// Header returns the response headers.
|
||||
func (rw *ResponseRecorder) Header() http.Header {
|
||||
m := rw.stagingMap
|
||||
m := rw.HeaderMap
|
||||
if m == nil {
|
||||
m = make(http.Header)
|
||||
rw.stagingMap = m
|
||||
rw.HeaderMap = m
|
||||
}
|
||||
return m
|
||||
}
|
||||
@ -104,11 +104,17 @@ func (rw *ResponseRecorder) WriteHeader(code int) {
|
||||
if rw.HeaderMap == nil {
|
||||
rw.HeaderMap = make(http.Header)
|
||||
}
|
||||
for k, vv := range rw.stagingMap {
|
||||
rw.snapHeader = cloneHeader(rw.HeaderMap)
|
||||
}
|
||||
|
||||
func cloneHeader(h http.Header) http.Header {
|
||||
h2 := make(http.Header, len(h))
|
||||
for k, vv := range h {
|
||||
vv2 := make([]string, len(vv))
|
||||
copy(vv2, vv)
|
||||
rw.HeaderMap[k] = vv2
|
||||
h2[k] = vv2
|
||||
}
|
||||
return h2
|
||||
}
|
||||
|
||||
// Flush sets rw.Flushed to true.
|
||||
@ -119,32 +125,61 @@ func (rw *ResponseRecorder) Flush() {
|
||||
rw.Flushed = true
|
||||
}
|
||||
|
||||
// Trailers returns any trailers set by the handler. It must be called
|
||||
// after the handler finished running.
|
||||
func (rw *ResponseRecorder) Trailers() http.Header {
|
||||
if rw.trailerMap != nil {
|
||||
return rw.trailerMap
|
||||
// Result returns the response generated by the handler.
|
||||
//
|
||||
// The returned Response will have at least its StatusCode,
|
||||
// Header, Body, and optionally Trailer populated.
|
||||
// More fields may be populated in the future, so callers should
|
||||
// not DeepEqual the result in tests.
|
||||
//
|
||||
// The Response.Header is a snapshot of the headers at the time of the
|
||||
// first write call, or at the time of this call, if the handler never
|
||||
// did a write.
|
||||
//
|
||||
// Result must only be called after the handler has finished running.
|
||||
func (rw *ResponseRecorder) Result() *http.Response {
|
||||
if rw.result != nil {
|
||||
return rw.result
|
||||
}
|
||||
trailers, ok := rw.HeaderMap["Trailer"]
|
||||
if !ok {
|
||||
rw.trailerMap = make(http.Header)
|
||||
return rw.trailerMap
|
||||
if rw.snapHeader == nil {
|
||||
rw.snapHeader = cloneHeader(rw.HeaderMap)
|
||||
}
|
||||
rw.trailerMap = make(http.Header, len(trailers))
|
||||
for _, k := range trailers {
|
||||
switch k {
|
||||
case "Transfer-Encoding", "Content-Length", "Trailer":
|
||||
// Ignore since forbidden by RFC 2616 14.40.
|
||||
continue
|
||||
res := &http.Response{
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
StatusCode: rw.Code,
|
||||
Header: rw.snapHeader,
|
||||
}
|
||||
rw.result = res
|
||||
if res.StatusCode == 0 {
|
||||
res.StatusCode = 200
|
||||
}
|
||||
res.Status = http.StatusText(res.StatusCode)
|
||||
if rw.Body != nil {
|
||||
res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes()))
|
||||
}
|
||||
|
||||
if trailers, ok := rw.snapHeader["Trailer"]; ok {
|
||||
res.Trailer = make(http.Header, len(trailers))
|
||||
for _, k := range trailers {
|
||||
// TODO: use http2.ValidTrailerHeader, but we can't
|
||||
// get at it easily because it's bundled into net/http
|
||||
// unexported. This is good enough for now:
|
||||
switch k {
|
||||
case "Transfer-Encoding", "Content-Length", "Trailer":
|
||||
// Ignore since forbidden by RFC 2616 14.40.
|
||||
continue
|
||||
}
|
||||
k = http.CanonicalHeaderKey(k)
|
||||
vv, ok := rw.HeaderMap[k]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
vv2 := make([]string, len(vv))
|
||||
copy(vv2, vv)
|
||||
res.Trailer[k] = vv2
|
||||
}
|
||||
k = http.CanonicalHeaderKey(k)
|
||||
vv, ok := rw.stagingMap[k]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
vv2 := make([]string, len(vv))
|
||||
copy(vv2, vv)
|
||||
rw.trailerMap[k] = vv2
|
||||
}
|
||||
return rw.trailerMap
|
||||
return res
|
||||
}
|
||||
|
@ -23,6 +23,14 @@ func TestRecorder(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
hasResultStatus := func(wantCode int) checkFunc {
|
||||
return func(rec *ResponseRecorder) error {
|
||||
if rec.Result().StatusCode != wantCode {
|
||||
return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
hasContents := func(want string) checkFunc {
|
||||
return func(rec *ResponseRecorder) error {
|
||||
if rec.Body.String() != want {
|
||||
@ -39,10 +47,18 @@ func TestRecorder(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
hasHeader := func(key, want string) checkFunc {
|
||||
hasOldHeader := func(key, want string) checkFunc {
|
||||
return func(rec *ResponseRecorder) error {
|
||||
if got := rec.HeaderMap.Get(key); got != want {
|
||||
return fmt.Errorf("header %s = %q; want %q", key, got, want)
|
||||
return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
hasHeader := func(key, want string) checkFunc {
|
||||
return func(rec *ResponseRecorder) error {
|
||||
if got := rec.Result().Header.Get(key); got != want {
|
||||
return fmt.Errorf("final header %s = %q; want %q", key, got, want)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -50,9 +66,9 @@ func TestRecorder(t *testing.T) {
|
||||
hasNotHeaders := func(keys ...string) checkFunc {
|
||||
return func(rec *ResponseRecorder) error {
|
||||
for _, k := range keys {
|
||||
_, ok := rec.HeaderMap[http.CanonicalHeaderKey(k)]
|
||||
v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)]
|
||||
if ok {
|
||||
return fmt.Errorf("unexpected header %s", k)
|
||||
return fmt.Errorf("unexpected header %s with value %q", k, v)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@ -60,7 +76,7 @@ func TestRecorder(t *testing.T) {
|
||||
}
|
||||
hasTrailer := func(key, want string) checkFunc {
|
||||
return func(rec *ResponseRecorder) error {
|
||||
if got := rec.Trailers().Get(key); got != want {
|
||||
if got := rec.Result().Trailer.Get(key); got != want {
|
||||
return fmt.Errorf("trailer %s = %q; want %q", key, got, want)
|
||||
}
|
||||
return nil
|
||||
@ -68,7 +84,7 @@ func TestRecorder(t *testing.T) {
|
||||
}
|
||||
hasNotTrailers := func(keys ...string) checkFunc {
|
||||
return func(rec *ResponseRecorder) error {
|
||||
trailers := rec.Trailers()
|
||||
trailers := rec.Result().Trailer
|
||||
for _, k := range keys {
|
||||
_, ok := trailers[http.CanonicalHeaderKey(k)]
|
||||
if ok {
|
||||
@ -194,6 +210,40 @@ func TestRecorder(t *testing.T) {
|
||||
hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"),
|
||||
),
|
||||
},
|
||||
{
|
||||
"Header set without any write", // Issue 15560
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Foo", "1")
|
||||
|
||||
// Simulate somebody using
|
||||
// new(ResponseRecorder) instead of
|
||||
// using the constructor which sets
|
||||
// this to 200
|
||||
w.(*ResponseRecorder).Code = 0
|
||||
},
|
||||
check(
|
||||
hasOldHeader("X-Foo", "1"),
|
||||
hasStatus(0),
|
||||
hasHeader("X-Foo", "1"),
|
||||
hasResultStatus(200),
|
||||
),
|
||||
},
|
||||
{
|
||||
"HeaderMap vs FinalHeaders", // more for Issue 15560
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
h := w.Header()
|
||||
h.Set("X-Foo", "1")
|
||||
w.Write([]byte("hi"))
|
||||
h.Set("X-Foo", "2")
|
||||
h.Set("X-Bar", "2")
|
||||
},
|
||||
check(
|
||||
hasOldHeader("X-Foo", "2"),
|
||||
hasOldHeader("X-Bar", "2"),
|
||||
hasHeader("X-Foo", "1"),
|
||||
hasNotHeaders("X-Bar"),
|
||||
),
|
||||
},
|
||||
}
|
||||
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
|
||||
for _, tt := range tests {
|
||||
|
Loading…
Reference in New Issue
Block a user