mirror of
https://github.com/golang/go
synced 2024-11-11 23:40:22 -07:00
bufio: fix potential endless loop in ReadByte
Also: Simplify ReadSlice implementation and ensure that it doesn't call fill() with a full buffer (this caused a failure in net/textproto TestLargeReadMIMEHeader because fill() wasn't able to read more data). Fixes #7745. LGTM=bradfitz R=r, bradfitz CC=golang-codereviews https://golang.org/cl/86590043
This commit is contained in:
parent
08d8eca968
commit
7b6bc3ebb3
@ -88,15 +88,26 @@ func (b *Reader) fill() {
|
|||||||
b.r = 0
|
b.r = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read new data.
|
if b.w >= len(b.buf) {
|
||||||
n, err := b.rd.Read(b.buf[b.w:])
|
panic("bufio: tried to fill full buffer")
|
||||||
if n < 0 {
|
|
||||||
panic(errNegativeRead)
|
|
||||||
}
|
}
|
||||||
b.w += n
|
|
||||||
if err != nil {
|
// Read new data: try a limited number of times.
|
||||||
b.err = err
|
for i := maxConsecutiveEmptyReads; i > 0; i-- {
|
||||||
|
n, err := b.rd.Read(b.buf[b.w:])
|
||||||
|
if n < 0 {
|
||||||
|
panic(errNegativeRead)
|
||||||
|
}
|
||||||
|
b.w += n
|
||||||
|
if err != nil {
|
||||||
|
b.err = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if n > 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
b.err = io.ErrNoProgress
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Reader) readErr() error {
|
func (b *Reader) readErr() error {
|
||||||
@ -116,8 +127,9 @@ func (b *Reader) Peek(n int) ([]byte, error) {
|
|||||||
if n > len(b.buf) {
|
if n > len(b.buf) {
|
||||||
return nil, ErrBufferFull
|
return nil, ErrBufferFull
|
||||||
}
|
}
|
||||||
|
// 0 <= n <= len(b.buf)
|
||||||
for b.w-b.r < n && b.err == nil {
|
for b.w-b.r < n && b.err == nil {
|
||||||
b.fill()
|
b.fill() // b.w-b.r < len(b.buf) => buffer is not full
|
||||||
}
|
}
|
||||||
m := b.w - b.r
|
m := b.w - b.r
|
||||||
if m > n {
|
if m > n {
|
||||||
@ -143,7 +155,7 @@ func (b *Reader) Read(p []byte) (n int, err error) {
|
|||||||
if n == 0 {
|
if n == 0 {
|
||||||
return 0, b.readErr()
|
return 0, b.readErr()
|
||||||
}
|
}
|
||||||
if b.w == b.r {
|
if b.r == b.w {
|
||||||
if b.err != nil {
|
if b.err != nil {
|
||||||
return 0, b.readErr()
|
return 0, b.readErr()
|
||||||
}
|
}
|
||||||
@ -151,13 +163,16 @@ func (b *Reader) Read(p []byte) (n int, err error) {
|
|||||||
// Large read, empty buffer.
|
// Large read, empty buffer.
|
||||||
// Read directly into p to avoid copy.
|
// Read directly into p to avoid copy.
|
||||||
n, b.err = b.rd.Read(p)
|
n, b.err = b.rd.Read(p)
|
||||||
|
if n < 0 {
|
||||||
|
panic(errNegativeRead)
|
||||||
|
}
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
b.lastByte = int(p[n-1])
|
b.lastByte = int(p[n-1])
|
||||||
b.lastRuneSize = -1
|
b.lastRuneSize = -1
|
||||||
}
|
}
|
||||||
return n, b.readErr()
|
return n, b.readErr()
|
||||||
}
|
}
|
||||||
b.fill()
|
b.fill() // buffer is empty
|
||||||
if b.w == b.r {
|
if b.w == b.r {
|
||||||
return 0, b.readErr()
|
return 0, b.readErr()
|
||||||
}
|
}
|
||||||
@ -181,7 +196,7 @@ func (b *Reader) ReadByte() (c byte, err error) {
|
|||||||
if b.err != nil {
|
if b.err != nil {
|
||||||
return 0, b.readErr()
|
return 0, b.readErr()
|
||||||
}
|
}
|
||||||
b.fill()
|
b.fill() // buffer is empty
|
||||||
}
|
}
|
||||||
c = b.buf[b.r]
|
c = b.buf[b.r]
|
||||||
b.r++
|
b.r++
|
||||||
@ -211,8 +226,8 @@ func (b *Reader) UnreadByte() error {
|
|||||||
// rune and its size in bytes. If the encoded rune is invalid, it consumes one byte
|
// rune and its size in bytes. If the encoded rune is invalid, it consumes one byte
|
||||||
// and returns unicode.ReplacementChar (U+FFFD) with a size of 1.
|
// and returns unicode.ReplacementChar (U+FFFD) with a size of 1.
|
||||||
func (b *Reader) ReadRune() (r rune, size int, err error) {
|
func (b *Reader) ReadRune() (r rune, size int, err error) {
|
||||||
for b.r+utf8.UTFMax > b.w && !utf8.FullRune(b.buf[b.r:b.w]) && b.err == nil {
|
for b.r+utf8.UTFMax > b.w && !utf8.FullRune(b.buf[b.r:b.w]) && b.err == nil && b.w-b.r < len(b.buf) {
|
||||||
b.fill()
|
b.fill() // b.w-b.r < len(buf) => buffer is not full
|
||||||
}
|
}
|
||||||
b.lastRuneSize = -1
|
b.lastRuneSize = -1
|
||||||
if b.r == b.w {
|
if b.r == b.w {
|
||||||
@ -256,36 +271,28 @@ func (b *Reader) Buffered() int { return b.w - b.r }
|
|||||||
// ReadBytes or ReadString instead.
|
// ReadBytes or ReadString instead.
|
||||||
// ReadSlice returns err != nil if and only if line does not end in delim.
|
// ReadSlice returns err != nil if and only if line does not end in delim.
|
||||||
func (b *Reader) ReadSlice(delim byte) (line []byte, err error) {
|
func (b *Reader) ReadSlice(delim byte) (line []byte, err error) {
|
||||||
// Look in buffer.
|
|
||||||
if i := bytes.IndexByte(b.buf[b.r:b.w], delim); i >= 0 {
|
|
||||||
line1 := b.buf[b.r : b.r+i+1]
|
|
||||||
b.r += i + 1
|
|
||||||
return line1, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read more into buffer, until buffer fills or we find delim.
|
|
||||||
for {
|
for {
|
||||||
|
// Search buffer.
|
||||||
|
if i := bytes.IndexByte(b.buf[b.r:b.w], delim); i >= 0 {
|
||||||
|
line := b.buf[b.r : b.r+i+1]
|
||||||
|
b.r += i + 1
|
||||||
|
return line, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pending error?
|
||||||
if b.err != nil {
|
if b.err != nil {
|
||||||
line := b.buf[b.r:b.w]
|
line := b.buf[b.r:b.w]
|
||||||
b.r = b.w
|
b.r = b.w
|
||||||
return line, b.readErr()
|
return line, b.readErr()
|
||||||
}
|
}
|
||||||
|
|
||||||
n := b.Buffered()
|
// Buffer full?
|
||||||
b.fill()
|
if n := b.Buffered(); n >= len(b.buf) {
|
||||||
|
|
||||||
// Search new part of buffer
|
|
||||||
if i := bytes.IndexByte(b.buf[n:b.w], delim); i >= 0 {
|
|
||||||
line := b.buf[0 : n+i+1]
|
|
||||||
b.r = n + i + 1
|
|
||||||
return line, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Buffer is full?
|
|
||||||
if b.Buffered() >= len(b.buf) {
|
|
||||||
b.r = b.w
|
b.r = b.w
|
||||||
return b.buf, ErrBufferFull
|
return b.buf, ErrBufferFull
|
||||||
}
|
}
|
||||||
|
|
||||||
|
b.fill() // buffer is not full
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -417,12 +424,18 @@ func (b *Reader) WriteTo(w io.Writer) (n int64, err error) {
|
|||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for b.fill(); b.r < b.w; b.fill() {
|
if b.w-b.r < len(b.buf) {
|
||||||
|
b.fill() // buffer not full
|
||||||
|
}
|
||||||
|
|
||||||
|
for b.r < b.w {
|
||||||
|
// b.r < b.w => buffer is not empty
|
||||||
m, err := b.writeBuf(w)
|
m, err := b.writeBuf(w)
|
||||||
n += m
|
n += m
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
b.fill() // buffer is empty
|
||||||
}
|
}
|
||||||
|
|
||||||
if b.err == io.EOF {
|
if b.err == io.EOF {
|
||||||
@ -435,6 +448,9 @@ func (b *Reader) WriteTo(w io.Writer) (n int64, err error) {
|
|||||||
// writeBuf writes the Reader's buffer to the writer.
|
// writeBuf writes the Reader's buffer to the writer.
|
||||||
func (b *Reader) writeBuf(w io.Writer) (int64, error) {
|
func (b *Reader) writeBuf(w io.Writer) (int64, error) {
|
||||||
n, err := w.Write(b.buf[b.r:b.w])
|
n, err := w.Write(b.buf[b.r:b.w])
|
||||||
|
if n < b.r-b.w {
|
||||||
|
panic(errors.New("bufio: writer did not write all data"))
|
||||||
|
}
|
||||||
b.r += n
|
b.r += n
|
||||||
return int64(n), err
|
return int64(n), err
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"testing/iotest"
|
"testing/iotest"
|
||||||
|
"time"
|
||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -174,6 +175,34 @@ func TestReader(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type zeroReader struct{}
|
||||||
|
|
||||||
|
func (zeroReader) Read(p []byte) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestZeroReader(t *testing.T) {
|
||||||
|
var z zeroReader
|
||||||
|
r := NewReader(z)
|
||||||
|
|
||||||
|
c := make(chan error)
|
||||||
|
go func() {
|
||||||
|
_, err := r.ReadByte()
|
||||||
|
c <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err := <-c:
|
||||||
|
if err == nil {
|
||||||
|
t.Error("error expected")
|
||||||
|
} else if err != io.ErrNoProgress {
|
||||||
|
t.Error("unexpected error:", err)
|
||||||
|
}
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Error("test timed out (endless loop in ReadByte?)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// A StringReader delivers its data one string segment at a time via Read.
|
// A StringReader delivers its data one string segment at a time via Read.
|
||||||
type StringReader struct {
|
type StringReader struct {
|
||||||
data []string
|
data []string
|
||||||
|
Loading…
Reference in New Issue
Block a user