From 3a41bfac9bfddc3663d81b8296dbe8904baef44e Mon Sep 17 00:00:00 2001 From: Kevin Burke Date: Thu, 7 Dec 2023 13:13:25 -0800 Subject: [PATCH] net/http/httptest: add NewRequestWithContext This matches the net/http API. Updates #59473. Change-Id: I99917cef3ed42a0b4a2b39230b492be00da8bbfd Reviewed-on: https://go-review.googlesource.com/c/go/+/548355 LUCI-TryBot-Result: Go LUCI Reviewed-by: Damien Neil Reviewed-by: Cherry Mui --- api/next/59473.txt | 1 + .../99-minor/net/http/httptest/59473.md | 2 ++ src/net/http/httptest/httptest.go | 11 +++++++-- src/net/http/httptest/httptest_test.go | 24 ++++++++++++++++++- 4 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 api/next/59473.txt create mode 100644 doc/next/6-stdlib/99-minor/net/http/httptest/59473.md diff --git a/api/next/59473.txt b/api/next/59473.txt new file mode 100644 index 0000000000..da6902d424 --- /dev/null +++ b/api/next/59473.txt @@ -0,0 +1 @@ +pkg net/http/httptest, func NewRequestWithContext(context.Context, string, string, io.Reader) *http.Request #59473 diff --git a/doc/next/6-stdlib/99-minor/net/http/httptest/59473.md b/doc/next/6-stdlib/99-minor/net/http/httptest/59473.md new file mode 100644 index 0000000000..65cc6076cf --- /dev/null +++ b/doc/next/6-stdlib/99-minor/net/http/httptest/59473.md @@ -0,0 +1,2 @@ +The new NewRequestWithContext method creates an incoming request with +a Context. diff --git a/src/net/http/httptest/httptest.go b/src/net/http/httptest/httptest.go index f0ca64362d..0c0dbb40e8 100644 --- a/src/net/http/httptest/httptest.go +++ b/src/net/http/httptest/httptest.go @@ -8,13 +8,19 @@ package httptest import ( "bufio" "bytes" + "context" "crypto/tls" "io" "net/http" "strings" ) -// NewRequest returns a new incoming server Request, suitable +// NewRequest wraps NewRequestWithContext using context.Background. +func NewRequest(method, target string, body io.Reader) *http.Request { + return NewRequestWithContext(context.Background(), method, target, body) +} + +// NewRequestWithContext returns a new incoming server Request, suitable // for passing to an [http.Handler] for testing. // // The target is the RFC 7230 "request-target": it may be either a @@ -37,7 +43,7 @@ import ( // // To generate a client HTTP request instead of a server request, see // the NewRequest function in the net/http package. -func NewRequest(method, target string, body io.Reader) *http.Request { +func NewRequestWithContext(ctx context.Context, method, target string, body io.Reader) *http.Request { if method == "" { method = "GET" } @@ -45,6 +51,7 @@ func NewRequest(method, target string, body io.Reader) *http.Request { if err != nil { panic("invalid NewRequest arguments; " + err.Error()) } + req = req.WithContext(ctx) // HTTP/1.0 was used above to avoid needing a Host field. Change it to 1.1 here. req.Proto = "HTTP/1.1" diff --git a/src/net/http/httptest/httptest_test.go b/src/net/http/httptest/httptest_test.go index 071add67ea..d5a4c3dc9d 100644 --- a/src/net/http/httptest/httptest_test.go +++ b/src/net/http/httptest/httptest_test.go @@ -5,6 +5,7 @@ package httptest import ( + "context" "crypto/tls" "io" "net/http" @@ -15,6 +16,26 @@ import ( ) func TestNewRequest(t *testing.T) { + got := NewRequest("GET", "/", nil) + want := &http.Request{ + Method: "GET", + Host: "example.com", + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "/", + } + got.Body = nil // before DeepEqual + want = want.WithContext(context.Background()) + if !reflect.DeepEqual(got, want) { + t.Errorf("Request mismatch:\n got: %#v\nwant: %#v", got, want) + } +} + +func TestNewRequestWithContext(t *testing.T) { for _, tt := range [...]struct { name string @@ -153,7 +174,7 @@ func TestNewRequest(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - got := NewRequest(tt.method, tt.uri, tt.body) + got := NewRequestWithContext(context.Background(), tt.method, tt.uri, tt.body) slurp, err := io.ReadAll(got.Body) if err != nil { t.Errorf("ReadAll: %v", err) @@ -161,6 +182,7 @@ func TestNewRequest(t *testing.T) { if string(slurp) != tt.wantBody { t.Errorf("Body = %q; want %q", slurp, tt.wantBody) } + tt.want = tt.want.WithContext(context.Background()) got.Body = nil // before DeepEqual if !reflect.DeepEqual(got.URL, tt.want.URL) { t.Errorf("Request.URL mismatch:\n got: %#v\nwant: %#v", got.URL, tt.want.URL)