1
0
mirror of https://github.com/golang/go synced 2024-11-11 23:10:23 -07:00

compress/flate: fix deflate Reset consistency

Modify the overflow detection logic to shuffle the contents
of the table to a lower offset to avoid leaking the effects
of a previous use of compress.Writer past Reset calls.

Fixes #34121

Change-Id: I9963eadfa5482881e7b7adbad4c2cae146b669ab
GitHub-Last-Rev: 8b35798cdd
GitHub-Pull-Request: golang/go#34128
Reviewed-on: https://go-review.googlesource.com/c/go/+/193605
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
Run-TryBot: Joe Tsai <thebrokentoaster@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
This commit is contained in:
Klaus Post 2020-05-07 12:50:00 +00:00 committed by Russ Cox
parent c5d7f2f1cb
commit 8d4330742c
2 changed files with 97 additions and 15 deletions

View File

@ -4,6 +4,8 @@
package flate package flate
import "math"
// This encoding algorithm, which prioritizes speed over output size, is // This encoding algorithm, which prioritizes speed over output size, is
// based on Snappy's LZ77-style encoder: github.com/golang/snappy // based on Snappy's LZ77-style encoder: github.com/golang/snappy
@ -12,6 +14,13 @@ const (
tableSize = 1 << tableBits // Size of the table. tableSize = 1 << tableBits // Size of the table.
tableMask = tableSize - 1 // Mask for table indices. Redundant, but can eliminate bounds checks. tableMask = tableSize - 1 // Mask for table indices. Redundant, but can eliminate bounds checks.
tableShift = 32 - tableBits // Right-shift to get the tableBits most significant bits of a uint32. tableShift = 32 - tableBits // Right-shift to get the tableBits most significant bits of a uint32.
// Reset the buffer offset when reaching this.
// Offsets are stored between blocks as int32 values.
// Since the offset we are checking against is at the beginning
// of the buffer, we need to subtract the current and input
// buffer to not risk overflowing the int32.
bufferReset = math.MaxInt32 - maxStoreBlockSize*2
) )
func load32(b []byte, i int32) uint32 { func load32(b []byte, i int32) uint32 {
@ -59,8 +68,8 @@ func newDeflateFast() *deflateFast {
// to dst and returns the result. // to dst and returns the result.
func (e *deflateFast) encode(dst []token, src []byte) []token { func (e *deflateFast) encode(dst []token, src []byte) []token {
// Ensure that e.cur doesn't wrap. // Ensure that e.cur doesn't wrap.
if e.cur > 1<<30 { if e.cur >= bufferReset {
e.resetAll() e.shiftOffsets()
} }
// This check isn't in the Snappy implementation, but there, the caller // This check isn't in the Snappy implementation, but there, the caller
@ -264,22 +273,32 @@ func (e *deflateFast) reset() {
e.cur += maxMatchOffset e.cur += maxMatchOffset
// Protect against e.cur wraparound. // Protect against e.cur wraparound.
if e.cur > 1<<30 { if e.cur >= bufferReset {
e.resetAll() e.shiftOffsets()
} }
} }
// resetAll resets the deflateFast struct and is only called in rare // shiftOffsets will shift down all match offset.
// situations to prevent integer overflow. It manually resets each field // This is only called in rare situations to prevent integer overflow.
// to avoid causing large stack growth.
// //
// See https://golang.org/issue/18636. // See https://golang.org/issue/18636 and https://github.com/golang/go/issues/34121.
func (e *deflateFast) resetAll() { func (e *deflateFast) shiftOffsets() {
// This is equivalent to: if len(e.prev) == 0 {
// *e = deflateFast{cur: maxStoreBlockSize, prev: e.prev[:0]} // We have no history; just clear the table.
e.cur = maxStoreBlockSize for i := range e.table[:] {
e.prev = e.prev[:0] e.table[i] = tableEntry{}
for i := range e.table { }
e.table[i] = tableEntry{} e.cur = maxMatchOffset
return
} }
// Shift down everything in the table that isn't already too far away.
for i := range e.table[:] {
v := e.table[i].offset - e.cur + maxMatchOffset
if v < 0 {
v = 0
}
e.table[i].offset = v
}
e.cur = maxMatchOffset
} }

View File

@ -173,3 +173,66 @@ func testDeterministic(i int, t *testing.T) {
t.Errorf("level %d did not produce deterministic result, result mismatch, len(a) = %d, len(b) = %d", i, len(b1b), len(b2b)) t.Errorf("level %d did not produce deterministic result, result mismatch, len(a) = %d, len(b) = %d", i, len(b1b), len(b2b))
} }
} }
// TestDeflateFast_Reset will test that encoding is consistent
// across a warparound of the table offset.
// See https://github.com/golang/go/issues/34121
func TestDeflateFast_Reset(t *testing.T) {
buf := new(bytes.Buffer)
n := 65536
for i := 0; i < n; i++ {
fmt.Fprintf(buf, "asdfasdfasdfasdf%d%dfghfgujyut%dyutyu\n", i, i, i)
}
// This is specific to level 1.
const level = 1
in := buf.Bytes()
offset := 1
if testing.Short() {
offset = 256
}
// We do an encode with a clean buffer to compare.
var want bytes.Buffer
w, err := NewWriter(&want, level)
if err != nil {
t.Fatalf("NewWriter: level %d: %v", level, err)
}
// Output written 3 times.
w.Write(in)
w.Write(in)
w.Write(in)
w.Close()
for ; offset <= 256; offset *= 2 {
w, err := NewWriter(ioutil.Discard, level)
if err != nil {
t.Fatalf("NewWriter: level %d: %v", level, err)
}
// Reset until we are right before the wraparound.
// Each reset adds maxMatchOffset to the offset.
for i := 0; i < (bufferReset-len(in)-offset-maxMatchOffset)/maxMatchOffset; i++ {
// skip ahead to where we are close to wrap around...
w.d.reset(nil)
}
var got bytes.Buffer
w.Reset(&got)
// Write 3 times, close.
for i := 0; i < 3; i++ {
_, err = w.Write(in)
if err != nil {
t.Fatal(err)
}
}
err = w.Close()
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(got.Bytes(), want.Bytes()) {
t.Fatalf("output did not match at wraparound, len(want) = %d, len(got) = %d", want.Len(), got.Len())
}
}
}