diff --git a/src/crypto/ecdsa/ecdsa.go b/src/crypto/ecdsa/ecdsa.go index 2179b01e8e3..9a4ea36d6a1 100644 --- a/src/crypto/ecdsa/ecdsa.go +++ b/src/crypto/ecdsa/ecdsa.go @@ -327,9 +327,9 @@ func signNISTEC[Point nistPoint[Point]](c *nistCurve[Point], priv *PrivateKey, c if err != nil { return nil, err } - s.Mul(r, c.N) + s.MulMod(r, c.N) s.Add(e, c.N) - s.Mul(kInv, c.N) + s.MulMod(kInv, c.N) // Again, the chance of this happening is cryptographically negligible. if s.IsZero() == 1 { @@ -528,12 +528,12 @@ func verifyNISTEC[Point nistPoint[Point]](c *nistCurve[Point], pub *PublicKey, h inverse(c, w, s) // p₁ = [e * s⁻¹]G - p1, err := c.newPoint().ScalarBaseMult(e.Mul(w, c.N).Bytes(c.N)) + p1, err := c.newPoint().ScalarBaseMult(e.MulMod(w, c.N).Bytes(c.N)) if err != nil { return false } // p₂ = [r * s⁻¹]Q - p2, err := Q.ScalarMult(Q, w.Mul(r, c.N).Bytes(c.N)) + p2, err := Q.ScalarMult(Q, w.MulMod(r, c.N).Bytes(c.N)) if err != nil { return false } diff --git a/src/crypto/internal/bigmod/nat.go b/src/crypto/internal/bigmod/nat.go index 5cbae40efe9..0982730ce29 100644 --- a/src/crypto/internal/bigmod/nat.go +++ b/src/crypto/internal/bigmod/nat.go @@ -92,8 +92,8 @@ func (x *Nat) reset(n int) *Nat { return x } -// set assigns x = y, optionally resizing x to the appropriate size. -func (x *Nat) set(y *Nat) *Nat { +// Set assigns x = y, optionally resizing x to the appropriate size. +func (x *Nat) Set(y *Nat) *Nat { x.reset(len(y.limbs)) copy(x.limbs, y.limbs) return x @@ -226,6 +226,29 @@ func (x *Nat) IsZero() choice { // // Both operands must have the same announced length. func (x *Nat) cmpGeq(y *Nat) choice { + c := x.subCarry(y) + // If there was a carry, then subtracting y underflowed, so + // x is not greater than or equal to y. + return not(choice(c)) +} + +// Cmp compares x and y and returns the result of the compare: +// 1 if x > y +// 0 if x == y +// -1 if x < y +func (x *Nat) Cmp(y *Nat) int { + if x.Equal(y) == yes { + return 0 + } + c := x.subCarry(y) + res := 1 + if c > 0 { + res = -1 + } + return res +} + +func (x *Nat) subCarry(y *Nat) uint { // Eliminate bounds checks in the loop. size := len(x.limbs) xLimbs := x.limbs[:size] @@ -235,9 +258,7 @@ func (x *Nat) cmpGeq(y *Nat) choice { for i := 0; i < size; i++ { _, c = bits.Sub(xLimbs[i], yLimbs[i], c) } - // If there was a carry, then subtracting y underflowed, so - // x is not greater than or equal to y. - return not(choice(c)) + return c } // assign sets x <- y if on == 1, and does nothing otherwise. @@ -274,7 +295,7 @@ func (x *Nat) add(y *Nat) (c uint) { // sub computes x -= y. It returns the borrow of the subtraction. // // Both operands must have the same announced length. -func (x *Nat) sub(y *Nat) (c uint) { +func (x *Nat) Sub(y *Nat) (c uint) { // Eliminate bounds checks in the loop. size := len(x.limbs) xLimbs := x.limbs[:size] @@ -301,6 +322,7 @@ type Modulus struct { leading int // number of leading zeros in the modulus m0inv uint // -nat.limbs[0]⁻¹ mod _W rr *Nat // R*R for montgomeryRepresentation + even bool } // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs). @@ -374,16 +396,18 @@ func minusInverseModW(x uint) uint { // The Int must be odd. The number of significant bits (and nothing else) is // leaked through timing side-channels. func NewModulusFromBig(n *big.Int) (*Modulus, error) { - if b := n.Bits(); len(b) == 0 { + b := n.Bits() + if len(b) == 0 { return nil, errors.New("modulus must be >= 0") - } else if b[0]&1 != 1 { - return nil, errors.New("modulus must be odd") } m := &Modulus{} + m.even = b[0]&1 != 1 m.nat = NewNat().setBig(n) m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1]) - m.m0inv = minusInverseModW(m.nat.limbs[0]) - m.rr = rr(m) + if !m.even { + m.m0inv = minusInverseModW(m.nat.limbs[0]) + m.rr = rr(m) + } return m, nil } @@ -508,8 +532,8 @@ func (out *Nat) resetFor(m *Modulus) *Nat { // // x and m operands must have the same announced length. func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { - t := NewNat().set(x) - underflow := t.sub(m.nat) + t := NewNat().Set(x) + underflow := t.Sub(m.nat) // We keep the result if x - m didn't underflow (meaning x >= m) // or if always was set. keep := not(choice(underflow)) | choice(always) @@ -520,10 +544,10 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) { // // The length of both operands must be the same as the modulus. Both operands // must already be reduced modulo m. -func (x *Nat) Sub(y *Nat, m *Modulus) *Nat { - underflow := x.sub(y) +func (x *Nat) SubMod(y *Nat, m *Modulus) *Nat { + underflow := x.Sub(y) // If the subtraction underflowed, add m. - t := NewNat().set(x) + t := NewNat().Set(x) t.add(m.nat) x.assign(choice(underflow), t) return x @@ -571,6 +595,9 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat { // All inputs should be the same length and already reduced modulo m. // x will be resized to the size of m and overwritten. func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat { + if m.even { + panic("crypto/rsa: montgomery multiplication on even modulus") + } n := len(m.nat.limbs) mLimbs := m.nat.limbs[:n] aLimbs := a.limbs[:n] @@ -707,17 +734,40 @@ func addMulVVW(z, x []uint, y uint) (carry uint) { return carry } -// Mul calculates x = x * y mod m. +// MulMod calculates x = x * y mod m. // // The length of both operands must be the same as the modulus. Both operands // must already be reduced modulo m. -func (x *Nat) Mul(y *Nat, m *Modulus) *Nat { +func (x *Nat) MulMod(y *Nat, m *Modulus) *Nat { // A Montgomery multiplication by a value out of the Montgomery domain // takes the result out of Montgomery representation. - xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m + xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m } +// Mul calculates z = x * y. +// +// All inputs should be the same length and already reduced modulo m. +// z will be resized to the size of m and overwritten. +func (z *Nat) Mul(x *Nat, y *Nat, m *Modulus) *Nat { + n := len(m.nat.limbs) + zLimbs := z.resetFor(m).limbs + xLimbs := x.limbs + yLimbs := y.limbs + switch n { + default: + for i := 0; i < n; i++ { + addMulVVW(zLimbs[i:], xLimbs, yLimbs[i]) + } + case 2048 / _W: + const n = 2048 / _W // compiler hint + for i := 0; i < n; i++ { + addMulVVW2048(&zLimbs[i:][0], &xLimbs[0], yLimbs[i]) + } + } + return z +} + // Exp calculates out = x^e mod m. // // The exponent e is represented in big-endian order. The output will be resized @@ -734,7 +784,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat { NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), NewNat(), } - table[0].set(x).montgomeryRepresentation(m) + table[0].Set(x).montgomeryRepresentation(m) for i := 1; i < len(table); i++ { table[i].montgomeryMul(table[i-1], table[0], m) } @@ -775,8 +825,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat { // For short exponents, precomputing a table and using a window like in Exp // doesn't pay off. Instead, we do a simple conditional square-and-multiply // chain, skipping the initial run of zeroes. - xR := NewNat().set(x).montgomeryRepresentation(m) - out.set(xR) + xR := NewNat().Set(x).montgomeryRepresentation(m) + out.Set(xR) for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ { out.montgomeryMul(out, out, m) if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 { diff --git a/src/crypto/internal/bigmod/nat_test.go b/src/crypto/internal/bigmod/nat_test.go index 7a956e3a57d..d3ac157762d 100644 --- a/src/crypto/internal/bigmod/nat_test.go +++ b/src/crypto/internal/bigmod/nat_test.go @@ -35,9 +35,9 @@ func (*Nat) Generate(r *rand.Rand, size int) reflect.Value { func testModAddCommutative(a *Nat, b *Nat) bool { m := maxModulus(uint(len(a.limbs))) - aPlusB := new(Nat).set(a) + aPlusB := new(Nat).Set(a) aPlusB.Add(b, m) - bPlusA := new(Nat).set(b) + bPlusA := new(Nat).Set(b) bPlusA.Add(a, m) return aPlusB.Equal(bPlusA) == 1 } @@ -51,8 +51,8 @@ func TestModAddCommutative(t *testing.T) { func testModSubThenAddIdentity(a *Nat, b *Nat) bool { m := maxModulus(uint(len(a.limbs))) - original := new(Nat).set(a) - a.Sub(b, m) + original := new(Nat).Set(a) + a.SubMod(b, m) a.Add(b, m) return a.Equal(original) == 1 } @@ -71,9 +71,9 @@ func TestMontgomeryRoundtrip(t *testing.T) { aPlusOne := new(big.Int).SetBytes(natBytes(a)) aPlusOne.Add(aPlusOne, big.NewInt(1)) m, _ := NewModulusFromBig(aPlusOne) - monty := new(Nat).set(a) + monty := new(Nat).Set(a) monty.montgomeryRepresentation(m) - aAgain := new(Nat).set(monty) + aAgain := new(Nat).Set(monty) aAgain.montgomeryMul(monty, one, m) if a.Equal(aAgain) != 1 { t.Errorf("%v != %v", a, aAgain) @@ -260,12 +260,12 @@ func TestModSub(t *testing.T) { m := modulusFromBytes([]byte{13}) x := &Nat{[]uint{6}} y := &Nat{[]uint{7}} - x.Sub(y, m) + x.SubMod(y, m) expected := &Nat{[]uint{12}} if x.Equal(expected) != 1 { t.Errorf("%+v != %+v", x, expected) } - x.Sub(y, m) + x.SubMod(y, m) expected = &Nat{[]uint{5}} if x.Equal(expected) != 1 { t.Errorf("%+v != %+v", x, expected) @@ -323,7 +323,7 @@ func TestMulReductions(t *testing.T) { A := NewNat().setBig(a).ExpandFor(N) B := NewNat().setBig(b).ExpandFor(N) - if A.Mul(B, N).IsZero() != 1 { + if A.MulMod(B, N).IsZero() != 1 { t.Error("a * b mod (a * b) != 0") } @@ -333,7 +333,7 @@ func TestMulReductions(t *testing.T) { I := NewNat().setBig(i).ExpandFor(N) one := NewNat().setBig(big.NewInt(1)).ExpandFor(N) - if A.Mul(I, N).Equal(one) != 1 { + if A.MulMod(I, N).Equal(one) != 1 { t.Error("a * inv(a) mod b != 1") } } @@ -401,7 +401,7 @@ func BenchmarkModSub(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - x.Sub(y, m) + x.SubMod(y, m) } } @@ -434,7 +434,7 @@ func BenchmarkModMul(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - x.Mul(y, m) + x.MulMod(y, m) } } @@ -472,9 +472,38 @@ func TestNewModFromBigZero(t *testing.T) { t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected) } - expected = "modulus must be odd" - _, err = NewModulusFromBig(big.NewInt(2)) - if err == nil || err.Error() != expected { - t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected) + defer func(t *testing.T) { + if r := recover(); r != nil { + if s, ok := r.(string); !ok || !strings.Contains(s, "montgomery multiplication on even modulus") { + t.Errorf("Unexpected panic: %#v", r) + } + } else { + t.Error("Expected panic to be recovered, got nothing.") + } + }(t) + + m, err := NewModulusFromBig(big.NewInt(10)) + if err != nil { + t.Errorf("NewModulusFromBig(2) got %q, want %q", err, "") + } + x := NewNat().setBig(big.NewInt(1)) + y := NewNat().setBig(big.NewInt(2)) + x.MulMod(y, m) +} + +func TestNatCmp(t *testing.T) { + testcases := [][3]int64{ + {33, 22, 1}, + {33, 33, 0}, + {22, 33, -1}, + } + for _, tc := range testcases { + a := new(big.Int).SetInt64(tc[0]) + b := new(big.Int).SetInt64(tc[1]) + na := natFromBytes(a.Bytes()) + nb := natFromBytes(b.Bytes()) + if res, _ := na.Cmp(nb); res != int(tc[2]) { + t.Errorf("expected %d got %d for (%d).cmp(%d)", tc[2], res, tc[0], tc[1]) + } } } diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go index 4d78d1eaaa6..81ece3db6eb 100644 --- a/src/crypto/rsa/rsa.go +++ b/src/crypto/rsa/rsa.go @@ -32,6 +32,7 @@ import ( "crypto/internal/randutil" "crypto/rand" "crypto/subtle" + "encoding/binary" "errors" "hash" "io" @@ -236,31 +237,149 @@ func (priv *PrivateKey) Validate() error { } // Check that Πprimes == n. - modulus := new(big.Int).Set(bigOne) + N, err := bigmod.NewModulusFromBig(priv.N) + if err != nil { + return err + } + bigOneNat, err := bigmod.NewNat().SetBytes(bigOne.Bytes(), N) + if err != nil { + return err + } + modulus, err := bigmod.NewNat().SetBytes(bigOne.Bytes(), N) + if err != nil { + return err + } for _, prime := range priv.Primes { // Any primes ≤ 1 will cause divide-by-zero panics later. - if prime.Cmp(bigOne) <= 0 { + nprime, err := bigmod.NewNat().SetBytes(prime.Bytes(), N) + if err != nil { + return err + } + if d := nprime.Cmp(bigOneNat); d <= 0 { return errors.New("crypto/rsa: invalid prime value") } - modulus.Mul(modulus, prime) + modulus.MulMod(nprime, N) } - if modulus.Cmp(priv.N) != 0 { + if modulus.Equal(N.Nat()) != 0 { return errors.New("crypto/rsa: invalid modulus") } - // Check that de ≡ 1 mod p-1, for each prime. - // This implies that e is coprime to each p-1 as e has a multiplicative - // inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) = - // exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1 - // mod p. Thus a^de ≡ a mod n for all a coprime to n, as required. - congruence := new(big.Int) - de := new(big.Int).SetInt64(int64(priv.E)) - de.Mul(de, priv.D) - for _, prime := range priv.Primes { - pminus1 := new(big.Int).Sub(prime, bigOne) - congruence.Mod(de, pminus1) - if congruence.Cmp(bigOne) != 0 { - return errors.New("crypto/rsa: invalid exponents") + // NIST SP 800-56B REV.2 6.4.1.1 3.a + // Key-pair consistency verifying that m = (m^e)^d mod n for some integer m satisfying 1 < m < (n − 1). + m, err := bigmod.NewNat().SetBytes([]byte{2}, N) + ebytes := make([]byte, 8) + binary.BigEndian.PutUint64(ebytes, uint64(priv.E)) + mm := bigmod.NewNat().Exp(m, ebytes, N) + mm.Exp(mm, priv.D.Bytes(), N) + if mm.Equal(m) != 1 { + return errors.New("crypto/rsa: key-pair consistency check failed") + } + + if len(priv.Primes) > 2 { + // Check that de ≡ 1 mod p-1, for each prime. + // This implies that e is coprime to each p-1 as e has a multiplicative + // inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) = + // exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1 + // mod p. Thus a^de ≡ a mod n for all a coprime to n, as required. + congruence := new(big.Int) + de := new(big.Int).SetInt64(int64(priv.E)) + de.Mul(de, priv.D) + for _, prime := range priv.Primes { + pminus1 := new(big.Int).Sub(prime, bigOne) + congruence.Mod(de, pminus1) + if congruence.Cmp(bigOne) != 0 { + return errors.New("crypto/rsa: invalid exponents") + } + } + } else { + // 6.4.1.2.1.D rsakpv1-crt + // + // (npub == p × q) + pBytes := priv.Primes[0].Bytes() + qBytes := priv.Primes[1].Bytes() + p, err := bigmod.NewNat().SetBytes(pBytes, N) + if err != nil { + return err + } + q, err := bigmod.NewNat().SetBytes(qBytes, N) + if err != nil { + return err + } + product := bigmod.NewNat().Mul(p, q, N) + if r := product.Cmp(N.Nat()); r != 0 { + return errors.New("crypto/rsa: invalid RSA key pair") + } + + // 6.4.1.2.1.F rsakpv1-crt + priv.Precompute() + + // Step a: 1 < dP < (p – 1). + pminus1 := bigmod.NewNat().Set(p) + pminus1.Sub(bigOneNat) + pminus1big := new(big.Int).SetBytes(pminus1.Bytes(N)) + pminus1mod, err := bigmod.NewModulusFromBig(pminus1big) + dP, err := bigmod.NewNat().SetBytes(priv.Precomputed.Dp.Bytes(), N) + if err != nil { + return err + } + // dP := bigmod.NewNat().Mod(d, pminus1mod) + res1 := bigOneNat.Cmp(dP) + res2 := dP.Cmp(pminus1) + if res1 != -1 || res2 != -1 { + return errors.New("crypto/rsa: (step A) invalid RSA key pair") + } + + // Step b: 1 < dQ < (q – 1). + dQ, err := bigmod.NewNat().SetBytes(priv.Precomputed.Dq.Bytes(), N) + if err != nil { + return err + } + qminus1 := bigmod.NewNat().Set(q) + qminus1.Sub(bigOneNat) + qminus1big := new(big.Int).SetBytes(qminus1.Bytes(N)) + qminus1mod, err := bigmod.NewModulusFromBig(qminus1big) + res1 = bigOneNat.Cmp(dQ) + res2 = dQ.Cmp(qminus1) + if res1 != -1 || res2 != -1 { + return errors.New("crypto/rsa: invalid RSA key pair") + } + // Step c: 1 < qInv < p. + qInv, err := bigmod.NewNat().SetBytes(priv.Precomputed.Qinv.Bytes(), N) + if err != nil { + return err + } + res1 = bigOneNat.Cmp(qInv) + res2 = qInv.Cmp(p) + if res1 != -1 || res2 != -1 { + return errors.New("crypto/rsa: invalid RSA key pair") + } + // Step d: 1 = (dP × epub) mod (p – 1). + buf := make([]byte, 8) + binary.BigEndian.PutUint64(buf, uint64(priv.E)) + epub, err := bigmod.NewNat().SetBytes(buf, N) + if err != nil { + return err + } + if bigmod.NewNat().Mod( + bigmod.NewNat().Mul(dP, epub, N), + pminus1mod).Equal(bigOneNat) != 1 { + return errors.New("crypto/rsa: invalid RSA key pair") + } + // Step e: 1 = (dQ × epub) mod (q – 1). + if bigmod.NewNat().Mod( + bigmod.NewNat().Mul(dQ, epub, N), + qminus1mod).Equal(bigOneNat) != 1 { + return errors.New("crypto/rsa: invalid RSA key pair") + } + // Step f: 1 = (qInv × q) mod p. + pmod, err := bigmod.NewModulusFromBig(priv.Primes[0]) + if err != nil { + return err + } + if bigmod.NewNat().Mod( + bigmod.NewNat().Mul(qInv, q, N), + pmod).Equal(bigOneNat) != 1 { + return errors.New("crypto/rsa: invalid RSA key pair") } } return nil @@ -675,11 +794,11 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) { // m2 = c ^ Dq mod q m2 := bigmod.NewNat().Exp(t0.Mod(c, Q), priv.Precomputed.Dq.Bytes(), Q) // m = m - m2 mod p - m.Sub(t0.Mod(m2, P), P) + m.SubMod(t0.Mod(m2, P), P) // m = m * Qinv mod p - m.Mul(Qinv, P) + m.MulMod(Qinv, P) // m = m * q mod N - m.ExpandFor(N).Mul(t0.Mod(Q.Nat(), N), N) + m.ExpandFor(N).MulMod(t0.Mod(Q.Nat(), N), N) // m = m + m2 mod N m.Add(m2.ExpandFor(N), N) } diff --git a/src/crypto/rsa/rsa_test.go b/src/crypto/rsa/rsa_test.go index 2afa045a3a0..886da39c27b 100644 --- a/src/crypto/rsa/rsa_test.go +++ b/src/crypto/rsa/rsa_test.go @@ -621,6 +621,28 @@ func BenchmarkVerifyPSS(b *testing.B) { }) } +func BenchmarkPrecompute(b *testing.B) { + b.Run("2048", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + k := *test2048Key + k.Precomputed = PrecomputedValues{} + k.Precompute() + } + }) +} + +func BenchmarkValidate(b *testing.B) { + b.Run("2048", func(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + if err := test2048Key.Validate(); err != nil { + b.Fatal(err) + } + } + }) +} + type testEncryptOAEPMessage struct { in []byte seed []byte