mirror of
https://github.com/golang/go
synced 2024-11-23 01:50:04 -07:00
encoding/base64: reduce the overflow risk when computing encode/decode length
Change-Id: I0a55cdc38ae496e2070f0b9ef317a41f82352afd
GitHub-Last-Rev: c19527a26b
GitHub-Pull-Request: golang/go#61407
Reviewed-on: https://go-review.googlesource.com/c/go/+/510635
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
This commit is contained in:
parent
050d4d3b9e
commit
14adf4fb21
@ -278,7 +278,7 @@ func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
|
|||||||
// of an input buffer of length n.
|
// of an input buffer of length n.
|
||||||
func (enc *Encoding) EncodedLen(n int) int {
|
func (enc *Encoding) EncodedLen(n int) int {
|
||||||
if enc.padChar == NoPadding {
|
if enc.padChar == NoPadding {
|
||||||
return (n*8 + 5) / 6 // minimum # chars at 6 bits per char
|
return n/3*4 + (n%3*8+5)/6 // minimum # chars at 6 bits per char
|
||||||
}
|
}
|
||||||
return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each
|
return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each
|
||||||
}
|
}
|
||||||
@ -623,7 +623,7 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
|
|||||||
func (enc *Encoding) DecodedLen(n int) int {
|
func (enc *Encoding) DecodedLen(n int) int {
|
||||||
if enc.padChar == NoPadding {
|
if enc.padChar == NoPadding {
|
||||||
// Unpadded data may end with partial block of 2-3 characters.
|
// Unpadded data may end with partial block of 2-3 characters.
|
||||||
return n * 6 / 8
|
return n/4*3 + n%4*6/8
|
||||||
}
|
}
|
||||||
// Padded base64 should always be a multiple of 4 characters in length.
|
// Padded base64 should always be a multiple of 4 characters in length.
|
||||||
return n / 4 * 3
|
return n / 4 * 3
|
||||||
|
@ -9,8 +9,10 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"math"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -262,11 +264,12 @@ func TestDecodeBounds(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestEncodedLen(t *testing.T) {
|
func TestEncodedLen(t *testing.T) {
|
||||||
for _, tt := range []struct {
|
type test struct {
|
||||||
enc *Encoding
|
enc *Encoding
|
||||||
n int
|
n int
|
||||||
want int
|
want int64
|
||||||
}{
|
}
|
||||||
|
tests := []test{
|
||||||
{RawStdEncoding, 0, 0},
|
{RawStdEncoding, 0, 0},
|
||||||
{RawStdEncoding, 1, 2},
|
{RawStdEncoding, 1, 2},
|
||||||
{RawStdEncoding, 2, 3},
|
{RawStdEncoding, 2, 3},
|
||||||
@ -278,19 +281,30 @@ func TestEncodedLen(t *testing.T) {
|
|||||||
{StdEncoding, 3, 4},
|
{StdEncoding, 3, 4},
|
||||||
{StdEncoding, 4, 8},
|
{StdEncoding, 4, 8},
|
||||||
{StdEncoding, 7, 12},
|
{StdEncoding, 7, 12},
|
||||||
} {
|
}
|
||||||
if got := tt.enc.EncodedLen(tt.n); got != tt.want {
|
// check overflow
|
||||||
|
switch strconv.IntSize {
|
||||||
|
case 32:
|
||||||
|
tests = append(tests, test{RawStdEncoding, (math.MaxInt-5)/8 + 1, 357913942})
|
||||||
|
tests = append(tests, test{RawStdEncoding, math.MaxInt/4*3 + 2, math.MaxInt})
|
||||||
|
case 64:
|
||||||
|
tests = append(tests, test{RawStdEncoding, (math.MaxInt-5)/8 + 1, 1537228672809129302})
|
||||||
|
tests = append(tests, test{RawStdEncoding, math.MaxInt/4*3 + 2, math.MaxInt})
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := tt.enc.EncodedLen(tt.n); int64(got) != tt.want {
|
||||||
t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want)
|
t.Errorf("EncodedLen(%d): got %d, want %d", tt.n, got, tt.want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDecodedLen(t *testing.T) {
|
func TestDecodedLen(t *testing.T) {
|
||||||
for _, tt := range []struct {
|
type test struct {
|
||||||
enc *Encoding
|
enc *Encoding
|
||||||
n int
|
n int
|
||||||
want int
|
want int64
|
||||||
}{
|
}
|
||||||
|
tests := []test{
|
||||||
{RawStdEncoding, 0, 0},
|
{RawStdEncoding, 0, 0},
|
||||||
{RawStdEncoding, 2, 1},
|
{RawStdEncoding, 2, 1},
|
||||||
{RawStdEncoding, 3, 2},
|
{RawStdEncoding, 3, 2},
|
||||||
@ -299,8 +313,18 @@ func TestDecodedLen(t *testing.T) {
|
|||||||
{StdEncoding, 0, 0},
|
{StdEncoding, 0, 0},
|
||||||
{StdEncoding, 4, 3},
|
{StdEncoding, 4, 3},
|
||||||
{StdEncoding, 8, 6},
|
{StdEncoding, 8, 6},
|
||||||
} {
|
}
|
||||||
if got := tt.enc.DecodedLen(tt.n); got != tt.want {
|
// check overflow
|
||||||
|
switch strconv.IntSize {
|
||||||
|
case 32:
|
||||||
|
tests = append(tests, test{RawStdEncoding, math.MaxInt/6 + 1, 268435456})
|
||||||
|
tests = append(tests, test{RawStdEncoding, math.MaxInt, 1610612735})
|
||||||
|
case 64:
|
||||||
|
tests = append(tests, test{RawStdEncoding, math.MaxInt/6 + 1, 1152921504606846976})
|
||||||
|
tests = append(tests, test{RawStdEncoding, math.MaxInt, 6917529027641081855})
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
if got := tt.enc.DecodedLen(tt.n); int64(got) != tt.want {
|
||||||
t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want)
|
t.Errorf("DecodedLen(%d): got %d, want %d", tt.n, got, tt.want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user