From 8bd9242f7cd27f69449b2cdae8acbf25a952134c Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Wed, 9 Apr 2014 14:19:13 -0700 Subject: [PATCH] bufio: fix UnreadByte Also: - fix error messages in tests - make tests more symmetric Fixes #7607. LGTM=r R=r CC=golang-codereviews https://golang.org/cl/86180043 --- src/pkg/bufio/bufio.go | 24 +++++++------- src/pkg/bufio/bufio_test.go | 62 +++++++++++++++++++++++++++---------- 2 files changed, 57 insertions(+), 29 deletions(-) diff --git a/src/pkg/bufio/bufio.go b/src/pkg/bufio/bufio.go index de81b4ddfd..1e0cdae38e 100644 --- a/src/pkg/bufio/bufio.go +++ b/src/pkg/bufio/bufio.go @@ -177,7 +177,7 @@ func (b *Reader) Read(p []byte) (n int, err error) { // If no byte is available, returns an error. func (b *Reader) ReadByte() (c byte, err error) { b.lastRuneSize = -1 - for b.w == b.r { + for b.r == b.w { if b.err != nil { return 0, b.readErr() } @@ -191,19 +191,19 @@ func (b *Reader) ReadByte() (c byte, err error) { // UnreadByte unreads the last byte. Only the most recently read byte can be unread. func (b *Reader) UnreadByte() error { - b.lastRuneSize = -1 - if b.r == b.w && b.lastByte >= 0 { - b.w = 1 - b.r = 0 - b.buf[0] = byte(b.lastByte) - b.lastByte = -1 - return nil - } - if b.r <= 0 { + if b.lastByte < 0 || b.r == 0 && b.w > 0 { return ErrInvalidUnreadByte } - b.r-- + // b.r > 0 || b.w == 0 + if b.r > 0 { + b.r-- + } else { + // b.r == 0 && b.w == 0 + b.w = 1 + } + b.buf[b.r] = byte(b.lastByte) b.lastByte = -1 + b.lastRuneSize = -1 return nil } @@ -233,7 +233,7 @@ func (b *Reader) ReadRune() (r rune, size int, err error) { // regard it is stricter than UnreadByte, which will unread the last byte // from any read operation.) func (b *Reader) UnreadRune() error { - if b.lastRuneSize < 0 || b.r == 0 { + if b.lastRuneSize < 0 || b.r < b.lastRuneSize { return ErrInvalidUnreadRune } b.r -= b.lastRuneSize diff --git a/src/pkg/bufio/bufio_test.go b/src/pkg/bufio/bufio_test.go index 800c6d2717..32ca86161f 100644 --- a/src/pkg/bufio/bufio_test.go +++ b/src/pkg/bufio/bufio_test.go @@ -228,66 +228,94 @@ func TestReadRune(t *testing.T) { } func TestUnreadRune(t *testing.T) { - got := "" segments := []string{"Hello, world:", "日本語"} - data := strings.Join(segments, "") r := NewReader(&StringReader{data: segments}) + got := "" + want := strings.Join(segments, "") // Normal execution. for { r1, _, err := r.ReadRune() if err != nil { if err != io.EOF { - t.Error("unexpected EOF") + t.Error("unexpected error on ReadRune:", err) } break } got += string(r1) - // Put it back and read it again + // Put it back and read it again. if err = r.UnreadRune(); err != nil { - t.Error("unexpected error on UnreadRune:", err) + t.Fatal("unexpected error on UnreadRune:", err) } r2, _, err := r.ReadRune() if err != nil { - t.Error("unexpected error reading after unreading:", err) + t.Fatal("unexpected error reading after unreading:", err) } if r1 != r2 { - t.Errorf("incorrect rune after unread: got %c wanted %c", r1, r2) + t.Fatalf("incorrect rune after unread: got %c, want %c", r1, r2) } } - if got != data { - t.Errorf("want=%q got=%q", data, got) + if got != want { + t.Errorf("got %q, want %q", got, want) } } func TestUnreadByte(t *testing.T) { - want := "Hello, world" - got := "" segments := []string{"Hello, ", "world"} r := NewReader(&StringReader{data: segments}) + got := "" + want := strings.Join(segments, "") // Normal execution. for { b1, err := r.ReadByte() if err != nil { if err != io.EOF { - t.Fatal("unexpected EOF") + t.Error("unexpected error on ReadByte:", err) } break } got += string(b1) - // Put it back and read it again + // Put it back and read it again. if err = r.UnreadByte(); err != nil { - t.Fatalf("unexpected error on UnreadByte: %v", err) + t.Fatal("unexpected error on UnreadByte:", err) } b2, err := r.ReadByte() if err != nil { - t.Fatalf("unexpected error reading after unreading: %v", err) + t.Fatal("unexpected error reading after unreading:", err) } if b1 != b2 { - t.Fatalf("incorrect byte after unread: got %c wanted %c", b1, b2) + t.Fatalf("incorrect byte after unread: got %q, want %q", b1, b2) } } if got != want { - t.Errorf("got=%q want=%q", got, want) + t.Errorf("got %q, want %q", got, want) + } +} + +func TestUnreadByteMultiple(t *testing.T) { + segments := []string{"Hello, ", "world"} + data := strings.Join(segments, "") + for n := 0; n <= len(data); n++ { + r := NewReader(&StringReader{data: segments}) + // Read n bytes. + for i := 0; i < n; i++ { + b, err := r.ReadByte() + if err != nil { + t.Fatalf("n = %d: unexpected error on ReadByte: %v", n, err) + } + if b != data[i] { + t.Fatalf("n = %d: incorrect byte returned from ReadByte: got %q, want %q", n, b, data[i]) + } + } + // Unread one byte if there is one. + if n > 0 { + if err := r.UnreadByte(); err != nil { + t.Errorf("n = %d: unexpected error on UnreadByte: %v", n, err) + } + } + // Test that we cannot unread any further. + if err := r.UnreadByte(); err == nil { + t.Errorf("n = %d: expected error on UnreadByte", n) + } } }