mirror of
https://github.com/golang/go
synced 2024-11-22 04:24:39 -07:00
http: RoundTrippers shouldn't mutate Request
Fixes #2146 R=rsc CC=golang-dev https://golang.org/cl/5284041
This commit is contained in:
parent
236aff31c5
commit
b9ad2787dd
@ -56,9 +56,10 @@ type RoundTripper interface {
|
|||||||
// higher-level protocol details such as redirects,
|
// higher-level protocol details such as redirects,
|
||||||
// authentication, or cookies.
|
// authentication, or cookies.
|
||||||
//
|
//
|
||||||
// RoundTrip may modify the request. The request Headers field is
|
// RoundTrip should not modify the request, except for
|
||||||
// guaranteed to be initialized.
|
// consuming the Body. The request's URL and Header fields
|
||||||
RoundTrip(req *Request) (resp *Response, err os.Error)
|
// are guaranteed to be initialized.
|
||||||
|
RoundTrip(*Request) (*Response, os.Error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
|
// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
|
||||||
@ -96,11 +97,15 @@ func send(req *Request, t RoundTripper) (resp *Response, err os.Error) {
|
|||||||
if t == nil {
|
if t == nil {
|
||||||
t = DefaultTransport
|
t = DefaultTransport
|
||||||
if t == nil {
|
if t == nil {
|
||||||
err = os.NewError("no http.Client.Transport or http.DefaultTransport")
|
err = os.NewError("http: no Client.Transport or DefaultTransport")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if req.URL == nil {
|
||||||
|
return nil, os.NewError("http: nil Request.URL")
|
||||||
|
}
|
||||||
|
|
||||||
// Most the callers of send (Get, Post, et al) don't need
|
// Most the callers of send (Get, Post, et al) don't need
|
||||||
// Headers, leaving it uninitialized. We guarantee to the
|
// Headers, leaving it uninitialized. We guarantee to the
|
||||||
// Transport that this has been initialized, though.
|
// Transport that this has been initialized, though.
|
||||||
|
@ -275,7 +275,7 @@ const defaultUserAgent = "Go http package"
|
|||||||
// hasn't been set to "identity", Write adds "Transfer-Encoding:
|
// hasn't been set to "identity", Write adds "Transfer-Encoding:
|
||||||
// chunked" to the header. Body is closed after it is sent.
|
// chunked" to the header. Body is closed after it is sent.
|
||||||
func (req *Request) Write(w io.Writer) os.Error {
|
func (req *Request) Write(w io.Writer) os.Error {
|
||||||
return req.write(w, false)
|
return req.write(w, false, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WriteProxy is like Write but writes the request in the form
|
// WriteProxy is like Write but writes the request in the form
|
||||||
@ -285,7 +285,7 @@ func (req *Request) Write(w io.Writer) os.Error {
|
|||||||
// either case, WriteProxy also writes a Host header, using either
|
// either case, WriteProxy also writes a Host header, using either
|
||||||
// req.Host or req.URL.Host.
|
// req.Host or req.URL.Host.
|
||||||
func (req *Request) WriteProxy(w io.Writer) os.Error {
|
func (req *Request) WriteProxy(w io.Writer) os.Error {
|
||||||
return req.write(w, true)
|
return req.write(w, true, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *Request) dumpWrite(w io.Writer) os.Error {
|
func (req *Request) dumpWrite(w io.Writer) os.Error {
|
||||||
@ -333,7 +333,8 @@ func (req *Request) dumpWrite(w io.Writer) os.Error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
|
// extraHeaders may be nil
|
||||||
|
func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) os.Error {
|
||||||
host := req.Host
|
host := req.Host
|
||||||
if host == "" {
|
if host == "" {
|
||||||
if req.URL == nil {
|
if req.URL == nil {
|
||||||
@ -394,6 +395,13 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if extraHeaders != nil {
|
||||||
|
err = extraHeaders.Write(bw)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
io.WriteString(bw, "\r\n")
|
io.WriteString(bw, "\r\n")
|
||||||
|
|
||||||
// Write body and trailer
|
// Write body and trailer
|
||||||
|
@ -100,11 +100,28 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, os.Error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// transportRequest is a wrapper around a *Request that adds
|
||||||
|
// optional extra headers to write.
|
||||||
|
type transportRequest struct {
|
||||||
|
*Request // original request, not to be mutated
|
||||||
|
extra Header // extra headers to write, or nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tr *transportRequest) extraHeaders() Header {
|
||||||
|
if tr.extra == nil {
|
||||||
|
tr.extra = make(Header)
|
||||||
|
}
|
||||||
|
return tr.extra
|
||||||
|
}
|
||||||
|
|
||||||
// RoundTrip implements the RoundTripper interface.
|
// RoundTrip implements the RoundTripper interface.
|
||||||
func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
|
func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
|
||||||
if req.URL == nil {
|
if req.URL == nil {
|
||||||
return nil, os.NewError("http: nil Request.URL")
|
return nil, os.NewError("http: nil Request.URL")
|
||||||
}
|
}
|
||||||
|
if req.Header == nil {
|
||||||
|
return nil, os.NewError("http: nil Request.Header")
|
||||||
|
}
|
||||||
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
|
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
|
||||||
t.lk.Lock()
|
t.lk.Lock()
|
||||||
var rt RoundTripper
|
var rt RoundTripper
|
||||||
@ -117,8 +134,8 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
|
|||||||
}
|
}
|
||||||
return rt.RoundTrip(req)
|
return rt.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
treq := &transportRequest{Request: req}
|
||||||
cm, err := t.connectMethodForRequest(req)
|
cm, err := t.connectMethodForRequest(treq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -132,7 +149,7 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return pconn.roundTrip(req)
|
return pconn.roundTrip(treq)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterProtocol registers a new protocol with scheme.
|
// RegisterProtocol registers a new protocol with scheme.
|
||||||
@ -185,14 +202,14 @@ func getenvEitherCase(k string) string {
|
|||||||
return os.Getenv(strings.ToLower(k))
|
return os.Getenv(strings.ToLower(k))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) {
|
func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, os.Error) {
|
||||||
cm := &connectMethod{
|
cm := &connectMethod{
|
||||||
targetScheme: req.URL.Scheme,
|
targetScheme: treq.URL.Scheme,
|
||||||
targetAddr: canonicalAddr(req.URL),
|
targetAddr: canonicalAddr(treq.URL),
|
||||||
}
|
}
|
||||||
if t.Proxy != nil {
|
if t.Proxy != nil {
|
||||||
var err os.Error
|
var err os.Error
|
||||||
cm.proxyURL, err = t.Proxy(req)
|
cm.proxyURL, err = t.Proxy(treq.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -295,19 +312,15 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
|
|||||||
conn: conn,
|
conn: conn,
|
||||||
reqch: make(chan requestAndChan, 50),
|
reqch: make(chan requestAndChan, 50),
|
||||||
}
|
}
|
||||||
newClientConnFunc := NewClientConn
|
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case cm.proxyURL == nil:
|
case cm.proxyURL == nil:
|
||||||
// Do nothing.
|
// Do nothing.
|
||||||
case cm.targetScheme == "http":
|
case cm.targetScheme == "http":
|
||||||
newClientConnFunc = NewProxyClientConn
|
pconn.isProxy = true
|
||||||
if pa != "" {
|
if pa != "" {
|
||||||
pconn.mutateRequestFunc = func(req *Request) {
|
pconn.mutateHeaderFunc = func(h Header) {
|
||||||
if req.Header == nil {
|
h.Set("Proxy-Authorization", pa)
|
||||||
req.Header = make(Header)
|
|
||||||
}
|
|
||||||
req.Header.Set("Proxy-Authorization", pa)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
case cm.targetScheme == "https":
|
case cm.targetScheme == "https":
|
||||||
@ -351,7 +364,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pconn.br = bufio.NewReader(pconn.conn)
|
pconn.br = bufio.NewReader(pconn.conn)
|
||||||
pconn.cc = newClientConnFunc(conn, pconn.br)
|
pconn.cc = NewClientConn(conn, pconn.br)
|
||||||
go pconn.readLoop()
|
go pconn.readLoop()
|
||||||
return pconn, nil
|
return pconn, nil
|
||||||
}
|
}
|
||||||
@ -447,30 +460,21 @@ func (cm *connectMethod) tlsHost() string {
|
|||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
type readResult struct {
|
|
||||||
res *Response // either res or err will be set
|
|
||||||
err os.Error
|
|
||||||
}
|
|
||||||
|
|
||||||
type writeRequest struct {
|
|
||||||
// Set by client (in pc.roundTrip)
|
|
||||||
req *Request
|
|
||||||
resch chan *readResult
|
|
||||||
|
|
||||||
// Set by writeLoop if an error writing headers.
|
|
||||||
writeErr os.Error
|
|
||||||
}
|
|
||||||
|
|
||||||
// persistConn wraps a connection, usually a persistent one
|
// persistConn wraps a connection, usually a persistent one
|
||||||
// (but may be used for non-keep-alive requests as well)
|
// (but may be used for non-keep-alive requests as well)
|
||||||
type persistConn struct {
|
type persistConn struct {
|
||||||
t *Transport
|
t *Transport
|
||||||
cacheKey string // its connectMethod.String()
|
cacheKey string // its connectMethod.String()
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
cc *ClientConn
|
cc *ClientConn
|
||||||
br *bufio.Reader
|
br *bufio.Reader
|
||||||
reqch chan requestAndChan // written by roundTrip(); read by readLoop()
|
reqch chan requestAndChan // written by roundTrip(); read by readLoop()
|
||||||
mutateRequestFunc func(*Request) // nil or func to modify each outbound request
|
isProxy bool
|
||||||
|
|
||||||
|
// mutateHeaderFunc is an optional func to modify extra
|
||||||
|
// headers on each outbound request before it's written. (the
|
||||||
|
// original Request given to RoundTrip is not modified)
|
||||||
|
mutateHeaderFunc func(Header)
|
||||||
|
|
||||||
lk sync.Mutex // guards numExpectedResponses and broken
|
lk sync.Mutex // guards numExpectedResponses and broken
|
||||||
numExpectedResponses int
|
numExpectedResponses int
|
||||||
@ -526,9 +530,6 @@ func (pc *persistConn) readLoop() {
|
|||||||
if err != nil || resp.ContentLength == 0 {
|
if err != nil || resp.ContentLength == 0 {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
if rc.addedGzip {
|
|
||||||
forReq.Header.Del("Accept-Encoding")
|
|
||||||
}
|
|
||||||
if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" {
|
if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" {
|
||||||
resp.Header.Del("Content-Encoding")
|
resp.Header.Del("Content-Encoding")
|
||||||
resp.Header.Del("Content-Length")
|
resp.Header.Del("Content-Length")
|
||||||
@ -604,9 +605,9 @@ type requestAndChan struct {
|
|||||||
addedGzip bool
|
addedGzip bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
|
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err os.Error) {
|
||||||
if pc.mutateRequestFunc != nil {
|
if pc.mutateHeaderFunc != nil {
|
||||||
pc.mutateRequestFunc(req)
|
pc.mutateHeaderFunc(req.extraHeaders())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ask for a compressed version if the caller didn't set their
|
// Ask for a compressed version if the caller didn't set their
|
||||||
@ -616,24 +617,28 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
|
|||||||
requestedGzip := false
|
requestedGzip := false
|
||||||
if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" {
|
if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" {
|
||||||
// Request gzip only, not deflate. Deflate is ambiguous and
|
// Request gzip only, not deflate. Deflate is ambiguous and
|
||||||
// as universally supported anyway.
|
// not as universally supported anyway.
|
||||||
// See: http://www.gzip.org/zlib/zlib_faq.html#faq38
|
// See: http://www.gzip.org/zlib/zlib_faq.html#faq38
|
||||||
requestedGzip = true
|
requestedGzip = true
|
||||||
req.Header.Set("Accept-Encoding", "gzip")
|
req.extraHeaders().Set("Accept-Encoding", "gzip")
|
||||||
}
|
}
|
||||||
|
|
||||||
pc.lk.Lock()
|
pc.lk.Lock()
|
||||||
pc.numExpectedResponses++
|
pc.numExpectedResponses++
|
||||||
pc.lk.Unlock()
|
pc.lk.Unlock()
|
||||||
|
|
||||||
err = pc.cc.Write(req)
|
pc.cc.writeReq = func(r *Request, w io.Writer) os.Error {
|
||||||
|
return r.write(w, pc.isProxy, req.extra)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = pc.cc.Write(req.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
pc.close()
|
pc.close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ch := make(chan responseAndError, 1)
|
ch := make(chan responseAndError, 1)
|
||||||
pc.reqch <- requestAndChan{req, ch, requestedGzip}
|
pc.reqch <- requestAndChan{req.Request, ch, requestedGzip}
|
||||||
re := <-ch
|
re := <-ch
|
||||||
pc.lk.Lock()
|
pc.lk.Lock()
|
||||||
pc.numExpectedResponses--
|
pc.numExpectedResponses--
|
||||||
@ -648,7 +653,7 @@ func (pc *persistConn) close() {
|
|||||||
pc.broken = true
|
pc.broken = true
|
||||||
pc.cc.Close()
|
pc.cc.Close()
|
||||||
pc.conn.Close()
|
pc.conn.Close()
|
||||||
pc.mutateRequestFunc = nil
|
pc.mutateHeaderFunc = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var portMap = map[string]string{
|
var portMap = map[string]string{
|
||||||
|
@ -372,7 +372,8 @@ var roundTripTests = []struct {
|
|||||||
// Requests with other accept-encoding should pass through unmodified
|
// Requests with other accept-encoding should pass through unmodified
|
||||||
{"foo", "foo", false},
|
{"foo", "foo", false},
|
||||||
// Requests with accept-encoding == gzip should be passed through
|
// Requests with accept-encoding == gzip should be passed through
|
||||||
{"gzip", "gzip", true}}
|
{"gzip", "gzip", true},
|
||||||
|
}
|
||||||
|
|
||||||
// Test that the modification made to the Request by the RoundTripper is cleaned up
|
// Test that the modification made to the Request by the RoundTripper is cleaned up
|
||||||
func TestRoundTripGzip(t *testing.T) {
|
func TestRoundTripGzip(t *testing.T) {
|
||||||
@ -380,7 +381,8 @@ func TestRoundTripGzip(t *testing.T) {
|
|||||||
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
|
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
|
||||||
accept := req.Header.Get("Accept-Encoding")
|
accept := req.Header.Get("Accept-Encoding")
|
||||||
if expect := req.FormValue("expect_accept"); accept != expect {
|
if expect := req.FormValue("expect_accept"); accept != expect {
|
||||||
t.Errorf("Accept-Encoding = %q, want %q", accept, expect)
|
t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
|
||||||
|
req.FormValue("testnum"), accept, expect)
|
||||||
}
|
}
|
||||||
if accept == "gzip" {
|
if accept == "gzip" {
|
||||||
rw.Header().Set("Content-Encoding", "gzip")
|
rw.Header().Set("Content-Encoding", "gzip")
|
||||||
@ -396,8 +398,10 @@ func TestRoundTripGzip(t *testing.T) {
|
|||||||
|
|
||||||
for i, test := range roundTripTests {
|
for i, test := range roundTripTests {
|
||||||
// Test basic request (no accept-encoding)
|
// Test basic request (no accept-encoding)
|
||||||
req, _ := NewRequest("GET", ts.URL+"?expect_accept="+test.expectAccept, nil)
|
req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
|
||||||
req.Header.Set("Accept-Encoding", test.accept)
|
if test.accept != "" {
|
||||||
|
req.Header.Set("Accept-Encoding", test.accept)
|
||||||
|
}
|
||||||
res, err := DefaultTransport.RoundTrip(req)
|
res, err := DefaultTransport.RoundTrip(req)
|
||||||
var body []byte
|
var body []byte
|
||||||
if test.compressed {
|
if test.compressed {
|
||||||
@ -409,16 +413,16 @@ func TestRoundTripGzip(t *testing.T) {
|
|||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("%d. Error: %q", i, err)
|
t.Errorf("%d. Error: %q", i, err)
|
||||||
} else {
|
continue
|
||||||
if g, e := string(body), responseBody; g != e {
|
}
|
||||||
t.Errorf("%d. body = %q; want %q", i, g, e)
|
if g, e := string(body), responseBody; g != e {
|
||||||
}
|
t.Errorf("%d. body = %q; want %q", i, g, e)
|
||||||
if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
|
}
|
||||||
t.Errorf("%d. Accept-Encoding = %q; want %q", i, g, e)
|
if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
|
||||||
}
|
t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
|
||||||
if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
|
}
|
||||||
t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
|
if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
|
||||||
}
|
t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user