1
0
mirror of https://github.com/golang/go synced 2024-11-21 22:24:40 -07:00

http: RoundTrippers shouldn't mutate Request

Fixes #2146

R=rsc
CC=golang-dev
https://golang.org/cl/5284041
This commit is contained in:
Brad Fitzpatrick 2011-10-14 14:16:43 -07:00
parent 236aff31c5
commit b9ad2787dd
4 changed files with 90 additions and 68 deletions

View File

@ -56,9 +56,10 @@ type RoundTripper interface {
// higher-level protocol details such as redirects,
// authentication, or cookies.
//
// RoundTrip may modify the request. The request Headers field is
// guaranteed to be initialized.
RoundTrip(req *Request) (resp *Response, err os.Error)
// RoundTrip should not modify the request, except for
// consuming the Body. The request's URL and Header fields
// are guaranteed to be initialized.
RoundTrip(*Request) (*Response, os.Error)
}
// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
@ -96,11 +97,15 @@ func send(req *Request, t RoundTripper) (resp *Response, err os.Error) {
if t == nil {
t = DefaultTransport
if t == nil {
err = os.NewError("no http.Client.Transport or http.DefaultTransport")
err = os.NewError("http: no Client.Transport or DefaultTransport")
return
}
}
if req.URL == nil {
return nil, os.NewError("http: nil Request.URL")
}
// Most the callers of send (Get, Post, et al) don't need
// Headers, leaving it uninitialized. We guarantee to the
// Transport that this has been initialized, though.

View File

@ -275,7 +275,7 @@ const defaultUserAgent = "Go http package"
// hasn't been set to "identity", Write adds "Transfer-Encoding:
// chunked" to the header. Body is closed after it is sent.
func (req *Request) Write(w io.Writer) os.Error {
return req.write(w, false)
return req.write(w, false, nil)
}
// WriteProxy is like Write but writes the request in the form
@ -285,7 +285,7 @@ func (req *Request) Write(w io.Writer) os.Error {
// either case, WriteProxy also writes a Host header, using either
// req.Host or req.URL.Host.
func (req *Request) WriteProxy(w io.Writer) os.Error {
return req.write(w, true)
return req.write(w, true, nil)
}
func (req *Request) dumpWrite(w io.Writer) os.Error {
@ -333,7 +333,8 @@ func (req *Request) dumpWrite(w io.Writer) os.Error {
return nil
}
func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
// extraHeaders may be nil
func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) os.Error {
host := req.Host
if host == "" {
if req.URL == nil {
@ -394,6 +395,13 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
return err
}
if extraHeaders != nil {
err = extraHeaders.Write(bw)
if err != nil {
return err
}
}
io.WriteString(bw, "\r\n")
// Write body and trailer

View File

@ -100,11 +100,28 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, os.Error) {
}
}
// transportRequest is a wrapper around a *Request that adds
// optional extra headers to write.
type transportRequest struct {
*Request // original request, not to be mutated
extra Header // extra headers to write, or nil
}
func (tr *transportRequest) extraHeaders() Header {
if tr.extra == nil {
tr.extra = make(Header)
}
return tr.extra
}
// RoundTrip implements the RoundTripper interface.
func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
if req.URL == nil {
return nil, os.NewError("http: nil Request.URL")
}
if req.Header == nil {
return nil, os.NewError("http: nil Request.Header")
}
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
t.lk.Lock()
var rt RoundTripper
@ -117,8 +134,8 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
}
return rt.RoundTrip(req)
}
cm, err := t.connectMethodForRequest(req)
treq := &transportRequest{Request: req}
cm, err := t.connectMethodForRequest(treq)
if err != nil {
return nil, err
}
@ -132,7 +149,7 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
return nil, err
}
return pconn.roundTrip(req)
return pconn.roundTrip(treq)
}
// RegisterProtocol registers a new protocol with scheme.
@ -185,14 +202,14 @@ func getenvEitherCase(k string) string {
return os.Getenv(strings.ToLower(k))
}
func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) {
func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, os.Error) {
cm := &connectMethod{
targetScheme: req.URL.Scheme,
targetAddr: canonicalAddr(req.URL),
targetScheme: treq.URL.Scheme,
targetAddr: canonicalAddr(treq.URL),
}
if t.Proxy != nil {
var err os.Error
cm.proxyURL, err = t.Proxy(req)
cm.proxyURL, err = t.Proxy(treq.Request)
if err != nil {
return nil, err
}
@ -295,19 +312,15 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
conn: conn,
reqch: make(chan requestAndChan, 50),
}
newClientConnFunc := NewClientConn
switch {
case cm.proxyURL == nil:
// Do nothing.
case cm.targetScheme == "http":
newClientConnFunc = NewProxyClientConn
pconn.isProxy = true
if pa != "" {
pconn.mutateRequestFunc = func(req *Request) {
if req.Header == nil {
req.Header = make(Header)
}
req.Header.Set("Proxy-Authorization", pa)
pconn.mutateHeaderFunc = func(h Header) {
h.Set("Proxy-Authorization", pa)
}
}
case cm.targetScheme == "https":
@ -351,7 +364,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
}
pconn.br = bufio.NewReader(pconn.conn)
pconn.cc = newClientConnFunc(conn, pconn.br)
pconn.cc = NewClientConn(conn, pconn.br)
go pconn.readLoop()
return pconn, nil
}
@ -447,30 +460,21 @@ func (cm *connectMethod) tlsHost() string {
return h
}
type readResult struct {
res *Response // either res or err will be set
err os.Error
}
type writeRequest struct {
// Set by client (in pc.roundTrip)
req *Request
resch chan *readResult
// Set by writeLoop if an error writing headers.
writeErr os.Error
}
// persistConn wraps a connection, usually a persistent one
// (but may be used for non-keep-alive requests as well)
type persistConn struct {
t *Transport
cacheKey string // its connectMethod.String()
conn net.Conn
cc *ClientConn
br *bufio.Reader
reqch chan requestAndChan // written by roundTrip(); read by readLoop()
mutateRequestFunc func(*Request) // nil or func to modify each outbound request
t *Transport
cacheKey string // its connectMethod.String()
conn net.Conn
cc *ClientConn
br *bufio.Reader
reqch chan requestAndChan // written by roundTrip(); read by readLoop()
isProxy bool
// mutateHeaderFunc is an optional func to modify extra
// headers on each outbound request before it's written. (the
// original Request given to RoundTrip is not modified)
mutateHeaderFunc func(Header)
lk sync.Mutex // guards numExpectedResponses and broken
numExpectedResponses int
@ -526,9 +530,6 @@ func (pc *persistConn) readLoop() {
if err != nil || resp.ContentLength == 0 {
return resp, err
}
if rc.addedGzip {
forReq.Header.Del("Accept-Encoding")
}
if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" {
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
@ -604,9 +605,9 @@ type requestAndChan struct {
addedGzip bool
}
func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
if pc.mutateRequestFunc != nil {
pc.mutateRequestFunc(req)
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err os.Error) {
if pc.mutateHeaderFunc != nil {
pc.mutateHeaderFunc(req.extraHeaders())
}
// Ask for a compressed version if the caller didn't set their
@ -616,24 +617,28 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
requestedGzip := false
if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" {
// Request gzip only, not deflate. Deflate is ambiguous and
// as universally supported anyway.
// not as universally supported anyway.
// See: http://www.gzip.org/zlib/zlib_faq.html#faq38
requestedGzip = true
req.Header.Set("Accept-Encoding", "gzip")
req.extraHeaders().Set("Accept-Encoding", "gzip")
}
pc.lk.Lock()
pc.numExpectedResponses++
pc.lk.Unlock()
err = pc.cc.Write(req)
pc.cc.writeReq = func(r *Request, w io.Writer) os.Error {
return r.write(w, pc.isProxy, req.extra)
}
err = pc.cc.Write(req.Request)
if err != nil {
pc.close()
return
}
ch := make(chan responseAndError, 1)
pc.reqch <- requestAndChan{req, ch, requestedGzip}
pc.reqch <- requestAndChan{req.Request, ch, requestedGzip}
re := <-ch
pc.lk.Lock()
pc.numExpectedResponses--
@ -648,7 +653,7 @@ func (pc *persistConn) close() {
pc.broken = true
pc.cc.Close()
pc.conn.Close()
pc.mutateRequestFunc = nil
pc.mutateHeaderFunc = nil
}
var portMap = map[string]string{

View File

@ -372,7 +372,8 @@ var roundTripTests = []struct {
// Requests with other accept-encoding should pass through unmodified
{"foo", "foo", false},
// Requests with accept-encoding == gzip should be passed through
{"gzip", "gzip", true}}
{"gzip", "gzip", true},
}
// Test that the modification made to the Request by the RoundTripper is cleaned up
func TestRoundTripGzip(t *testing.T) {
@ -380,7 +381,8 @@ func TestRoundTripGzip(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
accept := req.Header.Get("Accept-Encoding")
if expect := req.FormValue("expect_accept"); accept != expect {
t.Errorf("Accept-Encoding = %q, want %q", accept, expect)
t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
req.FormValue("testnum"), accept, expect)
}
if accept == "gzip" {
rw.Header().Set("Content-Encoding", "gzip")
@ -396,8 +398,10 @@ func TestRoundTripGzip(t *testing.T) {
for i, test := range roundTripTests {
// Test basic request (no accept-encoding)
req, _ := NewRequest("GET", ts.URL+"?expect_accept="+test.expectAccept, nil)
req.Header.Set("Accept-Encoding", test.accept)
req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
if test.accept != "" {
req.Header.Set("Accept-Encoding", test.accept)
}
res, err := DefaultTransport.RoundTrip(req)
var body []byte
if test.compressed {
@ -409,16 +413,16 @@ func TestRoundTripGzip(t *testing.T) {
}
if err != nil {
t.Errorf("%d. Error: %q", i, err)
} else {
if g, e := string(body), responseBody; g != e {
t.Errorf("%d. body = %q; want %q", i, g, e)
}
if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
t.Errorf("%d. Accept-Encoding = %q; want %q", i, g, e)
}
if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
}
continue
}
if g, e := string(body), responseBody; g != e {
t.Errorf("%d. body = %q; want %q", i, g, e)
}
if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
}
if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
}
}