From 2eb6a16e16411a394527447b4f6ec0ba838b18e8 Mon Sep 17 00:00:00 2001 From: Dave Borowitz Date: Wed, 22 Aug 2012 09:15:41 -0700 Subject: [PATCH] net/http: Set TLSClientConfig.ServerName on every HTTP request. This makes SNI "just work" for callers using the standard http.Client. Since we now have a test that depends on the httptest.Server cert, change the cert to be a CA (keeping all other fields the same). R=bradfitz CC=agl, dsymonds, gobot, golang-dev https://golang.org/cl/6448154 --- src/pkg/net/http/client_test.go | 47 +++++++++++++++++++++++++++++ src/pkg/net/http/httptest/server.go | 14 ++++----- src/pkg/net/http/transport.go | 13 +++++++- 3 files changed, 66 insertions(+), 8 deletions(-) diff --git a/src/pkg/net/http/client_test.go b/src/pkg/net/http/client_test.go index da7a44da7a5..c61b17d289b 100644 --- a/src/pkg/net/http/client_test.go +++ b/src/pkg/net/http/client_test.go @@ -8,6 +8,7 @@ package http_test import ( "crypto/tls" + "crypto/x509" "errors" "fmt" "io" @@ -470,3 +471,49 @@ func TestClientErrorWithRequestURI(t *testing.T) { t.Errorf("wanted error mentioning RequestURI; got error: %v", err) } } + +func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport { + certs := x509.NewCertPool() + for _, c := range ts.TLS.Certificates { + roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) + if err != nil { + t.Fatalf("error parsing server's root cert: %v", err) + } + for _, root := range roots { + certs.AddCert(root) + } + } + return &Transport{ + TLSClientConfig: &tls.Config{RootCAs: certs}, + } +} + +func TestClientWithCorrectTLSServerName(t *testing.T) { + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.TLS.ServerName != "127.0.0.1" { + t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName) + } + })) + defer ts.Close() + + c := &Client{Transport: newTLSTransport(t, ts)} + if _, err := c.Get(ts.URL); err != nil { + t.Fatalf("expected successful TLS connection, got error: %v", err) + } +} + +func TestClientWithIncorrectTLSServerName(t *testing.T) { + ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer ts.Close() + + trans := newTLSTransport(t, ts) + trans.TLSClientConfig.ServerName = "badserver" + c := &Client{Transport: trans} + _, err := c.Get(ts.URL) + if err == nil { + t.Fatalf("expected an error") + } + if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") { + t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err) + } +} diff --git a/src/pkg/net/http/httptest/server.go b/src/pkg/net/http/httptest/server.go index 57cf0c9417d..165600e52be 100644 --- a/src/pkg/net/http/httptest/server.go +++ b/src/pkg/net/http/httptest/server.go @@ -184,15 +184,15 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end // of ASN.1 time). var localhostCert = []byte(`-----BEGIN CERTIFICATE----- -MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX +MIIBTTCB+qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7 qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL -8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw -DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG -CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7 -+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM= ------END CERTIFICATE----- -`) +8i1UQF6AzwIDAQABo2MwYTAOBgNVHQ8BAf8EBAMCACQwEgYDVR0TAQH/BAgwBgEB +/wIBATANBgNVHQ4EBgQEAQIDBDAPBgNVHSMECDAGgAQBAgMEMBsGA1UdEQQUMBKC +CTEyNy4wLjAuMYIFWzo6MV0wCwYJKoZIhvcNAQEFA0EAj1Jsn/h2KHy7dgqutZNB +nCGlNN+8vw263Bax9MklR85Ti6a0VWSvp/fDQZUADvmFTDkcXeA24pqmdUxeQDWw +Pg== +-----END CERTIFICATE-----`) // localhostKey is the private key for localhostCert. var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY----- diff --git a/src/pkg/net/http/transport.go b/src/pkg/net/http/transport.go index fe6318824e0..a33d787f25d 100644 --- a/src/pkg/net/http/transport.go +++ b/src/pkg/net/http/transport.go @@ -379,7 +379,18 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) { if cm.targetScheme == "https" { // Initiate TLS and check remote host name against certificate. - conn = tls.Client(conn, t.TLSClientConfig) + cfg := t.TLSClientConfig + if cfg == nil || cfg.ServerName == "" { + host, _, _ := net.SplitHostPort(cm.addr()) + if cfg == nil { + cfg = &tls.Config{ServerName: host} + } else { + clone := *cfg // shallow clone + clone.ServerName = host + cfg = &clone + } + } + conn = tls.Client(conn, cfg) if err = conn.(*tls.Conn).Handshake(); err != nil { return nil, err }