1
0
mirror of https://github.com/golang/go synced 2024-11-14 20:00:31 -07:00

crypto/rand: prevent Read argument from escaping to heap

Mateusz had this idea before me in CL 578516, but it got much easier
after the recent cleanup.

It's unfortunate we lose the test coverage of batched, but the package
is significantly simpler than when we introduced it, so it should be
easier to review that everything does what it's supposed to do.

Fixes #66779

Co-authored-by: Mateusz Poliwczak <mpoliwczak34@gmail.com>
Change-Id: Id35f1172e678fec184efb0efae3631afac8121d0
Reviewed-on: https://go-review.googlesource.com/c/go/+/602498
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
This commit is contained in:
Filippo Valsorda 2024-08-01 19:59:07 +02:00
parent c050d42e1a
commit 534d6a1a9c
6 changed files with 70 additions and 96 deletions

View File

@ -70,28 +70,20 @@ func fatal(string)
// If [Reader] is set to a non-default value, Read calls [io.ReadFull] on
// [Reader] and crashes the program irrecoverably if an error is returned.
func Read(b []byte) (n int, err error) {
_, err = io.ReadFull(Reader, b)
// We don't want b to escape to the heap, but escape analysis can't see
// through a potentially overridden Reader, so we special-case the default
// case which we can keep non-escaping, and in the general case we read into
// a heap buffer and copy from it.
if r, ok := Reader.(*reader); ok {
_, err = r.Read(b)
} else {
bb := make([]byte, len(b))
_, err = io.ReadFull(Reader, bb)
copy(b, bb)
}
if err != nil {
fatal("crypto/rand: failed to read random data (see https://go.dev/issue/66821): " + err.Error())
panic("unreachable") // To be sure.
}
return len(b), nil
}
// batched returns a function that calls f to populate a []byte by chunking it
// into subslices of, at most, readMax bytes.
func batched(f func([]byte) error, readMax int) func([]byte) error {
return func(out []byte) error {
for len(out) > 0 {
read := len(out)
if read > readMax {
read = readMax
}
if err := f(out[:read]); err != nil {
return err
}
out = out[read:]
}
return nil
}
}

View File

@ -1,75 +0,0 @@
// Copyright 2014 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build unix
package rand
import (
"bytes"
"errors"
prand "math/rand"
"testing"
)
func TestBatched(t *testing.T) {
fillBatched := batched(func(p []byte) error {
for i := range p {
p[i] = byte(i)
}
return nil
}, 5)
p := make([]byte, 13)
if err := fillBatched(p); err != nil {
t.Fatalf("batched function returned error: %s", err)
}
expected := []byte{0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2}
if !bytes.Equal(expected, p) {
t.Errorf("incorrect batch result: got %x, want %x", p, expected)
}
}
func TestBatchedBuffering(t *testing.T) {
backingStore := make([]byte, 1<<23)
prand.Read(backingStore)
backingMarker := backingStore[:]
output := make([]byte, len(backingStore))
outputMarker := output[:]
fillBatched := batched(func(p []byte) error {
n := copy(p, backingMarker)
backingMarker = backingMarker[n:]
return nil
}, 731)
for len(outputMarker) > 0 {
max := 9200
if max > len(outputMarker) {
max = len(outputMarker)
}
howMuch := prand.Intn(max + 1)
if err := fillBatched(outputMarker[:howMuch]); err != nil {
t.Fatalf("batched function returned error: %s", err)
}
outputMarker = outputMarker[howMuch:]
}
if !bytes.Equal(backingStore, output) {
t.Error("incorrect batch result")
}
}
func TestBatchedError(t *testing.T) {
b := batched(func(p []byte) error { return errors.New("failure") }, 5)
if b(make([]byte, 13)) == nil {
t.Fatal("batched function should have returned an error")
}
}
func TestBatchedEmpty(t *testing.T) {
b := batched(func(p []byte) error { return errors.New("failure") }, 5)
if b(make([]byte, 0)) != nil {
t.Fatal("empty slice should always return successful")
}
}

View File

@ -8,5 +8,17 @@ package rand
import "internal/syscall/unix"
func read(b []byte) error {
for len(b) > 0 {
size := len(b)
if size > 256 {
size = 256
}
// getentropy(2) returns a maximum of 256 bytes per call.
var read = batched(unix.GetEntropy, 256)
if err := unix.GetEntropy(b[:size]); err != nil {
return err
}
b = b[size:]
}
return nil
}

View File

@ -24,3 +24,21 @@ func getRandom(b []byte) error {
js.CopyBytesToGo(b, a)
return nil
}
// batched returns a function that calls f to populate a []byte by chunking it
// into subslices of, at most, readMax bytes.
func batched(f func([]byte) error, readMax int) func([]byte) error {
return func(out []byte) error {
for len(out) > 0 {
read := len(out)
if read > readMax {
read = readMax
}
if err := f(out[:read]); err != nil {
return err
}
out = out[read:]
}
return nil
}
}

View File

@ -7,8 +7,10 @@ package rand_test
import (
"bytes"
"compress/flate"
"crypto/internal/boring"
. "crypto/rand"
"io"
"runtime"
"sync"
"testing"
)
@ -121,6 +123,30 @@ func TestConcurrentRead(t *testing.T) {
wg.Wait()
}
var sink byte
func TestAllocations(t *testing.T) {
if boring.Enabled {
// Might be fixable with https://go.dev/issue/56378.
t.Skip("boringcrypto allocates")
}
if runtime.GOOS == "aix" {
t.Skip("/dev/urandom read path allocates")
}
if runtime.GOOS == "js" {
t.Skip("syscall/js allocates")
}
n := int(testing.AllocsPerRun(10, func() {
buf := make([]byte, 32)
Read(buf)
sink ^= buf[0]
}))
if n > 0 {
t.Errorf("allocs = %d, want 0", n)
}
}
func BenchmarkRead(b *testing.B) {
b.Run("4", func(b *testing.B) {
benchmarkRead(b, 4)

View File

@ -14,4 +14,5 @@ func GetEntropy(p []byte) error {
}
//go:linkname getentropy syscall.getentropy
//go:noescape
func getentropy(p []byte) error