From b9ad2787dd6ada95c0712072079146cd2f586d7c Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Fri, 14 Oct 2011 14:16:43 -0700 Subject: [PATCH] http: RoundTrippers shouldn't mutate Request Fixes #2146 R=rsc CC=golang-dev https://golang.org/cl/5284041 --- src/pkg/http/client.go | 13 +++-- src/pkg/http/request.go | 14 +++-- src/pkg/http/transport.go | 99 ++++++++++++++++++---------------- src/pkg/http/transport_test.go | 32 ++++++----- 4 files changed, 90 insertions(+), 68 deletions(-) diff --git a/src/pkg/http/client.go b/src/pkg/http/client.go index 8997a07923..bce9014c4b 100644 --- a/src/pkg/http/client.go +++ b/src/pkg/http/client.go @@ -56,9 +56,10 @@ type RoundTripper interface { // higher-level protocol details such as redirects, // authentication, or cookies. // - // RoundTrip may modify the request. The request Headers field is - // guaranteed to be initialized. - RoundTrip(req *Request) (resp *Response, err os.Error) + // RoundTrip should not modify the request, except for + // consuming the Body. The request's URL and Header fields + // are guaranteed to be initialized. + RoundTrip(*Request) (*Response, os.Error) } // Given a string of the form "host", "host:port", or "[ipv6::address]:port", @@ -96,11 +97,15 @@ func send(req *Request, t RoundTripper) (resp *Response, err os.Error) { if t == nil { t = DefaultTransport if t == nil { - err = os.NewError("no http.Client.Transport or http.DefaultTransport") + err = os.NewError("http: no Client.Transport or DefaultTransport") return } } + if req.URL == nil { + return nil, os.NewError("http: nil Request.URL") + } + // Most the callers of send (Get, Post, et al) don't need // Headers, leaving it uninitialized. We guarantee to the // Transport that this has been initialized, though. diff --git a/src/pkg/http/request.go b/src/pkg/http/request.go index 4f555ff575..02317e0c41 100644 --- a/src/pkg/http/request.go +++ b/src/pkg/http/request.go @@ -275,7 +275,7 @@ const defaultUserAgent = "Go http package" // hasn't been set to "identity", Write adds "Transfer-Encoding: // chunked" to the header. Body is closed after it is sent. func (req *Request) Write(w io.Writer) os.Error { - return req.write(w, false) + return req.write(w, false, nil) } // WriteProxy is like Write but writes the request in the form @@ -285,7 +285,7 @@ func (req *Request) Write(w io.Writer) os.Error { // either case, WriteProxy also writes a Host header, using either // req.Host or req.URL.Host. func (req *Request) WriteProxy(w io.Writer) os.Error { - return req.write(w, true) + return req.write(w, true, nil) } func (req *Request) dumpWrite(w io.Writer) os.Error { @@ -333,7 +333,8 @@ func (req *Request) dumpWrite(w io.Writer) os.Error { return nil } -func (req *Request) write(w io.Writer, usingProxy bool) os.Error { +// extraHeaders may be nil +func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) os.Error { host := req.Host if host == "" { if req.URL == nil { @@ -394,6 +395,13 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error { return err } + if extraHeaders != nil { + err = extraHeaders.Write(bw) + if err != nil { + return err + } + } + io.WriteString(bw, "\r\n") // Write body and trailer diff --git a/src/pkg/http/transport.go b/src/pkg/http/transport.go index d46d565677..b0aea97087 100644 --- a/src/pkg/http/transport.go +++ b/src/pkg/http/transport.go @@ -100,11 +100,28 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, os.Error) { } } +// transportRequest is a wrapper around a *Request that adds +// optional extra headers to write. +type transportRequest struct { + *Request // original request, not to be mutated + extra Header // extra headers to write, or nil +} + +func (tr *transportRequest) extraHeaders() Header { + if tr.extra == nil { + tr.extra = make(Header) + } + return tr.extra +} + // RoundTrip implements the RoundTripper interface. func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { if req.URL == nil { return nil, os.NewError("http: nil Request.URL") } + if req.Header == nil { + return nil, os.NewError("http: nil Request.Header") + } if req.URL.Scheme != "http" && req.URL.Scheme != "https" { t.lk.Lock() var rt RoundTripper @@ -117,8 +134,8 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { } return rt.RoundTrip(req) } - - cm, err := t.connectMethodForRequest(req) + treq := &transportRequest{Request: req} + cm, err := t.connectMethodForRequest(treq) if err != nil { return nil, err } @@ -132,7 +149,7 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) { return nil, err } - return pconn.roundTrip(req) + return pconn.roundTrip(treq) } // RegisterProtocol registers a new protocol with scheme. @@ -185,14 +202,14 @@ func getenvEitherCase(k string) string { return os.Getenv(strings.ToLower(k)) } -func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) { +func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, os.Error) { cm := &connectMethod{ - targetScheme: req.URL.Scheme, - targetAddr: canonicalAddr(req.URL), + targetScheme: treq.URL.Scheme, + targetAddr: canonicalAddr(treq.URL), } if t.Proxy != nil { var err os.Error - cm.proxyURL, err = t.Proxy(req) + cm.proxyURL, err = t.Proxy(treq.Request) if err != nil { return nil, err } @@ -295,19 +312,15 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { conn: conn, reqch: make(chan requestAndChan, 50), } - newClientConnFunc := NewClientConn switch { case cm.proxyURL == nil: // Do nothing. case cm.targetScheme == "http": - newClientConnFunc = NewProxyClientConn + pconn.isProxy = true if pa != "" { - pconn.mutateRequestFunc = func(req *Request) { - if req.Header == nil { - req.Header = make(Header) - } - req.Header.Set("Proxy-Authorization", pa) + pconn.mutateHeaderFunc = func(h Header) { + h.Set("Proxy-Authorization", pa) } } case cm.targetScheme == "https": @@ -351,7 +364,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { } pconn.br = bufio.NewReader(pconn.conn) - pconn.cc = newClientConnFunc(conn, pconn.br) + pconn.cc = NewClientConn(conn, pconn.br) go pconn.readLoop() return pconn, nil } @@ -447,30 +460,21 @@ func (cm *connectMethod) tlsHost() string { return h } -type readResult struct { - res *Response // either res or err will be set - err os.Error -} - -type writeRequest struct { - // Set by client (in pc.roundTrip) - req *Request - resch chan *readResult - - // Set by writeLoop if an error writing headers. - writeErr os.Error -} - // persistConn wraps a connection, usually a persistent one // (but may be used for non-keep-alive requests as well) type persistConn struct { - t *Transport - cacheKey string // its connectMethod.String() - conn net.Conn - cc *ClientConn - br *bufio.Reader - reqch chan requestAndChan // written by roundTrip(); read by readLoop() - mutateRequestFunc func(*Request) // nil or func to modify each outbound request + t *Transport + cacheKey string // its connectMethod.String() + conn net.Conn + cc *ClientConn + br *bufio.Reader + reqch chan requestAndChan // written by roundTrip(); read by readLoop() + isProxy bool + + // mutateHeaderFunc is an optional func to modify extra + // headers on each outbound request before it's written. (the + // original Request given to RoundTrip is not modified) + mutateHeaderFunc func(Header) lk sync.Mutex // guards numExpectedResponses and broken numExpectedResponses int @@ -526,9 +530,6 @@ func (pc *persistConn) readLoop() { if err != nil || resp.ContentLength == 0 { return resp, err } - if rc.addedGzip { - forReq.Header.Del("Accept-Encoding") - } if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" { resp.Header.Del("Content-Encoding") resp.Header.Del("Content-Length") @@ -604,9 +605,9 @@ type requestAndChan struct { addedGzip bool } -func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { - if pc.mutateRequestFunc != nil { - pc.mutateRequestFunc(req) +func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err os.Error) { + if pc.mutateHeaderFunc != nil { + pc.mutateHeaderFunc(req.extraHeaders()) } // Ask for a compressed version if the caller didn't set their @@ -616,24 +617,28 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { requestedGzip := false if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" { // Request gzip only, not deflate. Deflate is ambiguous and - // as universally supported anyway. + // not as universally supported anyway. // See: http://www.gzip.org/zlib/zlib_faq.html#faq38 requestedGzip = true - req.Header.Set("Accept-Encoding", "gzip") + req.extraHeaders().Set("Accept-Encoding", "gzip") } pc.lk.Lock() pc.numExpectedResponses++ pc.lk.Unlock() - err = pc.cc.Write(req) + pc.cc.writeReq = func(r *Request, w io.Writer) os.Error { + return r.write(w, pc.isProxy, req.extra) + } + + err = pc.cc.Write(req.Request) if err != nil { pc.close() return } ch := make(chan responseAndError, 1) - pc.reqch <- requestAndChan{req, ch, requestedGzip} + pc.reqch <- requestAndChan{req.Request, ch, requestedGzip} re := <-ch pc.lk.Lock() pc.numExpectedResponses-- @@ -648,7 +653,7 @@ func (pc *persistConn) close() { pc.broken = true pc.cc.Close() pc.conn.Close() - pc.mutateRequestFunc = nil + pc.mutateHeaderFunc = nil } var portMap = map[string]string{ diff --git a/src/pkg/http/transport_test.go b/src/pkg/http/transport_test.go index a5dfe5ee3c..f3162b9ede 100644 --- a/src/pkg/http/transport_test.go +++ b/src/pkg/http/transport_test.go @@ -372,7 +372,8 @@ var roundTripTests = []struct { // Requests with other accept-encoding should pass through unmodified {"foo", "foo", false}, // Requests with accept-encoding == gzip should be passed through - {"gzip", "gzip", true}} + {"gzip", "gzip", true}, +} // Test that the modification made to the Request by the RoundTripper is cleaned up func TestRoundTripGzip(t *testing.T) { @@ -380,7 +381,8 @@ func TestRoundTripGzip(t *testing.T) { ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { accept := req.Header.Get("Accept-Encoding") if expect := req.FormValue("expect_accept"); accept != expect { - t.Errorf("Accept-Encoding = %q, want %q", accept, expect) + t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q", + req.FormValue("testnum"), accept, expect) } if accept == "gzip" { rw.Header().Set("Content-Encoding", "gzip") @@ -396,8 +398,10 @@ func TestRoundTripGzip(t *testing.T) { for i, test := range roundTripTests { // Test basic request (no accept-encoding) - req, _ := NewRequest("GET", ts.URL+"?expect_accept="+test.expectAccept, nil) - req.Header.Set("Accept-Encoding", test.accept) + req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil) + if test.accept != "" { + req.Header.Set("Accept-Encoding", test.accept) + } res, err := DefaultTransport.RoundTrip(req) var body []byte if test.compressed { @@ -409,16 +413,16 @@ func TestRoundTripGzip(t *testing.T) { } if err != nil { t.Errorf("%d. Error: %q", i, err) - } else { - if g, e := string(body), responseBody; g != e { - t.Errorf("%d. body = %q; want %q", i, g, e) - } - if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { - t.Errorf("%d. Accept-Encoding = %q; want %q", i, g, e) - } - if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { - t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) - } + continue + } + if g, e := string(body), responseBody; g != e { + t.Errorf("%d. body = %q; want %q", i, g, e) + } + if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e { + t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e) + } + if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e { + t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e) } }