diff --git a/src/encoding/hex/hex.go b/src/encoding/hex/hex.go index 18e0c09ef3..f47b7fa34e 100644 --- a/src/encoding/hex/hex.go +++ b/src/encoding/hex/hex.go @@ -58,11 +58,11 @@ func Decode(dst, src []byte) (int, error) { for i := 0; i < len(src)/2; i++ { a, ok := fromHexChar(src[i*2]) if !ok { - return 0, InvalidByteError(src[i*2]) + return i, InvalidByteError(src[i*2]) } b, ok := fromHexChar(src[i*2+1]) if !ok { - return 0, InvalidByteError(src[i*2+1]) + return i, InvalidByteError(src[i*2+1]) } dst[i] = (a << 4) | b } @@ -113,6 +113,77 @@ func Dump(data []byte) string { return buf.String() } +// bufferSize is the number of hexadecimal characters to buffer in encoder and decoder. +const bufferSize = 1024 + +type encoder struct { + w io.Writer + err error + out [bufferSize]byte // output buffer +} + +// NewEncoder returns an io.Writer that writes lowercase hexadecimal characters to w. +func NewEncoder(w io.Writer) io.Writer { + return &encoder{w: w} +} + +func (e *encoder) Write(p []byte) (n int, err error) { + for len(p) > 0 && e.err == nil { + chunkSize := bufferSize / 2 + if len(p) < chunkSize { + chunkSize = len(p) + } + + var written int + encoded := Encode(e.out[:], p[:chunkSize]) + written, e.err = e.w.Write(e.out[:encoded]) + n += written / 2 + p = p[chunkSize:] + } + return n, e.err +} + +type decoder struct { + r io.Reader + err error + in []byte // input buffer (encoded form) + arr [bufferSize]byte // backing array for in +} + +// NewDecoder returns an io.Reader that decodes hexadecimal characters from r. +// NewDecoder expects that r contain only an even number of hexadecimal characters. +func NewDecoder(r io.Reader) io.Reader { + return &decoder{r: r} +} + +func (d *decoder) Read(p []byte) (n int, err error) { + // Fill internal buffer with sufficient bytes to decode + if len(d.in) < 2 && d.err == nil { + var numCopy, numRead int + numCopy = copy(d.arr[:], d.in) // Copies either 0 or 1 bytes + numRead, d.err = d.r.Read(d.arr[numCopy:]) + d.in = d.arr[:numCopy+numRead] + if d.err == io.EOF && len(d.in)%2 != 0 { + d.err = io.ErrUnexpectedEOF + } + } + + // Decode internal buffer into output buffer + if numAvail := len(d.in) / 2; len(p) > numAvail { + p = p[:numAvail] + } + numDec, err := Decode(p, d.in[:len(p)*2]) + d.in = d.in[2*numDec:] + if err != nil { + d.in, d.err = nil, err // Decode error; discard input remainder + } + + if len(d.in) < 2 { + return numDec, d.err // Only expose errors when buffer fully consumed + } + return numDec, nil +} + // Dumper returns a WriteCloser that writes a hex dump of all written data to // w. The format of the dump matches the output of `hexdump -C` on the command // line. diff --git a/src/encoding/hex/hex_test.go b/src/encoding/hex/hex_test.go index e6dc765c95..d874b39e95 100644 --- a/src/encoding/hex/hex_test.go +++ b/src/encoding/hex/hex_test.go @@ -7,6 +7,9 @@ package hex import ( "bytes" "fmt" + "io" + "io/ioutil" + "strings" "testing" ) @@ -111,6 +114,63 @@ func TestInvalidStringErr(t *testing.T) { } } +func TestEncoderDecoder(t *testing.T) { + for _, multiplier := range []int{1, 128, 192} { + for _, test := range encDecTests { + input := bytes.Repeat(test.dec, multiplier) + output := strings.Repeat(test.enc, multiplier) + + var buf bytes.Buffer + enc := NewEncoder(&buf) + r := struct{ io.Reader }{bytes.NewReader(input)} // io.Reader only; not io.WriterTo + if n, err := io.CopyBuffer(enc, r, make([]byte, 7)); n != int64(len(input)) || err != nil { + t.Errorf("encoder.Write(%q*%d) = (%d, %v), want (%d, nil)", test.dec, multiplier, n, err, len(input)) + continue + } + + if encDst := buf.String(); encDst != output { + t.Errorf("buf(%q*%d) = %v, want %v", test.dec, multiplier, encDst, output) + continue + } + + dec := NewDecoder(&buf) + var decBuf bytes.Buffer + w := struct{ io.Writer }{&decBuf} // io.Writer only; not io.ReaderFrom + if _, err := io.CopyBuffer(w, dec, make([]byte, 7)); err != nil || decBuf.Len() != len(input) { + t.Errorf("decoder.Read(%q*%d) = (%d, %v), want (%d, nil)", test.enc, multiplier, decBuf.Len(), err, len(input)) + } + + if !bytes.Equal(decBuf.Bytes(), input) { + t.Errorf("decBuf(%q*%d) = %v, want %v", test.dec, multiplier, decBuf.Bytes(), input) + continue + } + } + } +} + +func TestDecodeErr(t *testing.T) { + tests := []struct { + in string + wantOut string + wantErr error + }{ + {"", "", nil}, + {"0", "", io.ErrUnexpectedEOF}, + {"0g", "", InvalidByteError('g')}, + {"00gg", "\x00", InvalidByteError('g')}, + {"0\x01", "", InvalidByteError('\x01')}, + {"ffeed", "\xff\xee", io.ErrUnexpectedEOF}, + } + + for _, tt := range tests { + dec := NewDecoder(strings.NewReader(tt.in)) + got, err := ioutil.ReadAll(dec) + if string(got) != tt.wantOut || err != tt.wantErr { + t.Errorf("NewDecoder(%q) = (%q, %v), want (%q, %v)", tt.in, got, err, tt.wantOut, tt.wantErr) + } + } +} + func TestDumper(t *testing.T) { var in [40]byte for i := range in {