diff --git a/src/encoding/base32/base32.go b/src/encoding/base32/base32.go index c193e65e1b..788a06115a 100644 --- a/src/encoding/base32/base32.go +++ b/src/encoding/base32/base32.go @@ -343,18 +343,33 @@ type decoder struct { outbuf [1024 / 8 * 5]byte } -func (d *decoder) Read(p []byte) (n int, err error) { - if d.err != nil { - return 0, d.err +func readEncodedData(r io.Reader, buf []byte, min int) (n int, err error) { + for n < min && err == nil { + var nn int + nn, err = r.Read(buf[n:]) + n += nn } + if n < min && n > 0 && err == io.EOF { + err = io.ErrUnexpectedEOF + } + return +} +func (d *decoder) Read(p []byte) (n int, err error) { // Use leftover decoded output from last read. if len(d.out) > 0 { n = copy(p, d.out) d.out = d.out[n:] + if len(d.out) == 0 { + return n, d.err + } return n, nil } + if d.err != nil { + return 0, d.err + } + // Read a chunk. nn := len(p) / 5 * 8 if nn < 8 { @@ -363,7 +378,8 @@ func (d *decoder) Read(p []byte) (n int, err error) { if nn > len(d.buf) { nn = len(d.buf) } - nn, d.err = io.ReadAtLeast(d.r, d.buf[d.nbuf:nn], 8-d.nbuf) + + nn, d.err = readEncodedData(d.r, d.buf[d.nbuf:nn], 8-d.nbuf) d.nbuf += nn if d.nbuf < 8 { return 0, d.err @@ -373,21 +389,30 @@ func (d *decoder) Read(p []byte) (n int, err error) { nr := d.nbuf / 8 * 8 nw := d.nbuf / 8 * 5 if nw > len(p) { - nw, d.end, d.err = d.enc.decode(d.outbuf[0:], d.buf[0:nr]) + nw, d.end, err = d.enc.decode(d.outbuf[0:], d.buf[0:nr]) d.out = d.outbuf[0:nw] n = copy(p, d.out) d.out = d.out[n:] } else { - n, d.end, d.err = d.enc.decode(p, d.buf[0:nr]) + n, d.end, err = d.enc.decode(p, d.buf[0:nr]) } d.nbuf -= nr for i := 0; i < d.nbuf; i++ { d.buf[i] = d.buf[i+nr] } - if d.err == nil { + if err != nil && (d.err == nil || d.err == io.EOF) { d.err = err } + + if len(d.out) > 0 { + // We cannot return all the decoded bytes to the caller in this + // invocation of Read, so we return a nil error to ensure that Read + // will be called again. The error stored in d.err, if any, will be + // returned with the last set of decoded bytes. + return n, nil + } + return n, d.err } @@ -407,7 +432,7 @@ func (r *newlineFilteringReader) Read(p []byte) (int, error) { offset++ } } - if offset > 0 { + if err != nil || offset > 0 { return offset, err } // Previous buffer entirely whitespace, read again diff --git a/src/encoding/base32/base32_test.go b/src/encoding/base32/base32_test.go index 66a48a3f6f..37db770b02 100644 --- a/src/encoding/base32/base32_test.go +++ b/src/encoding/base32/base32_test.go @@ -6,6 +6,7 @@ package base32 import ( "bytes" + "errors" "io" "io/ioutil" "strings" @@ -123,6 +124,160 @@ func TestDecoder(t *testing.T) { } } +type badReader struct { + data []byte + errs []error + called int + limit int +} + +// Populates p with data, returns a count of the bytes written and an +// error. The error returned is taken from badReader.errs, with each +// invocation of Read returning the next error in this slice, or io.EOF, +// if all errors from the slice have already been returned. The +// number of bytes returned is determined by the size of the input buffer +// the test passes to decoder.Read and will be a multiple of 8, unless +// badReader.limit is non zero. +func (b *badReader) Read(p []byte) (int, error) { + lim := len(p) + if b.limit != 0 && b.limit < lim { + lim = b.limit + } + if len(b.data) < lim { + lim = len(b.data) + } + for i := range p[:lim] { + p[i] = b.data[i] + } + b.data = b.data[lim:] + err := io.EOF + if b.called < len(b.errs) { + err = b.errs[b.called] + } + b.called++ + return lim, err +} + +// TestIssue20044 tests that decoder.Read behaves correctly when the caller +// supplied reader returns an error. +func TestIssue20044(t *testing.T) { + badErr := errors.New("bad reader error") + testCases := []struct { + r badReader + res string + err error + dbuflen int + }{ + // Check valid input data accompanied by an error is processed and the error is propagated. + {r: badReader{data: []byte("MY======"), errs: []error{badErr}}, + res: "f", err: badErr}, + // Check a read error accompanied by input data consisting of newlines only is propagated. + {r: badReader{data: []byte("\n\n\n\n\n\n\n\n"), errs: []error{badErr, nil}}, + res: "", err: badErr}, + // Reader will be called twice. The first time it will return 8 newline characters. The + // second time valid base32 encoded data and an error. The data should be decoded + // correctly and the error should be propagated. + {r: badReader{data: []byte("\n\n\n\n\n\n\n\nMY======"), errs: []error{nil, badErr}}, + res: "f", err: badErr, dbuflen: 8}, + // Reader returns invalid input data (too short) and an error. Verify the reader + // error is returned. + {r: badReader{data: []byte("MY====="), errs: []error{badErr}}, + res: "", err: badErr}, + // Reader returns invalid input data (too short) but no error. Verify io.ErrUnexpectedEOF + // is returned. + {r: badReader{data: []byte("MY====="), errs: []error{nil}}, + res: "", err: io.ErrUnexpectedEOF}, + // Reader returns invalid input data and an error. Verify the reader and not the + // decoder error is returned. + {r: badReader{data: []byte("Ma======"), errs: []error{badErr}}, + res: "", err: badErr}, + // Reader returns valid data and io.EOF. Check data is decoded and io.EOF is propagated. + {r: badReader{data: []byte("MZXW6YTB"), errs: []error{io.EOF}}, + res: "fooba", err: io.EOF}, + // Check errors are properly reported when decoder.Read is called multiple times. + // decoder.Read will be called 8 times, badReader.Read will be called twice, returning + // valid data both times but an error on the second call. + {r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, badErr}}, + res: "leasure.", err: badErr, dbuflen: 1}, + // Check io.EOF is properly reported when decoder.Read is called multiple times. + // decoder.Read will be called 8 times, badReader.Read will be called twice, returning + // valid data both times but io.EOF on the second call. + {r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, io.EOF}}, + res: "leasure.", err: io.EOF, dbuflen: 1}, + // The following two test cases check that errors are propagated correctly when more than + // 8 bytes are read at a time. + {r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{io.EOF}}, + res: "leasure.", err: io.EOF, dbuflen: 11}, + {r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{badErr}}, + res: "leasure.", err: badErr, dbuflen: 11}, + // Check that errors are correctly propagated when the reader returns valid bytes in + // groups that are not divisible by 8. The first read will return 11 bytes and no + // error. The second will return 7 and an error. The data should be decoded correctly + // and the error should be propagated. + {r: badReader{data: []byte("NRSWC43VOJSS4==="), errs: []error{nil, badErr}, limit: 11}, + res: "leasure.", err: badErr}, + } + + for _, tc := range testCases { + input := tc.r.data + decoder := NewDecoder(StdEncoding, &tc.r) + var dbuflen int + if tc.dbuflen > 0 { + dbuflen = tc.dbuflen + } else { + dbuflen = StdEncoding.DecodedLen(len(input)) + } + dbuf := make([]byte, dbuflen) + var err error + var res []byte + for err == nil { + var n int + n, err = decoder.Read(dbuf) + if n > 0 { + res = append(res, dbuf[:n]...) + } + } + + testEqual(t, "Decoding of %q = %q, want %q", string(input), string(res), tc.res) + testEqual(t, "Decoding of %q err = %v, expected %v", string(input), err, tc.err) + } +} + +// TestDecoderError verifies decode errors are propagated when there are no read +// errors. +func TestDecoderError(t *testing.T) { + for _, readErr := range []error{io.EOF, nil} { + input := "MZXW6YTb" + dbuf := make([]byte, StdEncoding.DecodedLen(len(input))) + br := badReader{data: []byte(input), errs: []error{readErr}} + decoder := NewDecoder(StdEncoding, &br) + n, err := decoder.Read(dbuf) + testEqual(t, "Read after EOF, n = %d, expected %d", n, 0) + if _, ok := err.(CorruptInputError); !ok { + t.Errorf("Corrupt input error expected. Found %T", err) + } + } +} + +// TestReaderEOF ensures decoder.Read behaves correctly when input data is +// exhausted. +func TestReaderEOF(t *testing.T) { + for _, readErr := range []error{io.EOF, nil} { + input := "MZXW6YTB" + br := badReader{data: []byte(input), errs: []error{nil, readErr}} + decoder := NewDecoder(StdEncoding, &br) + dbuf := make([]byte, StdEncoding.DecodedLen(len(input))) + n, err := decoder.Read(dbuf) + testEqual(t, "Decoding of %q err = %v, expected %v", string(input), err, error(nil)) + n, err = decoder.Read(dbuf) + testEqual(t, "Read after EOF, n = %d, expected %d", n, 0) + testEqual(t, "Read after EOF, err = %v, expected %v", err, io.EOF) + n, err = decoder.Read(dbuf) + testEqual(t, "Read after EOF, n = %d, expected %d", n, 0) + testEqual(t, "Read after EOF, err = %v, expected %v", err, io.EOF) + } +} + func TestDecoderBuffering(t *testing.T) { for bs := 1; bs <= 12; bs++ { decoder := NewDecoder(StdEncoding, strings.NewReader(bigtest.encoded))