From 696bf79350b5cb0e977def1fc98ba6d6c8bd829f Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Fri, 20 Jan 2012 13:51:49 -0800 Subject: [PATCH] bytes.Buffer: turn buffer size overflows into errors Fixes #2743. R=golang-dev, gri, r CC=golang-dev https://golang.org/cl/5556072 --- src/pkg/bytes/buffer.go | 40 ++++++++++++++++++++++++++++++++---- src/pkg/bytes/buffer_test.go | 16 +++++++++++++++ 2 files changed, 52 insertions(+), 4 deletions(-) diff --git a/src/pkg/bytes/buffer.go b/src/pkg/bytes/buffer.go index 77757af1d8..9d58326a4f 100644 --- a/src/pkg/bytes/buffer.go +++ b/src/pkg/bytes/buffer.go @@ -33,6 +33,9 @@ const ( opRead // Any other read operation. ) +// ErrTooLarge is returned if there is too much data to fit in a buffer. +var ErrTooLarge = errors.New("bytes.Buffer: too large") + // Bytes returns a slice of the contents of the unread portion of the buffer; // len(b.Bytes()) == b.Len(). If the caller changes the contents of the // returned slice, the contents of the buffer will change provided there @@ -68,8 +71,10 @@ func (b *Buffer) Truncate(n int) { // b.Reset() is the same as b.Truncate(0). func (b *Buffer) Reset() { b.Truncate(0) } -// Grow buffer to guarantee space for n more bytes. -// Return index where bytes should be written. +// grow grows the buffer to guarantee space for n more bytes. +// It returns the index where bytes should be written. +// If the buffer can't grow, it returns -1, which will +// become ErrTooLarge in the caller. func (b *Buffer) grow(n int) int { m := b.Len() // If buffer is empty, reset to recover space. @@ -82,7 +87,10 @@ func (b *Buffer) grow(n int) int { buf = b.bootstrap[0:] } else { // not enough space anywhere - buf = make([]byte, 2*cap(b.buf)+n) + buf = makeSlice(2*cap(b.buf) + n) + if buf == nil { + return -1 + } copy(buf, b.buf[b.off:]) } b.buf = buf @@ -97,6 +105,9 @@ func (b *Buffer) grow(n int) int { func (b *Buffer) Write(p []byte) (n int, err error) { b.lastRead = opInvalid m := b.grow(len(p)) + if m < 0 { + return 0, ErrTooLarge + } return copy(b.buf[m:], p), nil } @@ -105,6 +116,9 @@ func (b *Buffer) Write(p []byte) (n int, err error) { func (b *Buffer) WriteString(s string) (n int, err error) { b.lastRead = opInvalid m := b.grow(len(s)) + if m < 0 { + return 0, ErrTooLarge + } return copy(b.buf[m:], s), nil } @@ -133,7 +147,10 @@ func (b *Buffer) ReadFrom(r io.Reader) (n int64, err error) { newBuf = b.buf[0 : len(b.buf)-b.off] } else { // not enough space at end; put space on end - newBuf = make([]byte, len(b.buf)-b.off, 2*(cap(b.buf)-b.off)+MinRead) + newBuf = makeSlice(2*(cap(b.buf)-b.off) + MinRead)[:len(b.buf)-b.off] + if newBuf == nil { + return n, ErrTooLarge + } } copy(newBuf, b.buf[b.off:]) b.buf = newBuf @@ -152,6 +169,18 @@ func (b *Buffer) ReadFrom(r io.Reader) (n int64, err error) { return n, nil // err is EOF, so return nil explicitly } +// makeSlice allocates a slice of size n, returning nil if the slice cannot be allocated. +func makeSlice(n int) []byte { + if n < 0 { + return nil + } + // Catch out of memory panics. + defer func() { + recover() + }() + return make([]byte, n) +} + // WriteTo writes data to w until the buffer is drained or an error // occurs. The return value n is the number of bytes written; it always // fits into an int, but it is int64 to match the io.WriterTo interface. @@ -179,6 +208,9 @@ func (b *Buffer) WriteTo(w io.Writer) (n int64, err error) { func (b *Buffer) WriteByte(c byte) error { b.lastRead = opInvalid m := b.grow(1) + if m < 0 { + return ErrTooLarge + } b.buf[m] = c return nil } diff --git a/src/pkg/bytes/buffer_test.go b/src/pkg/bytes/buffer_test.go index d0af11f104..a36d1d010a 100644 --- a/src/pkg/bytes/buffer_test.go +++ b/src/pkg/bytes/buffer_test.go @@ -386,3 +386,19 @@ func TestReadEmptyAtEOF(t *testing.T) { t.Errorf("wrong count; got %d want 0", n) } } + +func TestHuge(t *testing.T) { + // About to use tons of memory, so avoid for simple installation testing. + if testing.Short() { + return + } + b := new(Buffer) + big := make([]byte, 500e6) + for i := 0; i < 1000; i++ { + if _, err := b.Write(big); err != nil { + // Got error as expected. Stop + return + } + } + t.Error("error expected") +}