From ffb8b1c586a61882b6c362c0454f56f89d2693c9 Mon Sep 17 00:00:00 2001 From: Filippo Valsorda Date: Wed, 16 Oct 2024 14:31:44 +0200 Subject: [PATCH] crypto/internal/mlkem768: make Decapsulate a method This will make it easier to support multiple sizes if needed. Change-Id: I47495559fdbbf678fd98421ad6cb28172e5c810d Reviewed-on: https://go-review.googlesource.com/c/go/+/621977 Reviewed-by: Daniel McCarney Reviewed-by: Russ Cox LUCI-TryBot-Result: Go LUCI Auto-Submit: Filippo Valsorda Reviewed-by: Roland Shoemaker --- src/crypto/internal/mlkem768/mlkem768.go | 2 +- src/crypto/internal/mlkem768/mlkem768_test.go | 12 ++++++------ src/crypto/tls/key_schedule.go | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/crypto/internal/mlkem768/mlkem768.go b/src/crypto/internal/mlkem768/mlkem768.go index 527c93ffe36..0daf3594466 100644 --- a/src/crypto/internal/mlkem768/mlkem768.go +++ b/src/crypto/internal/mlkem768/mlkem768.go @@ -320,7 +320,7 @@ func pkeEncrypt(cc *[CiphertextSize]byte, ex *encryptionKey, m *[messageSize]byt // If the ciphertext is not valid, Decapsulate returns an error. // // The shared key must be kept secret. -func Decapsulate(dk *DecapsulationKey, ciphertext []byte) (sharedKey []byte, err error) { +func (dk *DecapsulationKey) Decapsulate(ciphertext []byte) (sharedKey []byte, err error) { if len(ciphertext) != CiphertextSize { return nil, errors.New("mlkem768: invalid ciphertext length") } diff --git a/src/crypto/internal/mlkem768/mlkem768_test.go b/src/crypto/internal/mlkem768/mlkem768_test.go index 7d32805b2f7..5d129e11dfd 100644 --- a/src/crypto/internal/mlkem768/mlkem768_test.go +++ b/src/crypto/internal/mlkem768/mlkem768_test.go @@ -206,7 +206,7 @@ func TestRoundTrip(t *testing.T) { if err != nil { t.Fatal(err) } - Kd, err := Decapsulate(dk, c) + Kd, err := dk.Decapsulate(c) if err != nil { t.Fatal(err) } @@ -263,14 +263,14 @@ func TestBadLengths(t *testing.T) { } for i := 0; i < len(c)-1; i++ { - if _, err := Decapsulate(dk, c[:i]); err == nil { + if _, err := dk.Decapsulate(c[:i]); err == nil { t.Errorf("expected error for c length %d", i) } } cLong := c for i := 0; i < 100; i++ { cLong = append(cLong, 0) - if _, err := Decapsulate(dk, cLong); err == nil { + if _, err := dk.Decapsulate(cLong); err == nil { t.Errorf("expected error for c length %d", len(cLong)) } } @@ -315,7 +315,7 @@ func TestAccumulated(t *testing.T) { o.Write(ct) o.Write(k) - kk, err := Decapsulate(dk, ct) + kk, err := dk.Decapsulate(ct) if err != nil { t.Fatal(err) } @@ -324,7 +324,7 @@ func TestAccumulated(t *testing.T) { } s.Read(ct1) - k1, err := Decapsulate(dk, ct1) + k1, err := dk.Decapsulate(ct1) if err != nil { t.Fatal(err) } @@ -408,7 +408,7 @@ func BenchmarkRoundTrip(b *testing.B) { ekS := dkS.EncapsulationKey() sink ^= ekS[0] - Ks, err := Decapsulate(dk, c) + Ks, err := dk.Decapsulate(c) if err != nil { b.Fatal(err) } diff --git a/src/crypto/tls/key_schedule.go b/src/crypto/tls/key_schedule.go index 9c76ebe3670..e8ee9ce9c2f 100644 --- a/src/crypto/tls/key_schedule.go +++ b/src/crypto/tls/key_schedule.go @@ -59,7 +59,7 @@ type keySharePrivateKeys struct { // kyberDecapsulate implements decapsulation according to Kyber Round 3. func kyberDecapsulate(dk *mlkem768.DecapsulationKey, c []byte) ([]byte, error) { - K, err := mlkem768.Decapsulate(dk, c) + K, err := dk.Decapsulate(c) if err != nil { return nil, err }