1
0
mirror of https://github.com/golang/go synced 2024-11-21 22:54:40 -07:00

crypto/rsa: port PrivateKey.Validate to bigmod, add validations

This patch ports the implementation of PrivateKey.Validate to use the
bigmod math library, ensuring that the arithmetic operations happen in
constant time. A few new APIs have been added to bigmod to add
operations which don't explicitly require modulus arithmetic, but do
take the modulus size into account to ensure we don't leak any
non-public information.

In addition to porting this routine to use bigmod this patch also adds a
few more steps to the validation as defined by NIST SP 800-56B REV. 2
Section 6.4.1.4.3.

For #69536
This commit is contained in:
Derek Parker 2024-11-06 17:01:28 -08:00
parent 840ac5e037
commit 570721ddf5
5 changed files with 282 additions and 62 deletions

View File

@ -327,9 +327,9 @@ func signNISTEC[Point nistPoint[Point]](c *nistCurve[Point], priv *PrivateKey, c
if err != nil {
return nil, err
}
s.Mul(r, c.N)
s.MulMod(r, c.N)
s.Add(e, c.N)
s.Mul(kInv, c.N)
s.MulMod(kInv, c.N)
// Again, the chance of this happening is cryptographically negligible.
if s.IsZero() == 1 {
@ -528,12 +528,12 @@ func verifyNISTEC[Point nistPoint[Point]](c *nistCurve[Point], pub *PublicKey, h
inverse(c, w, s)
// p₁ = [e * s⁻¹]G
p1, err := c.newPoint().ScalarBaseMult(e.Mul(w, c.N).Bytes(c.N))
p1, err := c.newPoint().ScalarBaseMult(e.MulMod(w, c.N).Bytes(c.N))
if err != nil {
return false
}
// p₂ = [r * s⁻¹]Q
p2, err := Q.ScalarMult(Q, w.Mul(r, c.N).Bytes(c.N))
p2, err := Q.ScalarMult(Q, w.MulMod(r, c.N).Bytes(c.N))
if err != nil {
return false
}

View File

@ -92,8 +92,8 @@ func (x *Nat) reset(n int) *Nat {
return x
}
// set assigns x = y, optionally resizing x to the appropriate size.
func (x *Nat) set(y *Nat) *Nat {
// Set assigns x = y, optionally resizing x to the appropriate size.
func (x *Nat) Set(y *Nat) *Nat {
x.reset(len(y.limbs))
copy(x.limbs, y.limbs)
return x
@ -226,6 +226,29 @@ func (x *Nat) IsZero() choice {
//
// Both operands must have the same announced length.
func (x *Nat) cmpGeq(y *Nat) choice {
c := x.subCarry(y)
// If there was a carry, then subtracting y underflowed, so
// x is not greater than or equal to y.
return not(choice(c))
}
// Cmp compares x and y and returns the result of the compare:
// 1 if x > y
// 0 if x == y
// -1 if x < y
func (x *Nat) Cmp(y *Nat) int {
if x.Equal(y) == yes {
return 0
}
c := x.subCarry(y)
res := 1
if c > 0 {
res = -1
}
return res
}
func (x *Nat) subCarry(y *Nat) uint {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
@ -235,9 +258,7 @@ func (x *Nat) cmpGeq(y *Nat) choice {
for i := 0; i < size; i++ {
_, c = bits.Sub(xLimbs[i], yLimbs[i], c)
}
// If there was a carry, then subtracting y underflowed, so
// x is not greater than or equal to y.
return not(choice(c))
return c
}
// assign sets x <- y if on == 1, and does nothing otherwise.
@ -274,7 +295,7 @@ func (x *Nat) add(y *Nat) (c uint) {
// sub computes x -= y. It returns the borrow of the subtraction.
//
// Both operands must have the same announced length.
func (x *Nat) sub(y *Nat) (c uint) {
func (x *Nat) Sub(y *Nat) (c uint) {
// Eliminate bounds checks in the loop.
size := len(x.limbs)
xLimbs := x.limbs[:size]
@ -301,6 +322,7 @@ type Modulus struct {
leading int // number of leading zeros in the modulus
m0inv uint // -nat.limbs[0]⁻¹ mod _W
rr *Nat // R*R for montgomeryRepresentation
even bool
}
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
@ -374,16 +396,18 @@ func minusInverseModW(x uint) uint {
// The Int must be odd. The number of significant bits (and nothing else) is
// leaked through timing side-channels.
func NewModulusFromBig(n *big.Int) (*Modulus, error) {
if b := n.Bits(); len(b) == 0 {
b := n.Bits()
if len(b) == 0 {
return nil, errors.New("modulus must be >= 0")
} else if b[0]&1 != 1 {
return nil, errors.New("modulus must be odd")
}
m := &Modulus{}
m.even = b[0]&1 != 1
m.nat = NewNat().setBig(n)
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
m.m0inv = minusInverseModW(m.nat.limbs[0])
m.rr = rr(m)
if !m.even {
m.m0inv = minusInverseModW(m.nat.limbs[0])
m.rr = rr(m)
}
return m, nil
}
@ -508,8 +532,8 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
//
// x and m operands must have the same announced length.
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
t := NewNat().set(x)
underflow := t.sub(m.nat)
t := NewNat().Set(x)
underflow := t.Sub(m.nat)
// We keep the result if x - m didn't underflow (meaning x >= m)
// or if always was set.
keep := not(choice(underflow)) | choice(always)
@ -520,10 +544,10 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
underflow := x.sub(y)
func (x *Nat) SubMod(y *Nat, m *Modulus) *Nat {
underflow := x.Sub(y)
// If the subtraction underflowed, add m.
t := NewNat().set(x)
t := NewNat().Set(x)
t.add(m.nat)
x.assign(choice(underflow), t)
return x
@ -571,6 +595,9 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
// All inputs should be the same length and already reduced modulo m.
// x will be resized to the size of m and overwritten.
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
if m.even {
panic("crypto/rsa: montgomery multiplication on even modulus")
}
n := len(m.nat.limbs)
mLimbs := m.nat.limbs[:n]
aLimbs := a.limbs[:n]
@ -707,17 +734,40 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
return carry
}
// Mul calculates x = x * y mod m.
// MulMod calculates x = x * y mod m.
//
// The length of both operands must be the same as the modulus. Both operands
// must already be reduced modulo m.
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
func (x *Nat) MulMod(y *Nat, m *Modulus) *Nat {
// A Montgomery multiplication by a value out of the Montgomery domain
// takes the result out of Montgomery representation.
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
}
// Mul calculates z = x * y.
//
// All inputs should be the same length and already reduced modulo m.
// z will be resized to the size of m and overwritten.
func (z *Nat) Mul(x *Nat, y *Nat, m *Modulus) *Nat {
n := len(m.nat.limbs)
zLimbs := z.resetFor(m).limbs
xLimbs := x.limbs
yLimbs := y.limbs
switch n {
default:
for i := 0; i < n; i++ {
addMulVVW(zLimbs[i:], xLimbs, yLimbs[i])
}
case 2048 / _W:
const n = 2048 / _W // compiler hint
for i := 0; i < n; i++ {
addMulVVW2048(&zLimbs[i:][0], &xLimbs[0], yLimbs[i])
}
}
return z
}
// Exp calculates out = x^e mod m.
//
// The exponent e is represented in big-endian order. The output will be resized
@ -734,7 +784,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
}
table[0].set(x).montgomeryRepresentation(m)
table[0].Set(x).montgomeryRepresentation(m)
for i := 1; i < len(table); i++ {
table[i].montgomeryMul(table[i-1], table[0], m)
}
@ -775,8 +825,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
// For short exponents, precomputing a table and using a window like in Exp
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
// chain, skipping the initial run of zeroes.
xR := NewNat().set(x).montgomeryRepresentation(m)
out.set(xR)
xR := NewNat().Set(x).montgomeryRepresentation(m)
out.Set(xR)
for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ {
out.montgomeryMul(out, out, m)
if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {

View File

@ -35,9 +35,9 @@ func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
func testModAddCommutative(a *Nat, b *Nat) bool {
m := maxModulus(uint(len(a.limbs)))
aPlusB := new(Nat).set(a)
aPlusB := new(Nat).Set(a)
aPlusB.Add(b, m)
bPlusA := new(Nat).set(b)
bPlusA := new(Nat).Set(b)
bPlusA.Add(a, m)
return aPlusB.Equal(bPlusA) == 1
}
@ -51,8 +51,8 @@ func TestModAddCommutative(t *testing.T) {
func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
m := maxModulus(uint(len(a.limbs)))
original := new(Nat).set(a)
a.Sub(b, m)
original := new(Nat).Set(a)
a.SubMod(b, m)
a.Add(b, m)
return a.Equal(original) == 1
}
@ -71,9 +71,9 @@ func TestMontgomeryRoundtrip(t *testing.T) {
aPlusOne := new(big.Int).SetBytes(natBytes(a))
aPlusOne.Add(aPlusOne, big.NewInt(1))
m, _ := NewModulusFromBig(aPlusOne)
monty := new(Nat).set(a)
monty := new(Nat).Set(a)
monty.montgomeryRepresentation(m)
aAgain := new(Nat).set(monty)
aAgain := new(Nat).Set(monty)
aAgain.montgomeryMul(monty, one, m)
if a.Equal(aAgain) != 1 {
t.Errorf("%v != %v", a, aAgain)
@ -260,12 +260,12 @@ func TestModSub(t *testing.T) {
m := modulusFromBytes([]byte{13})
x := &Nat{[]uint{6}}
y := &Nat{[]uint{7}}
x.Sub(y, m)
x.SubMod(y, m)
expected := &Nat{[]uint{12}}
if x.Equal(expected) != 1 {
t.Errorf("%+v != %+v", x, expected)
}
x.Sub(y, m)
x.SubMod(y, m)
expected = &Nat{[]uint{5}}
if x.Equal(expected) != 1 {
t.Errorf("%+v != %+v", x, expected)
@ -323,7 +323,7 @@ func TestMulReductions(t *testing.T) {
A := NewNat().setBig(a).ExpandFor(N)
B := NewNat().setBig(b).ExpandFor(N)
if A.Mul(B, N).IsZero() != 1 {
if A.MulMod(B, N).IsZero() != 1 {
t.Error("a * b mod (a * b) != 0")
}
@ -333,7 +333,7 @@ func TestMulReductions(t *testing.T) {
I := NewNat().setBig(i).ExpandFor(N)
one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
if A.Mul(I, N).Equal(one) != 1 {
if A.MulMod(I, N).Equal(one) != 1 {
t.Error("a * inv(a) mod b != 1")
}
}
@ -401,7 +401,7 @@ func BenchmarkModSub(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Sub(y, m)
x.SubMod(y, m)
}
}
@ -434,7 +434,7 @@ func BenchmarkModMul(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
x.Mul(y, m)
x.MulMod(y, m)
}
}
@ -472,9 +472,38 @@ func TestNewModFromBigZero(t *testing.T) {
t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
}
expected = "modulus must be odd"
_, err = NewModulusFromBig(big.NewInt(2))
if err == nil || err.Error() != expected {
t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected)
defer func(t *testing.T) {
if r := recover(); r != nil {
if s, ok := r.(string); !ok || !strings.Contains(s, "montgomery multiplication on even modulus") {
t.Errorf("Unexpected panic: %#v", r)
}
} else {
t.Error("Expected panic to be recovered, got nothing.")
}
}(t)
m, err := NewModulusFromBig(big.NewInt(10))
if err != nil {
t.Errorf("NewModulusFromBig(2) got %q, want %q", err, "")
}
x := NewNat().setBig(big.NewInt(1))
y := NewNat().setBig(big.NewInt(2))
x.MulMod(y, m)
}
func TestNatCmp(t *testing.T) {
testcases := [][3]int64{
{33, 22, 1},
{33, 33, 0},
{22, 33, -1},
}
for _, tc := range testcases {
a := new(big.Int).SetInt64(tc[0])
b := new(big.Int).SetInt64(tc[1])
na := natFromBytes(a.Bytes())
nb := natFromBytes(b.Bytes())
if res, _ := na.Cmp(nb); res != int(tc[2]) {
t.Errorf("expected %d got %d for (%d).cmp(%d)", tc[2], res, tc[0], tc[1])
}
}
}

View File

@ -32,6 +32,7 @@ import (
"crypto/internal/randutil"
"crypto/rand"
"crypto/subtle"
"encoding/binary"
"errors"
"hash"
"io"
@ -236,31 +237,149 @@ func (priv *PrivateKey) Validate() error {
}
// Check that Πprimes == n.
modulus := new(big.Int).Set(bigOne)
N, err := bigmod.NewModulusFromBig(priv.N)
if err != nil {
return err
}
bigOneNat, err := bigmod.NewNat().SetBytes(bigOne.Bytes(), N)
if err != nil {
return err
}
modulus, err := bigmod.NewNat().SetBytes(bigOne.Bytes(), N)
if err != nil {
return err
}
for _, prime := range priv.Primes {
// Any primes ≤ 1 will cause divide-by-zero panics later.
if prime.Cmp(bigOne) <= 0 {
nprime, err := bigmod.NewNat().SetBytes(prime.Bytes(), N)
if err != nil {
return err
}
if d := nprime.Cmp(bigOneNat); d <= 0 {
return errors.New("crypto/rsa: invalid prime value")
}
modulus.Mul(modulus, prime)
modulus.MulMod(nprime, N)
}
if modulus.Cmp(priv.N) != 0 {
if modulus.Equal(N.Nat()) != 0 {
return errors.New("crypto/rsa: invalid modulus")
}
// Check that de ≡ 1 mod p-1, for each prime.
// This implies that e is coprime to each p-1 as e has a multiplicative
// inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) =
// exponent(/n). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1
// mod p. Thus a^de ≡ a mod n for all a coprime to n, as required.
congruence := new(big.Int)
de := new(big.Int).SetInt64(int64(priv.E))
de.Mul(de, priv.D)
for _, prime := range priv.Primes {
pminus1 := new(big.Int).Sub(prime, bigOne)
congruence.Mod(de, pminus1)
if congruence.Cmp(bigOne) != 0 {
return errors.New("crypto/rsa: invalid exponents")
// NIST SP 800-56B REV.2 6.4.1.1 3.a
// Key-pair consistency verifying that m = (m^e)^d mod n for some integer m satisfying 1 < m < (n 1).
m, err := bigmod.NewNat().SetBytes([]byte{2}, N)
ebytes := make([]byte, 8)
binary.BigEndian.PutUint64(ebytes, uint64(priv.E))
mm := bigmod.NewNat().Exp(m, ebytes, N)
mm.Exp(mm, priv.D.Bytes(), N)
if mm.Equal(m) != 1 {
return errors.New("crypto/rsa: key-pair consistency check failed")
}
if len(priv.Primes) > 2 {
// Check that de ≡ 1 mod p-1, for each prime.
// This implies that e is coprime to each p-1 as e has a multiplicative
// inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) =
// exponent(/n). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1
// mod p. Thus a^de ≡ a mod n for all a coprime to n, as required.
congruence := new(big.Int)
de := new(big.Int).SetInt64(int64(priv.E))
de.Mul(de, priv.D)
for _, prime := range priv.Primes {
pminus1 := new(big.Int).Sub(prime, bigOne)
congruence.Mod(de, pminus1)
if congruence.Cmp(bigOne) != 0 {
return errors.New("crypto/rsa: invalid exponents")
}
}
} else {
// 6.4.1.2.1.D rsakpv1-crt
//
// (npub == p × q)
pBytes := priv.Primes[0].Bytes()
qBytes := priv.Primes[1].Bytes()
p, err := bigmod.NewNat().SetBytes(pBytes, N)
if err != nil {
return err
}
q, err := bigmod.NewNat().SetBytes(qBytes, N)
if err != nil {
return err
}
product := bigmod.NewNat().Mul(p, q, N)
if r := product.Cmp(N.Nat()); r != 0 {
return errors.New("crypto/rsa: invalid RSA key pair")
}
// 6.4.1.2.1.F rsakpv1-crt
priv.Precompute()
// Step a: 1 < dP < (p 1).
pminus1 := bigmod.NewNat().Set(p)
pminus1.Sub(bigOneNat)
pminus1big := new(big.Int).SetBytes(pminus1.Bytes(N))
pminus1mod, err := bigmod.NewModulusFromBig(pminus1big)
dP, err := bigmod.NewNat().SetBytes(priv.Precomputed.Dp.Bytes(), N)
if err != nil {
return err
}
// dP := bigmod.NewNat().Mod(d, pminus1mod)
res1 := bigOneNat.Cmp(dP)
res2 := dP.Cmp(pminus1)
if res1 != -1 || res2 != -1 {
return errors.New("crypto/rsa: (step A) invalid RSA key pair")
}
// Step b: 1 < dQ < (q 1).
dQ, err := bigmod.NewNat().SetBytes(priv.Precomputed.Dq.Bytes(), N)
if err != nil {
return err
}
qminus1 := bigmod.NewNat().Set(q)
qminus1.Sub(bigOneNat)
qminus1big := new(big.Int).SetBytes(qminus1.Bytes(N))
qminus1mod, err := bigmod.NewModulusFromBig(qminus1big)
res1 = bigOneNat.Cmp(dQ)
res2 = dQ.Cmp(qminus1)
if res1 != -1 || res2 != -1 {
return errors.New("crypto/rsa: invalid RSA key pair")
}
// Step c: 1 < qInv < p.
qInv, err := bigmod.NewNat().SetBytes(priv.Precomputed.Qinv.Bytes(), N)
if err != nil {
return err
}
res1 = bigOneNat.Cmp(qInv)
res2 = qInv.Cmp(p)
if res1 != -1 || res2 != -1 {
return errors.New("crypto/rsa: invalid RSA key pair")
}
// Step d: 1 = (dP × epub) mod (p 1).
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(priv.E))
epub, err := bigmod.NewNat().SetBytes(buf, N)
if err != nil {
return err
}
if bigmod.NewNat().Mod(
bigmod.NewNat().Mul(dP, epub, N),
pminus1mod).Equal(bigOneNat) != 1 {
return errors.New("crypto/rsa: invalid RSA key pair")
}
// Step e: 1 = (dQ × epub) mod (q 1).
if bigmod.NewNat().Mod(
bigmod.NewNat().Mul(dQ, epub, N),
qminus1mod).Equal(bigOneNat) != 1 {
return errors.New("crypto/rsa: invalid RSA key pair")
}
// Step f: 1 = (qInv × q) mod p.
pmod, err := bigmod.NewModulusFromBig(priv.Primes[0])
if err != nil {
return err
}
if bigmod.NewNat().Mod(
bigmod.NewNat().Mul(qInv, q, N),
pmod).Equal(bigOneNat) != 1 {
return errors.New("crypto/rsa: invalid RSA key pair")
}
}
return nil
@ -675,11 +794,11 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
// m2 = c ^ Dq mod q
m2 := bigmod.NewNat().Exp(t0.Mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
// m = m - m2 mod p
m.Sub(t0.Mod(m2, P), P)
m.SubMod(t0.Mod(m2, P), P)
// m = m * Qinv mod p
m.Mul(Qinv, P)
m.MulMod(Qinv, P)
// m = m * q mod N
m.ExpandFor(N).Mul(t0.Mod(Q.Nat(), N), N)
m.ExpandFor(N).MulMod(t0.Mod(Q.Nat(), N), N)
// m = m + m2 mod N
m.Add(m2.ExpandFor(N), N)
}

View File

@ -621,6 +621,28 @@ func BenchmarkVerifyPSS(b *testing.B) {
})
}
func BenchmarkPrecompute(b *testing.B) {
b.Run("2048", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
k := *test2048Key
k.Precomputed = PrecomputedValues{}
k.Precompute()
}
})
}
func BenchmarkValidate(b *testing.B) {
b.Run("2048", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
if err := test2048Key.Validate(); err != nil {
b.Fatal(err)
}
}
})
}
type testEncryptOAEPMessage struct {
in []byte
seed []byte