1
0
mirror of https://github.com/golang/go synced 2024-11-26 22:11:25 -07:00

net/http: follow certain redirects after POST requests

Fixes #4145

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/6923055
This commit is contained in:
Brad Fitzpatrick 2012-12-12 11:09:55 -08:00
parent 3906706297
commit 08ce7f1d5c
2 changed files with 69 additions and 6 deletions

View File

@ -120,7 +120,10 @@ func (c *Client) send(req *Request) (*Response, error) {
// Generally Get, Post, or PostForm will be used instead of Do.
func (c *Client) Do(req *Request) (resp *Response, err error) {
if req.Method == "GET" || req.Method == "HEAD" {
return c.doFollowingRedirects(req)
return c.doFollowingRedirects(req, shouldRedirectGet)
}
if req.Method == "POST" || req.Method == "PUT" {
return c.doFollowingRedirects(req, shouldRedirectPost)
}
return c.send(req)
}
@ -166,7 +169,7 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) {
// True if the specified HTTP status code is one for which the Get utility should
// automatically redirect.
func shouldRedirect(statusCode int) bool {
func shouldRedirectGet(statusCode int) bool {
switch statusCode {
case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect:
return true
@ -174,6 +177,16 @@ func shouldRedirect(statusCode int) bool {
return false
}
// True if the specified HTTP status code is one for which the Post utility should
// automatically redirect.
func shouldRedirectPost(statusCode int) bool {
switch statusCode {
case StatusFound, StatusSeeOther:
return true
}
return false
}
// Get issues a GET to the specified URL. If the response is one of the following
// redirect codes, Get follows the redirect, up to a maximum of 10 redirects:
//
@ -214,10 +227,10 @@ func (c *Client) Get(url string) (resp *Response, err error) {
if err != nil {
return nil, err
}
return c.doFollowingRedirects(req)
return c.doFollowingRedirects(req, shouldRedirectGet)
}
func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error) {
func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) {
// TODO: if/when we add cookie support, the redirected request shouldn't
// necessarily supply the same cookies as the original.
var base *url.URL
@ -238,6 +251,9 @@ func (c *Client) doFollowingRedirects(ireq *Request) (resp *Response, err error)
if redirect != 0 {
req = new(Request)
req.Method = ireq.Method
if ireq.Method == "POST" || ireq.Method == "PUT" {
req.Method = "GET"
}
req.Header = make(Header)
req.URL, err = base.Parse(urlStr)
if err != nil {
@ -321,7 +337,7 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Respon
return nil, err
}
req.Header.Set("Content-Type", bodyType)
return c.send(req)
return c.doFollowingRedirects(req, shouldRedirectPost)
}
// PostForm issues a POST to the specified URL, with data's keys and
@ -371,5 +387,5 @@ func (c *Client) Head(url string) (resp *Response, err error) {
if err != nil {
return nil, err
}
return c.doFollowingRedirects(req)
return c.doFollowingRedirects(req, shouldRedirectGet)
}

View File

@ -7,6 +7,7 @@
package http_test
import (
"bytes"
"crypto/tls"
"crypto/x509"
"errors"
@ -246,6 +247,52 @@ func TestRedirects(t *testing.T) {
}
}
func TestPostRedirects(t *testing.T) {
var log struct {
sync.Mutex
bytes.Buffer
}
var ts *httptest.Server
ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
log.Lock()
fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI)
log.Unlock()
if v := r.URL.Query().Get("code"); v != "" {
code, _ := strconv.Atoi(v)
if code/100 == 3 {
w.Header().Set("Location", ts.URL)
}
w.WriteHeader(code)
}
}))
tests := []struct {
suffix string
want int // response code
}{
{"/", 200},
{"/?code=301", 301},
{"/?code=302", 200},
{"/?code=303", 200},
{"/?code=404", 404},
}
for _, tt := range tests {
res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content"))
if err != nil {
t.Fatal(err)
}
if res.StatusCode != tt.want {
t.Errorf("POST %s: status code = %d; want %d", tt.suffix, res.StatusCode, tt.want)
}
}
log.Lock()
got := log.String()
log.Unlock()
want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 "
if got != want {
t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want)
}
}
var expectedCookies = []*Cookie{
{Name: "ChocolateChip", Value: "tasty"},
{Name: "First", Value: "Hit"},