mirror of
https://github.com/golang/go
synced 2024-11-26 14:08:37 -07:00
internal/zstd: new internal package for zstd decompression
This package only does zstd decompression, which is starting to be used for ELF debug sections. If we need zstd compression we should use github.com/klauspost/compress/zstd. But for now that is a very large package to vendor into the standard library. For #55107 Change-Id: I60ede735357d491be653477ed419cf5f2f0d3f71 Reviewed-on: https://go-review.googlesource.com/c/go/+/473356 Reviewed-by: Ian Lance Taylor <iant@google.com> Run-TryBot: Ian Lance Taylor <iant@google.com> Run-TryBot: Ian Lance Taylor <iant@golang.org> Reviewed-by: Joseph Tsai <joetsai@digital-static.net> TryBot-Result: Gopher Robot <gobot@golang.org> Reviewed-by: Bryan Mills <bcmills@google.com> Auto-Submit: Ian Lance Taylor <iant@google.com>
This commit is contained in:
parent
d5514013b6
commit
73ee0fcf37
@ -226,7 +226,7 @@ var depsRules = `
|
||||
|
||||
# compression
|
||||
FMT, encoding/binary, hash/adler32, hash/crc32
|
||||
< compress/bzip2, compress/flate, compress/lzw
|
||||
< compress/bzip2, compress/flate, compress/lzw, internal/zstd
|
||||
< archive/zip, compress/gzip, compress/zlib;
|
||||
|
||||
# templates
|
||||
|
130
src/internal/zstd/bits.go
Normal file
130
src/internal/zstd/bits.go
Normal file
@ -0,0 +1,130 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// block is the data for a single compressed block.
|
||||
// The data starts immediately after the 3 byte block header,
|
||||
// and is Block_Size bytes long.
|
||||
type block []byte
|
||||
|
||||
// bitReader reads a bit stream going forward.
|
||||
type bitReader struct {
|
||||
r *Reader // for error reporting
|
||||
data block // the bits to read
|
||||
off uint32 // current offset into data
|
||||
bits uint32 // bits ready to be returned
|
||||
cnt uint32 // number of valid bits in the bits field
|
||||
}
|
||||
|
||||
// makeBitReader makes a bit reader starting at off.
|
||||
func (r *Reader) makeBitReader(data block, off int) bitReader {
|
||||
return bitReader{
|
||||
r: r,
|
||||
data: data,
|
||||
off: uint32(off),
|
||||
}
|
||||
}
|
||||
|
||||
// moreBits is called to read more bits.
|
||||
// This ensures that at least 16 bits are available.
|
||||
func (br *bitReader) moreBits() error {
|
||||
for br.cnt < 16 {
|
||||
if br.off >= uint32(len(br.data)) {
|
||||
return br.r.makeEOFError(int(br.off))
|
||||
}
|
||||
c := br.data[br.off]
|
||||
br.off++
|
||||
br.bits |= uint32(c) << br.cnt
|
||||
br.cnt += 8
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// val is called to fetch a value of b bits.
|
||||
func (br *bitReader) val(b uint8) uint32 {
|
||||
r := br.bits & ((1 << b) - 1)
|
||||
br.bits >>= b
|
||||
br.cnt -= uint32(b)
|
||||
return r
|
||||
}
|
||||
|
||||
// backup steps back to the last byte we used.
|
||||
func (br *bitReader) backup() {
|
||||
for br.cnt >= 8 {
|
||||
br.off--
|
||||
br.cnt -= 8
|
||||
}
|
||||
}
|
||||
|
||||
// makeError returns an error at the current offset wrapping a string.
|
||||
func (br *bitReader) makeError(msg string) error {
|
||||
return br.r.makeError(int(br.off), msg)
|
||||
}
|
||||
|
||||
// reverseBitReader reads a bit stream in reverse.
|
||||
type reverseBitReader struct {
|
||||
r *Reader // for error reporting
|
||||
data block // the bits to read
|
||||
off uint32 // current offset into data
|
||||
start uint32 // start in data; we read backward to start
|
||||
bits uint32 // bits ready to be returned
|
||||
cnt uint32 // number of valid bits in bits field
|
||||
}
|
||||
|
||||
// makeReverseBitReader makes a reverseBitReader reading backward
|
||||
// from off to start. The bitstream starts with a 1 bit in the last
|
||||
// byte, at off.
|
||||
func (r *Reader) makeReverseBitReader(data block, off, start int) (reverseBitReader, error) {
|
||||
streamStart := data[off]
|
||||
if streamStart == 0 {
|
||||
return reverseBitReader{}, r.makeError(off, "zero byte at reverse bit stream start")
|
||||
}
|
||||
rbr := reverseBitReader{
|
||||
r: r,
|
||||
data: data,
|
||||
off: uint32(off),
|
||||
start: uint32(start),
|
||||
bits: uint32(streamStart),
|
||||
cnt: uint32(7 - bits.LeadingZeros8(streamStart)),
|
||||
}
|
||||
return rbr, nil
|
||||
}
|
||||
|
||||
// val is called to fetch a value of b bits.
|
||||
func (rbr *reverseBitReader) val(b uint8) (uint32, error) {
|
||||
if !rbr.fetch(b) {
|
||||
return 0, rbr.r.makeEOFError(int(rbr.off))
|
||||
}
|
||||
|
||||
rbr.cnt -= uint32(b)
|
||||
v := (rbr.bits >> rbr.cnt) & ((1 << b) - 1)
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// fetch is called to ensure that at least b bits are available.
|
||||
// It reports false if this can't be done,
|
||||
// in which case only rbr.cnt bits are available.
|
||||
func (rbr *reverseBitReader) fetch(b uint8) bool {
|
||||
for rbr.cnt < uint32(b) {
|
||||
if rbr.off <= rbr.start {
|
||||
return false
|
||||
}
|
||||
rbr.off--
|
||||
c := rbr.data[rbr.off]
|
||||
rbr.bits <<= 8
|
||||
rbr.bits |= uint32(c)
|
||||
rbr.cnt += 8
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// makeError returns an error at the current offset wrapping a string.
|
||||
func (rbr *reverseBitReader) makeError(msg string) error {
|
||||
return rbr.r.makeError(int(rbr.off), msg)
|
||||
}
|
436
src/internal/zstd/block.go
Normal file
436
src/internal/zstd/block.go
Normal file
@ -0,0 +1,436 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"io"
|
||||
)
|
||||
|
||||
// debug can be set in the source to print debug info using println.
|
||||
const debug = false
|
||||
|
||||
// compressedBlock decompresses a compressed block, storing the decompressed
|
||||
// data in r.buffer. The blockSize argument is the compressed size.
|
||||
// RFC 3.1.1.3.
|
||||
func (r *Reader) compressedBlock(blockSize int) error {
|
||||
if len(r.compressedBuf) >= blockSize {
|
||||
r.compressedBuf = r.compressedBuf[:blockSize]
|
||||
} else {
|
||||
// We know that blockSize <= 128K,
|
||||
// so this won't allocate an enormous amount.
|
||||
need := blockSize - len(r.compressedBuf)
|
||||
r.compressedBuf = append(r.compressedBuf, make([]byte, need)...)
|
||||
}
|
||||
|
||||
if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil {
|
||||
return r.wrapNonEOFError(0, err)
|
||||
}
|
||||
|
||||
data := block(r.compressedBuf)
|
||||
off := 0
|
||||
r.buffer = r.buffer[:0]
|
||||
|
||||
litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r.literals = litbuf
|
||||
|
||||
off = litoff
|
||||
|
||||
seqCount, off, err := r.initSeqs(data, off)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if seqCount == 0 {
|
||||
// No sequences, just literals.
|
||||
if off < len(data) {
|
||||
return r.makeError(off, "extraneous data after no sequences")
|
||||
}
|
||||
if len(litbuf) == 0 {
|
||||
return r.makeError(off, "no sequences and no literals")
|
||||
}
|
||||
r.buffer = append(r.buffer, litbuf...)
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.execSeqs(data, off, litbuf, seqCount)
|
||||
}
|
||||
|
||||
// seqCode is the kind of sequence codes we have to handle.
|
||||
type seqCode int
|
||||
|
||||
const (
|
||||
seqLiteral seqCode = iota
|
||||
seqOffset
|
||||
seqMatch
|
||||
)
|
||||
|
||||
// seqCodeInfoData is the information needed to set up seqTables and
|
||||
// seqTableBits for a particular kind of sequence code.
|
||||
type seqCodeInfoData struct {
|
||||
predefTable []fseBaselineEntry // predefined FSE
|
||||
predefTableBits int // number of bits in predefTable
|
||||
maxSym int // max symbol value in FSE
|
||||
maxBits int // max bits for FSE
|
||||
|
||||
// toBaseline converts from an FSE table to an FSE baseline table.
|
||||
toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
|
||||
}
|
||||
|
||||
// seqCodeInfo is the seqCodeInfoData for each kind of sequence code.
|
||||
var seqCodeInfo = [3]seqCodeInfoData{
|
||||
seqLiteral: {
|
||||
predefTable: predefinedLiteralTable[:],
|
||||
predefTableBits: 6,
|
||||
maxSym: 35,
|
||||
maxBits: 9,
|
||||
toBaseline: (*Reader).makeLiteralBaselineFSE,
|
||||
},
|
||||
seqOffset: {
|
||||
predefTable: predefinedOffsetTable[:],
|
||||
predefTableBits: 5,
|
||||
maxSym: 31,
|
||||
maxBits: 8,
|
||||
toBaseline: (*Reader).makeOffsetBaselineFSE,
|
||||
},
|
||||
seqMatch: {
|
||||
predefTable: predefinedMatchTable[:],
|
||||
predefTableBits: 6,
|
||||
maxSym: 52,
|
||||
maxBits: 9,
|
||||
toBaseline: (*Reader).makeMatchBaselineFSE,
|
||||
},
|
||||
}
|
||||
|
||||
// initSeqs reads the Sequences_Section_Header and sets up the FSE
|
||||
// tables used to read the sequence codes. It returns the number of
|
||||
// sequences and the new offset. RFC 3.1.1.3.2.1.
|
||||
func (r *Reader) initSeqs(data block, off int) (int, int, error) {
|
||||
if off >= len(data) {
|
||||
return 0, 0, r.makeEOFError(off)
|
||||
}
|
||||
|
||||
seqHdr := data[off]
|
||||
off++
|
||||
if seqHdr == 0 {
|
||||
return 0, off, nil
|
||||
}
|
||||
|
||||
var seqCount int
|
||||
if seqHdr < 128 {
|
||||
seqCount = int(seqHdr)
|
||||
} else if seqHdr < 255 {
|
||||
if off >= len(data) {
|
||||
return 0, 0, r.makeEOFError(off)
|
||||
}
|
||||
seqCount = ((int(seqHdr) - 128) << 8) + int(data[off])
|
||||
off++
|
||||
} else {
|
||||
if off+1 >= len(data) {
|
||||
return 0, 0, r.makeEOFError(off)
|
||||
}
|
||||
seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00
|
||||
off += 2
|
||||
}
|
||||
|
||||
// Read the Symbol_Compression_Modes byte.
|
||||
|
||||
if off >= len(data) {
|
||||
return 0, 0, r.makeEOFError(off)
|
||||
}
|
||||
symMode := data[off]
|
||||
if symMode&3 != 0 {
|
||||
return 0, 0, r.makeError(off, "invalid symbol compression mode")
|
||||
}
|
||||
off++
|
||||
|
||||
// Set up the FSE tables used to decode the sequence codes.
|
||||
|
||||
var err error
|
||||
off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return seqCount, off, nil
|
||||
}
|
||||
|
||||
// setSeqTable uses the Compression_Mode in mode to set up r.seqTables and
|
||||
// r.seqTableBits for kind. We store these in the Reader because one of
|
||||
// the modes simply reuses the value from the last block in the frame.
|
||||
func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) {
|
||||
info := &seqCodeInfo[kind]
|
||||
switch mode {
|
||||
case 0:
|
||||
// Predefined_Mode
|
||||
r.seqTables[kind] = info.predefTable
|
||||
r.seqTableBits[kind] = uint8(info.predefTableBits)
|
||||
return off, nil
|
||||
|
||||
case 1:
|
||||
// RLE_Mode
|
||||
if off >= len(data) {
|
||||
return 0, r.makeEOFError(off)
|
||||
}
|
||||
rle := data[off]
|
||||
off++
|
||||
|
||||
// Build a simple baseline table that always returns rle.
|
||||
|
||||
entry := []fseEntry{
|
||||
{
|
||||
sym: rle,
|
||||
bits: 0,
|
||||
base: 0,
|
||||
},
|
||||
}
|
||||
if cap(r.seqTableBuffers[kind]) == 0 {
|
||||
r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
|
||||
}
|
||||
r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1]
|
||||
if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
r.seqTables[kind] = r.seqTableBuffers[kind]
|
||||
r.seqTableBits[kind] = 0
|
||||
return off, nil
|
||||
|
||||
case 2:
|
||||
// FSE_Compressed_Mode
|
||||
if cap(r.fseScratch) < 1<<info.maxBits {
|
||||
r.fseScratch = make([]fseEntry, 1<<info.maxBits)
|
||||
}
|
||||
r.fseScratch = r.fseScratch[:1<<info.maxBits]
|
||||
|
||||
tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
r.fseScratch = r.fseScratch[:1<<tableBits]
|
||||
|
||||
if cap(r.seqTableBuffers[kind]) == 0 {
|
||||
r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
|
||||
}
|
||||
r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits]
|
||||
|
||||
if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
r.seqTables[kind] = r.seqTableBuffers[kind]
|
||||
r.seqTableBits[kind] = uint8(tableBits)
|
||||
return roff, nil
|
||||
|
||||
case 3:
|
||||
// Repeat_Mode
|
||||
if len(r.seqTables[kind]) == 0 {
|
||||
return 0, r.makeError(off, "missing repeat sequence FSE table")
|
||||
}
|
||||
return off, nil
|
||||
}
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
// execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2.
|
||||
func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error {
|
||||
// Set up the initial states for the sequence code readers.
|
||||
|
||||
rbr, err := r.makeReverseBitReader(data, len(data)-1, off)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
literalState, err := rbr.val(r.seqTableBits[seqLiteral])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offsetState, err := rbr.val(r.seqTableBits[seqOffset])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
matchState, err := rbr.val(r.seqTableBits[seqMatch])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read and perform all the sequences. RFC 3.1.1.4.
|
||||
|
||||
seq := 0
|
||||
for seq < seqCount {
|
||||
if len(r.buffer)+len(litbuf) > 128<<10 {
|
||||
return rbr.makeError("uncompressed size too big")
|
||||
}
|
||||
|
||||
ptoffset := &r.seqTables[seqOffset][offsetState]
|
||||
ptmatch := &r.seqTables[seqMatch][matchState]
|
||||
ptliteral := &r.seqTables[seqLiteral][literalState]
|
||||
|
||||
add, err := rbr.val(ptoffset.basebits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
offset := ptoffset.baseline + add
|
||||
|
||||
add, err = rbr.val(ptmatch.basebits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
match := ptmatch.baseline + add
|
||||
|
||||
add, err = rbr.val(ptliteral.basebits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
literal := ptliteral.baseline + add
|
||||
|
||||
// Handle repeat offsets. RFC 3.1.1.5.
|
||||
// See the comment in makeOffsetBaselineFSE.
|
||||
if ptoffset.basebits > 1 {
|
||||
r.repeatedOffset3 = r.repeatedOffset2
|
||||
r.repeatedOffset2 = r.repeatedOffset1
|
||||
r.repeatedOffset1 = offset
|
||||
} else {
|
||||
if literal == 0 {
|
||||
offset++
|
||||
}
|
||||
switch offset {
|
||||
case 1:
|
||||
offset = r.repeatedOffset1
|
||||
case 2:
|
||||
offset = r.repeatedOffset2
|
||||
r.repeatedOffset2 = r.repeatedOffset1
|
||||
r.repeatedOffset1 = offset
|
||||
case 3:
|
||||
offset = r.repeatedOffset3
|
||||
r.repeatedOffset3 = r.repeatedOffset2
|
||||
r.repeatedOffset2 = r.repeatedOffset1
|
||||
r.repeatedOffset1 = offset
|
||||
case 4:
|
||||
offset = r.repeatedOffset1 - 1
|
||||
r.repeatedOffset3 = r.repeatedOffset2
|
||||
r.repeatedOffset2 = r.repeatedOffset1
|
||||
r.repeatedOffset1 = offset
|
||||
}
|
||||
}
|
||||
|
||||
seq++
|
||||
if seq < seqCount {
|
||||
// Update the states.
|
||||
add, err = rbr.val(ptliteral.bits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
literalState = uint32(ptliteral.base) + add
|
||||
|
||||
add, err = rbr.val(ptmatch.bits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
matchState = uint32(ptmatch.base) + add
|
||||
|
||||
add, err = rbr.val(ptoffset.bits)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
offsetState = uint32(ptoffset.base) + add
|
||||
}
|
||||
|
||||
// The next sequence is now in literal, offset, match.
|
||||
|
||||
if debug {
|
||||
println("literal", literal, "offset", offset, "match", match)
|
||||
}
|
||||
|
||||
// Copy literal bytes from litbuf.
|
||||
if literal > uint32(len(litbuf)) {
|
||||
return rbr.makeError("literal byte overflow")
|
||||
}
|
||||
if literal > 0 {
|
||||
r.buffer = append(r.buffer, litbuf[:literal]...)
|
||||
litbuf = litbuf[literal:]
|
||||
}
|
||||
|
||||
if match > 0 {
|
||||
if err := r.copyFromWindow(&rbr, offset, match); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(litbuf) > 0 {
|
||||
r.buffer = append(r.buffer, litbuf...)
|
||||
}
|
||||
|
||||
if rbr.cnt != 0 {
|
||||
return r.makeError(off, "extraneous data after sequences")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Copy match bytes from the decoded output, or the window, at offset.
|
||||
func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error {
|
||||
if offset == 0 {
|
||||
return rbr.makeError("invalid zero offset")
|
||||
}
|
||||
|
||||
lenBlock := uint32(len(r.buffer))
|
||||
if lenBlock < offset {
|
||||
lenWindow := uint32(len(r.window))
|
||||
windowOffset := offset - lenBlock
|
||||
if windowOffset > lenWindow {
|
||||
return rbr.makeError("offset past window")
|
||||
}
|
||||
from := lenWindow - windowOffset
|
||||
if from+match <= lenWindow {
|
||||
r.buffer = append(r.buffer, r.window[from:from+match]...)
|
||||
return nil
|
||||
}
|
||||
r.buffer = append(r.buffer, r.window[from:]...)
|
||||
copied := lenWindow - from
|
||||
offset -= copied
|
||||
match -= copied
|
||||
|
||||
if offset == 0 && match > 0 {
|
||||
return rbr.makeError("invalid offset")
|
||||
}
|
||||
}
|
||||
|
||||
from := lenBlock - offset
|
||||
if offset >= match {
|
||||
r.buffer = append(r.buffer, r.buffer[from:from+match]...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// We are being asked to copy data that we are adding to the
|
||||
// buffer in the same copy.
|
||||
for match > 0 {
|
||||
var copy uint32
|
||||
if offset >= match {
|
||||
copy = match
|
||||
} else {
|
||||
copy = offset
|
||||
}
|
||||
r.buffer = append(r.buffer, r.buffer[from:from+copy]...)
|
||||
match -= copy
|
||||
from += copy
|
||||
}
|
||||
return nil
|
||||
}
|
437
src/internal/zstd/fse.go
Normal file
437
src/internal/zstd/fse.go
Normal file
@ -0,0 +1,437 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// fseEntry is one entry in an FSE table.
|
||||
type fseEntry struct {
|
||||
sym uint8 // value that this entry records
|
||||
bits uint8 // number of bits to read to determine next state
|
||||
base uint16 // add those bits to this state to get the next state
|
||||
}
|
||||
|
||||
// readFSE reads an FSE table from data starting at off.
|
||||
// maxSym is the maximum symbol value.
|
||||
// maxBits is the maximum number of bits permitted for symbols in the table.
|
||||
// The FSE is written into table, which must be at least 1<<maxBits in size.
|
||||
// This returns the number of bits in the FSE table and the new offset.
|
||||
// RFC 4.1.1.
|
||||
func (r *Reader) readFSE(data block, off, maxSym, maxBits int, table []fseEntry) (tableBits, roff int, err error) {
|
||||
br := r.makeBitReader(data, off)
|
||||
if err := br.moreBits(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
accuracyLog := int(br.val(4)) + 5
|
||||
if accuracyLog > maxBits {
|
||||
return 0, 0, br.makeError("FSE accuracy log too large")
|
||||
}
|
||||
|
||||
// The number of remaining probabilities, plus 1.
|
||||
// This determines the number of bits to be read for the next value.
|
||||
remaining := (1 << accuracyLog) + 1
|
||||
|
||||
// The current difference between small and large values,
|
||||
// which depends on the number of remaining values.
|
||||
// Small values use 1 less bit.
|
||||
threshold := 1 << accuracyLog
|
||||
|
||||
// The number of bits needed to compute threshold.
|
||||
bitsNeeded := accuracyLog + 1
|
||||
|
||||
// The next character value.
|
||||
sym := 0
|
||||
|
||||
// Whether the last count was 0.
|
||||
prev0 := false
|
||||
|
||||
var norm [256]int16
|
||||
|
||||
for remaining > 1 && sym <= maxSym {
|
||||
if err := br.moreBits(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
if prev0 {
|
||||
// Previous count was 0, so there is a 2-bit
|
||||
// repeat flag. If the 2-bit flag is 0b11,
|
||||
// it adds 3 and then there is another repeat flag.
|
||||
zsym := sym
|
||||
for (br.bits & 0xfff) == 0xfff {
|
||||
zsym += 3 * 6
|
||||
br.bits >>= 12
|
||||
br.cnt -= 12
|
||||
if err := br.moreBits(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
for (br.bits & 3) == 3 {
|
||||
zsym += 3
|
||||
br.bits >>= 2
|
||||
br.cnt -= 2
|
||||
if err := br.moreBits(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
// We have at least 14 bits here,
|
||||
// no need to call moreBits
|
||||
|
||||
zsym += int(br.val(2))
|
||||
|
||||
if zsym > maxSym {
|
||||
return 0, 0, br.makeError("FSE symbol index overflow")
|
||||
}
|
||||
|
||||
for ; sym < zsym; sym++ {
|
||||
norm[uint8(sym)] = 0
|
||||
}
|
||||
|
||||
prev0 = false
|
||||
continue
|
||||
}
|
||||
|
||||
max := (2*threshold - 1) - remaining
|
||||
var count int
|
||||
if int(br.bits&uint32(threshold-1)) < max {
|
||||
// A small value.
|
||||
count = int(br.bits & uint32((threshold - 1)))
|
||||
br.bits >>= bitsNeeded - 1
|
||||
br.cnt -= uint32(bitsNeeded - 1)
|
||||
} else {
|
||||
// A large value.
|
||||
count = int(br.bits & uint32((2*threshold - 1)))
|
||||
if count >= threshold {
|
||||
count -= max
|
||||
}
|
||||
br.bits >>= bitsNeeded
|
||||
br.cnt -= uint32(bitsNeeded)
|
||||
}
|
||||
|
||||
count--
|
||||
if count >= 0 {
|
||||
remaining -= count
|
||||
} else {
|
||||
remaining--
|
||||
}
|
||||
if sym >= 256 {
|
||||
return 0, 0, br.makeError("FSE sym overflow")
|
||||
}
|
||||
norm[uint8(sym)] = int16(count)
|
||||
sym++
|
||||
|
||||
prev0 = count == 0
|
||||
|
||||
for remaining < threshold {
|
||||
bitsNeeded--
|
||||
threshold >>= 1
|
||||
}
|
||||
}
|
||||
|
||||
if remaining != 1 {
|
||||
return 0, 0, br.makeError("too many symbols in FSE table")
|
||||
}
|
||||
|
||||
for ; sym <= maxSym; sym++ {
|
||||
norm[uint8(sym)] = 0
|
||||
}
|
||||
|
||||
br.backup()
|
||||
|
||||
if err := r.buildFSE(off, norm[:maxSym+1], table, accuracyLog); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return accuracyLog, int(br.off), nil
|
||||
}
|
||||
|
||||
// buildFSE builds an FSE decoding table from a list of probabilities.
|
||||
// The probabilities are in norm. next is scratch space. The number of bits
|
||||
// in the table is tableBits.
|
||||
func (r *Reader) buildFSE(off int, norm []int16, table []fseEntry, tableBits int) error {
|
||||
tableSize := 1 << tableBits
|
||||
highThreshold := tableSize - 1
|
||||
|
||||
var next [256]uint16
|
||||
|
||||
for i, n := range norm {
|
||||
if n >= 0 {
|
||||
next[uint8(i)] = uint16(n)
|
||||
} else {
|
||||
table[highThreshold].sym = uint8(i)
|
||||
highThreshold--
|
||||
next[uint8(i)] = 1
|
||||
}
|
||||
}
|
||||
|
||||
pos := 0
|
||||
step := (tableSize >> 1) + (tableSize >> 3) + 3
|
||||
mask := tableSize - 1
|
||||
for i, n := range norm {
|
||||
for j := 0; j < int(n); j++ {
|
||||
table[pos].sym = uint8(i)
|
||||
pos = (pos + step) & mask
|
||||
for pos > highThreshold {
|
||||
pos = (pos + step) & mask
|
||||
}
|
||||
}
|
||||
}
|
||||
if pos != 0 {
|
||||
return r.makeError(off, "FSE count error")
|
||||
}
|
||||
|
||||
for i := 0; i < tableSize; i++ {
|
||||
sym := table[i].sym
|
||||
nextState := next[sym]
|
||||
next[sym]++
|
||||
|
||||
if nextState == 0 {
|
||||
return r.makeError(off, "FSE state error")
|
||||
}
|
||||
|
||||
highBit := 15 - bits.LeadingZeros16(nextState)
|
||||
|
||||
bits := tableBits - highBit
|
||||
table[i].bits = uint8(bits)
|
||||
table[i].base = (nextState << bits) - uint16(tableSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// fseBaselineEntry is an entry in an FSE baseline table.
|
||||
// We use these for literal/match/length values.
|
||||
// Those require mapping the symbol to a baseline value,
|
||||
// and then reading zero or more bits and adding the value to the baseline.
|
||||
// Rather than looking thees up in separate tables,
|
||||
// we convert the FSE table to an FSE baseline table.
|
||||
type fseBaselineEntry struct {
|
||||
baseline uint32 // baseline for value that this entry represents
|
||||
basebits uint8 // number of bits to read to add to baseline
|
||||
bits uint8 // number of bits to read to determine next state
|
||||
base uint16 // add the bits to this base to get the next state
|
||||
}
|
||||
|
||||
// Given a literal length code, we need to read a number of bits and
|
||||
// add that to a baseline. For states 0 to 15 the baseline is the
|
||||
// state and the number of bits is zero. RFC 3.1.1.3.2.1.1.
|
||||
|
||||
const literalLengthOffset = 16
|
||||
|
||||
var literalLengthBase = []uint32{
|
||||
16 | (1 << 24),
|
||||
18 | (1 << 24),
|
||||
20 | (1 << 24),
|
||||
22 | (1 << 24),
|
||||
24 | (2 << 24),
|
||||
28 | (2 << 24),
|
||||
32 | (3 << 24),
|
||||
40 | (3 << 24),
|
||||
48 | (4 << 24),
|
||||
64 | (6 << 24),
|
||||
128 | (7 << 24),
|
||||
256 | (8 << 24),
|
||||
512 | (9 << 24),
|
||||
1024 | (10 << 24),
|
||||
2048 | (11 << 24),
|
||||
4096 | (12 << 24),
|
||||
8192 | (13 << 24),
|
||||
16384 | (14 << 24),
|
||||
32768 | (15 << 24),
|
||||
65536 | (16 << 24),
|
||||
}
|
||||
|
||||
// makeLiteralBaselineFSE converts the literal length fseTable to baselineTable.
|
||||
func (r *Reader) makeLiteralBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
|
||||
for i, e := range fseTable {
|
||||
be := fseBaselineEntry{
|
||||
bits: e.bits,
|
||||
base: e.base,
|
||||
}
|
||||
if e.sym < literalLengthOffset {
|
||||
be.baseline = uint32(e.sym)
|
||||
be.basebits = 0
|
||||
} else {
|
||||
if e.sym > 35 {
|
||||
return r.makeError(off, "FSE baseline symbol overflow")
|
||||
}
|
||||
idx := e.sym - literalLengthOffset
|
||||
basebits := literalLengthBase[idx]
|
||||
be.baseline = basebits & 0xffffff
|
||||
be.basebits = uint8(basebits >> 24)
|
||||
}
|
||||
baselineTable[i] = be
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// makeOffsetBaselineFSE converts the offset length fseTable to baselineTable.
|
||||
func (r *Reader) makeOffsetBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
|
||||
for i, e := range fseTable {
|
||||
be := fseBaselineEntry{
|
||||
bits: e.bits,
|
||||
base: e.base,
|
||||
}
|
||||
if e.sym > 31 {
|
||||
return r.makeError(off, "FSE offset symbol overflow")
|
||||
}
|
||||
|
||||
// The simple way to write this is
|
||||
// be.baseline = 1 << e.sym
|
||||
// be.basebits = e.sym
|
||||
// That would give us an offset value that corresponds to
|
||||
// the one described in the RFC. However, for offsets > 3
|
||||
// we have to subtract 3. And for offset values 1, 2, 3
|
||||
// we use a repeated offset.
|
||||
//
|
||||
// The baseline is always a power of 2, and is never 0,
|
||||
// so for those low values we will see one entry that is
|
||||
// baseline 1, basebits 0, and one entry that is baseline 2,
|
||||
// basebits 1. All other entries will have baseline >= 4
|
||||
// basebits >= 2.
|
||||
//
|
||||
// So we can check for RFC offset <= 3 by checking for
|
||||
// basebits <= 1. That means that we can subtract 3 here
|
||||
// and not worry about doing it in the hot loop.
|
||||
|
||||
be.baseline = 1 << e.sym
|
||||
if e.sym >= 2 {
|
||||
be.baseline -= 3
|
||||
}
|
||||
be.basebits = e.sym
|
||||
baselineTable[i] = be
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Given a match length code, we need to read a number of bits and add
|
||||
// that to a baseline. For states 0 to 31 the baseline is state+3 and
|
||||
// the number of bits is zero. RFC 3.1.1.3.2.1.1.
|
||||
|
||||
const matchLengthOffset = 32
|
||||
|
||||
var matchLengthBase = []uint32{
|
||||
35 | (1 << 24),
|
||||
37 | (1 << 24),
|
||||
39 | (1 << 24),
|
||||
41 | (1 << 24),
|
||||
43 | (2 << 24),
|
||||
47 | (2 << 24),
|
||||
51 | (3 << 24),
|
||||
59 | (3 << 24),
|
||||
67 | (4 << 24),
|
||||
83 | (4 << 24),
|
||||
99 | (5 << 24),
|
||||
131 | (7 << 24),
|
||||
259 | (8 << 24),
|
||||
515 | (9 << 24),
|
||||
1027 | (10 << 24),
|
||||
2051 | (11 << 24),
|
||||
4099 | (12 << 24),
|
||||
8195 | (13 << 24),
|
||||
16387 | (14 << 24),
|
||||
32771 | (15 << 24),
|
||||
65539 | (16 << 24),
|
||||
}
|
||||
|
||||
// makeMatchBaselineFSE converts the match length fseTable to baselineTable.
|
||||
func (r *Reader) makeMatchBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
|
||||
for i, e := range fseTable {
|
||||
be := fseBaselineEntry{
|
||||
bits: e.bits,
|
||||
base: e.base,
|
||||
}
|
||||
if e.sym < matchLengthOffset {
|
||||
be.baseline = uint32(e.sym) + 3
|
||||
be.basebits = 0
|
||||
} else {
|
||||
if e.sym > 52 {
|
||||
return r.makeError(off, "FSE baseline symbol overflow")
|
||||
}
|
||||
idx := e.sym - matchLengthOffset
|
||||
basebits := matchLengthBase[idx]
|
||||
be.baseline = basebits & 0xffffff
|
||||
be.basebits = uint8(basebits >> 24)
|
||||
}
|
||||
baselineTable[i] = be
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// predefinedLiteralTable is the predefined table to use for literal lengths.
|
||||
// Generated from table in RFC 3.1.1.3.2.2.1.
|
||||
// Checked by TestPredefinedTables.
|
||||
var predefinedLiteralTable = [...]fseBaselineEntry{
|
||||
{0, 0, 4, 0}, {0, 0, 4, 16}, {1, 0, 5, 32},
|
||||
{3, 0, 5, 0}, {4, 0, 5, 0}, {6, 0, 5, 0},
|
||||
{7, 0, 5, 0}, {9, 0, 5, 0}, {10, 0, 5, 0},
|
||||
{12, 0, 5, 0}, {14, 0, 6, 0}, {16, 1, 5, 0},
|
||||
{20, 1, 5, 0}, {22, 1, 5, 0}, {28, 2, 5, 0},
|
||||
{32, 3, 5, 0}, {48, 4, 5, 0}, {64, 6, 5, 32},
|
||||
{128, 7, 5, 0}, {256, 8, 6, 0}, {1024, 10, 6, 0},
|
||||
{4096, 12, 6, 0}, {0, 0, 4, 32}, {1, 0, 4, 0},
|
||||
{2, 0, 5, 0}, {4, 0, 5, 32}, {5, 0, 5, 0},
|
||||
{7, 0, 5, 32}, {8, 0, 5, 0}, {10, 0, 5, 32},
|
||||
{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 1, 5, 32},
|
||||
{18, 1, 5, 0}, {22, 1, 5, 32}, {24, 2, 5, 0},
|
||||
{32, 3, 5, 32}, {40, 3, 5, 0}, {64, 6, 4, 0},
|
||||
{64, 6, 4, 16}, {128, 7, 5, 32}, {512, 9, 6, 0},
|
||||
{2048, 11, 6, 0}, {0, 0, 4, 48}, {1, 0, 4, 16},
|
||||
{2, 0, 5, 32}, {3, 0, 5, 32}, {5, 0, 5, 32},
|
||||
{6, 0, 5, 32}, {8, 0, 5, 32}, {9, 0, 5, 32},
|
||||
{11, 0, 5, 32}, {12, 0, 5, 32}, {15, 0, 6, 0},
|
||||
{18, 1, 5, 32}, {20, 1, 5, 32}, {24, 2, 5, 32},
|
||||
{28, 2, 5, 32}, {40, 3, 5, 32}, {48, 4, 5, 32},
|
||||
{65536, 16, 6, 0}, {32768, 15, 6, 0}, {16384, 14, 6, 0},
|
||||
{8192, 13, 6, 0},
|
||||
}
|
||||
|
||||
// predefinedOffsetTable is the predefined table to use for offsets.
|
||||
// Generated from table in RFC 3.1.1.3.2.2.3.
|
||||
// Checked by TestPredefinedTables.
|
||||
var predefinedOffsetTable = [...]fseBaselineEntry{
|
||||
{1, 0, 5, 0}, {61, 6, 4, 0}, {509, 9, 5, 0},
|
||||
{32765, 15, 5, 0}, {2097149, 21, 5, 0}, {5, 3, 5, 0},
|
||||
{125, 7, 4, 0}, {4093, 12, 5, 0}, {262141, 18, 5, 0},
|
||||
{8388605, 23, 5, 0}, {29, 5, 5, 0}, {253, 8, 4, 0},
|
||||
{16381, 14, 5, 0}, {1048573, 20, 5, 0}, {1, 2, 5, 0},
|
||||
{125, 7, 4, 16}, {2045, 11, 5, 0}, {131069, 17, 5, 0},
|
||||
{4194301, 22, 5, 0}, {13, 4, 5, 0}, {253, 8, 4, 16},
|
||||
{8189, 13, 5, 0}, {524285, 19, 5, 0}, {2, 1, 5, 0},
|
||||
{61, 6, 4, 16}, {1021, 10, 5, 0}, {65533, 16, 5, 0},
|
||||
{268435453, 28, 5, 0}, {134217725, 27, 5, 0}, {67108861, 26, 5, 0},
|
||||
{33554429, 25, 5, 0}, {16777213, 24, 5, 0},
|
||||
}
|
||||
|
||||
// predefinedMatchTable is the predefined table to use for match lengths.
|
||||
// Generated from table in RFC 3.1.1.3.2.2.2.
|
||||
// Checked by TestPredefinedTables.
|
||||
var predefinedMatchTable = [...]fseBaselineEntry{
|
||||
{3, 0, 6, 0}, {4, 0, 4, 0}, {5, 0, 5, 32},
|
||||
{6, 0, 5, 0}, {8, 0, 5, 0}, {9, 0, 5, 0},
|
||||
{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 0, 6, 0},
|
||||
{19, 0, 6, 0}, {22, 0, 6, 0}, {25, 0, 6, 0},
|
||||
{28, 0, 6, 0}, {31, 0, 6, 0}, {34, 0, 6, 0},
|
||||
{37, 1, 6, 0}, {41, 1, 6, 0}, {47, 2, 6, 0},
|
||||
{59, 3, 6, 0}, {83, 4, 6, 0}, {131, 7, 6, 0},
|
||||
{515, 9, 6, 0}, {4, 0, 4, 16}, {5, 0, 4, 0},
|
||||
{6, 0, 5, 32}, {7, 0, 5, 0}, {9, 0, 5, 32},
|
||||
{10, 0, 5, 0}, {12, 0, 6, 0}, {15, 0, 6, 0},
|
||||
{18, 0, 6, 0}, {21, 0, 6, 0}, {24, 0, 6, 0},
|
||||
{27, 0, 6, 0}, {30, 0, 6, 0}, {33, 0, 6, 0},
|
||||
{35, 1, 6, 0}, {39, 1, 6, 0}, {43, 2, 6, 0},
|
||||
{51, 3, 6, 0}, {67, 4, 6, 0}, {99, 5, 6, 0},
|
||||
{259, 8, 6, 0}, {4, 0, 4, 32}, {4, 0, 4, 48},
|
||||
{5, 0, 4, 16}, {7, 0, 5, 32}, {8, 0, 5, 32},
|
||||
{10, 0, 5, 32}, {11, 0, 5, 32}, {14, 0, 6, 0},
|
||||
{17, 0, 6, 0}, {20, 0, 6, 0}, {23, 0, 6, 0},
|
||||
{26, 0, 6, 0}, {29, 0, 6, 0}, {32, 0, 6, 0},
|
||||
{65539, 16, 6, 0}, {32771, 15, 6, 0}, {16387, 14, 6, 0},
|
||||
{8195, 13, 6, 0}, {4099, 12, 6, 0}, {2051, 11, 6, 0},
|
||||
{1027, 10, 6, 0},
|
||||
}
|
89
src/internal/zstd/fse_test.go
Normal file
89
src/internal/zstd/fse_test.go
Normal file
@ -0,0 +1,89 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"slices"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// literalPredefinedDistribution is the predefined distribution table
|
||||
// for literal lengths. RFC 3.1.1.3.2.2.1.
|
||||
var literalPredefinedDistribution = []int16{
|
||||
4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1,
|
||||
2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
|
||||
-1, -1, -1, -1,
|
||||
}
|
||||
|
||||
// offsetPredefinedDistribution is the predefined distribution table
|
||||
// for offsets. RFC 3.1.1.3.2.2.3.
|
||||
var offsetPredefinedDistribution = []int16{
|
||||
1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
|
||||
}
|
||||
|
||||
// matchPredefinedDistribution is the predefined distribution table
|
||||
// for match lengths. RFC 3.1.1.3.2.2.2.
|
||||
var matchPredefinedDistribution = []int16{
|
||||
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1,
|
||||
-1, -1, -1, -1, -1,
|
||||
}
|
||||
|
||||
// TestPredefinedTables verifies that we can generate the predefined
|
||||
// literal/offset/match tables from the input data in RFC 8878.
|
||||
// This serves as a test of the predefined tables, and also of buildFSE
|
||||
// and the functions that make baseline FSE tables.
|
||||
func TestPredefinedTables(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
distribution []int16
|
||||
tableBits int
|
||||
toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
|
||||
predef []fseBaselineEntry
|
||||
}{
|
||||
{
|
||||
name: "literal",
|
||||
distribution: literalPredefinedDistribution,
|
||||
tableBits: 6,
|
||||
toBaseline: (*Reader).makeLiteralBaselineFSE,
|
||||
predef: predefinedLiteralTable[:],
|
||||
},
|
||||
{
|
||||
name: "offset",
|
||||
distribution: offsetPredefinedDistribution,
|
||||
tableBits: 5,
|
||||
toBaseline: (*Reader).makeOffsetBaselineFSE,
|
||||
predef: predefinedOffsetTable[:],
|
||||
},
|
||||
{
|
||||
name: "match",
|
||||
distribution: matchPredefinedDistribution,
|
||||
tableBits: 6,
|
||||
toBaseline: (*Reader).makeMatchBaselineFSE,
|
||||
predef: predefinedMatchTable[:],
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var r Reader
|
||||
table := make([]fseEntry, 1<<test.tableBits)
|
||||
if err := r.buildFSE(0, test.distribution, table, test.tableBits); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
baselineTable := make([]fseBaselineEntry, len(table))
|
||||
if err := test.toBaseline(&r, 0, table, baselineTable); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(baselineTable, test.predef) {
|
||||
t.Errorf("got %v, want %v", baselineTable, test.predef)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
140
src/internal/zstd/fuzz_test.go
Normal file
140
src/internal/zstd/fuzz_test.go
Normal file
@ -0,0 +1,140 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// badStrings is some inputs that FuzzReader failed on earlier.
|
||||
var badStrings = []string{
|
||||
"(\xb5/\xfdd00,\x05\x00\xc4\x0400000000000000000000000000000000000000000000000000000000000000000000000000000 \xa07100000000000000000000000000000000000000000000000000000000000000000000000000aM\x8a2y0B\b",
|
||||
"(\xb5/\xfd00$\x05\x0020 00X70000a70000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
"(\xb5/\xfd00$\x05\x0020 00B00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
"(\xb5/\xfd00}\x00\x0020\x00\x9000000000000",
|
||||
"(\xb5/\xfd00}\x00\x00&0\x02\x830!000000000",
|
||||
"(\xb5/\xfd\x1002000$\x05\x0010\xcc0\xa8100000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
"(\xb5/\xfd\x1002000$\x05\x0000\xcc0\xa8100d\x0000001000000000000000000000000000000000000000000000000000000000000000000000000\x000000000000000000000000000000000000000000000000000000000000000000000000000000",
|
||||
"(\xb5/\xfd001\x00\x0000000000000000000",
|
||||
}
|
||||
|
||||
// This is a simple fuzzer to see if the decompressor panics.
|
||||
func FuzzReader(f *testing.F) {
|
||||
for _, test := range tests {
|
||||
f.Add([]byte(test.compressed))
|
||||
}
|
||||
for _, s := range badStrings {
|
||||
f.Add([]byte(s))
|
||||
}
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
r := NewReader(bytes.NewReader(b))
|
||||
io.Copy(io.Discard, r)
|
||||
})
|
||||
}
|
||||
|
||||
// Fuzz test to verify that what we decompress is what we compress.
|
||||
// This isn't a great fuzz test because the fuzzer can't efficiently
|
||||
// explore the space of decompressor behavior, since it can't see
|
||||
// what the compressor is doing. But it's better than nothing.
|
||||
func FuzzDecompressor(f *testing.F) {
|
||||
if _, err := os.Stat("/usr/bin/zstd"); err != nil {
|
||||
f.Skip("skipping because /usr/bin/zstd does not exist")
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
f.Add([]byte(test.uncompressed))
|
||||
}
|
||||
|
||||
// Add some larger data, as that has more interesting compression.
|
||||
f.Add(bytes.Repeat([]byte("abcdefghijklmnop"), 256))
|
||||
var buf bytes.Buffer
|
||||
for i := 0; i < 256; i++ {
|
||||
buf.WriteByte(byte(i))
|
||||
}
|
||||
f.Add(bytes.Repeat(buf.Bytes(), 64))
|
||||
f.Add(bigData(f))
|
||||
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
cmd := exec.Command("/usr/bin/zstd", "-z")
|
||||
cmd.Stdin = bytes.NewReader(b)
|
||||
var compressed bytes.Buffer
|
||||
cmd.Stdout = &compressed
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Errorf("running zstd failed: %v", err)
|
||||
}
|
||||
|
||||
r := NewReader(bytes.NewReader(compressed.Bytes()))
|
||||
got, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(got, b) {
|
||||
showDiffs(t, got, b)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Fuzz test to check that if we can decompress some data,
|
||||
// so can zstd, and that we get the same result.
|
||||
func FuzzReverse(f *testing.F) {
|
||||
if _, err := os.Stat("/usr/bin/zstd"); err != nil {
|
||||
f.Skip("skipping because /usr/bin/zstd does not exist")
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
f.Add([]byte(test.compressed))
|
||||
}
|
||||
|
||||
// Set a hook to reject some cases where we don't match zstd.
|
||||
fuzzing = true
|
||||
defer func() { fuzzing = false }()
|
||||
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
r := NewReader(bytes.NewReader(b))
|
||||
goExp, goErr := io.ReadAll(r)
|
||||
|
||||
cmd := exec.Command("/usr/bin/zstd", "-d")
|
||||
cmd.Stdin = bytes.NewReader(b)
|
||||
var uncompressed bytes.Buffer
|
||||
cmd.Stdout = &uncompressed
|
||||
cmd.Stderr = os.Stderr
|
||||
zstdErr := cmd.Run()
|
||||
zstdExp := uncompressed.Bytes()
|
||||
|
||||
if goErr == nil && zstdErr == nil {
|
||||
if !bytes.Equal(zstdExp, goExp) {
|
||||
showDiffs(t, zstdExp, goExp)
|
||||
}
|
||||
} else {
|
||||
// Ideally we should check that this package and
|
||||
// the zstd program both fail or both succeed,
|
||||
// and that if they both fail one byte sequence
|
||||
// is an exact prefix of the other.
|
||||
// Actually trying this proved to be frustrating,
|
||||
// as the zstd program appears to accept invalid
|
||||
// byte sequences using rules that are difficult
|
||||
// to determine.
|
||||
// So we just check the prefix.
|
||||
|
||||
c := len(goExp)
|
||||
if c > len(zstdExp) {
|
||||
c = len(zstdExp)
|
||||
}
|
||||
goExp = goExp[:c]
|
||||
zstdExp = zstdExp[:c]
|
||||
if !bytes.Equal(goExp, zstdExp) {
|
||||
t.Error("byte mismatch after error")
|
||||
t.Logf("Go error: %v\n", goErr)
|
||||
t.Logf("zstd error: %v\n", zstdErr)
|
||||
showDiffs(t, zstdExp, goExp)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
204
src/internal/zstd/huff.go
Normal file
204
src/internal/zstd/huff.go
Normal file
@ -0,0 +1,204 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"io"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
// maxHuffmanBits is the largest possible Huffman table bits.
|
||||
const maxHuffmanBits = 11
|
||||
|
||||
// readHuff reads Huffman table from data starting at off into table.
|
||||
// Each entry in a Huffman table is a pair of bytes.
|
||||
// The high byte is the encoded value. The low byte is the number
|
||||
// of bits used to encode that value. We index into the table
|
||||
// with a value of size tableBits. A value that requires fewer bits
|
||||
// appear in the table multiple times.
|
||||
// This returns the number of bits in the Huffman table and the new offset.
|
||||
// RFC 4.2.1.
|
||||
func (r *Reader) readHuff(data block, off int, table []uint16) (tableBits, roff int, err error) {
|
||||
if off >= len(data) {
|
||||
return 0, 0, r.makeEOFError(off)
|
||||
}
|
||||
|
||||
hdr := data[off]
|
||||
off++
|
||||
|
||||
var weights [256]uint8
|
||||
var count int
|
||||
if hdr < 128 {
|
||||
// The table is compressed using an FSE. RFC 4.2.1.2.
|
||||
if len(r.fseScratch) < 1<<6 {
|
||||
r.fseScratch = make([]fseEntry, 1<<6)
|
||||
}
|
||||
fseBits, noff, err := r.readFSE(data, off, 255, 6, r.fseScratch)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
fseTable := r.fseScratch
|
||||
|
||||
if off+int(hdr) > len(data) {
|
||||
return 0, 0, r.makeEOFError(off)
|
||||
}
|
||||
|
||||
rbr, err := r.makeReverseBitReader(data, off+int(hdr)-1, noff)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
state1, err := rbr.val(uint8(fseBits))
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
state2, err := rbr.val(uint8(fseBits))
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
// There are two independent FSE streams, tracked by
|
||||
// state1 and state2. We decode them alternately.
|
||||
|
||||
for {
|
||||
pt := &fseTable[state1]
|
||||
if !rbr.fetch(pt.bits) {
|
||||
if count >= 254 {
|
||||
return 0, 0, rbr.makeError("Huffman count overflow")
|
||||
}
|
||||
weights[count] = pt.sym
|
||||
weights[count+1] = fseTable[state2].sym
|
||||
count += 2
|
||||
break
|
||||
}
|
||||
|
||||
v, err := rbr.val(pt.bits)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
state1 = uint32(pt.base) + v
|
||||
|
||||
if count >= 255 {
|
||||
return 0, 0, rbr.makeError("Huffman count overflow")
|
||||
}
|
||||
|
||||
weights[count] = pt.sym
|
||||
count++
|
||||
|
||||
pt = &fseTable[state2]
|
||||
|
||||
if !rbr.fetch(pt.bits) {
|
||||
if count >= 254 {
|
||||
return 0, 0, rbr.makeError("Huffman count overflow")
|
||||
}
|
||||
weights[count] = pt.sym
|
||||
weights[count+1] = fseTable[state1].sym
|
||||
count += 2
|
||||
break
|
||||
}
|
||||
|
||||
v, err = rbr.val(pt.bits)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
state2 = uint32(pt.base) + v
|
||||
|
||||
if count >= 255 {
|
||||
return 0, 0, rbr.makeError("Huffman count overflow")
|
||||
}
|
||||
|
||||
weights[count] = pt.sym
|
||||
count++
|
||||
}
|
||||
|
||||
off += int(hdr)
|
||||
} else {
|
||||
// The table is not compressed. Each weight is 4 bits.
|
||||
|
||||
count = int(hdr) - 127
|
||||
if off+((count+1)/2) >= len(data) {
|
||||
return 0, 0, io.ErrUnexpectedEOF
|
||||
}
|
||||
for i := 0; i < count; i += 2 {
|
||||
b := data[off]
|
||||
off++
|
||||
weights[i] = b >> 4
|
||||
weights[i+1] = b & 0xf
|
||||
}
|
||||
}
|
||||
|
||||
// RFC 4.2.1.3.
|
||||
|
||||
var weightMark [13]uint32
|
||||
weightMask := uint32(0)
|
||||
for _, w := range weights[:count] {
|
||||
if w > 12 {
|
||||
return 0, 0, r.makeError(off, "Huffman weight overflow")
|
||||
}
|
||||
weightMark[w]++
|
||||
if w > 0 {
|
||||
weightMask += 1 << (w - 1)
|
||||
}
|
||||
}
|
||||
if weightMask == 0 {
|
||||
return 0, 0, r.makeError(off, "bad Huffman weights")
|
||||
}
|
||||
|
||||
tableBits = 32 - bits.LeadingZeros32(weightMask)
|
||||
if tableBits > maxHuffmanBits {
|
||||
return 0, 0, r.makeError(off, "bad Huffman weights")
|
||||
}
|
||||
|
||||
if len(table) < 1<<tableBits {
|
||||
return 0, 0, r.makeError(off, "Huffman table too small")
|
||||
}
|
||||
|
||||
// Work out the last weight value, which is omitted because
|
||||
// the weights must sum to a power of two.
|
||||
left := (uint32(1) << tableBits) - weightMask
|
||||
if left == 0 {
|
||||
return 0, 0, r.makeError(off, "bad Huffman weights")
|
||||
}
|
||||
highBit := 31 - bits.LeadingZeros32(left)
|
||||
if uint32(1)<<highBit != left {
|
||||
return 0, 0, r.makeError(off, "bad Huffman weights")
|
||||
}
|
||||
if count >= 256 {
|
||||
return 0, 0, r.makeError(off, "Huffman weight overflow")
|
||||
}
|
||||
weights[count] = uint8(highBit + 1)
|
||||
count++
|
||||
weightMark[highBit+1]++
|
||||
|
||||
if weightMark[1] < 2 || weightMark[1]&1 != 0 {
|
||||
return 0, 0, r.makeError(off, "bad Huffman weights")
|
||||
}
|
||||
|
||||
// Change weightMark from a count of weights to the index of
|
||||
// the first symbol for that weight. We shift the indexes to
|
||||
// also store how many we have seen so far,
|
||||
next := uint32(0)
|
||||
for i := 0; i < tableBits; i++ {
|
||||
cur := next
|
||||
next += weightMark[i+1] << i
|
||||
weightMark[i+1] = cur
|
||||
}
|
||||
|
||||
for i, w := range weights[:count] {
|
||||
if w == 0 {
|
||||
continue
|
||||
}
|
||||
length := uint32(1) << (w - 1)
|
||||
tval := uint16(i)<<8 | (uint16(tableBits) + 1 - uint16(w))
|
||||
start := weightMark[w]
|
||||
for j := uint32(0); j < length; j++ {
|
||||
table[start+j] = tval
|
||||
}
|
||||
weightMark[w] += length
|
||||
}
|
||||
|
||||
return tableBits, off, nil
|
||||
}
|
330
src/internal/zstd/literals.go
Normal file
330
src/internal/zstd/literals.go
Normal file
@ -0,0 +1,330 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
// readLiterals reads and decompresses the literals from data at off.
|
||||
// The literals are appended to outbuf, which is returned.
|
||||
// Also returns the new input offset. RFC 3.1.1.3.1.
|
||||
func (r *Reader) readLiterals(data block, off int, outbuf []byte) (int, []byte, error) {
|
||||
if off >= len(data) {
|
||||
return 0, nil, r.makeEOFError(off)
|
||||
}
|
||||
|
||||
// Literals section header. RFC 3.1.1.3.1.1.
|
||||
hdr := data[off]
|
||||
off++
|
||||
|
||||
if (hdr&3) == 0 || (hdr&3) == 1 {
|
||||
return r.readRawRLELiterals(data, off, hdr, outbuf)
|
||||
} else {
|
||||
return r.readHuffLiterals(data, off, hdr, outbuf)
|
||||
}
|
||||
}
|
||||
|
||||
// readRawRLELiterals reads and decompresses a Raw_Literals_Block or
|
||||
// a RLE_Literals_Block. RFC 3.1.1.3.1.1.
|
||||
func (r *Reader) readRawRLELiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) {
|
||||
raw := (hdr & 3) == 0
|
||||
|
||||
var regeneratedSize int
|
||||
switch (hdr >> 2) & 3 {
|
||||
case 0, 2:
|
||||
regeneratedSize = int(hdr >> 3)
|
||||
case 1:
|
||||
if off >= len(data) {
|
||||
return 0, nil, r.makeEOFError(off)
|
||||
}
|
||||
regeneratedSize = int(hdr>>4) + (int(data[off]) << 4)
|
||||
off++
|
||||
case 3:
|
||||
if off+1 >= len(data) {
|
||||
return 0, nil, r.makeEOFError(off)
|
||||
}
|
||||
regeneratedSize = int(hdr>>4) + (int(data[off]) << 4) + (int(data[off+1]) << 12)
|
||||
off += 2
|
||||
}
|
||||
|
||||
// We are going to use the entire literal block in the output.
|
||||
// The maximum size of one decompressed block is 128K,
|
||||
// so we can't have more literals than that.
|
||||
if regeneratedSize > 128<<10 {
|
||||
return 0, nil, r.makeError(off, "literal size too large")
|
||||
}
|
||||
|
||||
if raw {
|
||||
// RFC 3.1.1.3.1.2.
|
||||
if off+regeneratedSize > len(data) {
|
||||
return 0, nil, r.makeError(off, "raw literal size too large")
|
||||
}
|
||||
outbuf = append(outbuf, data[off:off+regeneratedSize]...)
|
||||
off += regeneratedSize
|
||||
} else {
|
||||
// RFC 3.1.1.3.1.3.
|
||||
if off >= len(data) {
|
||||
return 0, nil, r.makeError(off, "RLE literal missing")
|
||||
}
|
||||
rle := data[off]
|
||||
off++
|
||||
for i := 0; i < regeneratedSize; i++ {
|
||||
outbuf = append(outbuf, rle)
|
||||
}
|
||||
}
|
||||
|
||||
return off, outbuf, nil
|
||||
}
|
||||
|
||||
// readHuffLiterals reads and decompresses a Compressed_Literals_Block or
|
||||
// a Treeless_Literals_Block. RFC 3.1.1.3.1.4.
|
||||
func (r *Reader) readHuffLiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) {
|
||||
var (
|
||||
regeneratedSize int
|
||||
compressedSize int
|
||||
streams int
|
||||
)
|
||||
switch (hdr >> 2) & 3 {
|
||||
case 0, 1:
|
||||
if off+1 >= len(data) {
|
||||
return 0, nil, r.makeEOFError(off)
|
||||
}
|
||||
regeneratedSize = (int(hdr) >> 4) | ((int(data[off]) & 0x3f) << 4)
|
||||
compressedSize = (int(data[off]) >> 6) | (int(data[off+1]) << 2)
|
||||
off += 2
|
||||
if ((hdr >> 2) & 3) == 0 {
|
||||
streams = 1
|
||||
} else {
|
||||
streams = 4
|
||||
}
|
||||
case 2:
|
||||
if off+2 >= len(data) {
|
||||
return 0, nil, r.makeEOFError(off)
|
||||
}
|
||||
regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 3) << 12)
|
||||
compressedSize = (int(data[off+1]) >> 2) | (int(data[off+2]) << 6)
|
||||
off += 3
|
||||
streams = 4
|
||||
case 3:
|
||||
if off+3 >= len(data) {
|
||||
return 0, nil, r.makeEOFError(off)
|
||||
}
|
||||
regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 0x3f) << 12)
|
||||
compressedSize = (int(data[off+1]) >> 6) | (int(data[off+2]) << 2) | (int(data[off+3]) << 10)
|
||||
off += 4
|
||||
streams = 4
|
||||
}
|
||||
|
||||
// We are going to use the entire literal block in the output.
|
||||
// The maximum size of one decompressed block is 128K,
|
||||
// so we can't have more literals than that.
|
||||
if regeneratedSize > 128<<10 {
|
||||
return 0, nil, r.makeError(off, "literal size too large")
|
||||
}
|
||||
|
||||
roff := off + compressedSize
|
||||
if roff > len(data) || roff < 0 {
|
||||
return 0, nil, r.makeEOFError(off)
|
||||
}
|
||||
|
||||
totalStreamsSize := compressedSize
|
||||
if (hdr & 3) == 2 {
|
||||
// Compressed_Literals_Block.
|
||||
// Read new huffman tree.
|
||||
|
||||
if len(r.huffmanTable) < 1<<maxHuffmanBits {
|
||||
r.huffmanTable = make([]uint16, 1<<maxHuffmanBits)
|
||||
}
|
||||
|
||||
huffmanTableBits, hoff, err := r.readHuff(data, off, r.huffmanTable)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
r.huffmanTableBits = huffmanTableBits
|
||||
|
||||
if totalStreamsSize < hoff-off {
|
||||
return 0, nil, r.makeError(off, "Huffman table too big")
|
||||
}
|
||||
totalStreamsSize -= hoff - off
|
||||
off = hoff
|
||||
} else {
|
||||
// Treeless_Literals_Block
|
||||
// Reuse previous Huffman tree.
|
||||
if r.huffmanTableBits == 0 {
|
||||
return 0, nil, r.makeError(off, "missing literals Huffman tree")
|
||||
}
|
||||
}
|
||||
|
||||
// Decompress compressedSize bytes of data at off using the
|
||||
// Huffman tree.
|
||||
|
||||
var err error
|
||||
if streams == 1 {
|
||||
outbuf, err = r.readLiteralsOneStream(data, off, totalStreamsSize, regeneratedSize, outbuf)
|
||||
} else {
|
||||
outbuf, err = r.readLiteralsFourStreams(data, off, totalStreamsSize, regeneratedSize, outbuf)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
return roff, outbuf, nil
|
||||
}
|
||||
|
||||
// readLiteralsOneStream reads a single stream of compressed literals.
|
||||
func (r *Reader) readLiteralsOneStream(data block, off, compressedSize, regeneratedSize int, outbuf []byte) ([]byte, error) {
|
||||
// We let the reverse bit reader read earlier bytes,
|
||||
// because the Huffman table ignores bits that it doesn't need.
|
||||
rbr, err := r.makeReverseBitReader(data, off+compressedSize-1, off-2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
huffTable := r.huffmanTable
|
||||
huffBits := uint32(r.huffmanTableBits)
|
||||
huffMask := (uint32(1) << huffBits) - 1
|
||||
|
||||
for i := 0; i < regeneratedSize; i++ {
|
||||
if !rbr.fetch(uint8(huffBits)) {
|
||||
return nil, rbr.makeError("literals Huffman stream out of bits")
|
||||
}
|
||||
|
||||
var t uint16
|
||||
idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask
|
||||
t = huffTable[idx]
|
||||
outbuf = append(outbuf, byte(t>>8))
|
||||
rbr.cnt -= uint32(t & 0xff)
|
||||
}
|
||||
|
||||
return outbuf, nil
|
||||
}
|
||||
|
||||
// readLiteralsFourStreams reads four interleaved streams of
|
||||
// compressed literals.
|
||||
func (r *Reader) readLiteralsFourStreams(data block, off, totalStreamsSize, regeneratedSize int, outbuf []byte) ([]byte, error) {
|
||||
// Read the jump table to find out where the streams are.
|
||||
// RFC 3.1.1.3.1.6.
|
||||
if off+5 >= len(data) {
|
||||
return nil, r.makeEOFError(off)
|
||||
}
|
||||
if totalStreamsSize < 6 {
|
||||
return nil, r.makeError(off, "total streams size too small for jump table")
|
||||
}
|
||||
|
||||
streamSize1 := binary.LittleEndian.Uint16(data[off:])
|
||||
streamSize2 := binary.LittleEndian.Uint16(data[off+2:])
|
||||
streamSize3 := binary.LittleEndian.Uint16(data[off+4:])
|
||||
off += 6
|
||||
|
||||
tot := uint64(streamSize1) + uint64(streamSize2) + uint64(streamSize3)
|
||||
if tot > uint64(totalStreamsSize)-6 {
|
||||
return nil, r.makeEOFError(off)
|
||||
}
|
||||
streamSize4 := uint32(totalStreamsSize) - 6 - uint32(tot)
|
||||
|
||||
off--
|
||||
off1 := off + int(streamSize1)
|
||||
start1 := off + 1
|
||||
|
||||
off2 := off1 + int(streamSize2)
|
||||
start2 := off1 + 1
|
||||
|
||||
off3 := off2 + int(streamSize3)
|
||||
start3 := off2 + 1
|
||||
|
||||
off4 := off3 + int(streamSize4)
|
||||
start4 := off3 + 1
|
||||
|
||||
// We let the reverse bit readers read earlier bytes,
|
||||
// because the Huffman tables ignore bits that they don't need.
|
||||
|
||||
rbr1, err := r.makeReverseBitReader(data, off1, start1-2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rbr2, err := r.makeReverseBitReader(data, off2, start2-2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rbr3, err := r.makeReverseBitReader(data, off3, start3-2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rbr4, err := r.makeReverseBitReader(data, off4, start4-2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
regeneratedStreamSize := (regeneratedSize + 3) / 4
|
||||
|
||||
out1 := len(outbuf)
|
||||
out2 := out1 + regeneratedStreamSize
|
||||
out3 := out2 + regeneratedStreamSize
|
||||
out4 := out3 + regeneratedStreamSize
|
||||
|
||||
regeneratedStreamSize4 := regeneratedSize - regeneratedStreamSize*3
|
||||
|
||||
outbuf = append(outbuf, make([]byte, regeneratedSize)...)
|
||||
|
||||
huffTable := r.huffmanTable
|
||||
huffBits := uint32(r.huffmanTableBits)
|
||||
huffMask := (uint32(1) << huffBits) - 1
|
||||
|
||||
for i := 0; i < regeneratedStreamSize; i++ {
|
||||
use4 := i < regeneratedStreamSize4
|
||||
|
||||
fetchHuff := func(rbr *reverseBitReader) (uint16, error) {
|
||||
if !rbr.fetch(uint8(huffBits)) {
|
||||
return 0, rbr.makeError("literals Huffman stream out of bits")
|
||||
}
|
||||
idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask
|
||||
return huffTable[idx], nil
|
||||
}
|
||||
|
||||
t1, err := fetchHuff(&rbr1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t2, err := fetchHuff(&rbr2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t3, err := fetchHuff(&rbr3)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if use4 {
|
||||
t4, err := fetchHuff(&rbr4)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outbuf[out4] = byte(t4 >> 8)
|
||||
out4++
|
||||
rbr4.cnt -= uint32(t4 & 0xff)
|
||||
}
|
||||
|
||||
outbuf[out1] = byte(t1 >> 8)
|
||||
out1++
|
||||
rbr1.cnt -= uint32(t1 & 0xff)
|
||||
|
||||
outbuf[out2] = byte(t2 >> 8)
|
||||
out2++
|
||||
rbr2.cnt -= uint32(t2 & 0xff)
|
||||
|
||||
outbuf[out3] = byte(t3 >> 8)
|
||||
out3++
|
||||
rbr3.cnt -= uint32(t3 & 0xff)
|
||||
}
|
||||
|
||||
return outbuf, nil
|
||||
}
|
148
src/internal/zstd/xxhash.go
Normal file
148
src/internal/zstd/xxhash.go
Normal file
@ -0,0 +1,148 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"math/bits"
|
||||
)
|
||||
|
||||
const (
|
||||
xxhPrime64c1 = 0x9e3779b185ebca87
|
||||
xxhPrime64c2 = 0xc2b2ae3d27d4eb4f
|
||||
xxhPrime64c3 = 0x165667b19e3779f9
|
||||
xxhPrime64c4 = 0x85ebca77c2b2ae63
|
||||
xxhPrime64c5 = 0x27d4eb2f165667c5
|
||||
)
|
||||
|
||||
// xxhash64 is the state of a xxHash-64 checksum.
|
||||
type xxhash64 struct {
|
||||
len uint64 // total length hashed
|
||||
v [4]uint64 // accumulators
|
||||
buf [32]byte // buffer
|
||||
cnt int // number of bytes in buffer
|
||||
}
|
||||
|
||||
// reset discards the current state and prepares to compute a new hash.
|
||||
// We assume a seed of 0 since that is what zstd uses.
|
||||
func (xh *xxhash64) reset() {
|
||||
xh.len = 0
|
||||
|
||||
// Separate addition for awkward constant overflow.
|
||||
xh.v[0] = xxhPrime64c1
|
||||
xh.v[0] += xxhPrime64c2
|
||||
|
||||
xh.v[1] = xxhPrime64c2
|
||||
xh.v[2] = 0
|
||||
|
||||
// Separate negation for awkward constant overflow.
|
||||
xh.v[3] = xxhPrime64c1
|
||||
xh.v[3] = -xh.v[3]
|
||||
|
||||
for i := range xh.buf {
|
||||
xh.buf[i] = 0
|
||||
}
|
||||
xh.cnt = 0
|
||||
}
|
||||
|
||||
// update adds a buffer to the has.
|
||||
func (xh *xxhash64) update(b []byte) {
|
||||
xh.len += uint64(len(b))
|
||||
|
||||
if xh.cnt+len(b) < len(xh.buf) {
|
||||
copy(xh.buf[xh.cnt:], b)
|
||||
xh.cnt += len(b)
|
||||
return
|
||||
}
|
||||
|
||||
if xh.cnt > 0 {
|
||||
n := copy(xh.buf[xh.cnt:], b)
|
||||
b = b[n:]
|
||||
xh.v[0] = xh.round(xh.v[0], binary.LittleEndian.Uint64(xh.buf[:]))
|
||||
xh.v[1] = xh.round(xh.v[1], binary.LittleEndian.Uint64(xh.buf[8:]))
|
||||
xh.v[2] = xh.round(xh.v[2], binary.LittleEndian.Uint64(xh.buf[16:]))
|
||||
xh.v[3] = xh.round(xh.v[3], binary.LittleEndian.Uint64(xh.buf[24:]))
|
||||
xh.cnt = 0
|
||||
}
|
||||
|
||||
for len(b) >= 32 {
|
||||
xh.v[0] = xh.round(xh.v[0], binary.LittleEndian.Uint64(b))
|
||||
xh.v[1] = xh.round(xh.v[1], binary.LittleEndian.Uint64(b[8:]))
|
||||
xh.v[2] = xh.round(xh.v[2], binary.LittleEndian.Uint64(b[16:]))
|
||||
xh.v[3] = xh.round(xh.v[3], binary.LittleEndian.Uint64(b[24:]))
|
||||
b = b[32:]
|
||||
}
|
||||
|
||||
if len(b) > 0 {
|
||||
copy(xh.buf[:], b)
|
||||
xh.cnt = len(b)
|
||||
}
|
||||
}
|
||||
|
||||
// digest returns the final hash value.
|
||||
func (xh *xxhash64) digest() uint64 {
|
||||
var h64 uint64
|
||||
if xh.len < 32 {
|
||||
h64 = xh.v[2] + xxhPrime64c5
|
||||
} else {
|
||||
h64 = bits.RotateLeft64(xh.v[0], 1) +
|
||||
bits.RotateLeft64(xh.v[1], 7) +
|
||||
bits.RotateLeft64(xh.v[2], 12) +
|
||||
bits.RotateLeft64(xh.v[3], 18)
|
||||
h64 = xh.mergeRound(h64, xh.v[0])
|
||||
h64 = xh.mergeRound(h64, xh.v[1])
|
||||
h64 = xh.mergeRound(h64, xh.v[2])
|
||||
h64 = xh.mergeRound(h64, xh.v[3])
|
||||
}
|
||||
|
||||
h64 += xh.len
|
||||
|
||||
len := xh.len
|
||||
len &= 31
|
||||
buf := xh.buf[:]
|
||||
for len >= 8 {
|
||||
k1 := xh.round(0, binary.LittleEndian.Uint64(buf))
|
||||
buf = buf[8:]
|
||||
h64 ^= k1
|
||||
h64 = bits.RotateLeft64(h64, 27)*xxhPrime64c1 + xxhPrime64c4
|
||||
len -= 8
|
||||
}
|
||||
if len >= 4 {
|
||||
h64 ^= uint64(binary.LittleEndian.Uint32(buf)) * xxhPrime64c1
|
||||
buf = buf[4:]
|
||||
h64 = bits.RotateLeft64(h64, 23)*xxhPrime64c2 + xxhPrime64c3
|
||||
len -= 4
|
||||
}
|
||||
for len > 0 {
|
||||
h64 ^= uint64(buf[0]) * xxhPrime64c5
|
||||
buf = buf[1:]
|
||||
h64 = bits.RotateLeft64(h64, 11) * xxhPrime64c1
|
||||
len--
|
||||
}
|
||||
|
||||
h64 ^= h64 >> 33
|
||||
h64 *= xxhPrime64c2
|
||||
h64 ^= h64 >> 29
|
||||
h64 *= xxhPrime64c3
|
||||
h64 ^= h64 >> 32
|
||||
|
||||
return h64
|
||||
}
|
||||
|
||||
// round updates a value.
|
||||
func (xh *xxhash64) round(v, n uint64) uint64 {
|
||||
v += n * xxhPrime64c2
|
||||
v = bits.RotateLeft64(v, 31)
|
||||
v *= xxhPrime64c1
|
||||
return v
|
||||
}
|
||||
|
||||
// mergeRound updates a value in the final round.
|
||||
func (xh *xxhash64) mergeRound(v, n uint64) uint64 {
|
||||
n = xh.round(0, n)
|
||||
v ^= n
|
||||
v = v*xxhPrime64c1 + xxhPrime64c4
|
||||
return v
|
||||
}
|
105
src/internal/zstd/xxhash_test.go
Normal file
105
src/internal/zstd/xxhash_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var xxHashTests = []struct {
|
||||
data string
|
||||
hash uint64
|
||||
}{
|
||||
{
|
||||
"hello, world",
|
||||
0xb33a384e6d1b1242,
|
||||
},
|
||||
{
|
||||
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789$",
|
||||
0x1032d841e824f998,
|
||||
},
|
||||
}
|
||||
|
||||
func TestXXHash(t *testing.T) {
|
||||
var xh xxhash64
|
||||
for i, test := range xxHashTests {
|
||||
xh.reset()
|
||||
xh.update([]byte(test.data))
|
||||
if got := xh.digest(); got != test.hash {
|
||||
t.Errorf("#%d: got %#x want %#x", i, got, test.hash)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLargeXXHash(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping expensive test in short mode")
|
||||
}
|
||||
|
||||
data := bigData(t)
|
||||
var xh xxhash64
|
||||
xh.reset()
|
||||
i := 0
|
||||
for i < len(data) {
|
||||
// Write varying amounts to test buffering.
|
||||
c := i%4094 + 1
|
||||
if i+c > len(data) {
|
||||
c = len(data) - i
|
||||
}
|
||||
xh.update(data[i : i+c])
|
||||
i += c
|
||||
}
|
||||
|
||||
got := xh.digest()
|
||||
want := uint64(0xf0dd39fd7e063f82)
|
||||
if got != want {
|
||||
t.Errorf("got %#x want %#x", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzXXHash(f *testing.F) {
|
||||
if _, err := os.Stat("/usr/bin/xxhsum"); err != nil {
|
||||
f.Skip("skipping because /usr/bin/xxhsum does not exist")
|
||||
}
|
||||
|
||||
for _, test := range xxHashTests {
|
||||
f.Add([]byte(test.data))
|
||||
}
|
||||
f.Add(bytes.Repeat([]byte("abcdefghijklmnop"), 256))
|
||||
var buf bytes.Buffer
|
||||
for i := 0; i < 256; i++ {
|
||||
buf.WriteByte(byte(i))
|
||||
}
|
||||
f.Add(bytes.Repeat(buf.Bytes(), 64))
|
||||
f.Add(bigData(f))
|
||||
|
||||
f.Fuzz(func(t *testing.T, b []byte) {
|
||||
cmd := exec.Command("/usr/bin/xxhsum", "-H64")
|
||||
cmd.Stdin = bytes.NewReader(b)
|
||||
var hhsumHash bytes.Buffer
|
||||
cmd.Stdout = &hhsumHash
|
||||
if err := cmd.Run(); err != nil {
|
||||
t.Fatalf("running hhsum failed: %v", err)
|
||||
}
|
||||
hhHashBytes := bytes.Fields(bytes.TrimSpace(hhsumHash.Bytes()))[0]
|
||||
hhHash, err := strconv.ParseUint(string(hhHashBytes), 16, 64)
|
||||
if err != nil {
|
||||
t.Fatalf("could not parse hash %q: %v", hhHashBytes, err)
|
||||
}
|
||||
|
||||
var xh xxhash64
|
||||
xh.reset()
|
||||
xh.update(b)
|
||||
goHash := xh.digest()
|
||||
|
||||
if goHash != hhHash {
|
||||
t.Errorf("Go hash %#x != xxhsum hash %#x", goHash, hhHash)
|
||||
}
|
||||
})
|
||||
}
|
508
src/internal/zstd/zstd.go
Normal file
508
src/internal/zstd/zstd.go
Normal file
@ -0,0 +1,508 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package zstd provides a decompressor for zstd streams,
|
||||
// described in RFC 8878. It does not support dictionaries.
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// fuzzing is a fuzzer hook set to true when fuzzing.
|
||||
// This is used to reject cases where we don't match zstd.
|
||||
var fuzzing = false
|
||||
|
||||
// Reader implements [io.Reader] to read a zstd compressed stream.
|
||||
type Reader struct {
|
||||
// The underlying Reader.
|
||||
r io.Reader
|
||||
|
||||
// Whether we have read the frame header.
|
||||
// This is of interest when buffer is empty.
|
||||
// If true we expect to see a new block.
|
||||
sawFrameHeader bool
|
||||
|
||||
// Whether the current frame expects a checksum.
|
||||
hasChecksum bool
|
||||
|
||||
// Whether we have read at least one frame.
|
||||
readOneFrame bool
|
||||
|
||||
// True if the frame size is not known.
|
||||
frameSizeUnknown bool
|
||||
|
||||
// The number of uncompressed bytes remaining in the current frame.
|
||||
// If frameSizeUnknown is true, this is not valid.
|
||||
remainingFrameSize uint64
|
||||
|
||||
// The number of bytes read from r up to the start of the current
|
||||
// block, for error reporting.
|
||||
blockOffset int64
|
||||
|
||||
// Buffered decompressed data.
|
||||
buffer []byte
|
||||
// Current read offset in buffer.
|
||||
off int
|
||||
|
||||
// The current repeated offsets.
|
||||
repeatedOffset1 uint32
|
||||
repeatedOffset2 uint32
|
||||
repeatedOffset3 uint32
|
||||
|
||||
// The current Huffman tree used for compressing literals.
|
||||
huffmanTable []uint16
|
||||
huffmanTableBits int
|
||||
|
||||
// The window for back references.
|
||||
windowSize int // maximum required window size
|
||||
window []byte // window data
|
||||
|
||||
// A buffer available to hold a compressed block.
|
||||
compressedBuf []byte
|
||||
|
||||
// A buffer for literals.
|
||||
literals []byte
|
||||
|
||||
// Sequence decode FSE tables.
|
||||
seqTables [3][]fseBaselineEntry
|
||||
seqTableBits [3]uint8
|
||||
|
||||
// Buffers for sequence decode FSE tables.
|
||||
seqTableBuffers [3][]fseBaselineEntry
|
||||
|
||||
// Scratch space used for small reads, to avoid allocation.
|
||||
scratch [16]byte
|
||||
|
||||
// A scratch table for reading an FSE. Only temporarily valid.
|
||||
fseScratch []fseEntry
|
||||
|
||||
// For checksum computation.
|
||||
checksum xxhash64
|
||||
}
|
||||
|
||||
// NewReader creates a new Reader that decompresses data from the given reader.
|
||||
func NewReader(input io.Reader) *Reader {
|
||||
r := new(Reader)
|
||||
r.Reset(input)
|
||||
return r
|
||||
}
|
||||
|
||||
// Reset discards the current state and starts reading a new stream from r.
|
||||
// This permits reusing a Reader rather than allocating a new one.
|
||||
func (r *Reader) Reset(input io.Reader) {
|
||||
r.r = input
|
||||
|
||||
// Several fields are preserved to avoid allocation.
|
||||
// Others are always set before they are used.
|
||||
r.sawFrameHeader = false
|
||||
r.hasChecksum = false
|
||||
r.readOneFrame = false
|
||||
r.frameSizeUnknown = false
|
||||
r.remainingFrameSize = 0
|
||||
r.blockOffset = 0
|
||||
// buffer
|
||||
r.off = 0
|
||||
// repeatedOffset1
|
||||
// repeatedOffset2
|
||||
// repeatedOffset3
|
||||
// huffmanTable
|
||||
// huffmanTableBits
|
||||
// windowSize
|
||||
// window
|
||||
// compressedBuf
|
||||
// literals
|
||||
// seqTables
|
||||
// seqTableBits
|
||||
// seqTableBuffers
|
||||
// scratch
|
||||
// fseScratch
|
||||
}
|
||||
|
||||
// Read implements [io.Reader].
|
||||
func (r *Reader) Read(p []byte) (int, error) {
|
||||
if err := r.refillIfNeeded(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := copy(p, r.buffer[r.off:])
|
||||
r.off += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// ReadByte implements [io.ByteReader].
|
||||
func (r *Reader) ReadByte() (byte, error) {
|
||||
if err := r.refillIfNeeded(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
ret := r.buffer[r.off]
|
||||
r.off++
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// refillIfNeeded reads the next block if necessary.
|
||||
func (r *Reader) refillIfNeeded() error {
|
||||
for r.off >= len(r.buffer) {
|
||||
if err := r.refill(); err != nil {
|
||||
return err
|
||||
}
|
||||
r.off = 0
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// refill reads and decompresses the next block.
|
||||
func (r *Reader) refill() error {
|
||||
if !r.sawFrameHeader {
|
||||
if err := r.readFrameHeader(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return r.readBlock()
|
||||
}
|
||||
|
||||
// readFrameHeader reads the frame header and prepares to read a block.
|
||||
func (r *Reader) readFrameHeader() error {
|
||||
retry:
|
||||
relativeOffset := 0
|
||||
|
||||
// Read magic number. RFC 3.1.1.
|
||||
if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
|
||||
// We require that the stream contain at least one frame.
|
||||
if err == io.EOF && !r.readOneFrame {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return r.wrapError(relativeOffset, err)
|
||||
}
|
||||
|
||||
if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
|
||||
if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
|
||||
// This is a skippable frame.
|
||||
r.blockOffset += int64(relativeOffset) + 4
|
||||
if err := r.skipFrame(); err != nil {
|
||||
return err
|
||||
}
|
||||
goto retry
|
||||
}
|
||||
|
||||
return r.makeError(relativeOffset, "invalid magic number")
|
||||
}
|
||||
|
||||
relativeOffset += 4
|
||||
|
||||
// Read Frame_Header_Descriptor. RFC 3.1.1.1.1.
|
||||
if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
|
||||
return r.wrapNonEOFError(relativeOffset, err)
|
||||
}
|
||||
descriptor := r.scratch[0]
|
||||
|
||||
singleSegment := descriptor&(1<<5) != 0
|
||||
|
||||
fcsFieldSize := 1 << (descriptor >> 6)
|
||||
if fcsFieldSize == 1 && !singleSegment {
|
||||
fcsFieldSize = 0
|
||||
}
|
||||
|
||||
var windowDescriptorSize int
|
||||
if singleSegment {
|
||||
windowDescriptorSize = 0
|
||||
} else {
|
||||
windowDescriptorSize = 1
|
||||
}
|
||||
|
||||
if descriptor&(1<<3) != 0 {
|
||||
return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
|
||||
}
|
||||
|
||||
r.hasChecksum = descriptor&(1<<2) != 0
|
||||
if r.hasChecksum {
|
||||
r.checksum.reset()
|
||||
}
|
||||
|
||||
if descriptor&3 != 0 {
|
||||
return r.makeError(relativeOffset, "dictionaries are not supported")
|
||||
}
|
||||
|
||||
relativeOffset++
|
||||
|
||||
headerSize := windowDescriptorSize + fcsFieldSize
|
||||
|
||||
if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
|
||||
return r.wrapNonEOFError(relativeOffset, err)
|
||||
}
|
||||
|
||||
// Figure out the maximum amount of data we need to retain
|
||||
// for backreferences.
|
||||
|
||||
if singleSegment {
|
||||
// No window required, as all the data is in a single buffer.
|
||||
r.windowSize = 0
|
||||
} else {
|
||||
// Window descriptor. RFC 3.1.1.1.2.
|
||||
windowDescriptor := r.scratch[0]
|
||||
exponent := uint64(windowDescriptor >> 3)
|
||||
mantissa := uint64(windowDescriptor & 7)
|
||||
windowLog := exponent + 10
|
||||
windowBase := uint64(1) << windowLog
|
||||
windowAdd := (windowBase / 8) * mantissa
|
||||
windowSize := windowBase + windowAdd
|
||||
|
||||
// Default zstd sets limits on the window size.
|
||||
if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
|
||||
return r.makeError(relativeOffset, "windowSize too large")
|
||||
}
|
||||
|
||||
// RFC 8878 permits us to set an 8M max on window size.
|
||||
if windowSize > 8<<20 {
|
||||
windowSize = 8 << 20
|
||||
}
|
||||
|
||||
r.windowSize = int(windowSize)
|
||||
}
|
||||
|
||||
// Frame_Content_Size. RFC 3.1.1.4.
|
||||
r.frameSizeUnknown = false
|
||||
r.remainingFrameSize = 0
|
||||
fb := r.scratch[windowDescriptorSize:]
|
||||
switch fcsFieldSize {
|
||||
case 0:
|
||||
r.frameSizeUnknown = true
|
||||
case 1:
|
||||
r.remainingFrameSize = uint64(fb[0])
|
||||
case 2:
|
||||
r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
|
||||
case 4:
|
||||
r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
|
||||
case 8:
|
||||
r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
|
||||
relativeOffset += headerSize
|
||||
|
||||
r.sawFrameHeader = true
|
||||
r.readOneFrame = true
|
||||
r.blockOffset += int64(relativeOffset)
|
||||
|
||||
// Prepare to read blocks from the frame.
|
||||
r.repeatedOffset1 = 1
|
||||
r.repeatedOffset2 = 4
|
||||
r.repeatedOffset3 = 8
|
||||
r.huffmanTableBits = 0
|
||||
r.window = r.window[:0]
|
||||
r.seqTables[0] = nil
|
||||
r.seqTables[1] = nil
|
||||
r.seqTables[2] = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// skipFrame skips a skippable frame. RFC 3.1.2.
|
||||
func (r *Reader) skipFrame() error {
|
||||
relativeOffset := 0
|
||||
|
||||
if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
|
||||
return r.wrapNonEOFError(relativeOffset, err)
|
||||
}
|
||||
|
||||
relativeOffset += 4
|
||||
|
||||
size := binary.LittleEndian.Uint32(r.scratch[:4])
|
||||
|
||||
if seeker, ok := r.r.(io.Seeker); ok {
|
||||
if _, err := seeker.Seek(int64(size), io.SeekCurrent); err != nil {
|
||||
return err
|
||||
}
|
||||
r.blockOffset += int64(relativeOffset) + int64(size)
|
||||
return nil
|
||||
}
|
||||
|
||||
var skip []byte
|
||||
const chunk = 1 << 20 // 1M
|
||||
for size >= chunk {
|
||||
if len(skip) == 0 {
|
||||
skip = make([]byte, chunk)
|
||||
}
|
||||
if _, err := io.ReadFull(r.r, skip); err != nil {
|
||||
return r.wrapNonEOFError(relativeOffset, err)
|
||||
}
|
||||
relativeOffset += chunk
|
||||
size -= chunk
|
||||
}
|
||||
if size > 0 {
|
||||
if len(skip) == 0 {
|
||||
skip = make([]byte, size)
|
||||
}
|
||||
if _, err := io.ReadFull(r.r, skip); err != nil {
|
||||
return r.wrapNonEOFError(relativeOffset, err)
|
||||
}
|
||||
relativeOffset += int(size)
|
||||
}
|
||||
|
||||
r.blockOffset += int64(relativeOffset)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// readBlock reads the next block from a frame.
|
||||
func (r *Reader) readBlock() error {
|
||||
relativeOffset := 0
|
||||
|
||||
// Read Block_Header. RFC 3.1.1.2.
|
||||
if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
|
||||
return r.wrapNonEOFError(relativeOffset, err)
|
||||
}
|
||||
|
||||
relativeOffset += 3
|
||||
|
||||
header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
|
||||
|
||||
lastBlock := header&1 != 0
|
||||
blockType := (header >> 1) & 3
|
||||
blockSize := int(header >> 3)
|
||||
|
||||
// Maximum block size is smaller of window size and 128K.
|
||||
// We don't record the window size for a single segment frame,
|
||||
// so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4.
|
||||
if blockSize > 128<<10 || (r.windowSize > 0 && blockSize > r.windowSize) {
|
||||
return r.makeError(relativeOffset, "block size too large")
|
||||
}
|
||||
|
||||
// Handle different block types. RFC 3.1.1.2.2.
|
||||
switch blockType {
|
||||
case 0:
|
||||
r.setBufferSize(blockSize)
|
||||
if _, err := io.ReadFull(r.r, r.buffer); err != nil {
|
||||
return r.wrapNonEOFError(relativeOffset, err)
|
||||
}
|
||||
relativeOffset += blockSize
|
||||
r.blockOffset += int64(relativeOffset)
|
||||
case 1:
|
||||
r.setBufferSize(blockSize)
|
||||
if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
|
||||
return r.wrapNonEOFError(relativeOffset, err)
|
||||
}
|
||||
relativeOffset++
|
||||
v := r.scratch[0]
|
||||
for i := range r.buffer {
|
||||
r.buffer[i] = v
|
||||
}
|
||||
r.blockOffset += int64(relativeOffset)
|
||||
case 2:
|
||||
r.blockOffset += int64(relativeOffset)
|
||||
if err := r.compressedBlock(blockSize); err != nil {
|
||||
return err
|
||||
}
|
||||
r.blockOffset += int64(blockSize)
|
||||
case 3:
|
||||
return r.makeError(relativeOffset, "invalid block type")
|
||||
}
|
||||
|
||||
if !r.frameSizeUnknown {
|
||||
if uint64(len(r.buffer)) > r.remainingFrameSize {
|
||||
return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
|
||||
}
|
||||
r.remainingFrameSize -= uint64(len(r.buffer))
|
||||
}
|
||||
|
||||
if r.hasChecksum {
|
||||
r.checksum.update(r.buffer)
|
||||
}
|
||||
|
||||
if !lastBlock {
|
||||
r.saveWindow(r.buffer)
|
||||
} else {
|
||||
if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
|
||||
return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
|
||||
}
|
||||
// Check for checksum at end of frame. RFC 3.1.1.
|
||||
if r.hasChecksum {
|
||||
if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
|
||||
return r.wrapNonEOFError(0, err)
|
||||
}
|
||||
|
||||
inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
|
||||
dataChecksum := uint32(r.checksum.digest())
|
||||
if inputChecksum != dataChecksum {
|
||||
return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
|
||||
}
|
||||
|
||||
r.blockOffset += 4
|
||||
}
|
||||
r.sawFrameHeader = false
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setBufferSize sets the decompressed buffer size.
|
||||
// When this is called the buffer is empty.
|
||||
func (r *Reader) setBufferSize(size int) {
|
||||
if cap(r.buffer) < size {
|
||||
need := size - cap(r.buffer)
|
||||
r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
|
||||
}
|
||||
r.buffer = r.buffer[:size]
|
||||
}
|
||||
|
||||
// saveWindow saves bytes in the backreference window.
|
||||
// TODO: use a circular buffer for less data movement.
|
||||
func (r *Reader) saveWindow(buf []byte) {
|
||||
if r.windowSize == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if len(buf) >= r.windowSize {
|
||||
from := len(buf) - r.windowSize
|
||||
r.window = append(r.window[:0], buf[from:]...)
|
||||
return
|
||||
}
|
||||
|
||||
keep := r.windowSize - len(buf) // must be positive
|
||||
if keep < len(r.window) {
|
||||
remove := len(r.window) - keep
|
||||
copy(r.window[:], r.window[remove:])
|
||||
}
|
||||
|
||||
r.window = append(r.window, buf...)
|
||||
}
|
||||
|
||||
// zstdError is an error while decompressing.
|
||||
type zstdError struct {
|
||||
offset int64
|
||||
err error
|
||||
}
|
||||
|
||||
func (ze *zstdError) Error() string {
|
||||
return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
|
||||
}
|
||||
|
||||
func (ze *zstdError) Unwrap() error {
|
||||
return ze.err
|
||||
}
|
||||
|
||||
func (r *Reader) makeEOFError(off int) error {
|
||||
return r.wrapError(off, io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
func (r *Reader) wrapNonEOFError(off int, err error) error {
|
||||
if err == io.EOF {
|
||||
err = io.ErrUnexpectedEOF
|
||||
}
|
||||
return r.wrapError(off, err)
|
||||
}
|
||||
|
||||
func (r *Reader) makeError(off int, msg string) error {
|
||||
return r.wrapError(off, errors.New(msg))
|
||||
}
|
||||
|
||||
func (r *Reader) wrapError(off int, err error) error {
|
||||
if err == io.EOF {
|
||||
return err
|
||||
}
|
||||
return &zstdError{r.blockOffset + int64(off), err}
|
||||
}
|
249
src/internal/zstd/zstd_test.go
Normal file
249
src/internal/zstd/zstd_test.go
Normal file
@ -0,0 +1,249 @@
|
||||
// Copyright 2023 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package zstd
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"internal/race"
|
||||
"internal/testenv"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// tests holds some simple test cases, including some found by fuzzing.
|
||||
var tests = []struct {
|
||||
name, uncompressed, compressed string
|
||||
}{
|
||||
{
|
||||
"hello",
|
||||
"hello, world\n",
|
||||
"\x28\xb5\x2f\xfd\x24\x0d\x69\x00\x00\x68\x65\x6c\x6c\x6f\x2c\x20\x77\x6f\x72\x6c\x64\x0a\x4c\x1f\xf9\xf1",
|
||||
},
|
||||
{
|
||||
// a small compressed .debug_ranges section.
|
||||
"ranges",
|
||||
"\xcc\x11\x00\x00\x00\x00\x00\x00\xd5\x13\x00\x00\x00\x00\x00\x00" +
|
||||
"\x1c\x14\x00\x00\x00\x00\x00\x00\x72\x14\x00\x00\x00\x00\x00\x00" +
|
||||
"\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
|
||||
"\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" +
|
||||
"\x0c\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" +
|
||||
"\x29\x14\x00\x00\x00\x00\x00\x00\x4e\x14\x00\x00\x00\x00\x00\x00" +
|
||||
"\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
|
||||
"\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" +
|
||||
"\x67\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" +
|
||||
"\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
|
||||
"\x5f\x0b\x00\x00\x00\x00\x00\x00\x6c\x0b\x00\x00\x00\x00\x00\x00" +
|
||||
"\x7d\x0b\x00\x00\x00\x00\x00\x00\x7e\x0c\x00\x00\x00\x00\x00\x00" +
|
||||
"\x38\x0f\x00\x00\x00\x00\x00\x00\x5c\x0f\x00\x00\x00\x00\x00\x00" +
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
|
||||
"\x83\x0c\x00\x00\x00\x00\x00\x00\xfa\x0c\x00\x00\x00\x00\x00\x00" +
|
||||
"\xfd\x0d\x00\x00\x00\x00\x00\x00\xef\x0e\x00\x00\x00\x00\x00\x00" +
|
||||
"\x14\x0f\x00\x00\x00\x00\x00\x00\x38\x0f\x00\x00\x00\x00\x00\x00" +
|
||||
"\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" +
|
||||
"\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" +
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
|
||||
"\xfd\x0d\x00\x00\x00\x00\x00\x00\xd8\x0e\x00\x00\x00\x00\x00\x00" +
|
||||
"\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" +
|
||||
"\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" +
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
|
||||
"\xfa\x0c\x00\x00\x00\x00\x00\x00\xea\x0d\x00\x00\x00\x00\x00\x00" +
|
||||
"\xef\x0e\x00\x00\x00\x00\x00\x00\x14\x0f\x00\x00\x00\x00\x00\x00" +
|
||||
"\x5c\x0f\x00\x00\x00\x00\x00\x00\x9f\x0f\x00\x00\x00\x00\x00\x00" +
|
||||
"\xac\x0f\x00\x00\x00\x00\x00\x00\xdb\x0f\x00\x00\x00\x00\x00\x00" +
|
||||
"\xff\x0f\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" +
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
|
||||
"\x60\x11\x00\x00\x00\x00\x00\x00\xd1\x16\x00\x00\x00\x00\x00\x00" +
|
||||
"\x40\x0b\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" +
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
|
||||
"\x7a\x00\x00\x00\x00\x00\x00\x00\xb6\x00\x00\x00\x00\x00\x00\x00" +
|
||||
"\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" +
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
|
||||
"\x7a\x00\x00\x00\x00\x00\x00\x00\xa9\x00\x00\x00\x00\x00\x00\x00" +
|
||||
"\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" +
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
|
||||
|
||||
"\x28\xb5\x2f\xfd\x64\xa0\x01\x2d\x05\x00\xc4\x04\xcc\x11\x00\xd5" +
|
||||
"\x13\x00\x1c\x14\x00\x72\x9d\xd5\xfb\x12\x00\x09\x0c\x13\xcb\x13" +
|
||||
"\x29\x4e\x67\x5f\x0b\x6c\x0b\x7d\x0b\x7e\x0c\x38\x0f\x5c\x0f\x83" +
|
||||
"\x0c\xfa\x0c\xfd\x0d\xef\x0e\x14\x38\x9f\x0f\xac\x0f\xdb\x0f\xff" +
|
||||
"\x0f\xd8\x9f\xac\xdb\xff\xea\x5c\x2c\x10\x60\xd1\x16\x40\x0b\x7a" +
|
||||
"\x00\xb6\x00\x9f\x01\xa7\x01\xa9\x36\x20\xa0\x83\x14\x34\x63\x4a" +
|
||||
"\x21\x70\x8c\x07\x46\x03\x4e\x10\x62\x3c\x06\x4e\xc8\x8c\xb0\x32" +
|
||||
"\x2a\x59\xad\xb2\xf1\x02\x82\x7c\x33\xcb\x92\x6f\x32\x4f\x9b\xb0" +
|
||||
"\xa2\x30\xf0\xc0\x06\x1e\x98\x99\x2c\x06\x1e\xd8\xc0\x03\x56\xd8" +
|
||||
"\xc0\x03\x0f\x6c\xe0\x01\xf1\xf0\xee\x9a\xc6\xc8\x97\x99\xd1\x6c" +
|
||||
"\xb4\x21\x45\x3b\x10\xe4\x7b\x99\x4d\x8a\x36\x64\x5c\x77\x08\x02" +
|
||||
"\xcb\xe0\xce",
|
||||
},
|
||||
{
|
||||
"fuzz1",
|
||||
"0\x00\x00\x00\x00\x000\x00\x00\x00\x00\x001\x00\x00\x00\x00\x000000",
|
||||
"(\xb5/\xfd\x04X\x8d\x00\x00P0\x000\x001\x000000\x03T\x02\x00\x01\x01m\xf9\xb7G",
|
||||
},
|
||||
}
|
||||
|
||||
func TestSamples(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
r := NewReader(strings.NewReader(test.compressed))
|
||||
got, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gotstr := string(got)
|
||||
if gotstr != test.uncompressed {
|
||||
t.Errorf("got %q want %q", gotstr, test.uncompressed)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
bigDataOnce sync.Once
|
||||
bigDataBytes []byte
|
||||
bigDataErr error
|
||||
)
|
||||
|
||||
// bigData returns the contents of our large test file.
|
||||
func bigData(t testing.TB) []byte {
|
||||
bigDataOnce.Do(func() {
|
||||
bigDataBytes, bigDataErr = os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
|
||||
})
|
||||
if bigDataErr != nil {
|
||||
t.Fatal(bigDataErr)
|
||||
}
|
||||
return bigDataBytes
|
||||
}
|
||||
|
||||
var (
|
||||
zstdBigOnce sync.Once
|
||||
zstdBigBytes []byte
|
||||
zstdBigSkip bool
|
||||
zstdBigErr error
|
||||
)
|
||||
|
||||
// zstdBigData returns the compressed contents of our large test file.
|
||||
// This will only run on Unix systems with zstd installed.
|
||||
// That's OK as the package is GOOS-independent.
|
||||
func zstdBigData(t testing.TB) []byte {
|
||||
input := bigData(t)
|
||||
|
||||
zstdBigOnce.Do(func() {
|
||||
if _, err := os.Stat("/usr/bin/zstd"); err != nil {
|
||||
zstdBigSkip = true
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("/usr/bin/zstd", "-z")
|
||||
cmd.Stdin = bytes.NewReader(input)
|
||||
var compressed bytes.Buffer
|
||||
cmd.Stdout = &compressed
|
||||
cmd.Stderr = os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
zstdBigErr = fmt.Errorf("running zstd failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
zstdBigBytes = compressed.Bytes()
|
||||
})
|
||||
if zstdBigSkip {
|
||||
t.Skip("skipping because /usr/bin/zstd does not exist")
|
||||
}
|
||||
if zstdBigErr != nil {
|
||||
t.Fatal(zstdBigErr)
|
||||
}
|
||||
return zstdBigBytes
|
||||
}
|
||||
|
||||
// Test decompressing a large file. We don't have a compressor,
|
||||
// so this test only runs on systems with zstd installed.
|
||||
func TestLarge(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping expensive test in short mode")
|
||||
}
|
||||
|
||||
data := bigData(t)
|
||||
compressed := zstdBigData(t)
|
||||
|
||||
t.Logf("/usr/bin/zstd compressed %d bytes to %d", len(data), len(compressed))
|
||||
|
||||
r := NewReader(bytes.NewReader(compressed))
|
||||
got, err := io.ReadAll(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(got, data) {
|
||||
showDiffs(t, got, data)
|
||||
}
|
||||
}
|
||||
|
||||
// showDiffs reports the first few differences in two []byte.
|
||||
func showDiffs(t *testing.T, got, want []byte) {
|
||||
t.Error("data mismatch")
|
||||
if len(got) != len(want) {
|
||||
t.Errorf("got data length %d, want %d", len(got), len(want))
|
||||
}
|
||||
diffs := 0
|
||||
for i, b := range got {
|
||||
if i >= len(want) {
|
||||
break
|
||||
}
|
||||
if b != want[i] {
|
||||
diffs++
|
||||
if diffs > 20 {
|
||||
break
|
||||
}
|
||||
t.Logf("%d: %#x != %#x", i, b, want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAlloc(t *testing.T) {
|
||||
testenv.SkipIfOptimizationOff(t)
|
||||
if race.Enabled {
|
||||
t.Skip("skipping allocation test under race detector")
|
||||
}
|
||||
|
||||
compressed := zstdBigData(t)
|
||||
input := bytes.NewReader(compressed)
|
||||
r := NewReader(input)
|
||||
c := testing.AllocsPerRun(10, func() {
|
||||
input.Reset(compressed)
|
||||
r.Reset(input)
|
||||
io.Copy(io.Discard, r)
|
||||
})
|
||||
if c != 0 {
|
||||
t.Errorf("got %v allocs, want 0", c)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkLarge(b *testing.B) {
|
||||
b.StopTimer()
|
||||
b.ReportAllocs()
|
||||
|
||||
compressed := zstdBigData(b)
|
||||
|
||||
b.SetBytes(int64(len(compressed)))
|
||||
|
||||
input := bytes.NewReader(compressed)
|
||||
r := NewReader(input)
|
||||
|
||||
b.StartTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
input.Reset(compressed)
|
||||
r.Reset(input)
|
||||
io.Copy(io.Discard, r)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user