mirror of
https://github.com/golang/go
synced 2024-11-23 16:40:03 -07:00
net/http: add Request.GetBody func for 307/308 redirects
Updates #10767 Change-Id: I197535f71bc2dc45e783f38d8031aa717d50fd80 Reviewed-on: https://go-review.googlesource.com/31733 Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com> Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
ca4431a384
commit
aa1e063efd
@ -485,8 +485,15 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo
|
||||
Cancel: ireq.Cancel,
|
||||
ctx: ireq.ctx,
|
||||
}
|
||||
if ireq.GetBody != nil {
|
||||
req.Body, err = ireq.GetBody()
|
||||
if err != nil {
|
||||
return nil, uerr(err)
|
||||
}
|
||||
}
|
||||
if ireq.Method == "POST" || ireq.Method == "PUT" {
|
||||
req.Method = "GET"
|
||||
req.Body = nil // TODO: fix this when 307/308 support happens
|
||||
}
|
||||
// Copy the initial request's Header values
|
||||
// (at least the safe ones). Do this before
|
||||
|
@ -151,6 +151,14 @@ type Request struct {
|
||||
// Handler does not need to.
|
||||
Body io.ReadCloser
|
||||
|
||||
// GetBody defines an optional func to return a new copy of
|
||||
// Body. It used for client requests when a redirect requires
|
||||
// reading the body more than once. Use of GetBody still
|
||||
// requires setting Body.
|
||||
//
|
||||
// For server requests it is unused.
|
||||
GetBody func() (io.ReadCloser, error)
|
||||
|
||||
// ContentLength records the length of the associated content.
|
||||
// The value -1 indicates that the length is unknown.
|
||||
// Values >= 0 indicate that the given number of bytes may
|
||||
@ -738,10 +746,25 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
|
||||
switch v := body.(type) {
|
||||
case *bytes.Buffer:
|
||||
req.ContentLength = int64(v.Len())
|
||||
buf := v.Bytes()
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
r := bytes.NewReader(buf)
|
||||
return ioutil.NopCloser(r), nil
|
||||
}
|
||||
case *bytes.Reader:
|
||||
req.ContentLength = int64(v.Len())
|
||||
snapshot := *v
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
r := snapshot
|
||||
return ioutil.NopCloser(&r), nil
|
||||
}
|
||||
case *strings.Reader:
|
||||
req.ContentLength = int64(v.Len())
|
||||
snapshot := *v
|
||||
req.GetBody = func() (io.ReadCloser, error) {
|
||||
r := snapshot
|
||||
return ioutil.NopCloser(&r), nil
|
||||
}
|
||||
default:
|
||||
req.ContentLength = -1 // unknown
|
||||
}
|
||||
@ -751,6 +774,7 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
|
||||
// to set the Body to nil.
|
||||
if req.ContentLength == 0 {
|
||||
req.Body = nil
|
||||
req.GetBody = nil
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -784,6 +784,47 @@ func TestMaxBytesReaderStickyError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// verify that NewRequest sets Request.GetBody and that it works
|
||||
func TestNewRequestGetBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
r io.Reader
|
||||
}{
|
||||
{r: strings.NewReader("hello")},
|
||||
{r: bytes.NewReader([]byte("hello"))},
|
||||
{r: bytes.NewBuffer([]byte("hello"))},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
req, err := NewRequest("POST", "http://foo.tld/", tt.r)
|
||||
if err != nil {
|
||||
t.Errorf("test[%d]: %v", i, err)
|
||||
continue
|
||||
}
|
||||
if req.Body == nil {
|
||||
t.Errorf("test[%d]: Body = nil", i)
|
||||
continue
|
||||
}
|
||||
if req.GetBody == nil {
|
||||
t.Errorf("test[%d]: GetBody = nil", i)
|
||||
continue
|
||||
}
|
||||
slurp1, err := ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
t.Errorf("test[%d]: ReadAll(Body) = %v", i, err)
|
||||
}
|
||||
newBody, err := req.GetBody()
|
||||
if err != nil {
|
||||
t.Errorf("test[%d]: GetBody = %v", i, err)
|
||||
}
|
||||
slurp2, err := ioutil.ReadAll(newBody)
|
||||
if err != nil {
|
||||
t.Errorf("test[%d]: ReadAll(GetBody()) = %v", i, err)
|
||||
}
|
||||
if string(slurp1) != string(slurp2) {
|
||||
t.Errorf("test[%d]: Body %q != GetBody %q", i, slurp1, slurp2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testMissingFile(t *testing.T, req *Request) {
|
||||
f, fh, err := req.FormFile("missing")
|
||||
if f != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user