From fb86c70bdad8aaa9756d6740885714c5aeff5989 Mon Sep 17 00:00:00 2001 From: Katie Hockman Date: Thu, 4 Jun 2020 10:52:24 -0400 Subject: [PATCH] crypto/tls: set CipherSuite for VerifyConnection The ConnectionState's CipherSuite was not set prior to the VerifyConnection callback in TLS 1.2 servers, both for full handshakes and resumptions. Change-Id: Iab91783eff84d1b42ca09c8df08e07861e18da30 Reviewed-on: https://go-review.googlesource.com/c/go/+/236558 Run-TryBot: Katie Hockman TryBot-Result: Gobot Gobot Reviewed-by: Filippo Valsorda --- src/crypto/tls/handshake_client_test.go | 31 ++++++++++++++----------- src/crypto/tls/handshake_server.go | 3 ++- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go index 88c974f83d..1cda90190c 100644 --- a/src/crypto/tls/handshake_client_test.go +++ b/src/crypto/tls/handshake_client_test.go @@ -1470,25 +1470,28 @@ func TestVerifyConnection(t *testing.T) { } func testVerifyConnection(t *testing.T, version uint16) { - checkFields := func(c ConnectionState, called *int) error { + checkFields := func(c ConnectionState, called *int, errorType string) error { if c.Version != version { - return fmt.Errorf("got Version %v, want %v", c.Version, version) + return fmt.Errorf("%s: got Version %v, want %v", errorType, c.Version, version) } if c.HandshakeComplete { - return fmt.Errorf("got HandshakeComplete, want false") + return fmt.Errorf("%s: got HandshakeComplete, want false", errorType) } if c.ServerName != "example.golang" { - return fmt.Errorf("got ServerName %s, want %s", c.ServerName, "example.golang") + return fmt.Errorf("%s: got ServerName %s, want %s", errorType, c.ServerName, "example.golang") } if c.NegotiatedProtocol != "protocol1" { - return fmt.Errorf("got NegotiatedProtocol %s, want %s", c.NegotiatedProtocol, "protocol1") + return fmt.Errorf("%s: got NegotiatedProtocol %s, want %s", errorType, c.NegotiatedProtocol, "protocol1") + } + if c.CipherSuite == 0 { + return fmt.Errorf("%s: got CipherSuite 0, want non-zero", errorType) } wantDidResume := false if *called == 2 { // if this is the second time, then it should be a resumption wantDidResume = true } if c.DidResume != wantDidResume { - return fmt.Errorf("got DidResume %t, want %t", c.DidResume, wantDidResume) + return fmt.Errorf("%s: got DidResume %t, want %t", errorType, c.DidResume, wantDidResume) } return nil } @@ -1510,7 +1513,7 @@ func testVerifyConnection(t *testing.T, version uint16) { if len(c.VerifiedChains) == 0 { return fmt.Errorf("server: got len(VerifiedChains) = 0, wanted non-zero") } - return checkFields(c, called) + return checkFields(c, called, "server") } }, configureClient: func(config *Config, called *int) { @@ -1533,7 +1536,7 @@ func testVerifyConnection(t *testing.T, version uint16) { if len(c.SignedCertificateTimestamps) == 0 { return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") } - return checkFields(c, called) + return checkFields(c, called, "client") } }, }, @@ -1550,7 +1553,7 @@ func testVerifyConnection(t *testing.T, version uint16) { if c.VerifiedChains != nil { return fmt.Errorf("server: got Verified Chains %v, want nil", c.VerifiedChains) } - return checkFields(c, called) + return checkFields(c, called, "server") } }, configureClient: func(config *Config, called *int) { @@ -1574,7 +1577,7 @@ func testVerifyConnection(t *testing.T, version uint16) { if len(c.SignedCertificateTimestamps) == 0 { return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") } - return checkFields(c, called) + return checkFields(c, called, "client") } }, }, @@ -1584,13 +1587,13 @@ func testVerifyConnection(t *testing.T, version uint16) { config.ClientAuth = NoClientCert config.VerifyConnection = func(c ConnectionState) error { *called++ - return checkFields(c, called) + return checkFields(c, called, "server") } }, configureClient: func(config *Config, called *int) { config.VerifyConnection = func(c ConnectionState) error { *called++ - return checkFields(c, called) + return checkFields(c, called, "client") } }, }, @@ -1600,7 +1603,7 @@ func testVerifyConnection(t *testing.T, version uint16) { config.ClientAuth = RequestClientCert config.VerifyConnection = func(c ConnectionState) error { *called++ - return checkFields(c, called) + return checkFields(c, called, "server") } }, configureClient: func(config *Config, called *int) { @@ -1624,7 +1627,7 @@ func testVerifyConnection(t *testing.T, version uint16) { if len(c.SignedCertificateTimestamps) == 0 { return fmt.Errorf("client: got len(SignedCertificateTimestamps) = 0, wanted non-zero") } - return checkFields(c, called) + return checkFields(c, called, "client") } }, }, diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go index 2c2f0a4879..16d3e643f0 100644 --- a/src/crypto/tls/handshake_server.go +++ b/src/crypto/tls/handshake_server.go @@ -308,6 +308,7 @@ func (hs *serverHandshakeState) pickCipherSuite() error { c.sendAlert(alertHandshakeFailure) return errors.New("tls: no cipher suite supported by both client and server") } + c.cipherSuite = hs.suite.id for _, id := range hs.clientHello.cipherSuites { if id == TLS_FALLBACK_SCSV { @@ -407,6 +408,7 @@ func (hs *serverHandshakeState) doResumeHandshake() error { c := hs.c hs.hello.cipherSuite = hs.suite.id + c.cipherSuite = hs.suite.id // We echo the client's session ID in the ServerHello to let it know // that we're doing a resumption. hs.hello.sessionId = hs.clientHello.sessionId @@ -743,7 +745,6 @@ func (hs *serverHandshakeState) sendFinished(out []byte) error { return err } - c.cipherSuite = hs.suite.id copy(out, finished.verifyData) return nil