diff --git a/src/encoding/base64/base64.go b/src/encoding/base64/base64.go index 6aa8a15bdc8..87f68970627 100644 --- a/src/encoding/base64/base64.go +++ b/src/encoding/base64/base64.go @@ -278,7 +278,7 @@ func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser { // of an input buffer of length n. func (enc *Encoding) EncodedLen(n int) int { if enc.padChar == NoPadding { - return (n*8 + 5) / 6 // minimum # chars at 6 bits per char + return n/3*4 + (n%3*8+5)/6 // minimum # chars at 6 bits per char } return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each } @@ -623,7 +623,7 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader { func (enc *Encoding) DecodedLen(n int) int { if enc.padChar == NoPadding { // Unpadded data may end with partial block of 2-3 characters. - return n * 6 / 8 + return n/4*3 + n%4*6/8 } // Padded base64 should always be a multiple of 4 characters in length. return n / 4 * 3 diff --git a/src/encoding/base64/base64_test.go b/src/encoding/base64/base64_test.go index 0ad88ebb3a7..97aea845ae3 100644 --- a/src/encoding/base64/base64_test.go +++ b/src/encoding/base64/base64_test.go @@ -9,8 +9,10 @@ import ( "errors" "fmt" "io" + "math" "reflect" "runtime/debug" + "strconv" "strings" "testing" "time" @@ -262,11 +264,12 @@ func TestDecodeBounds(t *testing.T) { } func TestEncodedLen(t *testing.T) { - for _, tt := range []struct { + type test struct { enc *Encoding n int - want int - }{ + want int64 + } + tests := []test{ {RawStdEncoding, 0, 0}, {RawStdEncoding, 1, 2}, {RawStdEncoding, 2, 3}, @@ -278,19 +281,30 @@ func TestEncodedLen(t *testing.T) { {StdEncoding, 3, 4}, {StdEncoding, 4, 8}, {StdEncoding, 7, 12}, - } { - if got := tt.enc.EncodedLen(tt.n); got != tt.want { + } + // check overflow + switch strconv.IntSize { + case 32: + tests = append(tests, test{RawStdEncoding, (math.MaxInt-5)/8 + 1, 357913942}) + tests = append(tests, test{RawStdEncoding, math.MaxInt/4*3 + 2, math.MaxInt}) + case 64: + tests = append(tests, test{RawStdEncoding, (math.MaxInt-5)/8 + 1, 1537228672809129302}) + tests = append(tests, test{RawStdEncoding, math.MaxInt/4*3 + 2, math.MaxInt}) + } + for _, tt := range tests { + if got := tt.enc.EncodedLen(tt.n); int64(got) != tt.want { t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want) } } } func TestDecodedLen(t *testing.T) { - for _, tt := range []struct { + type test struct { enc *Encoding n int - want int - }{ + want int64 + } + tests := []test{ {RawStdEncoding, 0, 0}, {RawStdEncoding, 2, 1}, {RawStdEncoding, 3, 2}, @@ -299,8 +313,18 @@ func TestDecodedLen(t *testing.T) { {StdEncoding, 0, 0}, {StdEncoding, 4, 3}, {StdEncoding, 8, 6}, - } { - if got := tt.enc.DecodedLen(tt.n); got != tt.want { + } + // check overflow + switch strconv.IntSize { + case 32: + tests = append(tests, test{RawStdEncoding, math.MaxInt/6 + 1, 268435456}) + tests = append(tests, test{RawStdEncoding, math.MaxInt, 1610612735}) + case 64: + tests = append(tests, test{RawStdEncoding, math.MaxInt/6 + 1, 1152921504606846976}) + tests = append(tests, test{RawStdEncoding, math.MaxInt, 6917529027641081855}) + } + for _, tt := range tests { + if got := tt.enc.DecodedLen(tt.n); int64(got) != tt.want { t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want) } }