1
0
mirror of https://github.com/golang/go synced 2024-11-19 16:44:43 -07:00

net/http: HTTPS proxies support

net/http already supports http proxies. This CL allows it to establish
a connection to the http proxy over https. See more at:
https://www.chromium.org/developers/design-documents/secure-web-proxy

Fixes golang/go#11332

Change-Id: If0e017df0e8f8c2c499a2ddcbbeb625c8fa2bb6b
Reviewed-on: https://go-review.googlesource.com/68550
Run-TryBot: Tom Bergan <tombergan@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Tom Bergan <tombergan@google.com>
This commit is contained in:
Ben Schwartz 2017-10-05 14:07:55 -04:00 committed by Tom Bergan
parent 897080d5cb
commit f5cd3868d5
2 changed files with 283 additions and 120 deletions

View File

@ -618,11 +618,6 @@ func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectM
if port := cm.proxyURL.Port(); !validPort(port) {
return cm, fmt.Errorf("invalid proxy URL port %q", port)
}
switch cm.proxyURL.Scheme {
case "http", "socks5":
default:
return cm, fmt.Errorf("invalid proxy URL scheme %q", cm.proxyURL.Scheme)
}
}
}
return cm, err
@ -1021,6 +1016,69 @@ func (d oneConnDialer) Dial(network, addr string) (net.Conn, error) {
}
}
// The connect method and the transport can both specify a TLS
// Host name. The transport's name takes precedence if present.
func chooseTLSHost(cm connectMethod, t *Transport) string {
tlsHost := ""
if t.TLSClientConfig != nil {
tlsHost = t.TLSClientConfig.ServerName
}
if tlsHost == "" {
tlsHost = cm.tlsHost()
}
return tlsHost
}
// Add TLS to a persistent connection, i.e. negotiate a TLS session. If pconn is already a TLS
// tunnel, this function establishes a nested TLS session inside the encrypted channel.
// The remote endpoint's name may be overridden by TLSClientConfig.ServerName.
func (pconn *persistConn) addTLS(name string, trace *httptrace.ClientTrace) error {
// Initiate TLS and check remote host name against certificate.
cfg := cloneTLSConfig(pconn.t.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = name
}
plainConn := pconn.conn
tlsConn := tls.Client(plainConn, cfg)
errc := make(chan error, 2)
var timer *time.Timer // for canceling TLS handshake
if d := pconn.t.TLSHandshakeTimeout; d != 0 {
timer = time.AfterFunc(d, func() {
errc <- tlsHandshakeTimeoutError{}
})
}
go func() {
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := tlsConn.Handshake()
if timer != nil {
timer.Stop()
}
errc <- err
}()
if err := <-errc; err != nil {
plainConn.Close()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
}
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
plainConn.Close()
return err
}
}
cs := tlsConn.ConnectionState()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(cs, nil)
}
pconn.tlsState = &cs
pconn.conn = tlsConn
return nil
}
func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistConn, error) {
pconn := &persistConn{
t: t,
@ -1032,15 +1090,21 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
writeLoopDone: make(chan struct{}),
}
trace := httptrace.ContextClientTrace(ctx)
tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
if tlsDial {
wrapErr := func(err error) error {
if cm.proxyURL != nil {
// Return a typed error, per Issue 16997
return &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err}
}
return err
}
if cm.scheme() == "https" && t.DialTLS != nil {
var err error
pconn.conn, err = t.DialTLS("tcp", cm.addr())
if err != nil {
return nil, err
return nil, wrapErr(err)
}
if pconn.conn == nil {
return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)")
return nil, wrapErr(errors.New("net/http: Transport.DialTLS returned (nil, nil)"))
}
if tc, ok := pconn.conn.(*tls.Conn); ok {
// Handshake here, in case DialTLS didn't. TLSNextProto below
@ -1064,13 +1128,18 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
} else {
conn, err := t.dial(ctx, "tcp", cm.addr())
if err != nil {
if cm.proxyURL != nil {
// Return a typed error, per Issue 16997:
err = &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err}
}
return nil, err
return nil, wrapErr(err)
}
pconn.conn = conn
if cm.scheme() == "https" {
var firstTLSHost string
if firstTLSHost, _, err = net.SplitHostPort(cm.addr()); err != nil {
return nil, wrapErr(err)
}
if err = pconn.addTLS(firstTLSHost, trace); err != nil {
return nil, wrapErr(err)
}
}
}
// Proxy setup.
@ -1134,50 +1203,10 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
}
}
if cm.targetScheme == "https" && !tlsDial {
// Initiate TLS and check remote host name against certificate.
cfg := cloneTLSConfig(t.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = cm.tlsHost()
}
plainConn := pconn.conn
tlsConn := tls.Client(plainConn, cfg)
errc := make(chan error, 2)
var timer *time.Timer // for canceling TLS handshake
if d := t.TLSHandshakeTimeout; d != 0 {
timer = time.AfterFunc(d, func() {
errc <- tlsHandshakeTimeoutError{}
})
}
go func() {
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := tlsConn.Handshake()
if timer != nil {
timer.Stop()
}
errc <- err
}()
if err := <-errc; err != nil {
plainConn.Close()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tls.ConnectionState{}, err)
}
if cm.proxyURL != nil && cm.targetScheme == "https" {
if err := pconn.addTLS(cm.tlsHost(), trace); err != nil {
return nil, err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
plainConn.Close()
return nil, err
}
}
cs := tlsConn.ConnectionState()
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(cs, nil)
}
pconn.tlsState = &cs
pconn.conn = tlsConn
}
if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
@ -1279,13 +1308,16 @@ func useProxy(addr string) bool {
// http://proxy.com|http http to proxy, http to anywhere after that
// socks5://proxy.com|http|foo.com socks5 to proxy, then http to foo.com
// socks5://proxy.com|https|foo.com socks5 to proxy, then https to foo.com
//
// Note: no support to https to the proxy yet.
// https://proxy.com|https|foo.com https to proxy, then CONNECT to foo.com
// https://proxy.com|http https to proxy, http to anywhere after that
//
type connectMethod struct {
proxyURL *url.URL // nil for no proxy, else full proxy URL
targetScheme string // "http" or "https"
targetAddr string // Not used if http proxy + http targetScheme (4th example in table)
// If proxyURL specifies an http or https proxy, and targetScheme is http (not https),
// then targetAddr is not included in the connect method key, because the socket can
// be reused for different targetAddr values.
targetAddr string
}
func (cm *connectMethod) key() connectMethodKey {
@ -1293,7 +1325,7 @@ func (cm *connectMethod) key() connectMethodKey {
targetAddr := cm.targetAddr
if cm.proxyURL != nil {
proxyStr = cm.proxyURL.String()
if strings.HasPrefix(cm.proxyURL.Scheme, "http") && cm.targetScheme == "http" {
if (cm.proxyURL.Scheme == "http" || cm.proxyURL.Scheme == "https") && cm.targetScheme == "http" {
targetAddr = ""
}
}
@ -1304,6 +1336,14 @@ func (cm *connectMethod) key() connectMethodKey {
}
}
// scheme returns the first hop scheme: http, https, or socks5
func (cm *connectMethod) scheme() string {
if cm.proxyURL != nil {
return cm.proxyURL.Scheme
}
return cm.targetScheme
}
// addr returns the first hop "host:port" to which we need to TCP connect.
func (cm *connectMethod) addr() string {
if cm.proxyURL != nil {

View File

@ -961,14 +961,10 @@ func TestTransportExpect100Continue(t *testing.T) {
func TestSocks5Proxy(t *testing.T) {
defer afterTest(t)
ch := make(chan string, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ch <- "real server"
}))
defer ts.Close()
l := newLocalListener(t)
defer l.Close()
go func() {
defer close(ch)
defer close(ch)
proxy := func(t *testing.T) {
s, err := l.Accept()
if err != nil {
t.Errorf("socks5 proxy Accept(): %v", err)
@ -1003,7 +999,8 @@ func TestSocks5Proxy(t *testing.T) {
case 4:
ipLen = 16
default:
t.Fatalf("socks5 proxy second read: unexpected address type %v", buf[4])
t.Errorf("socks5 proxy second read: unexpected address type %v", buf[4])
return
}
if _, err := io.ReadFull(s, buf[4:ipLen+6]); err != nil {
t.Errorf("socks5 proxy address read: %v", err)
@ -1016,71 +1013,197 @@ func TestSocks5Proxy(t *testing.T) {
t.Errorf("socks5 proxy connect write: %v", err)
return
}
done := make(chan struct{})
srv := &Server{Handler: HandlerFunc(func(w ResponseWriter, r *Request) {
done <- struct{}{}
})}
srv.Serve(&oneConnListener{conn: s})
<-done
srv.Shutdown(context.Background())
ch <- fmt.Sprintf("proxy for %s:%d", ip, port)
}()
// Implement proxying.
targetHost := net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))
targetConn, err := net.Dial("tcp", targetHost)
if err != nil {
t.Errorf("net.Dial failed")
return
}
go io.Copy(targetConn, s)
io.Copy(s, targetConn) // Wait for the client to close the socket.
targetConn.Close()
}
pu, err := url.Parse("socks5://" + l.Addr().String())
if err != nil {
t.Fatal(err)
}
c := ts.Client()
c.Transport.(*Transport).Proxy = ProxyURL(pu)
if _, err := c.Head(ts.URL); err != nil {
t.Error(err)
}
var got string
select {
case got = <-ch:
case <-time.After(5 * time.Second):
t.Fatal("timeout connecting to socks5 proxy")
}
tsu, err := url.Parse(ts.URL)
if err != nil {
t.Fatal(err)
}
want := "proxy for " + tsu.Host
if got != want {
t.Errorf("got %q, want %q", got, want)
sentinelHeader := "X-Sentinel"
sentinelValue := "12345"
h := HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set(sentinelHeader, sentinelValue)
})
for _, useTLS := range []bool{false, true} {
t.Run(fmt.Sprintf("useTLS=%v", useTLS), func(t *testing.T) {
var ts *httptest.Server
if useTLS {
ts = httptest.NewTLSServer(h)
} else {
ts = httptest.NewServer(h)
}
go proxy(t)
c := ts.Client()
c.Transport.(*Transport).Proxy = ProxyURL(pu)
r, err := c.Head(ts.URL)
if err != nil {
t.Fatal(err)
}
if r.Header.Get(sentinelHeader) != sentinelValue {
t.Errorf("Failed to retrieve sentinel value")
}
var got string
select {
case got = <-ch:
case <-time.After(5 * time.Second):
t.Fatal("timeout connecting to socks5 proxy")
}
ts.Close()
tsu, err := url.Parse(ts.URL)
if err != nil {
t.Fatal(err)
}
want := "proxy for " + tsu.Host
if got != want {
t.Errorf("got %q, want %q", got, want)
}
})
}
}
func TestTransportProxy(t *testing.T) {
defer afterTest(t)
ch := make(chan string, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ch <- "real server"
}))
defer ts.Close()
proxy := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
ch <- "proxy for " + r.URL.String()
}))
defer proxy.Close()
testCases := []struct{ httpsSite, httpsProxy bool }{
{false, false},
{false, true},
{true, false},
{true, true},
}
for _, testCase := range testCases {
httpsSite := testCase.httpsSite
httpsProxy := testCase.httpsProxy
t.Run(fmt.Sprintf("httpsSite=%v, httpsProxy=%v", httpsSite, httpsProxy), func(t *testing.T) {
siteCh := make(chan *Request, 1)
h1 := HandlerFunc(func(w ResponseWriter, r *Request) {
siteCh <- r
})
proxyCh := make(chan *Request, 1)
h2 := HandlerFunc(func(w ResponseWriter, r *Request) {
proxyCh <- r
// Implement an entire CONNECT proxy
if r.Method == "CONNECT" {
hijacker, ok := w.(Hijacker)
if !ok {
t.Errorf("hijack not allowed")
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
t.Errorf("hijacking failed")
return
}
res := &Response{
StatusCode: StatusOK,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(Header),
}
pu, err := url.Parse(proxy.URL)
if err != nil {
t.Fatal(err)
}
c := ts.Client()
c.Transport.(*Transport).Proxy = ProxyURL(pu)
if _, err := c.Head(ts.URL); err != nil {
t.Error(err)
}
var got string
select {
case got = <-ch:
case <-time.After(5 * time.Second):
t.Fatal("timeout connecting to http proxy")
}
want := "proxy for " + ts.URL + "/"
if got != want {
t.Errorf("got %q, want %q", got, want)
log.Printf("Dialing %s", r.URL.Host)
targetConn, err := net.Dial("tcp", r.URL.Host)
if err != nil {
t.Errorf("net.Dial failed")
return
}
if err := res.Write(clientConn); err != nil {
t.Errorf("Writing 200 OK failed")
return
}
go io.Copy(targetConn, clientConn)
go func() {
io.Copy(clientConn, targetConn)
targetConn.Close()
}()
}
})
var ts *httptest.Server
if httpsSite {
ts = httptest.NewTLSServer(h1)
} else {
ts = httptest.NewServer(h1)
}
var proxy *httptest.Server
if httpsProxy {
proxy = httptest.NewTLSServer(h2)
} else {
proxy = httptest.NewServer(h2)
}
pu, err := url.Parse(proxy.URL)
if err != nil {
t.Fatal(err)
}
// If neither server is HTTPS or both are, then c may be derived from either.
// If only one server is HTTPS, c must be derived from that server in order
// to ensure that it is configured to use the fake root CA from testcert.go.
c := proxy.Client()
if httpsSite {
c = ts.Client()
}
c.Transport.(*Transport).Proxy = ProxyURL(pu)
if _, err := c.Head(ts.URL); err != nil {
t.Error(err)
}
var got *Request
select {
case got = <-proxyCh:
case <-time.After(5 * time.Second):
t.Fatal("timeout connecting to http proxy")
}
c.Transport.(*Transport).CloseIdleConnections()
ts.Close()
proxy.Close()
if httpsSite {
// First message should be a CONNECT, asking for a socket to the real server,
if got.Method != "CONNECT" {
t.Errorf("Wrong method for secure proxying: %q", got.Method)
}
gotHost := got.URL.Host
pu, err := url.Parse(ts.URL)
if err != nil {
t.Fatal("Invalid site URL")
}
if wantHost := pu.Host; gotHost != wantHost {
t.Errorf("Got CONNECT host %q, want %q", gotHost, wantHost)
}
// The next message on the channel should be from the site's server.
next := <-siteCh
if next.Method != "HEAD" {
t.Errorf("Wrong method at destination: %s", next.Method)
}
if nextURL := next.URL.String(); nextURL != "/" {
t.Errorf("Wrong URL at destination: %s", nextURL)
}
} else {
if got.Method != "HEAD" {
t.Errorf("Wrong method for destination: %q", got.Method)
}
gotURL := got.URL.String()
wantURL := ts.URL + "/"
if gotURL != wantURL {
t.Errorf("Got URL %q, want %q", gotURL, wantURL)
}
}
})
}
}