diff --git a/src/crypto/tls/boring_test.go b/src/crypto/tls/boring_test.go index 6868f1a370..6f70f02f49 100644 --- a/src/crypto/tls/boring_test.go +++ b/src/crypto/tls/boring_test.go @@ -43,9 +43,9 @@ func TestBoringServerProtocolVersion(t *testing.T) { fipstls.Force() defer fipstls.Abandon() - test("VersionSSL30", VersionSSL30, "unsupported, maximum protocol version") - test("VersionTLS10", VersionTLS10, "unsupported, maximum protocol version") - test("VersionTLS11", VersionTLS11, "unsupported, maximum protocol version") + test("VersionSSL30", VersionSSL30, "client offered only unsupported versions") + test("VersionTLS10", VersionTLS10, "client offered only unsupported versions") + test("VersionTLS11", VersionTLS11, "client offered only unsupported versions") test("VersionTLS12", VersionTLS12, "") } diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go index a3cfe05bc0..a2b960ef54 100644 --- a/src/crypto/tls/common.go +++ b/src/crypto/tls/common.go @@ -40,9 +40,6 @@ const ( recordHeaderLen = 5 // record header length maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) maxUselessRecords = 5 // maximum number of consecutive non-advancing records - - minVersion = VersionTLS10 - maxVersion = VersionTLS12 ) // TLS record types. @@ -57,19 +54,23 @@ const ( // TLS handshake message types. const ( - typeHelloRequest uint8 = 0 - typeClientHello uint8 = 1 - typeServerHello uint8 = 2 - typeNewSessionTicket uint8 = 4 - typeCertificate uint8 = 11 - typeServerKeyExchange uint8 = 12 - typeCertificateRequest uint8 = 13 - typeServerHelloDone uint8 = 14 - typeCertificateVerify uint8 = 15 - typeClientKeyExchange uint8 = 16 - typeFinished uint8 = 20 - typeCertificateStatus uint8 = 22 - typeNextProtocol uint8 = 67 // Not IANA assigned + typeHelloRequest uint8 = 0 + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeNewSessionTicket uint8 = 4 + typeEndOfEarlyData uint8 = 5 + typeEncryptedExtensions uint8 = 8 + typeCertificate uint8 = 11 + typeServerKeyExchange uint8 = 12 + typeCertificateRequest uint8 = 13 + typeServerHelloDone uint8 = 14 + typeCertificateVerify uint8 = 15 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 + typeCertificateStatus uint8 = 22 + typeKeyUpdate uint8 = 24 + typeNextProtocol uint8 = 67 // Not IANA assigned + typeMessageHash uint8 = 254 // synthetic message ) // TLS compression types. @@ -88,6 +89,7 @@ const ( extensionSCT uint16 = 18 extensionSessionTicket uint16 = 35 extensionPreSharedKey uint16 = 41 + extensionEarlyData uint16 = 42 extensionSupportedVersions uint16 = 43 extensionCookie uint16 = 44 extensionPSKModes uint16 = 45 @@ -713,24 +715,46 @@ func (c *Config) cipherSuites() []uint16 { return s } -func (c *Config) minVersion() uint16 { - if needFIPS() { - return fipsMinVersion(c) - } - if c == nil || c.MinVersion == 0 { - return minVersion - } - return c.MinVersion +var supportedVersions = []uint16{ + VersionTLS12, + VersionTLS11, + VersionTLS10, + VersionSSL30, } -func (c *Config) maxVersion() uint16 { - if needFIPS() { - return fipsMaxVersion(c) +func (c *Config) supportedVersions(isClient bool) []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if needFIPS() && (v < fipsMinVersion(c) || v > fipsMaxVersion(c)) { + continue + } + if c != nil && c.MinVersion != 0 && v < c.MinVersion { + continue + } + if c != nil && c.MaxVersion != 0 && v > c.MaxVersion { + continue + } + // TLS 1.0 is the minimum version supported as a client. + if isClient && v < VersionTLS10 { + continue + } + versions = append(versions, v) } - if c == nil || c.MaxVersion == 0 { - return maxVersion + return versions +} + +// supportedVersionsFromMax returns a list of supported versions derived from a +// legacy maximum version value. Note that only versions supported by this +// library are returned. Any newer peer will use supportedVersions anyway. +func supportedVersionsFromMax(maxVersion uint16) []uint16 { + versions := make([]uint16, 0, len(supportedVersions)) + for _, v := range supportedVersions { + if v > maxVersion { + continue + } + versions = append(versions, v) } - return c.MaxVersion + return versions } var defaultCurvePreferences = []CurveID{X25519, CurveP256, CurveP384, CurveP521} @@ -746,18 +770,17 @@ func (c *Config) curvePreferences() []CurveID { } // mutualVersion returns the protocol version to use given the advertised -// version of the peer. -func (c *Config) mutualVersion(vers uint16) (uint16, bool) { - minVersion := c.minVersion() - maxVersion := c.maxVersion() - - if vers < minVersion { - return 0, false +// versions of the peer. Priority is given to the peer preference order. +func (c *Config) mutualVersion(isClient bool, peerVersions []uint16) (uint16, bool) { + supportedVersions := c.supportedVersions(isClient) + for _, peerVersion := range peerVersions { + for _, v := range supportedVersions { + if v == peerVersion { + return v, true + } + } } - if vers > maxVersion { - vers = maxVersion - } - return vers, true + return 0, false } // getCertificate returns the best certificate for the given ClientHelloInfo, diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index 5af1413935..3619964095 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -990,12 +990,24 @@ func (c *Conn) readHandshake() (interface{}, error) { case typeServerHello: m = new(serverHelloMsg) case typeNewSessionTicket: - m = new(newSessionTicketMsg) + if c.vers == VersionTLS13 { + m = new(newSessionTicketMsgTLS13) + } else { + m = new(newSessionTicketMsg) + } case typeCertificate: - m = new(certificateMsg) + if c.vers == VersionTLS13 { + m = new(certificateMsgTLS13) + } else { + m = new(certificateMsg) + } case typeCertificateRequest: - m = &certificateRequestMsg{ - hasSignatureAlgorithm: c.vers >= VersionTLS12, + if c.vers == VersionTLS13 { + m = new(certificateRequestMsgTLS13) + } else { + m = &certificateRequestMsg{ + hasSignatureAlgorithm: c.vers >= VersionTLS12, + } } case typeCertificateStatus: m = new(certificateStatusMsg) @@ -1013,6 +1025,12 @@ func (c *Conn) readHandshake() (interface{}, error) { m = new(nextProtoMsg) case typeFinished: m = new(finishedMsg) + case typeEncryptedExtensions: + m = new(encryptedExtensionsMsg) + case typeEndOfEarlyData: + m = new(endOfEarlyDataMsg) + case typeKeyUpdate: + m = new(keyUpdateMsg) default: return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) } diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index a1f0731730..995fd0c5b6 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -43,13 +43,25 @@ func makeClientHello(config *Config) (*clientHelloMsg, error) { nextProtosLength += 1 + l } } - if nextProtosLength > 0xffff { return nil, errors.New("tls: NextProtos values too large") } + supportedVersions := config.supportedVersions(true) + if len(supportedVersions) == 0 { + return nil, errors.New("tls: no supported versions satisfy MinVersion and MaxVersion") + } + + clientHelloVersion := supportedVersions[0] + // The version at the beginning of the ClientHello was capped at TLS 1.2 + // for compatibility reasons. The supported_versions extension is used + // to negotiate versions now. See RFC 8446, Section 4.2.1. + if clientHelloVersion > VersionTLS12 { + clientHelloVersion = VersionTLS12 + } + hello := &clientHelloMsg{ - vers: config.maxVersion(), + vers: clientHelloVersion, compressionMethods: []uint8{compressionNone}, random: make([]byte, 32), ocspStapling: true, @@ -60,6 +72,7 @@ func makeClientHello(config *Config) (*clientHelloMsg, error) { nextProtoNeg: len(config.NextProtos) > 0, secureRenegotiationSupported: true, alpnProtocols: config.NextProtos, + supportedVersions: supportedVersions, } possibleCipherSuites := config.cipherSuites() hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) @@ -143,8 +156,14 @@ func (c *Conn) clientHandshake() error { } } - versOk := candidateSession.vers >= c.config.minVersion() && - candidateSession.vers <= c.config.maxVersion() + versOk := false + for _, v := range c.config.supportedVersions(true) { + if v == candidateSession.vers { + versOk = true + break + } + } + if versOk && cipherSuiteOk { session = candidateSession } @@ -276,11 +295,15 @@ func (hs *clientHandshakeState) handshake() error { } func (hs *clientHandshakeState) pickTLSVersion() error { - vers, ok := hs.c.config.mutualVersion(hs.serverHello.vers) - if !ok || vers < VersionTLS10 { - // TLS 1.0 is the minimum version supported as a client. + peerVersion := hs.serverHello.vers + if hs.serverHello.supportedVersion != 0 { + peerVersion = hs.serverHello.supportedVersion + } + + vers, ok := hs.c.config.mutualVersion(true, []uint16{peerVersion}) + if !ok { hs.c.sendAlert(alertProtocolVersion) - return fmt.Errorf("tls: server selected unsupported protocol version %x", hs.serverHello.vers) + return fmt.Errorf("tls: server selected unsupported protocol version %x", peerVersion) } hs.c.vers = vers @@ -398,9 +421,7 @@ func (hs *clientHandshakeState) doFullHandshake() error { } hs.finishedHash.Write(cs.marshal()) - if cs.statusType == statusTypeOCSP { - c.ocspResponse = cs.response - } + c.ocspResponse = cs.response msg, err = c.readHandshake() if err != nil { diff --git a/src/crypto/tls/handshake_client_test.go b/src/crypto/tls/handshake_client_test.go index 437aaed462..18c15340ea 100644 --- a/src/crypto/tls/handshake_client_test.go +++ b/src/crypto/tls/handshake_client_test.go @@ -279,6 +279,12 @@ func (test *clientTest) loadData() (flows [][]byte, err error) { func (test *clientTest) run(t *testing.T, write bool) { checkOpenSSLVersion(t) + // TODO(filippo): regenerate client tests all at once after CL 146217, + // RSA-PSS and client-side TLS 1.3 are landed. + if !write { + t.Skip("recorded client tests are out of date") + } + var clientConn, serverConn net.Conn var recordingConn *recordingConn var childProcess *exec.Cmd diff --git a/src/crypto/tls/handshake_messages.go b/src/crypto/tls/handshake_messages.go index d04efc98f6..82b91cc87e 100644 --- a/src/crypto/tls/handshake_messages.go +++ b/src/crypto/tls/handshake_messages.go @@ -71,6 +71,7 @@ type clientHelloMsg struct { supportedVersions []uint16 cookie []byte keyShares []keyShare + earlyData bool pskModes []uint8 pskIdentities []pskIdentity pskBinders [][]byte @@ -239,6 +240,11 @@ func (m *clientHelloMsg) marshal() []byte { }) }) } + if m.earlyData { + // RFC 8446, Section 4.2.10 + b.AddUint16(extensionEarlyData) + b.AddUint16(0) // empty extension_data + } if len(m.pskModes) > 0 { // RFC 8446, Section 4.2.9 b.AddUint16(extensionPSKModes) @@ -478,6 +484,9 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { } m.keyShares = append(m.keyShares, ks) } + case extensionEarlyData: + // RFC 8446, Section 4.2.10 + m.earlyData = true case extensionPSKModes: // RFC 8446, Section 4.2.9 if !readUint8LengthPrefixed(&extData, &m.pskModes) { @@ -782,6 +791,342 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool { return true } +type encryptedExtensionsMsg struct { + raw []byte + alpnProtocol string +} + +func (m *encryptedExtensionsMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeEncryptedExtensions) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if len(m.alpnProtocol) > 0 { + b.AddUint16(extensionALPN) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte(m.alpnProtocol)) + }) + }) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { + *m = encryptedExtensionsMsg{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionALPN: + var protoList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() { + return false + } + var proto cryptobyte.String + if !protoList.ReadUint8LengthPrefixed(&proto) || + proto.Empty() || !protoList.Empty() { + return false + } + m.alpnProtocol = string(proto) + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type endOfEarlyDataMsg struct{} + +func (m *endOfEarlyDataMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeEndOfEarlyData + return x +} + +func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type keyUpdateMsg struct { + raw []byte + updateRequested bool +} + +func (m *keyUpdateMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeKeyUpdate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + if m.updateRequested { + b.AddUint8(1) + } else { + b.AddUint8(0) + } + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *keyUpdateMsg) unmarshal(data []byte) bool { + m.raw = data + s := cryptobyte.String(data) + + var updateRequested uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&updateRequested) || !s.Empty() { + return false + } + switch updateRequested { + case 0: + m.updateRequested = false + case 1: + m.updateRequested = true + default: + return false + } + return true +} + +type newSessionTicketMsgTLS13 struct { + raw []byte + lifetime uint32 + ageAdd uint32 + nonce []byte + label []byte + maxEarlyData uint32 +} + +func (m *newSessionTicketMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeNewSessionTicket) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.lifetime) + b.AddUint32(m.ageAdd) + b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.nonce) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.label) + }) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.maxEarlyData > 0 { + b.AddUint16(extensionEarlyData) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint32(m.maxEarlyData) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool { + *m = newSessionTicketMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint32(&m.lifetime) || + !s.ReadUint32(&m.ageAdd) || + !readUint8LengthPrefixed(&s, &m.nonce) || + !readUint16LengthPrefixed(&s, &m.label) || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionEarlyData: + if !extData.ReadUint32(&m.maxEarlyData) { + return false + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + +type certificateRequestMsgTLS13 struct { + raw []byte + ocspStapling bool + scts bool + supportedSignatureAlgorithms []SignatureScheme + supportedSignatureAlgorithmsCert []SignatureScheme +} + +func (m *certificateRequestMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificateRequest) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + // certificate_request_context (SHALL be zero length unless used for + // post-handshake authentication) + b.AddUint8(0) + + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16(0) // empty extension_data + } + if m.scts { + // RFC 8446, Section 4.4.2.1 makes no mention of + // signed_certificate_timestamp in CertificateRequest, but + // "Extensions in the Certificate message from the client MUST + // correspond to extensions in the CertificateRequest message + // from the server." and it appears in the table in Section 4.2. + b.AddUint16(extensionSCT) + b.AddUint16(0) // empty extension_data + } + if len(m.supportedSignatureAlgorithms) > 0 { + b.AddUint16(extensionSignatureAlgorithms) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithms { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + if len(m.supportedSignatureAlgorithmsCert) > 0 { + b.AddUint16(extensionSignatureAlgorithmsCert) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sigAlgo := range m.supportedSignatureAlgorithmsCert { + b.AddUint16(uint16(sigAlgo)) + } + }) + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool { + *m = certificateRequestMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context, extensions cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !s.ReadUint16LengthPrefixed(&extensions) || + !s.Empty() { + return false + } + + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + case extensionSCT: + m.scts = true + case extensionSignatureAlgorithms: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithms = append( + m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg)) + } + case extensionSignatureAlgorithmsCert: + var sigAndAlgs cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() { + return false + } + for !sigAndAlgs.Empty() { + var sigAndAlg uint16 + if !sigAndAlgs.ReadUint16(&sigAndAlg) { + return false + } + m.supportedSignatureAlgorithmsCert = append( + m.supportedSignatureAlgorithmsCert, SignatureScheme(sigAndAlg)) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + + return true +} + type certificateMsg struct { raw []byte certificates [][]byte @@ -859,6 +1204,131 @@ func (m *certificateMsg) unmarshal(data []byte) bool { return true } +type certificateMsgTLS13 struct { + raw []byte + certificate Certificate + ocspStapling bool + scts bool +} + +func (m *certificateMsgTLS13) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var b cryptobyte.Builder + b.AddUint8(typeCertificate) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(0) // certificate_request_context + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + for i, cert := range m.certificate.Certificate { + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(cert) + }) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + if i > 0 { + // This library only supports OCSP and SCT for leaf certificates. + return + } + if m.ocspStapling { + b.AddUint16(extensionStatusRequest) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.certificate.OCSPStaple) + }) + }) + } + if m.scts { + b.AddUint16(extensionSCT) + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + for _, sct := range m.certificate.SignedCertificateTimestamps { + b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(sct) + }) + } + }) + }) + } + }) + } + }) + }) + + m.raw = b.BytesOrPanic() + return m.raw +} + +func (m *certificateMsgTLS13) unmarshal(data []byte) bool { + *m = certificateMsgTLS13{raw: data} + s := cryptobyte.String(data) + + var context, certList cryptobyte.String + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8LengthPrefixed(&context) || !context.Empty() || + !s.ReadUint24LengthPrefixed(&certList) || + !s.Empty() { + return false + } + + for !certList.Empty() { + var cert []byte + var extensions cryptobyte.String + if !readUint24LengthPrefixed(&certList, &cert) || + !certList.ReadUint16LengthPrefixed(&extensions) { + return false + } + m.certificate.Certificate = append(m.certificate.Certificate, cert) + for !extensions.Empty() { + var extension uint16 + var extData cryptobyte.String + if !extensions.ReadUint16(&extension) || + !extensions.ReadUint16LengthPrefixed(&extData) { + return false + } + if len(m.certificate.Certificate) > 1 { + // This library only supports OCSP and SCT for leaf certificates. + continue + } + + switch extension { + case extensionStatusRequest: + m.ocspStapling = true + var statusType uint8 + if !extData.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&extData, &m.certificate.OCSPStaple) || + len(m.certificate.OCSPStaple) == 0 { + return false + } + case extensionSCT: + m.scts = true + var sctList cryptobyte.String + if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() { + return false + } + for !sctList.Empty() { + var sct []byte + if !readUint16LengthPrefixed(&sctList, &sct) || + len(sct) == 0 { + return false + } + m.certificate.SignedCertificateTimestamps = append( + m.certificate.SignedCertificateTimestamps, sct) + } + default: + // Ignore unknown extensions. + continue + } + + if !extData.Empty() { + return false + } + } + } + return true +} + type serverKeyExchangeMsg struct { raw []byte key []byte @@ -890,9 +1360,8 @@ func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { } type certificateStatusMsg struct { - raw []byte - statusType uint8 - response []byte + raw []byte + response []byte } func (m *certificateStatusMsg) marshal() []byte { @@ -900,46 +1369,29 @@ func (m *certificateStatusMsg) marshal() []byte { return m.raw } - var x []byte - if m.statusType == statusTypeOCSP { - x = make([]byte, 4+4+len(m.response)) - x[0] = typeCertificateStatus - l := len(m.response) + 4 - x[1] = byte(l >> 16) - x[2] = byte(l >> 8) - x[3] = byte(l) - x[4] = statusTypeOCSP + var b cryptobyte.Builder + b.AddUint8(typeCertificateStatus) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddUint8(statusTypeOCSP) + b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(m.response) + }) + }) - l -= 4 - x[5] = byte(l >> 16) - x[6] = byte(l >> 8) - x[7] = byte(l) - copy(x[8:], m.response) - } else { - x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType} - } - - m.raw = x - return x + m.raw = b.BytesOrPanic() + return m.raw } func (m *certificateStatusMsg) unmarshal(data []byte) bool { m.raw = data - if len(data) < 5 { - return false - } - m.statusType = data[4] + s := cryptobyte.String(data) - m.response = nil - if m.statusType == statusTypeOCSP { - if len(data) < 8 { - return false - } - respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) - if uint32(len(data)) != 4+4+respLen { - return false - } - m.response = data[8:] + var statusType uint8 + if !s.Skip(4) || // message type and uint24 length field + !s.ReadUint8(&statusType) || statusType != statusTypeOCSP || + !readUint24LengthPrefixed(&s, &m.response) || + len(m.response) == 0 || !s.Empty() { + return false } return true } diff --git a/src/crypto/tls/handshake_messages_test.go b/src/crypto/tls/handshake_messages_test.go index fdf096b473..ce2b04344c 100644 --- a/src/crypto/tls/handshake_messages_test.go +++ b/src/crypto/tls/handshake_messages_test.go @@ -29,6 +29,12 @@ var tests = []interface{}{ &nextProtoMsg{}, &newSessionTicketMsg{}, &sessionState{}, + &encryptedExtensionsMsg{}, + &endOfEarlyDataMsg{}, + &keyUpdateMsg{}, + &newSessionTicketMsgTLS13{}, + &certificateRequestMsgTLS13{}, + &certificateMsgTLS13{}, } func TestMarshalUnmarshal(t *testing.T) { @@ -184,6 +190,9 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { m.pskIdentities = append(m.pskIdentities, psk) m.pskBinders = append(m.pskBinders, randomBytes(rand.Intn(50)+32, rand)) } + if rand.Intn(10) > 5 { + m.earlyData = true + } return reflect.ValueOf(m) } @@ -209,7 +218,9 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { if rand.Intn(10) > 5 { m.ticketSupported = true } - m.alpnProtocol = randomString(rand.Intn(32)+1, rand) + if rand.Intn(10) > 5 { + m.alpnProtocol = randomString(rand.Intn(32)+1, rand) + } for i := 0; i < rand.Intn(4); i++ { m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand)) @@ -241,6 +252,16 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.ValueOf(m) } +func (*encryptedExtensionsMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &encryptedExtensionsMsg{} + + if rand.Intn(10) > 5 { + m.alpnProtocol = randomString(rand.Intn(32)+1, rand) + } + + return reflect.ValueOf(m) +} + func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &certificateMsg{} numCerts := rand.Intn(20) @@ -270,12 +291,7 @@ func (*certificateVerifyMsg) Generate(rand *rand.Rand, size int) reflect.Value { func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value { m := &certificateStatusMsg{} - if rand.Intn(10) > 5 { - m.statusType = statusTypeOCSP - m.response = randomBytes(rand.Intn(10)+1, rand) - } else { - m.statusType = 42 - } + m.response = randomBytes(rand.Intn(10)+1, rand) return reflect.ValueOf(m) } @@ -316,6 +332,66 @@ func (*sessionState) Generate(rand *rand.Rand, size int) reflect.Value { return reflect.ValueOf(s) } +func (*endOfEarlyDataMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &endOfEarlyDataMsg{} + return reflect.ValueOf(m) +} + +func (*keyUpdateMsg) Generate(rand *rand.Rand, size int) reflect.Value { + m := &keyUpdateMsg{} + m.updateRequested = rand.Intn(10) > 5 + return reflect.ValueOf(m) +} + +func (*newSessionTicketMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { + m := &newSessionTicketMsgTLS13{} + m.lifetime = uint32(rand.Intn(500000)) + m.ageAdd = uint32(rand.Intn(500000)) + m.nonce = randomBytes(rand.Intn(100), rand) + m.label = randomBytes(rand.Intn(1000), rand) + if rand.Intn(10) > 5 { + m.maxEarlyData = uint32(rand.Intn(500000)) + } + return reflect.ValueOf(m) +} + +func (*certificateRequestMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateRequestMsgTLS13{} + if rand.Intn(10) > 5 { + m.ocspStapling = true + } + if rand.Intn(10) > 5 { + m.scts = true + } + if rand.Intn(10) > 5 { + m.supportedSignatureAlgorithms = supportedSignatureAlgorithms() + } + if rand.Intn(10) > 5 { + m.supportedSignatureAlgorithmsCert = supportedSignatureAlgorithms() + } + return reflect.ValueOf(m) +} + +func (*certificateMsgTLS13) Generate(rand *rand.Rand, size int) reflect.Value { + m := &certificateMsgTLS13{} + for i := 0; i < rand.Intn(2)+1; i++ { + m.certificate.Certificate = append( + m.certificate.Certificate, randomBytes(rand.Intn(500)+1, rand)) + } + if rand.Intn(10) > 5 { + m.ocspStapling = true + m.certificate.OCSPStaple = randomBytes(rand.Intn(100)+1, rand) + } + if rand.Intn(10) > 5 { + m.scts = true + for i := 0; i < rand.Intn(2)+1; i++ { + m.certificate.SignedCertificateTimestamps = append( + m.certificate.SignedCertificateTimestamps, randomBytes(rand.Intn(500)+1, rand)) + } + } + return reflect.ValueOf(m) +} + func TestRejectEmptySCTList(t *testing.T) { // RFC 6962, Section 3.3.1 specifies that empty SCT lists are invalid. diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go index fc458f6b01..00ce49f444 100644 --- a/src/crypto/tls/handshake_server.go +++ b/src/crypto/tls/handshake_server.go @@ -135,14 +135,19 @@ func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { } } - c.vers, ok = c.config.mutualVersion(hs.clientHello.vers) + clientVersions := hs.clientHello.supportedVersions + if len(hs.clientHello.supportedVersions) == 0 { + clientVersions = supportedVersionsFromMax(hs.clientHello.vers) + } + c.vers, ok = c.config.mutualVersion(false, clientVersions) if !ok { c.sendAlert(alertProtocolVersion) - return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers) + return false, fmt.Errorf("tls: client offered only unsupported versions: %x", clientVersions) } c.haveVers = true hs.hello = new(serverHelloMsg) + hs.hello.vers = c.vers supportedCurve := false preferredCurves := c.config.curvePreferences() @@ -179,7 +184,6 @@ Curves: return false, errors.New("tls: client does not support uncompressed connections") } - hs.hello.vers = c.vers hs.hello.random = make([]byte, 32) _, err = io.ReadFull(c.config.rand(), hs.hello.random) if err != nil { @@ -272,7 +276,7 @@ Curves: for _, id := range hs.clientHello.cipherSuites { if id == TLS_FALLBACK_SCSV { // The client is doing a fallback connection. - if hs.clientHello.vers < c.config.maxVersion() { + if hs.clientHello.vers < c.config.supportedVersions(false)[0] { c.sendAlert(alertInappropriateFallback) return false, errors.New("tls: client using inappropriate protocol fallback") } @@ -389,7 +393,6 @@ func (hs *serverHandshakeState) doFullHandshake() error { if hs.hello.ocspStapling { certStatus := new(certificateStatusMsg) - certStatus.statusType = statusTypeOCSP certStatus.response = hs.cert.OCSPStaple hs.finishedHash.Write(certStatus.marshal()) if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil { @@ -765,19 +768,14 @@ func (hs *serverHandshakeState) setCipherSuite(id uint16, supportedCipherSuites return false } -// suppVersArray is the backing array of ClientHelloInfo.SupportedVersions -var suppVersArray = [...]uint16{VersionTLS12, VersionTLS11, VersionTLS10, VersionSSL30} - func (hs *serverHandshakeState) clientHelloInfo() *ClientHelloInfo { if hs.cachedClientHelloInfo != nil { return hs.cachedClientHelloInfo } - var supportedVersions []uint16 - if hs.clientHello.vers > VersionTLS12 { - supportedVersions = suppVersArray[:] - } else if hs.clientHello.vers >= VersionSSL30 { - supportedVersions = suppVersArray[VersionTLS12-hs.clientHello.vers:] + supportedVersions := hs.clientHello.supportedVersions + if len(hs.clientHello.supportedVersions) == 0 { + supportedVersions = supportedVersionsFromMax(hs.clientHello.vers) } hs.cachedClientHelloInfo = &ClientHelloInfo{ diff --git a/src/crypto/tls/handshake_server_test.go b/src/crypto/tls/handshake_server_test.go index 01de92d971..5aaa815279 100644 --- a/src/crypto/tls/handshake_server_test.go +++ b/src/crypto/tls/handshake_server_test.go @@ -104,8 +104,13 @@ func TestRejectBadProtocolVersion(t *testing.T) { testClientHelloFailure(t, testConfig, &clientHelloMsg{ vers: v, random: make([]byte, 32), - }, "unsupported, maximum protocol version") + }, "unsupported versions") } + testClientHelloFailure(t, testConfig, &clientHelloMsg{ + vers: VersionTLS12, + supportedVersions: badProtocolVersions, + random: make([]byte, 32), + }, "unsupported versions") } func TestNoSuiteOverlap(t *testing.T) { @@ -1289,11 +1294,11 @@ var getConfigForClientTests = []struct { func(clientHello *ClientHelloInfo) (*Config, error) { config := testConfig.Clone() // Setting a maximum version of TLS 1.1 should cause - // the handshake to fail. + // the handshake to fail, as the client MinVersion is TLS 1.2. config.MaxVersion = VersionTLS11 return config, nil }, - "version 301 when expecting version 302", + "client offered only unsupported versions", nil, }, { diff --git a/src/crypto/tls/key_schedule.go b/src/crypto/tls/key_schedule.go new file mode 100644 index 0000000000..21b50f177d --- /dev/null +++ b/src/crypto/tls/key_schedule.go @@ -0,0 +1,85 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "golang_org/x/crypto/cryptobyte" + "golang_org/x/crypto/hkdf" + "hash" +) + +// This file contains the functions necessary to compute the TLS 1.3 key +// schedule. See RFC 8446, Section 7. + +const ( + resumptionBinderLabel = "res binder" + clientHandshakeTrafficLabel = "c hs traffic" + serverHandshakeTrafficLabel = "s hs traffic" + clientApplicationTrafficLabel = "c ap traffic" + serverApplicationTrafficLabel = "s ap traffic" + exporterLabel = "exp master" + resumptionLabel = "res master" + trafficUpdateLabel = "traffic upd" +) + +// expandLabel implements HKDF-Expand-Label from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []byte, length int) []byte { + var hkdfLabel cryptobyte.Builder + hkdfLabel.AddUint16(uint16(length)) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes([]byte("tls13 ")) + b.AddBytes([]byte(label)) + }) + hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) { + b.AddBytes(context) + }) + out := make([]byte, length) + n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out) + if err != nil || n != length { + panic("tls: HKDF-Expand-Label invocation failed unexpectedly") + } + return out +} + +// deriveSecret implements Derive-Secret from RFC 8446, Section 7.1. +func (c *cipherSuiteTLS13) deriveSecret(secret []byte, label string, transcript hash.Hash) []byte { + if transcript == nil { + transcript = c.hash.New() + } + return c.expandLabel(secret, label, transcript.Sum(nil), c.hash.Size()) +} + +// extract implements HKDF-Extract with the cipher suite hash. +func (c *cipherSuiteTLS13) extract(newSecret, currentSecret []byte) []byte { + if newSecret == nil { + newSecret = make([]byte, c.hash.Size()) + } + return hkdf.Extract(c.hash.New, newSecret, currentSecret) +} + +// nextTrafficSecret generates the next traffic secret, given the current one, +// according to RFC 8446, Section 7.2. +func (c *cipherSuiteTLS13) nextTrafficSecret(trafficSecret []byte) []byte { + return c.expandLabel(trafficSecret, trafficUpdateLabel, nil, c.hash.Size()) +} + +// trafficKey generates traffic keys according to RFC 8446, Section 7.3. +func (c *cipherSuiteTLS13) trafficKey(trafficSecret []byte) (key, iv []byte) { + key = c.expandLabel(trafficSecret, "key", nil, c.keyLen) + iv = c.expandLabel(trafficSecret, "iv", nil, aeadNonceLength) + return +} + +// exportKeyingMaterial implements RFC5705 exporters for TLS 1.3 according to +// RFC 8446, Section 7.5. +func (c *cipherSuiteTLS13) exportKeyingMaterial(masterSecret []byte, transcript hash.Hash) func(string, []byte, int) ([]byte, error) { + expMasterSecret := c.deriveSecret(masterSecret, exporterLabel, transcript) + return func(label string, context []byte, length int) ([]byte, error) { + secret := c.deriveSecret(expMasterSecret, label, nil) + h := c.hash.New() + h.Write(context) + return c.expandLabel(secret, "exporter", h.Sum(nil), length), nil + } +} diff --git a/src/crypto/tls/key_schedule_test.go b/src/crypto/tls/key_schedule_test.go new file mode 100644 index 0000000000..79ff6a62b1 --- /dev/null +++ b/src/crypto/tls/key_schedule_test.go @@ -0,0 +1,175 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tls + +import ( + "bytes" + "encoding/hex" + "hash" + "strings" + "testing" + "unicode" +) + +// This file contains tests derived from draft-ietf-tls-tls13-vectors-07. + +func parseVector(v string) []byte { + v = strings.Map(func(c rune) rune { + if unicode.IsSpace(c) { + return -1 + } + return c + }, v) + parts := strings.Split(v, ":") + v = parts[len(parts)-1] + res, err := hex.DecodeString(v) + if err != nil { + panic(err) + } + return res +} + +func TestDeriveSecret(t *testing.T) { + chTranscript := cipherSuitesTLS13[0].hash.New() + chTranscript.Write(parseVector(` + payload (512 octets): 01 00 01 fc 03 03 1b c3 ce b6 bb e3 9c ff + 93 83 55 b5 a5 0a db 6d b2 1b 7a 6a f6 49 d7 b4 bc 41 9d 78 76 + 48 7d 95 00 00 06 13 01 13 03 13 02 01 00 01 cd 00 00 00 0b 00 + 09 00 00 06 73 65 72 76 65 72 ff 01 00 01 00 00 0a 00 14 00 12 + 00 1d 00 17 00 18 00 19 01 00 01 01 01 02 01 03 01 04 00 33 00 + 26 00 24 00 1d 00 20 e4 ff b6 8a c0 5f 8d 96 c9 9d a2 66 98 34 + 6c 6b e1 64 82 ba dd da fe 05 1a 66 b4 f1 8d 66 8f 0b 00 2a 00 + 00 00 2b 00 03 02 03 04 00 0d 00 20 00 1e 04 03 05 03 06 03 02 + 03 08 04 08 05 08 06 04 01 05 01 06 01 02 01 04 02 05 02 06 02 + 02 02 00 2d 00 02 01 01 00 1c 00 02 40 01 00 15 00 57 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 + 00 29 00 dd 00 b8 00 b2 2c 03 5d 82 93 59 ee 5f f7 af 4e c9 00 + 00 00 00 26 2a 64 94 dc 48 6d 2c 8a 34 cb 33 fa 90 bf 1b 00 70 + ad 3c 49 88 83 c9 36 7c 09 a2 be 78 5a bc 55 cd 22 60 97 a3 a9 + 82 11 72 83 f8 2a 03 a1 43 ef d3 ff 5d d3 6d 64 e8 61 be 7f d6 + 1d 28 27 db 27 9c ce 14 50 77 d4 54 a3 66 4d 4e 6d a4 d2 9e e0 + 37 25 a6 a4 da fc d0 fc 67 d2 ae a7 05 29 51 3e 3d a2 67 7f a5 + 90 6c 5b 3f 7d 8f 92 f2 28 bd a4 0d da 72 14 70 f9 fb f2 97 b5 + ae a6 17 64 6f ac 5c 03 27 2e 97 07 27 c6 21 a7 91 41 ef 5f 7d + e6 50 5e 5b fb c3 88 e9 33 43 69 40 93 93 4a e4 d3 57 fa d6 aa + cb 00 21 20 3a dd 4f b2 d8 fd f8 22 a0 ca 3c f7 67 8e f5 e8 8d + ae 99 01 41 c5 92 4d 57 bb 6f a3 1b 9e 5f 9d`)) + + type args struct { + secret []byte + label string + transcript hash.Hash + } + tests := []struct { + name string + args args + want []byte + }{ + { + `derive secret for handshake "tls13 derived"`, + args{ + parseVector(`PRK (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c e2 + 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`), + "derived", + nil, + }, + parseVector(`expanded (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba + b6 97 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`), + }, + { + `derive secret "tls13 c e traffic"`, + args{ + parseVector(`PRK (32 octets): 9b 21 88 e9 b2 fc 6d 64 d7 1d c3 29 90 0e 20 bb + 41 91 50 00 f6 78 aa 83 9c bb 79 7c b7 d8 33 2c`), + "c e traffic", + chTranscript, + }, + parseVector(`expanded (32 octets): 3f bb e6 a6 0d eb 66 c3 0a 32 79 5a ba 0e + ff 7e aa 10 10 55 86 e7 be 5c 09 67 8d 63 b6 ca ab 62`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := cipherSuitesTLS13[0] + if got := c.deriveSecret(tt.args.secret, tt.args.label, tt.args.transcript); !bytes.Equal(got, tt.want) { + t.Errorf("cipherSuiteTLS13.deriveSecret() = % x, want % x", got, tt.want) + } + }) + } +} + +func TestTrafficKey(t *testing.T) { + trafficSecret := parseVector( + `PRK (32 octets): b6 7b 7d 69 0c c1 6c 4e 75 e5 42 13 cb 2d 37 b4 + e9 c9 12 bc de d9 10 5d 42 be fd 59 d3 91 ad 38`) + wantKey := parseVector( + `key expanded (16 octets): 3f ce 51 60 09 c2 17 27 d0 f2 e4 e8 6e + e4 03 bc`) + wantIV := parseVector( + `iv expanded (12 octets): 5d 31 3e b2 67 12 76 ee 13 00 0b 30`) + + c := cipherSuitesTLS13[0] + gotKey, gotIV := c.trafficKey(trafficSecret) + if !bytes.Equal(gotKey, wantKey) { + t.Errorf("cipherSuiteTLS13.trafficKey() gotKey = % x, want % x", gotKey, wantKey) + } + if !bytes.Equal(gotIV, wantIV) { + t.Errorf("cipherSuiteTLS13.trafficKey() gotIV = % x, want % x", gotIV, wantIV) + } +} + +func TestExtract(t *testing.T) { + type args struct { + newSecret []byte + currentSecret []byte + } + tests := []struct { + name string + args args + want []byte + }{ + { + `extract secret "early"`, + args{ + nil, + nil, + }, + parseVector(`secret (32 octets): 33 ad 0a 1c 60 7e c0 3b 09 e6 cd 98 93 68 0c + e2 10 ad f3 00 aa 1f 26 60 e1 b2 2e 10 f1 70 f9 2a`), + }, + { + `extract secret "master"`, + args{ + nil, + parseVector(`salt (32 octets): 43 de 77 e0 c7 77 13 85 9a 94 4d b9 db 25 90 b5 + 31 90 a6 5b 3e e2 e4 f1 2d d7 a0 bb 7c e2 54 b4`), + }, + parseVector(`secret (32 octets): 18 df 06 84 3d 13 a0 8b f2 a4 49 84 4c 5f 8a + 47 80 01 bc 4d 4c 62 79 84 d5 a4 1d a8 d0 40 29 19`), + }, + { + `extract secret "handshake"`, + args{ + parseVector(`IKM (32 octets): 8b d4 05 4f b5 5b 9d 63 fd fb ac f9 f0 4b 9f 0d + 35 e6 d6 3f 53 75 63 ef d4 62 72 90 0f 89 49 2d`), + parseVector(`salt (32 octets): 6f 26 15 a1 08 c7 02 c5 67 8f 54 fc 9d ba b6 97 + 16 c0 76 18 9c 48 25 0c eb ea c3 57 6c 36 11 ba`), + }, + parseVector(`secret (32 octets): 1d c8 26 e9 36 06 aa 6f dc 0a ad c1 2f 74 1b + 01 04 6a a6 b9 9f 69 1e d2 21 a9 f0 ca 04 3f be ac`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := cipherSuitesTLS13[0] + if got := c.extract(tt.args.newSecret, tt.args.currentSecret); !bytes.Equal(got, tt.want) { + t.Errorf("cipherSuiteTLS13.extract() = % x, want % x", got, tt.want) + } + }) + } +} diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go index b40bd46672..d0feca4f90 100644 --- a/src/go/build/deps_test.go +++ b/src/go/build/deps_test.go @@ -396,7 +396,7 @@ var pkgDeps = map[string][]string{ // SSL/TLS. "crypto/tls": { - "L4", "CRYPTO-MATH", "OS", "golang_org/x/crypto/cryptobyte", + "L4", "CRYPTO-MATH", "OS", "golang_org/x/crypto/cryptobyte", "golang_org/x/crypto/hkdf", "container/list", "crypto/x509", "encoding/pem", "net", "syscall", }, "crypto/x509": { diff --git a/src/vendor/golang_org/x/crypto/hkdf/example_test.go b/src/vendor/golang_org/x/crypto/hkdf/example_test.go new file mode 100644 index 0000000000..1fd140a324 --- /dev/null +++ b/src/vendor/golang_org/x/crypto/hkdf/example_test.go @@ -0,0 +1,56 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package hkdf_test + +import ( + "bytes" + "crypto/rand" + "crypto/sha256" + "fmt" + "io" + + "golang_org/x/crypto/hkdf" +) + +// Usage example that expands one master secret into three other +// cryptographically secure keys. +func Example_usage() { + // Underlying hash function for HMAC. + hash := sha256.New + + // Cryptographically secure master secret. + secret := []byte{0x00, 0x01, 0x02, 0x03} // i.e. NOT this. + + // Non-secret salt, optional (can be nil). + // Recommended: hash-length random value. + salt := make([]byte, hash().Size()) + if _, err := rand.Read(salt); err != nil { + panic(err) + } + + // Non-secret context info, optional (can be nil). + info := []byte("hkdf example") + + // Generate three 128-bit derived keys. + hkdf := hkdf.New(hash, secret, salt, info) + + var keys [][]byte + for i := 0; i < 3; i++ { + key := make([]byte, 16) + if _, err := io.ReadFull(hkdf, key); err != nil { + panic(err) + } + keys = append(keys, key) + } + + for i := range keys { + fmt.Printf("Key #%d: %v\n", i+1, !bytes.Equal(keys[i], make([]byte, 16))) + } + + // Output: + // Key #1: true + // Key #2: true + // Key #3: true +} diff --git a/src/vendor/golang_org/x/crypto/hkdf/hkdf.go b/src/vendor/golang_org/x/crypto/hkdf/hkdf.go new file mode 100644 index 0000000000..dda3f143be --- /dev/null +++ b/src/vendor/golang_org/x/crypto/hkdf/hkdf.go @@ -0,0 +1,93 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package hkdf implements the HMAC-based Extract-and-Expand Key Derivation +// Function (HKDF) as defined in RFC 5869. +// +// HKDF is a cryptographic key derivation function (KDF) with the goal of +// expanding limited input keying material into one or more cryptographically +// strong secret keys. +package hkdf // import "golang.org/x/crypto/hkdf" + +import ( + "crypto/hmac" + "errors" + "hash" + "io" +) + +// Extract generates a pseudorandom key for use with Expand from an input secret +// and an optional independent salt. +// +// Only use this function if you need to reuse the extracted key with multiple +// Expand invocations and different context values. Most common scenarios, +// including the generation of multiple keys, should use New instead. +func Extract(hash func() hash.Hash, secret, salt []byte) []byte { + if salt == nil { + salt = make([]byte, hash().Size()) + } + extractor := hmac.New(hash, salt) + extractor.Write(secret) + return extractor.Sum(nil) +} + +type hkdf struct { + expander hash.Hash + size int + + info []byte + counter byte + + prev []byte + buf []byte +} + +func (f *hkdf) Read(p []byte) (int, error) { + // Check whether enough data can be generated + need := len(p) + remains := len(f.buf) + int(255-f.counter+1)*f.size + if remains < need { + return 0, errors.New("hkdf: entropy limit reached") + } + // Read any leftover from the buffer + n := copy(p, f.buf) + p = p[n:] + + // Fill the rest of the buffer + for len(p) > 0 { + f.expander.Reset() + f.expander.Write(f.prev) + f.expander.Write(f.info) + f.expander.Write([]byte{f.counter}) + f.prev = f.expander.Sum(f.prev[:0]) + f.counter++ + + // Copy the new batch into p + f.buf = f.prev + n = copy(p, f.buf) + p = p[n:] + } + // Save leftovers for next run + f.buf = f.buf[n:] + + return need, nil +} + +// Expand returns a Reader, from which keys can be read, using the given +// pseudorandom key and optional context info, skipping the extraction step. +// +// The pseudorandomKey should have been generated by Extract, or be a uniformly +// random or pseudorandom cryptographically strong key. See RFC 5869, Section +// 3.3. Most common scenarios will want to use New instead. +func Expand(hash func() hash.Hash, pseudorandomKey, info []byte) io.Reader { + expander := hmac.New(hash, pseudorandomKey) + return &hkdf{expander, expander.Size(), info, 1, nil, nil} +} + +// New returns a Reader, from which keys can be read, using the given hash, +// secret, salt and context info. Salt and info can be nil. +func New(hash func() hash.Hash, secret, salt, info []byte) io.Reader { + prk := Extract(hash, secret, salt) + return Expand(hash, prk, info) +} diff --git a/src/vendor/golang_org/x/crypto/hkdf/hkdf_test.go b/src/vendor/golang_org/x/crypto/hkdf/hkdf_test.go new file mode 100644 index 0000000000..ea575772ef --- /dev/null +++ b/src/vendor/golang_org/x/crypto/hkdf/hkdf_test.go @@ -0,0 +1,449 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +package hkdf + +import ( + "bytes" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "hash" + "io" + "testing" +) + +type hkdfTest struct { + hash func() hash.Hash + master []byte + salt []byte + prk []byte + info []byte + out []byte +} + +var hkdfTests = []hkdfTest{ + // Tests from RFC 5869 + { + sha256.New, + []byte{ + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + }, + []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, + }, + []byte{ + 0x07, 0x77, 0x09, 0x36, 0x2c, 0x2e, 0x32, 0xdf, + 0x0d, 0xdc, 0x3f, 0x0d, 0xc4, 0x7b, 0xba, 0x63, + 0x90, 0xb6, 0xc7, 0x3b, 0xb5, 0x0f, 0x9c, 0x31, + 0x22, 0xec, 0x84, 0x4a, 0xd7, 0xc2, 0xb3, 0xe5, + }, + []byte{ + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, + 0xf8, 0xf9, + }, + []byte{ + 0x3c, 0xb2, 0x5f, 0x25, 0xfa, 0xac, 0xd5, 0x7a, + 0x90, 0x43, 0x4f, 0x64, 0xd0, 0x36, 0x2f, 0x2a, + 0x2d, 0x2d, 0x0a, 0x90, 0xcf, 0x1a, 0x5a, 0x4c, + 0x5d, 0xb0, 0x2d, 0x56, 0xec, 0xc4, 0xc5, 0xbf, + 0x34, 0x00, 0x72, 0x08, 0xd5, 0xb8, 0x87, 0x18, + 0x58, 0x65, + }, + }, + { + sha256.New, + []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, + 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, + 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, + }, + []byte{ + 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, + 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, + 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, + 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, + 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, + 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, + 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf, + }, + []byte{ + 0x06, 0xa6, 0xb8, 0x8c, 0x58, 0x53, 0x36, 0x1a, + 0x06, 0x10, 0x4c, 0x9c, 0xeb, 0x35, 0xb4, 0x5c, + 0xef, 0x76, 0x00, 0x14, 0x90, 0x46, 0x71, 0x01, + 0x4a, 0x19, 0x3f, 0x40, 0xc1, 0x5f, 0xc2, 0x44, + }, + []byte{ + 0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7, + 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf, + 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, + 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf, + 0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, + 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf, + 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, + 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, + 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, + }, + []byte{ + 0xb1, 0x1e, 0x39, 0x8d, 0xc8, 0x03, 0x27, 0xa1, + 0xc8, 0xe7, 0xf7, 0x8c, 0x59, 0x6a, 0x49, 0x34, + 0x4f, 0x01, 0x2e, 0xda, 0x2d, 0x4e, 0xfa, 0xd8, + 0xa0, 0x50, 0xcc, 0x4c, 0x19, 0xaf, 0xa9, 0x7c, + 0x59, 0x04, 0x5a, 0x99, 0xca, 0xc7, 0x82, 0x72, + 0x71, 0xcb, 0x41, 0xc6, 0x5e, 0x59, 0x0e, 0x09, + 0xda, 0x32, 0x75, 0x60, 0x0c, 0x2f, 0x09, 0xb8, + 0x36, 0x77, 0x93, 0xa9, 0xac, 0xa3, 0xdb, 0x71, + 0xcc, 0x30, 0xc5, 0x81, 0x79, 0xec, 0x3e, 0x87, + 0xc1, 0x4c, 0x01, 0xd5, 0xc1, 0xf3, 0x43, 0x4f, + 0x1d, 0x87, + }, + }, + { + sha256.New, + []byte{ + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + }, + []byte{}, + []byte{ + 0x19, 0xef, 0x24, 0xa3, 0x2c, 0x71, 0x7b, 0x16, + 0x7f, 0x33, 0xa9, 0x1d, 0x6f, 0x64, 0x8b, 0xdf, + 0x96, 0x59, 0x67, 0x76, 0xaf, 0xdb, 0x63, 0x77, + 0xac, 0x43, 0x4c, 0x1c, 0x29, 0x3c, 0xcb, 0x04, + }, + []byte{}, + []byte{ + 0x8d, 0xa4, 0xe7, 0x75, 0xa5, 0x63, 0xc1, 0x8f, + 0x71, 0x5f, 0x80, 0x2a, 0x06, 0x3c, 0x5a, 0x31, + 0xb8, 0xa1, 0x1f, 0x5c, 0x5e, 0xe1, 0x87, 0x9e, + 0xc3, 0x45, 0x4e, 0x5f, 0x3c, 0x73, 0x8d, 0x2d, + 0x9d, 0x20, 0x13, 0x95, 0xfa, 0xa4, 0xb6, 0x1a, + 0x96, 0xc8, + }, + }, + { + sha256.New, + []byte{ + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + }, + nil, + []byte{ + 0x19, 0xef, 0x24, 0xa3, 0x2c, 0x71, 0x7b, 0x16, + 0x7f, 0x33, 0xa9, 0x1d, 0x6f, 0x64, 0x8b, 0xdf, + 0x96, 0x59, 0x67, 0x76, 0xaf, 0xdb, 0x63, 0x77, + 0xac, 0x43, 0x4c, 0x1c, 0x29, 0x3c, 0xcb, 0x04, + }, + nil, + []byte{ + 0x8d, 0xa4, 0xe7, 0x75, 0xa5, 0x63, 0xc1, 0x8f, + 0x71, 0x5f, 0x80, 0x2a, 0x06, 0x3c, 0x5a, 0x31, + 0xb8, 0xa1, 0x1f, 0x5c, 0x5e, 0xe1, 0x87, 0x9e, + 0xc3, 0x45, 0x4e, 0x5f, 0x3c, 0x73, 0x8d, 0x2d, + 0x9d, 0x20, 0x13, 0x95, 0xfa, 0xa4, 0xb6, 0x1a, + 0x96, 0xc8, + }, + }, + { + sha1.New, + []byte{ + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, + }, + []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, + }, + []byte{ + 0x9b, 0x6c, 0x18, 0xc4, 0x32, 0xa7, 0xbf, 0x8f, + 0x0e, 0x71, 0xc8, 0xeb, 0x88, 0xf4, 0xb3, 0x0b, + 0xaa, 0x2b, 0xa2, 0x43, + }, + []byte{ + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, + 0xf8, 0xf9, + }, + []byte{ + 0x08, 0x5a, 0x01, 0xea, 0x1b, 0x10, 0xf3, 0x69, + 0x33, 0x06, 0x8b, 0x56, 0xef, 0xa5, 0xad, 0x81, + 0xa4, 0xf1, 0x4b, 0x82, 0x2f, 0x5b, 0x09, 0x15, + 0x68, 0xa9, 0xcd, 0xd4, 0xf1, 0x55, 0xfd, 0xa2, + 0xc2, 0x2e, 0x42, 0x24, 0x78, 0xd3, 0x05, 0xf3, + 0xf8, 0x96, + }, + }, + { + sha1.New, + []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, + 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, + 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, + 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, + 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, + 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, + 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, + }, + []byte{ + 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, + 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, + 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, + 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, + 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, + 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, + 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, + 0xa0, 0xa1, 0xa2, 0xa3, 0xa4, 0xa5, 0xa6, 0xa7, + 0xa8, 0xa9, 0xaa, 0xab, 0xac, 0xad, 0xae, 0xaf, + }, + []byte{ + 0x8a, 0xda, 0xe0, 0x9a, 0x2a, 0x30, 0x70, 0x59, + 0x47, 0x8d, 0x30, 0x9b, 0x26, 0xc4, 0x11, 0x5a, + 0x22, 0x4c, 0xfa, 0xf6, + }, + []byte{ + 0xb0, 0xb1, 0xb2, 0xb3, 0xb4, 0xb5, 0xb6, 0xb7, + 0xb8, 0xb9, 0xba, 0xbb, 0xbc, 0xbd, 0xbe, 0xbf, + 0xc0, 0xc1, 0xc2, 0xc3, 0xc4, 0xc5, 0xc6, 0xc7, + 0xc8, 0xc9, 0xca, 0xcb, 0xcc, 0xcd, 0xce, 0xcf, + 0xd0, 0xd1, 0xd2, 0xd3, 0xd4, 0xd5, 0xd6, 0xd7, + 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf, + 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, + 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, + 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, + 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, + }, + []byte{ + 0x0b, 0xd7, 0x70, 0xa7, 0x4d, 0x11, 0x60, 0xf7, + 0xc9, 0xf1, 0x2c, 0xd5, 0x91, 0x2a, 0x06, 0xeb, + 0xff, 0x6a, 0xdc, 0xae, 0x89, 0x9d, 0x92, 0x19, + 0x1f, 0xe4, 0x30, 0x56, 0x73, 0xba, 0x2f, 0xfe, + 0x8f, 0xa3, 0xf1, 0xa4, 0xe5, 0xad, 0x79, 0xf3, + 0xf3, 0x34, 0xb3, 0xb2, 0x02, 0xb2, 0x17, 0x3c, + 0x48, 0x6e, 0xa3, 0x7c, 0xe3, 0xd3, 0x97, 0xed, + 0x03, 0x4c, 0x7f, 0x9d, 0xfe, 0xb1, 0x5c, 0x5e, + 0x92, 0x73, 0x36, 0xd0, 0x44, 0x1f, 0x4c, 0x43, + 0x00, 0xe2, 0xcf, 0xf0, 0xd0, 0x90, 0x0b, 0x52, + 0xd3, 0xb4, + }, + }, + { + sha1.New, + []byte{ + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, 0x0b, + }, + []byte{}, + []byte{ + 0xda, 0x8c, 0x8a, 0x73, 0xc7, 0xfa, 0x77, 0x28, + 0x8e, 0xc6, 0xf5, 0xe7, 0xc2, 0x97, 0x78, 0x6a, + 0xa0, 0xd3, 0x2d, 0x01, + }, + []byte{}, + []byte{ + 0x0a, 0xc1, 0xaf, 0x70, 0x02, 0xb3, 0xd7, 0x61, + 0xd1, 0xe5, 0x52, 0x98, 0xda, 0x9d, 0x05, 0x06, + 0xb9, 0xae, 0x52, 0x05, 0x72, 0x20, 0xa3, 0x06, + 0xe0, 0x7b, 0x6b, 0x87, 0xe8, 0xdf, 0x21, 0xd0, + 0xea, 0x00, 0x03, 0x3d, 0xe0, 0x39, 0x84, 0xd3, + 0x49, 0x18, + }, + }, + { + sha1.New, + []byte{ + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, + 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, 0x0c, + }, + nil, + []byte{ + 0x2a, 0xdc, 0xca, 0xda, 0x18, 0x77, 0x9e, 0x7c, + 0x20, 0x77, 0xad, 0x2e, 0xb1, 0x9d, 0x3f, 0x3e, + 0x73, 0x13, 0x85, 0xdd, + }, + nil, + []byte{ + 0x2c, 0x91, 0x11, 0x72, 0x04, 0xd7, 0x45, 0xf3, + 0x50, 0x0d, 0x63, 0x6a, 0x62, 0xf6, 0x4f, 0x0a, + 0xb3, 0xba, 0xe5, 0x48, 0xaa, 0x53, 0xd4, 0x23, + 0xb0, 0xd1, 0xf2, 0x7e, 0xbb, 0xa6, 0xf5, 0xe5, + 0x67, 0x3a, 0x08, 0x1d, 0x70, 0xcc, 0xe7, 0xac, + 0xfc, 0x48, + }, + }, +} + +func TestHKDF(t *testing.T) { + for i, tt := range hkdfTests { + prk := Extract(tt.hash, tt.master, tt.salt) + if !bytes.Equal(prk, tt.prk) { + t.Errorf("test %d: incorrect PRK: have %v, need %v.", i, prk, tt.prk) + } + + hkdf := New(tt.hash, tt.master, tt.salt, tt.info) + out := make([]byte, len(tt.out)) + + n, err := io.ReadFull(hkdf, out) + if n != len(tt.out) || err != nil { + t.Errorf("test %d: not enough output bytes: %d.", i, n) + } + + if !bytes.Equal(out, tt.out) { + t.Errorf("test %d: incorrect output: have %v, need %v.", i, out, tt.out) + } + + hkdf = Expand(tt.hash, prk, tt.info) + + n, err = io.ReadFull(hkdf, out) + if n != len(tt.out) || err != nil { + t.Errorf("test %d: not enough output bytes from Expand: %d.", i, n) + } + + if !bytes.Equal(out, tt.out) { + t.Errorf("test %d: incorrect output from Expand: have %v, need %v.", i, out, tt.out) + } + } +} + +func TestHKDFMultiRead(t *testing.T) { + for i, tt := range hkdfTests { + hkdf := New(tt.hash, tt.master, tt.salt, tt.info) + out := make([]byte, len(tt.out)) + + for b := 0; b < len(tt.out); b++ { + n, err := io.ReadFull(hkdf, out[b:b+1]) + if n != 1 || err != nil { + t.Errorf("test %d.%d: not enough output bytes: have %d, need %d .", i, b, n, len(tt.out)) + } + } + + if !bytes.Equal(out, tt.out) { + t.Errorf("test %d: incorrect output: have %v, need %v.", i, out, tt.out) + } + } +} + +func TestHKDFLimit(t *testing.T) { + hash := sha1.New + master := []byte{0x00, 0x01, 0x02, 0x03} + info := []byte{} + + hkdf := New(hash, master, nil, info) + limit := hash().Size() * 255 + out := make([]byte, limit) + + // The maximum output bytes should be extractable + n, err := io.ReadFull(hkdf, out) + if n != limit || err != nil { + t.Errorf("not enough output bytes: %d, %v.", n, err) + } + + // Reading one more should fail + n, err = io.ReadFull(hkdf, make([]byte, 1)) + if n > 0 || err == nil { + t.Errorf("key expansion overflowed: n = %d, err = %v", n, err) + } +} + +func Benchmark16ByteMD5Single(b *testing.B) { + benchmarkHKDFSingle(md5.New, 16, b) +} + +func Benchmark20ByteSHA1Single(b *testing.B) { + benchmarkHKDFSingle(sha1.New, 20, b) +} + +func Benchmark32ByteSHA256Single(b *testing.B) { + benchmarkHKDFSingle(sha256.New, 32, b) +} + +func Benchmark64ByteSHA512Single(b *testing.B) { + benchmarkHKDFSingle(sha512.New, 64, b) +} + +func Benchmark8ByteMD5Stream(b *testing.B) { + benchmarkHKDFStream(md5.New, 8, b) +} + +func Benchmark16ByteMD5Stream(b *testing.B) { + benchmarkHKDFStream(md5.New, 16, b) +} + +func Benchmark8ByteSHA1Stream(b *testing.B) { + benchmarkHKDFStream(sha1.New, 8, b) +} + +func Benchmark20ByteSHA1Stream(b *testing.B) { + benchmarkHKDFStream(sha1.New, 20, b) +} + +func Benchmark8ByteSHA256Stream(b *testing.B) { + benchmarkHKDFStream(sha256.New, 8, b) +} + +func Benchmark32ByteSHA256Stream(b *testing.B) { + benchmarkHKDFStream(sha256.New, 32, b) +} + +func Benchmark8ByteSHA512Stream(b *testing.B) { + benchmarkHKDFStream(sha512.New, 8, b) +} + +func Benchmark64ByteSHA512Stream(b *testing.B) { + benchmarkHKDFStream(sha512.New, 64, b) +} + +func benchmarkHKDFSingle(hasher func() hash.Hash, block int, b *testing.B) { + master := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07} + salt := []byte{0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17} + info := []byte{0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27} + out := make([]byte, block) + + b.SetBytes(int64(block)) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + hkdf := New(hasher, master, salt, info) + io.ReadFull(hkdf, out) + } +} + +func benchmarkHKDFStream(hasher func() hash.Hash, block int, b *testing.B) { + master := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07} + salt := []byte{0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17} + info := []byte{0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27} + out := make([]byte, block) + + b.SetBytes(int64(block)) + b.ResetTimer() + + hkdf := New(hasher, master, salt, info) + for i := 0; i < b.N; i++ { + _, err := io.ReadFull(hkdf, out) + if err != nil { + hkdf = New(hasher, master, salt, info) + i-- + } + } +}