1
0
mirror of https://github.com/golang/go synced 2024-11-21 21:34:40 -07:00

net/http: Set TLSClientConfig.ServerName on every HTTP request.

This makes SNI "just work" for callers using the standard http.Client.

Since we now have a test that depends on the httptest.Server cert, change
the cert to be a CA (keeping all other fields the same).

R=bradfitz
CC=agl, dsymonds, gobot, golang-dev
https://golang.org/cl/6448154
This commit is contained in:
Dave Borowitz 2012-08-22 09:15:41 -07:00 committed by Brad Fitzpatrick
parent c27a7dbf78
commit 2eb6a16e16
3 changed files with 66 additions and 8 deletions

View File

@ -8,6 +8,7 @@ package http_test
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
@ -470,3 +471,49 @@ func TestClientErrorWithRequestURI(t *testing.T) {
t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
}
}
func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
certs := x509.NewCertPool()
for _, c := range ts.TLS.Certificates {
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
if err != nil {
t.Fatalf("error parsing server's root cert: %v", err)
}
for _, root := range roots {
certs.AddCert(root)
}
}
return &Transport{
TLSClientConfig: &tls.Config{RootCAs: certs},
}
}
func TestClientWithCorrectTLSServerName(t *testing.T) {
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
if r.TLS.ServerName != "127.0.0.1" {
t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName)
}
}))
defer ts.Close()
c := &Client{Transport: newTLSTransport(t, ts)}
if _, err := c.Get(ts.URL); err != nil {
t.Fatalf("expected successful TLS connection, got error: %v", err)
}
}
func TestClientWithIncorrectTLSServerName(t *testing.T) {
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer ts.Close()
trans := newTLSTransport(t, ts)
trans.TLSClientConfig.ServerName = "badserver"
c := &Client{Transport: trans}
_, err := c.Get(ts.URL)
if err == nil {
t.Fatalf("expected an error")
}
if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
}
}

View File

@ -184,15 +184,15 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
// of ASN.1 time).
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
MIIBTTCB+qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7
qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL
8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw
DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG
CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7
+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM=
-----END CERTIFICATE-----
`)
8i1UQF6AzwIDAQABo2MwYTAOBgNVHQ8BAf8EBAMCACQwEgYDVR0TAQH/BAgwBgEB
/wIBATANBgNVHQ4EBgQEAQIDBDAPBgNVHSMECDAGgAQBAgMEMBsGA1UdEQQUMBKC
CTEyNy4wLjAuMYIFWzo6MV0wCwYJKoZIhvcNAQEFA0EAj1Jsn/h2KHy7dgqutZNB
nCGlNN+8vw263Bax9MklR85Ti6a0VWSvp/fDQZUADvmFTDkcXeA24pqmdUxeQDWw
Pg==
-----END CERTIFICATE-----`)
// localhostKey is the private key for localhostCert.
var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----

View File

@ -379,7 +379,18 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
if cm.targetScheme == "https" {
// Initiate TLS and check remote host name against certificate.
conn = tls.Client(conn, t.TLSClientConfig)
cfg := t.TLSClientConfig
if cfg == nil || cfg.ServerName == "" {
host, _, _ := net.SplitHostPort(cm.addr())
if cfg == nil {
cfg = &tls.Config{ServerName: host}
} else {
clone := *cfg // shallow clone
clone.ServerName = host
cfg = &clone
}
}
conn = tls.Client(conn, cfg)
if err = conn.(*tls.Conn).Handshake(); err != nil {
return nil, err
}