diff --git a/src/encoding/base32/base32.go b/src/encoding/base32/base32.go index 09e90eab5f..fdf42e5df3 100644 --- a/src/encoding/base32/base32.go +++ b/src/encoding/base32/base32.go @@ -404,7 +404,7 @@ type decoder struct { outbuf [1024 / 8 * 5]byte } -func readEncodedData(r io.Reader, buf []byte, min int) (n int, err error) { +func readEncodedData(r io.Reader, buf []byte, min int, expectsPadding bool) (n int, err error) { for n < min && err == nil { var nn int nn, err = r.Read(buf[n:]) @@ -415,7 +415,9 @@ func readEncodedData(r io.Reader, buf []byte, min int) (n int, err error) { err = io.ErrUnexpectedEOF } // no data was read, the buffer already contains some data - if min < 8 && n == 0 && err == io.EOF { + // when padding is disabled this is not an error, as the message can be of + // any length + if expectsPadding && min < 8 && n == 0 && err == io.EOF { err = io.ErrUnexpectedEOF } return @@ -445,15 +447,32 @@ func (d *decoder) Read(p []byte) (n int, err error) { nn = len(d.buf) } - nn, d.err = readEncodedData(d.r, d.buf[d.nbuf:nn], 8-d.nbuf) + // Minimum amount of bytes that needs to be read each cycle + var min int + var expectsPadding bool + if d.enc.padChar == NoPadding { + min = 1 + expectsPadding = false + } else { + min = 8 - d.nbuf + expectsPadding = true + } + + nn, d.err = readEncodedData(d.r, d.buf[d.nbuf:nn], min, expectsPadding) d.nbuf += nn - if d.nbuf < 8 { + if d.nbuf < min { return 0, d.err } // Decode chunk into p, or d.out and then p if p is too small. - nr := d.nbuf / 8 * 8 - nw := d.nbuf / 8 * 5 + var nr int + if d.enc.padChar == NoPadding { + nr = d.nbuf + } else { + nr = d.nbuf / 8 * 8 + } + nw := d.enc.DecodedLen(d.nbuf) + if nw > len(p) { nw, d.end, err = d.enc.decode(d.outbuf[0:], d.buf[0:nr]) d.out = d.outbuf[0:nw] diff --git a/src/encoding/base32/base32_test.go b/src/encoding/base32/base32_test.go index fdd862dc49..c5506ed4de 100644 --- a/src/encoding/base32/base32_test.go +++ b/src/encoding/base32/base32_test.go @@ -686,3 +686,66 @@ func TestWithoutPaddingClose(t *testing.T) { } } } + +func TestDecodeReadAll(t *testing.T) { + encodings := []*Encoding{ + StdEncoding, + StdEncoding.WithPadding(NoPadding), + } + + for _, pair := range pairs { + for encIndex, encoding := range encodings { + encoded := pair.encoded + if encoding.padChar == NoPadding { + encoded = strings.Replace(encoded, "=", "", -1) + } + + decReader, err := ioutil.ReadAll(NewDecoder(encoding, strings.NewReader(encoded))) + if err != nil { + t.Errorf("NewDecoder error: %v", err) + } + + if pair.decoded != string(decReader) { + t.Errorf("Expected %s got %s; Encoding %d", pair.decoded, decReader, encIndex) + } + } + } +} + +func TestDecodeSmallBuffer(t *testing.T) { + encodings := []*Encoding{ + StdEncoding, + StdEncoding.WithPadding(NoPadding), + } + + for bufferSize := 1; bufferSize < 200; bufferSize++ { + for _, pair := range pairs { + for encIndex, encoding := range encodings { + encoded := pair.encoded + if encoding.padChar == NoPadding { + encoded = strings.Replace(encoded, "=", "", -1) + } + + decoder := NewDecoder(encoding, strings.NewReader(encoded)) + + var allRead []byte + + for { + buf := make([]byte, bufferSize) + n, err := decoder.Read(buf) + allRead = append(allRead, buf[0:n]...) + if err == io.EOF { + break + } + if err != nil { + t.Error(err) + } + } + + if pair.decoded != string(allRead) { + t.Errorf("Expected %s got %s; Encoding %d; bufferSize %d", pair.decoded, allRead, encIndex, bufferSize) + } + } + } + } +}