mirror of
https://github.com/golang/go
synced 2024-11-23 04:20:03 -07:00
net/http/httptest: mimic the normal HTTP server's ResponseWriter more closely
Also adds tests, which didn't exist before. Fixes #4188 R=golang-dev, rsc CC=golang-dev https://golang.org/cl/6613062
This commit is contained in:
parent
d9953c9dde
commit
13576e3b65
@ -17,6 +17,8 @@ type ResponseRecorder struct {
|
||||
HeaderMap http.Header // the HTTP response headers
|
||||
Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
|
||||
Flushed bool
|
||||
|
||||
wroteHeader bool
|
||||
}
|
||||
|
||||
// NewRecorder returns an initialized ResponseRecorder.
|
||||
@ -24,6 +26,7 @@ func NewRecorder() *ResponseRecorder {
|
||||
return &ResponseRecorder{
|
||||
HeaderMap: make(http.Header),
|
||||
Body: new(bytes.Buffer),
|
||||
Code: 200,
|
||||
}
|
||||
}
|
||||
|
||||
@ -33,26 +36,37 @@ const DefaultRemoteAddr = "1.2.3.4"
|
||||
|
||||
// Header returns the response headers.
|
||||
func (rw *ResponseRecorder) Header() http.Header {
|
||||
return rw.HeaderMap
|
||||
m := rw.HeaderMap
|
||||
if m == nil {
|
||||
m = make(http.Header)
|
||||
rw.HeaderMap = m
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// Write always succeeds and writes to rw.Body, if not nil.
|
||||
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
|
||||
if !rw.wroteHeader {
|
||||
rw.WriteHeader(200)
|
||||
}
|
||||
if rw.Body != nil {
|
||||
rw.Body.Write(buf)
|
||||
}
|
||||
if rw.Code == 0 {
|
||||
rw.Code = http.StatusOK
|
||||
}
|
||||
return len(buf), nil
|
||||
}
|
||||
|
||||
// WriteHeader sets rw.Code.
|
||||
func (rw *ResponseRecorder) WriteHeader(code int) {
|
||||
if !rw.wroteHeader {
|
||||
rw.Code = code
|
||||
}
|
||||
rw.wroteHeader = true
|
||||
}
|
||||
|
||||
// Flush sets rw.Flushed to true.
|
||||
func (rw *ResponseRecorder) Flush() {
|
||||
if !rw.wroteHeader {
|
||||
rw.WriteHeader(200)
|
||||
}
|
||||
rw.Flushed = true
|
||||
}
|
||||
|
90
src/pkg/net/http/httptest/recorder_test.go
Normal file
90
src/pkg/net/http/httptest/recorder_test.go
Normal file
@ -0,0 +1,90 @@
|
||||
// Copyright 2012 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package httptest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRecorder(t *testing.T) {
|
||||
type checkFunc func(*ResponseRecorder) error
|
||||
check := func(fns ...checkFunc) []checkFunc { return fns }
|
||||
|
||||
hasStatus := func(wantCode int) checkFunc {
|
||||
return func(rec *ResponseRecorder) error {
|
||||
if rec.Code != wantCode {
|
||||
return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
hasContents := func(want string) checkFunc {
|
||||
return func(rec *ResponseRecorder) error {
|
||||
if rec.Body.String() != want {
|
||||
return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
hasFlush := func(want bool) checkFunc {
|
||||
return func(rec *ResponseRecorder) error {
|
||||
if rec.Flushed != want {
|
||||
return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
h func(w http.ResponseWriter, r *http.Request)
|
||||
checks []checkFunc
|
||||
}{
|
||||
{
|
||||
"200 default",
|
||||
func(w http.ResponseWriter, r *http.Request) {},
|
||||
check(hasStatus(200), hasContents("")),
|
||||
},
|
||||
{
|
||||
"first code only",
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(201)
|
||||
w.WriteHeader(202)
|
||||
w.Write([]byte("hi"))
|
||||
},
|
||||
check(hasStatus(201), hasContents("hi")),
|
||||
},
|
||||
{
|
||||
"write sends 200",
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte("hi first"))
|
||||
w.WriteHeader(201)
|
||||
w.WriteHeader(202)
|
||||
},
|
||||
check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
|
||||
},
|
||||
{
|
||||
"flush",
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
w.(http.Flusher).Flush() // also sends a 200
|
||||
w.WriteHeader(201)
|
||||
},
|
||||
check(hasStatus(200), hasFlush(true)),
|
||||
},
|
||||
}
|
||||
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
|
||||
for _, tt := range tests {
|
||||
h := http.HandlerFunc(tt.h)
|
||||
rec := NewRecorder()
|
||||
h.ServeHTTP(rec, r)
|
||||
for _, check := range tt.checks {
|
||||
if err := check(rec); err != nil {
|
||||
t.Errorf("%s: %v", tt.name, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user