1
0
mirror of https://github.com/golang/go synced 2024-11-26 14:08:37 -07:00

internal/zstd: new internal package for zstd decompression

This package only does zstd decompression, which is starting to
be used for ELF debug sections. If we need zstd compression we
should use github.com/klauspost/compress/zstd. But for now that
is a very large package to vendor into the standard library.

For #55107

Change-Id: I60ede735357d491be653477ed419cf5f2f0d3f71
Reviewed-on: https://go-review.googlesource.com/c/go/+/473356
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Joseph Tsai <joetsai@digital-static.net>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Bryan Mills <bcmills@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
This commit is contained in:
Ian Lance Taylor 2023-03-03 11:42:07 -08:00 committed by Gopher Robot
parent d5514013b6
commit 73ee0fcf37
12 changed files with 2777 additions and 1 deletions

View File

@ -226,7 +226,7 @@ var depsRules = `
# compression
FMT, encoding/binary, hash/adler32, hash/crc32
< compress/bzip2, compress/flate, compress/lzw
< compress/bzip2, compress/flate, compress/lzw, internal/zstd
< archive/zip, compress/gzip, compress/zlib;
# templates

130
src/internal/zstd/bits.go Normal file
View File

@ -0,0 +1,130 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zstd
import (
"math/bits"
)
// block is the data for a single compressed block.
// The data starts immediately after the 3 byte block header,
// and is Block_Size bytes long.
type block []byte
// bitReader reads a bit stream going forward.
type bitReader struct {
r *Reader // for error reporting
data block // the bits to read
off uint32 // current offset into data
bits uint32 // bits ready to be returned
cnt uint32 // number of valid bits in the bits field
}
// makeBitReader makes a bit reader starting at off.
func (r *Reader) makeBitReader(data block, off int) bitReader {
return bitReader{
r: r,
data: data,
off: uint32(off),
}
}
// moreBits is called to read more bits.
// This ensures that at least 16 bits are available.
func (br *bitReader) moreBits() error {
for br.cnt < 16 {
if br.off >= uint32(len(br.data)) {
return br.r.makeEOFError(int(br.off))
}
c := br.data[br.off]
br.off++
br.bits |= uint32(c) << br.cnt
br.cnt += 8
}
return nil
}
// val is called to fetch a value of b bits.
func (br *bitReader) val(b uint8) uint32 {
r := br.bits & ((1 << b) - 1)
br.bits >>= b
br.cnt -= uint32(b)
return r
}
// backup steps back to the last byte we used.
func (br *bitReader) backup() {
for br.cnt >= 8 {
br.off--
br.cnt -= 8
}
}
// makeError returns an error at the current offset wrapping a string.
func (br *bitReader) makeError(msg string) error {
return br.r.makeError(int(br.off), msg)
}
// reverseBitReader reads a bit stream in reverse.
type reverseBitReader struct {
r *Reader // for error reporting
data block // the bits to read
off uint32 // current offset into data
start uint32 // start in data; we read backward to start
bits uint32 // bits ready to be returned
cnt uint32 // number of valid bits in bits field
}
// makeReverseBitReader makes a reverseBitReader reading backward
// from off to start. The bitstream starts with a 1 bit in the last
// byte, at off.
func (r *Reader) makeReverseBitReader(data block, off, start int) (reverseBitReader, error) {
streamStart := data[off]
if streamStart == 0 {
return reverseBitReader{}, r.makeError(off, "zero byte at reverse bit stream start")
}
rbr := reverseBitReader{
r: r,
data: data,
off: uint32(off),
start: uint32(start),
bits: uint32(streamStart),
cnt: uint32(7 - bits.LeadingZeros8(streamStart)),
}
return rbr, nil
}
// val is called to fetch a value of b bits.
func (rbr *reverseBitReader) val(b uint8) (uint32, error) {
if !rbr.fetch(b) {
return 0, rbr.r.makeEOFError(int(rbr.off))
}
rbr.cnt -= uint32(b)
v := (rbr.bits >> rbr.cnt) & ((1 << b) - 1)
return v, nil
}
// fetch is called to ensure that at least b bits are available.
// It reports false if this can't be done,
// in which case only rbr.cnt bits are available.
func (rbr *reverseBitReader) fetch(b uint8) bool {
for rbr.cnt < uint32(b) {
if rbr.off <= rbr.start {
return false
}
rbr.off--
c := rbr.data[rbr.off]
rbr.bits <<= 8
rbr.bits |= uint32(c)
rbr.cnt += 8
}
return true
}
// makeError returns an error at the current offset wrapping a string.
func (rbr *reverseBitReader) makeError(msg string) error {
return rbr.r.makeError(int(rbr.off), msg)
}

436
src/internal/zstd/block.go Normal file
View File

@ -0,0 +1,436 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zstd
import (
"io"
)
// debug can be set in the source to print debug info using println.
const debug = false
// compressedBlock decompresses a compressed block, storing the decompressed
// data in r.buffer. The blockSize argument is the compressed size.
// RFC 3.1.1.3.
func (r *Reader) compressedBlock(blockSize int) error {
if len(r.compressedBuf) >= blockSize {
r.compressedBuf = r.compressedBuf[:blockSize]
} else {
// We know that blockSize <= 128K,
// so this won't allocate an enormous amount.
need := blockSize - len(r.compressedBuf)
r.compressedBuf = append(r.compressedBuf, make([]byte, need)...)
}
if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil {
return r.wrapNonEOFError(0, err)
}
data := block(r.compressedBuf)
off := 0
r.buffer = r.buffer[:0]
litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0])
if err != nil {
return err
}
r.literals = litbuf
off = litoff
seqCount, off, err := r.initSeqs(data, off)
if err != nil {
return err
}
if seqCount == 0 {
// No sequences, just literals.
if off < len(data) {
return r.makeError(off, "extraneous data after no sequences")
}
if len(litbuf) == 0 {
return r.makeError(off, "no sequences and no literals")
}
r.buffer = append(r.buffer, litbuf...)
return nil
}
return r.execSeqs(data, off, litbuf, seqCount)
}
// seqCode is the kind of sequence codes we have to handle.
type seqCode int
const (
seqLiteral seqCode = iota
seqOffset
seqMatch
)
// seqCodeInfoData is the information needed to set up seqTables and
// seqTableBits for a particular kind of sequence code.
type seqCodeInfoData struct {
predefTable []fseBaselineEntry // predefined FSE
predefTableBits int // number of bits in predefTable
maxSym int // max symbol value in FSE
maxBits int // max bits for FSE
// toBaseline converts from an FSE table to an FSE baseline table.
toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
}
// seqCodeInfo is the seqCodeInfoData for each kind of sequence code.
var seqCodeInfo = [3]seqCodeInfoData{
seqLiteral: {
predefTable: predefinedLiteralTable[:],
predefTableBits: 6,
maxSym: 35,
maxBits: 9,
toBaseline: (*Reader).makeLiteralBaselineFSE,
},
seqOffset: {
predefTable: predefinedOffsetTable[:],
predefTableBits: 5,
maxSym: 31,
maxBits: 8,
toBaseline: (*Reader).makeOffsetBaselineFSE,
},
seqMatch: {
predefTable: predefinedMatchTable[:],
predefTableBits: 6,
maxSym: 52,
maxBits: 9,
toBaseline: (*Reader).makeMatchBaselineFSE,
},
}
// initSeqs reads the Sequences_Section_Header and sets up the FSE
// tables used to read the sequence codes. It returns the number of
// sequences and the new offset. RFC 3.1.1.3.2.1.
func (r *Reader) initSeqs(data block, off int) (int, int, error) {
if off >= len(data) {
return 0, 0, r.makeEOFError(off)
}
seqHdr := data[off]
off++
if seqHdr == 0 {
return 0, off, nil
}
var seqCount int
if seqHdr < 128 {
seqCount = int(seqHdr)
} else if seqHdr < 255 {
if off >= len(data) {
return 0, 0, r.makeEOFError(off)
}
seqCount = ((int(seqHdr) - 128) << 8) + int(data[off])
off++
} else {
if off+1 >= len(data) {
return 0, 0, r.makeEOFError(off)
}
seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00
off += 2
}
// Read the Symbol_Compression_Modes byte.
if off >= len(data) {
return 0, 0, r.makeEOFError(off)
}
symMode := data[off]
if symMode&3 != 0 {
return 0, 0, r.makeError(off, "invalid symbol compression mode")
}
off++
// Set up the FSE tables used to decode the sequence codes.
var err error
off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3)
if err != nil {
return 0, 0, err
}
off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3)
if err != nil {
return 0, 0, err
}
off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3)
if err != nil {
return 0, 0, err
}
return seqCount, off, nil
}
// setSeqTable uses the Compression_Mode in mode to set up r.seqTables and
// r.seqTableBits for kind. We store these in the Reader because one of
// the modes simply reuses the value from the last block in the frame.
func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) {
info := &seqCodeInfo[kind]
switch mode {
case 0:
// Predefined_Mode
r.seqTables[kind] = info.predefTable
r.seqTableBits[kind] = uint8(info.predefTableBits)
return off, nil
case 1:
// RLE_Mode
if off >= len(data) {
return 0, r.makeEOFError(off)
}
rle := data[off]
off++
// Build a simple baseline table that always returns rle.
entry := []fseEntry{
{
sym: rle,
bits: 0,
base: 0,
},
}
if cap(r.seqTableBuffers[kind]) == 0 {
r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
}
r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1]
if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil {
return 0, err
}
r.seqTables[kind] = r.seqTableBuffers[kind]
r.seqTableBits[kind] = 0
return off, nil
case 2:
// FSE_Compressed_Mode
if cap(r.fseScratch) < 1<<info.maxBits {
r.fseScratch = make([]fseEntry, 1<<info.maxBits)
}
r.fseScratch = r.fseScratch[:1<<info.maxBits]
tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch)
if err != nil {
return 0, err
}
r.fseScratch = r.fseScratch[:1<<tableBits]
if cap(r.seqTableBuffers[kind]) == 0 {
r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
}
r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits]
if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil {
return 0, err
}
r.seqTables[kind] = r.seqTableBuffers[kind]
r.seqTableBits[kind] = uint8(tableBits)
return roff, nil
case 3:
// Repeat_Mode
if len(r.seqTables[kind]) == 0 {
return 0, r.makeError(off, "missing repeat sequence FSE table")
}
return off, nil
}
panic("unreachable")
}
// execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2.
func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error {
// Set up the initial states for the sequence code readers.
rbr, err := r.makeReverseBitReader(data, len(data)-1, off)
if err != nil {
return err
}
literalState, err := rbr.val(r.seqTableBits[seqLiteral])
if err != nil {
return err
}
offsetState, err := rbr.val(r.seqTableBits[seqOffset])
if err != nil {
return err
}
matchState, err := rbr.val(r.seqTableBits[seqMatch])
if err != nil {
return err
}
// Read and perform all the sequences. RFC 3.1.1.4.
seq := 0
for seq < seqCount {
if len(r.buffer)+len(litbuf) > 128<<10 {
return rbr.makeError("uncompressed size too big")
}
ptoffset := &r.seqTables[seqOffset][offsetState]
ptmatch := &r.seqTables[seqMatch][matchState]
ptliteral := &r.seqTables[seqLiteral][literalState]
add, err := rbr.val(ptoffset.basebits)
if err != nil {
return err
}
offset := ptoffset.baseline + add
add, err = rbr.val(ptmatch.basebits)
if err != nil {
return err
}
match := ptmatch.baseline + add
add, err = rbr.val(ptliteral.basebits)
if err != nil {
return err
}
literal := ptliteral.baseline + add
// Handle repeat offsets. RFC 3.1.1.5.
// See the comment in makeOffsetBaselineFSE.
if ptoffset.basebits > 1 {
r.repeatedOffset3 = r.repeatedOffset2
r.repeatedOffset2 = r.repeatedOffset1
r.repeatedOffset1 = offset
} else {
if literal == 0 {
offset++
}
switch offset {
case 1:
offset = r.repeatedOffset1
case 2:
offset = r.repeatedOffset2
r.repeatedOffset2 = r.repeatedOffset1
r.repeatedOffset1 = offset
case 3:
offset = r.repeatedOffset3
r.repeatedOffset3 = r.repeatedOffset2
r.repeatedOffset2 = r.repeatedOffset1
r.repeatedOffset1 = offset
case 4:
offset = r.repeatedOffset1 - 1
r.repeatedOffset3 = r.repeatedOffset2
r.repeatedOffset2 = r.repeatedOffset1
r.repeatedOffset1 = offset
}
}
seq++
if seq < seqCount {
// Update the states.
add, err = rbr.val(ptliteral.bits)
if err != nil {
return err
}
literalState = uint32(ptliteral.base) + add
add, err = rbr.val(ptmatch.bits)
if err != nil {
return err
}
matchState = uint32(ptmatch.base) + add
add, err = rbr.val(ptoffset.bits)
if err != nil {
return err
}
offsetState = uint32(ptoffset.base) + add
}
// The next sequence is now in literal, offset, match.
if debug {
println("literal", literal, "offset", offset, "match", match)
}
// Copy literal bytes from litbuf.
if literal > uint32(len(litbuf)) {
return rbr.makeError("literal byte overflow")
}
if literal > 0 {
r.buffer = append(r.buffer, litbuf[:literal]...)
litbuf = litbuf[literal:]
}
if match > 0 {
if err := r.copyFromWindow(&rbr, offset, match); err != nil {
return err
}
}
}
if len(litbuf) > 0 {
r.buffer = append(r.buffer, litbuf...)
}
if rbr.cnt != 0 {
return r.makeError(off, "extraneous data after sequences")
}
return nil
}
// Copy match bytes from the decoded output, or the window, at offset.
func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error {
if offset == 0 {
return rbr.makeError("invalid zero offset")
}
lenBlock := uint32(len(r.buffer))
if lenBlock < offset {
lenWindow := uint32(len(r.window))
windowOffset := offset - lenBlock
if windowOffset > lenWindow {
return rbr.makeError("offset past window")
}
from := lenWindow - windowOffset
if from+match <= lenWindow {
r.buffer = append(r.buffer, r.window[from:from+match]...)
return nil
}
r.buffer = append(r.buffer, r.window[from:]...)
copied := lenWindow - from
offset -= copied
match -= copied
if offset == 0 && match > 0 {
return rbr.makeError("invalid offset")
}
}
from := lenBlock - offset
if offset >= match {
r.buffer = append(r.buffer, r.buffer[from:from+match]...)
return nil
}
// We are being asked to copy data that we are adding to the
// buffer in the same copy.
for match > 0 {
var copy uint32
if offset >= match {
copy = match
} else {
copy = offset
}
r.buffer = append(r.buffer, r.buffer[from:from+copy]...)
match -= copy
from += copy
}
return nil
}

437
src/internal/zstd/fse.go Normal file
View File

@ -0,0 +1,437 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zstd
import (
"math/bits"
)
// fseEntry is one entry in an FSE table.
type fseEntry struct {
sym uint8 // value that this entry records
bits uint8 // number of bits to read to determine next state
base uint16 // add those bits to this state to get the next state
}
// readFSE reads an FSE table from data starting at off.
// maxSym is the maximum symbol value.
// maxBits is the maximum number of bits permitted for symbols in the table.
// The FSE is written into table, which must be at least 1<<maxBits in size.
// This returns the number of bits in the FSE table and the new offset.
// RFC 4.1.1.
func (r *Reader) readFSE(data block, off, maxSym, maxBits int, table []fseEntry) (tableBits, roff int, err error) {
br := r.makeBitReader(data, off)
if err := br.moreBits(); err != nil {
return 0, 0, err
}
accuracyLog := int(br.val(4)) + 5
if accuracyLog > maxBits {
return 0, 0, br.makeError("FSE accuracy log too large")
}
// The number of remaining probabilities, plus 1.
// This determines the number of bits to be read for the next value.
remaining := (1 << accuracyLog) + 1
// The current difference between small and large values,
// which depends on the number of remaining values.
// Small values use 1 less bit.
threshold := 1 << accuracyLog
// The number of bits needed to compute threshold.
bitsNeeded := accuracyLog + 1
// The next character value.
sym := 0
// Whether the last count was 0.
prev0 := false
var norm [256]int16
for remaining > 1 && sym <= maxSym {
if err := br.moreBits(); err != nil {
return 0, 0, err
}
if prev0 {
// Previous count was 0, so there is a 2-bit
// repeat flag. If the 2-bit flag is 0b11,
// it adds 3 and then there is another repeat flag.
zsym := sym
for (br.bits & 0xfff) == 0xfff {
zsym += 3 * 6
br.bits >>= 12
br.cnt -= 12
if err := br.moreBits(); err != nil {
return 0, 0, err
}
}
for (br.bits & 3) == 3 {
zsym += 3
br.bits >>= 2
br.cnt -= 2
if err := br.moreBits(); err != nil {
return 0, 0, err
}
}
// We have at least 14 bits here,
// no need to call moreBits
zsym += int(br.val(2))
if zsym > maxSym {
return 0, 0, br.makeError("FSE symbol index overflow")
}
for ; sym < zsym; sym++ {
norm[uint8(sym)] = 0
}
prev0 = false
continue
}
max := (2*threshold - 1) - remaining
var count int
if int(br.bits&uint32(threshold-1)) < max {
// A small value.
count = int(br.bits & uint32((threshold - 1)))
br.bits >>= bitsNeeded - 1
br.cnt -= uint32(bitsNeeded - 1)
} else {
// A large value.
count = int(br.bits & uint32((2*threshold - 1)))
if count >= threshold {
count -= max
}
br.bits >>= bitsNeeded
br.cnt -= uint32(bitsNeeded)
}
count--
if count >= 0 {
remaining -= count
} else {
remaining--
}
if sym >= 256 {
return 0, 0, br.makeError("FSE sym overflow")
}
norm[uint8(sym)] = int16(count)
sym++
prev0 = count == 0
for remaining < threshold {
bitsNeeded--
threshold >>= 1
}
}
if remaining != 1 {
return 0, 0, br.makeError("too many symbols in FSE table")
}
for ; sym <= maxSym; sym++ {
norm[uint8(sym)] = 0
}
br.backup()
if err := r.buildFSE(off, norm[:maxSym+1], table, accuracyLog); err != nil {
return 0, 0, err
}
return accuracyLog, int(br.off), nil
}
// buildFSE builds an FSE decoding table from a list of probabilities.
// The probabilities are in norm. next is scratch space. The number of bits
// in the table is tableBits.
func (r *Reader) buildFSE(off int, norm []int16, table []fseEntry, tableBits int) error {
tableSize := 1 << tableBits
highThreshold := tableSize - 1
var next [256]uint16
for i, n := range norm {
if n >= 0 {
next[uint8(i)] = uint16(n)
} else {
table[highThreshold].sym = uint8(i)
highThreshold--
next[uint8(i)] = 1
}
}
pos := 0
step := (tableSize >> 1) + (tableSize >> 3) + 3
mask := tableSize - 1
for i, n := range norm {
for j := 0; j < int(n); j++ {
table[pos].sym = uint8(i)
pos = (pos + step) & mask
for pos > highThreshold {
pos = (pos + step) & mask
}
}
}
if pos != 0 {
return r.makeError(off, "FSE count error")
}
for i := 0; i < tableSize; i++ {
sym := table[i].sym
nextState := next[sym]
next[sym]++
if nextState == 0 {
return r.makeError(off, "FSE state error")
}
highBit := 15 - bits.LeadingZeros16(nextState)
bits := tableBits - highBit
table[i].bits = uint8(bits)
table[i].base = (nextState << bits) - uint16(tableSize)
}
return nil
}
// fseBaselineEntry is an entry in an FSE baseline table.
// We use these for literal/match/length values.
// Those require mapping the symbol to a baseline value,
// and then reading zero or more bits and adding the value to the baseline.
// Rather than looking thees up in separate tables,
// we convert the FSE table to an FSE baseline table.
type fseBaselineEntry struct {
baseline uint32 // baseline for value that this entry represents
basebits uint8 // number of bits to read to add to baseline
bits uint8 // number of bits to read to determine next state
base uint16 // add the bits to this base to get the next state
}
// Given a literal length code, we need to read a number of bits and
// add that to a baseline. For states 0 to 15 the baseline is the
// state and the number of bits is zero. RFC 3.1.1.3.2.1.1.
const literalLengthOffset = 16
var literalLengthBase = []uint32{
16 | (1 << 24),
18 | (1 << 24),
20 | (1 << 24),
22 | (1 << 24),
24 | (2 << 24),
28 | (2 << 24),
32 | (3 << 24),
40 | (3 << 24),
48 | (4 << 24),
64 | (6 << 24),
128 | (7 << 24),
256 | (8 << 24),
512 | (9 << 24),
1024 | (10 << 24),
2048 | (11 << 24),
4096 | (12 << 24),
8192 | (13 << 24),
16384 | (14 << 24),
32768 | (15 << 24),
65536 | (16 << 24),
}
// makeLiteralBaselineFSE converts the literal length fseTable to baselineTable.
func (r *Reader) makeLiteralBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
for i, e := range fseTable {
be := fseBaselineEntry{
bits: e.bits,
base: e.base,
}
if e.sym < literalLengthOffset {
be.baseline = uint32(e.sym)
be.basebits = 0
} else {
if e.sym > 35 {
return r.makeError(off, "FSE baseline symbol overflow")
}
idx := e.sym - literalLengthOffset
basebits := literalLengthBase[idx]
be.baseline = basebits & 0xffffff
be.basebits = uint8(basebits >> 24)
}
baselineTable[i] = be
}
return nil
}
// makeOffsetBaselineFSE converts the offset length fseTable to baselineTable.
func (r *Reader) makeOffsetBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
for i, e := range fseTable {
be := fseBaselineEntry{
bits: e.bits,
base: e.base,
}
if e.sym > 31 {
return r.makeError(off, "FSE offset symbol overflow")
}
// The simple way to write this is
// be.baseline = 1 << e.sym
// be.basebits = e.sym
// That would give us an offset value that corresponds to
// the one described in the RFC. However, for offsets > 3
// we have to subtract 3. And for offset values 1, 2, 3
// we use a repeated offset.
//
// The baseline is always a power of 2, and is never 0,
// so for those low values we will see one entry that is
// baseline 1, basebits 0, and one entry that is baseline 2,
// basebits 1. All other entries will have baseline >= 4
// basebits >= 2.
//
// So we can check for RFC offset <= 3 by checking for
// basebits <= 1. That means that we can subtract 3 here
// and not worry about doing it in the hot loop.
be.baseline = 1 << e.sym
if e.sym >= 2 {
be.baseline -= 3
}
be.basebits = e.sym
baselineTable[i] = be
}
return nil
}
// Given a match length code, we need to read a number of bits and add
// that to a baseline. For states 0 to 31 the baseline is state+3 and
// the number of bits is zero. RFC 3.1.1.3.2.1.1.
const matchLengthOffset = 32
var matchLengthBase = []uint32{
35 | (1 << 24),
37 | (1 << 24),
39 | (1 << 24),
41 | (1 << 24),
43 | (2 << 24),
47 | (2 << 24),
51 | (3 << 24),
59 | (3 << 24),
67 | (4 << 24),
83 | (4 << 24),
99 | (5 << 24),
131 | (7 << 24),
259 | (8 << 24),
515 | (9 << 24),
1027 | (10 << 24),
2051 | (11 << 24),
4099 | (12 << 24),
8195 | (13 << 24),
16387 | (14 << 24),
32771 | (15 << 24),
65539 | (16 << 24),
}
// makeMatchBaselineFSE converts the match length fseTable to baselineTable.
func (r *Reader) makeMatchBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
for i, e := range fseTable {
be := fseBaselineEntry{
bits: e.bits,
base: e.base,
}
if e.sym < matchLengthOffset {
be.baseline = uint32(e.sym) + 3
be.basebits = 0
} else {
if e.sym > 52 {
return r.makeError(off, "FSE baseline symbol overflow")
}
idx := e.sym - matchLengthOffset
basebits := matchLengthBase[idx]
be.baseline = basebits & 0xffffff
be.basebits = uint8(basebits >> 24)
}
baselineTable[i] = be
}
return nil
}
// predefinedLiteralTable is the predefined table to use for literal lengths.
// Generated from table in RFC 3.1.1.3.2.2.1.
// Checked by TestPredefinedTables.
var predefinedLiteralTable = [...]fseBaselineEntry{
{0, 0, 4, 0}, {0, 0, 4, 16}, {1, 0, 5, 32},
{3, 0, 5, 0}, {4, 0, 5, 0}, {6, 0, 5, 0},
{7, 0, 5, 0}, {9, 0, 5, 0}, {10, 0, 5, 0},
{12, 0, 5, 0}, {14, 0, 6, 0}, {16, 1, 5, 0},
{20, 1, 5, 0}, {22, 1, 5, 0}, {28, 2, 5, 0},
{32, 3, 5, 0}, {48, 4, 5, 0}, {64, 6, 5, 32},
{128, 7, 5, 0}, {256, 8, 6, 0}, {1024, 10, 6, 0},
{4096, 12, 6, 0}, {0, 0, 4, 32}, {1, 0, 4, 0},
{2, 0, 5, 0}, {4, 0, 5, 32}, {5, 0, 5, 0},
{7, 0, 5, 32}, {8, 0, 5, 0}, {10, 0, 5, 32},
{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 1, 5, 32},
{18, 1, 5, 0}, {22, 1, 5, 32}, {24, 2, 5, 0},
{32, 3, 5, 32}, {40, 3, 5, 0}, {64, 6, 4, 0},
{64, 6, 4, 16}, {128, 7, 5, 32}, {512, 9, 6, 0},
{2048, 11, 6, 0}, {0, 0, 4, 48}, {1, 0, 4, 16},
{2, 0, 5, 32}, {3, 0, 5, 32}, {5, 0, 5, 32},
{6, 0, 5, 32}, {8, 0, 5, 32}, {9, 0, 5, 32},
{11, 0, 5, 32}, {12, 0, 5, 32}, {15, 0, 6, 0},
{18, 1, 5, 32}, {20, 1, 5, 32}, {24, 2, 5, 32},
{28, 2, 5, 32}, {40, 3, 5, 32}, {48, 4, 5, 32},
{65536, 16, 6, 0}, {32768, 15, 6, 0}, {16384, 14, 6, 0},
{8192, 13, 6, 0},
}
// predefinedOffsetTable is the predefined table to use for offsets.
// Generated from table in RFC 3.1.1.3.2.2.3.
// Checked by TestPredefinedTables.
var predefinedOffsetTable = [...]fseBaselineEntry{
{1, 0, 5, 0}, {61, 6, 4, 0}, {509, 9, 5, 0},
{32765, 15, 5, 0}, {2097149, 21, 5, 0}, {5, 3, 5, 0},
{125, 7, 4, 0}, {4093, 12, 5, 0}, {262141, 18, 5, 0},
{8388605, 23, 5, 0}, {29, 5, 5, 0}, {253, 8, 4, 0},
{16381, 14, 5, 0}, {1048573, 20, 5, 0}, {1, 2, 5, 0},
{125, 7, 4, 16}, {2045, 11, 5, 0}, {131069, 17, 5, 0},
{4194301, 22, 5, 0}, {13, 4, 5, 0}, {253, 8, 4, 16},
{8189, 13, 5, 0}, {524285, 19, 5, 0}, {2, 1, 5, 0},
{61, 6, 4, 16}, {1021, 10, 5, 0}, {65533, 16, 5, 0},
{268435453, 28, 5, 0}, {134217725, 27, 5, 0}, {67108861, 26, 5, 0},
{33554429, 25, 5, 0}, {16777213, 24, 5, 0},
}
// predefinedMatchTable is the predefined table to use for match lengths.
// Generated from table in RFC 3.1.1.3.2.2.2.
// Checked by TestPredefinedTables.
var predefinedMatchTable = [...]fseBaselineEntry{
{3, 0, 6, 0}, {4, 0, 4, 0}, {5, 0, 5, 32},
{6, 0, 5, 0}, {8, 0, 5, 0}, {9, 0, 5, 0},
{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 0, 6, 0},
{19, 0, 6, 0}, {22, 0, 6, 0}, {25, 0, 6, 0},
{28, 0, 6, 0}, {31, 0, 6, 0}, {34, 0, 6, 0},
{37, 1, 6, 0}, {41, 1, 6, 0}, {47, 2, 6, 0},
{59, 3, 6, 0}, {83, 4, 6, 0}, {131, 7, 6, 0},
{515, 9, 6, 0}, {4, 0, 4, 16}, {5, 0, 4, 0},
{6, 0, 5, 32}, {7, 0, 5, 0}, {9, 0, 5, 32},
{10, 0, 5, 0}, {12, 0, 6, 0}, {15, 0, 6, 0},
{18, 0, 6, 0}, {21, 0, 6, 0}, {24, 0, 6, 0},
{27, 0, 6, 0}, {30, 0, 6, 0}, {33, 0, 6, 0},
{35, 1, 6, 0}, {39, 1, 6, 0}, {43, 2, 6, 0},
{51, 3, 6, 0}, {67, 4, 6, 0}, {99, 5, 6, 0},
{259, 8, 6, 0}, {4, 0, 4, 32}, {4, 0, 4, 48},
{5, 0, 4, 16}, {7, 0, 5, 32}, {8, 0, 5, 32},
{10, 0, 5, 32}, {11, 0, 5, 32}, {14, 0, 6, 0},
{17, 0, 6, 0}, {20, 0, 6, 0}, {23, 0, 6, 0},
{26, 0, 6, 0}, {29, 0, 6, 0}, {32, 0, 6, 0},
{65539, 16, 6, 0}, {32771, 15, 6, 0}, {16387, 14, 6, 0},
{8195, 13, 6, 0}, {4099, 12, 6, 0}, {2051, 11, 6, 0},
{1027, 10, 6, 0},
}

View File

@ -0,0 +1,89 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zstd
import (
"slices"
"testing"
)
// literalPredefinedDistribution is the predefined distribution table
// for literal lengths. RFC 3.1.1.3.2.2.1.
var literalPredefinedDistribution = []int16{
4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1,
2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1,
-1, -1, -1, -1,
}
// offsetPredefinedDistribution is the predefined distribution table
// for offsets. RFC 3.1.1.3.2.2.3.
var offsetPredefinedDistribution = []int16{
1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1,
}
// matchPredefinedDistribution is the predefined distribution table
// for match lengths. RFC 3.1.1.3.2.2.2.
var matchPredefinedDistribution = []int16{
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1,
-1, -1, -1, -1, -1,
}
// TestPredefinedTables verifies that we can generate the predefined
// literal/offset/match tables from the input data in RFC 8878.
// This serves as a test of the predefined tables, and also of buildFSE
// and the functions that make baseline FSE tables.
func TestPredefinedTables(t *testing.T) {
tests := []struct {
name string
distribution []int16
tableBits int
toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
predef []fseBaselineEntry
}{
{
name: "literal",
distribution: literalPredefinedDistribution,
tableBits: 6,
toBaseline: (*Reader).makeLiteralBaselineFSE,
predef: predefinedLiteralTable[:],
},
{
name: "offset",
distribution: offsetPredefinedDistribution,
tableBits: 5,
toBaseline: (*Reader).makeOffsetBaselineFSE,
predef: predefinedOffsetTable[:],
},
{
name: "match",
distribution: matchPredefinedDistribution,
tableBits: 6,
toBaseline: (*Reader).makeMatchBaselineFSE,
predef: predefinedMatchTable[:],
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
var r Reader
table := make([]fseEntry, 1<<test.tableBits)
if err := r.buildFSE(0, test.distribution, table, test.tableBits); err != nil {
t.Fatal(err)
}
baselineTable := make([]fseBaselineEntry, len(table))
if err := test.toBaseline(&r, 0, table, baselineTable); err != nil {
t.Fatal(err)
}
if !slices.Equal(baselineTable, test.predef) {
t.Errorf("got %v, want %v", baselineTable, test.predef)
}
})
}
}

View File

@ -0,0 +1,140 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zstd
import (
"bytes"
"io"
"os"
"os/exec"
"testing"
)
// badStrings is some inputs that FuzzReader failed on earlier.
var badStrings = []string{
"(\xb5/\xfdd00,\x05\x00\xc4\x0400000000000000000000000000000000000000000000000000000000000000000000000000000 \xa07100000000000000000000000000000000000000000000000000000000000000000000000000aM\x8a2y0B\b",
"(\xb5/\xfd00$\x05\x0020 00X70000a70000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"(\xb5/\xfd00$\x05\x0020 00B00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"(\xb5/\xfd00}\x00\x0020\x00\x9000000000000",
"(\xb5/\xfd00}\x00\x00&0\x02\x830!000000000",
"(\xb5/\xfd\x1002000$\x05\x0010\xcc0\xa8100000000100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
"(\xb5/\xfd\x1002000$\x05\x0000\xcc0\xa8100d\x0000001000000000000000000000000000000000000000000000000000000000000000000000000\x000000000000000000000000000000000000000000000000000000000000000000000000000000",
"(\xb5/\xfd001\x00\x0000000000000000000",
}
// This is a simple fuzzer to see if the decompressor panics.
func FuzzReader(f *testing.F) {
for _, test := range tests {
f.Add([]byte(test.compressed))
}
for _, s := range badStrings {
f.Add([]byte(s))
}
f.Fuzz(func(t *testing.T, b []byte) {
r := NewReader(bytes.NewReader(b))
io.Copy(io.Discard, r)
})
}
// Fuzz test to verify that what we decompress is what we compress.
// This isn't a great fuzz test because the fuzzer can't efficiently
// explore the space of decompressor behavior, since it can't see
// what the compressor is doing. But it's better than nothing.
func FuzzDecompressor(f *testing.F) {
if _, err := os.Stat("/usr/bin/zstd"); err != nil {
f.Skip("skipping because /usr/bin/zstd does not exist")
}
for _, test := range tests {
f.Add([]byte(test.uncompressed))
}
// Add some larger data, as that has more interesting compression.
f.Add(bytes.Repeat([]byte("abcdefghijklmnop"), 256))
var buf bytes.Buffer
for i := 0; i < 256; i++ {
buf.WriteByte(byte(i))
}
f.Add(bytes.Repeat(buf.Bytes(), 64))
f.Add(bigData(f))
f.Fuzz(func(t *testing.T, b []byte) {
cmd := exec.Command("/usr/bin/zstd", "-z")
cmd.Stdin = bytes.NewReader(b)
var compressed bytes.Buffer
cmd.Stdout = &compressed
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
t.Errorf("running zstd failed: %v", err)
}
r := NewReader(bytes.NewReader(compressed.Bytes()))
got, err := io.ReadAll(r)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, b) {
showDiffs(t, got, b)
}
})
}
// Fuzz test to check that if we can decompress some data,
// so can zstd, and that we get the same result.
func FuzzReverse(f *testing.F) {
if _, err := os.Stat("/usr/bin/zstd"); err != nil {
f.Skip("skipping because /usr/bin/zstd does not exist")
}
for _, test := range tests {
f.Add([]byte(test.compressed))
}
// Set a hook to reject some cases where we don't match zstd.
fuzzing = true
defer func() { fuzzing = false }()
f.Fuzz(func(t *testing.T, b []byte) {
r := NewReader(bytes.NewReader(b))
goExp, goErr := io.ReadAll(r)
cmd := exec.Command("/usr/bin/zstd", "-d")
cmd.Stdin = bytes.NewReader(b)
var uncompressed bytes.Buffer
cmd.Stdout = &uncompressed
cmd.Stderr = os.Stderr
zstdErr := cmd.Run()
zstdExp := uncompressed.Bytes()
if goErr == nil && zstdErr == nil {
if !bytes.Equal(zstdExp, goExp) {
showDiffs(t, zstdExp, goExp)
}
} else {
// Ideally we should check that this package and
// the zstd program both fail or both succeed,
// and that if they both fail one byte sequence
// is an exact prefix of the other.
// Actually trying this proved to be frustrating,
// as the zstd program appears to accept invalid
// byte sequences using rules that are difficult
// to determine.
// So we just check the prefix.
c := len(goExp)
if c > len(zstdExp) {
c = len(zstdExp)
}
goExp = goExp[:c]
zstdExp = zstdExp[:c]
if !bytes.Equal(goExp, zstdExp) {
t.Error("byte mismatch after error")
t.Logf("Go error: %v\n", goErr)
t.Logf("zstd error: %v\n", zstdErr)
showDiffs(t, zstdExp, goExp)
}
}
})
}

204
src/internal/zstd/huff.go Normal file
View File

@ -0,0 +1,204 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zstd
import (
"io"
"math/bits"
)
// maxHuffmanBits is the largest possible Huffman table bits.
const maxHuffmanBits = 11
// readHuff reads Huffman table from data starting at off into table.
// Each entry in a Huffman table is a pair of bytes.
// The high byte is the encoded value. The low byte is the number
// of bits used to encode that value. We index into the table
// with a value of size tableBits. A value that requires fewer bits
// appear in the table multiple times.
// This returns the number of bits in the Huffman table and the new offset.
// RFC 4.2.1.
func (r *Reader) readHuff(data block, off int, table []uint16) (tableBits, roff int, err error) {
if off >= len(data) {
return 0, 0, r.makeEOFError(off)
}
hdr := data[off]
off++
var weights [256]uint8
var count int
if hdr < 128 {
// The table is compressed using an FSE. RFC 4.2.1.2.
if len(r.fseScratch) < 1<<6 {
r.fseScratch = make([]fseEntry, 1<<6)
}
fseBits, noff, err := r.readFSE(data, off, 255, 6, r.fseScratch)
if err != nil {
return 0, 0, err
}
fseTable := r.fseScratch
if off+int(hdr) > len(data) {
return 0, 0, r.makeEOFError(off)
}
rbr, err := r.makeReverseBitReader(data, off+int(hdr)-1, noff)
if err != nil {
return 0, 0, err
}
state1, err := rbr.val(uint8(fseBits))
if err != nil {
return 0, 0, err
}
state2, err := rbr.val(uint8(fseBits))
if err != nil {
return 0, 0, err
}
// There are two independent FSE streams, tracked by
// state1 and state2. We decode them alternately.
for {
pt := &fseTable[state1]
if !rbr.fetch(pt.bits) {
if count >= 254 {
return 0, 0, rbr.makeError("Huffman count overflow")
}
weights[count] = pt.sym
weights[count+1] = fseTable[state2].sym
count += 2
break
}
v, err := rbr.val(pt.bits)
if err != nil {
return 0, 0, err
}
state1 = uint32(pt.base) + v
if count >= 255 {
return 0, 0, rbr.makeError("Huffman count overflow")
}
weights[count] = pt.sym
count++
pt = &fseTable[state2]
if !rbr.fetch(pt.bits) {
if count >= 254 {
return 0, 0, rbr.makeError("Huffman count overflow")
}
weights[count] = pt.sym
weights[count+1] = fseTable[state1].sym
count += 2
break
}
v, err = rbr.val(pt.bits)
if err != nil {
return 0, 0, err
}
state2 = uint32(pt.base) + v
if count >= 255 {
return 0, 0, rbr.makeError("Huffman count overflow")
}
weights[count] = pt.sym
count++
}
off += int(hdr)
} else {
// The table is not compressed. Each weight is 4 bits.
count = int(hdr) - 127
if off+((count+1)/2) >= len(data) {
return 0, 0, io.ErrUnexpectedEOF
}
for i := 0; i < count; i += 2 {
b := data[off]
off++
weights[i] = b >> 4
weights[i+1] = b & 0xf
}
}
// RFC 4.2.1.3.
var weightMark [13]uint32
weightMask := uint32(0)
for _, w := range weights[:count] {
if w > 12 {
return 0, 0, r.makeError(off, "Huffman weight overflow")
}
weightMark[w]++
if w > 0 {
weightMask += 1 << (w - 1)
}
}
if weightMask == 0 {
return 0, 0, r.makeError(off, "bad Huffman weights")
}
tableBits = 32 - bits.LeadingZeros32(weightMask)
if tableBits > maxHuffmanBits {
return 0, 0, r.makeError(off, "bad Huffman weights")
}
if len(table) < 1<<tableBits {
return 0, 0, r.makeError(off, "Huffman table too small")
}
// Work out the last weight value, which is omitted because
// the weights must sum to a power of two.
left := (uint32(1) << tableBits) - weightMask
if left == 0 {
return 0, 0, r.makeError(off, "bad Huffman weights")
}
highBit := 31 - bits.LeadingZeros32(left)
if uint32(1)<<highBit != left {
return 0, 0, r.makeError(off, "bad Huffman weights")
}
if count >= 256 {
return 0, 0, r.makeError(off, "Huffman weight overflow")
}
weights[count] = uint8(highBit + 1)
count++
weightMark[highBit+1]++
if weightMark[1] < 2 || weightMark[1]&1 != 0 {
return 0, 0, r.makeError(off, "bad Huffman weights")
}
// Change weightMark from a count of weights to the index of
// the first symbol for that weight. We shift the indexes to
// also store how many we have seen so far,
next := uint32(0)
for i := 0; i < tableBits; i++ {
cur := next
next += weightMark[i+1] << i
weightMark[i+1] = cur
}
for i, w := range weights[:count] {
if w == 0 {
continue
}
length := uint32(1) << (w - 1)
tval := uint16(i)<<8 | (uint16(tableBits) + 1 - uint16(w))
start := weightMark[w]
for j := uint32(0); j < length; j++ {
table[start+j] = tval
}
weightMark[w] += length
}
return tableBits, off, nil
}

View File

@ -0,0 +1,330 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zstd
import (
"encoding/binary"
)
// readLiterals reads and decompresses the literals from data at off.
// The literals are appended to outbuf, which is returned.
// Also returns the new input offset. RFC 3.1.1.3.1.
func (r *Reader) readLiterals(data block, off int, outbuf []byte) (int, []byte, error) {
if off >= len(data) {
return 0, nil, r.makeEOFError(off)
}
// Literals section header. RFC 3.1.1.3.1.1.
hdr := data[off]
off++
if (hdr&3) == 0 || (hdr&3) == 1 {
return r.readRawRLELiterals(data, off, hdr, outbuf)
} else {
return r.readHuffLiterals(data, off, hdr, outbuf)
}
}
// readRawRLELiterals reads and decompresses a Raw_Literals_Block or
// a RLE_Literals_Block. RFC 3.1.1.3.1.1.
func (r *Reader) readRawRLELiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) {
raw := (hdr & 3) == 0
var regeneratedSize int
switch (hdr >> 2) & 3 {
case 0, 2:
regeneratedSize = int(hdr >> 3)
case 1:
if off >= len(data) {
return 0, nil, r.makeEOFError(off)
}
regeneratedSize = int(hdr>>4) + (int(data[off]) << 4)
off++
case 3:
if off+1 >= len(data) {
return 0, nil, r.makeEOFError(off)
}
regeneratedSize = int(hdr>>4) + (int(data[off]) << 4) + (int(data[off+1]) << 12)
off += 2
}
// We are going to use the entire literal block in the output.
// The maximum size of one decompressed block is 128K,
// so we can't have more literals than that.
if regeneratedSize > 128<<10 {
return 0, nil, r.makeError(off, "literal size too large")
}
if raw {
// RFC 3.1.1.3.1.2.
if off+regeneratedSize > len(data) {
return 0, nil, r.makeError(off, "raw literal size too large")
}
outbuf = append(outbuf, data[off:off+regeneratedSize]...)
off += regeneratedSize
} else {
// RFC 3.1.1.3.1.3.
if off >= len(data) {
return 0, nil, r.makeError(off, "RLE literal missing")
}
rle := data[off]
off++
for i := 0; i < regeneratedSize; i++ {
outbuf = append(outbuf, rle)
}
}
return off, outbuf, nil
}
// readHuffLiterals reads and decompresses a Compressed_Literals_Block or
// a Treeless_Literals_Block. RFC 3.1.1.3.1.4.
func (r *Reader) readHuffLiterals(data block, off int, hdr byte, outbuf []byte) (int, []byte, error) {
var (
regeneratedSize int
compressedSize int
streams int
)
switch (hdr >> 2) & 3 {
case 0, 1:
if off+1 >= len(data) {
return 0, nil, r.makeEOFError(off)
}
regeneratedSize = (int(hdr) >> 4) | ((int(data[off]) & 0x3f) << 4)
compressedSize = (int(data[off]) >> 6) | (int(data[off+1]) << 2)
off += 2
if ((hdr >> 2) & 3) == 0 {
streams = 1
} else {
streams = 4
}
case 2:
if off+2 >= len(data) {
return 0, nil, r.makeEOFError(off)
}
regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 3) << 12)
compressedSize = (int(data[off+1]) >> 2) | (int(data[off+2]) << 6)
off += 3
streams = 4
case 3:
if off+3 >= len(data) {
return 0, nil, r.makeEOFError(off)
}
regeneratedSize = (int(hdr) >> 4) | (int(data[off]) << 4) | ((int(data[off+1]) & 0x3f) << 12)
compressedSize = (int(data[off+1]) >> 6) | (int(data[off+2]) << 2) | (int(data[off+3]) << 10)
off += 4
streams = 4
}
// We are going to use the entire literal block in the output.
// The maximum size of one decompressed block is 128K,
// so we can't have more literals than that.
if regeneratedSize > 128<<10 {
return 0, nil, r.makeError(off, "literal size too large")
}
roff := off + compressedSize
if roff > len(data) || roff < 0 {
return 0, nil, r.makeEOFError(off)
}
totalStreamsSize := compressedSize
if (hdr & 3) == 2 {
// Compressed_Literals_Block.
// Read new huffman tree.
if len(r.huffmanTable) < 1<<maxHuffmanBits {
r.huffmanTable = make([]uint16, 1<<maxHuffmanBits)
}
huffmanTableBits, hoff, err := r.readHuff(data, off, r.huffmanTable)
if err != nil {
return 0, nil, err
}
r.huffmanTableBits = huffmanTableBits
if totalStreamsSize < hoff-off {
return 0, nil, r.makeError(off, "Huffman table too big")
}
totalStreamsSize -= hoff - off
off = hoff
} else {
// Treeless_Literals_Block
// Reuse previous Huffman tree.
if r.huffmanTableBits == 0 {
return 0, nil, r.makeError(off, "missing literals Huffman tree")
}
}
// Decompress compressedSize bytes of data at off using the
// Huffman tree.
var err error
if streams == 1 {
outbuf, err = r.readLiteralsOneStream(data, off, totalStreamsSize, regeneratedSize, outbuf)
} else {
outbuf, err = r.readLiteralsFourStreams(data, off, totalStreamsSize, regeneratedSize, outbuf)
}
if err != nil {
return 0, nil, err
}
return roff, outbuf, nil
}
// readLiteralsOneStream reads a single stream of compressed literals.
func (r *Reader) readLiteralsOneStream(data block, off, compressedSize, regeneratedSize int, outbuf []byte) ([]byte, error) {
// We let the reverse bit reader read earlier bytes,
// because the Huffman table ignores bits that it doesn't need.
rbr, err := r.makeReverseBitReader(data, off+compressedSize-1, off-2)
if err != nil {
return nil, err
}
huffTable := r.huffmanTable
huffBits := uint32(r.huffmanTableBits)
huffMask := (uint32(1) << huffBits) - 1
for i := 0; i < regeneratedSize; i++ {
if !rbr.fetch(uint8(huffBits)) {
return nil, rbr.makeError("literals Huffman stream out of bits")
}
var t uint16
idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask
t = huffTable[idx]
outbuf = append(outbuf, byte(t>>8))
rbr.cnt -= uint32(t & 0xff)
}
return outbuf, nil
}
// readLiteralsFourStreams reads four interleaved streams of
// compressed literals.
func (r *Reader) readLiteralsFourStreams(data block, off, totalStreamsSize, regeneratedSize int, outbuf []byte) ([]byte, error) {
// Read the jump table to find out where the streams are.
// RFC 3.1.1.3.1.6.
if off+5 >= len(data) {
return nil, r.makeEOFError(off)
}
if totalStreamsSize < 6 {
return nil, r.makeError(off, "total streams size too small for jump table")
}
streamSize1 := binary.LittleEndian.Uint16(data[off:])
streamSize2 := binary.LittleEndian.Uint16(data[off+2:])
streamSize3 := binary.LittleEndian.Uint16(data[off+4:])
off += 6
tot := uint64(streamSize1) + uint64(streamSize2) + uint64(streamSize3)
if tot > uint64(totalStreamsSize)-6 {
return nil, r.makeEOFError(off)
}
streamSize4 := uint32(totalStreamsSize) - 6 - uint32(tot)
off--
off1 := off + int(streamSize1)
start1 := off + 1
off2 := off1 + int(streamSize2)
start2 := off1 + 1
off3 := off2 + int(streamSize3)
start3 := off2 + 1
off4 := off3 + int(streamSize4)
start4 := off3 + 1
// We let the reverse bit readers read earlier bytes,
// because the Huffman tables ignore bits that they don't need.
rbr1, err := r.makeReverseBitReader(data, off1, start1-2)
if err != nil {
return nil, err
}
rbr2, err := r.makeReverseBitReader(data, off2, start2-2)
if err != nil {
return nil, err
}
rbr3, err := r.makeReverseBitReader(data, off3, start3-2)
if err != nil {
return nil, err
}
rbr4, err := r.makeReverseBitReader(data, off4, start4-2)
if err != nil {
return nil, err
}
regeneratedStreamSize := (regeneratedSize + 3) / 4
out1 := len(outbuf)
out2 := out1 + regeneratedStreamSize
out3 := out2 + regeneratedStreamSize
out4 := out3 + regeneratedStreamSize
regeneratedStreamSize4 := regeneratedSize - regeneratedStreamSize*3
outbuf = append(outbuf, make([]byte, regeneratedSize)...)
huffTable := r.huffmanTable
huffBits := uint32(r.huffmanTableBits)
huffMask := (uint32(1) << huffBits) - 1
for i := 0; i < regeneratedStreamSize; i++ {
use4 := i < regeneratedStreamSize4
fetchHuff := func(rbr *reverseBitReader) (uint16, error) {
if !rbr.fetch(uint8(huffBits)) {
return 0, rbr.makeError("literals Huffman stream out of bits")
}
idx := (rbr.bits >> (rbr.cnt - huffBits)) & huffMask
return huffTable[idx], nil
}
t1, err := fetchHuff(&rbr1)
if err != nil {
return nil, err
}
t2, err := fetchHuff(&rbr2)
if err != nil {
return nil, err
}
t3, err := fetchHuff(&rbr3)
if err != nil {
return nil, err
}
if use4 {
t4, err := fetchHuff(&rbr4)
if err != nil {
return nil, err
}
outbuf[out4] = byte(t4 >> 8)
out4++
rbr4.cnt -= uint32(t4 & 0xff)
}
outbuf[out1] = byte(t1 >> 8)
out1++
rbr1.cnt -= uint32(t1 & 0xff)
outbuf[out2] = byte(t2 >> 8)
out2++
rbr2.cnt -= uint32(t2 & 0xff)
outbuf[out3] = byte(t3 >> 8)
out3++
rbr3.cnt -= uint32(t3 & 0xff)
}
return outbuf, nil
}

148
src/internal/zstd/xxhash.go Normal file
View File

@ -0,0 +1,148 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zstd
import (
"encoding/binary"
"math/bits"
)
const (
xxhPrime64c1 = 0x9e3779b185ebca87
xxhPrime64c2 = 0xc2b2ae3d27d4eb4f
xxhPrime64c3 = 0x165667b19e3779f9
xxhPrime64c4 = 0x85ebca77c2b2ae63
xxhPrime64c5 = 0x27d4eb2f165667c5
)
// xxhash64 is the state of a xxHash-64 checksum.
type xxhash64 struct {
len uint64 // total length hashed
v [4]uint64 // accumulators
buf [32]byte // buffer
cnt int // number of bytes in buffer
}
// reset discards the current state and prepares to compute a new hash.
// We assume a seed of 0 since that is what zstd uses.
func (xh *xxhash64) reset() {
xh.len = 0
// Separate addition for awkward constant overflow.
xh.v[0] = xxhPrime64c1
xh.v[0] += xxhPrime64c2
xh.v[1] = xxhPrime64c2
xh.v[2] = 0
// Separate negation for awkward constant overflow.
xh.v[3] = xxhPrime64c1
xh.v[3] = -xh.v[3]
for i := range xh.buf {
xh.buf[i] = 0
}
xh.cnt = 0
}
// update adds a buffer to the has.
func (xh *xxhash64) update(b []byte) {
xh.len += uint64(len(b))
if xh.cnt+len(b) < len(xh.buf) {
copy(xh.buf[xh.cnt:], b)
xh.cnt += len(b)
return
}
if xh.cnt > 0 {
n := copy(xh.buf[xh.cnt:], b)
b = b[n:]
xh.v[0] = xh.round(xh.v[0], binary.LittleEndian.Uint64(xh.buf[:]))
xh.v[1] = xh.round(xh.v[1], binary.LittleEndian.Uint64(xh.buf[8:]))
xh.v[2] = xh.round(xh.v[2], binary.LittleEndian.Uint64(xh.buf[16:]))
xh.v[3] = xh.round(xh.v[3], binary.LittleEndian.Uint64(xh.buf[24:]))
xh.cnt = 0
}
for len(b) >= 32 {
xh.v[0] = xh.round(xh.v[0], binary.LittleEndian.Uint64(b))
xh.v[1] = xh.round(xh.v[1], binary.LittleEndian.Uint64(b[8:]))
xh.v[2] = xh.round(xh.v[2], binary.LittleEndian.Uint64(b[16:]))
xh.v[3] = xh.round(xh.v[3], binary.LittleEndian.Uint64(b[24:]))
b = b[32:]
}
if len(b) > 0 {
copy(xh.buf[:], b)
xh.cnt = len(b)
}
}
// digest returns the final hash value.
func (xh *xxhash64) digest() uint64 {
var h64 uint64
if xh.len < 32 {
h64 = xh.v[2] + xxhPrime64c5
} else {
h64 = bits.RotateLeft64(xh.v[0], 1) +
bits.RotateLeft64(xh.v[1], 7) +
bits.RotateLeft64(xh.v[2], 12) +
bits.RotateLeft64(xh.v[3], 18)
h64 = xh.mergeRound(h64, xh.v[0])
h64 = xh.mergeRound(h64, xh.v[1])
h64 = xh.mergeRound(h64, xh.v[2])
h64 = xh.mergeRound(h64, xh.v[3])
}
h64 += xh.len
len := xh.len
len &= 31
buf := xh.buf[:]
for len >= 8 {
k1 := xh.round(0, binary.LittleEndian.Uint64(buf))
buf = buf[8:]
h64 ^= k1
h64 = bits.RotateLeft64(h64, 27)*xxhPrime64c1 + xxhPrime64c4
len -= 8
}
if len >= 4 {
h64 ^= uint64(binary.LittleEndian.Uint32(buf)) * xxhPrime64c1
buf = buf[4:]
h64 = bits.RotateLeft64(h64, 23)*xxhPrime64c2 + xxhPrime64c3
len -= 4
}
for len > 0 {
h64 ^= uint64(buf[0]) * xxhPrime64c5
buf = buf[1:]
h64 = bits.RotateLeft64(h64, 11) * xxhPrime64c1
len--
}
h64 ^= h64 >> 33
h64 *= xxhPrime64c2
h64 ^= h64 >> 29
h64 *= xxhPrime64c3
h64 ^= h64 >> 32
return h64
}
// round updates a value.
func (xh *xxhash64) round(v, n uint64) uint64 {
v += n * xxhPrime64c2
v = bits.RotateLeft64(v, 31)
v *= xxhPrime64c1
return v
}
// mergeRound updates a value in the final round.
func (xh *xxhash64) mergeRound(v, n uint64) uint64 {
n = xh.round(0, n)
v ^= n
v = v*xxhPrime64c1 + xxhPrime64c4
return v
}

View File

@ -0,0 +1,105 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zstd
import (
"bytes"
"os"
"os/exec"
"strconv"
"testing"
)
var xxHashTests = []struct {
data string
hash uint64
}{
{
"hello, world",
0xb33a384e6d1b1242,
},
{
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789$",
0x1032d841e824f998,
},
}
func TestXXHash(t *testing.T) {
var xh xxhash64
for i, test := range xxHashTests {
xh.reset()
xh.update([]byte(test.data))
if got := xh.digest(); got != test.hash {
t.Errorf("#%d: got %#x want %#x", i, got, test.hash)
}
}
}
func TestLargeXXHash(t *testing.T) {
if testing.Short() {
t.Skip("skipping expensive test in short mode")
}
data := bigData(t)
var xh xxhash64
xh.reset()
i := 0
for i < len(data) {
// Write varying amounts to test buffering.
c := i%4094 + 1
if i+c > len(data) {
c = len(data) - i
}
xh.update(data[i : i+c])
i += c
}
got := xh.digest()
want := uint64(0xf0dd39fd7e063f82)
if got != want {
t.Errorf("got %#x want %#x", got, want)
}
}
func FuzzXXHash(f *testing.F) {
if _, err := os.Stat("/usr/bin/xxhsum"); err != nil {
f.Skip("skipping because /usr/bin/xxhsum does not exist")
}
for _, test := range xxHashTests {
f.Add([]byte(test.data))
}
f.Add(bytes.Repeat([]byte("abcdefghijklmnop"), 256))
var buf bytes.Buffer
for i := 0; i < 256; i++ {
buf.WriteByte(byte(i))
}
f.Add(bytes.Repeat(buf.Bytes(), 64))
f.Add(bigData(f))
f.Fuzz(func(t *testing.T, b []byte) {
cmd := exec.Command("/usr/bin/xxhsum", "-H64")
cmd.Stdin = bytes.NewReader(b)
var hhsumHash bytes.Buffer
cmd.Stdout = &hhsumHash
if err := cmd.Run(); err != nil {
t.Fatalf("running hhsum failed: %v", err)
}
hhHashBytes := bytes.Fields(bytes.TrimSpace(hhsumHash.Bytes()))[0]
hhHash, err := strconv.ParseUint(string(hhHashBytes), 16, 64)
if err != nil {
t.Fatalf("could not parse hash %q: %v", hhHashBytes, err)
}
var xh xxhash64
xh.reset()
xh.update(b)
goHash := xh.digest()
if goHash != hhHash {
t.Errorf("Go hash %#x != xxhsum hash %#x", goHash, hhHash)
}
})
}

508
src/internal/zstd/zstd.go Normal file
View File

@ -0,0 +1,508 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package zstd provides a decompressor for zstd streams,
// described in RFC 8878. It does not support dictionaries.
package zstd
import (
"encoding/binary"
"errors"
"fmt"
"io"
)
// fuzzing is a fuzzer hook set to true when fuzzing.
// This is used to reject cases where we don't match zstd.
var fuzzing = false
// Reader implements [io.Reader] to read a zstd compressed stream.
type Reader struct {
// The underlying Reader.
r io.Reader
// Whether we have read the frame header.
// This is of interest when buffer is empty.
// If true we expect to see a new block.
sawFrameHeader bool
// Whether the current frame expects a checksum.
hasChecksum bool
// Whether we have read at least one frame.
readOneFrame bool
// True if the frame size is not known.
frameSizeUnknown bool
// The number of uncompressed bytes remaining in the current frame.
// If frameSizeUnknown is true, this is not valid.
remainingFrameSize uint64
// The number of bytes read from r up to the start of the current
// block, for error reporting.
blockOffset int64
// Buffered decompressed data.
buffer []byte
// Current read offset in buffer.
off int
// The current repeated offsets.
repeatedOffset1 uint32
repeatedOffset2 uint32
repeatedOffset3 uint32
// The current Huffman tree used for compressing literals.
huffmanTable []uint16
huffmanTableBits int
// The window for back references.
windowSize int // maximum required window size
window []byte // window data
// A buffer available to hold a compressed block.
compressedBuf []byte
// A buffer for literals.
literals []byte
// Sequence decode FSE tables.
seqTables [3][]fseBaselineEntry
seqTableBits [3]uint8
// Buffers for sequence decode FSE tables.
seqTableBuffers [3][]fseBaselineEntry
// Scratch space used for small reads, to avoid allocation.
scratch [16]byte
// A scratch table for reading an FSE. Only temporarily valid.
fseScratch []fseEntry
// For checksum computation.
checksum xxhash64
}
// NewReader creates a new Reader that decompresses data from the given reader.
func NewReader(input io.Reader) *Reader {
r := new(Reader)
r.Reset(input)
return r
}
// Reset discards the current state and starts reading a new stream from r.
// This permits reusing a Reader rather than allocating a new one.
func (r *Reader) Reset(input io.Reader) {
r.r = input
// Several fields are preserved to avoid allocation.
// Others are always set before they are used.
r.sawFrameHeader = false
r.hasChecksum = false
r.readOneFrame = false
r.frameSizeUnknown = false
r.remainingFrameSize = 0
r.blockOffset = 0
// buffer
r.off = 0
// repeatedOffset1
// repeatedOffset2
// repeatedOffset3
// huffmanTable
// huffmanTableBits
// windowSize
// window
// compressedBuf
// literals
// seqTables
// seqTableBits
// seqTableBuffers
// scratch
// fseScratch
}
// Read implements [io.Reader].
func (r *Reader) Read(p []byte) (int, error) {
if err := r.refillIfNeeded(); err != nil {
return 0, err
}
n := copy(p, r.buffer[r.off:])
r.off += n
return n, nil
}
// ReadByte implements [io.ByteReader].
func (r *Reader) ReadByte() (byte, error) {
if err := r.refillIfNeeded(); err != nil {
return 0, err
}
ret := r.buffer[r.off]
r.off++
return ret, nil
}
// refillIfNeeded reads the next block if necessary.
func (r *Reader) refillIfNeeded() error {
for r.off >= len(r.buffer) {
if err := r.refill(); err != nil {
return err
}
r.off = 0
}
return nil
}
// refill reads and decompresses the next block.
func (r *Reader) refill() error {
if !r.sawFrameHeader {
if err := r.readFrameHeader(); err != nil {
return err
}
}
return r.readBlock()
}
// readFrameHeader reads the frame header and prepares to read a block.
func (r *Reader) readFrameHeader() error {
retry:
relativeOffset := 0
// Read magic number. RFC 3.1.1.
if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
// We require that the stream contain at least one frame.
if err == io.EOF && !r.readOneFrame {
err = io.ErrUnexpectedEOF
}
return r.wrapError(relativeOffset, err)
}
if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
// This is a skippable frame.
r.blockOffset += int64(relativeOffset) + 4
if err := r.skipFrame(); err != nil {
return err
}
goto retry
}
return r.makeError(relativeOffset, "invalid magic number")
}
relativeOffset += 4
// Read Frame_Header_Descriptor. RFC 3.1.1.1.1.
if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
return r.wrapNonEOFError(relativeOffset, err)
}
descriptor := r.scratch[0]
singleSegment := descriptor&(1<<5) != 0
fcsFieldSize := 1 << (descriptor >> 6)
if fcsFieldSize == 1 && !singleSegment {
fcsFieldSize = 0
}
var windowDescriptorSize int
if singleSegment {
windowDescriptorSize = 0
} else {
windowDescriptorSize = 1
}
if descriptor&(1<<3) != 0 {
return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
}
r.hasChecksum = descriptor&(1<<2) != 0
if r.hasChecksum {
r.checksum.reset()
}
if descriptor&3 != 0 {
return r.makeError(relativeOffset, "dictionaries are not supported")
}
relativeOffset++
headerSize := windowDescriptorSize + fcsFieldSize
if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
return r.wrapNonEOFError(relativeOffset, err)
}
// Figure out the maximum amount of data we need to retain
// for backreferences.
if singleSegment {
// No window required, as all the data is in a single buffer.
r.windowSize = 0
} else {
// Window descriptor. RFC 3.1.1.1.2.
windowDescriptor := r.scratch[0]
exponent := uint64(windowDescriptor >> 3)
mantissa := uint64(windowDescriptor & 7)
windowLog := exponent + 10
windowBase := uint64(1) << windowLog
windowAdd := (windowBase / 8) * mantissa
windowSize := windowBase + windowAdd
// Default zstd sets limits on the window size.
if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
return r.makeError(relativeOffset, "windowSize too large")
}
// RFC 8878 permits us to set an 8M max on window size.
if windowSize > 8<<20 {
windowSize = 8 << 20
}
r.windowSize = int(windowSize)
}
// Frame_Content_Size. RFC 3.1.1.4.
r.frameSizeUnknown = false
r.remainingFrameSize = 0
fb := r.scratch[windowDescriptorSize:]
switch fcsFieldSize {
case 0:
r.frameSizeUnknown = true
case 1:
r.remainingFrameSize = uint64(fb[0])
case 2:
r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
case 4:
r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
case 8:
r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
default:
panic("unreachable")
}
relativeOffset += headerSize
r.sawFrameHeader = true
r.readOneFrame = true
r.blockOffset += int64(relativeOffset)
// Prepare to read blocks from the frame.
r.repeatedOffset1 = 1
r.repeatedOffset2 = 4
r.repeatedOffset3 = 8
r.huffmanTableBits = 0
r.window = r.window[:0]
r.seqTables[0] = nil
r.seqTables[1] = nil
r.seqTables[2] = nil
return nil
}
// skipFrame skips a skippable frame. RFC 3.1.2.
func (r *Reader) skipFrame() error {
relativeOffset := 0
if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
return r.wrapNonEOFError(relativeOffset, err)
}
relativeOffset += 4
size := binary.LittleEndian.Uint32(r.scratch[:4])
if seeker, ok := r.r.(io.Seeker); ok {
if _, err := seeker.Seek(int64(size), io.SeekCurrent); err != nil {
return err
}
r.blockOffset += int64(relativeOffset) + int64(size)
return nil
}
var skip []byte
const chunk = 1 << 20 // 1M
for size >= chunk {
if len(skip) == 0 {
skip = make([]byte, chunk)
}
if _, err := io.ReadFull(r.r, skip); err != nil {
return r.wrapNonEOFError(relativeOffset, err)
}
relativeOffset += chunk
size -= chunk
}
if size > 0 {
if len(skip) == 0 {
skip = make([]byte, size)
}
if _, err := io.ReadFull(r.r, skip); err != nil {
return r.wrapNonEOFError(relativeOffset, err)
}
relativeOffset += int(size)
}
r.blockOffset += int64(relativeOffset)
return nil
}
// readBlock reads the next block from a frame.
func (r *Reader) readBlock() error {
relativeOffset := 0
// Read Block_Header. RFC 3.1.1.2.
if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
return r.wrapNonEOFError(relativeOffset, err)
}
relativeOffset += 3
header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
lastBlock := header&1 != 0
blockType := (header >> 1) & 3
blockSize := int(header >> 3)
// Maximum block size is smaller of window size and 128K.
// We don't record the window size for a single segment frame,
// so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4.
if blockSize > 128<<10 || (r.windowSize > 0 && blockSize > r.windowSize) {
return r.makeError(relativeOffset, "block size too large")
}
// Handle different block types. RFC 3.1.1.2.2.
switch blockType {
case 0:
r.setBufferSize(blockSize)
if _, err := io.ReadFull(r.r, r.buffer); err != nil {
return r.wrapNonEOFError(relativeOffset, err)
}
relativeOffset += blockSize
r.blockOffset += int64(relativeOffset)
case 1:
r.setBufferSize(blockSize)
if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
return r.wrapNonEOFError(relativeOffset, err)
}
relativeOffset++
v := r.scratch[0]
for i := range r.buffer {
r.buffer[i] = v
}
r.blockOffset += int64(relativeOffset)
case 2:
r.blockOffset += int64(relativeOffset)
if err := r.compressedBlock(blockSize); err != nil {
return err
}
r.blockOffset += int64(blockSize)
case 3:
return r.makeError(relativeOffset, "invalid block type")
}
if !r.frameSizeUnknown {
if uint64(len(r.buffer)) > r.remainingFrameSize {
return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
}
r.remainingFrameSize -= uint64(len(r.buffer))
}
if r.hasChecksum {
r.checksum.update(r.buffer)
}
if !lastBlock {
r.saveWindow(r.buffer)
} else {
if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
}
// Check for checksum at end of frame. RFC 3.1.1.
if r.hasChecksum {
if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
return r.wrapNonEOFError(0, err)
}
inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
dataChecksum := uint32(r.checksum.digest())
if inputChecksum != dataChecksum {
return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
}
r.blockOffset += 4
}
r.sawFrameHeader = false
}
return nil
}
// setBufferSize sets the decompressed buffer size.
// When this is called the buffer is empty.
func (r *Reader) setBufferSize(size int) {
if cap(r.buffer) < size {
need := size - cap(r.buffer)
r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
}
r.buffer = r.buffer[:size]
}
// saveWindow saves bytes in the backreference window.
// TODO: use a circular buffer for less data movement.
func (r *Reader) saveWindow(buf []byte) {
if r.windowSize == 0 {
return
}
if len(buf) >= r.windowSize {
from := len(buf) - r.windowSize
r.window = append(r.window[:0], buf[from:]...)
return
}
keep := r.windowSize - len(buf) // must be positive
if keep < len(r.window) {
remove := len(r.window) - keep
copy(r.window[:], r.window[remove:])
}
r.window = append(r.window, buf...)
}
// zstdError is an error while decompressing.
type zstdError struct {
offset int64
err error
}
func (ze *zstdError) Error() string {
return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
}
func (ze *zstdError) Unwrap() error {
return ze.err
}
func (r *Reader) makeEOFError(off int) error {
return r.wrapError(off, io.ErrUnexpectedEOF)
}
func (r *Reader) wrapNonEOFError(off int, err error) error {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return r.wrapError(off, err)
}
func (r *Reader) makeError(off int, msg string) error {
return r.wrapError(off, errors.New(msg))
}
func (r *Reader) wrapError(off int, err error) error {
if err == io.EOF {
return err
}
return &zstdError{r.blockOffset + int64(off), err}
}

View File

@ -0,0 +1,249 @@
// Copyright 2023 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package zstd
import (
"bytes"
"fmt"
"internal/race"
"internal/testenv"
"io"
"os"
"os/exec"
"strings"
"sync"
"testing"
)
// tests holds some simple test cases, including some found by fuzzing.
var tests = []struct {
name, uncompressed, compressed string
}{
{
"hello",
"hello, world\n",
"\x28\xb5\x2f\xfd\x24\x0d\x69\x00\x00\x68\x65\x6c\x6c\x6f\x2c\x20\x77\x6f\x72\x6c\x64\x0a\x4c\x1f\xf9\xf1",
},
{
// a small compressed .debug_ranges section.
"ranges",
"\xcc\x11\x00\x00\x00\x00\x00\x00\xd5\x13\x00\x00\x00\x00\x00\x00" +
"\x1c\x14\x00\x00\x00\x00\x00\x00\x72\x14\x00\x00\x00\x00\x00\x00" +
"\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" +
"\x0c\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" +
"\x29\x14\x00\x00\x00\x00\x00\x00\x4e\x14\x00\x00\x00\x00\x00\x00" +
"\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\xfb\x12\x00\x00\x00\x00\x00\x00\x09\x13\x00\x00\x00\x00\x00\x00" +
"\x67\x13\x00\x00\x00\x00\x00\x00\xcb\x13\x00\x00\x00\x00\x00\x00" +
"\x9d\x14\x00\x00\x00\x00\x00\x00\xd5\x14\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x5f\x0b\x00\x00\x00\x00\x00\x00\x6c\x0b\x00\x00\x00\x00\x00\x00" +
"\x7d\x0b\x00\x00\x00\x00\x00\x00\x7e\x0c\x00\x00\x00\x00\x00\x00" +
"\x38\x0f\x00\x00\x00\x00\x00\x00\x5c\x0f\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x83\x0c\x00\x00\x00\x00\x00\x00\xfa\x0c\x00\x00\x00\x00\x00\x00" +
"\xfd\x0d\x00\x00\x00\x00\x00\x00\xef\x0e\x00\x00\x00\x00\x00\x00" +
"\x14\x0f\x00\x00\x00\x00\x00\x00\x38\x0f\x00\x00\x00\x00\x00\x00" +
"\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" +
"\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\xfd\x0d\x00\x00\x00\x00\x00\x00\xd8\x0e\x00\x00\x00\x00\x00\x00" +
"\x9f\x0f\x00\x00\x00\x00\x00\x00\xac\x0f\x00\x00\x00\x00\x00\x00" +
"\xdb\x0f\x00\x00\x00\x00\x00\x00\xff\x0f\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\xfa\x0c\x00\x00\x00\x00\x00\x00\xea\x0d\x00\x00\x00\x00\x00\x00" +
"\xef\x0e\x00\x00\x00\x00\x00\x00\x14\x0f\x00\x00\x00\x00\x00\x00" +
"\x5c\x0f\x00\x00\x00\x00\x00\x00\x9f\x0f\x00\x00\x00\x00\x00\x00" +
"\xac\x0f\x00\x00\x00\x00\x00\x00\xdb\x0f\x00\x00\x00\x00\x00\x00" +
"\xff\x0f\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x60\x11\x00\x00\x00\x00\x00\x00\xd1\x16\x00\x00\x00\x00\x00\x00" +
"\x40\x0b\x00\x00\x00\x00\x00\x00\x2c\x10\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x7a\x00\x00\x00\x00\x00\x00\x00\xb6\x00\x00\x00\x00\x00\x00\x00" +
"\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" +
"\x7a\x00\x00\x00\x00\x00\x00\x00\xa9\x00\x00\x00\x00\x00\x00\x00" +
"\x9f\x01\x00\x00\x00\x00\x00\x00\xa7\x01\x00\x00\x00\x00\x00\x00" +
"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
"\x28\xb5\x2f\xfd\x64\xa0\x01\x2d\x05\x00\xc4\x04\xcc\x11\x00\xd5" +
"\x13\x00\x1c\x14\x00\x72\x9d\xd5\xfb\x12\x00\x09\x0c\x13\xcb\x13" +
"\x29\x4e\x67\x5f\x0b\x6c\x0b\x7d\x0b\x7e\x0c\x38\x0f\x5c\x0f\x83" +
"\x0c\xfa\x0c\xfd\x0d\xef\x0e\x14\x38\x9f\x0f\xac\x0f\xdb\x0f\xff" +
"\x0f\xd8\x9f\xac\xdb\xff\xea\x5c\x2c\x10\x60\xd1\x16\x40\x0b\x7a" +
"\x00\xb6\x00\x9f\x01\xa7\x01\xa9\x36\x20\xa0\x83\x14\x34\x63\x4a" +
"\x21\x70\x8c\x07\x46\x03\x4e\x10\x62\x3c\x06\x4e\xc8\x8c\xb0\x32" +
"\x2a\x59\xad\xb2\xf1\x02\x82\x7c\x33\xcb\x92\x6f\x32\x4f\x9b\xb0" +
"\xa2\x30\xf0\xc0\x06\x1e\x98\x99\x2c\x06\x1e\xd8\xc0\x03\x56\xd8" +
"\xc0\x03\x0f\x6c\xe0\x01\xf1\xf0\xee\x9a\xc6\xc8\x97\x99\xd1\x6c" +
"\xb4\x21\x45\x3b\x10\xe4\x7b\x99\x4d\x8a\x36\x64\x5c\x77\x08\x02" +
"\xcb\xe0\xce",
},
{
"fuzz1",
"0\x00\x00\x00\x00\x000\x00\x00\x00\x00\x001\x00\x00\x00\x00\x000000",
"(\xb5/\xfd\x04X\x8d\x00\x00P0\x000\x001\x000000\x03T\x02\x00\x01\x01m\xf9\xb7G",
},
}
func TestSamples(t *testing.T) {
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
r := NewReader(strings.NewReader(test.compressed))
got, err := io.ReadAll(r)
if err != nil {
t.Fatal(err)
}
gotstr := string(got)
if gotstr != test.uncompressed {
t.Errorf("got %q want %q", gotstr, test.uncompressed)
}
})
}
}
var (
bigDataOnce sync.Once
bigDataBytes []byte
bigDataErr error
)
// bigData returns the contents of our large test file.
func bigData(t testing.TB) []byte {
bigDataOnce.Do(func() {
bigDataBytes, bigDataErr = os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
})
if bigDataErr != nil {
t.Fatal(bigDataErr)
}
return bigDataBytes
}
var (
zstdBigOnce sync.Once
zstdBigBytes []byte
zstdBigSkip bool
zstdBigErr error
)
// zstdBigData returns the compressed contents of our large test file.
// This will only run on Unix systems with zstd installed.
// That's OK as the package is GOOS-independent.
func zstdBigData(t testing.TB) []byte {
input := bigData(t)
zstdBigOnce.Do(func() {
if _, err := os.Stat("/usr/bin/zstd"); err != nil {
zstdBigSkip = true
return
}
cmd := exec.Command("/usr/bin/zstd", "-z")
cmd.Stdin = bytes.NewReader(input)
var compressed bytes.Buffer
cmd.Stdout = &compressed
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
zstdBigErr = fmt.Errorf("running zstd failed: %v", err)
return
}
zstdBigBytes = compressed.Bytes()
})
if zstdBigSkip {
t.Skip("skipping because /usr/bin/zstd does not exist")
}
if zstdBigErr != nil {
t.Fatal(zstdBigErr)
}
return zstdBigBytes
}
// Test decompressing a large file. We don't have a compressor,
// so this test only runs on systems with zstd installed.
func TestLarge(t *testing.T) {
if testing.Short() {
t.Skip("skipping expensive test in short mode")
}
data := bigData(t)
compressed := zstdBigData(t)
t.Logf("/usr/bin/zstd compressed %d bytes to %d", len(data), len(compressed))
r := NewReader(bytes.NewReader(compressed))
got, err := io.ReadAll(r)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got, data) {
showDiffs(t, got, data)
}
}
// showDiffs reports the first few differences in two []byte.
func showDiffs(t *testing.T, got, want []byte) {
t.Error("data mismatch")
if len(got) != len(want) {
t.Errorf("got data length %d, want %d", len(got), len(want))
}
diffs := 0
for i, b := range got {
if i >= len(want) {
break
}
if b != want[i] {
diffs++
if diffs > 20 {
break
}
t.Logf("%d: %#x != %#x", i, b, want[i])
}
}
}
func TestAlloc(t *testing.T) {
testenv.SkipIfOptimizationOff(t)
if race.Enabled {
t.Skip("skipping allocation test under race detector")
}
compressed := zstdBigData(t)
input := bytes.NewReader(compressed)
r := NewReader(input)
c := testing.AllocsPerRun(10, func() {
input.Reset(compressed)
r.Reset(input)
io.Copy(io.Discard, r)
})
if c != 0 {
t.Errorf("got %v allocs, want 0", c)
}
}
func BenchmarkLarge(b *testing.B) {
b.StopTimer()
b.ReportAllocs()
compressed := zstdBigData(b)
b.SetBytes(int64(len(compressed)))
input := bytes.NewReader(compressed)
r := NewReader(input)
b.StartTimer()
for i := 0; i < b.N; i++ {
input.Reset(compressed)
r.Reset(input)
io.Copy(io.Discard, r)
}
}