diff --git a/src/pkg/io/io.go b/src/pkg/io/io.go index a41a674cea9..2b2f4d56714 100644 --- a/src/pkg/io/io.go +++ b/src/pkg/io/io.go @@ -19,6 +19,9 @@ type Error struct { // but failed to return an explicit error. var ErrShortWrite os.Error = &Error{"short write"} +// ErrShortBuffer means that a read required a longer buffer than was provided. +var ErrShortBuffer os.Error = &Error{"short buffer"} + // ErrUnexpectedEOF means that os.EOF was encountered in the // middle of reading a fixed-size block or data structure. var ErrUnexpectedEOF os.Error = &Error{"unexpected EOF"} @@ -165,8 +168,11 @@ func WriteString(w Writer, s string) (n int, err os.Error) { // The error is os.EOF only if no bytes were read. // If an EOF happens after reading fewer than min bytes, // ReadAtLeast returns ErrUnexpectedEOF. +// If min is greater than the length of buf, ReadAtLeast returns ErrShortBuffer. func ReadAtLeast(r Reader, buf []byte, min int) (n int, err os.Error) { - n = 0 + if len(buf) < min { + return 0, ErrShortBuffer + } for n < min { nn, e := r.Read(buf[n:]) if nn > 0 { @@ -179,7 +185,7 @@ func ReadAtLeast(r Reader, buf []byte, min int) (n int, err os.Error) { return n, e } } - return n, nil + return } // ReadFull reads exactly len(buf) bytes from r into buf. diff --git a/src/pkg/io/io_test.go b/src/pkg/io/io_test.go index 4ad1e595107..20f240a51a5 100644 --- a/src/pkg/io/io_test.go +++ b/src/pkg/io/io_test.go @@ -7,6 +7,7 @@ package io_test import ( "bytes" . "io" + "os" "testing" ) @@ -78,3 +79,42 @@ func TestCopynWriteTo(t *testing.T) { t.Errorf("Copyn did not work properly") } } + +func TestReadAtLeast(t *testing.T) { + var rb bytes.Buffer + rb.Write([]byte("0123")) + buf := make([]byte, 2) + n, err := ReadAtLeast(&rb, buf, 2) + if err != nil { + t.Error(err) + } + n, err = ReadAtLeast(&rb, buf, 4) + if err != ErrShortBuffer { + t.Errorf("expected ErrShortBuffer got %v", err) + } + if n != 0 { + t.Errorf("expected to have read 0 bytes, got %v", n) + } + n, err = ReadAtLeast(&rb, buf, 1) + if err != nil { + t.Error(err) + } + if n != 2 { + t.Errorf("expected to have read 2 bytes, got %v", n) + } + n, err = ReadAtLeast(&rb, buf, 2) + if err != os.EOF { + t.Errorf("expected EOF, got %v", err) + } + if n != 0 { + t.Errorf("expected to have read 0 bytes, got %v", n) + } + rb.Write([]byte("4")) + n, err = ReadAtLeast(&rb, buf, 2) + if err != ErrUnexpectedEOF { + t.Errorf("expected ErrUnexpectedEOF, got %v", err) + } + if n != 1 { + t.Errorf("expected to have read 1 bytes, got %v", n) + } +}