From c69e6869c9793872cb0282008ea8ab643a92da65 Mon Sep 17 00:00:00 2001 From: Caio Marcelo de Oliveira Filho Date: Mon, 29 Feb 2016 17:46:48 -0300 Subject: [PATCH] net/http/httptest: record trailing headers in ResponseRecorder Trailers() returns the headers that were set by the handler after the headers were written "to the wire" (in this case HeaderMap) and that were also specified in a proper header called "Trailer". Neither HeaderMap or trailerMap (used for Trailers()) are manipulated by the handler code, instead a third stagingMap is given to the handler. This avoid a reference kept by handler to affect the recorded results. If a handler just modify the header but doesn't call any Write or Flush method from ResponseWriter (or Flusher) interface, HeaderMap will not be updated. In this case, calling Flush in the recorder is enough to get the HeaderMap filled. Fixes #14531. Fixes #8857. Change-Id: I42842341ec3e95c7b87d7e6f178c65cd03d63cc3 Reviewed-on: https://go-review.googlesource.com/20047 Reviewed-by: Brad Fitzpatrick --- src/net/http/httptest/recorder.go | 66 +++++++++++++++++++++----- src/net/http/httptest/recorder_test.go | 64 +++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 12 deletions(-) diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index 7c51af1867..4e3948dd91 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -18,6 +18,9 @@ type ResponseRecorder struct { Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to Flushed bool + stagingMap http.Header // map that handlers manipulate to set headers + trailerMap http.Header // lazily filled when Trailers() is called + wroteHeader bool } @@ -36,10 +39,10 @@ const DefaultRemoteAddr = "1.2.3.4" // Header returns the response headers. func (rw *ResponseRecorder) Header() http.Header { - m := rw.HeaderMap + m := rw.stagingMap if m == nil { m = make(http.Header) - rw.HeaderMap = m + rw.stagingMap = m } return m } @@ -59,16 +62,15 @@ func (rw *ResponseRecorder) writeHeader(b []byte, str string) { str = str[:512] } - _, hasType := rw.HeaderMap["Content-Type"] - hasTE := rw.HeaderMap.Get("Transfer-Encoding") != "" + m := rw.Header() + + _, hasType := m["Content-Type"] + hasTE := m.Get("Transfer-Encoding") != "" if !hasType && !hasTE { if b == nil { b = []byte(str) } - if rw.HeaderMap == nil { - rw.HeaderMap = make(http.Header) - } - rw.HeaderMap.Set("Content-Type", http.DetectContentType(b)) + m.Set("Content-Type", http.DetectContentType(b)) } rw.WriteHeader(200) @@ -92,11 +94,21 @@ func (rw *ResponseRecorder) WriteString(str string) (int, error) { return len(str), nil } -// WriteHeader sets rw.Code. +// WriteHeader sets rw.Code. After it is called, changing rw.Header +// will not affect rw.HeaderMap. func (rw *ResponseRecorder) WriteHeader(code int) { - if !rw.wroteHeader { - rw.Code = code - rw.wroteHeader = true + if rw.wroteHeader { + return + } + rw.Code = code + rw.wroteHeader = true + if rw.HeaderMap == nil { + rw.HeaderMap = make(http.Header) + } + for k, vv := range rw.stagingMap { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + rw.HeaderMap[k] = vv2 } } @@ -107,3 +119,33 @@ func (rw *ResponseRecorder) Flush() { } rw.Flushed = true } + +// Trailers returns any trailers set by the handler. It must be called +// after the handler finished running. +func (rw *ResponseRecorder) Trailers() http.Header { + if rw.trailerMap != nil { + return rw.trailerMap + } + trailers, ok := rw.HeaderMap["Trailer"] + if !ok { + rw.trailerMap = make(http.Header) + return rw.trailerMap + } + rw.trailerMap = make(http.Header, len(trailers)) + for _, k := range trailers { + switch k { + case "Transfer-Encoding", "Content-Length", "Trailer": + // Ignore since forbidden by RFC 2616 14.40. + continue + } + k = http.CanonicalHeaderKey(k) + vv, ok := rw.stagingMap[k] + if !ok { + continue + } + vv2 := make([]string, len(vv)) + copy(vv2, vv) + rw.trailerMap[k] = vv2 + } + return rw.trailerMap +} diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index c29b6d4cf9..19a37b6c54 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -47,6 +47,37 @@ func TestRecorder(t *testing.T) { return nil } } + hasNotHeaders := func(keys ...string) checkFunc { + return func(rec *ResponseRecorder) error { + for _, k := range keys { + _, ok := rec.HeaderMap[http.CanonicalHeaderKey(k)] + if ok { + return fmt.Errorf("unexpected header %s", k) + } + } + return nil + } + } + hasTrailer := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Trailers().Get(key); got != want { + return fmt.Errorf("trailer %s = %q; want %q", key, got, want) + } + return nil + } + } + hasNotTrailers := func(keys ...string) checkFunc { + return func(rec *ResponseRecorder) error { + trailers := rec.Trailers() + for _, k := range keys { + _, ok := trailers[http.CanonicalHeaderKey(k)] + if ok { + return fmt.Errorf("unexpected trailer %s", k) + } + } + return nil + } + } tests := []struct { name string @@ -130,6 +161,39 @@ func TestRecorder(t *testing.T) { }, check(hasHeader("Content-Type", "text/html; charset=utf-8")), }, + { + "Header is not changed after write", + func(w http.ResponseWriter, r *http.Request) { + hdr := w.Header() + hdr.Set("Key", "correct") + w.WriteHeader(200) + hdr.Set("Key", "incorrect") + }, + check(hasHeader("Key", "correct")), + }, + { + "Trailer headers are correctly recorded", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Non-Trailer", "correct") + w.Header().Set("Trailer", "Trailer-A") + w.Header().Add("Trailer", "Trailer-B") + w.Header().Add("Trailer", "Trailer-C") + io.WriteString(w, "") + w.Header().Set("Non-Trailer", "incorrect") + w.Header().Set("Trailer-A", "valuea") + w.Header().Set("Trailer-C", "valuec") + w.Header().Set("Trailer-NotDeclared", "should be omitted") + }, + check( + hasStatus(200), + hasHeader("Content-Type", "text/html; charset=utf-8"), + hasHeader("Non-Trailer", "correct"), + hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"), + hasTrailer("Trailer-A", "valuea"), + hasTrailer("Trailer-C", "valuec"), + hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"), + ), + }, } r, _ := http.NewRequest("GET", "http://foo.com/", nil) for _, tt := range tests {