diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go index daba3a89b0..d73cbc8550 100644 --- a/src/pkg/http/client.go +++ b/src/pkg/http/client.go @@ -22,6 +22,16 @@ import ( // Client is not yet very configurable. type Client struct { Transport RoundTripper // if nil, DefaultTransport is used + + // If CheckRedirect is not nil, the client calls it before + // following an HTTP redirect. The arguments req and via + // are the upcoming request and the requests made already, + // oldest first. If CheckRedirect returns an error, the client + // returns that error instead of issue the Request req. + // + // If CheckRedirect is nil, the Client uses its default policy, + // which is to stop after 10 consecutive requests. + CheckRedirect func(req *Request, via []*Request) os.Error } // DefaultClient is the default Client and is used by Get, Head, and Post. @@ -109,7 +119,7 @@ func shouldRedirect(statusCode int) bool { } // Get issues a GET to the specified URL. If the response is one of the following -// redirect codes, it follows the redirect, up to a maximum of 10 redirects: +// redirect codes, Get follows the redirect, up to a maximum of 10 redirects: // // 301 (Moved Permanently) // 302 (Found) @@ -126,35 +136,33 @@ func Get(url string) (r *Response, finalURL string, err os.Error) { return DefaultClient.Get(url) } -// Get issues a GET to the specified URL. If the response is one of the following -// redirect codes, it follows the redirect, up to a maximum of 10 redirects: +// Get issues a GET to the specified URL. If the response is one of the +// following redirect codes, Get follows the redirect after calling the +// Client's CheckRedirect function. // // 301 (Moved Permanently) // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) // -// finalURL is the URL from which the response was fetched -- identical to the -// input URL unless redirects were followed. +// finalURL is the URL from which the response was fetched -- identical +// to the input URL unless redirects were followed. // // Caller should close r.Body when done reading from it. func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { // TODO: if/when we add cookie support, the redirected request shouldn't // necessarily supply the same cookies as the original. - // TODO: set referrer header on redirects. var base *URL - // TODO: remove this hard-coded 10 and use the Client's policy - // (ClientConfig) instead. - for redirect := 0; ; redirect++ { - if redirect >= 10 { - err = os.ErrorString("stopped after 10 redirects") - break - } + redirectChecker := c.CheckRedirect + if redirectChecker == nil { + redirectChecker = defaultCheckRedirect + } + var via []*Request + for redirect := 0; ; redirect++ { var req Request req.Method = "GET" - req.ProtoMajor = 1 - req.ProtoMinor = 1 + req.Header = make(Header) if base == nil { req.URL, err = ParseURL(url) } else { @@ -163,6 +171,19 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { if err != nil { break } + if len(via) > 0 { + // Add the Referer header. + lastReq := via[len(via)-1] + if lastReq.URL.Scheme != "https" { + req.Referer = lastReq.URL.String() + } + + err = redirectChecker(&req, via) + if err != nil { + break + } + } + url = req.URL.String() if r, err = send(&req, c.Transport); err != nil { break @@ -174,6 +195,7 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { break } base = req.URL + via = append(via, &req) continue } finalURL = url @@ -184,6 +206,13 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { return } +func defaultCheckRedirect(req *Request, via []*Request) os.Error { + if len(via) >= 10 { + return os.ErrorString("stopped after 10 redirects") + } + return nil +} + // Post issues a POST to the specified URL. // // Caller should close r.Body when done reading from it. diff --git a/src/pkg/http/client_test.go b/src/pkg/http/client_test.go index 3a6f834253..59d62c1c9d 100644 --- a/src/pkg/http/client_test.go +++ b/src/pkg/http/client_test.go @@ -12,6 +12,7 @@ import ( "http/httptest" "io/ioutil" "os" + "strconv" "strings" "testing" ) @@ -75,3 +76,51 @@ func TestGetRequestFormat(t *testing.T) { t.Errorf("expected non-nil request Header") } } + +func TestRedirects(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + n, _ := strconv.Atoi(r.FormValue("n")) + // Test Referer header. (7 is arbitrary position to test at) + if n == 7 { + if g, e := r.Referer, ts.URL+"/?n=6"; e != g { + t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g) + } + } + if n < 15 { + Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound) + return + } + fmt.Fprintf(w, "n=%d", n) + })) + defer ts.Close() + + c := &Client{} + _, _, err := c.Get(ts.URL) + if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client, expected error %q, got %q", e, g) + } + + var checkErr os.Error + var lastVia []*Request + c = &Client{CheckRedirect: func(_ *Request, via []*Request) os.Error { + lastVia = via + return checkErr + }} + _, finalUrl, err := c.Get(ts.URL) + if e, g := "", fmt.Sprintf("%v", err); e != g { + t.Errorf("with custom client, expected error %q, got %q", e, g) + } + if !strings.HasSuffix(finalUrl, "/?n=15") { + t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl) + } + if e, g := 15, len(lastVia); e != g { + t.Errorf("expected lastVia to have contained %d elements; got %d", e, g) + } + + checkErr = os.NewError("no redirects allowed") + _, finalUrl, err = c.Get(ts.URL) + if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g { + t.Errorf("with redirects forbidden, expected error %q, got %q", e, g) + } +}