From 570721ddf553bb2d5724549efc35226b17e6f7d9 Mon Sep 17 00:00:00 2001 From: Derek Parker Date: Wed, 6 Nov 2024 17:01:28 -0800 Subject: [PATCH] crypto/rsa: port PrivateKey.Validate to bigmod, add validations This patch ports the implementation of PrivateKey.Validate to use the bigmod math library, ensuring that the arithmetic operations happen in constant time. A few new APIs have been added to bigmod to add operations which don't explicitly require modulus arithmetic, but do take the modulus size into account to ensure we don't leak any non-public information. In addition to porting this routine to use bigmod this patch also adds a few more steps to the validation as defined by NIST SP 800-56B REV. 2 Section 6.4.1.4.3. For #69536 --- src/crypto/ecdsa/ecdsa.go | 8 +- src/crypto/internal/bigmod/nat.go | 94 +++++++++++---- src/crypto/internal/bigmod/nat_test.go | 61 +++++++--- src/crypto/rsa/rsa.go | 159 +++++++++++++++++++++---- src/crypto/rsa/rsa_test.go | 22 ++++ 5 files changed, 282 insertions(+), 62 deletions(-) diff --git a/src/crypto/ecdsa/ecdsa.go b/src/crypto/ecdsa/ecdsa.go index 2179b01e8e3..9a4ea36d6a1 100644 --- a/src/crypto/ecdsa/ecdsa.go +++ b/src/crypto/ecdsa/ecdsa.go @@ -327,9 +327,9 @@ func signNISTEC[Point nistPoint[Point]](c *nistCurve[Point], priv *PrivateKey, c if err != nil { return nil, err } - s.Mul(r, c.N) + s.MulMod(r, c.N) s.Add(e, c.N) - s.Mul(kInv, c.N) + s.MulMod(kInv, c.N) // Again, the chance of this happening is cryptographically negligible. if s.IsZero() == 1 { @@ -528,12 +528,12 @@ func verifyNISTEC[Point nistPoint[Point]](c *nistCurve[Point], pub *PublicKey, h inverse(c, w, s) // p₁ = [e * s⁻¹]G - p1, err := c.newPoint().ScalarBaseMult(e.Mul(w, c.N).Bytes(c.N)) + p1, err := c.newPoint().ScalarBaseMult(e.MulMod(w, c.N).Bytes(c.N)) if err != nil { return false } // p₂ = [r * s⁻¹]Q - p2, err := Q.ScalarMult(Q, w.Mul(r, c.N).Bytes(c.N)) + p2, err := Q.ScalarMult(Q, w.MulMod(r, c.N).Bytes(c.N)) if err != nil { return false } diff --git a/src/crypto/internal/bigmod/nat.go b/src/crypto/internal/bigmod/nat.go index 5cbae40efe9..0982730ce29 100644 --- a/src/crypto/internal/bigmod/nat.go +++ b/src/crypto/internal/bigmod/nat.go @@ -92,8 +92,8 @@ func (x *Nat) reset(n int) *Nat { return x } -// set assigns x = y, optionally resizing x to the appropriate size. -func (x *Nat) set(y *Nat) *Nat { +// Set assigns x = y, optionally resizing x to the appropriate size. +func (x *Nat) Set(y *Nat) *Nat { x.reset(len(y.limbs)) copy(x.limbs, y.limbs) return x @@ -226,6 +226,29 @@ func (x *Nat) IsZero() choice { // // Both operands must have the same announced length. func (x *Nat) cmpGeq(y *Nat) choice { + c := x.subCarry(y) + // If there was a carry, then subtracting y underflowed, so + // x is not greater than or equal to y. + return not(choice(c)) +} + +// Cmp compares x and y and returns the result of the compare: +// 1 if x > y +// 0 if x == y +// -1 if x < y +func (x *Nat) Cmp(y *Nat) int { + if x.Equal(y) == yes { + return 0 + } + c := x.subCarry(y) + res := 1 + if c > 0 { + res = -1 + } + return res +} + +func (x *Nat) subCarry(y *Nat) uint { // Eliminate bounds checks in the loop. size := len(x.limbs) xLimbs := x.limbs[:size] @@ -235,9 +258,7 @@ func (x *Nat) cmpGeq(y *Nat) choice { for i := 0; i < size; i++ { _, c = bits.Sub(xLimbs[i], yLimbs[i], c) } - // If there was a carry, then subtracting y underflowed, so - // x is not greater than or equal to y. - return not(choice(c)) + return c } // assign sets x <- y if on == 1, and does nothing otherwise. @@ -274,7 +295,7 @@ func (x *Nat) add(y *Nat) (c uint) { // sub computes x -= y. It returns the borrow of the subtraction. // // Both operands must have the same announced length. -func (x *Nat) sub(y *Nat) (c uint) { +func (x *Nat) Sub(y *Nat) (c uint) { // Eliminate bounds checks in the loop. size := len(x.limbs) xLimbs := x.limbs[:size] @@ -301,6 +322,7 @@ type Modulus struct { leading int // number of leading zeros in the modulus m0inv uint // -nat.limbs[0]⁻¹ mod _W rr *Nat // R*R for montgomeryRepresentation + even bool } // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs). @@ -374,16 +396,18 @@ func minusInverseModW(x uint) uint { // The Int must be odd. The number of significant bits (and nothing else) is // leaked through timing side-channels. func NewModulusFromBig(n *big.Int) (*Modulus, error) { - if b := n.Bits(); len(b) == 0 { + b := n.Bits() + if len(b) == 0 { return nil, errors.New("modulus must be >= 0") - } else if b[0]&1 != 1 { - return nil, errors.New("modulus must be odd") } m := &Modulus{} + m.even = b[0]&1 != 1 m.nat = NewNat().setBig(n) m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1]) - m.m0inv = minusInverseModW(m.nat.limbs[0]) - m.rr = rr(m) + if !m.even { + m.m0inv = minusInverseModW(m.nat.limbs[0]) + m.rr = rr(m) + } return m, nil } @@ -508,8 +532,8 @@ func (out *Nat) resetFor(m *Modulus) *Nat { // // x and m operands must have the same announced length. func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { - t := NewNat().set(x) - underflow := t.sub(m.nat) + t := NewNat().Set(x) + underflow := t.Sub(m.nat) // We keep the result if x - m didn't underflow (meaning x >= m) // or if always was set. keep := not(choice(underflow)) | choice(always) @@ -520,10 +544,10 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { // // The length of both operands must be the same as the modulus. Both operands // must already be reduced modulo m. -func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { - underflow := x.sub(y) +func (x *Nat) SubMod(y *Nat, m *Modulus) *Nat { + underflow := x.Sub(y) // If the subtraction underflowed, add m. - t := NewNat().set(x) + t := NewNat().Set(x) t.add(m.nat) x.assign(choice(underflow), t) return x @@ -571,6 +595,9 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat { // All inputs should be the same length and already reduced modulo m. // x will be resized to the size of m and overwritten. func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { + if m.even { + panic("crypto/rsa: montgomery multiplication on even modulus") + } n := len(m.nat.limbs) mLimbs := m.nat.limbs[:n] aLimbs := a.limbs[:n] @@ -707,17 +734,40 @@ func addMulVVW(z, x []uint, y uint) (carry uint) { return carry } -// Mul calculates x = x * y mod m. +// MulMod calculates x = x * y mod m. // // The length of both operands must be the same as the modulus. Both operands // must already be reduced modulo m. -func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { +func (x *Nat) MulMod(y *Nat, m *Modulus) *Nat { // A Montgomery multiplication by a value out of the Montgomery domain // takes the result out of Montgomery representation. - xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m + xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m } +// Mul calculates z = x * y. +// +// All inputs should be the same length and already reduced modulo m. +// z will be resized to the size of m and overwritten. +func (z *Nat) Mul(x *Nat, y *Nat, m *Modulus) *Nat { + n := len(m.nat.limbs) + zLimbs := z.resetFor(m).limbs + xLimbs := x.limbs + yLimbs := y.limbs + switch n { + default: + for i := 0; i < n; i++ { + addMulVVW(zLimbs[i:], xLimbs, yLimbs[i]) + } + case 2048 / _W: + const n = 2048 / _W // compiler hint + for i := 0; i < n; i++ { + addMulVVW2048(&zLimbs[i:][0], &xLimbs[0], yLimbs[i]) + } + } + return z +} + // Exp calculates out = x^e mod m. // // The exponent e is represented in big-endian order. The output will be resized @@ -734,7 +784,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), } - table[0].set(x).montgomeryRepresentation(m) + table[0].Set(x).montgomeryRepresentation(m) for i := 1; i < len(table); i++ { table[i].montgomeryMul(table[i-1], table[0], m) } @@ -775,8 +825,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat { // For short exponents, precomputing a table and using a window like in Exp // doesn't pay off. Instead, we do a simple conditional square-and-multiply // chain, skipping the initial run of zeroes. - xR := NewNat().set(x).montgomeryRepresentation(m) - out.set(xR) + xR := NewNat().Set(x).montgomeryRepresentation(m) + out.Set(xR) for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ { out.montgomeryMul(out, out, m) if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 { diff --git a/src/crypto/internal/bigmod/nat_test.go b/src/crypto/internal/bigmod/nat_test.go index 7a956e3a57d..d3ac157762d 100644 --- a/src/crypto/internal/bigmod/nat_test.go +++ b/src/crypto/internal/bigmod/nat_test.go @@ -35,9 +35,9 @@ func (*Nat) Generate(r *rand.Rand, size int) reflect.Value { func testModAddCommutative(a *Nat, b *Nat) bool { m := maxModulus(uint(len(a.limbs))) - aPlusB := new(Nat).set(a) + aPlusB := new(Nat).Set(a) aPlusB.Add(b, m) - bPlusA := new(Nat).set(b) + bPlusA := new(Nat).Set(b) bPlusA.Add(a, m) return aPlusB.Equal(bPlusA) == 1 } @@ -51,8 +51,8 @@ func TestModAddCommutative(t *testing.T) { func testModSubThenAddIdentity(a *Nat, b *Nat) bool { m := maxModulus(uint(len(a.limbs))) - original := new(Nat).set(a) - a.Sub(b, m) + original := new(Nat).Set(a) + a.SubMod(b, m) a.Add(b, m) return a.Equal(original) == 1 } @@ -71,9 +71,9 @@ func TestMontgomeryRoundtrip(t *testing.T) { aPlusOne := new(big.Int).SetBytes(natBytes(a)) aPlusOne.Add(aPlusOne, big.NewInt(1)) m, _ := NewModulusFromBig(aPlusOne) - monty := new(Nat).set(a) + monty := new(Nat).Set(a) monty.montgomeryRepresentation(m) - aAgain := new(Nat).set(monty) + aAgain := new(Nat).Set(monty) aAgain.montgomeryMul(monty, one, m) if a.Equal(aAgain) != 1 { t.Errorf("%v != %v", a, aAgain) @@ -260,12 +260,12 @@ func TestModSub(t *testing.T) { m := modulusFromBytes([]byte{13}) x := &Nat{[]uint{6}} y := &Nat{[]uint{7}} - x.Sub(y, m) + x.SubMod(y, m) expected := &Nat{[]uint{12}} if x.Equal(expected) != 1 { t.Errorf("%+v != %+v", x, expected) } - x.Sub(y, m) + x.SubMod(y, m) expected = &Nat{[]uint{5}} if x.Equal(expected) != 1 { t.Errorf("%+v != %+v", x, expected) @@ -323,7 +323,7 @@ func TestMulReductions(t *testing.T) { A := NewNat().setBig(a).ExpandFor(N) B := NewNat().setBig(b).ExpandFor(N) - if A.Mul(B, N).IsZero() != 1 { + if A.MulMod(B, N).IsZero() != 1 { t.Error("a * b mod (a * b) != 0") } @@ -333,7 +333,7 @@ func TestMulReductions(t *testing.T) { I := NewNat().setBig(i).ExpandFor(N) one := NewNat().setBig(big.NewInt(1)).ExpandFor(N) - if A.Mul(I, N).Equal(one) != 1 { + if A.MulMod(I, N).Equal(one) != 1 { t.Error("a * inv(a) mod b != 1") } } @@ -401,7 +401,7 @@ func BenchmarkModSub(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - x.Sub(y, m) + x.SubMod(y, m) } } @@ -434,7 +434,7 @@ func BenchmarkModMul(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - x.Mul(y, m) + x.MulMod(y, m) } } @@ -472,9 +472,38 @@ func TestNewModFromBigZero(t *testing.T) { t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected) } - expected = "modulus must be odd" - _, err = NewModulusFromBig(big.NewInt(2)) - if err == nil || err.Error() != expected { - t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected) + defer func(t *testing.T) { + if r := recover(); r != nil { + if s, ok := r.(string); !ok || !strings.Contains(s, "montgomery multiplication on even modulus") { + t.Errorf("Unexpected panic: %#v", r) + } + } else { + t.Error("Expected panic to be recovered, got nothing.") + } + }(t) + + m, err := NewModulusFromBig(big.NewInt(10)) + if err != nil { + t.Errorf("NewModulusFromBig(2) got %q, want %q", err, "") + } + x := NewNat().setBig(big.NewInt(1)) + y := NewNat().setBig(big.NewInt(2)) + x.MulMod(y, m) +} + +func TestNatCmp(t *testing.T) { + testcases := [][3]int64{ + {33, 22, 1}, + {33, 33, 0}, + {22, 33, -1}, + } + for _, tc := range testcases { + a := new(big.Int).SetInt64(tc[0]) + b := new(big.Int).SetInt64(tc[1]) + na := natFromBytes(a.Bytes()) + nb := natFromBytes(b.Bytes()) + if res, _ := na.Cmp(nb); res != int(tc[2]) { + t.Errorf("expected %d got %d for (%d).cmp(%d)", tc[2], res, tc[0], tc[1]) + } } } diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go index 4d78d1eaaa6..81ece3db6eb 100644 --- a/src/crypto/rsa/rsa.go +++ b/src/crypto/rsa/rsa.go @@ -32,6 +32,7 @@ import ( "crypto/internal/randutil" "crypto/rand" "crypto/subtle" + "encoding/binary" "errors" "hash" "io" @@ -236,31 +237,149 @@ func (priv *PrivateKey) Validate() error { } // Check that Πprimes == n. - modulus := new(big.Int).Set(bigOne) + N, err := bigmod.NewModulusFromBig(priv.N) + if err != nil { + return err + } + bigOneNat, err := bigmod.NewNat().SetBytes(bigOne.Bytes(), N) + if err != nil { + return err + } + modulus, err := bigmod.NewNat().SetBytes(bigOne.Bytes(), N) + if err != nil { + return err + } for _, prime := range priv.Primes { // Any primes ≤ 1 will cause divide-by-zero panics later. - if prime.Cmp(bigOne) <= 0 { + nprime, err := bigmod.NewNat().SetBytes(prime.Bytes(), N) + if err != nil { + return err + } + if d := nprime.Cmp(bigOneNat); d <= 0 { return errors.New("crypto/rsa: invalid prime value") } - modulus.Mul(modulus, prime) + modulus.MulMod(nprime, N) } - if modulus.Cmp(priv.N) != 0 { + if modulus.Equal(N.Nat()) != 0 { return errors.New("crypto/rsa: invalid modulus") } - // Check that de ≡ 1 mod p-1, for each prime. - // This implies that e is coprime to each p-1 as e has a multiplicative - // inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) = - // exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1 - // mod p. Thus a^de ≡ a mod n for all a coprime to n, as required. - congruence := new(big.Int) - de := new(big.Int).SetInt64(int64(priv.E)) - de.Mul(de, priv.D) - for _, prime := range priv.Primes { - pminus1 := new(big.Int).Sub(prime, bigOne) - congruence.Mod(de, pminus1) - if congruence.Cmp(bigOne) != 0 { - return errors.New("crypto/rsa: invalid exponents") + // NIST SP 800-56B REV.2 6.4.1.1 3.a + // Key-pair consistency verifying that m = (m^e)^d mod n for some integer m satisfying 1 < m < (n − 1). + m, err := bigmod.NewNat().SetBytes([]byte{2}, N) + ebytes := make([]byte, 8) + binary.BigEndian.PutUint64(ebytes, uint64(priv.E)) + mm := bigmod.NewNat().Exp(m, ebytes, N) + mm.Exp(mm, priv.D.Bytes(), N) + if mm.Equal(m) != 1 { + return errors.New("crypto/rsa: key-pair consistency check failed") + } + + if len(priv.Primes) > 2 { + // Check that de ≡ 1 mod p-1, for each prime. + // This implies that e is coprime to each p-1 as e has a multiplicative + // inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) = + // exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1 + // mod p. Thus a^de ≡ a mod n for all a coprime to n, as required. + congruence := new(big.Int) + de := new(big.Int).SetInt64(int64(priv.E)) + de.Mul(de, priv.D) + for _, prime := range priv.Primes { + pminus1 := new(big.Int).Sub(prime, bigOne) + congruence.Mod(de, pminus1) + if congruence.Cmp(bigOne) != 0 { + return errors.New("crypto/rsa: invalid exponents") + } + } + } else { + // 6.4.1.2.1.D rsakpv1-crt + // + // (npub == p × q) + pBytes := priv.Primes[0].Bytes() + qBytes := priv.Primes[1].Bytes() + p, err := bigmod.NewNat().SetBytes(pBytes, N) + if err != nil { + return err + } + q, err := bigmod.NewNat().SetBytes(qBytes, N) + if err != nil { + return err + } + product := bigmod.NewNat().Mul(p, q, N) + if r := product.Cmp(N.Nat()); r != 0 { + return errors.New("crypto/rsa: invalid RSA key pair") + } + + // 6.4.1.2.1.F rsakpv1-crt + priv.Precompute() + + // Step a: 1 < dP < (p – 1). + pminus1 := bigmod.NewNat().Set(p) + pminus1.Sub(bigOneNat) + pminus1big := new(big.Int).SetBytes(pminus1.Bytes(N)) + pminus1mod, err := bigmod.NewModulusFromBig(pminus1big) + dP, err := bigmod.NewNat().SetBytes(priv.Precomputed.Dp.Bytes(), N) + if err != nil { + return err + } + // dP := bigmod.NewNat().Mod(d, pminus1mod) + res1 := bigOneNat.Cmp(dP) + res2 := dP.Cmp(pminus1) + if res1 != -1 || res2 != -1 { + return errors.New("crypto/rsa: (step A) invalid RSA key pair") + } + + // Step b: 1 < dQ < (q – 1). + dQ, err := bigmod.NewNat().SetBytes(priv.Precomputed.Dq.Bytes(), N) + if err != nil { + return err + } + qminus1 := bigmod.NewNat().Set(q) + qminus1.Sub(bigOneNat) + qminus1big := new(big.Int).SetBytes(qminus1.Bytes(N)) + qminus1mod, err := bigmod.NewModulusFromBig(qminus1big) + res1 = bigOneNat.Cmp(dQ) + res2 = dQ.Cmp(qminus1) + if res1 != -1 || res2 != -1 { + return errors.New("crypto/rsa: invalid RSA key pair") + } + // Step c: 1 < qInv < p. + qInv, err := bigmod.NewNat().SetBytes(priv.Precomputed.Qinv.Bytes(), N) + if err != nil { + return err + } + res1 = bigOneNat.Cmp(qInv) + res2 = qInv.Cmp(p) + if res1 != -1 || res2 != -1 { + return errors.New("crypto/rsa: invalid RSA key pair") + } + // Step d: 1 = (dP × epub) mod (p – 1). + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(priv.E)) + epub, err := bigmod.NewNat().SetBytes(buf, N) + if err != nil { + return err + } + if bigmod.NewNat().Mod( + bigmod.NewNat().Mul(dP, epub, N), + pminus1mod).Equal(bigOneNat) != 1 { + return errors.New("crypto/rsa: invalid RSA key pair") + } + // Step e: 1 = (dQ × epub) mod (q – 1). + if bigmod.NewNat().Mod( + bigmod.NewNat().Mul(dQ, epub, N), + qminus1mod).Equal(bigOneNat) != 1 { + return errors.New("crypto/rsa: invalid RSA key pair") + } + // Step f: 1 = (qInv × q) mod p. + pmod, err := bigmod.NewModulusFromBig(priv.Primes[0]) + if err != nil { + return err + } + if bigmod.NewNat().Mod( + bigmod.NewNat().Mul(qInv, q, N), + pmod).Equal(bigOneNat) != 1 { + return errors.New("crypto/rsa: invalid RSA key pair") } } return nil @@ -675,11 +794,11 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) { // m2 = c ^ Dq mod q m2 := bigmod.NewNat().Exp(t0.Mod(c, Q), priv.Precomputed.Dq.Bytes(), Q) // m = m - m2 mod p - m.Sub(t0.Mod(m2, P), P) + m.SubMod(t0.Mod(m2, P), P) // m = m * Qinv mod p - m.Mul(Qinv, P) + m.MulMod(Qinv, P) // m = m * q mod N - m.ExpandFor(N).Mul(t0.Mod(Q.Nat(), N), N) + m.ExpandFor(N).MulMod(t0.Mod(Q.Nat(), N), N) // m = m + m2 mod N m.Add(m2.ExpandFor(N), N) } diff --git a/src/crypto/rsa/rsa_test.go b/src/crypto/rsa/rsa_test.go index 2afa045a3a0..886da39c27b 100644 --- a/src/crypto/rsa/rsa_test.go +++ b/src/crypto/rsa/rsa_test.go @@ -621,6 +621,28 @@ func BenchmarkVerifyPSS(b *testing.B) { }) } +func BenchmarkPrecompute(b *testing.B) { + b.Run("2048", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + k := *test2048Key + k.Precomputed = PrecomputedValues{} + k.Precompute() + } + }) +} + +func BenchmarkValidate(b *testing.B) { + b.Run("2048", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := test2048Key.Validate(); err != nil { + b.Fatal(err) + } + } + }) +} + type testEncryptOAEPMessage struct { in []byte seed []byte