diff --git a/src/pkg/compress/zlib/writer.go b/src/pkg/compress/zlib/writer.go index cd8dea460a4..99ff6549acb 100644 --- a/src/pkg/compress/zlib/writer.go +++ b/src/pkg/compress/zlib/writer.go @@ -70,6 +70,23 @@ func NewWriterLevelDict(w io.Writer, level int, dict []byte) (*Writer, error) { }, nil } +// Reset clears the state of the Writer z such that it is equivalent to its +// initial state from NewWriterLevel or NewWriterLevelDict, but instead writing +// to w. +func (z *Writer) Reset(w io.Writer) { + z.w = w + // z.level and z.dict left unchanged. + if z.compressor != nil { + z.compressor.Reset(w) + } + if z.digest != nil { + z.digest.Reset() + } + z.err = nil + z.scratch = [4]byte{} + z.wroteHeader = false +} + // writeHeader writes the ZLIB header. func (z *Writer) writeHeader() (err error) { z.wroteHeader = true @@ -111,11 +128,15 @@ func (z *Writer) writeHeader() (err error) { return err } } - z.compressor, err = flate.NewWriterDict(z.w, z.level, z.dict) - if err != nil { - return err + if z.compressor == nil { + // Initialize deflater unless the Writer is being reused + // after a Reset call. + z.compressor, err = flate.NewWriterDict(z.w, z.level, z.dict) + if err != nil { + return err + } + z.digest = adler32.New() } - z.digest = adler32.New() return nil } diff --git a/src/pkg/compress/zlib/writer_test.go b/src/pkg/compress/zlib/writer_test.go index aee1a5c2f54..cf9c8325455 100644 --- a/src/pkg/compress/zlib/writer_test.go +++ b/src/pkg/compress/zlib/writer_test.go @@ -89,6 +89,56 @@ func testLevelDict(t *testing.T, fn string, b0 []byte, level int, d string) { } } +func testFileLevelDictReset(t *testing.T, fn string, level int, dict []byte) { + var b0 []byte + var err error + if fn != "" { + b0, err = ioutil.ReadFile(fn) + if err != nil { + t.Errorf("%s (level=%d): %v", fn, level, err) + return + } + } + + // Compress once. + buf := new(bytes.Buffer) + var zlibw *Writer + if dict == nil { + zlibw, err = NewWriterLevel(buf, level) + } else { + zlibw, err = NewWriterLevelDict(buf, level, dict) + } + if err == nil { + _, err = zlibw.Write(b0) + } + if err == nil { + err = zlibw.Close() + } + if err != nil { + t.Errorf("%s (level=%d): %v", fn, level, err) + return + } + out := buf.String() + + // Reset and comprses again. + buf2 := new(bytes.Buffer) + zlibw.Reset(buf2) + _, err = zlibw.Write(b0) + if err == nil { + err = zlibw.Close() + } + if err != nil { + t.Errorf("%s (level=%d): %v", fn, level, err) + return + } + out2 := buf2.String() + + if out2 != out { + t.Errorf("%s (level=%d): different output after reset (got %d bytes, expected %d", + fn, level, len(out2), len(out)) + } +} + func TestWriter(t *testing.T) { for i, s := range data { b := []byte(s) @@ -122,6 +172,21 @@ func TestWriterDict(t *testing.T) { } } +func TestWriterReset(t *testing.T) { + const dictionary = "0123456789." + for _, fn := range filenames { + testFileLevelDictReset(t, fn, NoCompression, nil) + testFileLevelDictReset(t, fn, DefaultCompression, nil) + testFileLevelDictReset(t, fn, NoCompression, []byte(dictionary)) + testFileLevelDictReset(t, fn, DefaultCompression, []byte(dictionary)) + if !testing.Short() { + for level := BestSpeed; level <= BestCompression; level++ { + testFileLevelDictReset(t, fn, level, nil) + } + } + } +} + func TestWriterDictIsUsed(t *testing.T) { var input = []byte("Lorem ipsum dolor sit amet, consectetur adipisicing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") var buf bytes.Buffer