From 652c8c6712886184e59a110c3fa1e6dcb643d93b Mon Sep 17 00:00:00 2001 From: chanxuehong Date: Mon, 24 Jul 2023 15:29:47 +0800 Subject: [PATCH] encoding/base32: reduce the overflow risk when computing encode/decode length --- src/encoding/base32/base32.go | 5 +- src/encoding/base32/base32_test.go | 124 +++++++++++++++++++---------- 2 files changed, 82 insertions(+), 47 deletions(-) diff --git a/src/encoding/base32/base32.go b/src/encoding/base32/base32.go index 3dc37b0aa79..a4d515edbd7 100644 --- a/src/encoding/base32/base32.go +++ b/src/encoding/base32/base32.go @@ -271,7 +271,7 @@ func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser { // of an input buffer of length n. func (enc *Encoding) EncodedLen(n int) int { if enc.padChar == NoPadding { - return (n*8 + 4) / 5 + return n/5*8 + (n%5*8+4)/5 } return (n + 4) / 5 * 8 } @@ -545,8 +545,7 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader { // corresponding to n bytes of base32-encoded data. func (enc *Encoding) DecodedLen(n int) int { if enc.padChar == NoPadding { - return n * 5 / 8 + return n/8*5 + n%8*5/8 } - return n / 8 * 5 } diff --git a/src/encoding/base32/base32_test.go b/src/encoding/base32/base32_test.go index 8118531b383..bdb9f0e61f8 100644 --- a/src/encoding/base32/base32_test.go +++ b/src/encoding/base32/base32_test.go @@ -8,6 +8,8 @@ import ( "bytes" "errors" "io" + "math" + "strconv" "strings" "testing" ) @@ -679,52 +681,86 @@ func TestBufferedDecodingPadding(t *testing.T) { } } -func TestEncodedDecodedLen(t *testing.T) { +func TestEncodedLen(t *testing.T) { + var rawStdEncoding = StdEncoding.WithPadding(NoPadding) type test struct { - in int - wantEnc int - wantDec int + enc *Encoding + n int + want int64 } - data := bytes.Repeat([]byte("x"), 100) - for _, test := range []struct { - name string - enc *Encoding - cases []test - }{ - {"StdEncoding", StdEncoding, []test{ - {0, 0, 0}, - {1, 8, 5}, - {5, 8, 5}, - {6, 16, 10}, - {10, 16, 10}, - }}, - {"NoPadding", StdEncoding.WithPadding(NoPadding), []test{ - {0, 0, 0}, - {1, 2, 1}, - {2, 4, 2}, - {5, 8, 5}, - {6, 10, 6}, - {7, 12, 7}, - {10, 16, 10}, - {11, 18, 11}, - }}, - } { - t.Run(test.name, func(t *testing.T) { - for _, tc := range test.cases { - encLen := test.enc.EncodedLen(tc.in) - decLen := test.enc.DecodedLen(encLen) - enc := test.enc.EncodeToString(data[:tc.in]) - if len(enc) != encLen { - t.Fatalf("EncodedLen(%d) = %d but encoded to %q (%d)", tc.in, encLen, enc, len(enc)) - } - if encLen != tc.wantEnc { - t.Fatalf("EncodedLen(%d) = %d; want %d", tc.in, encLen, tc.wantEnc) - } - if decLen != tc.wantDec { - t.Fatalf("DecodedLen(%d) = %d; want %d", encLen, decLen, tc.wantDec) - } - } - }) + tests := []test{ + {StdEncoding, 0, 0}, + {StdEncoding, 1, 8}, + {StdEncoding, 2, 8}, + {StdEncoding, 3, 8}, + {StdEncoding, 4, 8}, + {StdEncoding, 5, 8}, + {StdEncoding, 6, 16}, + {StdEncoding, 10, 16}, + {StdEncoding, 11, 24}, + {rawStdEncoding, 0, 0}, + {rawStdEncoding, 1, 2}, + {rawStdEncoding, 2, 4}, + {rawStdEncoding, 3, 5}, + {rawStdEncoding, 4, 7}, + {rawStdEncoding, 5, 8}, + {rawStdEncoding, 6, 10}, + {rawStdEncoding, 7, 12}, + {rawStdEncoding, 10, 16}, + {rawStdEncoding, 11, 18}, + } + // check overflow + switch strconv.IntSize { + case 32: + tests = append(tests, test{rawStdEncoding, (math.MaxInt-4)/8 + 1, 429496730}) + tests = append(tests, test{rawStdEncoding, math.MaxInt/8*5 + 4, math.MaxInt}) + case 64: + tests = append(tests, test{rawStdEncoding, (math.MaxInt-4)/8 + 1, 1844674407370955162}) + tests = append(tests, test{rawStdEncoding, math.MaxInt/8*5 + 4, math.MaxInt}) + } + for _, tt := range tests { + if got := tt.enc.EncodedLen(tt.n); int64(got) != tt.want { + t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want) + } + } +} + +func TestDecodedLen(t *testing.T) { + var rawStdEncoding = StdEncoding.WithPadding(NoPadding) + type test struct { + enc *Encoding + n int + want int64 + } + tests := []test{ + {StdEncoding, 0, 0}, + {StdEncoding, 8, 5}, + {StdEncoding, 16, 10}, + {StdEncoding, 24, 15}, + {rawStdEncoding, 0, 0}, + {rawStdEncoding, 2, 1}, + {rawStdEncoding, 4, 2}, + {rawStdEncoding, 5, 3}, + {rawStdEncoding, 7, 4}, + {rawStdEncoding, 8, 5}, + {rawStdEncoding, 10, 6}, + {rawStdEncoding, 12, 7}, + {rawStdEncoding, 16, 10}, + {rawStdEncoding, 18, 11}, + } + // check overflow + switch strconv.IntSize { + case 32: + tests = append(tests, test{rawStdEncoding, math.MaxInt/5 + 1, 268435456}) + tests = append(tests, test{rawStdEncoding, math.MaxInt, 1342177279}) + case 64: + tests = append(tests, test{rawStdEncoding, math.MaxInt/5 + 1, 1152921504606846976}) + tests = append(tests, test{rawStdEncoding, math.MaxInt, 5764607523034234879}) + } + for _, tt := range tests { + if got := tt.enc.DecodedLen(tt.n); int64(got) != tt.want { + t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want) + } } }