1
0
mirror of https://github.com/golang/go synced 2024-11-25 07:07:57 -07:00

compress/flate: do not use background goroutines

Programs expect that Read and Write are synchronous.
The background goroutines make the implementation
a little easier, but they introduce asynchrony that
trips up calling code.  Remove them.

R=golang-dev, krasin
CC=golang-dev
https://golang.org/cl/4548079
This commit is contained in:
Russ Cox 2011-06-02 09:32:38 -04:00
parent 422abf3b8e
commit 07acc02a29
3 changed files with 360 additions and 372 deletions

View File

@ -16,8 +16,10 @@ const (
fastCompression = 3 fastCompression = 3
BestCompression = 9 BestCompression = 9
DefaultCompression = -1 DefaultCompression = -1
logWindowSize = 15
windowSize = 1 << logWindowSize
windowMask = windowSize - 1
logMaxOffsetSize = 15 // Standard DEFLATE logMaxOffsetSize = 15 // Standard DEFLATE
wideLogMaxOffsetSize = 22 // Wide DEFLATE
minMatchLength = 3 // The smallest match that the compressor looks for minMatchLength = 3 // The smallest match that the compressor looks for
maxMatchLength = 258 // The longest match for the compressor maxMatchLength = 258 // The longest match for the compressor
minOffsetSize = 1 // The shortest offset that makes any sence minOffsetSize = 1 // The shortest offset that makes any sence
@ -32,22 +34,6 @@ const (
hashShift = (hashBits + minMatchLength - 1) / minMatchLength hashShift = (hashBits + minMatchLength - 1) / minMatchLength
) )
type syncPipeReader struct {
*io.PipeReader
closeChan chan bool
}
func (sr *syncPipeReader) CloseWithError(err os.Error) os.Error {
retErr := sr.PipeReader.CloseWithError(err)
sr.closeChan <- true // finish writer close
return retErr
}
type syncPipeWriter struct {
*io.PipeWriter
closeChan chan bool
}
type compressionLevel struct { type compressionLevel struct {
good, lazy, nice, chain, fastSkipHashing int good, lazy, nice, chain, fastSkipHashing int
} }
@ -68,105 +54,73 @@ var levels = []compressionLevel{
{32, 258, 258, 4096, math.MaxInt32}, {32, 258, 258, 4096, math.MaxInt32},
} }
func (sw *syncPipeWriter) Close() os.Error {
err := sw.PipeWriter.Close()
<-sw.closeChan // wait for reader close
return err
}
func syncPipe() (*syncPipeReader, *syncPipeWriter) {
r, w := io.Pipe()
sr := &syncPipeReader{r, make(chan bool, 1)}
sw := &syncPipeWriter{w, sr.closeChan}
return sr, sw
}
type compressor struct { type compressor struct {
level int compressionLevel
logWindowSize uint
w *huffmanBitWriter w *huffmanBitWriter
r io.Reader
// (1 << logWindowSize) - 1.
windowMask int
eof bool // has eof been reached on input? // compression algorithm
sync bool // writer wants to flush fill func(*compressor, []byte) int // copy data to window
syncChan chan os.Error step func(*compressor) // process window
sync bool // requesting flush
// Input hash chains
// hashHead[hashValue] contains the largest inputIndex with the specified hash value // hashHead[hashValue] contains the largest inputIndex with the specified hash value
hashHead []int
// If hashHead[hashValue] is within the current window, then // If hashHead[hashValue] is within the current window, then
// hashPrev[hashHead[hashValue] & windowMask] contains the previous index // hashPrev[hashHead[hashValue] & windowMask] contains the previous index
// with the same hash value. // with the same hash value.
chainHead int
hashHead []int
hashPrev []int hashPrev []int
// If we find a match of length >= niceMatch, then we don't bother searching // input window: unprocessed data is window[index:windowEnd]
// any further. index int
niceMatch int
// If we find a match of length >= goodMatch, we only do a half-hearted
// effort at doing lazy matching starting at the next character
goodMatch int
// The maximum number of chains we look at when finding a match
maxChainLength int
// The sliding window we use for matching
window []byte window []byte
// The index just past the last valid character
windowEnd int windowEnd int
blockStart int // window index where current tokens start
byteAvailable bool // if true, still need to process window[index-1].
// index in "window" at which current block starts // queued output tokens: tokens[:ti]
blockStart int tokens []token
ti int
// deflate state
length int
offset int
hash int
maxInsertIndex int
err os.Error
} }
func (d *compressor) flush() os.Error { func (d *compressor) fillDeflate(b []byte) int {
d.w.flush() if d.index >= 2*windowSize-(minMatchLength+maxMatchLength) {
return d.w.err // shift the window by windowSize
} copy(d.window, d.window[windowSize:2*windowSize])
d.index -= windowSize
func (d *compressor) fillWindow(index int) (int, os.Error) { d.windowEnd -= windowSize
if d.sync { if d.blockStart >= windowSize {
return index, nil d.blockStart -= windowSize
}
wSize := d.windowMask + 1
if index >= wSize+wSize-(minMatchLength+maxMatchLength) {
// shift the window by wSize
copy(d.window, d.window[wSize:2*wSize])
index -= wSize
d.windowEnd -= wSize
if d.blockStart >= wSize {
d.blockStart -= wSize
} else { } else {
d.blockStart = math.MaxInt32 d.blockStart = math.MaxInt32
} }
for i, h := range d.hashHead { for i, h := range d.hashHead {
v := h - wSize v := h - windowSize
if v < -1 { if v < -1 {
v = -1 v = -1
} }
d.hashHead[i] = v d.hashHead[i] = v
} }
for i, h := range d.hashPrev { for i, h := range d.hashPrev {
v := -h - wSize v := -h - windowSize
if v < -1 { if v < -1 {
v = -1 v = -1
} }
d.hashPrev[i] = v d.hashPrev[i] = v
} }
} }
count, err := d.r.Read(d.window[d.windowEnd:]) n := copy(d.window[d.windowEnd:], b)
d.windowEnd += count d.windowEnd += n
if count == 0 && err == nil { return n
d.sync = true
}
if err == os.EOF {
d.eof = true
err = nil
}
return index, err
} }
func (d *compressor) writeBlock(tokens []token, index int, eof bool) os.Error { func (d *compressor) writeBlock(tokens []token, index int, eof bool) os.Error {
@ -194,21 +148,21 @@ func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead
// We quit when we get a match that's at least nice long // We quit when we get a match that's at least nice long
nice := len(win) - pos nice := len(win) - pos
if d.niceMatch < nice { if d.nice < nice {
nice = d.niceMatch nice = d.nice
} }
// If we've got a match that's good enough, only look in 1/4 the chain. // If we've got a match that's good enough, only look in 1/4 the chain.
tries := d.maxChainLength tries := d.chain
length = prevLength length = prevLength
if length >= d.goodMatch { if length >= d.good {
tries >>= 2 tries >>= 2
} }
w0 := win[pos] w0 := win[pos]
w1 := win[pos+1] w1 := win[pos+1]
wEnd := win[pos+length] wEnd := win[pos+length]
minIndex := pos - (d.windowMask + 1) minIndex := pos - windowSize
for i := prevHead; tries > 0; tries-- { for i := prevHead; tries > 0; tries-- {
if w0 == win[i] && w1 == win[i+1] && wEnd == win[i+length] { if w0 == win[i] && w1 == win[i+1] && wEnd == win[i+length] {
@ -233,7 +187,7 @@ func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead
// hashPrev[i & windowMask] has already been overwritten, so stop now. // hashPrev[i & windowMask] has already been overwritten, so stop now.
break break
} }
if i = d.hashPrev[i&d.windowMask]; i < minIndex || i < 0 { if i = d.hashPrev[i&windowMask]; i < minIndex || i < 0 {
break break
} }
} }
@ -248,234 +202,224 @@ func (d *compressor) writeStoredBlock(buf []byte) os.Error {
return d.w.err return d.w.err
} }
func (d *compressor) storedDeflate() os.Error { func (d *compressor) initDeflate() {
buf := make([]byte, maxStoreBlockSize) d.hashHead = make([]int, hashSize)
for { d.hashPrev = make([]int, windowSize)
n, err := d.r.Read(buf) d.window = make([]byte, 2*windowSize)
if n == 0 && err == nil { fillInts(d.hashHead, -1)
d.sync = true d.tokens = make([]token, maxFlateBlockTokens, maxFlateBlockTokens+1)
} d.length = minMatchLength - 1
if n > 0 || d.sync { d.offset = 0
if err := d.writeStoredBlock(buf[0:n]); err != nil { d.byteAvailable = false
return err d.index = 0
} d.ti = 0
if d.sync { d.hash = 0
d.syncChan <- nil d.chainHead = -1
d.sync = false
}
}
if err != nil {
if err == os.EOF {
break
}
return err
}
}
return nil
} }
func (d *compressor) doDeflate() (err os.Error) { func (d *compressor) deflate() {
// init if d.windowEnd-d.index < minMatchLength+maxMatchLength && !d.sync {
d.windowMask = 1<<d.logWindowSize - 1
d.hashHead = make([]int, hashSize)
d.hashPrev = make([]int, 1<<d.logWindowSize)
d.window = make([]byte, 2<<d.logWindowSize)
fillInts(d.hashHead, -1)
tokens := make([]token, maxFlateBlockTokens, maxFlateBlockTokens+1)
l := levels[d.level]
d.goodMatch = l.good
d.niceMatch = l.nice
d.maxChainLength = l.chain
lazyMatch := l.lazy
length := minMatchLength - 1
offset := 0
byteAvailable := false
isFastDeflate := l.fastSkipHashing != 0
index := 0
// run
if index, err = d.fillWindow(index); err != nil {
return return
} }
maxOffset := d.windowMask + 1 // (1 << logWindowSize);
// only need to change when you refill the window
windowEnd := d.windowEnd
maxInsertIndex := windowEnd - (minMatchLength - 1)
ti := 0
hash := int(0) d.maxInsertIndex = d.windowEnd - (minMatchLength - 1)
if index < maxInsertIndex { if d.index < d.maxInsertIndex {
hash = int(d.window[index])<<hashShift + int(d.window[index+1]) d.hash = int(d.window[d.index])<<hashShift + int(d.window[d.index+1])
} }
chainHead := -1
Loop: Loop:
for { for {
if index > windowEnd { if d.index > d.windowEnd {
panic("index > windowEnd") panic("index > windowEnd")
} }
lookahead := windowEnd - index lookahead := d.windowEnd - d.index
if lookahead < minMatchLength+maxMatchLength { if lookahead < minMatchLength+maxMatchLength {
if index, err = d.fillWindow(index); err != nil { if !d.sync {
return break Loop
} }
windowEnd = d.windowEnd if d.index > d.windowEnd {
if index > windowEnd {
panic("index > windowEnd") panic("index > windowEnd")
} }
maxInsertIndex = windowEnd - (minMatchLength - 1)
lookahead = windowEnd - index
if lookahead == 0 { if lookahead == 0 {
// Flush current output block if any. // Flush current output block if any.
if byteAvailable { if d.byteAvailable {
// There is still one pending token that needs to be flushed // There is still one pending token that needs to be flushed
tokens[ti] = literalToken(uint32(d.window[index-1]) & 0xFF) d.tokens[d.ti] = literalToken(uint32(d.window[d.index-1]))
ti++ d.ti++
byteAvailable = false d.byteAvailable = false
} }
if ti > 0 { if d.ti > 0 {
if err = d.writeBlock(tokens[0:ti], index, false); err != nil { if d.err = d.writeBlock(d.tokens[0:d.ti], d.index, false); d.err != nil {
return return
} }
ti = 0 d.ti = 0
}
if d.sync {
d.w.writeStoredHeader(0, false)
d.w.flush()
d.syncChan <- d.w.err
d.sync = false
}
// If this was only a sync (not at EOF) keep going.
if !d.eof {
continue
} }
break Loop break Loop
} }
} }
if index < maxInsertIndex { if d.index < d.maxInsertIndex {
// Update the hash // Update the hash
hash = (hash<<hashShift + int(d.window[index+2])) & hashMask d.hash = (d.hash<<hashShift + int(d.window[d.index+2])) & hashMask
chainHead = d.hashHead[hash] d.chainHead = d.hashHead[d.hash]
d.hashPrev[index&d.windowMask] = chainHead d.hashPrev[d.index&windowMask] = d.chainHead
d.hashHead[hash] = index d.hashHead[d.hash] = d.index
} }
prevLength := length prevLength := d.length
prevOffset := offset prevOffset := d.offset
length = minMatchLength - 1 d.length = minMatchLength - 1
offset = 0 d.offset = 0
minIndex := index - maxOffset minIndex := d.index - windowSize
if minIndex < 0 { if minIndex < 0 {
minIndex = 0 minIndex = 0
} }
if chainHead >= minIndex && if d.chainHead >= minIndex &&
(isFastDeflate && lookahead > minMatchLength-1 || (d.fastSkipHashing != 0 && lookahead > minMatchLength-1 ||
!isFastDeflate && lookahead > prevLength && prevLength < lazyMatch) { d.fastSkipHashing == 0 && lookahead > prevLength && prevLength < d.lazy) {
if newLength, newOffset, ok := d.findMatch(index, chainHead, minMatchLength-1, lookahead); ok { if newLength, newOffset, ok := d.findMatch(d.index, d.chainHead, minMatchLength-1, lookahead); ok {
length = newLength d.length = newLength
offset = newOffset d.offset = newOffset
} }
} }
if isFastDeflate && length >= minMatchLength || if d.fastSkipHashing != 0 && d.length >= minMatchLength ||
!isFastDeflate && prevLength >= minMatchLength && length <= prevLength { d.fastSkipHashing == 0 && prevLength >= minMatchLength && d.length <= prevLength {
// There was a match at the previous step, and the current match is // There was a match at the previous step, and the current match is
// not better. Output the previous match. // not better. Output the previous match.
if isFastDeflate { if d.fastSkipHashing != 0 {
tokens[ti] = matchToken(uint32(length-minMatchLength), uint32(offset-minOffsetSize)) d.tokens[d.ti] = matchToken(uint32(d.length-minMatchLength), uint32(d.offset-minOffsetSize))
} else { } else {
tokens[ti] = matchToken(uint32(prevLength-minMatchLength), uint32(prevOffset-minOffsetSize)) d.tokens[d.ti] = matchToken(uint32(prevLength-minMatchLength), uint32(prevOffset-minOffsetSize))
} }
ti++ d.ti++
// Insert in the hash table all strings up to the end of the match. // Insert in the hash table all strings up to the end of the match.
// index and index-1 are already inserted. If there is not enough // index and index-1 are already inserted. If there is not enough
// lookahead, the last two strings are not inserted into the hash // lookahead, the last two strings are not inserted into the hash
// table. // table.
if length <= l.fastSkipHashing { if d.length <= d.fastSkipHashing {
var newIndex int var newIndex int
if isFastDeflate { if d.fastSkipHashing != 0 {
newIndex = index + length newIndex = d.index + d.length
} else { } else {
newIndex = prevLength - 1 newIndex = prevLength - 1
} }
for index++; index < newIndex; index++ { for d.index++; d.index < newIndex; d.index++ {
if index < maxInsertIndex { if d.index < d.maxInsertIndex {
hash = (hash<<hashShift + int(d.window[index+2])) & hashMask d.hash = (d.hash<<hashShift + int(d.window[d.index+2])) & hashMask
// Get previous value with the same hash. // Get previous value with the same hash.
// Our chain should point to the previous value. // Our chain should point to the previous value.
d.hashPrev[index&d.windowMask] = d.hashHead[hash] d.hashPrev[d.index&windowMask] = d.hashHead[d.hash]
// Set the head of the hash chain to us. // Set the head of the hash chain to us.
d.hashHead[hash] = index d.hashHead[d.hash] = d.index
} }
} }
if !isFastDeflate { if d.fastSkipHashing == 0 {
byteAvailable = false d.byteAvailable = false
length = minMatchLength - 1 d.length = minMatchLength - 1
} }
} else { } else {
// For matches this long, we don't bother inserting each individual // For matches this long, we don't bother inserting each individual
// item into the table. // item into the table.
index += length d.index += d.length
hash = (int(d.window[index])<<hashShift + int(d.window[index+1])) d.hash = (int(d.window[d.index])<<hashShift + int(d.window[d.index+1]))
} }
if ti == maxFlateBlockTokens { if d.ti == maxFlateBlockTokens {
// The block includes the current character // The block includes the current character
if err = d.writeBlock(tokens, index, false); err != nil { if d.err = d.writeBlock(d.tokens, d.index, false); d.err != nil {
return return
} }
ti = 0 d.ti = 0
} }
} else { } else {
if isFastDeflate || byteAvailable { if d.fastSkipHashing != 0 || d.byteAvailable {
i := index - 1 i := d.index - 1
if isFastDeflate { if d.fastSkipHashing != 0 {
i = index i = d.index
} }
tokens[ti] = literalToken(uint32(d.window[i]) & 0xFF) d.tokens[d.ti] = literalToken(uint32(d.window[i]))
ti++ d.ti++
if ti == maxFlateBlockTokens { if d.ti == maxFlateBlockTokens {
if err = d.writeBlock(tokens, i+1, false); err != nil { if d.err = d.writeBlock(d.tokens, i+1, false); d.err != nil {
return return
} }
ti = 0 d.ti = 0
} }
} }
index++ d.index++
if !isFastDeflate { if d.fastSkipHashing == 0 {
byteAvailable = true d.byteAvailable = true
} }
} }
} }
return
} }
func (d *compressor) compress(r io.Reader, w io.Writer, level int, logWindowSize uint) (err os.Error) { func (d *compressor) fillStore(b []byte) int {
d.r = r n := copy(d.window[d.windowEnd:], b)
d.windowEnd += n
return n
}
func (d *compressor) store() {
if d.windowEnd > 0 {
d.err = d.writeStoredBlock(d.window[:d.windowEnd])
}
d.windowEnd = 0
}
func (d *compressor) write(b []byte) (n int, err os.Error) {
n = len(b)
b = b[d.fill(d, b):]
for len(b) > 0 {
d.step(d)
b = b[d.fill(d, b):]
}
return n, d.err
}
func (d *compressor) syncFlush() os.Error {
d.sync = true
d.step(d)
if d.err == nil {
d.w.writeStoredHeader(0, false)
d.w.flush()
d.err = d.w.err
}
d.sync = false
return d.err
}
func (d *compressor) init(w io.Writer, level int) (err os.Error) {
d.w = newHuffmanBitWriter(w) d.w = newHuffmanBitWriter(w)
d.level = level
d.logWindowSize = logWindowSize
switch { switch {
case level == NoCompression: case level == NoCompression:
err = d.storedDeflate() d.window = make([]byte, maxStoreBlockSize)
d.fill = (*compressor).fillStore
d.step = (*compressor).store
case level == DefaultCompression: case level == DefaultCompression:
d.level = 6 level = 6
fallthrough fallthrough
case 1 <= level && level <= 9: case 1 <= level && level <= 9:
err = d.doDeflate() d.compressionLevel = levels[level]
d.initDeflate()
d.fill = (*compressor).fillDeflate
d.step = (*compressor).deflate
default: default:
return WrongValueError{"level", 0, 9, int32(level)} return WrongValueError{"level", 0, 9, int32(level)}
} }
return nil
if d.sync {
d.syncChan <- err
d.sync = false
} }
if err != nil {
return err func (d *compressor) close() os.Error {
d.sync = true
d.step(d)
if d.err != nil {
return d.err
} }
if d.w.writeStoredHeader(0, true); d.w.err != nil { if d.w.writeStoredHeader(0, true); d.w.err != nil {
return d.w.err return d.w.err
} }
return d.flush() d.w.flush()
return d.w.err
} }
// NewWriter returns a new Writer compressing // NewWriter returns a new Writer compressing
@ -486,14 +430,9 @@ func (d *compressor) compress(r io.Reader, w io.Writer, level int, logWindowSize
// compression; it only adds the necessary DEFLATE framing. // compression; it only adds the necessary DEFLATE framing.
func NewWriter(w io.Writer, level int) *Writer { func NewWriter(w io.Writer, level int) *Writer {
const logWindowSize = logMaxOffsetSize const logWindowSize = logMaxOffsetSize
var d compressor var dw Writer
d.syncChan = make(chan os.Error, 1) dw.d.init(w, level)
pr, pw := syncPipe() return &dw
go func() {
err := d.compress(pr, w, level, logWindowSize)
pr.CloseWithError(err)
}()
return &Writer{pw, &d}
} }
// NewWriterDict is like NewWriter but initializes the new // NewWriterDict is like NewWriter but initializes the new
@ -526,18 +465,13 @@ func (w *dictWriter) Write(b []byte) (n int, err os.Error) {
// A Writer takes data written to it and writes the compressed // A Writer takes data written to it and writes the compressed
// form of that data to an underlying writer (see NewWriter). // form of that data to an underlying writer (see NewWriter).
type Writer struct { type Writer struct {
w *syncPipeWriter d compressor
d *compressor
} }
// Write writes data to w, which will eventually write the // Write writes data to w, which will eventually write the
// compressed form of data to its underlying writer. // compressed form of data to its underlying writer.
func (w *Writer) Write(data []byte) (n int, err os.Error) { func (w *Writer) Write(data []byte) (n int, err os.Error) {
if len(data) == 0 { return w.d.write(data)
// no point, and nil interferes with sync
return
}
return w.w.Write(data)
} }
// Flush flushes any pending compressed data to the underlying writer. // Flush flushes any pending compressed data to the underlying writer.
@ -550,18 +484,10 @@ func (w *Writer) Write(data []byte) (n int, err os.Error) {
func (w *Writer) Flush() os.Error { func (w *Writer) Flush() os.Error {
// For more about flushing: // For more about flushing:
// http://www.bolet.org/~pornin/deflate-flush.html // http://www.bolet.org/~pornin/deflate-flush.html
if w.d.sync { return w.d.syncFlush()
panic("compress/flate: double Flush")
}
_, err := w.w.Write(nil)
err1 := <-w.d.syncChan
if err == nil {
err = err1
}
return err
} }
// Close flushes and closes the writer. // Close flushes and closes the writer.
func (w *Writer) Close() os.Error { func (w *Writer) Close() os.Error {
return w.w.Close() return w.d.close()
} }

View File

@ -57,7 +57,7 @@ var deflateInflateTests = []*deflateInflateTest{
&deflateInflateTest{[]byte{0x11, 0x12}}, &deflateInflateTest{[]byte{0x11, 0x12}},
&deflateInflateTest{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}}, &deflateInflateTest{[]byte{0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11}},
&deflateInflateTest{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}}, &deflateInflateTest{[]byte{0x11, 0x10, 0x13, 0x41, 0x21, 0x21, 0x41, 0x13, 0x87, 0x78, 0x13}},
&deflateInflateTest{getLargeDataChunk()}, &deflateInflateTest{largeDataChunk()},
} }
var reverseBitsTests = []*reverseBitsTest{ var reverseBitsTests = []*reverseBitsTest{
@ -71,23 +71,22 @@ var reverseBitsTests = []*reverseBitsTest{
&reverseBitsTest{29, 5, 23}, &reverseBitsTest{29, 5, 23},
} }
func getLargeDataChunk() []byte { func largeDataChunk() []byte {
result := make([]byte, 100000) result := make([]byte, 100000)
for i := range result { for i := range result {
result[i] = byte(int64(i) * int64(i) & 0xFF) result[i] = byte(i * i & 0xFF)
} }
return result return result
} }
func TestDeflate(t *testing.T) { func TestDeflate(t *testing.T) {
for _, h := range deflateTests { for _, h := range deflateTests {
buffer := bytes.NewBuffer(nil) var buf bytes.Buffer
w := NewWriter(buffer, h.level) w := NewWriter(&buf, h.level)
w.Write(h.in) w.Write(h.in)
w.Close() w.Close()
if bytes.Compare(buffer.Bytes(), h.out) != 0 { if !bytes.Equal(buf.Bytes(), h.out) {
t.Errorf("buffer is wrong; level = %v, buffer.Bytes() = %v, expected output = %v", t.Errorf("Deflate(%d, %x) = %x, want %x", h.level, h.in, buf.Bytes(), h.out)
h.level, buffer.Bytes(), h.out)
} }
} }
} }

View File

@ -195,9 +195,8 @@ type Reader interface {
// Decompress state. // Decompress state.
type decompressor struct { type decompressor struct {
// Input/output sources. // Input source.
r Reader r Reader
w io.Writer
roffset int64 roffset int64
woffset int64 woffset int64
@ -220,38 +219,79 @@ type decompressor struct {
// Temporary buffer (avoids repeated allocation). // Temporary buffer (avoids repeated allocation).
buf [4]byte buf [4]byte
// Next step in the decompression,
// and decompression state.
step func(*decompressor)
final bool
err os.Error
toRead []byte
hl, hd *huffmanDecoder
copyLen int
copyDist int
} }
func (f *decompressor) inflate() (err os.Error) { func (f *decompressor) nextBlock() {
final := false if f.final {
for err == nil && !final { if f.hw != f.hp {
f.flush((*decompressor).nextBlock)
return
}
f.err = os.EOF
return
}
for f.nb < 1+2 { for f.nb < 1+2 {
if err = f.moreBits(); err != nil { if f.err = f.moreBits(); f.err != nil {
return return
} }
} }
final = f.b&1 == 1 f.final = f.b&1 == 1
f.b >>= 1 f.b >>= 1
typ := f.b & 3 typ := f.b & 3
f.b >>= 2 f.b >>= 2
f.nb -= 1 + 2 f.nb -= 1 + 2
switch typ { switch typ {
case 0: case 0:
err = f.dataBlock() f.dataBlock()
case 1: case 1:
// compressed, fixed Huffman tables // compressed, fixed Huffman tables
err = f.decodeBlock(&fixedHuffmanDecoder, nil) f.hl = &fixedHuffmanDecoder
f.hd = nil
f.huffmanBlock()
case 2: case 2:
// compressed, dynamic Huffman tables // compressed, dynamic Huffman tables
if err = f.readHuffman(); err == nil { if f.err = f.readHuffman(); f.err != nil {
err = f.decodeBlock(&f.h1, &f.h2) break
} }
f.hl = &f.h1
f.hd = &f.h2
f.huffmanBlock()
default: default:
// 3 is reserved. // 3 is reserved.
err = CorruptInputError(f.roffset) f.err = CorruptInputError(f.roffset)
} }
} }
return
func (f *decompressor) Read(b []byte) (int, os.Error) {
for {
if len(f.toRead) > 0 {
n := copy(b, f.toRead)
f.toRead = f.toRead[n:]
return n, nil
}
if f.err != nil {
return 0, f.err
}
f.step(f)
}
panic("unreachable")
}
func (f *decompressor) Close() os.Error {
if f.err == os.EOF {
return nil
}
return f.err
} }
// RFC 1951 section 3.2.7. // RFC 1951 section 3.2.7.
@ -356,11 +396,12 @@ func (f *decompressor) readHuffman() os.Error {
// hl and hd are the Huffman states for the lit/length values // hl and hd are the Huffman states for the lit/length values
// and the distance values, respectively. If hd == nil, using the // and the distance values, respectively. If hd == nil, using the
// fixed distance encoding associated with fixed Huffman blocks. // fixed distance encoding associated with fixed Huffman blocks.
func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error { func (f *decompressor) huffmanBlock() {
for { for {
v, err := f.huffSym(hl) v, err := f.huffSym(f.hl)
if err != nil { if err != nil {
return err f.err = err
return
} }
var n uint // number of bits extra var n uint // number of bits extra
var length int var length int
@ -369,13 +410,15 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
f.hist[f.hp] = byte(v) f.hist[f.hp] = byte(v)
f.hp++ f.hp++
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
if err = f.flush(); err != nil { // After the flush, continue this loop.
return err f.flush((*decompressor).huffmanBlock)
} return
} }
continue continue
case v == 256: case v == 256:
return nil // Done with huffman block; read next block.
f.step = (*decompressor).nextBlock
return
// otherwise, reference to older data // otherwise, reference to older data
case v < 265: case v < 265:
length = v - (257 - 3) length = v - (257 - 3)
@ -402,7 +445,8 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
if n > 0 { if n > 0 {
for f.nb < n { for f.nb < n {
if err = f.moreBits(); err != nil { if err = f.moreBits(); err != nil {
return err f.err = err
return
} }
} }
length += int(f.b & uint32(1<<n-1)) length += int(f.b & uint32(1<<n-1))
@ -411,18 +455,20 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
} }
var dist int var dist int
if hd == nil { if f.hd == nil {
for f.nb < 5 { for f.nb < 5 {
if err = f.moreBits(); err != nil { if err = f.moreBits(); err != nil {
return err f.err = err
return
} }
} }
dist = int(reverseByte[(f.b&0x1F)<<3]) dist = int(reverseByte[(f.b&0x1F)<<3])
f.b >>= 5 f.b >>= 5
f.nb -= 5 f.nb -= 5
} else { } else {
if dist, err = f.huffSym(hd); err != nil { if dist, err = f.huffSym(f.hd); err != nil {
return err f.err = err
return
} }
} }
@ -430,14 +476,16 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
case dist < 4: case dist < 4:
dist++ dist++
case dist >= 30: case dist >= 30:
return CorruptInputError(f.roffset) f.err = CorruptInputError(f.roffset)
return
default: default:
nb := uint(dist-2) >> 1 nb := uint(dist-2) >> 1
// have 1 bit in bottom of dist, need nb more. // have 1 bit in bottom of dist, need nb more.
extra := (dist & 1) << nb extra := (dist & 1) << nb
for f.nb < nb { for f.nb < nb {
if err = f.moreBits(); err != nil { if err = f.moreBits(); err != nil {
return err f.err = err
return
} }
} }
extra |= int(f.b & uint32(1<<nb-1)) extra |= int(f.b & uint32(1<<nb-1))
@ -448,12 +496,14 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
// Copy history[-dist:-dist+length] into output. // Copy history[-dist:-dist+length] into output.
if dist > len(f.hist) { if dist > len(f.hist) {
return InternalError("bad history distance") f.err = InternalError("bad history distance")
return
} }
// No check on length; encoding can be prescient. // No check on length; encoding can be prescient.
if !f.hfull && dist > f.hp { if !f.hfull && dist > f.hp {
return CorruptInputError(f.roffset) f.err = CorruptInputError(f.roffset)
return
} }
p := f.hp - dist p := f.hp - dist
@ -465,9 +515,11 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
f.hp++ f.hp++
p++ p++
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
if err = f.flush(); err != nil { // After flush continue copying out of history.
return err f.copyLen = length - (i + 1)
} f.copyDist = dist
f.flush((*decompressor).copyHuff)
return
} }
if p == len(f.hist) { if p == len(f.hist) {
p = 0 p = 0
@ -477,8 +529,33 @@ func (f *decompressor) decodeBlock(hl, hd *huffmanDecoder) os.Error {
panic("unreached") panic("unreached")
} }
func (f *decompressor) copyHuff() {
length := f.copyLen
dist := f.copyDist
p := f.hp - dist
if p < 0 {
p += len(f.hist)
}
for i := 0; i < length; i++ {
f.hist[f.hp] = f.hist[p]
f.hp++
p++
if f.hp == len(f.hist) {
f.copyLen = length - (i + 1)
f.flush((*decompressor).copyHuff)
return
}
if p == len(f.hist) {
p = 0
}
}
// Continue processing Huffman block.
f.huffmanBlock()
}
// Copy a single uncompressed data block from input to output. // Copy a single uncompressed data block from input to output.
func (f *decompressor) dataBlock() os.Error { func (f *decompressor) dataBlock() {
// Uncompressed. // Uncompressed.
// Discard current half-byte. // Discard current half-byte.
f.nb = 0 f.nb = 0
@ -488,21 +565,30 @@ func (f *decompressor) dataBlock() os.Error {
nr, err := io.ReadFull(f.r, f.buf[0:4]) nr, err := io.ReadFull(f.r, f.buf[0:4])
f.roffset += int64(nr) f.roffset += int64(nr)
if err != nil { if err != nil {
return &ReadError{f.roffset, err} f.err = &ReadError{f.roffset, err}
return
} }
n := int(f.buf[0]) | int(f.buf[1])<<8 n := int(f.buf[0]) | int(f.buf[1])<<8
nn := int(f.buf[2]) | int(f.buf[3])<<8 nn := int(f.buf[2]) | int(f.buf[3])<<8
if uint16(nn) != uint16(^n) { if uint16(nn) != uint16(^n) {
return CorruptInputError(f.roffset) f.err = CorruptInputError(f.roffset)
return
} }
if n == 0 { if n == 0 {
// 0-length block means sync // 0-length block means sync
return f.flush() f.flush((*decompressor).nextBlock)
return
} }
// Read len bytes into history, f.copyLen = n
// writing as history fills. f.copyData()
}
func (f *decompressor) copyData() {
// Read f.dataLen bytes into history,
// pausing for reads as history fills.
n := f.copyLen
for n > 0 { for n > 0 {
m := len(f.hist) - f.hp m := len(f.hist) - f.hp
if m > n { if m > n {
@ -511,17 +597,18 @@ func (f *decompressor) dataBlock() os.Error {
m, err := io.ReadFull(f.r, f.hist[f.hp:f.hp+m]) m, err := io.ReadFull(f.r, f.hist[f.hp:f.hp+m])
f.roffset += int64(m) f.roffset += int64(m)
if err != nil { if err != nil {
return &ReadError{f.roffset, err} f.err = &ReadError{f.roffset, err}
return
} }
n -= m n -= m
f.hp += m f.hp += m
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
if err = f.flush(); err != nil { f.copyLen = n
return err f.flush((*decompressor).copyData)
return
} }
} }
} f.step = (*decompressor).nextBlock
return nil
} }
func (f *decompressor) setDict(dict []byte) { func (f *decompressor) setDict(dict []byte) {
@ -577,17 +664,8 @@ func (f *decompressor) huffSym(h *huffmanDecoder) (int, os.Error) {
} }
// Flush any buffered output to the underlying writer. // Flush any buffered output to the underlying writer.
func (f *decompressor) flush() os.Error { func (f *decompressor) flush(step func(*decompressor)) {
if f.hw == f.hp { f.toRead = f.hist[f.hw:f.hp]
return nil
}
n, err := f.w.Write(f.hist[f.hw:f.hp])
if n != f.hp-f.hw && err == nil {
err = io.ErrShortWrite
}
if err != nil {
return &WriteError{f.woffset, err}
}
f.woffset += int64(f.hp - f.hw) f.woffset += int64(f.hp - f.hw)
f.hw = f.hp f.hw = f.hp
if f.hp == len(f.hist) { if f.hp == len(f.hist) {
@ -595,7 +673,7 @@ func (f *decompressor) flush() os.Error {
f.hw = 0 f.hw = 0
f.hfull = true f.hfull = true
} }
return nil f.step = step
} }
func makeReader(r io.Reader) Reader { func makeReader(r io.Reader) Reader {
@ -605,30 +683,15 @@ func makeReader(r io.Reader) Reader {
return bufio.NewReader(r) return bufio.NewReader(r)
} }
// decompress reads DEFLATE-compressed data from r and writes
// the uncompressed data to w.
func (f *decompressor) decompress(r io.Reader, w io.Writer) os.Error {
f.r = makeReader(r)
f.w = w
f.woffset = 0
if err := f.inflate(); err != nil {
return err
}
if err := f.flush(); err != nil {
return err
}
return nil
}
// NewReader returns a new ReadCloser that can be used // NewReader returns a new ReadCloser that can be used
// to read the uncompressed version of r. It is the caller's // to read the uncompressed version of r. It is the caller's
// responsibility to call Close on the ReadCloser when // responsibility to call Close on the ReadCloser when
// finished reading. // finished reading.
func NewReader(r io.Reader) io.ReadCloser { func NewReader(r io.Reader) io.ReadCloser {
var f decompressor var f decompressor
pr, pw := io.Pipe() f.r = makeReader(r)
go func() { pw.CloseWithError(f.decompress(r, pw)) }() f.step = (*decompressor).nextBlock
return pr return &f
} }
// NewReaderDict is like NewReader but initializes the reader // NewReaderDict is like NewReader but initializes the reader
@ -639,7 +702,7 @@ func NewReader(r io.Reader) io.ReadCloser {
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
var f decompressor var f decompressor
f.setDict(dict) f.setDict(dict)
pr, pw := io.Pipe() f.r = makeReader(r)
go func() { pw.CloseWithError(f.decompress(r, pw)) }() f.step = (*decompressor).nextBlock
return pr return &f
} }