diff --git a/src/crypto/cipher/cfb_test.go b/src/crypto/cipher/cfb_test.go index 72f62e69d3..67033d9a3b 100644 --- a/src/crypto/cipher/cfb_test.go +++ b/src/crypto/cipher/cfb_test.go @@ -8,8 +8,11 @@ import ( "bytes" "crypto/aes" "crypto/cipher" + "crypto/des" + "crypto/internal/cryptotest" "crypto/rand" "encoding/hex" + "fmt" "testing" ) @@ -111,3 +114,47 @@ func TestCFBInverse(t *testing.T) { t.Errorf("got: %x, want: %x", plaintextCopy, plaintext) } } + +func TestCFBStream(t *testing.T) { + + for _, keylen := range []int{128, 192, 256} { + + t.Run(fmt.Sprintf("AES-%d", keylen), func(t *testing.T) { + rng := newRandReader(t) + + key := make([]byte, keylen/8) + rng.Read(key) + + block, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + + t.Run("Encrypter", func(t *testing.T) { + cryptotest.TestStreamFromBlock(t, block, cipher.NewCFBEncrypter) + }) + t.Run("Decrypter", func(t *testing.T) { + cryptotest.TestStreamFromBlock(t, block, cipher.NewCFBDecrypter) + }) + }) + } + + t.Run("DES", func(t *testing.T) { + rng := newRandReader(t) + + key := make([]byte, 8) + rng.Read(key) + + block, err := des.NewCipher(key) + if err != nil { + panic(err) + } + + t.Run("Encrypter", func(t *testing.T) { + cryptotest.TestStreamFromBlock(t, block, cipher.NewCFBEncrypter) + }) + t.Run("Decrypter", func(t *testing.T) { + cryptotest.TestStreamFromBlock(t, block, cipher.NewCFBDecrypter) + }) + }) +} diff --git a/src/crypto/cipher/ctr_test.go b/src/crypto/cipher/ctr_test.go index e5cce576c7..4bb9deab80 100644 --- a/src/crypto/cipher/ctr_test.go +++ b/src/crypto/cipher/ctr_test.go @@ -6,7 +6,11 @@ package cipher_test import ( "bytes" + "crypto/aes" "crypto/cipher" + "crypto/des" + "crypto/internal/cryptotest" + "fmt" "testing" ) @@ -53,3 +57,37 @@ func TestCTR(t *testing.T) { } } } + +func TestCTRStream(t *testing.T) { + + for _, keylen := range []int{128, 192, 256} { + + t.Run(fmt.Sprintf("AES-%d", keylen), func(t *testing.T) { + rng := newRandReader(t) + + key := make([]byte, keylen/8) + rng.Read(key) + + block, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + + cryptotest.TestStreamFromBlock(t, block, cipher.NewCTR) + }) + } + + t.Run("DES", func(t *testing.T) { + rng := newRandReader(t) + + key := make([]byte, 8) + rng.Read(key) + + block, err := des.NewCipher(key) + if err != nil { + panic(err) + } + + cryptotest.TestStreamFromBlock(t, block, cipher.NewCTR) + }) +} diff --git a/src/crypto/cipher/ofb_test.go b/src/crypto/cipher/ofb_test.go index 8d3c5d3a38..036b76c45c 100644 --- a/src/crypto/cipher/ofb_test.go +++ b/src/crypto/cipher/ofb_test.go @@ -14,6 +14,9 @@ import ( "bytes" "crypto/aes" "crypto/cipher" + "crypto/des" + "crypto/internal/cryptotest" + "fmt" "testing" ) @@ -100,3 +103,37 @@ func TestOFB(t *testing.T) { } } } + +func TestOFBStream(t *testing.T) { + + for _, keylen := range []int{128, 192, 256} { + + t.Run(fmt.Sprintf("AES-%d", keylen), func(t *testing.T) { + rng := newRandReader(t) + + key := make([]byte, keylen/8) + rng.Read(key) + + block, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + + cryptotest.TestStreamFromBlock(t, block, cipher.NewOFB) + }) + } + + t.Run("DES", func(t *testing.T) { + rng := newRandReader(t) + + key := make([]byte, 8) + rng.Read(key) + + block, err := des.NewCipher(key) + if err != nil { + panic(err) + } + + cryptotest.TestStreamFromBlock(t, block, cipher.NewOFB) + }) +} diff --git a/src/crypto/internal/cryptotest/stream.go b/src/crypto/internal/cryptotest/stream.go new file mode 100644 index 0000000000..fb9c553bd1 --- /dev/null +++ b/src/crypto/internal/cryptotest/stream.go @@ -0,0 +1,241 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package cryptotest + +import ( + "bytes" + "crypto/cipher" + "crypto/subtle" + "fmt" + "strings" + "testing" +) + +// Each test is executed with each of the buffer lengths in bufLens. +var ( + bufLens = []int{0, 1, 3, 4, 8, 10, 15, 16, 20, 32, 50, 4096, 5000} + bufCap = 10000 +) + +// MakeStream returns a cipher.Stream instance. +// +// Multiple calls to MakeStream must return equivalent instances, +// so for example the key and/or IV must be fixed. +type MakeStream func() cipher.Stream + +// TestStream performs a set of tests on cipher.Stream implementations, +// checking the documented requirements of XORKeyStream. +func TestStream(t *testing.T, ms MakeStream) { + + t.Run("XORSemantics", func(t *testing.T) { + if strings.Contains(t.Name(), "TestCFBStream") { + // This is ugly, but so is CFB's abuse of cipher.Stream. + // Don't want to make it easier for anyone else to do that. + t.Skip("CFB implements cipher.Stream but does not follow XOR semantics") + } + + // Test that XORKeyStream inverts itself for encryption/decryption. + t.Run("Roundtrip", func(t *testing.T) { + + for _, length := range bufLens { + t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) { + rng := newRandReader(t) + + plaintext := make([]byte, length) + rng.Read(plaintext) + + ciphertext := make([]byte, length) + decrypted := make([]byte, length) + + ms().XORKeyStream(ciphertext, plaintext) // Encrypt plaintext + ms().XORKeyStream(decrypted, ciphertext) // Decrypt ciphertext + if !bytes.Equal(decrypted, plaintext) { + t.Errorf("plaintext is different after an encrypt/decrypt cycle; got %s, want %s", truncateHex(decrypted), truncateHex(plaintext)) + } + }) + } + }) + + // Test that XORKeyStream behaves the same as directly XORing + // plaintext with the stream. + t.Run("DirectXOR", func(t *testing.T) { + + for _, length := range bufLens { + t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) { + rng := newRandReader(t) + + plaintext := make([]byte, length) + rng.Read(plaintext) + + // Encrypting all zeros should reveal the stream itself + stream, directXOR := make([]byte, length), make([]byte, length) + ms().XORKeyStream(stream, stream) + // Encrypt plaintext by directly XORing the stream + subtle.XORBytes(directXOR, stream, plaintext) + + // Encrypt plaintext with XORKeyStream + ciphertext := make([]byte, length) + ms().XORKeyStream(ciphertext, plaintext) + if !bytes.Equal(ciphertext, directXOR) { + t.Errorf("xor semantics were not preserved; got %s, want %s", truncateHex(ciphertext), truncateHex(directXOR)) + } + }) + } + }) + }) + + t.Run("AlterInput", func(t *testing.T) { + rng := newRandReader(t) + src, dst, before := make([]byte, bufCap), make([]byte, bufCap), make([]byte, bufCap) + rng.Read(src) + + for _, length := range bufLens { + + t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) { + copy(before, src) + + ms().XORKeyStream(dst[:length], src[:length]) + if !bytes.Equal(src, before) { + t.Errorf("XORKeyStream modified src; got %s, want %s", truncateHex(src), truncateHex(before)) + } + }) + } + }) + + t.Run("Aliasing", func(t *testing.T) { + rng := newRandReader(t) + + buff, expectedOutput := make([]byte, bufCap), make([]byte, bufCap) + + for _, length := range bufLens { + // Record what output is when src and dst are different + rng.Read(buff) + ms().XORKeyStream(expectedOutput[:length], buff[:length]) + + // Check that the same output is generated when src=dst alias to the same + // memory + ms().XORKeyStream(buff[:length], buff[:length]) + if !bytes.Equal(buff[:length], expectedOutput[:length]) { + t.Errorf("block cipher produced different output when dst = src; got %x, want %x", buff[:length], expectedOutput[:length]) + } + } + }) + + t.Run("OutOfBoundsWrite", func(t *testing.T) { // Issue 21104 + rng := newRandReader(t) + + plaintext := make([]byte, bufCap) + rng.Read(plaintext) + ciphertext := make([]byte, bufCap) + + for _, length := range bufLens { + copy(ciphertext, plaintext) // Reset ciphertext buffer + + t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) { + mustPanic(t, "output smaller than input", func() { ms().XORKeyStream(ciphertext[:length], plaintext) }) + + if !bytes.Equal(ciphertext[length:], plaintext[length:]) { + t.Errorf("XORKeyStream did out of bounds write; got %s, want %s", truncateHex(ciphertext[length:]), truncateHex(plaintext[length:])) + } + }) + } + }) + + t.Run("BufferOverlap", func(t *testing.T) { + rng := newRandReader(t) + + buff := make([]byte, bufCap) + rng.Read(buff) + + for _, length := range bufLens { + if length == 0 || length == 1 { + continue + } + + t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) { + // Make src and dst slices point to same array with inexact overlap + src := buff[:length] + dst := buff[1 : length+1] + mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) }) + + // Only overlap on one byte + src = buff[:length] + dst = buff[length-1 : 2*length-1] + mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) }) + + // src comes after dst with one byte overlap + src = buff[length-1 : 2*length-1] + dst = buff[:length] + mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) }) + }) + } + }) + + t.Run("KeepState", func(t *testing.T) { + rng := newRandReader(t) + + plaintext := make([]byte, bufCap) + rng.Read(plaintext) + ciphertext := make([]byte, bufCap) + + // Make one long call to XORKeyStream + ms().XORKeyStream(ciphertext, plaintext) + + for _, step := range bufLens { + if step == 0 { + continue + } + stepMsg := fmt.Sprintf("step %d: ", step) + + dst := make([]byte, bufCap) + + // Make a bunch of small calls to (stateful) XORKeyStream + stream := ms() + i := 0 + for i+step < len(plaintext) { + stream.XORKeyStream(dst[i:], plaintext[i:i+step]) + i += step + } + stream.XORKeyStream(dst[i:], plaintext[i:]) + + if !bytes.Equal(dst, ciphertext) { + t.Errorf(stepMsg+"successive XORKeyStream calls returned a different result than a single one; got %s, want %s", truncateHex(dst), truncateHex(ciphertext)) + } + } + }) +} + +// TestStreamFromBlock creates a Stream from a cipher.Block used in a +// cipher.BlockMode. It addresses Issue 68377 by checking for a panic when the +// BlockMode uses an IV with incorrect length. +// For a valid IV, it also runs all TestStream tests on the resulting stream. +func TestStreamFromBlock(t *testing.T, block cipher.Block, blockMode func(b cipher.Block, iv []byte) cipher.Stream) { + + t.Run("WrongIVLen", func(t *testing.T) { + t.Skip("see Issue 68377") + + rng := newRandReader(t) + iv := make([]byte, block.BlockSize()+1) + rng.Read(iv) + mustPanic(t, "IV length must equal block size", func() { blockMode(block, iv) }) + }) + + t.Run("BlockModeStream", func(t *testing.T) { + rng := newRandReader(t) + iv := make([]byte, block.BlockSize()) + rng.Read(iv) + + TestStream(t, func() cipher.Stream { return blockMode(block, iv) }) + }) +} + +func truncateHex(b []byte) string { + numVals := 50 + + if len(b) <= numVals { + return fmt.Sprintf("%x", b) + } + return fmt.Sprintf("%x...", b[:numVals]) +} diff --git a/src/crypto/rc4/rc4_test.go b/src/crypto/rc4/rc4_test.go index e7356aa45d..f092f4f265 100644 --- a/src/crypto/rc4/rc4_test.go +++ b/src/crypto/rc4/rc4_test.go @@ -6,6 +6,8 @@ package rc4 import ( "bytes" + "crypto/cipher" + "crypto/internal/cryptotest" "fmt" "testing" ) @@ -136,6 +138,13 @@ func TestBlock(t *testing.T) { } } +func TestRC4Stream(t *testing.T) { + cryptotest.TestStream(t, func() cipher.Stream { + c, _ := NewCipher(golden[0].key) + return c + }) +} + func benchmark(b *testing.B, size int64) { buf := make([]byte, size) c, err := NewCipher(golden[0].key)