mirror of
https://github.com/golang/go
synced 2024-11-24 05:10:19 -07:00
compress/flate: return error on closed stream write
Previously flate.Writer allowed writes after Close, and this behavior could lead to stream corruption. Fixes #27741 Change-Id: Iee1ac69f8199232f693dba77b275f7078257b582 Reviewed-on: https://go-review.googlesource.com/c/go/+/136475 Run-TryBot: Ian Lance Taylor <iant@golang.org> TryBot-Result: Gopher Robot <gobot@golang.org> Auto-Submit: Ian Lance Taylor <iant@google.com> Reviewed-by: Joseph Tsai <joetsai@digital-static.net> Reviewed-by: Carlos Amedee <carlos@golang.org> Reviewed-by: Ian Lance Taylor <iant@google.com> Run-TryBot: Ian Lance Taylor <iant@google.com>
This commit is contained in:
parent
fdf1d768f2
commit
2719b1a9a8
@ -5,6 +5,7 @@
|
||||
package flate
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
@ -699,17 +700,27 @@ func (w *dictWriter) Write(b []byte) (n int, err error) {
|
||||
return w.w.Write(b)
|
||||
}
|
||||
|
||||
var errWriteAfterClose = errors.New("compress/flate: write after close")
|
||||
|
||||
// A Writer takes data written to it and writes the compressed
|
||||
// form of that data to an underlying writer (see NewWriter).
|
||||
type Writer struct {
|
||||
d compressor
|
||||
dict []byte
|
||||
err error
|
||||
}
|
||||
|
||||
// Write writes data to w, which will eventually write the
|
||||
// compressed form of data to its underlying writer.
|
||||
func (w *Writer) Write(data []byte) (n int, err error) {
|
||||
return w.d.write(data)
|
||||
if w.err != nil {
|
||||
return 0, w.err
|
||||
}
|
||||
n, err = w.d.write(data)
|
||||
if err != nil {
|
||||
w.err = err
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// Flush flushes any pending data to the underlying writer.
|
||||
@ -724,18 +735,37 @@ func (w *Writer) Write(data []byte) (n int, err error) {
|
||||
func (w *Writer) Flush() error {
|
||||
// For more about flushing:
|
||||
// https://www.bolet.org/~pornin/deflate-flush.html
|
||||
return w.d.syncFlush()
|
||||
if w.err != nil {
|
||||
return w.err
|
||||
}
|
||||
if err := w.d.syncFlush(); err != nil {
|
||||
w.err = err
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close flushes and closes the writer.
|
||||
func (w *Writer) Close() error {
|
||||
return w.d.close()
|
||||
if w.err == errWriteAfterClose {
|
||||
return nil
|
||||
}
|
||||
if w.err != nil {
|
||||
return w.err
|
||||
}
|
||||
if err := w.d.close(); err != nil {
|
||||
w.err = err
|
||||
return err
|
||||
}
|
||||
w.err = errWriteAfterClose
|
||||
return nil
|
||||
}
|
||||
|
||||
// Reset discards the writer's state and makes it equivalent to
|
||||
// the result of NewWriter or NewWriterDict called with dst
|
||||
// and w's level and dictionary.
|
||||
func (w *Writer) Reset(dst io.Writer) {
|
||||
w.err = nil
|
||||
if dw, ok := w.d.w.writer.(*dictWriter); ok {
|
||||
// w was created with NewWriterDict
|
||||
dw.w = dst
|
||||
|
@ -125,6 +125,40 @@ func TestDeflate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriterClose(t *testing.T) {
|
||||
b := new(bytes.Buffer)
|
||||
zw, err := NewWriter(b, 6)
|
||||
if err != nil {
|
||||
t.Fatalf("NewWriter: %v", err)
|
||||
}
|
||||
|
||||
if c, err := zw.Write([]byte("Test")); err != nil || c != 4 {
|
||||
t.Fatalf("Write to not closed writer: %s, %d", err, c)
|
||||
}
|
||||
|
||||
if err := zw.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
|
||||
afterClose := b.Len()
|
||||
|
||||
if c, err := zw.Write([]byte("Test")); err == nil || c != 0 {
|
||||
t.Fatalf("Write to closed writer: %s, %d", err, c)
|
||||
}
|
||||
|
||||
if err := zw.Flush(); err == nil {
|
||||
t.Fatalf("Flush to closed writer: %s", err)
|
||||
}
|
||||
|
||||
if err := zw.Close(); err != nil {
|
||||
t.Fatalf("Close: %v", err)
|
||||
}
|
||||
|
||||
if afterClose != b.Len() {
|
||||
t.Fatalf("Writer wrote data after close. After close: %d. After writes on closed stream: %d", afterClose, b.Len())
|
||||
}
|
||||
}
|
||||
|
||||
// A sparseReader returns a stream consisting of 0s followed by 1<<16 1s.
|
||||
// This tests missing hash references in a very large input.
|
||||
type sparseReader struct {
|
||||
@ -683,7 +717,7 @@ func (w *failWriter) Write(b []byte) (int, error) {
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
func TestWriterPersistentError(t *testing.T) {
|
||||
func TestWriterPersistentWriteError(t *testing.T) {
|
||||
t.Parallel()
|
||||
d, err := os.ReadFile("../../testdata/Isaac.Newton-Opticks.txt")
|
||||
if err != nil {
|
||||
@ -706,12 +740,16 @@ func TestWriterPersistentError(t *testing.T) {
|
||||
|
||||
_, werr := zw.Write(d)
|
||||
cerr := zw.Close()
|
||||
ferr := zw.Flush()
|
||||
if werr != errIO && werr != nil {
|
||||
t.Errorf("test %d, mismatching Write error: got %v, want %v", i, werr, errIO)
|
||||
}
|
||||
if cerr != errIO && fw.n < 0 {
|
||||
t.Errorf("test %d, mismatching Close error: got %v, want %v", i, cerr, errIO)
|
||||
}
|
||||
if ferr != errIO && fw.n < 0 {
|
||||
t.Errorf("test %d, mismatching Flush error: got %v, want %v", i, ferr, errIO)
|
||||
}
|
||||
if fw.n >= 0 {
|
||||
// At this point, the failure threshold was sufficiently high enough
|
||||
// that we wrote the whole stream without any errors.
|
||||
@ -719,6 +757,54 @@ func TestWriterPersistentError(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
func TestWriterPersistentFlushError(t *testing.T) {
|
||||
zw, err := NewWriter(&failWriter{0}, DefaultCompression)
|
||||
if err != nil {
|
||||
t.Fatalf("NewWriter: %v", err)
|
||||
}
|
||||
flushErr := zw.Flush()
|
||||
closeErr := zw.Close()
|
||||
_, writeErr := zw.Write([]byte("Test"))
|
||||
checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)
|
||||
}
|
||||
|
||||
func TestWriterPersistentCloseError(t *testing.T) {
|
||||
// If underlying writer return error on closing stream we should persistent this error across all writer calls.
|
||||
zw, err := NewWriter(&failWriter{0}, DefaultCompression)
|
||||
if err != nil {
|
||||
t.Fatalf("NewWriter: %v", err)
|
||||
}
|
||||
closeErr := zw.Close()
|
||||
flushErr := zw.Flush()
|
||||
_, writeErr := zw.Write([]byte("Test"))
|
||||
checkErrors([]error{closeErr, flushErr, writeErr}, errIO, t)
|
||||
|
||||
// After closing writer we should persistent "write after close" error across Flush and Write calls, but return nil
|
||||
// on next Close calls.
|
||||
var b bytes.Buffer
|
||||
zw.Reset(&b)
|
||||
err = zw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("First call to close returned error: %s", err)
|
||||
}
|
||||
err = zw.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Second call to close returned error: %s", err)
|
||||
}
|
||||
|
||||
flushErr = zw.Flush()
|
||||
_, writeErr = zw.Write([]byte("Test"))
|
||||
checkErrors([]error{flushErr, writeErr}, errWriteAfterClose, t)
|
||||
}
|
||||
|
||||
func checkErrors(got []error, want error, t *testing.T) {
|
||||
t.Helper()
|
||||
for _, err := range got {
|
||||
if err != want {
|
||||
t.Errorf("Errors dosn't match\nWant: %s\nGot: %s", want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBestSpeedMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
Loading…
Reference in New Issue
Block a user