mirror of
https://github.com/golang/go
synced 2024-11-23 23:40:13 -07:00
crypto/tls: add GetClientCertificate callback
Currently, the selection of a client certificate done internally based on the limitations given by the server's request and the certifcates in the Config. This means that it's not possible for an application to control that selection based on details of the request. This change adds a callback, GetClientCertificate, that is called by a Client during the handshake and which allows applications to select the best certificate at that time. (Based on https://golang.org/cl/25570/ by Bernd Fix.) Fixes #16626. Change-Id: Ia4cea03235d2aa3c9fd49c99c227593c8e86ddd9 Reviewed-on: https://go-review.googlesource.com/32115 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
parent
6c242c52d3
commit
81038d2e2b
@ -286,6 +286,21 @@ type ClientHelloInfo struct {
|
||||
Conn net.Conn
|
||||
}
|
||||
|
||||
// CertificateRequestInfo contains information from a server's
|
||||
// CertificateRequest message, which is used to demand a certificate and proof
|
||||
// of control from a client.
|
||||
type CertificateRequestInfo struct {
|
||||
// AcceptableCAs contains zero or more, DER-encoded, X.501
|
||||
// Distinguished Names. These are the names of root or intermediate CAs
|
||||
// that the server wishes the returned certificate to be signed by. An
|
||||
// empty slice indicates that the server has no preference.
|
||||
AcceptableCAs [][]byte
|
||||
|
||||
// SignatureSchemes lists the signature schemes that the server is
|
||||
// willing to verify.
|
||||
SignatureSchemes []SignatureScheme
|
||||
}
|
||||
|
||||
// RenegotiationSupport enumerates the different levels of support for TLS
|
||||
// renegotiation. TLS renegotiation is the act of performing subsequent
|
||||
// handshakes on a connection after the first. This significantly complicates
|
||||
@ -328,10 +343,11 @@ type Config struct {
|
||||
// If Time is nil, TLS uses time.Now.
|
||||
Time func() time.Time
|
||||
|
||||
// Certificates contains one or more certificate chains
|
||||
// to present to the other side of the connection.
|
||||
// Server configurations must include at least one certificate
|
||||
// or else set GetCertificate.
|
||||
// Certificates contains one or more certificate chains to present to
|
||||
// the other side of the connection. Server configurations must include
|
||||
// at least one certificate or else set GetCertificate. Clients doing
|
||||
// client-authentication may set either Certificates or
|
||||
// GetClientCertificate.
|
||||
Certificates []Certificate
|
||||
|
||||
// NameToCertificate maps from a certificate name to an element of
|
||||
@ -351,6 +367,21 @@ type Config struct {
|
||||
// first element of Certificates will be used.
|
||||
GetCertificate func(*ClientHelloInfo) (*Certificate, error)
|
||||
|
||||
// GetClientCertificate, if not nil, is called when a server requests a
|
||||
// certificate from a client. If set, the contents of Certificates will
|
||||
// be ignored.
|
||||
//
|
||||
// If GetClientCertificate returns an error, the handshake will be
|
||||
// aborted and that error will be returned. Otherwise
|
||||
// GetClientCertificate must return a non-nil Certificate. If
|
||||
// Certificate.Certificate is empty then no certificate will be sent to
|
||||
// the server. If this is unacceptable to the server then it may abort
|
||||
// the handshake.
|
||||
//
|
||||
// GetClientCertificate may be called multiple times for the same
|
||||
// connection if renegotiation occurs or if TLS 1.3 is in use.
|
||||
GetClientCertificate func(*CertificateRequestInfo) (*Certificate, error)
|
||||
|
||||
// GetConfigForClient, if not nil, is called after a ClientHello is
|
||||
// received from a client. It may return a non-nil Config in order to
|
||||
// change the Config that will be used to handle this connection. If
|
||||
|
@ -199,7 +199,7 @@ NextCipherSuite:
|
||||
// Otherwise, in a full handshake, if we don't have any certificates
|
||||
// configured then we will never send a CertificateVerify message and
|
||||
// thus no signatures are needed in that case either.
|
||||
if isResume || len(c.config.Certificates) == 0 {
|
||||
if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) {
|
||||
hs.finishedHash.discardHandshakeBuffer()
|
||||
}
|
||||
|
||||
@ -377,71 +377,11 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
||||
certReq, ok := msg.(*certificateRequestMsg)
|
||||
if ok {
|
||||
certRequested = true
|
||||
|
||||
// RFC 4346 on the certificateAuthorities field:
|
||||
// A list of the distinguished names of acceptable certificate
|
||||
// authorities. These distinguished names may specify a desired
|
||||
// distinguished name for a root CA or for a subordinate CA;
|
||||
// thus, this message can be used to describe both known roots
|
||||
// and a desired authorization space. If the
|
||||
// certificate_authorities list is empty then the client MAY
|
||||
// send any certificate of the appropriate
|
||||
// ClientCertificateType, unless there is some external
|
||||
// arrangement to the contrary.
|
||||
|
||||
hs.finishedHash.Write(certReq.marshal())
|
||||
|
||||
var rsaAvail, ecdsaAvail bool
|
||||
for _, certType := range certReq.certificateTypes {
|
||||
switch certType {
|
||||
case certTypeRSASign:
|
||||
rsaAvail = true
|
||||
case certTypeECDSASign:
|
||||
ecdsaAvail = true
|
||||
}
|
||||
}
|
||||
|
||||
// We need to search our list of client certs for one
|
||||
// where SignatureAlgorithm is acceptable to the server and the
|
||||
// Issuer is in certReq.certificateAuthorities
|
||||
findCert:
|
||||
for i, chain := range c.config.Certificates {
|
||||
if !rsaAvail && !ecdsaAvail {
|
||||
continue
|
||||
}
|
||||
|
||||
for j, cert := range chain.Certificate {
|
||||
x509Cert := chain.Leaf
|
||||
// parse the certificate if this isn't the leaf
|
||||
// node, or if chain.Leaf was nil
|
||||
if j != 0 || x509Cert == nil {
|
||||
if x509Cert, err = x509.ParseCertificate(cert); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA:
|
||||
case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA:
|
||||
default:
|
||||
continue findCert
|
||||
}
|
||||
|
||||
if len(certReq.certificateAuthorities) == 0 {
|
||||
// they gave us an empty list, so just take the
|
||||
// first cert from c.config.Certificates
|
||||
chainToSend = &chain
|
||||
break findCert
|
||||
}
|
||||
|
||||
for _, ca := range certReq.certificateAuthorities {
|
||||
if bytes.Equal(x509Cert.RawIssuer, ca) {
|
||||
chainToSend = &chain
|
||||
break findCert
|
||||
}
|
||||
}
|
||||
}
|
||||
if chainToSend, err = hs.getCertificate(certReq); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return err
|
||||
}
|
||||
|
||||
msg, err = c.readHandshake()
|
||||
@ -462,9 +402,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
||||
// certificate to send.
|
||||
if certRequested {
|
||||
certMsg = new(certificateMsg)
|
||||
if chainToSend != nil {
|
||||
certMsg.certificates = chainToSend.Certificate
|
||||
}
|
||||
certMsg.certificates = chainToSend.Certificate
|
||||
hs.finishedHash.Write(certMsg.marshal())
|
||||
if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
|
||||
return err
|
||||
@ -483,7 +421,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
|
||||
}
|
||||
}
|
||||
|
||||
if chainToSend != nil {
|
||||
if chainToSend != nil && len(chainToSend.Certificate) > 0 {
|
||||
certVerify := &certificateVerifyMsg{
|
||||
hasSignatureAndHash: c.vers >= VersionTLS12,
|
||||
}
|
||||
@ -727,6 +665,117 @@ func (hs *clientHandshakeState) sendFinished(out []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// tls11SignatureSchemes contains the signature schemes that we synthesise for
|
||||
// a TLS <= 1.1 connection, based on the supported certificate types.
|
||||
var tls11SignatureSchemes = []SignatureScheme{ECDSAWithP256AndSHA256, ECDSAWithP384AndSHA384, ECDSAWithP521AndSHA512, PKCS1WithSHA256, PKCS1WithSHA384, PKCS1WithSHA512, PKCS1WithSHA1}
|
||||
|
||||
const (
|
||||
// tls11SignatureSchemesNumECDSA is the number of initial elements of
|
||||
// tls11SignatureSchemes that use ECDSA.
|
||||
tls11SignatureSchemesNumECDSA = 3
|
||||
// tls11SignatureSchemesNumRSA is the number of trailing elements of
|
||||
// tls11SignatureSchemes that use RSA.
|
||||
tls11SignatureSchemesNumRSA = 4
|
||||
)
|
||||
|
||||
func (hs *clientHandshakeState) getCertificate(certReq *certificateRequestMsg) (*Certificate, error) {
|
||||
c := hs.c
|
||||
|
||||
var rsaAvail, ecdsaAvail bool
|
||||
for _, certType := range certReq.certificateTypes {
|
||||
switch certType {
|
||||
case certTypeRSASign:
|
||||
rsaAvail = true
|
||||
case certTypeECDSASign:
|
||||
ecdsaAvail = true
|
||||
}
|
||||
}
|
||||
|
||||
if c.config.GetClientCertificate != nil {
|
||||
var signatureSchemes []SignatureScheme
|
||||
|
||||
if !certReq.hasSignatureAndHash {
|
||||
// Prior to TLS 1.2, the signature schemes were not
|
||||
// included in the certificate request message. In this
|
||||
// case we use a plausible list based on the acceptable
|
||||
// certificate types.
|
||||
signatureSchemes = tls11SignatureSchemes
|
||||
if !ecdsaAvail {
|
||||
signatureSchemes = signatureSchemes[tls11SignatureSchemesNumECDSA:]
|
||||
}
|
||||
if !rsaAvail {
|
||||
signatureSchemes = signatureSchemes[:len(signatureSchemes)-tls11SignatureSchemesNumRSA]
|
||||
}
|
||||
} else {
|
||||
signatureSchemes = make([]SignatureScheme, 0, len(certReq.signatureAndHashes))
|
||||
for _, sah := range certReq.signatureAndHashes {
|
||||
signatureSchemes = append(signatureSchemes, SignatureScheme(sah.hash)<<8+SignatureScheme(sah.signature))
|
||||
}
|
||||
}
|
||||
|
||||
return c.config.GetClientCertificate(&CertificateRequestInfo{
|
||||
AcceptableCAs: certReq.certificateAuthorities,
|
||||
SignatureSchemes: signatureSchemes,
|
||||
})
|
||||
}
|
||||
|
||||
// RFC 4346 on the certificateAuthorities field: A list of the
|
||||
// distinguished names of acceptable certificate authorities.
|
||||
// These distinguished names may specify a desired
|
||||
// distinguished name for a root CA or for a subordinate CA;
|
||||
// thus, this message can be used to describe both known roots
|
||||
// and a desired authorization space. If the
|
||||
// certificate_authorities list is empty then the client MAY
|
||||
// send any certificate of the appropriate
|
||||
// ClientCertificateType, unless there is some external
|
||||
// arrangement to the contrary.
|
||||
|
||||
// We need to search our list of client certs for one
|
||||
// where SignatureAlgorithm is acceptable to the server and the
|
||||
// Issuer is in certReq.certificateAuthorities
|
||||
findCert:
|
||||
for i, chain := range c.config.Certificates {
|
||||
if !rsaAvail && !ecdsaAvail {
|
||||
continue
|
||||
}
|
||||
|
||||
for j, cert := range chain.Certificate {
|
||||
x509Cert := chain.Leaf
|
||||
// parse the certificate if this isn't the leaf
|
||||
// node, or if chain.Leaf was nil
|
||||
if j != 0 || x509Cert == nil {
|
||||
var err error
|
||||
if x509Cert, err = x509.ParseCertificate(cert); err != nil {
|
||||
c.sendAlert(alertInternalError)
|
||||
return nil, errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA:
|
||||
case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA:
|
||||
default:
|
||||
continue findCert
|
||||
}
|
||||
|
||||
if len(certReq.certificateAuthorities) == 0 {
|
||||
// they gave us an empty list, so just take the
|
||||
// first cert from c.config.Certificates
|
||||
return &chain, nil
|
||||
}
|
||||
|
||||
for _, ca := range certReq.certificateAuthorities {
|
||||
if bytes.Equal(x509Cert.RawIssuer, ca) {
|
||||
return &chain, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// No acceptable certificate found. Don't send a certificate.
|
||||
return new(Certificate), nil
|
||||
}
|
||||
|
||||
// clientSessionCacheKey returns a key used to cache sessionTickets that could
|
||||
// be used to resume previously negotiated TLS sessions with a server.
|
||||
func clientSessionCacheKey(serverAddr net.Addr, config *Config) string {
|
||||
|
@ -1408,3 +1408,137 @@ func TestHandshakeRace(t *testing.T) {
|
||||
<-readDone
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLS11SignatureSchemes(t *testing.T) {
|
||||
expected := tls11SignatureSchemesNumECDSA + tls11SignatureSchemesNumRSA
|
||||
if expected != len(tls11SignatureSchemes) {
|
||||
t.Errorf("expected to find %d TLS 1.1 signature schemes, but found %d", expected, len(tls11SignatureSchemes))
|
||||
}
|
||||
}
|
||||
|
||||
var getClientCertificateTests = []struct {
|
||||
setup func(*Config)
|
||||
expectedClientError string
|
||||
verify func(*testing.T, int, *ConnectionState)
|
||||
}{
|
||||
{
|
||||
func(clientConfig *Config) {
|
||||
// Returning a Certificate with no certificate data
|
||||
// should result in an empty message being sent to the
|
||||
// server.
|
||||
clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
|
||||
if len(cri.SignatureSchemes) == 0 {
|
||||
panic("empty SignatureSchemes")
|
||||
}
|
||||
return new(Certificate), nil
|
||||
}
|
||||
},
|
||||
"",
|
||||
func(t *testing.T, testNum int, cs *ConnectionState) {
|
||||
if l := len(cs.PeerCertificates); l != 0 {
|
||||
t.Errorf("#%d: expected no certificates but got %d", testNum, l)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
func(clientConfig *Config) {
|
||||
// With TLS 1.1, the SignatureSchemes should be
|
||||
// synthesised from the supported certificate types.
|
||||
clientConfig.MaxVersion = VersionTLS11
|
||||
clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
|
||||
if len(cri.SignatureSchemes) == 0 {
|
||||
panic("empty SignatureSchemes")
|
||||
}
|
||||
return new(Certificate), nil
|
||||
}
|
||||
},
|
||||
"",
|
||||
func(t *testing.T, testNum int, cs *ConnectionState) {
|
||||
if l := len(cs.PeerCertificates); l != 0 {
|
||||
t.Errorf("#%d: expected no certificates but got %d", testNum, l)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
func(clientConfig *Config) {
|
||||
// Returning an error should abort the handshake with
|
||||
// that error.
|
||||
clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
|
||||
return nil, errors.New("GetClientCertificate")
|
||||
}
|
||||
},
|
||||
"GetClientCertificate",
|
||||
func(t *testing.T, testNum int, cs *ConnectionState) {
|
||||
},
|
||||
},
|
||||
{
|
||||
func(clientConfig *Config) {
|
||||
clientConfig.GetClientCertificate = func(cri *CertificateRequestInfo) (*Certificate, error) {
|
||||
return &testConfig.Certificates[0], nil
|
||||
}
|
||||
},
|
||||
"",
|
||||
func(t *testing.T, testNum int, cs *ConnectionState) {
|
||||
if l := len(cs.VerifiedChains); l != 0 {
|
||||
t.Errorf("#%d: expected some verified chains, but found none", testNum)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestGetClientCertificate(t *testing.T) {
|
||||
issuer, err := x509.ParseCertificate(testRSACertificateIssuer)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
for i, test := range getClientCertificateTests {
|
||||
serverConfig := testConfig.Clone()
|
||||
serverConfig.ClientAuth = RequestClientCert
|
||||
serverConfig.RootCAs = x509.NewCertPool()
|
||||
serverConfig.RootCAs.AddCert(issuer)
|
||||
|
||||
clientConfig := testConfig.Clone()
|
||||
|
||||
test.setup(clientConfig)
|
||||
|
||||
type serverResult struct {
|
||||
cs ConnectionState
|
||||
err error
|
||||
}
|
||||
|
||||
c, s := net.Pipe()
|
||||
done := make(chan serverResult)
|
||||
|
||||
go func() {
|
||||
defer s.Close()
|
||||
server := Server(s, serverConfig)
|
||||
err := server.Handshake()
|
||||
|
||||
var cs ConnectionState
|
||||
if err == nil {
|
||||
cs = server.ConnectionState()
|
||||
}
|
||||
done <- serverResult{cs, err}
|
||||
}()
|
||||
|
||||
clientErr := Client(c, clientConfig).Handshake()
|
||||
c.Close()
|
||||
|
||||
result := <-done
|
||||
|
||||
if clientErr != nil {
|
||||
if len(test.expectedClientError) == 0 {
|
||||
t.Errorf("#%d: client error: %v", i, clientErr)
|
||||
} else if got := clientErr.Error(); got != test.expectedClientError {
|
||||
t.Errorf("#%d: expected client error %q, but got %q", i, test.expectedClientError, got)
|
||||
}
|
||||
} else if len(test.expectedClientError) > 0 {
|
||||
t.Errorf("#%d: expected client error %q, but got no error", i, test.expectedClientError)
|
||||
} else if err := result.err; err != nil {
|
||||
t.Errorf("#%d: server error: %v", i, err)
|
||||
} else {
|
||||
test.verify(t, i, &result.cs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -584,7 +584,7 @@ func TestClone(t *testing.T) {
|
||||
case "Rand":
|
||||
f.Set(reflect.ValueOf(io.Reader(os.Stdin)))
|
||||
continue
|
||||
case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate":
|
||||
case "Time", "GetCertificate", "GetConfigForClient", "VerifyPeerCertificate", "GetClientCertificate":
|
||||
// DeepEqual can't compare functions.
|
||||
continue
|
||||
case "Certificates":
|
||||
|
Loading…
Reference in New Issue
Block a user