diff --git a/src/pkg/encoding/base64/base64.go b/src/pkg/encoding/base64/base64.go index 0b07e733a0..a6efd44615 100644 --- a/src/pkg/encoding/base64/base64.go +++ b/src/pkg/encoding/base64/base64.go @@ -224,21 +224,33 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { var dbuf [4]byte dlen := 4 - for j := 0; j < 4; { + for j := range dbuf { if len(src) == 0 { return n, false, CorruptInputError(olen - len(src) - j) } in := src[0] src = src[1:] - if in == '=' && j >= 2 && len(src) < 4 { + if in == '=' { // We've reached the end and there's padding - if len(src)+j < 4-1 { - // not enough padding - return n, false, CorruptInputError(olen) - } - if len(src) > 0 && src[0] != '=' { + switch j { + case 0, 1: // incorrect padding return n, false, CorruptInputError(olen - len(src) - 1) + case 2: + // "==" is expected, the first "=" is already consumed. + if len(src) == 0 { + // not enough padding + return n, false, CorruptInputError(olen) + } + if src[0] != '=' { + // incorrect padding + return n, false, CorruptInputError(olen - len(src) - 1) + } + src = src[1:] + } + if len(src) > 0 { + // trailing garbage + return n, false, CorruptInputError(olen - len(src)) } dlen, end = j, true break @@ -247,7 +259,6 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) { if dbuf[j] == 0xFF { return n, false, CorruptInputError(olen - len(src) - 1) } - j++ } // Pack 4x 6-bit source blocks into 3 byte destination diff --git a/src/pkg/encoding/base64/base64_test.go b/src/pkg/encoding/base64/base64_test.go index 6bcc724d9b..0285629029 100644 --- a/src/pkg/encoding/base64/base64_test.go +++ b/src/pkg/encoding/base64/base64_test.go @@ -149,9 +149,13 @@ func TestDecodeCorrupt(t *testing.T) { }{ {"", -1}, {"!!!!", 0}, + {"====", 0}, {"x===", 1}, + {"=AAA", 0}, + {"A=AA", 1}, {"AA=A", 2}, - {"AAA=AAAA", 3}, + {"AA==A", 4}, + {"AAA=AAAA", 4}, {"AAAAA", 4}, {"AAAAAA", 4}, {"A=", 1},