diff --git a/src/pkg/http/proxy_test.go b/src/pkg/http/proxy_test.go index 308bf44b48a..9b320b3aa5b 100644 --- a/src/pkg/http/proxy_test.go +++ b/src/pkg/http/proxy_test.go @@ -40,10 +40,8 @@ func TestUseProxy(t *testing.T) { no_proxy := "foobar.com, .barbaz.net" os.Setenv("NO_PROXY", no_proxy) - tr := &Transport{} - for _, test := range UseProxyTests { - if tr.useProxy(test.host+":80") != test.match { + if useProxy(test.host+":80") != test.match { t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) } } diff --git a/src/pkg/http/transport.go b/src/pkg/http/transport.go index fa912b1e18d..34bfbdd34a4 100644 --- a/src/pkg/http/transport.go +++ b/src/pkg/http/transport.go @@ -24,7 +24,7 @@ import ( // each call to Do and uses HTTP proxies as directed by the // $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy) // environment variables. -var DefaultTransport RoundTripper = &Transport{} +var DefaultTransport RoundTripper = &Transport{Proxy: ProxyFromEnvironment} // DefaultMaxIdleConnsPerHost is the default value of Transport's // MaxIdleConnsPerHost. @@ -41,7 +41,12 @@ type Transport struct { // TODO: tunable on timeout on cached connections // TODO: optional pipelining - IgnoreEnvironment bool // don't look at environment variables for proxy configuration + // Proxy optionally specifies a function to return a proxy for + // a given Request. If the function returns a non-nil error, + // the request is aborted with the provided error. If Proxy is + // nil or returns a nil *URL, no proxy is used. + Proxy func(*Request) (*URL, os.Error) + DisableKeepAlives bool DisableCompression bool @@ -51,6 +56,39 @@ type Transport struct { MaxIdleConnsPerHost int } +// ProxyFromEnvironment returns the URL of the proxy to use for a +// given request, as indicated by the environment variables +// $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy). +// Either URL or an error is returned. +func ProxyFromEnvironment(req *Request) (*URL, os.Error) { + proxy := getenvEitherCase("HTTP_PROXY") + if proxy == "" { + return nil, nil + } + if !useProxy(canonicalAddr(req.URL)) { + return nil, nil + } + proxyURL, err := ParseRequestURL(proxy) + if err != nil { + return nil, os.ErrorString("invalid proxy address") + } + if proxyURL.Host == "" { + proxyURL, err = ParseRequestURL("http://" + proxy) + if err != nil { + return nil, os.ErrorString("invalid proxy address") + } + } + return proxyURL, nil +} + +// ProxyURL returns a proxy function (for use in a Transport) +// that always returns the same URL. +func ProxyURL(url *URL) func(*Request) (*URL, os.Error) { + return func(*Request) (*URL, os.Error) { + return url, nil + } +} + // RoundTrip implements the RoundTripper interface. func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { if req.URL == nil { @@ -101,21 +139,11 @@ func (t *Transport) CloseIdleConnections() { // Private implementation past this point. // -func (t *Transport) getenvEitherCase(k string) string { - if t.IgnoreEnvironment { - return "" - } - if v := t.getenv(strings.ToUpper(k)); v != "" { +func getenvEitherCase(k string) string { + if v := os.Getenv(strings.ToUpper(k)); v != "" { return v } - return t.getenv(strings.ToLower(k)) -} - -func (t *Transport) getenv(k string) string { - if t.IgnoreEnvironment { - return "" - } - return os.Getenv(k) + return os.Getenv(strings.ToLower(k)) } func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) { @@ -123,20 +151,12 @@ func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Er targetScheme: req.URL.Scheme, targetAddr: canonicalAddr(req.URL), } - - proxy := t.getenvEitherCase("HTTP_PROXY") - if proxy != "" && t.useProxy(cm.targetAddr) { - proxyURL, err := ParseRequestURL(proxy) + if t.Proxy != nil { + var err os.Error + cm.proxyURL, err = t.Proxy(req) if err != nil { - return nil, os.ErrorString("invalid proxy address") + return nil, err } - if proxyURL.Host == "" { - proxyURL, err = ParseRequestURL("http://" + proxy) - if err != nil { - return nil, os.ErrorString("invalid proxy address") - } - } - cm.proxyURL = proxyURL } return cm, nil } @@ -296,7 +316,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { // useProxy returns true if requests to addr should use a proxy, // according to the NO_PROXY or no_proxy environment variable. // addr is always a canonicalAddr with a host and port. -func (t *Transport) useProxy(addr string) bool { +func useProxy(addr string) bool { if len(addr) == 0 { return true } @@ -313,7 +333,7 @@ func (t *Transport) useProxy(addr string) bool { } } - no_proxy := t.getenvEitherCase("NO_PROXY") + no_proxy := getenvEitherCase("NO_PROXY") if no_proxy == "*" { return false } diff --git a/src/pkg/http/transport_test.go b/src/pkg/http/transport_test.go index 13865505efc..9cd18ffecf4 100644 --- a/src/pkg/http/transport_test.go +++ b/src/pkg/http/transport_test.go @@ -478,6 +478,30 @@ func TestTransportGzip(t *testing.T) { } } +func TestTransportProxy(t *testing.T) { + ch := make(chan string, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ch <- "real server" + })) + defer ts.Close() + proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ch <- "proxy for " + r.URL.String() + })) + defer proxy.Close() + + pu, err := ParseURL(proxy.URL) + if err != nil { + t.Fatal(err) + } + c := &Client{Transport: &Transport{Proxy: ProxyURL(pu)}} + c.Head(ts.URL) + got := <-ch + want := "proxy for " + ts.URL + "/" + if got != want { + t.Errorf("want %q, got %q", want, got) + } +} + // TestTransportGzipRecursive sends a gzip quine and checks that the // client gets the same value back. This is more cute than anything, // but checks that we don't recurse forever, and checks that