mirror of
https://github.com/golang/go
synced 2024-11-23 07:40:04 -07:00
crypto/rand: properly handle large Read on windows
Use the batched reader to chunk large Read calls on windows to a max of 1 << 31 - 1 bytes. This prevents an infinite loop when trying to read more than 1 << 32 -1 bytes, due to how RtlGenRandom works. This change moves the batched function from rand_unix.go to rand.go, since it is now needed for both windows and unix implementations. Fixes #52561 Change-Id: Id98fc4b1427e5cb2132762a445b2aed646a37473 Reviewed-on: https://go-review.googlesource.com/c/go/+/402257 Run-TryBot: Roland Shoemaker <roland@golang.org> Reviewed-by: Filippo Valsorda <filippo@golang.org> Reviewed-by: Filippo Valsorda <valsorda@google.com> TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
parent
7c74b0db8a
commit
bb1f441618
@ -24,3 +24,21 @@ var Reader io.Reader
|
|||||||
func Read(b []byte) (n int, err error) {
|
func Read(b []byte) (n int, err error) {
|
||||||
return io.ReadFull(Reader, b)
|
return io.ReadFull(Reader, b)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// batched returns a function that calls f to populate a []byte by chunking it
|
||||||
|
// into subslices of, at most, readMax bytes.
|
||||||
|
func batched(f func([]byte) error, readMax int) func([]byte) error {
|
||||||
|
return func(out []byte) error {
|
||||||
|
for len(out) > 0 {
|
||||||
|
read := len(out)
|
||||||
|
if read > readMax {
|
||||||
|
read = readMax
|
||||||
|
}
|
||||||
|
if err := f(out[:read]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
out = out[read:]
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -23,8 +23,8 @@ func TestBatched(t *testing.T) {
|
|||||||
}, 5)
|
}, 5)
|
||||||
|
|
||||||
p := make([]byte, 13)
|
p := make([]byte, 13)
|
||||||
if !fillBatched(p) {
|
if err := fillBatched(p); err != nil {
|
||||||
t.Fatal("batched function returned false")
|
t.Fatalf("batched function returned error: %s", err)
|
||||||
}
|
}
|
||||||
expected := []byte{0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2}
|
expected := []byte{0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2}
|
||||||
if !bytes.Equal(expected, p) {
|
if !bytes.Equal(expected, p) {
|
||||||
@ -55,8 +55,8 @@ func TestBatchedBuffering(t *testing.T) {
|
|||||||
max = len(outputMarker)
|
max = len(outputMarker)
|
||||||
}
|
}
|
||||||
howMuch := prand.Intn(max + 1)
|
howMuch := prand.Intn(max + 1)
|
||||||
if !fillBatched(outputMarker[:howMuch]) {
|
if err := fillBatched(outputMarker[:howMuch]); err != nil {
|
||||||
t.Fatal("batched function returned false")
|
t.Fatalf("batched function returned error: %s", err)
|
||||||
}
|
}
|
||||||
outputMarker = outputMarker[howMuch:]
|
outputMarker = outputMarker[howMuch:]
|
||||||
}
|
}
|
||||||
@ -67,14 +67,14 @@ func TestBatchedBuffering(t *testing.T) {
|
|||||||
|
|
||||||
func TestBatchedError(t *testing.T) {
|
func TestBatchedError(t *testing.T) {
|
||||||
b := batched(func(p []byte) error { return errors.New("failure") }, 5)
|
b := batched(func(p []byte) error { return errors.New("failure") }, 5)
|
||||||
if b(make([]byte, 13)) {
|
if b(make([]byte, 13)) == nil {
|
||||||
t.Fatal("batched function should have returned an error")
|
t.Fatal("batched function should have returned an error")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBatchedEmpty(t *testing.T) {
|
func TestBatchedEmpty(t *testing.T) {
|
||||||
b := batched(func(p []byte) error { return errors.New("failure") }, 5)
|
b := batched(func(p []byte) error { return errors.New("failure") }, 5)
|
||||||
if !b(make([]byte, 0)) {
|
if b(make([]byte, 0)) != nil {
|
||||||
t.Fatal("empty slice should always return successful")
|
t.Fatal("empty slice should always return successful")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -40,25 +40,7 @@ type reader struct {
|
|||||||
|
|
||||||
// altGetRandom if non-nil specifies an OS-specific function to get
|
// altGetRandom if non-nil specifies an OS-specific function to get
|
||||||
// urandom-style randomness.
|
// urandom-style randomness.
|
||||||
var altGetRandom func([]byte) (ok bool)
|
var altGetRandom func([]byte) (err error)
|
||||||
|
|
||||||
// batched returns a function that calls f to populate a []byte by chunking it
|
|
||||||
// into subslices of, at most, readMax bytes.
|
|
||||||
func batched(f func([]byte) error, readMax int) func([]byte) bool {
|
|
||||||
return func(out []byte) bool {
|
|
||||||
for len(out) > 0 {
|
|
||||||
read := len(out)
|
|
||||||
if read > readMax {
|
|
||||||
read = readMax
|
|
||||||
}
|
|
||||||
if f(out[:read]) != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
out = out[read:]
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func warnBlocked() {
|
func warnBlocked() {
|
||||||
println("crypto/rand: blocked for 60 seconds waiting to read random data from the kernel")
|
println("crypto/rand: blocked for 60 seconds waiting to read random data from the kernel")
|
||||||
@ -72,7 +54,7 @@ func (r *reader) Read(b []byte) (n int, err error) {
|
|||||||
t := time.AfterFunc(time.Minute, warnBlocked)
|
t := time.AfterFunc(time.Minute, warnBlocked)
|
||||||
defer t.Stop()
|
defer t.Stop()
|
||||||
}
|
}
|
||||||
if altGetRandom != nil && altGetRandom(b) {
|
if altGetRandom != nil && altGetRandom(b) == nil {
|
||||||
return len(b), nil
|
return len(b), nil
|
||||||
}
|
}
|
||||||
if atomic.LoadUint32(&r.used) != 2 {
|
if atomic.LoadUint32(&r.used) != 2 {
|
||||||
|
@ -9,7 +9,6 @@ package rand
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"internal/syscall/windows"
|
"internal/syscall/windows"
|
||||||
"os"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() { Reader = &rngReader{} }
|
func init() { Reader = &rngReader{} }
|
||||||
@ -17,16 +16,11 @@ func init() { Reader = &rngReader{} }
|
|||||||
type rngReader struct{}
|
type rngReader struct{}
|
||||||
|
|
||||||
func (r *rngReader) Read(b []byte) (n int, err error) {
|
func (r *rngReader) Read(b []byte) (n int, err error) {
|
||||||
// RtlGenRandom only accepts 2**32-1 bytes at a time, so truncate.
|
// RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
|
||||||
inputLen := uint32(len(b))
|
// most 1<<31-1 bytes at a time so that this works the same on 32-bit
|
||||||
|
// and 64-bit systems.
|
||||||
if inputLen == 0 {
|
if err := batched(windows.RtlGenRandom, 1<<31-1)(b); err != nil {
|
||||||
return 0, nil
|
return 0, err
|
||||||
}
|
}
|
||||||
|
return len(b), nil
|
||||||
err = windows.RtlGenRandom(b)
|
|
||||||
if err != nil {
|
|
||||||
return 0, os.NewSyscallError("RtlGenRandom", err)
|
|
||||||
}
|
|
||||||
return int(inputLen), nil
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user