1
0
mirror of https://github.com/golang/go synced 2024-09-23 11:20:17 -06:00

crypto/rand: properly handle large Read on windows

Use the batched reader to chunk large Read calls on windows to a max of
1 << 31 - 1 bytes. This prevents an infinite loop when trying to read
more than 1 << 32 -1 bytes, due to how RtlGenRandom works.

This change moves the batched function from rand_unix.go to rand.go,
since it is now needed for both windows and unix implementations.

Fixes #52561

Change-Id: Id98fc4b1427e5cb2132762a445b2aed646a37473
Reviewed-on: https://go-review.googlesource.com/c/go/+/402257
Run-TryBot: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Filippo Valsorda <valsorda@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
Roland Shoemaker 2022-04-25 19:02:35 -07:00
parent 7c74b0db8a
commit bb1f441618
4 changed files with 32 additions and 38 deletions

View File

@ -24,3 +24,21 @@ var Reader io.Reader
func Read(b []byte) (n int, err error) { func Read(b []byte) (n int, err error) {
return io.ReadFull(Reader, b) return io.ReadFull(Reader, b)
} }
// batched returns a function that calls f to populate a []byte by chunking it
// into subslices of, at most, readMax bytes.
func batched(f func([]byte) error, readMax int) func([]byte) error {
return func(out []byte) error {
for len(out) > 0 {
read := len(out)
if read > readMax {
read = readMax
}
if err := f(out[:read]); err != nil {
return err
}
out = out[read:]
}
return nil
}
}

View File

@ -23,8 +23,8 @@ func TestBatched(t *testing.T) {
}, 5) }, 5)
p := make([]byte, 13) p := make([]byte, 13)
if !fillBatched(p) { if err := fillBatched(p); err != nil {
t.Fatal("batched function returned false") t.Fatalf("batched function returned error: %s", err)
} }
expected := []byte{0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2} expected := []byte{0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2}
if !bytes.Equal(expected, p) { if !bytes.Equal(expected, p) {
@ -55,8 +55,8 @@ func TestBatchedBuffering(t *testing.T) {
max = len(outputMarker) max = len(outputMarker)
} }
howMuch := prand.Intn(max + 1) howMuch := prand.Intn(max + 1)
if !fillBatched(outputMarker[:howMuch]) { if err := fillBatched(outputMarker[:howMuch]); err != nil {
t.Fatal("batched function returned false") t.Fatalf("batched function returned error: %s", err)
} }
outputMarker = outputMarker[howMuch:] outputMarker = outputMarker[howMuch:]
} }
@ -67,14 +67,14 @@ func TestBatchedBuffering(t *testing.T) {
func TestBatchedError(t *testing.T) { func TestBatchedError(t *testing.T) {
b := batched(func(p []byte) error { return errors.New("failure") }, 5) b := batched(func(p []byte) error { return errors.New("failure") }, 5)
if b(make([]byte, 13)) { if b(make([]byte, 13)) == nil {
t.Fatal("batched function should have returned an error") t.Fatal("batched function should have returned an error")
} }
} }
func TestBatchedEmpty(t *testing.T) { func TestBatchedEmpty(t *testing.T) {
b := batched(func(p []byte) error { return errors.New("failure") }, 5) b := batched(func(p []byte) error { return errors.New("failure") }, 5)
if !b(make([]byte, 0)) { if b(make([]byte, 0)) != nil {
t.Fatal("empty slice should always return successful") t.Fatal("empty slice should always return successful")
} }
} }

View File

@ -40,25 +40,7 @@ type reader struct {
// altGetRandom if non-nil specifies an OS-specific function to get // altGetRandom if non-nil specifies an OS-specific function to get
// urandom-style randomness. // urandom-style randomness.
var altGetRandom func([]byte) (ok bool) var altGetRandom func([]byte) (err error)
// batched returns a function that calls f to populate a []byte by chunking it
// into subslices of, at most, readMax bytes.
func batched(f func([]byte) error, readMax int) func([]byte) bool {
return func(out []byte) bool {
for len(out) > 0 {
read := len(out)
if read > readMax {
read = readMax
}
if f(out[:read]) != nil {
return false
}
out = out[read:]
}
return true
}
}
func warnBlocked() { func warnBlocked() {
println("crypto/rand: blocked for 60 seconds waiting to read random data from the kernel") println("crypto/rand: blocked for 60 seconds waiting to read random data from the kernel")
@ -72,7 +54,7 @@ func (r *reader) Read(b []byte) (n int, err error) {
t := time.AfterFunc(time.Minute, warnBlocked) t := time.AfterFunc(time.Minute, warnBlocked)
defer t.Stop() defer t.Stop()
} }
if altGetRandom != nil && altGetRandom(b) { if altGetRandom != nil && altGetRandom(b) == nil {
return len(b), nil return len(b), nil
} }
if atomic.LoadUint32(&r.used) != 2 { if atomic.LoadUint32(&r.used) != 2 {

View File

@ -9,7 +9,6 @@ package rand
import ( import (
"internal/syscall/windows" "internal/syscall/windows"
"os"
) )
func init() { Reader = &rngReader{} } func init() { Reader = &rngReader{} }
@ -17,16 +16,11 @@ func init() { Reader = &rngReader{} }
type rngReader struct{} type rngReader struct{}
func (r *rngReader) Read(b []byte) (n int, err error) { func (r *rngReader) Read(b []byte) (n int, err error) {
// RtlGenRandom only accepts 2**32-1 bytes at a time, so truncate. // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
inputLen := uint32(len(b)) // most 1<<31-1 bytes at a time so that this works the same on 32-bit
// and 64-bit systems.
if inputLen == 0 { if err := batched(windows.RtlGenRandom, 1<<31-1)(b); err != nil {
return 0, nil return 0, err
} }
return len(b), nil
err = windows.RtlGenRandom(b)
if err != nil {
return 0, os.NewSyscallError("RtlGenRandom", err)
}
return int(inputLen), nil
} }