1
0
mirror of https://github.com/golang/go synced 2024-11-23 02:00:03 -07:00

encoding: optimize growth behavior in Encoding.AppendDecode

The Encoding.DecodedLen API only returns the maximum length of the
expected decoded output, since it does not know about padding.
Since we have the input, we can do better by computing the
input length without padding, and then perform the DecodedLen
calculation as if there were no padding.

This avoids over-growing the destination slice if possible.
Over-growth is still possible since the input may contain
ignore characters like newlines and carriage returns,
but those a rarely encountered in practice.

Change-Id: I38b8f91de1f4fbd3a7128c491a25098bd385cf74
Reviewed-on: https://go-review.googlesource.com/c/go/+/520267
Run-TryBot: Joseph Tsai <joetsai@digital-static.net>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Auto-Submit: Joseph Tsai <joetsai@digital-static.net>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
This commit is contained in:
Joe Tsai 2023-08-16 21:27:15 -07:00 committed by Gopher Robot
parent e47fad515d
commit e8cdab5c49
4 changed files with 38 additions and 4 deletions

View File

@ -402,7 +402,13 @@ func (enc *Encoding) Decode(dst, src []byte) (n int, err error) {
// and returns the extended buffer. // and returns the extended buffer.
// If the input is malformed, it returns the partially decoded src and an error. // If the input is malformed, it returns the partially decoded src and an error.
func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) { func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) {
n := enc.DecodedLen(len(src)) // Compute the output size without padding to avoid over allocating.
n := len(src)
for n > 0 && rune(src[n-1]) == enc.padChar {
n--
}
n = decodedLen(n, NoPadding)
dst = slices.Grow(dst, n) dst = slices.Grow(dst, n)
n, err := enc.Decode(dst[len(dst):][:n], src) n, err := enc.Decode(dst[len(dst):][:n], src)
return dst[:len(dst)+n], err return dst[:len(dst)+n], err
@ -567,7 +573,11 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
// DecodedLen returns the maximum length in bytes of the decoded data // DecodedLen returns the maximum length in bytes of the decoded data
// corresponding to n bytes of base32-encoded data. // corresponding to n bytes of base32-encoded data.
func (enc *Encoding) DecodedLen(n int) int { func (enc *Encoding) DecodedLen(n int) int {
if enc.padChar == NoPadding { return decodedLen(n, enc.padChar)
}
func decodedLen(n int, padChar rune) int {
if padChar == NoPadding {
return n/8*5 + n%8*5/8 return n/8*5 + n%8*5/8
} }
return n / 8 * 5 return n / 8 * 5

View File

@ -110,6 +110,13 @@ func TestDecode(t *testing.T) {
dst, err := StdEncoding.AppendDecode([]byte("lead"), []byte(p.encoded)) dst, err := StdEncoding.AppendDecode([]byte("lead"), []byte(p.encoded))
testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil)) testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil))
testEqual(t, `AppendDecode("lead", %q) = %q, want %q`, p.encoded, string(dst), "lead"+p.decoded) testEqual(t, `AppendDecode("lead", %q) = %q, want %q`, p.encoded, string(dst), "lead"+p.decoded)
dst2, err := StdEncoding.AppendDecode(dst[:0:len(p.decoded)], []byte(p.encoded))
testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil))
testEqual(t, `AppendDecode("", %q) = %q, want %q`, p.encoded, string(dst2), p.decoded)
if len(dst) > 0 && len(dst2) > 0 && &dst[0] != &dst2[0] {
t.Errorf("unexpected capacity growth: got %d, want %d", cap(dst2), cap(dst))
}
} }
} }

View File

@ -411,7 +411,13 @@ func (enc *Encoding) decodeQuantum(dst, src []byte, si int) (nsi, n int, err err
// and returns the extended buffer. // and returns the extended buffer.
// If the input is malformed, it returns the partially decoded src and an error. // If the input is malformed, it returns the partially decoded src and an error.
func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) { func (enc *Encoding) AppendDecode(dst, src []byte) ([]byte, error) {
n := enc.DecodedLen(len(src)) // Compute the output size without padding to avoid over allocating.
n := len(src)
for n > 0 && rune(src[n-1]) == enc.padChar {
n--
}
n = decodedLen(n, NoPadding)
dst = slices.Grow(dst, n) dst = slices.Grow(dst, n)
n, err := enc.Decode(dst[len(dst):][:n], src) n, err := enc.Decode(dst[len(dst):][:n], src)
return dst[:len(dst)+n], err return dst[:len(dst)+n], err
@ -643,7 +649,11 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
// DecodedLen returns the maximum length in bytes of the decoded data // DecodedLen returns the maximum length in bytes of the decoded data
// corresponding to n bytes of base64-encoded data. // corresponding to n bytes of base64-encoded data.
func (enc *Encoding) DecodedLen(n int) int { func (enc *Encoding) DecodedLen(n int) int {
if enc.padChar == NoPadding { return decodedLen(n, enc.padChar)
}
func decodedLen(n int, padChar rune) int {
if padChar == NoPadding {
// Unpadded data may end with partial block of 2-3 characters. // Unpadded data may end with partial block of 2-3 characters.
return n/4*3 + n%4*6/8 return n/4*3 + n%4*6/8
} }

View File

@ -167,6 +167,13 @@ func TestDecode(t *testing.T) {
dst, err := tt.enc.AppendDecode([]byte("lead"), []byte(encoded)) dst, err := tt.enc.AppendDecode([]byte("lead"), []byte(encoded))
testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil)) testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil))
testEqual(t, `AppendDecode("lead", %q) = %q, want %q`, p.encoded, string(dst), "lead"+p.decoded) testEqual(t, `AppendDecode("lead", %q) = %q, want %q`, p.encoded, string(dst), "lead"+p.decoded)
dst2, err := tt.enc.AppendDecode(dst[:0:len(p.decoded)], []byte(encoded))
testEqual(t, "AppendDecode(%q) = error %v, want %v", p.encoded, err, error(nil))
testEqual(t, `AppendDecode("", %q) = %q, want %q`, p.encoded, string(dst2), p.decoded)
if len(dst) > 0 && len(dst2) > 0 && &dst[0] != &dst2[0] {
t.Errorf("unexpected capacity growth: got %d, want %d", cap(dst2), cap(dst))
}
} }
} }
} }