diff --git a/src/bufio/bufio.go b/src/bufio/bufio.go index a58df254941..063a7785f36 100644 --- a/src/bufio/bufio.go +++ b/src/bufio/bufio.go @@ -745,19 +745,14 @@ func (b *Writer) WriteString(s string) (int, error) { } // ReadFrom implements io.ReaderFrom. If the underlying writer -// supports the ReadFrom method, and b has no buffered data yet, -// this calls the underlying ReadFrom without buffering. +// supports the ReadFrom method, this calls the underlying ReadFrom. +// If there is buffered data and an underlying ReadFrom, this fills +// the buffer and writes it before calling ReadFrom. func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) { if b.err != nil { return 0, b.err } - if b.Buffered() == 0 { - if w, ok := b.wr.(io.ReaderFrom); ok { - n, err = w.ReadFrom(r) - b.err = err - return n, err - } - } + readerFrom, readerFromOK := b.wr.(io.ReaderFrom) var m int for { if b.Available() == 0 { @@ -765,6 +760,12 @@ func (b *Writer) ReadFrom(r io.Reader) (n int64, err error) { return n, err1 } } + if readerFromOK && b.Buffered() == 0 { + nn, err := readerFrom.ReadFrom(r) + b.err = err + n += nn + return n, err + } nr := 0 for nr < maxConsecutiveEmptyReads { m, err = r.Read(b.buf[b.n:]) diff --git a/src/bufio/bufio_test.go b/src/bufio/bufio_test.go index 8e8a8a1778a..66b3e700531 100644 --- a/src/bufio/bufio_test.go +++ b/src/bufio/bufio_test.go @@ -1351,6 +1351,54 @@ func TestWriterReadFromErrNoProgress(t *testing.T) { } } +type readFromWriter struct { + buf []byte + writeBytes int + readFromBytes int +} + +func (w *readFromWriter) Write(p []byte) (int, error) { + w.buf = append(w.buf, p...) + w.writeBytes += len(p) + return len(p), nil +} + +func (w *readFromWriter) ReadFrom(r io.Reader) (int64, error) { + b, err := io.ReadAll(r) + w.buf = append(w.buf, b...) + w.readFromBytes += len(b) + return int64(len(b)), err +} + +// Test that calling (*Writer).ReadFrom with a partially-filled buffer +// fills the buffer before switching over to ReadFrom. +func TestWriterReadFromWithBufferedData(t *testing.T) { + const bufsize = 16 + + input := createTestInput(64) + rfw := &readFromWriter{} + w := NewWriterSize(rfw, bufsize) + + const writeSize = 8 + if n, err := w.Write(input[:writeSize]); n != writeSize || err != nil { + t.Errorf("w.Write(%v bytes) = %v, %v; want %v, nil", writeSize, n, err, writeSize) + } + n, err := w.ReadFrom(bytes.NewReader(input[writeSize:])) + if wantn := len(input[writeSize:]); int(n) != wantn || err != nil { + t.Errorf("io.Copy(w, %v bytes) = %v, %v; want %v, nil", wantn, n, err, wantn) + } + if err := w.Flush(); err != nil { + t.Errorf("w.Flush() = %v, want nil", err) + } + + if got, want := rfw.writeBytes, bufsize; got != want { + t.Errorf("wrote %v bytes with Write, want %v", got, want) + } + if got, want := rfw.readFromBytes, len(input)-bufsize; got != want { + t.Errorf("wrote %v bytes with ReadFrom, want %v", got, want) + } +} + func TestReadZero(t *testing.T) { for _, size := range []int{100, 2} { t.Run(fmt.Sprintf("bufsize=%d", size), func(t *testing.T) {