mirror of
https://github.com/golang/go
synced 2024-11-21 22:24:40 -07:00
http: RoundTrippers shouldn't mutate Request
Fixes #2146 R=rsc CC=golang-dev https://golang.org/cl/5284041
This commit is contained in:
parent
236aff31c5
commit
b9ad2787dd
@ -56,9 +56,10 @@ type RoundTripper interface {
|
||||
// higher-level protocol details such as redirects,
|
||||
// authentication, or cookies.
|
||||
//
|
||||
// RoundTrip may modify the request. The request Headers field is
|
||||
// guaranteed to be initialized.
|
||||
RoundTrip(req *Request) (resp *Response, err os.Error)
|
||||
// RoundTrip should not modify the request, except for
|
||||
// consuming the Body. The request's URL and Header fields
|
||||
// are guaranteed to be initialized.
|
||||
RoundTrip(*Request) (*Response, os.Error)
|
||||
}
|
||||
|
||||
// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
|
||||
@ -96,11 +97,15 @@ func send(req *Request, t RoundTripper) (resp *Response, err os.Error) {
|
||||
if t == nil {
|
||||
t = DefaultTransport
|
||||
if t == nil {
|
||||
err = os.NewError("no http.Client.Transport or http.DefaultTransport")
|
||||
err = os.NewError("http: no Client.Transport or DefaultTransport")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if req.URL == nil {
|
||||
return nil, os.NewError("http: nil Request.URL")
|
||||
}
|
||||
|
||||
// Most the callers of send (Get, Post, et al) don't need
|
||||
// Headers, leaving it uninitialized. We guarantee to the
|
||||
// Transport that this has been initialized, though.
|
||||
|
@ -275,7 +275,7 @@ const defaultUserAgent = "Go http package"
|
||||
// hasn't been set to "identity", Write adds "Transfer-Encoding:
|
||||
// chunked" to the header. Body is closed after it is sent.
|
||||
func (req *Request) Write(w io.Writer) os.Error {
|
||||
return req.write(w, false)
|
||||
return req.write(w, false, nil)
|
||||
}
|
||||
|
||||
// WriteProxy is like Write but writes the request in the form
|
||||
@ -285,7 +285,7 @@ func (req *Request) Write(w io.Writer) os.Error {
|
||||
// either case, WriteProxy also writes a Host header, using either
|
||||
// req.Host or req.URL.Host.
|
||||
func (req *Request) WriteProxy(w io.Writer) os.Error {
|
||||
return req.write(w, true)
|
||||
return req.write(w, true, nil)
|
||||
}
|
||||
|
||||
func (req *Request) dumpWrite(w io.Writer) os.Error {
|
||||
@ -333,7 +333,8 @@ func (req *Request) dumpWrite(w io.Writer) os.Error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
|
||||
// extraHeaders may be nil
|
||||
func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) os.Error {
|
||||
host := req.Host
|
||||
if host == "" {
|
||||
if req.URL == nil {
|
||||
@ -394,6 +395,13 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
|
||||
return err
|
||||
}
|
||||
|
||||
if extraHeaders != nil {
|
||||
err = extraHeaders.Write(bw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
io.WriteString(bw, "\r\n")
|
||||
|
||||
// Write body and trailer
|
||||
|
@ -100,11 +100,28 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, os.Error) {
|
||||
}
|
||||
}
|
||||
|
||||
// transportRequest is a wrapper around a *Request that adds
|
||||
// optional extra headers to write.
|
||||
type transportRequest struct {
|
||||
*Request // original request, not to be mutated
|
||||
extra Header // extra headers to write, or nil
|
||||
}
|
||||
|
||||
func (tr *transportRequest) extraHeaders() Header {
|
||||
if tr.extra == nil {
|
||||
tr.extra = make(Header)
|
||||
}
|
||||
return tr.extra
|
||||
}
|
||||
|
||||
// RoundTrip implements the RoundTripper interface.
|
||||
func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
|
||||
if req.URL == nil {
|
||||
return nil, os.NewError("http: nil Request.URL")
|
||||
}
|
||||
if req.Header == nil {
|
||||
return nil, os.NewError("http: nil Request.Header")
|
||||
}
|
||||
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
|
||||
t.lk.Lock()
|
||||
var rt RoundTripper
|
||||
@ -117,8 +134,8 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
|
||||
}
|
||||
return rt.RoundTrip(req)
|
||||
}
|
||||
|
||||
cm, err := t.connectMethodForRequest(req)
|
||||
treq := &transportRequest{Request: req}
|
||||
cm, err := t.connectMethodForRequest(treq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -132,7 +149,7 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pconn.roundTrip(req)
|
||||
return pconn.roundTrip(treq)
|
||||
}
|
||||
|
||||
// RegisterProtocol registers a new protocol with scheme.
|
||||
@ -185,14 +202,14 @@ func getenvEitherCase(k string) string {
|
||||
return os.Getenv(strings.ToLower(k))
|
||||
}
|
||||
|
||||
func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) {
|
||||
func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, os.Error) {
|
||||
cm := &connectMethod{
|
||||
targetScheme: req.URL.Scheme,
|
||||
targetAddr: canonicalAddr(req.URL),
|
||||
targetScheme: treq.URL.Scheme,
|
||||
targetAddr: canonicalAddr(treq.URL),
|
||||
}
|
||||
if t.Proxy != nil {
|
||||
var err os.Error
|
||||
cm.proxyURL, err = t.Proxy(req)
|
||||
cm.proxyURL, err = t.Proxy(treq.Request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -295,19 +312,15 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
|
||||
conn: conn,
|
||||
reqch: make(chan requestAndChan, 50),
|
||||
}
|
||||
newClientConnFunc := NewClientConn
|
||||
|
||||
switch {
|
||||
case cm.proxyURL == nil:
|
||||
// Do nothing.
|
||||
case cm.targetScheme == "http":
|
||||
newClientConnFunc = NewProxyClientConn
|
||||
pconn.isProxy = true
|
||||
if pa != "" {
|
||||
pconn.mutateRequestFunc = func(req *Request) {
|
||||
if req.Header == nil {
|
||||
req.Header = make(Header)
|
||||
}
|
||||
req.Header.Set("Proxy-Authorization", pa)
|
||||
pconn.mutateHeaderFunc = func(h Header) {
|
||||
h.Set("Proxy-Authorization", pa)
|
||||
}
|
||||
}
|
||||
case cm.targetScheme == "https":
|
||||
@ -351,7 +364,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
|
||||
}
|
||||
|
||||
pconn.br = bufio.NewReader(pconn.conn)
|
||||
pconn.cc = newClientConnFunc(conn, pconn.br)
|
||||
pconn.cc = NewClientConn(conn, pconn.br)
|
||||
go pconn.readLoop()
|
||||
return pconn, nil
|
||||
}
|
||||
@ -447,30 +460,21 @@ func (cm *connectMethod) tlsHost() string {
|
||||
return h
|
||||
}
|
||||
|
||||
type readResult struct {
|
||||
res *Response // either res or err will be set
|
||||
err os.Error
|
||||
}
|
||||
|
||||
type writeRequest struct {
|
||||
// Set by client (in pc.roundTrip)
|
||||
req *Request
|
||||
resch chan *readResult
|
||||
|
||||
// Set by writeLoop if an error writing headers.
|
||||
writeErr os.Error
|
||||
}
|
||||
|
||||
// persistConn wraps a connection, usually a persistent one
|
||||
// (but may be used for non-keep-alive requests as well)
|
||||
type persistConn struct {
|
||||
t *Transport
|
||||
cacheKey string // its connectMethod.String()
|
||||
conn net.Conn
|
||||
cc *ClientConn
|
||||
br *bufio.Reader
|
||||
reqch chan requestAndChan // written by roundTrip(); read by readLoop()
|
||||
mutateRequestFunc func(*Request) // nil or func to modify each outbound request
|
||||
t *Transport
|
||||
cacheKey string // its connectMethod.String()
|
||||
conn net.Conn
|
||||
cc *ClientConn
|
||||
br *bufio.Reader
|
||||
reqch chan requestAndChan // written by roundTrip(); read by readLoop()
|
||||
isProxy bool
|
||||
|
||||
// mutateHeaderFunc is an optional func to modify extra
|
||||
// headers on each outbound request before it's written. (the
|
||||
// original Request given to RoundTrip is not modified)
|
||||
mutateHeaderFunc func(Header)
|
||||
|
||||
lk sync.Mutex // guards numExpectedResponses and broken
|
||||
numExpectedResponses int
|
||||
@ -526,9 +530,6 @@ func (pc *persistConn) readLoop() {
|
||||
if err != nil || resp.ContentLength == 0 {
|
||||
return resp, err
|
||||
}
|
||||
if rc.addedGzip {
|
||||
forReq.Header.Del("Accept-Encoding")
|
||||
}
|
||||
if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" {
|
||||
resp.Header.Del("Content-Encoding")
|
||||
resp.Header.Del("Content-Length")
|
||||
@ -604,9 +605,9 @@ type requestAndChan struct {
|
||||
addedGzip bool
|
||||
}
|
||||
|
||||
func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
|
||||
if pc.mutateRequestFunc != nil {
|
||||
pc.mutateRequestFunc(req)
|
||||
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err os.Error) {
|
||||
if pc.mutateHeaderFunc != nil {
|
||||
pc.mutateHeaderFunc(req.extraHeaders())
|
||||
}
|
||||
|
||||
// Ask for a compressed version if the caller didn't set their
|
||||
@ -616,24 +617,28 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
|
||||
requestedGzip := false
|
||||
if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" {
|
||||
// Request gzip only, not deflate. Deflate is ambiguous and
|
||||
// as universally supported anyway.
|
||||
// not as universally supported anyway.
|
||||
// See: http://www.gzip.org/zlib/zlib_faq.html#faq38
|
||||
requestedGzip = true
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
req.extraHeaders().Set("Accept-Encoding", "gzip")
|
||||
}
|
||||
|
||||
pc.lk.Lock()
|
||||
pc.numExpectedResponses++
|
||||
pc.lk.Unlock()
|
||||
|
||||
err = pc.cc.Write(req)
|
||||
pc.cc.writeReq = func(r *Request, w io.Writer) os.Error {
|
||||
return r.write(w, pc.isProxy, req.extra)
|
||||
}
|
||||
|
||||
err = pc.cc.Write(req.Request)
|
||||
if err != nil {
|
||||
pc.close()
|
||||
return
|
||||
}
|
||||
|
||||
ch := make(chan responseAndError, 1)
|
||||
pc.reqch <- requestAndChan{req, ch, requestedGzip}
|
||||
pc.reqch <- requestAndChan{req.Request, ch, requestedGzip}
|
||||
re := <-ch
|
||||
pc.lk.Lock()
|
||||
pc.numExpectedResponses--
|
||||
@ -648,7 +653,7 @@ func (pc *persistConn) close() {
|
||||
pc.broken = true
|
||||
pc.cc.Close()
|
||||
pc.conn.Close()
|
||||
pc.mutateRequestFunc = nil
|
||||
pc.mutateHeaderFunc = nil
|
||||
}
|
||||
|
||||
var portMap = map[string]string{
|
||||
|
@ -372,7 +372,8 @@ var roundTripTests = []struct {
|
||||
// Requests with other accept-encoding should pass through unmodified
|
||||
{"foo", "foo", false},
|
||||
// Requests with accept-encoding == gzip should be passed through
|
||||
{"gzip", "gzip", true}}
|
||||
{"gzip", "gzip", true},
|
||||
}
|
||||
|
||||
// Test that the modification made to the Request by the RoundTripper is cleaned up
|
||||
func TestRoundTripGzip(t *testing.T) {
|
||||
@ -380,7 +381,8 @@ func TestRoundTripGzip(t *testing.T) {
|
||||
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||||
accept := req.Header.Get("Accept-Encoding")
|
||||
if expect := req.FormValue("expect_accept"); accept != expect {
|
||||
t.Errorf("Accept-Encoding = %q, want %q", accept, expect)
|
||||
t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
|
||||
req.FormValue("testnum"), accept, expect)
|
||||
}
|
||||
if accept == "gzip" {
|
||||
rw.Header().Set("Content-Encoding", "gzip")
|
||||
@ -396,8 +398,10 @@ func TestRoundTripGzip(t *testing.T) {
|
||||
|
||||
for i, test := range roundTripTests {
|
||||
// Test basic request (no accept-encoding)
|
||||
req, _ := NewRequest("GET", ts.URL+"?expect_accept="+test.expectAccept, nil)
|
||||
req.Header.Set("Accept-Encoding", test.accept)
|
||||
req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
|
||||
if test.accept != "" {
|
||||
req.Header.Set("Accept-Encoding", test.accept)
|
||||
}
|
||||
res, err := DefaultTransport.RoundTrip(req)
|
||||
var body []byte
|
||||
if test.compressed {
|
||||
@ -409,16 +413,16 @@ func TestRoundTripGzip(t *testing.T) {
|
||||
}
|
||||
if err != nil {
|
||||
t.Errorf("%d. Error: %q", i, err)
|
||||
} else {
|
||||
if g, e := string(body), responseBody; g != e {
|
||||
t.Errorf("%d. body = %q; want %q", i, g, e)
|
||||
}
|
||||
if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
|
||||
t.Errorf("%d. Accept-Encoding = %q; want %q", i, g, e)
|
||||
}
|
||||
if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
|
||||
t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if g, e := string(body), responseBody; g != e {
|
||||
t.Errorf("%d. body = %q; want %q", i, g, e)
|
||||
}
|
||||
if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
|
||||
t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
|
||||
}
|
||||
if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
|
||||
t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user