From bb1f4416180511231de6d17a1f2f55c82aafc863 Mon Sep 17 00:00:00 2001 From: Roland Shoemaker Date: Mon, 25 Apr 2022 19:02:35 -0700 Subject: [PATCH] crypto/rand: properly handle large Read on windows Use the batched reader to chunk large Read calls on windows to a max of 1 << 31 - 1 bytes. This prevents an infinite loop when trying to read more than 1 << 32 -1 bytes, due to how RtlGenRandom works. This change moves the batched function from rand_unix.go to rand.go, since it is now needed for both windows and unix implementations. Fixes #52561 Change-Id: Id98fc4b1427e5cb2132762a445b2aed646a37473 Reviewed-on: https://go-review.googlesource.com/c/go/+/402257 Run-TryBot: Roland Shoemaker Reviewed-by: Filippo Valsorda Reviewed-by: Filippo Valsorda TryBot-Result: Gopher Robot --- src/crypto/rand/rand.go | 18 ++++++++++++++++++ src/crypto/rand/rand_batched_test.go | 12 ++++++------ src/crypto/rand/rand_unix.go | 22 ++-------------------- src/crypto/rand/rand_windows.go | 18 ++++++------------ 4 files changed, 32 insertions(+), 38 deletions(-) diff --git a/src/crypto/rand/rand.go b/src/crypto/rand/rand.go index b6248a4438..af85b966df 100644 --- a/src/crypto/rand/rand.go +++ b/src/crypto/rand/rand.go @@ -24,3 +24,21 @@ var Reader io.Reader func Read(b []byte) (n int, err error) { return io.ReadFull(Reader, b) } + +// batched returns a function that calls f to populate a []byte by chunking it +// into subslices of, at most, readMax bytes. +func batched(f func([]byte) error, readMax int) func([]byte) error { + return func(out []byte) error { + for len(out) > 0 { + read := len(out) + if read > readMax { + read = readMax + } + if err := f(out[:read]); err != nil { + return err + } + out = out[read:] + } + return nil + } +} diff --git a/src/crypto/rand/rand_batched_test.go b/src/crypto/rand/rand_batched_test.go index dfb9517d5e..89953776a8 100644 --- a/src/crypto/rand/rand_batched_test.go +++ b/src/crypto/rand/rand_batched_test.go @@ -23,8 +23,8 @@ func TestBatched(t *testing.T) { }, 5) p := make([]byte, 13) - if !fillBatched(p) { - t.Fatal("batched function returned false") + if err := fillBatched(p); err != nil { + t.Fatalf("batched function returned error: %s", err) } expected := []byte{0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2} if !bytes.Equal(expected, p) { @@ -55,8 +55,8 @@ func TestBatchedBuffering(t *testing.T) { max = len(outputMarker) } howMuch := prand.Intn(max + 1) - if !fillBatched(outputMarker[:howMuch]) { - t.Fatal("batched function returned false") + if err := fillBatched(outputMarker[:howMuch]); err != nil { + t.Fatalf("batched function returned error: %s", err) } outputMarker = outputMarker[howMuch:] } @@ -67,14 +67,14 @@ func TestBatchedBuffering(t *testing.T) { func TestBatchedError(t *testing.T) { b := batched(func(p []byte) error { return errors.New("failure") }, 5) - if b(make([]byte, 13)) { + if b(make([]byte, 13)) == nil { t.Fatal("batched function should have returned an error") } } func TestBatchedEmpty(t *testing.T) { b := batched(func(p []byte) error { return errors.New("failure") }, 5) - if !b(make([]byte, 0)) { + if b(make([]byte, 0)) != nil { t.Fatal("empty slice should always return successful") } } diff --git a/src/crypto/rand/rand_unix.go b/src/crypto/rand/rand_unix.go index 87ba9e3af7..64b865289d 100644 --- a/src/crypto/rand/rand_unix.go +++ b/src/crypto/rand/rand_unix.go @@ -40,25 +40,7 @@ type reader struct { // altGetRandom if non-nil specifies an OS-specific function to get // urandom-style randomness. -var altGetRandom func([]byte) (ok bool) - -// batched returns a function that calls f to populate a []byte by chunking it -// into subslices of, at most, readMax bytes. -func batched(f func([]byte) error, readMax int) func([]byte) bool { - return func(out []byte) bool { - for len(out) > 0 { - read := len(out) - if read > readMax { - read = readMax - } - if f(out[:read]) != nil { - return false - } - out = out[read:] - } - return true - } -} +var altGetRandom func([]byte) (err error) func warnBlocked() { println("crypto/rand: blocked for 60 seconds waiting to read random data from the kernel") @@ -72,7 +54,7 @@ func (r *reader) Read(b []byte) (n int, err error) { t := time.AfterFunc(time.Minute, warnBlocked) defer t.Stop() } - if altGetRandom != nil && altGetRandom(b) { + if altGetRandom != nil && altGetRandom(b) == nil { return len(b), nil } if atomic.LoadUint32(&r.used) != 2 { diff --git a/src/crypto/rand/rand_windows.go b/src/crypto/rand/rand_windows.go index 7379f1489a..6c0655c72b 100644 --- a/src/crypto/rand/rand_windows.go +++ b/src/crypto/rand/rand_windows.go @@ -9,7 +9,6 @@ package rand import ( "internal/syscall/windows" - "os" ) func init() { Reader = &rngReader{} } @@ -17,16 +16,11 @@ func init() { Reader = &rngReader{} } type rngReader struct{} func (r *rngReader) Read(b []byte) (n int, err error) { - // RtlGenRandom only accepts 2**32-1 bytes at a time, so truncate. - inputLen := uint32(len(b)) - - if inputLen == 0 { - return 0, nil + // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at + // most 1<<31-1 bytes at a time so that this works the same on 32-bit + // and 64-bit systems. + if err := batched(windows.RtlGenRandom, 1<<31-1)(b); err != nil { + return 0, err } - - err = windows.RtlGenRandom(b) - if err != nil { - return 0, os.NewSyscallError("RtlGenRandom", err) - } - return int(inputLen), nil + return len(b), nil }