diff --git a/src/crypto/internal/mlkem768/mlkem768.go b/src/crypto/internal/mlkem768/mlkem768.go index f152e7682ee..830841f7380 100644 --- a/src/crypto/internal/mlkem768/mlkem768.go +++ b/src/crypto/internal/mlkem768/mlkem768.go @@ -73,6 +73,8 @@ type DecapsulationKey struct { } // Bytes returns the decapsulation key as a 64-byte seed in the "d || z" form. +// +// The decapsulation key must be kept secret. func (dk *DecapsulationKey) Bytes() []byte { var b [SeedSize]byte copy(b[:], dk.d[:]) @@ -82,17 +84,34 @@ func (dk *DecapsulationKey) Bytes() []byte { // EncapsulationKey returns the public encapsulation key necessary to produce // ciphertexts. -func (dk *DecapsulationKey) EncapsulationKey() []byte { - // The actual logic is in a separate function to outline this allocation. - b := make([]byte, 0, EncapsulationKeySize) - return dk.encapsulationKey(b) +func (dk *DecapsulationKey) EncapsulationKey() *EncapsulationKey { + return &EncapsulationKey{ + ρ: dk.ρ, + h: dk.h, + encryptionKey: dk.encryptionKey, + } } -func (dk *DecapsulationKey) encapsulationKey(b []byte) []byte { - for i := range dk.t { - b = polyByteEncode(b, dk.t[i]) +// An EncapsulationKey is the public key used to produce ciphertexts to be +// decapsulated by the corresponding [DecapsulationKey]. +type EncapsulationKey struct { + ρ [32]byte // sampleNTT seed for A + h [32]byte // H(ek) + encryptionKey +} + +// Bytes returns the encapsulation key as a byte slice. +func (ek *EncapsulationKey) Bytes() []byte { + // The actual logic is in a separate function to outline this allocation. + b := make([]byte, 0, EncapsulationKeySize) + return ek.bytes(b) +} + +func (ek *EncapsulationKey) bytes(b []byte) []byte { + for i := range ek.t { + b = polyByteEncode(b, ek.t[i]) } - b = append(b, dk.ρ[:]...) + b = append(b, ek.ρ[:]...) return b } @@ -123,9 +142,9 @@ func generateKey(dk *DecapsulationKey) *DecapsulationKey { return kemKeyGen(dk, &d, &z) } -// NewKeyFromSeed deterministically generates a decapsulation key from a 64-byte +// NewDecapsulationKey parses a decapsulation key from a 64-byte // seed in the "d || z" form. The seed must be uniformly random. -func NewKeyFromSeed(seed []byte) (*DecapsulationKey, error) { +func NewDecapsulationKey(seed []byte) (*DecapsulationKey, error) { // The actual logic is in a separate function to outline this allocation. dk := &DecapsulationKey{} return newKeyFromSeed(dk, seed) @@ -187,7 +206,7 @@ func kemKeyGen(dk *DecapsulationKey, d, z *[32]byte) *DecapsulationKey { } H := sha3.New256() - ek := dk.EncapsulationKey() + ek := dk.EncapsulationKey().Bytes() H.Write(ek) H.Sum(dk.h[:0]) @@ -196,74 +215,75 @@ func kemKeyGen(dk *DecapsulationKey, d, z *[32]byte) *DecapsulationKey { // Encapsulate generates a shared key and an associated ciphertext from an // encapsulation key, drawing random bytes from crypto/rand. -// If the encapsulation key is not valid, Encapsulate returns an error. // // The shared key must be kept secret. -func Encapsulate(encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) { +func (ek *EncapsulationKey) Encapsulate() (ciphertext, sharedKey []byte) { // The actual logic is in a separate function to outline this allocation. var cc [CiphertextSize]byte - return encapsulate(&cc, encapsulationKey) + return ek.encapsulate(&cc) } -func encapsulate(cc *[CiphertextSize]byte, encapsulationKey []byte) (ciphertext, sharedKey []byte, err error) { - if len(encapsulationKey) != EncapsulationKeySize { - return nil, nil, errors.New("mlkem768: invalid encapsulation key length") - } +func (ek *EncapsulationKey) encapsulate(cc *[CiphertextSize]byte) (ciphertext, sharedKey []byte) { var m [messageSize]byte rand.Read(m[:]) // Note that the modulus check (step 2 of the encapsulation key check from // FIPS 203, Section 7.2) is performed by polyByteDecode in parseEK. - return kemEncaps(cc, encapsulationKey, &m) + return kemEncaps(cc, ek, &m) } // kemEncaps generates a shared key and an associated ciphertext. // // It implements ML-KEM.Encaps_internal according to FIPS 203, Algorithm 17. -func kemEncaps(cc *[CiphertextSize]byte, ek []byte, m *[messageSize]byte) (c, K []byte, err error) { +func kemEncaps(cc *[CiphertextSize]byte, ek *EncapsulationKey, m *[messageSize]byte) (c, K []byte) { if cc == nil { cc = &[CiphertextSize]byte{} } - H := sha3.Sum256(ek[:]) g := sha3.New512() g.Write(m[:]) - g.Write(H[:]) + g.Write(ek.h[:]) G := g.Sum(nil) K, r := G[:SharedKeySize], G[SharedKeySize:] - var ex encryptionKey - if err := parseEK(&ex, ek[:]); err != nil { - return nil, nil, err - } - c = pkeEncrypt(cc, &ex, m, r) - return c, K, nil + c = pkeEncrypt(cc, &ek.encryptionKey, m, r) + return c, K +} + +// NewEncapsulationKey parses an encapsulation key from its encoded form. +// If the encapsulation key is not valid, NewEncapsulationKey returns an error. +func NewEncapsulationKey(encapsulationKey []byte) (*EncapsulationKey, error) { + // The actual logic is in a separate function to outline this allocation. + ek := &EncapsulationKey{} + return parseEK(ek, encapsulationKey) } // parseEK parses an encryption key from its encoded form. // // It implements the initial stages of K-PKE.Encrypt according to FIPS 203, // Algorithm 14. -func parseEK(ex *encryptionKey, ekPKE []byte) error { +func parseEK(ek *EncapsulationKey, ekPKE []byte) (*EncapsulationKey, error) { if len(ekPKE) != encryptionKeySize { - return errors.New("mlkem768: invalid encryption key length") + return nil, errors.New("mlkem768: invalid encapsulation key length") } - for i := range ex.t { + ek.h = sha3.Sum256(ekPKE[:]) + + for i := range ek.t { var err error - ex.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12]) + ek.t[i], err = polyByteDecode[nttElement](ekPKE[:encodingSize12]) if err != nil { - return err + return nil, err } ekPKE = ekPKE[encodingSize12:] } - ρ := ekPKE + copy(ek.ρ[:], ekPKE) for i := byte(0); i < k; i++ { for j := byte(0); j < k; j++ { - ex.a[i*k+j] = sampleNTT(ρ, j, i) + ek.a[i*k+j] = sampleNTT(ek.ρ[:], j, i) } } - return nil + return ek, nil } // pkeEncrypt encrypt a plaintext message. diff --git a/src/crypto/internal/mlkem768/mlkem768_test.go b/src/crypto/internal/mlkem768/mlkem768_test.go index 4775f77aeba..295aa95d0ad 100644 --- a/src/crypto/internal/mlkem768/mlkem768_test.go +++ b/src/crypto/internal/mlkem768/mlkem768_test.go @@ -20,10 +20,7 @@ func TestRoundTrip(t *testing.T) { if err != nil { t.Fatal(err) } - c, Ke, err := Encapsulate(dk.EncapsulationKey()) - if err != nil { - t.Fatal(err) - } + c, Ke := dk.EncapsulationKey().Encapsulate() Kd, err := dk.Decapsulate(c) if err != nil { t.Fatal(err) @@ -36,17 +33,14 @@ func TestRoundTrip(t *testing.T) { if err != nil { t.Fatal(err) } - if bytes.Equal(dk.EncapsulationKey(), dk1.EncapsulationKey()) { + if bytes.Equal(dk.EncapsulationKey().Bytes(), dk1.EncapsulationKey().Bytes()) { t.Fail() } if bytes.Equal(dk.Bytes(), dk1.Bytes()) { t.Fail() } - c1, Ke1, err := Encapsulate(dk.EncapsulationKey()) - if err != nil { - t.Fatal(err) - } + c1, Ke1 := dk.EncapsulationKey().Encapsulate() if bytes.Equal(c, c1) { t.Fail() } @@ -61,25 +55,22 @@ func TestBadLengths(t *testing.T) { t.Fatal(err) } ek := dk.EncapsulationKey() + ekBytes := dk.EncapsulationKey().Bytes() + c, _ := ek.Encapsulate() - for i := 0; i < len(ek)-1; i++ { - if _, _, err := Encapsulate(ek[:i]); err == nil { + for i := 0; i < len(ekBytes)-1; i++ { + if _, err := NewEncapsulationKey(ekBytes[:i]); err == nil { t.Errorf("expected error for ek length %d", i) } } - ekLong := ek + ekLong := ekBytes for i := 0; i < 100; i++ { ekLong = append(ekLong, 0) - if _, _, err := Encapsulate(ekLong); err == nil { + if _, err := NewEncapsulationKey(ekLong); err == nil { t.Errorf("expected error for ek length %d", len(ekLong)) } } - c, _, err := Encapsulate(ek) - if err != nil { - t.Fatal(err) - } - for i := 0; i < len(c)-1; i++ { if _, err := dk.Decapsulate(c[:i]); err == nil { t.Errorf("expected error for c length %d", i) @@ -118,18 +109,15 @@ func TestAccumulated(t *testing.T) { for i := 0; i < n; i++ { s.Read(seed) - dk, err := NewKeyFromSeed(seed) + dk, err := NewDecapsulationKey(seed) if err != nil { t.Fatal(err) } ek := dk.EncapsulationKey() - o.Write(ek) + o.Write(ek.Bytes()) s.Read(msg[:]) - ct, k, err := kemEncaps(nil, ek, &msg) - if err != nil { - t.Fatal(err) - } + ct, k := kemEncaps(nil, ek, &msg) o.Write(ct) o.Write(k) @@ -165,7 +153,7 @@ func BenchmarkKeyGen(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { dk := kemKeyGen(&dk, &d, &z) - sink ^= dk.EncapsulationKey()[0] + sink ^= dk.EncapsulationKey().Bytes()[0] } } @@ -174,18 +162,19 @@ func BenchmarkEncaps(b *testing.B) { rand.Read(seed) var m [messageSize]byte rand.Read(m[:]) - dk, err := NewKeyFromSeed(seed) + dk, err := NewDecapsulationKey(seed) if err != nil { b.Fatal(err) } - ek := dk.EncapsulationKey() + ekBytes := dk.EncapsulationKey().Bytes() var c [CiphertextSize]byte b.ResetTimer() for i := 0; i < b.N; i++ { - c, K, err := kemEncaps(&c, ek, &m) + ek, err := NewEncapsulationKey(ekBytes) if err != nil { b.Fatal(err) } + c, K := kemEncaps(&c, ek, &m) sink ^= c[0] ^ K[0] } } @@ -196,10 +185,7 @@ func BenchmarkDecaps(b *testing.B) { b.Fatal(err) } ek := dk.EncapsulationKey() - c, _, err := Encapsulate(ek) - if err != nil { - b.Fatal(err) - } + c, _ := ek.Encapsulate() b.ResetTimer() for i := 0; i < b.N; i++ { K := kemDecaps(dk, (*[CiphertextSize]byte)(c)) @@ -213,7 +199,8 @@ func BenchmarkRoundTrip(b *testing.B) { b.Fatal(err) } ek := dk.EncapsulationKey() - c, _, err := Encapsulate(ek) + ekBytes := ek.Bytes() + c, _ := ek.Encapsulate() if err != nil { b.Fatal(err) } @@ -223,7 +210,7 @@ func BenchmarkRoundTrip(b *testing.B) { if err != nil { b.Fatal(err) } - ekS := dkS.EncapsulationKey() + ekS := dkS.EncapsulationKey().Bytes() sink ^= ekS[0] Ks, err := dk.Decapsulate(c) @@ -235,7 +222,11 @@ func BenchmarkRoundTrip(b *testing.B) { }) b.Run("Bob", func(b *testing.B) { for i := 0; i < b.N; i++ { - cS, Ks, err := Encapsulate(ek) + ek, err := NewEncapsulationKey(ekBytes) + if err != nil { + b.Fatal(err) + } + cS, Ks := ek.Encapsulate() if err != nil { b.Fatal(err) } diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index f6bccc40bcb..8965ad6eafd 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -164,7 +164,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon if _, err := io.ReadFull(config.rand(), seed); err != nil { return nil, nil, nil, err } - keyShareKeys.kyber, err = mlkem768.NewKeyFromSeed(seed) + keyShareKeys.kyber, err = mlkem768.NewDecapsulationKey(seed) if err != nil { return nil, nil, nil, err } @@ -174,7 +174,7 @@ func (c *Conn) makeClientHello() (*clientHelloMsg, *keySharePrivateKeys, *echCon // both, as allowed by draft-ietf-tls-hybrid-design-09, Section 3.2. hello.keyShares = []keyShare{ {group: x25519Kyber768Draft00, data: append(keyShareKeys.ecdhe.PublicKey().Bytes(), - keyShareKeys.kyber.EncapsulationKey()...)}, + keyShareKeys.kyber.EncapsulationKey().Bytes()...)}, {group: X25519, data: keyShareKeys.ecdhe.PublicKey().Bytes()}, } } else { diff --git a/src/crypto/tls/key_schedule.go b/src/crypto/tls/key_schedule.go index e8ee9ce9c2f..3bbfc1b435f 100644 --- a/src/crypto/tls/key_schedule.go +++ b/src/crypto/tls/key_schedule.go @@ -63,19 +63,20 @@ func kyberDecapsulate(dk *mlkem768.DecapsulationKey, c []byte) ([]byte, error) { if err != nil { return nil, err } - return kyberSharedSecret(K, c), nil + return kyberSharedSecret(c, K), nil } // kyberEncapsulate implements encapsulation according to Kyber Round 3. func kyberEncapsulate(ek []byte) (c, ss []byte, err error) { - c, ss, err = mlkem768.Encapsulate(ek) + k, err := mlkem768.NewEncapsulationKey(ek) if err != nil { return nil, nil, err } - return c, kyberSharedSecret(ss, c), nil + c, ss = k.Encapsulate() + return c, kyberSharedSecret(c, ss), nil } -func kyberSharedSecret(K, c []byte) []byte { +func kyberSharedSecret(c, K []byte) []byte { // Package mlkem768 implements ML-KEM, which compared to Kyber removed a // final hashing step. Compute SHAKE-256(K || SHA3-256(c), 32) to match Kyber. // See https://words.filippo.io/mlkem768/#bonus-track-using-a-ml-kem-implementation-as-kyber-v3. diff --git a/src/crypto/tls/key_schedule_test.go b/src/crypto/tls/key_schedule_test.go index 095113ca179..32532770d41 100644 --- a/src/crypto/tls/key_schedule_test.go +++ b/src/crypto/tls/key_schedule_test.go @@ -124,7 +124,7 @@ func TestKyberEncapsulate(t *testing.T) { if err != nil { t.Fatal(err) } - ct, ss, err := kyberEncapsulate(dk.EncapsulationKey()) + ct, ss, err := kyberEncapsulate(dk.EncapsulationKey().Bytes()) if err != nil { t.Fatal(err) }