mirror of
https://github.com/golang/go
synced 2024-11-27 05:11:22 -07:00
crypto/rand: only read necessary bytes for Int
We only need to read the number of bytes required to store the value "max - 1" to generate a random number in the range [0, max). Before, there was an off-by-one error where an extra byte was read from the io.Reader for inputs like "256" (right at the boundary for a byte). There was a similar off-by-one error in the logic for clearing bits and thus for any input that was a power of 2, there was a 50% chance the read would continue to be retried as the mask failed to remove a bit. Fixes #18165. Change-Id: I548c1368990e23e365591e77980e9086fafb6518 Reviewed-on: https://go-review.googlesource.com/43891 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org> Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org> TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
parent
9f03e89552
commit
8a2553e380
@ -107,16 +107,23 @@ func Int(rand io.Reader, max *big.Int) (n *big.Int, err error) {
|
||||
if max.Sign() <= 0 {
|
||||
panic("crypto/rand: argument to Int is <= 0")
|
||||
}
|
||||
k := (max.BitLen() + 7) / 8
|
||||
|
||||
// b is the number of bits in the most significant byte of max.
|
||||
b := uint(max.BitLen() % 8)
|
||||
n = new(big.Int)
|
||||
n.Sub(max, n.SetUint64(1))
|
||||
// bitLen is the maximum bit length needed to encode a value < max.
|
||||
bitLen := n.BitLen()
|
||||
if bitLen == 0 {
|
||||
// the only valid result is 0
|
||||
return
|
||||
}
|
||||
// k is the maximum byte length needed to encode a value < max.
|
||||
k := (bitLen + 7) / 8
|
||||
// b is the number of bits in the most significant byte of max-1.
|
||||
b := uint(bitLen % 8)
|
||||
if b == 0 {
|
||||
b = 8
|
||||
}
|
||||
|
||||
bytes := make([]byte, k)
|
||||
n = new(big.Int)
|
||||
|
||||
for {
|
||||
_, err = io.ReadFull(rand, bytes)
|
||||
|
@ -5,7 +5,10 @@
|
||||
package rand_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
mathrand "math/rand"
|
||||
"testing"
|
||||
@ -45,6 +48,56 @@ func TestInt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type countingReader struct {
|
||||
r io.Reader
|
||||
n int
|
||||
}
|
||||
|
||||
func (r *countingReader) Read(p []byte) (n int, err error) {
|
||||
n, err = r.r.Read(p)
|
||||
r.n += n
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Test that Int reads only the necessary number of bytes from the reader for
|
||||
// max at each bit length
|
||||
func TestIntReads(t *testing.T) {
|
||||
for i := 0; i < 32; i++ {
|
||||
max := int64(1 << uint64(i))
|
||||
t.Run(fmt.Sprintf("max=%d", max), func(t *testing.T) {
|
||||
reader := &countingReader{r: rand.Reader}
|
||||
|
||||
_, err := rand.Int(reader, big.NewInt(max))
|
||||
if err != nil {
|
||||
t.Fatalf("Can't generate random value: %d, %v", max, err)
|
||||
}
|
||||
expected := (i + 7) / 8
|
||||
if reader.n != expected {
|
||||
t.Errorf("Int(reader, %d) should read %d bytes, but it read: %d", max, expected, reader.n)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test that Int does not mask out valid return values
|
||||
func TestIntMask(t *testing.T) {
|
||||
for max := 1; max <= 256; max++ {
|
||||
t.Run(fmt.Sprintf("max=%d", max), func(t *testing.T) {
|
||||
for i := 0; i < max; i++ {
|
||||
var b bytes.Buffer
|
||||
b.WriteByte(byte(i))
|
||||
n, err := rand.Int(&b, big.NewInt(int64(max)))
|
||||
if err != nil {
|
||||
t.Fatalf("Can't generate random value: %d, %v", max, err)
|
||||
}
|
||||
if n.Int64() != int64(i) {
|
||||
t.Errorf("Int(reader, %d) should have returned value of %d, but it returned: %v", max, i, n)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testIntPanics(t *testing.T, b *big.Int) {
|
||||
defer func() {
|
||||
if err := recover(); err == nil {
|
||||
|
Loading…
Reference in New Issue
Block a user