diff --git a/src/crypto/rsa/pkcs1v15.go b/src/crypto/rsa/pkcs1v15.go index 59cf0e6e6c1..971aee6a6d5 100644 --- a/src/crypto/rsa/pkcs1v15.go +++ b/src/crypto/rsa/pkcs1v15.go @@ -181,7 +181,7 @@ func decryptPKCS1v15(priv *PrivateKey, ciphertext []byte) (valid int, em []byte, return } } else { - em, err = decrypt(priv, ciphertext) + em, err = decrypt(priv, ciphertext, noCheck) if err != nil { return } @@ -295,7 +295,7 @@ func SignPKCS1v15(random io.Reader, priv *PrivateKey, hash crypto.Hash, hashed [ copy(em[k-tLen:k-hashLen], prefix) copy(em[k-hashLen:k], hashed) - return decryptAndCheck(priv, em) + return decrypt(priv, em, withCheck) } // VerifyPKCS1v15 verifies an RSA PKCS #1 v1.5 signature. diff --git a/src/crypto/rsa/pss.go b/src/crypto/rsa/pss.go index 3e01bbd8f8e..6f1a0c12a5b 100644 --- a/src/crypto/rsa/pss.go +++ b/src/crypto/rsa/pss.go @@ -219,7 +219,7 @@ func signPSSWithSalt(priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([ if err != nil { return nil, err } - // Note: BoringCrypto takes care of the "AndCheck" part of "decryptAndCheck". + // Note: BoringCrypto always does decrypt "withCheck". // (It's not just decrypt.) s, err := boring.DecryptRSANoPadding(bkey, em) if err != nil { @@ -241,7 +241,7 @@ func signPSSWithSalt(priv *PrivateKey, hash crypto.Hash, hashed, salt []byte) ([ em = emNew } - return decryptAndCheck(priv, em) + return decrypt(priv, em, withCheck) } const ( diff --git a/src/crypto/rsa/rsa.go b/src/crypto/rsa/rsa.go index 14500326522..c71d35ab5ea 100644 --- a/src/crypto/rsa/rsa.go +++ b/src/crypto/rsa/rsa.go @@ -111,8 +111,9 @@ type PrivateKey struct { D *big.Int // private exponent Primes []*big.Int // prime factors of N, has >= 2 elements. - // Precomputed contains precomputed values that speed up private - // operations, if available. + // Precomputed contains precomputed values that speed up RSA operations, + // if available. It must be generated by calling PrivateKey.Precompute and + // must not be modified. Precomputed PrecomputedValues } @@ -207,6 +208,8 @@ type PrecomputedValues struct { // and is implemented by this package without CRT optimizations to limit // complexity. CRTValues []CRTValue + + n, p, q *modulus // moduli for CRT with Montgomery precomputed constants } // CRTValue contains the precomputed Chinese remainder theorem values. @@ -311,6 +314,9 @@ func GenerateMultiPrimeKey(random io.Reader, nprimes int, bits int) (*PrivateKey Dq: Dq, Qinv: Qinv, CRTValues: make([]CRTValue, 0), // non-nil, to match Precompute + n: modulusFromNat(natFromBig(N)), + p: modulusFromNat(natFromBig(P)), + q: modulusFromNat(natFromBig(Q)), }, } return key, nil @@ -450,17 +456,23 @@ func encrypt(pub *PublicKey, plaintext []byte) []byte { N := modulusFromNat(natFromBig(pub.N)) m := natFromBytes(plaintext).expandFor(N) - - e := make([]byte, 8) - binary.BigEndian.PutUint64(e, uint64(pub.E)) - for len(e) > 1 && e[0] == 0 { - e = e[1:] - } + e := intToBytes(pub.E) out := make([]byte, modulusSize(N)) return new(nat).exp(m, e, N).fillBytes(out) } +// intToBytes returns i as a big-endian slice of bytes with no leading zeroes, +// leaking only the bit size of i through timing side-channels. +func intToBytes(i int) []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, uint64(i)) + for len(b) > 1 && b[0] == 0 { + b = b[1:] + } + return b +} + // EncryptOAEP encrypts the given message with RSA-OAEP. // // OAEP is parameterised by a hash function that is used as a random oracle. @@ -540,6 +552,13 @@ var ErrVerification = errors.New("crypto/rsa: verification error") // Precompute performs some calculations that speed up private key operations // in the future. func (priv *PrivateKey) Precompute() { + if priv.Precomputed.n == nil && len(priv.Primes) == 2 { + priv.Precomputed.n = modulusFromNat(natFromBig(priv.N)) + priv.Precomputed.p = modulusFromNat(natFromBig(priv.Primes[0])) + priv.Precomputed.q = modulusFromNat(natFromBig(priv.Primes[1])) + } + + // Fill in the backwards-compatibility *big.Int values. if priv.Precomputed.Dp != nil { return } @@ -568,13 +587,21 @@ func (priv *PrivateKey) Precompute() { } } -// decrypt performs an RSA decryption of ciphertext into out. -func decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { +const withCheck = true +const noCheck = false + +// decrypt performs an RSA decryption of ciphertext into out. If check is true, +// m^e is calculated and compared with ciphertext, in order to defend against +// errors in the CRT computation. +func decrypt(priv *PrivateKey, ciphertext []byte, check bool) ([]byte, error) { if len(priv.Primes) <= 2 { boring.Unreachable() } - N := modulusFromNat(natFromBig(priv.N)) + N := priv.Precomputed.n + if N == nil { + N = modulusFromNat(natFromBig(priv.N)) + } c := natFromBytes(ciphertext).expandFor(N) if c.cmpGeq(N.nat) == 1 { return nil, ErrDecryption @@ -583,49 +610,37 @@ func decrypt(priv *PrivateKey, ciphertext []byte) ([]byte, error) { return nil, ErrDecryption } - // Note that because our private decryption exponents are stored as big.Int, - // we potentially leak the exact number of bits of these exponents. This - // isn't great, but should be fine. - if priv.Precomputed.Dp == nil || len(priv.Primes) > 2 { - out := make([]byte, modulusSize(N)) - return new(nat).exp(c, priv.D.Bytes(), N).fillBytes(out), nil + var m *nat + if priv.Precomputed.n == nil { + m = new(nat).exp(c, priv.D.Bytes(), N) + } else { + t0 := new(nat) + P, Q := priv.Precomputed.p, priv.Precomputed.q + // m = c ^ Dp mod p + m = new(nat).exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P) + // m2 = c ^ Dq mod q + m2 := new(nat).exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q) + // m = m - m2 mod p + m.modSub(t0.mod(m2, P), P) + // m = m * Qinv mod p + m.modMul(natFromBig(priv.Precomputed.Qinv).expandFor(P), P) + // m = m * q mod N + m.expandFor(N).modMul(t0.mod(Q.nat, N), N) + // m = m + m2 mod N + m.modAdd(m2.expandFor(N), N) } - t0 := new(nat) - P := modulusFromNat(natFromBig(priv.Primes[0])) - Q := modulusFromNat(natFromBig(priv.Primes[1])) - // m = c ^ Dp mod p - m := new(nat).exp(t0.mod(c, P), priv.Precomputed.Dp.Bytes(), P) - // m2 = c ^ Dq mod q - m2 := new(nat).exp(t0.mod(c, Q), priv.Precomputed.Dq.Bytes(), Q) - // m = m - m2 mod p - m.modSub(t0.mod(m2, P), P) - // m = m * Qinv mod p - m.modMul(natFromBig(priv.Precomputed.Qinv).expandFor(P), P) - // m = m * q mod N - m.expandFor(N).modMul(t0.mod(Q.nat, N), N) - // m = m + m2 mod N - m.modAdd(m2.expandFor(N), N) + if check { + c1 := new(nat).exp(m, intToBytes(priv.E), N) + if c1.cmpEq(c) != 1 { + return nil, ErrDecryption + } + } out := make([]byte, modulusSize(N)) return m.fillBytes(out), nil } -func decryptAndCheck(priv *PrivateKey, ciphertext []byte) (m []byte, err error) { - m, err = decrypt(priv, ciphertext) - if err != nil { - return nil, err - } - - // In order to defend against errors in the CRT computation, m^e is - // calculated, which should match the original ciphertext. - check := encrypt(&priv.PublicKey, m) - if subtle.ConstantTimeCompare(ciphertext, check) != 1 { - return nil, errors.New("rsa: internal error") - } - return m, nil -} - // DecryptOAEP decrypts ciphertext using RSA-OAEP. // // OAEP is parameterised by a hash function that is used as a random oracle. @@ -662,7 +677,7 @@ func decryptOAEP(hash, mgfHash hash.Hash, random io.Reader, priv *PrivateKey, ci return out, nil } - em, err := decrypt(priv, ciphertext) + em, err := decrypt(priv, ciphertext, noCheck) if err != nil { return nil, err }