1
0
mirror of https://github.com/golang/go synced 2024-11-18 15:54:42 -07:00

encoding/base64: add unpadded encodings, and test all encodings.

Some applications use unpadded base64 format, omitting the trailing
'=' padding characters from the standard base64 format, either to
minimize size or (more justifiably) to avoid use of the '=' character.
Unpadded flavors are standard and documented in section 3.2 of RFC 4648.

To support these unpadded flavors, this change adds two predefined
encoding variables, RawStdEncoding and RawURLEncoding, for unpadded
encodings using the standard and URL character set, respectively.
The change also adds a function WithPadding() to customize the padding
character or disable padding in a custom Encoding.

Finally, I noticed that the existing base64 test-suite was only
exercising the StdEncoding, and not referencing URLEncoding at all.
This change adds test-suite functionality to exercise all four encodings
(the two existing ones and the two new unpadded flavors),
although it still doesn't run *every* test on all four encodings.

Naming: I used the "Raw" prefix because it's more concise than "Unpadded"
and seemed just as expressive, but I have no strong preferences here.
Another short alternative prefix would be "Min" ("minimal" encoding).

Change-Id: Ic0423e02589b39a6b2bb7d0763bd073fd244f469
Reviewed-on: https://go-review.googlesource.com/1511
Reviewed-by: Russ Cox <rsc@golang.org>
Reviewed-by: Minux Ma <minux@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
This commit is contained in:
Bryan Ford 2014-12-13 13:54:39 -05:00 committed by Brad Fitzpatrick
parent f3c85c507b
commit 2e0a1a7573
2 changed files with 135 additions and 41 deletions

View File

@ -24,16 +24,25 @@ import (
type Encoding struct {
encode string
decodeMap [256]byte
padChar rune
}
const (
StdPadding rune = '=' // Standard padding character
NoPadding rune = -1 // No padding
)
const encodeStd = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
const encodeURL = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
// NewEncoding returns a new Encoding defined by the given alphabet,
// NewEncoding returns a new padded Encoding defined by the given alphabet,
// which must be a 64-byte string.
// The resulting Encoding uses the default padding character ('='),
// which may be changed or disabled via WithPadding.
func NewEncoding(encoder string) *Encoding {
e := new(Encoding)
e.encode = encoder
e.padChar = StdPadding
for i := 0; i < len(e.decodeMap); i++ {
e.decodeMap[i] = 0xFF
}
@ -43,6 +52,13 @@ func NewEncoding(encoder string) *Encoding {
return e
}
// WithPadding creates a new encoding identical to enc except
// with a specified padding character, or NoPadding to disable padding.
func (enc Encoding) WithPadding(padding rune) *Encoding {
enc.padChar = padding
return &enc
}
// StdEncoding is the standard base64 encoding, as defined in
// RFC 4648.
var StdEncoding = NewEncoding(encodeStd)
@ -51,6 +67,16 @@ var StdEncoding = NewEncoding(encodeStd)
// It is typically used in URLs and file names.
var URLEncoding = NewEncoding(encodeURL)
// RawStdEncoding is the standard raw, unpadded base64 encoding,
// as defined in RFC 4648 section 3.2.
// This is the same as StdEncoding but omits padding characters.
var RawStdEncoding = StdEncoding.WithPadding(NoPadding)
// URLEncoding is the unpadded alternate base64 encoding defined in RFC 4648.
// It is typically used in URLs and file names.
// This is the same as URLEncoding but omits padding characters.
var RawURLEncoding = URLEncoding.WithPadding(NoPadding)
var removeNewlinesMapper = func(r rune) rune {
if r == '\r' || r == '\n' {
return -1
@ -95,14 +121,18 @@ func (enc *Encoding) Encode(dst, src []byte) {
// Encode 6-bit blocks using the base64 alphabet
dst[0] = enc.encode[b0]
dst[1] = enc.encode[b1]
dst[2] = enc.encode[b2]
dst[3] = enc.encode[b3]
// Pad the final quantum
if len(src) < 3 {
dst[3] = '='
if len(src) < 2 {
dst[2] = '='
if len(src) >= 3 {
dst[2] = enc.encode[b2]
dst[3] = enc.encode[b3]
} else { // Final incomplete quantum
if len(src) >= 2 {
dst[2] = enc.encode[b2]
}
if enc.padChar != NoPadding {
if len(src) < 2 {
dst[2] = byte(enc.padChar)
}
dst[3] = byte(enc.padChar)
}
break
}
@ -145,8 +175,8 @@ func (e *encoder) Write(p []byte) (n int, err error) {
if e.nbuf < 3 {
return
}
e.enc.Encode(e.out[0:], e.buf[0:])
if _, e.err = e.w.Write(e.out[0:4]); e.err != nil {
e.enc.Encode(e.out[:], e.buf[:])
if _, e.err = e.w.Write(e.out[:4]); e.err != nil {
return n, e.err
}
e.nbuf = 0
@ -159,7 +189,7 @@ func (e *encoder) Write(p []byte) (n int, err error) {
nn = len(p)
nn -= nn % 3
}
e.enc.Encode(e.out[0:], p[0:nn])
e.enc.Encode(e.out[:], p[:nn])
if _, e.err = e.w.Write(e.out[0 : nn/3*4]); e.err != nil {
return n, e.err
}
@ -181,9 +211,9 @@ func (e *encoder) Write(p []byte) (n int, err error) {
func (e *encoder) Close() error {
// If there's anything left in the buffer, flush it out
if e.err == nil && e.nbuf > 0 {
e.enc.Encode(e.out[0:], e.buf[0:e.nbuf])
e.enc.Encode(e.out[:], e.buf[:e.nbuf])
_, e.err = e.w.Write(e.out[:e.enc.EncodedLen(e.nbuf)])
e.nbuf = 0
_, e.err = e.w.Write(e.out[0:4])
}
return e.err
}
@ -199,7 +229,12 @@ func NewEncoder(enc *Encoding, w io.Writer) io.WriteCloser {
// EncodedLen returns the length in bytes of the base64 encoding
// of an input buffer of length n.
func (enc *Encoding) EncodedLen(n int) int { return (n + 2) / 3 * 4 }
func (enc *Encoding) EncodedLen(n int) int {
if enc.padChar == NoPadding {
return (n*8 + 5) / 6 // minimum # chars at 6 bits per char
}
return (n + 2) / 3 * 4 // minimum # 4-char quanta, 3 bytes each
}
/*
* Decoder
@ -212,23 +247,27 @@ func (e CorruptInputError) Error() string {
}
// decode is like Decode but returns an additional 'end' value, which
// indicates if end-of-message padding was encountered and thus any
// additional data is an error. This method assumes that src has been
// indicates if end-of-message padding or a partial quantum was encountered
// and thus any additional data is an error. This method assumes that src has been
// stripped of all supported whitespace ('\r' and '\n').
func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
olen := len(src)
for len(src) > 0 && !end {
// Decode quantum using the base64 alphabet
var dbuf [4]byte
dlen := 4
dinc, dlen := 3, 4
for j := range dbuf {
if len(src) == 0 {
return n, false, CorruptInputError(olen - len(src) - j)
if enc.padChar != NoPadding || j < 2 {
return n, false, CorruptInputError(olen - len(src) - j)
}
dinc, dlen, end = j-1, j, true
break
}
in := src[0]
src = src[1:]
if in == '=' {
if rune(in) == enc.padChar {
// We've reached the end and there's padding
switch j {
case 0, 1:
@ -240,7 +279,7 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// not enough padding
return n, false, CorruptInputError(olen)
}
if src[0] != '=' {
if rune(src[0]) != enc.padChar {
// incorrect padding
return n, false, CorruptInputError(olen - len(src) - 1)
}
@ -250,7 +289,7 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
// trailing garbage
err = CorruptInputError(olen - len(src))
}
dlen, end = j, true
dinc, dlen, end = 3, j, true
break
}
dbuf[j] = enc.decodeMap[in]
@ -271,7 +310,7 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {
case 2:
dst[0] = dbuf[0]<<2 | dbuf[1]>>4
}
dst = dst[3:]
dst = dst[dinc:]
n += dlen - 1
}
@ -338,12 +377,12 @@ func (d *decoder) Read(p []byte) (n int, err error) {
nr := d.nbuf / 4 * 4
nw := d.nbuf / 4 * 3
if nw > len(p) {
nw, d.end, d.err = d.enc.decode(d.outbuf[0:], d.buf[0:nr])
d.out = d.outbuf[0:nw]
nw, d.end, d.err = d.enc.decode(d.outbuf[:], d.buf[:nr])
d.out = d.outbuf[:nw]
n = copy(p, d.out)
d.out = d.out[n:]
} else {
n, d.end, d.err = d.enc.decode(p, d.buf[0:nr])
n, d.end, d.err = d.enc.decode(p, d.buf[:nr])
}
d.nbuf -= nr
for i := 0; i < d.nbuf; i++ {
@ -364,7 +403,7 @@ func (r *newlineFilteringReader) Read(p []byte) (int, error) {
n, err := r.wrapped.Read(p)
for n > 0 {
offset := 0
for i, b := range p[0:n] {
for i, b := range p[:n] {
if b != '\r' && b != '\n' {
if i != offset {
p[offset] = b
@ -388,4 +427,11 @@ func NewDecoder(enc *Encoding, r io.Reader) io.Reader {
// DecodedLen returns the maximum length in bytes of the decoded data
// corresponding to n bytes of base64-encoded data.
func (enc *Encoding) DecodedLen(n int) int { return n / 4 * 3 }
func (enc *Encoding) DecodedLen(n int) int {
if enc.padChar == NoPadding {
// Unpadded data may end with partial block of 2-3 characters.
return (n*6 + 7) / 8
}
// Padded base64 should always be a multiple of 4 characters in length.
return n / 4 * 3
}

View File

@ -45,6 +45,48 @@ var pairs = []testpair{
{"sure.", "c3VyZS4="},
}
// Do nothing to a reference base64 string (leave in standard format)
func stdRef(ref string) string {
return ref
}
// Convert a reference string to URL-encoding
func urlRef(ref string) string {
ref = strings.Replace(ref, "+", "-", -1)
ref = strings.Replace(ref, "/", "_", -1)
return ref
}
// Convert a reference string to raw, unpadded format
func rawRef(ref string) string {
return strings.TrimRight(ref, "=")
}
// Both URL and unpadding conversions
func rawUrlRef(ref string) string {
return rawRef(urlRef(ref))
}
// A nonstandard encoding with a funny padding character, for testing
var funnyEncoding = NewEncoding(encodeStd).WithPadding(rune('@'))
func funnyRef(ref string) string {
return strings.Replace(ref, "=", "@", -1)
}
type encodingTest struct {
enc *Encoding // Encoding to test
conv func(string) string // Reference string converter
}
var encodingTests = []encodingTest{
encodingTest{StdEncoding, stdRef},
encodingTest{URLEncoding, urlRef},
encodingTest{RawStdEncoding, rawRef},
encodingTest{RawURLEncoding, rawUrlRef},
encodingTest{funnyEncoding, funnyRef},
}
var bigtest = testpair{
"Twas brillig, and the slithy toves",
"VHdhcyBicmlsbGlnLCBhbmQgdGhlIHNsaXRoeSB0b3Zlcw==",
@ -60,8 +102,11 @@ func testEqual(t *testing.T, msg string, args ...interface{}) bool {
func TestEncode(t *testing.T) {
for _, p := range pairs {
got := StdEncoding.EncodeToString([]byte(p.decoded))
testEqual(t, "Encode(%q) = %q, want %q", p.decoded, got, p.encoded)
for _, tt := range encodingTests {
got := tt.enc.EncodeToString([]byte(p.decoded))
testEqual(t, "Encode(%q) = %q, want %q", p.decoded,
got, tt.conv(p.encoded))
}
}
}
@ -97,18 +142,21 @@ func TestEncoderBuffering(t *testing.T) {
func TestDecode(t *testing.T) {
for _, p := range pairs {
dbuf := make([]byte, StdEncoding.DecodedLen(len(p.encoded)))
count, end, err := StdEncoding.decode(dbuf, []byte(p.encoded))
testEqual(t, "Decode(%q) = error %v, want %v", p.encoded, err, error(nil))
testEqual(t, "Decode(%q) = length %v, want %v", p.encoded, count, len(p.decoded))
if len(p.encoded) > 0 {
testEqual(t, "Decode(%q) = end %v, want %v", p.encoded, end, (p.encoded[len(p.encoded)-1] == '='))
}
testEqual(t, "Decode(%q) = %q, want %q", p.encoded, string(dbuf[0:count]), p.decoded)
for _, tt := range encodingTests {
encoded := tt.conv(p.encoded)
dbuf := make([]byte, tt.enc.DecodedLen(len(encoded)))
count, end, err := tt.enc.decode(dbuf, []byte(encoded))
testEqual(t, "Decode(%q) = error %v, want %v", encoded, err, error(nil))
testEqual(t, "Decode(%q) = length %v, want %v", encoded, count, len(p.decoded))
if len(encoded) > 0 {
testEqual(t, "Decode(%q) = end %v, want %v", encoded, end, len(p.decoded)%3 != 0)
}
testEqual(t, "Decode(%q) = %q, want %q", encoded, string(dbuf[0:count]), p.decoded)
dbuf, err = StdEncoding.DecodeString(p.encoded)
testEqual(t, "DecodeString(%q) = error %v, want %v", p.encoded, err, error(nil))
testEqual(t, "DecodeString(%q) = %q, want %q", string(dbuf), p.decoded)
dbuf, err = tt.enc.DecodeString(encoded)
testEqual(t, "DecodeString(%q) = error %v, want %v", encoded, err, error(nil))
testEqual(t, "DecodeString(%q) = %q, want %q", string(dbuf), p.decoded)
}
}
}