diff --git a/src/crypto/internal/hpke/hpke.go b/src/crypto/internal/hpke/hpke.go index 022cdd28dfe..978f79cbcf3 100644 --- a/src/crypto/internal/hpke/hpke.go +++ b/src/crypto/internal/hpke/hpke.go @@ -26,10 +26,10 @@ type hkdfKDF struct { hash crypto.Hash } -func (kdf *hkdfKDF) LabeledExtract(suiteID []byte, salt []byte, label string, inputKey []byte) []byte { - labeledIKM := make([]byte, 0, 7+len(suiteID)+len(label)+len(inputKey)) +func (kdf *hkdfKDF) LabeledExtract(sid []byte, salt []byte, label string, inputKey []byte) []byte { + labeledIKM := make([]byte, 0, 7+len(sid)+len(label)+len(inputKey)) labeledIKM = append(labeledIKM, []byte("HPKE-v1")...) - labeledIKM = append(labeledIKM, suiteID...) + labeledIKM = append(labeledIKM, sid...) labeledIKM = append(labeledIKM, label...) labeledIKM = append(labeledIKM, inputKey...) return hkdf.Extract(kdf.hash.New, labeledIKM, salt) @@ -59,13 +59,17 @@ type dhKEM struct { nSecret uint16 } +type KemID uint16 + +const DHKEM_X25519_HKDF_SHA256 = 0x0020 + var SupportedKEMs = map[uint16]struct { curve ecdh.Curve hash crypto.Hash nSecret uint16 }{ // RFC 9180 Section 7.1 - 0x0020: {ecdh.X25519(), crypto.SHA256, 32}, + DHKEM_X25519_HKDF_SHA256: {ecdh.X25519(), crypto.SHA256, 32}, } func newDHKem(kemID uint16) (*dhKEM, error) { @@ -108,9 +112,22 @@ func (dh *dhKEM) Encap(pubRecipient *ecdh.PublicKey) (sharedSecret []byte, encap return dh.ExtractAndExpand(dhVal, kemContext), encPubEph, nil } -type Sender struct { +func (dh *dhKEM) Decap(encPubEph []byte, secRecipient *ecdh.PrivateKey) ([]byte, error) { + pubEph, err := dh.dh.NewPublicKey(encPubEph) + if err != nil { + return nil, err + } + dhVal, err := secRecipient.ECDH(pubEph) + if err != nil { + return nil, err + } + kemContext := append(encPubEph, secRecipient.PublicKey().Bytes()...) + + return dh.ExtractAndExpand(dhVal, kemContext), nil +} + +type context struct { aead cipher.AEAD - kem *dhKEM sharedSecret []byte @@ -123,6 +140,14 @@ type Sender struct { seqNum uint128 } +type Sender struct { + *context +} + +type Receipient struct { + *context +} + var aesGCMNew = func(key []byte) (cipher.AEAD, error) { block, err := aes.NewCipher(key) if err != nil { @@ -131,97 +156,143 @@ var aesGCMNew = func(key []byte) (cipher.AEAD, error) { return cipher.NewGCM(block) } +type AEADID uint16 + +const ( + AEAD_AES_128_GCM = 0x0001 + AEAD_AES_256_GCM = 0x0002 + AEAD_ChaCha20Poly1305 = 0x0003 +) + var SupportedAEADs = map[uint16]struct { keySize int nonceSize int aead func([]byte) (cipher.AEAD, error) }{ // RFC 9180, Section 7.3 - 0x0001: {keySize: 16, nonceSize: 12, aead: aesGCMNew}, - 0x0002: {keySize: 32, nonceSize: 12, aead: aesGCMNew}, - 0x0003: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New}, + AEAD_AES_128_GCM: {keySize: 16, nonceSize: 12, aead: aesGCMNew}, + AEAD_AES_256_GCM: {keySize: 32, nonceSize: 12, aead: aesGCMNew}, + AEAD_ChaCha20Poly1305: {keySize: chacha20poly1305.KeySize, nonceSize: chacha20poly1305.NonceSize, aead: chacha20poly1305.New}, } +type KDFID uint16 + +const KDF_HKDF_SHA256 = 0x0001 + var SupportedKDFs = map[uint16]func() *hkdfKDF{ // RFC 9180, Section 7.2 - 0x0001: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} }, + KDF_HKDF_SHA256: func() *hkdfKDF { return &hkdfKDF{crypto.SHA256} }, } -func SetupSender(kemID, kdfID, aeadID uint16, pub crypto.PublicKey, info []byte) ([]byte, *Sender, error) { - suiteID := SuiteID(kemID, kdfID, aeadID) - - kem, err := newDHKem(kemID) - if err != nil { - return nil, nil, err - } - pubRecipient, ok := pub.(*ecdh.PublicKey) - if !ok { - return nil, nil, errors.New("incorrect public key type") - } - sharedSecret, encapsulatedKey, err := kem.Encap(pubRecipient) - if err != nil { - return nil, nil, err - } +func newContext(sharedSecret []byte, kemID, kdfID, aeadID uint16, info []byte) (*context, error) { + sid := suiteID(kemID, kdfID, aeadID) kdfInit, ok := SupportedKDFs[kdfID] if !ok { - return nil, nil, errors.New("unsupported KDF id") + return nil, errors.New("unsupported KDF id") } kdf := kdfInit() aeadInfo, ok := SupportedAEADs[aeadID] if !ok { - return nil, nil, errors.New("unsupported AEAD id") + return nil, errors.New("unsupported AEAD id") } - pskIDHash := kdf.LabeledExtract(suiteID, nil, "psk_id_hash", nil) - infoHash := kdf.LabeledExtract(suiteID, nil, "info_hash", info) + pskIDHash := kdf.LabeledExtract(sid, nil, "psk_id_hash", nil) + infoHash := kdf.LabeledExtract(sid, nil, "info_hash", info) ksContext := append([]byte{0}, pskIDHash...) ksContext = append(ksContext, infoHash...) - secret := kdf.LabeledExtract(suiteID, sharedSecret, "secret", nil) + secret := kdf.LabeledExtract(sid, sharedSecret, "secret", nil) - key := kdf.LabeledExpand(suiteID, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */) - baseNonce := kdf.LabeledExpand(suiteID, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */) - exporterSecret := kdf.LabeledExpand(suiteID, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/) + key := kdf.LabeledExpand(sid, secret, "key", ksContext, uint16(aeadInfo.keySize) /* Nk - key size for AEAD */) + baseNonce := kdf.LabeledExpand(sid, secret, "base_nonce", ksContext, uint16(aeadInfo.nonceSize) /* Nn - nonce size for AEAD */) + exporterSecret := kdf.LabeledExpand(sid, secret, "exp", ksContext, uint16(kdf.hash.Size()) /* Nh - hash output size of the kdf*/) aead, err := aeadInfo.aead(key) if err != nil { - return nil, nil, err + return nil, err } - return encapsulatedKey, &Sender{ - kem: kem, + return &context{ aead: aead, sharedSecret: sharedSecret, - suiteID: suiteID, + suiteID: sid, key: key, baseNonce: baseNonce, exporterSecret: exporterSecret, }, nil } -func (s *Sender) nextNonce() []byte { - nonce := s.seqNum.bytes()[16-s.aead.NonceSize():] - for i := range s.baseNonce { - nonce[i] ^= s.baseNonce[i] +func SetupSender(kemID, kdfID, aeadID uint16, pub *ecdh.PublicKey, info []byte) ([]byte, *Sender, error) { + kem, err := newDHKem(kemID) + if err != nil { + return nil, nil, err } - // Message limit is, according to the RFC, 2^95+1, which - // is somewhat confusing, but we do as we're told. - if s.seqNum.bitLen() >= (s.aead.NonceSize()*8)-1 { - panic("message limit reached") + sharedSecret, encapsulatedKey, err := kem.Encap(pub) + if err != nil { + return nil, nil, err + } + + context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info) + if err != nil { + return nil, nil, err + } + + return encapsulatedKey, &Sender{context}, nil +} + +func SetupReceipient(kemID, kdfID, aeadID uint16, priv *ecdh.PrivateKey, info, encPubEph []byte) (*Receipient, error) { + kem, err := newDHKem(kemID) + if err != nil { + return nil, err + } + sharedSecret, err := kem.Decap(encPubEph, priv) + if err != nil { + return nil, err + } + + context, err := newContext(sharedSecret, kemID, kdfID, aeadID, info) + if err != nil { + return nil, err + } + + return &Receipient{context}, nil +} + +func (ctx *context) nextNonce() []byte { + nonce := ctx.seqNum.bytes()[16-ctx.aead.NonceSize():] + for i := range ctx.baseNonce { + nonce[i] ^= ctx.baseNonce[i] } - s.seqNum = s.seqNum.addOne() return nonce } -func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) { +func (ctx *context) incrementNonce() { + // Message limit is, according to the RFC, 2^95+1, which + // is somewhat confusing, but we do as we're told. + if ctx.seqNum.bitLen() >= (ctx.aead.NonceSize()*8)-1 { + panic("message limit reached") + } + ctx.seqNum = ctx.seqNum.addOne() +} +func (s *Sender) Seal(aad, plaintext []byte) ([]byte, error) { ciphertext := s.aead.Seal(nil, s.nextNonce(), plaintext, aad) + s.incrementNonce() return ciphertext, nil } -func SuiteID(kemID, kdfID, aeadID uint16) []byte { +func (r *Receipient) Open(aad, ciphertext []byte) ([]byte, error) { + plaintext, err := r.aead.Open(nil, r.nextNonce(), ciphertext, aad) + if err != nil { + return nil, err + } + r.incrementNonce() + return plaintext, nil +} + +func suiteID(kemID, kdfID, aeadID uint16) []byte { suiteID := make([]byte, 0, 4+2+2+2) suiteID = append(suiteID, []byte("HPKE")...) suiteID = byteorder.BeAppendUint16(suiteID, kemID) @@ -238,6 +309,14 @@ func ParseHPKEPublicKey(kemID uint16, bytes []byte) (*ecdh.PublicKey, error) { return kemInfo.curve.NewPublicKey(bytes) } +func ParseHPKEPrivateKey(kemID uint16, bytes []byte) (*ecdh.PrivateKey, error) { + kemInfo, ok := SupportedKEMs[kemID] + if !ok { + return nil, errors.New("unsupported KEM id") + } + return kemInfo.curve.NewPrivateKey(bytes) +} + type uint128 struct { hi, lo uint64 } diff --git a/src/crypto/internal/hpke/hpke_test.go b/src/crypto/internal/hpke/hpke_test.go index dbdfd7a80a3..51beeed2128 100644 --- a/src/crypto/internal/hpke/hpke_test.go +++ b/src/crypto/internal/hpke/hpke_test.go @@ -104,7 +104,7 @@ func TestRFC9180Vectors(t *testing.T) { } t.Cleanup(func() { testingOnlyGenerateKey = nil }) - encap, context, err := SetupSender( + encap, sender, err := SetupSender( uint16(kemID), uint16(kdfID), uint16(aeadID), @@ -119,21 +119,42 @@ func TestRFC9180Vectors(t *testing.T) { if !bytes.Equal(encap, expectedEncap) { t.Errorf("unexpected encapsulated key, got: %x, want %x", encap, expectedEncap) } - expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"]) - if !bytes.Equal(context.sharedSecret, expectedSharedSecret) { - t.Errorf("unexpected shared secret, got: %x, want %x", context.sharedSecret, expectedSharedSecret) + + privKeyBytes := mustDecodeHex(t, setup["skRm"]) + priv, err := ParseHPKEPrivateKey(uint16(kemID), privKeyBytes) + if err != nil { + t.Fatal(err) } - expectedKey := mustDecodeHex(t, setup["key"]) - if !bytes.Equal(context.key, expectedKey) { - t.Errorf("unexpected key, got: %x, want %x", context.key, expectedKey) + + receipient, err := SetupReceipient( + uint16(kemID), + uint16(kdfID), + uint16(aeadID), + priv, + info, + encap, + ) + if err != nil { + t.Fatal(err) } - expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"]) - if !bytes.Equal(context.baseNonce, expectedBaseNonce) { - t.Errorf("unexpected base nonce, got: %x, want %x", context.baseNonce, expectedBaseNonce) - } - expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"]) - if !bytes.Equal(context.exporterSecret, expectedExporterSecret) { - t.Errorf("unexpected exporter secret, got: %x, want %x", context.exporterSecret, expectedExporterSecret) + + for _, ctx := range []*context{sender.context, receipient.context} { + expectedSharedSecret := mustDecodeHex(t, setup["shared_secret"]) + if !bytes.Equal(ctx.sharedSecret, expectedSharedSecret) { + t.Errorf("unexpected shared secret, got: %x, want %x", ctx.sharedSecret, expectedSharedSecret) + } + expectedKey := mustDecodeHex(t, setup["key"]) + if !bytes.Equal(ctx.key, expectedKey) { + t.Errorf("unexpected key, got: %x, want %x", ctx.key, expectedKey) + } + expectedBaseNonce := mustDecodeHex(t, setup["base_nonce"]) + if !bytes.Equal(ctx.baseNonce, expectedBaseNonce) { + t.Errorf("unexpected base nonce, got: %x, want %x", ctx.baseNonce, expectedBaseNonce) + } + expectedExporterSecret := mustDecodeHex(t, setup["exporter_secret"]) + if !bytes.Equal(ctx.exporterSecret, expectedExporterSecret) { + t.Errorf("unexpected exporter secret, got: %x, want %x", ctx.exporterSecret, expectedExporterSecret) + } } for _, enc := range parseVectorEncryptions(vector.Encryptions) { @@ -142,26 +163,31 @@ func TestRFC9180Vectors(t *testing.T) { if err != nil { t.Fatal(err) } - context.seqNum = uint128{lo: uint64(seqNum)} + sender.seqNum = uint128{lo: uint64(seqNum)} + receipient.seqNum = uint128{lo: uint64(seqNum)} expectedNonce := mustDecodeHex(t, enc["nonce"]) - // We can't call nextNonce, because it increments the sequence number, - // so just compute it directly. - computedNonce := context.seqNum.bytes()[16-context.aead.NonceSize():] - for i := range context.baseNonce { - computedNonce[i] ^= context.baseNonce[i] - } + computedNonce := sender.nextNonce() if !bytes.Equal(computedNonce, expectedNonce) { t.Errorf("unexpected nonce: got %x, want %x", computedNonce, expectedNonce) } expectedCiphertext := mustDecodeHex(t, enc["ct"]) - ciphertext, err := context.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"])) + ciphertext, err := sender.Seal(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["pt"])) if err != nil { t.Fatal(err) } if !bytes.Equal(ciphertext, expectedCiphertext) { t.Errorf("unexpected ciphertext: got %x want %x", ciphertext, expectedCiphertext) } + + expectedPlaintext := mustDecodeHex(t, enc["pt"]) + plaintext, err := receipient.Open(mustDecodeHex(t, enc["aad"]), mustDecodeHex(t, enc["ct"])) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(plaintext, expectedPlaintext) { + t.Errorf("unexpected plaintext: got %x want %x", plaintext, expectedPlaintext) + } }) } })