1
0
mirror of https://github.com/golang/go synced 2024-11-11 23:40:22 -07:00

bufio: fix potential endless loop in ReadByte

Also: Simplify ReadSlice implementation and
ensure that it doesn't call fill() with a full
buffer (this caused a failure in net/textproto
TestLargeReadMIMEHeader because fill() wasn't able
to read more data).

Fixes #7745.

LGTM=bradfitz
R=r, bradfitz
CC=golang-codereviews
https://golang.org/cl/86590043
This commit is contained in:
Robert Griesemer 2014-04-10 21:46:00 -07:00
parent 08d8eca968
commit 7b6bc3ebb3
2 changed files with 79 additions and 34 deletions

View File

@ -88,15 +88,26 @@ func (b *Reader) fill() {
b.r = 0 b.r = 0
} }
// Read new data. if b.w >= len(b.buf) {
n, err := b.rd.Read(b.buf[b.w:]) panic("bufio: tried to fill full buffer")
if n < 0 {
panic(errNegativeRead)
} }
b.w += n
if err != nil { // Read new data: try a limited number of times.
b.err = err for i := maxConsecutiveEmptyReads; i > 0; i-- {
n, err := b.rd.Read(b.buf[b.w:])
if n < 0 {
panic(errNegativeRead)
}
b.w += n
if err != nil {
b.err = err
return
}
if n > 0 {
return
}
} }
b.err = io.ErrNoProgress
} }
func (b *Reader) readErr() error { func (b *Reader) readErr() error {
@ -116,8 +127,9 @@ func (b *Reader) Peek(n int) ([]byte, error) {
if n > len(b.buf) { if n > len(b.buf) {
return nil, ErrBufferFull return nil, ErrBufferFull
} }
// 0 <= n <= len(b.buf)
for b.w-b.r < n && b.err == nil { for b.w-b.r < n && b.err == nil {
b.fill() b.fill() // b.w-b.r < len(b.buf) => buffer is not full
} }
m := b.w - b.r m := b.w - b.r
if m > n { if m > n {
@ -143,7 +155,7 @@ func (b *Reader) Read(p []byte) (n int, err error) {
if n == 0 { if n == 0 {
return 0, b.readErr() return 0, b.readErr()
} }
if b.w == b.r { if b.r == b.w {
if b.err != nil { if b.err != nil {
return 0, b.readErr() return 0, b.readErr()
} }
@ -151,13 +163,16 @@ func (b *Reader) Read(p []byte) (n int, err error) {
// Large read, empty buffer. // Large read, empty buffer.
// Read directly into p to avoid copy. // Read directly into p to avoid copy.
n, b.err = b.rd.Read(p) n, b.err = b.rd.Read(p)
if n < 0 {
panic(errNegativeRead)
}
if n > 0 { if n > 0 {
b.lastByte = int(p[n-1]) b.lastByte = int(p[n-1])
b.lastRuneSize = -1 b.lastRuneSize = -1
} }
return n, b.readErr() return n, b.readErr()
} }
b.fill() b.fill() // buffer is empty
if b.w == b.r { if b.w == b.r {
return 0, b.readErr() return 0, b.readErr()
} }
@ -181,7 +196,7 @@ func (b *Reader) ReadByte() (c byte, err error) {
if b.err != nil { if b.err != nil {
return 0, b.readErr() return 0, b.readErr()
} }
b.fill() b.fill() // buffer is empty
} }
c = b.buf[b.r] c = b.buf[b.r]
b.r++ b.r++
@ -211,8 +226,8 @@ func (b *Reader) UnreadByte() error {
// rune and its size in bytes. If the encoded rune is invalid, it consumes one byte // rune and its size in bytes. If the encoded rune is invalid, it consumes one byte
// and returns unicode.ReplacementChar (U+FFFD) with a size of 1. // and returns unicode.ReplacementChar (U+FFFD) with a size of 1.
func (b *Reader) ReadRune() (r rune, size int, err error) { func (b *Reader) ReadRune() (r rune, size int, err error) {
for b.r+utf8.UTFMax > b.w && !utf8.FullRune(b.buf[b.r:b.w]) && b.err == nil { for b.r+utf8.UTFMax > b.w && !utf8.FullRune(b.buf[b.r:b.w]) && b.err == nil && b.w-b.r < len(b.buf) {
b.fill() b.fill() // b.w-b.r < len(buf) => buffer is not full
} }
b.lastRuneSize = -1 b.lastRuneSize = -1
if b.r == b.w { if b.r == b.w {
@ -256,36 +271,28 @@ func (b *Reader) Buffered() int { return b.w - b.r }
// ReadBytes or ReadString instead. // ReadBytes or ReadString instead.
// ReadSlice returns err != nil if and only if line does not end in delim. // ReadSlice returns err != nil if and only if line does not end in delim.
func (b *Reader) ReadSlice(delim byte) (line []byte, err error) { func (b *Reader) ReadSlice(delim byte) (line []byte, err error) {
// Look in buffer.
if i := bytes.IndexByte(b.buf[b.r:b.w], delim); i >= 0 {
line1 := b.buf[b.r : b.r+i+1]
b.r += i + 1
return line1, nil
}
// Read more into buffer, until buffer fills or we find delim.
for { for {
// Search buffer.
if i := bytes.IndexByte(b.buf[b.r:b.w], delim); i >= 0 {
line := b.buf[b.r : b.r+i+1]
b.r += i + 1
return line, nil
}
// Pending error?
if b.err != nil { if b.err != nil {
line := b.buf[b.r:b.w] line := b.buf[b.r:b.w]
b.r = b.w b.r = b.w
return line, b.readErr() return line, b.readErr()
} }
n := b.Buffered() // Buffer full?
b.fill() if n := b.Buffered(); n >= len(b.buf) {
// Search new part of buffer
if i := bytes.IndexByte(b.buf[n:b.w], delim); i >= 0 {
line := b.buf[0 : n+i+1]
b.r = n + i + 1
return line, nil
}
// Buffer is full?
if b.Buffered() >= len(b.buf) {
b.r = b.w b.r = b.w
return b.buf, ErrBufferFull return b.buf, ErrBufferFull
} }
b.fill() // buffer is not full
} }
} }
@ -417,12 +424,18 @@ func (b *Reader) WriteTo(w io.Writer) (n int64, err error) {
return n, err return n, err
} }
for b.fill(); b.r < b.w; b.fill() { if b.w-b.r < len(b.buf) {
b.fill() // buffer not full
}
for b.r < b.w {
// b.r < b.w => buffer is not empty
m, err := b.writeBuf(w) m, err := b.writeBuf(w)
n += m n += m
if err != nil { if err != nil {
return n, err return n, err
} }
b.fill() // buffer is empty
} }
if b.err == io.EOF { if b.err == io.EOF {
@ -435,6 +448,9 @@ func (b *Reader) WriteTo(w io.Writer) (n int64, err error) {
// writeBuf writes the Reader's buffer to the writer. // writeBuf writes the Reader's buffer to the writer.
func (b *Reader) writeBuf(w io.Writer) (int64, error) { func (b *Reader) writeBuf(w io.Writer) (int64, error) {
n, err := w.Write(b.buf[b.r:b.w]) n, err := w.Write(b.buf[b.r:b.w])
if n < b.r-b.w {
panic(errors.New("bufio: writer did not write all data"))
}
b.r += n b.r += n
return int64(n), err return int64(n), err
} }

View File

@ -14,6 +14,7 @@ import (
"strings" "strings"
"testing" "testing"
"testing/iotest" "testing/iotest"
"time"
"unicode/utf8" "unicode/utf8"
) )
@ -174,6 +175,34 @@ func TestReader(t *testing.T) {
} }
} }
type zeroReader struct{}
func (zeroReader) Read(p []byte) (int, error) {
return 0, nil
}
func TestZeroReader(t *testing.T) {
var z zeroReader
r := NewReader(z)
c := make(chan error)
go func() {
_, err := r.ReadByte()
c <- err
}()
select {
case err := <-c:
if err == nil {
t.Error("error expected")
} else if err != io.ErrNoProgress {
t.Error("unexpected error:", err)
}
case <-time.After(time.Second):
t.Error("test timed out (endless loop in ReadByte?)")
}
}
// A StringReader delivers its data one string segment at a time via Read. // A StringReader delivers its data one string segment at a time via Read.
type StringReader struct { type StringReader struct {
data []string data []string