1
0
mirror of https://github.com/golang/go synced 2024-11-18 09:44:50 -07:00

net/http: add Request.GetBody func for 307/308 redirects

Updates #10767

Change-Id: I197535f71bc2dc45e783f38d8031aa717d50fd80
Reviewed-on: https://go-review.googlesource.com/31733
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Brad Fitzpatrick 2016-10-21 12:03:41 +01:00
parent ca4431a384
commit aa1e063efd
3 changed files with 72 additions and 0 deletions

View File

@ -485,8 +485,15 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo
Cancel: ireq.Cancel, Cancel: ireq.Cancel,
ctx: ireq.ctx, ctx: ireq.ctx,
} }
if ireq.GetBody != nil {
req.Body, err = ireq.GetBody()
if err != nil {
return nil, uerr(err)
}
}
if ireq.Method == "POST" || ireq.Method == "PUT" { if ireq.Method == "POST" || ireq.Method == "PUT" {
req.Method = "GET" req.Method = "GET"
req.Body = nil // TODO: fix this when 307/308 support happens
} }
// Copy the initial request's Header values // Copy the initial request's Header values
// (at least the safe ones). Do this before // (at least the safe ones). Do this before

View File

@ -151,6 +151,14 @@ type Request struct {
// Handler does not need to. // Handler does not need to.
Body io.ReadCloser Body io.ReadCloser
// GetBody defines an optional func to return a new copy of
// Body. It used for client requests when a redirect requires
// reading the body more than once. Use of GetBody still
// requires setting Body.
//
// For server requests it is unused.
GetBody func() (io.ReadCloser, error)
// ContentLength records the length of the associated content. // ContentLength records the length of the associated content.
// The value -1 indicates that the length is unknown. // The value -1 indicates that the length is unknown.
// Values >= 0 indicate that the given number of bytes may // Values >= 0 indicate that the given number of bytes may
@ -738,10 +746,25 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
switch v := body.(type) { switch v := body.(type) {
case *bytes.Buffer: case *bytes.Buffer:
req.ContentLength = int64(v.Len()) req.ContentLength = int64(v.Len())
buf := v.Bytes()
req.GetBody = func() (io.ReadCloser, error) {
r := bytes.NewReader(buf)
return ioutil.NopCloser(r), nil
}
case *bytes.Reader: case *bytes.Reader:
req.ContentLength = int64(v.Len()) req.ContentLength = int64(v.Len())
snapshot := *v
req.GetBody = func() (io.ReadCloser, error) {
r := snapshot
return ioutil.NopCloser(&r), nil
}
case *strings.Reader: case *strings.Reader:
req.ContentLength = int64(v.Len()) req.ContentLength = int64(v.Len())
snapshot := *v
req.GetBody = func() (io.ReadCloser, error) {
r := snapshot
return ioutil.NopCloser(&r), nil
}
default: default:
req.ContentLength = -1 // unknown req.ContentLength = -1 // unknown
} }
@ -751,6 +774,7 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
// to set the Body to nil. // to set the Body to nil.
if req.ContentLength == 0 { if req.ContentLength == 0 {
req.Body = nil req.Body = nil
req.GetBody = nil
} }
} }

View File

@ -784,6 +784,47 @@ func TestMaxBytesReaderStickyError(t *testing.T) {
} }
} }
// verify that NewRequest sets Request.GetBody and that it works
func TestNewRequestGetBody(t *testing.T) {
tests := []struct {
r io.Reader
}{
{r: strings.NewReader("hello")},
{r: bytes.NewReader([]byte("hello"))},
{r: bytes.NewBuffer([]byte("hello"))},
}
for i, tt := range tests {
req, err := NewRequest("POST", "http://foo.tld/", tt.r)
if err != nil {
t.Errorf("test[%d]: %v", i, err)
continue
}
if req.Body == nil {
t.Errorf("test[%d]: Body = nil", i)
continue
}
if req.GetBody == nil {
t.Errorf("test[%d]: GetBody = nil", i)
continue
}
slurp1, err := ioutil.ReadAll(req.Body)
if err != nil {
t.Errorf("test[%d]: ReadAll(Body) = %v", i, err)
}
newBody, err := req.GetBody()
if err != nil {
t.Errorf("test[%d]: GetBody = %v", i, err)
}
slurp2, err := ioutil.ReadAll(newBody)
if err != nil {
t.Errorf("test[%d]: ReadAll(GetBody()) = %v", i, err)
}
if string(slurp1) != string(slurp2) {
t.Errorf("test[%d]: Body %q != GetBody %q", i, slurp1, slurp2)
}
}
}
func testMissingFile(t *testing.T, req *Request) { func testMissingFile(t *testing.T, req *Request) {
f, fh, err := req.FormFile("missing") f, fh, err := req.FormFile("missing")
if f != nil { if f != nil {