diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go index 90846a3659..fd21ae8fb1 100644 --- a/src/crypto/tls/common.go +++ b/src/crypto/tls/common.go @@ -219,7 +219,7 @@ type ConnectionState struct { CipherSuite uint16 // cipher suite in use (TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, ...) NegotiatedProtocol string // negotiated next protocol (not guaranteed to be from Config.NextProtos) NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server (client side only) - ServerName string // server name requested by client, if any (server side only) + ServerName string // server name requested by client, if any PeerCertificates []*x509.Certificate // certificate chain presented by remote peer VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates SignedCertificateTimestamps [][]byte // SCTs from the peer, if any @@ -520,6 +520,16 @@ type Config struct { // be considered but the verifiedChains argument will always be nil. VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error + // VerifyConnection, if not nil, is called after normal certificate + // verification and after VerifyPeerCertificate by either a TLS client + // or server. If it returns a non-nil error, the handshake is aborted + // and that error results. + // + // If normal verification fails then the handshake will abort before + // considering this callback. This callback will run for all connections + // regardless of InsecureSkipVerify or ClientAuth settings. + VerifyConnection func(ConnectionState) error + // RootCAs defines the set of root certificate authorities // that clients use when verifying server certificates. // If RootCAs is nil, TLS uses the host's root CA set. @@ -685,6 +695,7 @@ func (c *Config) Clone() *Config { GetClientCertificate: c.GetClientCertificate, GetConfigForClient: c.GetConfigForClient, VerifyPeerCertificate: c.VerifyPeerCertificate, + VerifyConnection: c.VerifyConnection, RootCAs: c.RootCAs, NextProtos: c.NextProtos, ServerName: c.ServerName, diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index d759986bb9..edcfecf81d 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -1379,35 +1379,34 @@ func (c *Conn) Handshake() error { func (c *Conn) ConnectionState() ConnectionState { c.handshakeMutex.Lock() defer c.handshakeMutex.Unlock() + return c.connectionStateLocked() +} +func (c *Conn) connectionStateLocked() ConnectionState { var state ConnectionState state.HandshakeComplete = c.handshakeComplete() + state.Version = c.vers + state.NegotiatedProtocol = c.clientProtocol + state.DidResume = c.didResume + state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback state.ServerName = c.serverName - - if state.HandshakeComplete { - state.Version = c.vers - state.NegotiatedProtocol = c.clientProtocol - state.DidResume = c.didResume - state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback - state.CipherSuite = c.cipherSuite - state.PeerCertificates = c.peerCertificates - state.VerifiedChains = c.verifiedChains - state.SignedCertificateTimestamps = c.scts - state.OCSPResponse = c.ocspResponse - if !c.didResume && c.vers != VersionTLS13 { - if c.clientFinishedIsFirst { - state.TLSUnique = c.clientFinished[:] - } else { - state.TLSUnique = c.serverFinished[:] - } - } - if c.config.Renegotiation != RenegotiateNever { - state.ekm = noExportedKeyingMaterial + state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains + state.SignedCertificateTimestamps = c.scts + state.OCSPResponse = c.ocspResponse + if !c.didResume && c.vers != VersionTLS13 { + if c.clientFinishedIsFirst { + state.TLSUnique = c.clientFinished[:] } else { - state.ekm = c.ekm + state.TLSUnique = c.serverFinished[:] } } - + if c.config.Renegotiation != RenegotiateNever { + state.ekm = noExportedKeyingMaterial + } else { + state.ekm = c.ekm + } return state } diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index 210eece26d..40c8e02c53 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -146,6 +146,7 @@ func (c *Conn) clientHandshake() (err error) { if err != nil { return err } + c.serverName = hello.serverName cacheKey, session, earlySecret, binderKey := c.loadSession(hello) if cacheKey != "" && session != nil { @@ -388,6 +389,7 @@ func (hs *clientHandshakeState) handshake() error { hs.finishedHash.Write(hs.serverHello.marshal()) c.buffering = true + c.didResume = isResume if isResume { if err := hs.establishKeys(); err != nil { return err @@ -399,6 +401,15 @@ func (hs *clientHandshakeState) handshake() error { return err } c.clientFinishedIsFirst = false + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } if err := hs.sendFinished(c.clientFinished[:]); err != nil { return err } @@ -428,7 +439,6 @@ func (hs *clientHandshakeState) handshake() error { } c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random) - c.didResume = isResume atomic.StoreUint32(&c.handshakeStatus, 1) return nil @@ -458,25 +468,6 @@ func (hs *clientHandshakeState) doFullHandshake() error { } hs.finishedHash.Write(certMsg.marshal()) - if c.handshakes == 0 { - // If this is the first handshake on a connection, process and - // (optionally) verify the server's certificates. - if err := c.verifyServerCertificate(certMsg.certificates); err != nil { - return err - } - } else { - // This is a renegotiation handshake. We require that the - // server's identity (i.e. leaf certificate) is unchanged and - // thus any previous trust decision is still valid. - // - // See https://mitls.org/pages/attacks/3SHAKE for the - // motivation behind this requirement. - if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { - c.sendAlert(alertBadCertificate) - return errors.New("tls: server's identity changed during renegotiation") - } - } - msg, err = c.readHandshake() if err != nil { return err @@ -505,6 +496,25 @@ func (hs *clientHandshakeState) doFullHandshake() error { } } + if c.handshakes == 0 { + // If this is the first handshake on a connection, process and + // (optionally) verify the server's certificates. + if err := c.verifyServerCertificate(certMsg.certificates); err != nil { + return err + } + } else { + // This is a renegotiation handshake. We require that the + // server's identity (i.e. leaf certificate) is unchanged and + // thus any previous trust decision is still valid. + // + // See https://mitls.org/pages/attacks/3SHAKE for the + // motivation behind this requirement. + if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) { + c.sendAlert(alertBadCertificate) + return errors.New("tls: server's identity changed during renegotiation") + } + } + keyAgreement := hs.suite.ka(c.vers) skx, ok := msg.(*serverKeyExchangeMsg) @@ -831,13 +841,6 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { } } - if c.config.VerifyPeerCertificate != nil { - if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { - c.sendAlert(alertBadCertificate) - return err - } - } - switch certs[0].PublicKey.(type) { case *rsa.PublicKey, *ecdsa.PublicKey, ed25519.PublicKey: break @@ -848,6 +851,20 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error { c.peerCertificates = certs + if c.config.VerifyPeerCertificate != nil { + if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + return nil } diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go index cd387dcc6c..313872ca76 100644 --- a/src/crypto/tls/handshake_client_test.go +++ b/src/crypto/tls/handshake_client_test.go @@ -907,6 +907,9 @@ func testResumption(t *testing.T, version uint16) { if didResume && (hs.PeerCertificates == nil || hs.VerifiedChains == nil) { t.Fatalf("expected non-nil certificates after resumption. Got peerCertificates: %#v, verifiedCertificates: %#v", hs.PeerCertificates, hs.VerifiedChains) } + if got, want := hs.ServerName, clientConfig.ServerName; got != want { + t.Errorf("%s: server name %s, want %s", test, got, want) + } } getTicket := func() []byte { @@ -1458,7 +1461,7 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) { sentinelErr := errors.New("TestVerifyPeerCertificate") - verifyCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + verifyPeerCertificateCallback := func(called *bool, rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { if l := len(rawCerts); l != 1 { return fmt.Errorf("got len(rawCerts) = %d, wanted 1", l) } @@ -1468,6 +1471,19 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) { *called = true return nil } + verifyConnectionCallback := func(called *bool, isClient bool, c ConnectionState) error { + if l := len(c.PeerCertificates); l != 1 { + return fmt.Errorf("got len(PeerCertificates) = %d, wanted 1", l) + } + if len(c.VerifiedChains) == 0 { + return fmt.Errorf("got len(VerifiedChains) = 0, wanted non-zero") + } + if isClient && len(c.OCSPResponse) == 0 { + return fmt.Errorf("got len(OCSPResponse) = 0, wanted non-zero") + } + *called = true + return nil + } tests := []struct { configureServer func(*Config, *bool) @@ -1478,13 +1494,13 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) { configureServer: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { - return verifyCallback(called, rawCerts, validatedChains) + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) } }, configureClient: func(config *Config, called *bool) { config.InsecureSkipVerify = false config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { - return verifyCallback(called, rawCerts, validatedChains) + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) } }, validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { @@ -1565,6 +1581,116 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) { } }, }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return verifyConnectionCallback(called, false, c) + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return verifyConnectionCallback(called, true, c) + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != nil { + t.Errorf("test[%d]: client handshake failed: %v", testNo, clientErr) + } + if serverErr != nil { + t.Errorf("test[%d]: server handshake failed: %v", testNo, serverErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + if !serverCalled { + t.Errorf("test[%d]: server did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = nil + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if serverErr != sentinelErr { + t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = nil + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != sentinelErr { + t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = nil + config.VerifyConnection = nil + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if serverErr != sentinelErr { + t.Errorf("#%d: got server error %v, wanted sentinelErr", testNo, serverErr) + } + if !serverCalled { + t.Errorf("test[%d]: server did not call callback", testNo) + } + }, + }, + { + configureServer: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = nil + config.VerifyConnection = nil + }, + configureClient: func(config *Config, called *bool) { + config.InsecureSkipVerify = false + config.VerifyPeerCertificate = func(rawCerts [][]byte, validatedChains [][]*x509.Certificate) error { + return verifyPeerCertificateCallback(called, rawCerts, validatedChains) + } + config.VerifyConnection = func(c ConnectionState) error { + return sentinelErr + } + }, + validate: func(t *testing.T, testNo int, clientCalled, serverCalled bool, clientErr, serverErr error) { + if clientErr != sentinelErr { + t.Errorf("#%d: got client error %v, wanted sentinelErr", testNo, clientErr) + } + if !clientCalled { + t.Errorf("test[%d]: client did not call callback", testNo) + } + }, + }, } for i, test := range tests { @@ -1580,6 +1706,11 @@ func testVerifyPeerCertificate(t *testing.T, version uint16) { config.ClientCAs = rootCAs config.Time = now config.MaxVersion = version + config.Certificates = make([]Certificate, 1) + config.Certificates[0].Certificate = [][]byte{testRSACertificate} + config.Certificates[0].PrivateKey = testRSAPrivateKey + config.Certificates[0].SignedCertificateTimestamps = [][]byte{[]byte("dummy sct 1"), []byte("dummy sct 2")} + config.Certificates[0].OCSPStaple = []byte("dummy ocsp") test.configureServer(config, &serverCalled) err = Server(s, config).Handshake() diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index 97122bd220..35a00f2f3a 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -407,6 +407,15 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error { // Either a PSK or a certificate is always used, but not both. // See RFC 8446, Section 4.1.1. if hs.usingPSK { + // Make sure the connection is still being verified whether or not this + // is a resumption. Resumptions currently don't reverify certificates so + // they don't call verifyServerCertificate. See Issue 31641. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } return nil } diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go index 4885c69568..6aacfa1ff6 100644 --- a/src/crypto/tls/handshake_server.go +++ b/src/crypto/tls/handshake_server.go @@ -68,6 +68,7 @@ func (hs *serverHandshakeState) handshake() error { c.buffering = true if hs.checkForResumption() { // The client has included a session ticket and so we do an abbreviated handshake. + c.didResume = true if err := hs.doResumeHandshake(); err != nil { return err } @@ -92,7 +93,6 @@ func (hs *serverHandshakeState) handshake() error { if err := hs.readFinished(nil); err != nil { return err } - c.didResume = true } else { // The client didn't include a session ticket, or it wasn't // valid so we do a full handshake. @@ -553,6 +553,15 @@ func (hs *serverHandshakeState) doFullHandshake() error { if err != nil { return err } + } else { + // Make sure the connection is still being verified whether or not + // the server requested a client certificate. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } } // Get client key exchange @@ -771,6 +780,19 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { c.verifiedChains = chains } + c.peerCertificates = certs + c.ocspResponse = certificate.OCSPStaple + c.scts = certificate.SignedCertificateTimestamps + + if len(certs) > 0 { + switch certs[0].PublicKey.(type) { + case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) + } + } + if c.config.VerifyPeerCertificate != nil { if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil { c.sendAlert(alertBadCertificate) @@ -778,20 +800,13 @@ func (c *Conn) processCertsFromClient(certificate Certificate) error { } } - if len(certs) == 0 { - return nil + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } } - switch certs[0].PublicKey.(type) { - case *ecdsa.PublicKey, *rsa.PublicKey, ed25519.PublicKey: - default: - c.sendAlert(alertUnsupportedCertificate) - return fmt.Errorf("tls: client certificate contains an unsupported public key of type %T", certs[0].PublicKey) - } - - c.peerCertificates = certs - c.ocspResponse = certificate.OCSPStaple - c.scts = certificate.SignedCertificateTimestamps return nil } diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go index 5432145de4..fb7f871390 100644 --- a/src/crypto/tls/handshake_server_tls13.go +++ b/src/crypto/tls/handshake_server_tls13.go @@ -306,6 +306,7 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { return errors.New("tls: invalid PSK binder") } + c.didResume = true if err := c.processCertsFromClient(sessionState.certificate); err != nil { return err } @@ -313,7 +314,6 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error { hs.hello.selectedIdentityPresent = true hs.hello.selectedIdentity = uint16(i) hs.usingPSK = true - c.didResume = true return nil } @@ -753,6 +753,14 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error { c := hs.c if !hs.requestClientCert() { + // Make sure the connection is still being verified whether or not + // the server requested a client certificate. + if c.config.VerifyConnection != nil { + if err := c.config.VerifyConnection(c.connectionStateLocked()); err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } return nil } diff --git a/src/crypto/tls/tls_test.go b/src/crypto/tls/tls_test.go index 85005d4950..9e340774b6 100644 --- a/src/crypto/tls/tls_test.go +++ b/src/crypto/tls/tls_test.go @@ -734,7 +734,7 @@ func TestWarningAlertFlood(t *testing.T) { } func TestCloneFuncFields(t *testing.T) { - const expectedCount = 5 + const expectedCount = 6 called := 0 c1 := Config{ @@ -758,6 +758,10 @@ func TestCloneFuncFields(t *testing.T) { called |= 1 << 4 return nil }, + VerifyConnection: func(ConnectionState) error { + called |= 1 << 5 + return nil + }, } c2 := c1.Clone() @@ -767,6 +771,7 @@ func TestCloneFuncFields(t *testing.T) { c2.GetClientCertificate(nil) c2.GetConfigForClient(nil) c2.VerifyPeerCertificate(nil, nil) + c2.VerifyConnection(ConnectionState{}) if called != (1<