mirror of
https://github.com/golang/go
synced 2024-11-21 21:34:40 -07:00
net/http: Set TLSClientConfig.ServerName on every HTTP request.
This makes SNI "just work" for callers using the standard http.Client. Since we now have a test that depends on the httptest.Server cert, change the cert to be a CA (keeping all other fields the same). R=bradfitz CC=agl, dsymonds, gobot, golang-dev https://golang.org/cl/6448154
This commit is contained in:
parent
c27a7dbf78
commit
2eb6a16e16
@ -8,6 +8,7 @@ package http_test
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -470,3 +471,49 @@ func TestClientErrorWithRequestURI(t *testing.T) {
|
||||
t.Errorf("wanted error mentioning RequestURI; got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
|
||||
certs := x509.NewCertPool()
|
||||
for _, c := range ts.TLS.Certificates {
|
||||
roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing server's root cert: %v", err)
|
||||
}
|
||||
for _, root := range roots {
|
||||
certs.AddCert(root)
|
||||
}
|
||||
}
|
||||
return &Transport{
|
||||
TLSClientConfig: &tls.Config{RootCAs: certs},
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientWithCorrectTLSServerName(t *testing.T) {
|
||||
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
|
||||
if r.TLS.ServerName != "127.0.0.1" {
|
||||
t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := &Client{Transport: newTLSTransport(t, ts)}
|
||||
if _, err := c.Get(ts.URL); err != nil {
|
||||
t.Fatalf("expected successful TLS connection, got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientWithIncorrectTLSServerName(t *testing.T) {
|
||||
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
|
||||
defer ts.Close()
|
||||
|
||||
trans := newTLSTransport(t, ts)
|
||||
trans.TLSClientConfig.ServerName = "badserver"
|
||||
c := &Client{Transport: trans}
|
||||
_, err := c.Get(ts.URL)
|
||||
if err == nil {
|
||||
t.Fatalf("expected an error")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "127.0.0.1") || !strings.Contains(err.Error(), "badserver") {
|
||||
t.Errorf("wanted error mentioning 127.0.0.1 and badserver; got error: %v", err)
|
||||
}
|
||||
}
|
||||
|
@ -184,15 +184,15 @@ func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
|
||||
// of ASN.1 time).
|
||||
var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIBOTCB5qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
|
||||
MIIBTTCB+qADAgECAgEAMAsGCSqGSIb3DQEBBTAAMB4XDTcwMDEwMTAwMDAwMFoX
|
||||
DTQ5MTIzMTIzNTk1OVowADBaMAsGCSqGSIb3DQEBAQNLADBIAkEAsuA5mAFMj6Q7
|
||||
qoBzcvKzIq4kzuT5epSp2AkcQfyBHm7K13Ws7u+0b5Vb9gqTf5cAiIKcrtrXVqkL
|
||||
8i1UQF6AzwIDAQABo08wTTAOBgNVHQ8BAf8EBAMCACQwDQYDVR0OBAYEBAECAwQw
|
||||
DwYDVR0jBAgwBoAEAQIDBDAbBgNVHREEFDASggkxMjcuMC4wLjGCBVs6OjFdMAsG
|
||||
CSqGSIb3DQEBBQNBAJH30zjLWRztrWpOCgJL8RQWLaKzhK79pVhAx6q/3NrF16C7
|
||||
+l1BRZstTwIGdoGId8BRpErK1TXkniFb95ZMynM=
|
||||
-----END CERTIFICATE-----
|
||||
`)
|
||||
8i1UQF6AzwIDAQABo2MwYTAOBgNVHQ8BAf8EBAMCACQwEgYDVR0TAQH/BAgwBgEB
|
||||
/wIBATANBgNVHQ4EBgQEAQIDBDAPBgNVHSMECDAGgAQBAgMEMBsGA1UdEQQUMBKC
|
||||
CTEyNy4wLjAuMYIFWzo6MV0wCwYJKoZIhvcNAQEFA0EAj1Jsn/h2KHy7dgqutZNB
|
||||
nCGlNN+8vw263Bax9MklR85Ti6a0VWSvp/fDQZUADvmFTDkcXeA24pqmdUxeQDWw
|
||||
Pg==
|
||||
-----END CERTIFICATE-----`)
|
||||
|
||||
// localhostKey is the private key for localhostCert.
|
||||
var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
|
@ -379,7 +379,18 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
|
||||
|
||||
if cm.targetScheme == "https" {
|
||||
// Initiate TLS and check remote host name against certificate.
|
||||
conn = tls.Client(conn, t.TLSClientConfig)
|
||||
cfg := t.TLSClientConfig
|
||||
if cfg == nil || cfg.ServerName == "" {
|
||||
host, _, _ := net.SplitHostPort(cm.addr())
|
||||
if cfg == nil {
|
||||
cfg = &tls.Config{ServerName: host}
|
||||
} else {
|
||||
clone := *cfg // shallow clone
|
||||
clone.ServerName = host
|
||||
cfg = &clone
|
||||
}
|
||||
}
|
||||
conn = tls.Client(conn, cfg)
|
||||
if err = conn.(*tls.Conn).Handshake(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user