diff --git a/src/crypto/sha1/sha1.go b/src/crypto/sha1/sha1.go index ac593b1bf0..fbb2f94613 100644 --- a/src/crypto/sha1/sha1.go +++ b/src/crypto/sha1/sha1.go @@ -90,7 +90,7 @@ func (d0 *digest) Sum(in []byte) []byte { func (d *digest) checkSum() [Size]byte { len := d.len - // Padding. Add a 1 bit and 0 bits until 56 bytes mod 64. + // Padding. Add a 1 bit and 0 bits until 56 bytes mod 64. var tmp [64]byte tmp[0] = 0x80 if len%64 < 56 { @@ -121,6 +121,74 @@ func (d *digest) checkSum() [Size]byte { return digest } +// ConstantTimeSum computes the same result of Sum() but in constant time +func (d0 *digest) ConstantTimeSum(in []byte) []byte { + d := *d0 + hash := d.constSum() + return append(in, hash[:]...) +} + +func (d *digest) constSum() [Size]byte { + var length [8]byte + l := d.len << 3 + for i := uint(0); i < 8; i++ { + length[i] = byte(l >> (56 - 8*i)) + } + + nx := byte(d.nx) + t := nx - 56 // if nx < 56 then the MSB of t is one + mask1b := byte(int8(t) >> 7) // mask1b is 0xFF iff one block is enough + + separator := byte(0x80) // gets reset to 0x00 once used + for i := byte(0); i < chunk; i++ { + mask := byte(int8(i-nx) >> 7) // 0x00 after the end of data + + // if we reached the end of the data, replace with 0x80 or 0x00 + d.x[i] = (^mask & separator) | (mask & d.x[i]) + + // zero the separator once used + separator &= mask + + if i >= 56 { + // we might have to write the length here if all fit in one block + d.x[i] |= mask1b & length[i-56] + } + } + + // compress, and only keep the digest if all fit in one block + block(d, d.x[:]) + + var digest [Size]byte + for i, s := range d.h { + digest[i*4] = mask1b & byte(s>>24) + digest[i*4+1] = mask1b & byte(s>>16) + digest[i*4+2] = mask1b & byte(s>>8) + digest[i*4+3] = mask1b & byte(s) + } + + for i := byte(0); i < chunk; i++ { + // second block, it's always past the end of data, might start with 0x80 + if i < 56 { + d.x[i] = separator + separator = 0 + } else { + d.x[i] = length[i-56] + } + } + + // compress, and only keep the digest if we actually needed the second block + block(d, d.x[:]) + + for i, s := range d.h { + digest[i*4] |= ^mask1b & byte(s>>24) + digest[i*4+1] |= ^mask1b & byte(s>>16) + digest[i*4+2] |= ^mask1b & byte(s>>8) + digest[i*4+3] |= ^mask1b & byte(s) + } + + return digest +} + // Sum returns the SHA1 checksum of the data. func Sum(data []byte) [Size]byte { var d digest diff --git a/src/crypto/sha1/sha1_test.go b/src/crypto/sha1/sha1_test.go index 214afc51e1..3e59a5defe 100644 --- a/src/crypto/sha1/sha1_test.go +++ b/src/crypto/sha1/sha1_test.go @@ -61,15 +61,24 @@ func TestGolden(t *testing.T) { t.Fatalf("Sum function: sha1(%s) = %s want %s", g.in, s, g.out) } c := New() - for j := 0; j < 3; j++ { - if j < 2 { + for j := 0; j < 4; j++ { + var sum []byte + switch j { + case 0, 1: io.WriteString(c, g.in) - } else { + sum = c.Sum(nil) + case 2: io.WriteString(c, g.in[0:len(g.in)/2]) c.Sum(nil) io.WriteString(c, g.in[len(g.in)/2:]) + sum = c.Sum(nil) + case 3: + io.WriteString(c, g.in[0:len(g.in)/2]) + c.(*digest).ConstantTimeSum(nil) + io.WriteString(c, g.in[len(g.in)/2:]) + sum = c.(*digest).ConstantTimeSum(nil) } - s := fmt.Sprintf("%x", c.Sum(nil)) + s := fmt.Sprintf("%x", sum) if s != g.out { t.Fatalf("sha1[%d](%s) = %s want %s", j, g.in, s, g.out) } diff --git a/src/crypto/tls/cipher_suites.go b/src/crypto/tls/cipher_suites.go index d6bcc192d4..7efbe5a364 100644 --- a/src/crypto/tls/cipher_suites.go +++ b/src/crypto/tls/cipher_suites.go @@ -131,7 +131,7 @@ func macSHA1(version uint16, key []byte) macFunction { copy(mac.key, key) return mac } - return tls10MAC{hmac.New(sha1.New, key)} + return tls10MAC{hmac.New(newConstantTimeHash(sha1.New), key)} } // macSHA256 returns a SHA-256 based MAC. These are only supported in TLS 1.2 @@ -142,7 +142,7 @@ func macSHA256(version uint16, key []byte) macFunction { type macFunction interface { Size() int - MAC(digestBuf, seq, header, data []byte) []byte + MAC(digestBuf, seq, header, data, extra []byte) []byte } // fixedNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to @@ -200,7 +200,9 @@ var ssl30Pad1 = [48]byte{0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0 var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c} -func (s ssl30MAC) MAC(digestBuf, seq, header, data []byte) []byte { +// MAC does not offer constant timing guarantees for SSL v3.0, since it's deemed +// useless considering the similar, protocol-level POODLE vulnerability. +func (s ssl30MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte { padLength := 48 if s.h.Size() == 20 { padLength = 40 @@ -222,6 +224,29 @@ func (s ssl30MAC) MAC(digestBuf, seq, header, data []byte) []byte { return s.h.Sum(digestBuf[:0]) } +type constantTimeHash interface { + hash.Hash + ConstantTimeSum(b []byte) []byte +} + +// cthWrapper wraps any hash.Hash that implements ConstantTimeSum, and replaces +// with that all calls to Sum. It's used to obtain a ConstantTimeSum-based HMAC. +type cthWrapper struct { + h constantTimeHash +} + +func (c *cthWrapper) Size() int { return c.h.Size() } +func (c *cthWrapper) BlockSize() int { return c.h.BlockSize() } +func (c *cthWrapper) Reset() { c.h.Reset() } +func (c *cthWrapper) Write(p []byte) (int, error) { return c.h.Write(p) } +func (c *cthWrapper) Sum(b []byte) []byte { return c.h.ConstantTimeSum(b) } + +func newConstantTimeHash(h func() hash.Hash) func() hash.Hash { + return func() hash.Hash { + return &cthWrapper{h().(constantTimeHash)} + } +} + // tls10MAC implements the TLS 1.0 MAC function. RFC 2246, section 6.2.3. type tls10MAC struct { h hash.Hash @@ -231,12 +256,19 @@ func (s tls10MAC) Size() int { return s.h.Size() } -func (s tls10MAC) MAC(digestBuf, seq, header, data []byte) []byte { +// MAC is guaranteed to take constant time, as long as +// len(seq)+len(header)+len(data)+len(extra) is constant. extra is not fed into +// the MAC, but is only provided to make the timing profile constant. +func (s tls10MAC) MAC(digestBuf, seq, header, data, extra []byte) []byte { s.h.Reset() s.h.Write(seq) s.h.Write(header) s.h.Write(data) - return s.h.Sum(digestBuf[:0]) + res := s.h.Sum(digestBuf[:0]) + if extra != nil { + s.h.Write(extra) + } + return res } func rsaKA(version uint16) keyAgreement { diff --git a/src/crypto/tls/conn.go b/src/crypto/tls/conn.go index 6fd486462f..a0c29f0c48 100644 --- a/src/crypto/tls/conn.go +++ b/src/crypto/tls/conn.go @@ -193,18 +193,18 @@ func (hc *halfConn) incSeq() { panic("TLS: sequence number wraparound") } -// removePadding returns an unpadded slice, in constant time, which is a prefix -// of the input. It also returns a byte which is equal to 255 if the padding -// was valid and 0 otherwise. See RFC 2246, section 6.2.3.2 -func removePadding(payload []byte) ([]byte, byte) { +// extractPadding returns, in constant time, the length of the padding to remove +// from the end of payload. It also returns a byte which is equal to 255 if the +// padding was valid and 0 otherwise. See RFC 2246, section 6.2.3.2 +func extractPadding(payload []byte) (toRemove int, good byte) { if len(payload) < 1 { - return payload, 0 + return 0, 0 } paddingLen := payload[len(payload)-1] t := uint(len(payload)-1) - uint(paddingLen) // if len(payload) >= (paddingLen - 1) then the MSB of t is zero - good := byte(int32(^t) >> 31) + good = byte(int32(^t) >> 31) toCheck := 255 // the maximum possible padding length // The length of the padded data is public, so we can use an if here @@ -227,24 +227,24 @@ func removePadding(payload []byte) ([]byte, byte) { good &= good << 1 good = uint8(int8(good) >> 7) - toRemove := good&paddingLen + 1 - return payload[:len(payload)-int(toRemove)], good + toRemove = int(paddingLen) + 1 + return } -// removePaddingSSL30 is a replacement for removePadding in the case that the +// extractPaddingSSL30 is a replacement for extractPadding in the case that the // protocol version is SSLv3. In this version, the contents of the padding // are random and cannot be checked. -func removePaddingSSL30(payload []byte) ([]byte, byte) { +func extractPaddingSSL30(payload []byte) (toRemove int, good byte) { if len(payload) < 1 { - return payload, 0 + return 0, 0 } paddingLen := int(payload[len(payload)-1]) + 1 if paddingLen > len(payload) { - return payload, 0 + return 0, 0 } - return payload[:len(payload)-paddingLen], 255 + return paddingLen, 255 } func roundUp(a, b int) int { @@ -270,6 +270,7 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) } paddingGood := byte(255) + paddingLen := 0 explicitIVLen := 0 // decrypt @@ -312,22 +313,17 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) } c.CryptBlocks(payload, payload) if hc.version == VersionSSL30 { - payload, paddingGood = removePaddingSSL30(payload) + paddingLen, paddingGood = extractPaddingSSL30(payload) } else { - payload, paddingGood = removePadding(payload) - } - b.resize(recordHeaderLen + explicitIVLen + len(payload)) + paddingLen, paddingGood = extractPadding(payload) - // note that we still have a timing side-channel in the - // MAC check, below. An attacker can align the record - // so that a correct padding will cause one less hash - // block to be calculated. Then they can iteratively - // decrypt a record by breaking each byte. See - // "Password Interception in a SSL/TLS Channel", Brice - // Canvel et al. - // - // However, our behavior matches OpenSSL, so we leak - // only as much as they do. + // To protect against CBC padding oracles like Lucky13, the data + // past paddingLen (which is secret) is passed to the MAC + // function as extra data, to be fed into the HMAC after + // computing the digest. This makes the MAC constant time as + // long as the digest computation is constant time and does not + // affect the subsequent write. + } default: panic("unknown cipher type") } @@ -340,17 +336,19 @@ func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) } // strip mac off payload, b.data - n := len(payload) - macSize + n := len(payload) - macSize - paddingLen + n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n) // if n < 0 { n = 0 } b.data[3] = byte(n >> 8) b.data[4] = byte(n) - b.resize(recordHeaderLen + explicitIVLen + n) - remoteMAC := payload[n:] - localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n]) + remoteMAC := payload[n : n+macSize] + localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n], payload[n+macSize:]) if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 { return false, 0, alertBadRecordMAC } hc.inDigestBuf = localMAC + + b.resize(recordHeaderLen + explicitIVLen + n) } hc.incSeq() @@ -378,7 +376,7 @@ func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) { func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { // mac if hc.mac != nil { - mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:]) + mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:], nil) n := len(b.data) b.resize(n + len(mac)) diff --git a/src/crypto/tls/conn_test.go b/src/crypto/tls/conn_test.go index 15397d607e..5e5c7a2e96 100644 --- a/src/crypto/tls/conn_test.go +++ b/src/crypto/tls/conn_test.go @@ -40,7 +40,7 @@ var paddingTests = []struct { func TestRemovePadding(t *testing.T) { for i, test := range paddingTests { - payload, good := removePadding(test.in) + paddingLen, good := extractPadding(test.in) expectedGood := byte(255) if !test.good { expectedGood = 0 @@ -48,8 +48,8 @@ func TestRemovePadding(t *testing.T) { if good != expectedGood { t.Errorf("#%d: wrong validity, want:%d got:%d", i, expectedGood, good) } - if good == 255 && len(payload) != test.expectedLen { - t.Errorf("#%d: got %d, want %d", i, len(payload), test.expectedLen) + if good == 255 && len(test.in)-paddingLen != test.expectedLen { + t.Errorf("#%d: got %d, want %d", i, len(test.in)-paddingLen, test.expectedLen) } } }