From aa1e063efd7376e268ee592ebe078c6d05b0bdf8 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 21 Oct 2016 12:03:41 +0100 Subject: [PATCH] net/http: add Request.GetBody func for 307/308 redirects Updates #10767 Change-Id: I197535f71bc2dc45e783f38d8031aa717d50fd80 Reviewed-on: https://go-review.googlesource.com/31733 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Emmanuel Odeke Reviewed-by: Brad Fitzpatrick --- src/net/http/client.go | 7 ++++++ src/net/http/request.go | 24 +++++++++++++++++++++ src/net/http/request_test.go | 41 ++++++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) diff --git a/src/net/http/client.go b/src/net/http/client.go index 39c38bd8dd..9b60f35708 100644 --- a/src/net/http/client.go +++ b/src/net/http/client.go @@ -485,8 +485,15 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo Cancel: ireq.Cancel, ctx: ireq.ctx, } + if ireq.GetBody != nil { + req.Body, err = ireq.GetBody() + if err != nil { + return nil, uerr(err) + } + } if ireq.Method == "POST" || ireq.Method == "PUT" { req.Method = "GET" + req.Body = nil // TODO: fix this when 307/308 support happens } // Copy the initial request's Header values // (at least the safe ones). Do this before diff --git a/src/net/http/request.go b/src/net/http/request.go index 83d6c81de9..551310cab0 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -151,6 +151,14 @@ type Request struct { // Handler does not need to. Body io.ReadCloser + // GetBody defines an optional func to return a new copy of + // Body. It used for client requests when a redirect requires + // reading the body more than once. Use of GetBody still + // requires setting Body. + // + // For server requests it is unused. + GetBody func() (io.ReadCloser, error) + // ContentLength records the length of the associated content. // The value -1 indicates that the length is unknown. // Values >= 0 indicate that the given number of bytes may @@ -738,10 +746,25 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { switch v := body.(type) { case *bytes.Buffer: req.ContentLength = int64(v.Len()) + buf := v.Bytes() + req.GetBody = func() (io.ReadCloser, error) { + r := bytes.NewReader(buf) + return ioutil.NopCloser(r), nil + } case *bytes.Reader: req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return ioutil.NopCloser(&r), nil + } case *strings.Reader: req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return ioutil.NopCloser(&r), nil + } default: req.ContentLength = -1 // unknown } @@ -751,6 +774,7 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { // to set the Body to nil. if req.ContentLength == 0 { req.Body = nil + req.GetBody = nil } } diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go index f12b41cf1b..e463d79492 100644 --- a/src/net/http/request_test.go +++ b/src/net/http/request_test.go @@ -784,6 +784,47 @@ func TestMaxBytesReaderStickyError(t *testing.T) { } } +// verify that NewRequest sets Request.GetBody and that it works +func TestNewRequestGetBody(t *testing.T) { + tests := []struct { + r io.Reader + }{ + {r: strings.NewReader("hello")}, + {r: bytes.NewReader([]byte("hello"))}, + {r: bytes.NewBuffer([]byte("hello"))}, + } + for i, tt := range tests { + req, err := NewRequest("POST", "http://foo.tld/", tt.r) + if err != nil { + t.Errorf("test[%d]: %v", i, err) + continue + } + if req.Body == nil { + t.Errorf("test[%d]: Body = nil", i) + continue + } + if req.GetBody == nil { + t.Errorf("test[%d]: GetBody = nil", i) + continue + } + slurp1, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("test[%d]: ReadAll(Body) = %v", i, err) + } + newBody, err := req.GetBody() + if err != nil { + t.Errorf("test[%d]: GetBody = %v", i, err) + } + slurp2, err := ioutil.ReadAll(newBody) + if err != nil { + t.Errorf("test[%d]: ReadAll(GetBody()) = %v", i, err) + } + if string(slurp1) != string(slurp2) { + t.Errorf("test[%d]: Body %q != GetBody %q", i, slurp1, slurp2) + } + } +} + func testMissingFile(t *testing.T, req *Request) { f, fh, err := req.FormFile("missing") if f != nil {