diff --git a/src/io/io.go b/src/io/io.go index 86710ed6f3..27482de62e 100644 --- a/src/io/io.go +++ b/src/io/io.go @@ -335,7 +335,7 @@ func ReadFull(r Reader, buf []byte) (n int, err error) { // If dst implements the ReaderFrom interface, // the copy is implemented using it. func CopyN(dst Writer, src Reader, n int64) (written int64, err error) { - written, err = copyN(dst, src, n) + written, err = Copy(dst, LimitReader(src, n)) if written == n { return n, nil } @@ -346,55 +346,6 @@ func CopyN(dst Writer, src Reader, n int64) (written int64, err error) { return } -// copyN copies n bytes (or until an error) from src to dst. -// It returns the number of bytes copied and the earliest -// error encountered while copying. -// -// If dst implements the ReaderFrom interface, -// the copy is implemented using it. -func copyN(dst Writer, src Reader, n int64) (int64, error) { - // If the writer has a ReadFrom method, use it to do the copy. - if rt, ok := dst.(ReaderFrom); ok { - return rt.ReadFrom(LimitReader(src, n)) - } - - l := 32 * 1024 // same size as in copyBuffer - if n < int64(l) { - l = int(n) - } - buf := make([]byte, l) - - var written int64 - for n > 0 { - if n < int64(len(buf)) { - buf = buf[:n] - } - - nr, errR := src.Read(buf) - if nr > 0 { - n -= int64(nr) - nw, errW := dst.Write(buf[:nr]) - if nw > 0 { - written += int64(nw) - } - if errW != nil { - return written, errW - } - if nr != nw { - return written, ErrShortWrite - } - } - - if errR != nil { - if errR != EOF { - return written, errR - } - return written, nil - } - } - return written, nil -} - // Copy copies from src to dst until either EOF is reached // on src or an error occurs. It returns the number of bytes // copied and the first error encountered while copying, if any. @@ -434,8 +385,16 @@ func copyBuffer(dst Writer, src Reader, buf []byte) (written int64, err error) { if rt, ok := dst.(ReaderFrom); ok { return rt.ReadFrom(src) } + size := 32 * 1024 + if l, ok := src.(*LimitedReader); ok && int64(size) > l.N { + if l.N < 1 { + size = 1 + } else { + size = int(l.N) + } + } if buf == nil { - buf = make([]byte, 32*1024) + buf = make([]byte, size) } for { nr, er := src.Read(buf) diff --git a/src/io/io_test.go b/src/io/io_test.go index e81065c13d..0e4ce61240 100644 --- a/src/io/io_test.go +++ b/src/io/io_test.go @@ -32,6 +32,21 @@ func TestCopy(t *testing.T) { } } +func TestCopyNegative(t *testing.T) { + rb := new(Buffer) + wb := new(Buffer) + rb.WriteString("hello") + Copy(wb, &LimitedReader{R: rb, N: -1}) + if wb.String() != "" { + t.Errorf("Copy on LimitedReader with N<0 copied data") + } + + CopyN(wb, rb, -1) + if wb.String() != "" { + t.Errorf("CopyN with N<0 copied data") + } +} + func TestCopyBuffer(t *testing.T) { rb := new(Buffer) wb := new(Buffer)