mirror of
https://github.com/golang/go
synced 2024-11-25 11:07:59 -07:00
crypto/rsa: port PrivateKey.Validate to bigmod, add validations
This patch ports the implementation of PrivateKey.Validate to use the bigmod math library, ensuring that the arithmetic operations happen in constant time. A few new APIs have been added to bigmod to add operations which don't explicitly require modulus arithmetic, but do take the modulus size into account to ensure we don't leak any non-public information. In addition to porting this routine to use bigmod this patch also adds a few more steps to the validation as defined by NIST SP 800-56B REV. 2 Section 6.4.1.4.3. For #69536
This commit is contained in:
parent
840ac5e037
commit
570721ddf5
@ -327,9 +327,9 @@ func signNISTEC[Point nistPoint[Point]](c *nistCurve[Point], priv *PrivateKey, c
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
s.Mul(r, c.N)
|
s.MulMod(r, c.N)
|
||||||
s.Add(e, c.N)
|
s.Add(e, c.N)
|
||||||
s.Mul(kInv, c.N)
|
s.MulMod(kInv, c.N)
|
||||||
|
|
||||||
// Again, the chance of this happening is cryptographically negligible.
|
// Again, the chance of this happening is cryptographically negligible.
|
||||||
if s.IsZero() == 1 {
|
if s.IsZero() == 1 {
|
||||||
@ -528,12 +528,12 @@ func verifyNISTEC[Point nistPoint[Point]](c *nistCurve[Point], pub *PublicKey, h
|
|||||||
inverse(c, w, s)
|
inverse(c, w, s)
|
||||||
|
|
||||||
// p₁ = [e * s⁻¹]G
|
// p₁ = [e * s⁻¹]G
|
||||||
p1, err := c.newPoint().ScalarBaseMult(e.Mul(w, c.N).Bytes(c.N))
|
p1, err := c.newPoint().ScalarBaseMult(e.MulMod(w, c.N).Bytes(c.N))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
// p₂ = [r * s⁻¹]Q
|
// p₂ = [r * s⁻¹]Q
|
||||||
p2, err := Q.ScalarMult(Q, w.Mul(r, c.N).Bytes(c.N))
|
p2, err := Q.ScalarMult(Q, w.MulMod(r, c.N).Bytes(c.N))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -92,8 +92,8 @@ func (x *Nat) reset(n int) *Nat {
|
|||||||
return x
|
return x
|
||||||
}
|
}
|
||||||
|
|
||||||
// set assigns x = y, optionally resizing x to the appropriate size.
|
// Set assigns x = y, optionally resizing x to the appropriate size.
|
||||||
func (x *Nat) set(y *Nat) *Nat {
|
func (x *Nat) Set(y *Nat) *Nat {
|
||||||
x.reset(len(y.limbs))
|
x.reset(len(y.limbs))
|
||||||
copy(x.limbs, y.limbs)
|
copy(x.limbs, y.limbs)
|
||||||
return x
|
return x
|
||||||
@ -226,6 +226,29 @@ func (x *Nat) IsZero() choice {
|
|||||||
//
|
//
|
||||||
// Both operands must have the same announced length.
|
// Both operands must have the same announced length.
|
||||||
func (x *Nat) cmpGeq(y *Nat) choice {
|
func (x *Nat) cmpGeq(y *Nat) choice {
|
||||||
|
c := x.subCarry(y)
|
||||||
|
// If there was a carry, then subtracting y underflowed, so
|
||||||
|
// x is not greater than or equal to y.
|
||||||
|
return not(choice(c))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cmp compares x and y and returns the result of the compare:
|
||||||
|
// 1 if x > y
|
||||||
|
// 0 if x == y
|
||||||
|
// -1 if x < y
|
||||||
|
func (x *Nat) Cmp(y *Nat) int {
|
||||||
|
if x.Equal(y) == yes {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
c := x.subCarry(y)
|
||||||
|
res := 1
|
||||||
|
if c > 0 {
|
||||||
|
res = -1
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *Nat) subCarry(y *Nat) uint {
|
||||||
// Eliminate bounds checks in the loop.
|
// Eliminate bounds checks in the loop.
|
||||||
size := len(x.limbs)
|
size := len(x.limbs)
|
||||||
xLimbs := x.limbs[:size]
|
xLimbs := x.limbs[:size]
|
||||||
@ -235,9 +258,7 @@ func (x *Nat) cmpGeq(y *Nat) choice {
|
|||||||
for i := 0; i < size; i++ {
|
for i := 0; i < size; i++ {
|
||||||
_, c = bits.Sub(xLimbs[i], yLimbs[i], c)
|
_, c = bits.Sub(xLimbs[i], yLimbs[i], c)
|
||||||
}
|
}
|
||||||
// If there was a carry, then subtracting y underflowed, so
|
return c
|
||||||
// x is not greater than or equal to y.
|
|
||||||
return not(choice(c))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// assign sets x <- y if on == 1, and does nothing otherwise.
|
// assign sets x <- y if on == 1, and does nothing otherwise.
|
||||||
@ -274,7 +295,7 @@ func (x *Nat) add(y *Nat) (c uint) {
|
|||||||
// sub computes x -= y. It returns the borrow of the subtraction.
|
// sub computes x -= y. It returns the borrow of the subtraction.
|
||||||
//
|
//
|
||||||
// Both operands must have the same announced length.
|
// Both operands must have the same announced length.
|
||||||
func (x *Nat) sub(y *Nat) (c uint) {
|
func (x *Nat) Sub(y *Nat) (c uint) {
|
||||||
// Eliminate bounds checks in the loop.
|
// Eliminate bounds checks in the loop.
|
||||||
size := len(x.limbs)
|
size := len(x.limbs)
|
||||||
xLimbs := x.limbs[:size]
|
xLimbs := x.limbs[:size]
|
||||||
@ -301,6 +322,7 @@ type Modulus struct {
|
|||||||
leading int // number of leading zeros in the modulus
|
leading int // number of leading zeros in the modulus
|
||||||
m0inv uint // -nat.limbs[0]⁻¹ mod _W
|
m0inv uint // -nat.limbs[0]⁻¹ mod _W
|
||||||
rr *Nat // R*R for montgomeryRepresentation
|
rr *Nat // R*R for montgomeryRepresentation
|
||||||
|
even bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
|
// rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
|
||||||
@ -374,16 +396,18 @@ func minusInverseModW(x uint) uint {
|
|||||||
// The Int must be odd. The number of significant bits (and nothing else) is
|
// The Int must be odd. The number of significant bits (and nothing else) is
|
||||||
// leaked through timing side-channels.
|
// leaked through timing side-channels.
|
||||||
func NewModulusFromBig(n *big.Int) (*Modulus, error) {
|
func NewModulusFromBig(n *big.Int) (*Modulus, error) {
|
||||||
if b := n.Bits(); len(b) == 0 {
|
b := n.Bits()
|
||||||
|
if len(b) == 0 {
|
||||||
return nil, errors.New("modulus must be >= 0")
|
return nil, errors.New("modulus must be >= 0")
|
||||||
} else if b[0]&1 != 1 {
|
|
||||||
return nil, errors.New("modulus must be odd")
|
|
||||||
}
|
}
|
||||||
m := &Modulus{}
|
m := &Modulus{}
|
||||||
|
m.even = b[0]&1 != 1
|
||||||
m.nat = NewNat().setBig(n)
|
m.nat = NewNat().setBig(n)
|
||||||
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
|
m.leading = _W - bitLen(m.nat.limbs[len(m.nat.limbs)-1])
|
||||||
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
if !m.even {
|
||||||
m.rr = rr(m)
|
m.m0inv = minusInverseModW(m.nat.limbs[0])
|
||||||
|
m.rr = rr(m)
|
||||||
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -508,8 +532,8 @@ func (out *Nat) resetFor(m *Modulus) *Nat {
|
|||||||
//
|
//
|
||||||
// x and m operands must have the same announced length.
|
// x and m operands must have the same announced length.
|
||||||
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
|
func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
|
||||||
t := NewNat().set(x)
|
t := NewNat().Set(x)
|
||||||
underflow := t.sub(m.nat)
|
underflow := t.Sub(m.nat)
|
||||||
// We keep the result if x - m didn't underflow (meaning x >= m)
|
// We keep the result if x - m didn't underflow (meaning x >= m)
|
||||||
// or if always was set.
|
// or if always was set.
|
||||||
keep := not(choice(underflow)) | choice(always)
|
keep := not(choice(underflow)) | choice(always)
|
||||||
@ -520,10 +544,10 @@ func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
|
|||||||
//
|
//
|
||||||
// The length of both operands must be the same as the modulus. Both operands
|
// The length of both operands must be the same as the modulus. Both operands
|
||||||
// must already be reduced modulo m.
|
// must already be reduced modulo m.
|
||||||
func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
|
func (x *Nat) SubMod(y *Nat, m *Modulus) *Nat {
|
||||||
underflow := x.sub(y)
|
underflow := x.Sub(y)
|
||||||
// If the subtraction underflowed, add m.
|
// If the subtraction underflowed, add m.
|
||||||
t := NewNat().set(x)
|
t := NewNat().Set(x)
|
||||||
t.add(m.nat)
|
t.add(m.nat)
|
||||||
x.assign(choice(underflow), t)
|
x.assign(choice(underflow), t)
|
||||||
return x
|
return x
|
||||||
@ -571,6 +595,9 @@ func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
|
|||||||
// All inputs should be the same length and already reduced modulo m.
|
// All inputs should be the same length and already reduced modulo m.
|
||||||
// x will be resized to the size of m and overwritten.
|
// x will be resized to the size of m and overwritten.
|
||||||
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
|
func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
|
||||||
|
if m.even {
|
||||||
|
panic("crypto/rsa: montgomery multiplication on even modulus")
|
||||||
|
}
|
||||||
n := len(m.nat.limbs)
|
n := len(m.nat.limbs)
|
||||||
mLimbs := m.nat.limbs[:n]
|
mLimbs := m.nat.limbs[:n]
|
||||||
aLimbs := a.limbs[:n]
|
aLimbs := a.limbs[:n]
|
||||||
@ -707,17 +734,40 @@ func addMulVVW(z, x []uint, y uint) (carry uint) {
|
|||||||
return carry
|
return carry
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mul calculates x = x * y mod m.
|
// MulMod calculates x = x * y mod m.
|
||||||
//
|
//
|
||||||
// The length of both operands must be the same as the modulus. Both operands
|
// The length of both operands must be the same as the modulus. Both operands
|
||||||
// must already be reduced modulo m.
|
// must already be reduced modulo m.
|
||||||
func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
|
func (x *Nat) MulMod(y *Nat, m *Modulus) *Nat {
|
||||||
// A Montgomery multiplication by a value out of the Montgomery domain
|
// A Montgomery multiplication by a value out of the Montgomery domain
|
||||||
// takes the result out of Montgomery representation.
|
// takes the result out of Montgomery representation.
|
||||||
xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
|
xR := NewNat().Set(x).montgomeryRepresentation(m) // xR = x * R mod m
|
||||||
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
|
return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mul calculates z = x * y.
|
||||||
|
//
|
||||||
|
// All inputs should be the same length and already reduced modulo m.
|
||||||
|
// z will be resized to the size of m and overwritten.
|
||||||
|
func (z *Nat) Mul(x *Nat, y *Nat, m *Modulus) *Nat {
|
||||||
|
n := len(m.nat.limbs)
|
||||||
|
zLimbs := z.resetFor(m).limbs
|
||||||
|
xLimbs := x.limbs
|
||||||
|
yLimbs := y.limbs
|
||||||
|
switch n {
|
||||||
|
default:
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
addMulVVW(zLimbs[i:], xLimbs, yLimbs[i])
|
||||||
|
}
|
||||||
|
case 2048 / _W:
|
||||||
|
const n = 2048 / _W // compiler hint
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
addMulVVW2048(&zLimbs[i:][0], &xLimbs[0], yLimbs[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return z
|
||||||
|
}
|
||||||
|
|
||||||
// Exp calculates out = x^e mod m.
|
// Exp calculates out = x^e mod m.
|
||||||
//
|
//
|
||||||
// The exponent e is represented in big-endian order. The output will be resized
|
// The exponent e is represented in big-endian order. The output will be resized
|
||||||
@ -734,7 +784,7 @@ func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
|
|||||||
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
|
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
|
||||||
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
|
NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
|
||||||
}
|
}
|
||||||
table[0].set(x).montgomeryRepresentation(m)
|
table[0].Set(x).montgomeryRepresentation(m)
|
||||||
for i := 1; i < len(table); i++ {
|
for i := 1; i < len(table); i++ {
|
||||||
table[i].montgomeryMul(table[i-1], table[0], m)
|
table[i].montgomeryMul(table[i-1], table[0], m)
|
||||||
}
|
}
|
||||||
@ -775,8 +825,8 @@ func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
|
|||||||
// For short exponents, precomputing a table and using a window like in Exp
|
// For short exponents, precomputing a table and using a window like in Exp
|
||||||
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
|
// doesn't pay off. Instead, we do a simple conditional square-and-multiply
|
||||||
// chain, skipping the initial run of zeroes.
|
// chain, skipping the initial run of zeroes.
|
||||||
xR := NewNat().set(x).montgomeryRepresentation(m)
|
xR := NewNat().Set(x).montgomeryRepresentation(m)
|
||||||
out.set(xR)
|
out.Set(xR)
|
||||||
for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ {
|
for i := bits.UintSize - bitLen(e) + 1; i < bits.UintSize; i++ {
|
||||||
out.montgomeryMul(out, out, m)
|
out.montgomeryMul(out, out, m)
|
||||||
if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
|
if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
|
||||||
|
@ -35,9 +35,9 @@ func (*Nat) Generate(r *rand.Rand, size int) reflect.Value {
|
|||||||
|
|
||||||
func testModAddCommutative(a *Nat, b *Nat) bool {
|
func testModAddCommutative(a *Nat, b *Nat) bool {
|
||||||
m := maxModulus(uint(len(a.limbs)))
|
m := maxModulus(uint(len(a.limbs)))
|
||||||
aPlusB := new(Nat).set(a)
|
aPlusB := new(Nat).Set(a)
|
||||||
aPlusB.Add(b, m)
|
aPlusB.Add(b, m)
|
||||||
bPlusA := new(Nat).set(b)
|
bPlusA := new(Nat).Set(b)
|
||||||
bPlusA.Add(a, m)
|
bPlusA.Add(a, m)
|
||||||
return aPlusB.Equal(bPlusA) == 1
|
return aPlusB.Equal(bPlusA) == 1
|
||||||
}
|
}
|
||||||
@ -51,8 +51,8 @@ func TestModAddCommutative(t *testing.T) {
|
|||||||
|
|
||||||
func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
|
func testModSubThenAddIdentity(a *Nat, b *Nat) bool {
|
||||||
m := maxModulus(uint(len(a.limbs)))
|
m := maxModulus(uint(len(a.limbs)))
|
||||||
original := new(Nat).set(a)
|
original := new(Nat).Set(a)
|
||||||
a.Sub(b, m)
|
a.SubMod(b, m)
|
||||||
a.Add(b, m)
|
a.Add(b, m)
|
||||||
return a.Equal(original) == 1
|
return a.Equal(original) == 1
|
||||||
}
|
}
|
||||||
@ -71,9 +71,9 @@ func TestMontgomeryRoundtrip(t *testing.T) {
|
|||||||
aPlusOne := new(big.Int).SetBytes(natBytes(a))
|
aPlusOne := new(big.Int).SetBytes(natBytes(a))
|
||||||
aPlusOne.Add(aPlusOne, big.NewInt(1))
|
aPlusOne.Add(aPlusOne, big.NewInt(1))
|
||||||
m, _ := NewModulusFromBig(aPlusOne)
|
m, _ := NewModulusFromBig(aPlusOne)
|
||||||
monty := new(Nat).set(a)
|
monty := new(Nat).Set(a)
|
||||||
monty.montgomeryRepresentation(m)
|
monty.montgomeryRepresentation(m)
|
||||||
aAgain := new(Nat).set(monty)
|
aAgain := new(Nat).Set(monty)
|
||||||
aAgain.montgomeryMul(monty, one, m)
|
aAgain.montgomeryMul(monty, one, m)
|
||||||
if a.Equal(aAgain) != 1 {
|
if a.Equal(aAgain) != 1 {
|
||||||
t.Errorf("%v != %v", a, aAgain)
|
t.Errorf("%v != %v", a, aAgain)
|
||||||
@ -260,12 +260,12 @@ func TestModSub(t *testing.T) {
|
|||||||
m := modulusFromBytes([]byte{13})
|
m := modulusFromBytes([]byte{13})
|
||||||
x := &Nat{[]uint{6}}
|
x := &Nat{[]uint{6}}
|
||||||
y := &Nat{[]uint{7}}
|
y := &Nat{[]uint{7}}
|
||||||
x.Sub(y, m)
|
x.SubMod(y, m)
|
||||||
expected := &Nat{[]uint{12}}
|
expected := &Nat{[]uint{12}}
|
||||||
if x.Equal(expected) != 1 {
|
if x.Equal(expected) != 1 {
|
||||||
t.Errorf("%+v != %+v", x, expected)
|
t.Errorf("%+v != %+v", x, expected)
|
||||||
}
|
}
|
||||||
x.Sub(y, m)
|
x.SubMod(y, m)
|
||||||
expected = &Nat{[]uint{5}}
|
expected = &Nat{[]uint{5}}
|
||||||
if x.Equal(expected) != 1 {
|
if x.Equal(expected) != 1 {
|
||||||
t.Errorf("%+v != %+v", x, expected)
|
t.Errorf("%+v != %+v", x, expected)
|
||||||
@ -323,7 +323,7 @@ func TestMulReductions(t *testing.T) {
|
|||||||
A := NewNat().setBig(a).ExpandFor(N)
|
A := NewNat().setBig(a).ExpandFor(N)
|
||||||
B := NewNat().setBig(b).ExpandFor(N)
|
B := NewNat().setBig(b).ExpandFor(N)
|
||||||
|
|
||||||
if A.Mul(B, N).IsZero() != 1 {
|
if A.MulMod(B, N).IsZero() != 1 {
|
||||||
t.Error("a * b mod (a * b) != 0")
|
t.Error("a * b mod (a * b) != 0")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -333,7 +333,7 @@ func TestMulReductions(t *testing.T) {
|
|||||||
I := NewNat().setBig(i).ExpandFor(N)
|
I := NewNat().setBig(i).ExpandFor(N)
|
||||||
one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
|
one := NewNat().setBig(big.NewInt(1)).ExpandFor(N)
|
||||||
|
|
||||||
if A.Mul(I, N).Equal(one) != 1 {
|
if A.MulMod(I, N).Equal(one) != 1 {
|
||||||
t.Error("a * inv(a) mod b != 1")
|
t.Error("a * inv(a) mod b != 1")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -401,7 +401,7 @@ func BenchmarkModSub(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
x.Sub(y, m)
|
x.SubMod(y, m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -434,7 +434,7 @@ func BenchmarkModMul(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
x.Mul(y, m)
|
x.MulMod(y, m)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -472,9 +472,38 @@ func TestNewModFromBigZero(t *testing.T) {
|
|||||||
t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
|
t.Errorf("NewModulusFromBig(0) got %q, want %q", err, expected)
|
||||||
}
|
}
|
||||||
|
|
||||||
expected = "modulus must be odd"
|
defer func(t *testing.T) {
|
||||||
_, err = NewModulusFromBig(big.NewInt(2))
|
if r := recover(); r != nil {
|
||||||
if err == nil || err.Error() != expected {
|
if s, ok := r.(string); !ok || !strings.Contains(s, "montgomery multiplication on even modulus") {
|
||||||
t.Errorf("NewModulusFromBig(2) got %q, want %q", err, expected)
|
t.Errorf("Unexpected panic: %#v", r)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
t.Error("Expected panic to be recovered, got nothing.")
|
||||||
|
}
|
||||||
|
}(t)
|
||||||
|
|
||||||
|
m, err := NewModulusFromBig(big.NewInt(10))
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("NewModulusFromBig(2) got %q, want %q", err, "")
|
||||||
|
}
|
||||||
|
x := NewNat().setBig(big.NewInt(1))
|
||||||
|
y := NewNat().setBig(big.NewInt(2))
|
||||||
|
x.MulMod(y, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNatCmp(t *testing.T) {
|
||||||
|
testcases := [][3]int64{
|
||||||
|
{33, 22, 1},
|
||||||
|
{33, 33, 0},
|
||||||
|
{22, 33, -1},
|
||||||
|
}
|
||||||
|
for _, tc := range testcases {
|
||||||
|
a := new(big.Int).SetInt64(tc[0])
|
||||||
|
b := new(big.Int).SetInt64(tc[1])
|
||||||
|
na := natFromBytes(a.Bytes())
|
||||||
|
nb := natFromBytes(b.Bytes())
|
||||||
|
if res, _ := na.Cmp(nb); res != int(tc[2]) {
|
||||||
|
t.Errorf("expected %d got %d for (%d).cmp(%d)", tc[2], res, tc[0], tc[1])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -32,6 +32,7 @@ import (
|
|||||||
"crypto/internal/randutil"
|
"crypto/internal/randutil"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"hash"
|
"hash"
|
||||||
"io"
|
"io"
|
||||||
@ -236,31 +237,149 @@ func (priv *PrivateKey) Validate() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check that Πprimes == n.
|
// Check that Πprimes == n.
|
||||||
modulus := new(big.Int).Set(bigOne)
|
N, err := bigmod.NewModulusFromBig(priv.N)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
bigOneNat, err := bigmod.NewNat().SetBytes(bigOne.Bytes(), N)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
modulus, err := bigmod.NewNat().SetBytes(bigOne.Bytes(), N)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
for _, prime := range priv.Primes {
|
for _, prime := range priv.Primes {
|
||||||
// Any primes ≤ 1 will cause divide-by-zero panics later.
|
// Any primes ≤ 1 will cause divide-by-zero panics later.
|
||||||
if prime.Cmp(bigOne) <= 0 {
|
nprime, err := bigmod.NewNat().SetBytes(prime.Bytes(), N)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if d := nprime.Cmp(bigOneNat); d <= 0 {
|
||||||
return errors.New("crypto/rsa: invalid prime value")
|
return errors.New("crypto/rsa: invalid prime value")
|
||||||
}
|
}
|
||||||
modulus.Mul(modulus, prime)
|
modulus.MulMod(nprime, N)
|
||||||
}
|
}
|
||||||
if modulus.Cmp(priv.N) != 0 {
|
if modulus.Equal(N.Nat()) != 0 {
|
||||||
return errors.New("crypto/rsa: invalid modulus")
|
return errors.New("crypto/rsa: invalid modulus")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that de ≡ 1 mod p-1, for each prime.
|
// NIST SP 800-56B REV.2 6.4.1.1 3.a
|
||||||
// This implies that e is coprime to each p-1 as e has a multiplicative
|
// Key-pair consistency verifying that m = (m^e)^d mod n for some integer m satisfying 1 < m < (n − 1).
|
||||||
// inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) =
|
m, err := bigmod.NewNat().SetBytes([]byte{2}, N)
|
||||||
// exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1
|
ebytes := make([]byte, 8)
|
||||||
// mod p. Thus a^de ≡ a mod n for all a coprime to n, as required.
|
binary.BigEndian.PutUint64(ebytes, uint64(priv.E))
|
||||||
congruence := new(big.Int)
|
mm := bigmod.NewNat().Exp(m, ebytes, N)
|
||||||
de := new(big.Int).SetInt64(int64(priv.E))
|
mm.Exp(mm, priv.D.Bytes(), N)
|
||||||
de.Mul(de, priv.D)
|
if mm.Equal(m) != 1 {
|
||||||
for _, prime := range priv.Primes {
|
return errors.New("crypto/rsa: key-pair consistency check failed")
|
||||||
pminus1 := new(big.Int).Sub(prime, bigOne)
|
}
|
||||||
congruence.Mod(de, pminus1)
|
|
||||||
if congruence.Cmp(bigOne) != 0 {
|
if len(priv.Primes) > 2 {
|
||||||
return errors.New("crypto/rsa: invalid exponents")
|
// Check that de ≡ 1 mod p-1, for each prime.
|
||||||
|
// This implies that e is coprime to each p-1 as e has a multiplicative
|
||||||
|
// inverse. Therefore e is coprime to lcm(p-1,q-1,r-1,...) =
|
||||||
|
// exponent(ℤ/nℤ). It also implies that a^de ≡ a mod p as a^(p-1) ≡ 1
|
||||||
|
// mod p. Thus a^de ≡ a mod n for all a coprime to n, as required.
|
||||||
|
congruence := new(big.Int)
|
||||||
|
de := new(big.Int).SetInt64(int64(priv.E))
|
||||||
|
de.Mul(de, priv.D)
|
||||||
|
for _, prime := range priv.Primes {
|
||||||
|
pminus1 := new(big.Int).Sub(prime, bigOne)
|
||||||
|
congruence.Mod(de, pminus1)
|
||||||
|
if congruence.Cmp(bigOne) != 0 {
|
||||||
|
return errors.New("crypto/rsa: invalid exponents")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 6.4.1.2.1.D rsakpv1-crt
|
||||||
|
//
|
||||||
|
// (npub == p × q)
|
||||||
|
pBytes := priv.Primes[0].Bytes()
|
||||||
|
qBytes := priv.Primes[1].Bytes()
|
||||||
|
p, err := bigmod.NewNat().SetBytes(pBytes, N)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
q, err := bigmod.NewNat().SetBytes(qBytes, N)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
product := bigmod.NewNat().Mul(p, q, N)
|
||||||
|
if r := product.Cmp(N.Nat()); r != 0 {
|
||||||
|
return errors.New("crypto/rsa: invalid RSA key pair")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6.4.1.2.1.F rsakpv1-crt
|
||||||
|
priv.Precompute()
|
||||||
|
|
||||||
|
// Step a: 1 < dP < (p – 1).
|
||||||
|
pminus1 := bigmod.NewNat().Set(p)
|
||||||
|
pminus1.Sub(bigOneNat)
|
||||||
|
pminus1big := new(big.Int).SetBytes(pminus1.Bytes(N))
|
||||||
|
pminus1mod, err := bigmod.NewModulusFromBig(pminus1big)
|
||||||
|
dP, err := bigmod.NewNat().SetBytes(priv.Precomputed.Dp.Bytes(), N)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// dP := bigmod.NewNat().Mod(d, pminus1mod)
|
||||||
|
res1 := bigOneNat.Cmp(dP)
|
||||||
|
res2 := dP.Cmp(pminus1)
|
||||||
|
if res1 != -1 || res2 != -1 {
|
||||||
|
return errors.New("crypto/rsa: (step A) invalid RSA key pair")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step b: 1 < dQ < (q – 1).
|
||||||
|
dQ, err := bigmod.NewNat().SetBytes(priv.Precomputed.Dq.Bytes(), N)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
qminus1 := bigmod.NewNat().Set(q)
|
||||||
|
qminus1.Sub(bigOneNat)
|
||||||
|
qminus1big := new(big.Int).SetBytes(qminus1.Bytes(N))
|
||||||
|
qminus1mod, err := bigmod.NewModulusFromBig(qminus1big)
|
||||||
|
res1 = bigOneNat.Cmp(dQ)
|
||||||
|
res2 = dQ.Cmp(qminus1)
|
||||||
|
if res1 != -1 || res2 != -1 {
|
||||||
|
return errors.New("crypto/rsa: invalid RSA key pair")
|
||||||
|
}
|
||||||
|
// Step c: 1 < qInv < p.
|
||||||
|
qInv, err := bigmod.NewNat().SetBytes(priv.Precomputed.Qinv.Bytes(), N)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res1 = bigOneNat.Cmp(qInv)
|
||||||
|
res2 = qInv.Cmp(p)
|
||||||
|
if res1 != -1 || res2 != -1 {
|
||||||
|
return errors.New("crypto/rsa: invalid RSA key pair")
|
||||||
|
}
|
||||||
|
// Step d: 1 = (dP × epub) mod (p – 1).
|
||||||
|
buf := make([]byte, 8)
|
||||||
|
binary.BigEndian.PutUint64(buf, uint64(priv.E))
|
||||||
|
epub, err := bigmod.NewNat().SetBytes(buf, N)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if bigmod.NewNat().Mod(
|
||||||
|
bigmod.NewNat().Mul(dP, epub, N),
|
||||||
|
pminus1mod).Equal(bigOneNat) != 1 {
|
||||||
|
return errors.New("crypto/rsa: invalid RSA key pair")
|
||||||
|
}
|
||||||
|
// Step e: 1 = (dQ × epub) mod (q – 1).
|
||||||
|
if bigmod.NewNat().Mod(
|
||||||
|
bigmod.NewNat().Mul(dQ, epub, N),
|
||||||
|
qminus1mod).Equal(bigOneNat) != 1 {
|
||||||
|
return errors.New("crypto/rsa: invalid RSA key pair")
|
||||||
|
}
|
||||||
|
// Step f: 1 = (qInv × q) mod p.
|
||||||
|
pmod, err := bigmod.NewModulusFromBig(priv.Primes[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if bigmod.NewNat().Mod(
|
||||||
|
bigmod.NewNat().Mul(qInv, q, N),
|
||||||
|
pmod).Equal(bigOneNat) != 1 {
|
||||||
|
return errors.New("crypto/rsa: invalid RSA key pair")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@ -675,11 +794,11 @@ func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) {
|
|||||||
// m2 = c ^ Dq mod q
|
// m2 = c ^ Dq mod q
|
||||||
m2 := bigmod.NewNat().Exp(t0.Mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
|
m2 := bigmod.NewNat().Exp(t0.Mod(c, Q), priv.Precomputed.Dq.Bytes(), Q)
|
||||||
// m = m - m2 mod p
|
// m = m - m2 mod p
|
||||||
m.Sub(t0.Mod(m2, P), P)
|
m.SubMod(t0.Mod(m2, P), P)
|
||||||
// m = m * Qinv mod p
|
// m = m * Qinv mod p
|
||||||
m.Mul(Qinv, P)
|
m.MulMod(Qinv, P)
|
||||||
// m = m * q mod N
|
// m = m * q mod N
|
||||||
m.ExpandFor(N).Mul(t0.Mod(Q.Nat(), N), N)
|
m.ExpandFor(N).MulMod(t0.Mod(Q.Nat(), N), N)
|
||||||
// m = m + m2 mod N
|
// m = m + m2 mod N
|
||||||
m.Add(m2.ExpandFor(N), N)
|
m.Add(m2.ExpandFor(N), N)
|
||||||
}
|
}
|
||||||
|
@ -621,6 +621,28 @@ func BenchmarkVerifyPSS(b *testing.B) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func BenchmarkPrecompute(b *testing.B) {
|
||||||
|
b.Run("2048", func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
k := *test2048Key
|
||||||
|
k.Precomputed = PrecomputedValues{}
|
||||||
|
k.Precompute()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkValidate(b *testing.B) {
|
||||||
|
b.Run("2048", func(b *testing.B) {
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
if err := test2048Key.Validate(); err != nil {
|
||||||
|
b.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
type testEncryptOAEPMessage struct {
|
type testEncryptOAEPMessage struct {
|
||||||
in []byte
|
in []byte
|
||||||
seed []byte
|
seed []byte
|
||||||
|
Loading…
Reference in New Issue
Block a user