mirror of
https://github.com/golang/go
synced 2024-09-24 05:10:13 -06:00
bufio: fix UnreadByte
Also: - fix error messages in tests - make tests more symmetric Fixes #7607. LGTM=r R=r CC=golang-codereviews https://golang.org/cl/86180043
This commit is contained in:
parent
a8787cd820
commit
8bd9242f7c
@ -177,7 +177,7 @@ func (b *Reader) Read(p []byte) (n int, err error) {
|
||||
// If no byte is available, returns an error.
|
||||
func (b *Reader) ReadByte() (c byte, err error) {
|
||||
b.lastRuneSize = -1
|
||||
for b.w == b.r {
|
||||
for b.r == b.w {
|
||||
if b.err != nil {
|
||||
return 0, b.readErr()
|
||||
}
|
||||
@ -191,19 +191,19 @@ func (b *Reader) ReadByte() (c byte, err error) {
|
||||
|
||||
// UnreadByte unreads the last byte. Only the most recently read byte can be unread.
|
||||
func (b *Reader) UnreadByte() error {
|
||||
b.lastRuneSize = -1
|
||||
if b.r == b.w && b.lastByte >= 0 {
|
||||
b.w = 1
|
||||
b.r = 0
|
||||
b.buf[0] = byte(b.lastByte)
|
||||
b.lastByte = -1
|
||||
return nil
|
||||
}
|
||||
if b.r <= 0 {
|
||||
if b.lastByte < 0 || b.r == 0 && b.w > 0 {
|
||||
return ErrInvalidUnreadByte
|
||||
}
|
||||
b.r--
|
||||
// b.r > 0 || b.w == 0
|
||||
if b.r > 0 {
|
||||
b.r--
|
||||
} else {
|
||||
// b.r == 0 && b.w == 0
|
||||
b.w = 1
|
||||
}
|
||||
b.buf[b.r] = byte(b.lastByte)
|
||||
b.lastByte = -1
|
||||
b.lastRuneSize = -1
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -233,7 +233,7 @@ func (b *Reader) ReadRune() (r rune, size int, err error) {
|
||||
// regard it is stricter than UnreadByte, which will unread the last byte
|
||||
// from any read operation.)
|
||||
func (b *Reader) UnreadRune() error {
|
||||
if b.lastRuneSize < 0 || b.r == 0 {
|
||||
if b.lastRuneSize < 0 || b.r < b.lastRuneSize {
|
||||
return ErrInvalidUnreadRune
|
||||
}
|
||||
b.r -= b.lastRuneSize
|
||||
|
@ -228,66 +228,94 @@ func TestReadRune(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUnreadRune(t *testing.T) {
|
||||
got := ""
|
||||
segments := []string{"Hello, world:", "日本語"}
|
||||
data := strings.Join(segments, "")
|
||||
r := NewReader(&StringReader{data: segments})
|
||||
got := ""
|
||||
want := strings.Join(segments, "")
|
||||
// Normal execution.
|
||||
for {
|
||||
r1, _, err := r.ReadRune()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
t.Error("unexpected EOF")
|
||||
t.Error("unexpected error on ReadRune:", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
got += string(r1)
|
||||
// Put it back and read it again
|
||||
// Put it back and read it again.
|
||||
if err = r.UnreadRune(); err != nil {
|
||||
t.Error("unexpected error on UnreadRune:", err)
|
||||
t.Fatal("unexpected error on UnreadRune:", err)
|
||||
}
|
||||
r2, _, err := r.ReadRune()
|
||||
if err != nil {
|
||||
t.Error("unexpected error reading after unreading:", err)
|
||||
t.Fatal("unexpected error reading after unreading:", err)
|
||||
}
|
||||
if r1 != r2 {
|
||||
t.Errorf("incorrect rune after unread: got %c wanted %c", r1, r2)
|
||||
t.Fatalf("incorrect rune after unread: got %c, want %c", r1, r2)
|
||||
}
|
||||
}
|
||||
if got != data {
|
||||
t.Errorf("want=%q got=%q", data, got)
|
||||
if got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnreadByte(t *testing.T) {
|
||||
want := "Hello, world"
|
||||
got := ""
|
||||
segments := []string{"Hello, ", "world"}
|
||||
r := NewReader(&StringReader{data: segments})
|
||||
got := ""
|
||||
want := strings.Join(segments, "")
|
||||
// Normal execution.
|
||||
for {
|
||||
b1, err := r.ReadByte()
|
||||
if err != nil {
|
||||
if err != io.EOF {
|
||||
t.Fatal("unexpected EOF")
|
||||
t.Error("unexpected error on ReadByte:", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
got += string(b1)
|
||||
// Put it back and read it again
|
||||
// Put it back and read it again.
|
||||
if err = r.UnreadByte(); err != nil {
|
||||
t.Fatalf("unexpected error on UnreadByte: %v", err)
|
||||
t.Fatal("unexpected error on UnreadByte:", err)
|
||||
}
|
||||
b2, err := r.ReadByte()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error reading after unreading: %v", err)
|
||||
t.Fatal("unexpected error reading after unreading:", err)
|
||||
}
|
||||
if b1 != b2 {
|
||||
t.Fatalf("incorrect byte after unread: got %c wanted %c", b1, b2)
|
||||
t.Fatalf("incorrect byte after unread: got %q, want %q", b1, b2)
|
||||
}
|
||||
}
|
||||
if got != want {
|
||||
t.Errorf("got=%q want=%q", got, want)
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnreadByteMultiple(t *testing.T) {
|
||||
segments := []string{"Hello, ", "world"}
|
||||
data := strings.Join(segments, "")
|
||||
for n := 0; n <= len(data); n++ {
|
||||
r := NewReader(&StringReader{data: segments})
|
||||
// Read n bytes.
|
||||
for i := 0; i < n; i++ {
|
||||
b, err := r.ReadByte()
|
||||
if err != nil {
|
||||
t.Fatalf("n = %d: unexpected error on ReadByte: %v", n, err)
|
||||
}
|
||||
if b != data[i] {
|
||||
t.Fatalf("n = %d: incorrect byte returned from ReadByte: got %q, want %q", n, b, data[i])
|
||||
}
|
||||
}
|
||||
// Unread one byte if there is one.
|
||||
if n > 0 {
|
||||
if err := r.UnreadByte(); err != nil {
|
||||
t.Errorf("n = %d: unexpected error on UnreadByte: %v", n, err)
|
||||
}
|
||||
}
|
||||
// Test that we cannot unread any further.
|
||||
if err := r.UnreadByte(); err == nil {
|
||||
t.Errorf("n = %d: expected error on UnreadByte", n)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user