diff --git a/src/pkg/http/response.go b/src/pkg/http/response.go index b01a303a12..56c65b53c7 100644 --- a/src/pkg/http/response.go +++ b/src/pkg/http/response.go @@ -13,6 +13,7 @@ import ( "os" "strconv" "strings" + "url" ) var respExcludeHeader = map[string]bool{ @@ -77,6 +78,23 @@ func (r *Response) Cookies() []*Cookie { return readSetCookies(r.Header) } +var ErrNoLocation = os.NewError("http: no Location header in response") + +// Location returns the URL of the response's "Location" header, +// if present. Relative redirects are resolved relative to +// the Response's Request. ErrNoLocation is returned if no +// Location header is present. +func (r *Response) Location() (*url.URL, os.Error) { + lv := r.Header.Get("Location") + if lv == "" { + return nil, ErrNoLocation + } + if r.Request != nil && r.Request.URL != nil { + return r.Request.URL.Parse(lv) + } + return url.Parse(lv) +} + // ReadResponse reads and returns an HTTP response from r. The // req parameter specifies the Request that corresponds to // this Response. Clients must call resp.Body.Close when finished diff --git a/src/pkg/http/response_test.go b/src/pkg/http/response_test.go index 1d4a234235..86494bf4ae 100644 --- a/src/pkg/http/response_test.go +++ b/src/pkg/http/response_test.go @@ -15,6 +15,7 @@ import ( "io/ioutil" "reflect" "testing" + "url" ) type respTest struct { @@ -395,3 +396,52 @@ func diff(t *testing.T, prefix string, have, want interface{}) { } } } + +type responseLocationTest struct { + location string // Response's Location header or "" + requrl string // Response.Request.URL or "" + want string + wantErr os.Error +} + +var responseLocationTests = []responseLocationTest{ + {"/foo", "http://bar.com/baz", "http://bar.com/foo", nil}, + {"http://foo.com/", "http://bar.com/baz", "http://foo.com/", nil}, + {"", "http://bar.com/baz", "", ErrNoLocation}, +} + +func TestLocationResponse(t *testing.T) { + for i, tt := range responseLocationTests { + res := new(Response) + res.Header = make(Header) + res.Header.Set("Location", tt.location) + if tt.requrl != "" { + res.Request = &Request{} + var err os.Error + res.Request.URL, err = url.Parse(tt.requrl) + if err != nil { + t.Fatalf("bad test URL %q: %v", tt.requrl, err) + } + } + + got, err := res.Location() + if tt.wantErr != nil { + if err == nil { + t.Errorf("%d. err=nil; want %q", i, tt.wantErr) + continue + } + if g, e := err.String(), tt.wantErr.String(); g != e { + t.Errorf("%d. err=%q; want %q", i, g, e) + continue + } + continue + } + if err != nil { + t.Errorf("%d. err=%q", i, err) + continue + } + if g, e := got.String(), tt.want; g != e { + t.Errorf("%d. Location=%q; want %q", i, g, e) + } + } +}