diff --git a/src/archive/tar/writer.go b/src/archive/tar/writer.go index ffef29af10..3a8f102196 100644 --- a/src/archive/tar/writer.go +++ b/src/archive/tar/writer.go @@ -4,9 +4,6 @@ package tar -// TODO(dsymonds): -// - catch more errors (no first header, etc.) - import ( "bytes" "fmt" @@ -22,14 +19,16 @@ import ( // Call WriteHeader to begin a new file, and then call Write to supply that file's data, // writing at most hdr.Size bytes in total. type Writer struct { - w io.Writer - err error - nb int64 // number of unwritten bytes for current file entry - pad int64 // amount of padding to write after current file entry - closed bool - + w io.Writer + nb int64 // number of unwritten bytes for current file entry + pad int64 // amount of padding to write after current file entry hdr Header // Shallow copy of Header that is safe for mutations blk block // Buffer to use as temporary local storage + + // err is a persistent error. + // It is only the responsibility of every exported method of Writer to + // ensure that this error is sticky. + err error } // NewWriter creates a new Writer writing to w. @@ -41,10 +40,12 @@ func NewWriter(w io.Writer) *Writer { return &Writer{w: w} } // Deprecated: This is unecessary as the next call to WriteHeader or Close // will implicitly flush out the file's padding. func (tw *Writer) Flush() error { - if tw.nb > 0 { - tw.err = fmt.Errorf("tar: missed writing %d bytes", tw.nb) + if tw.err != nil { return tw.err } + if tw.nb > 0 { + return fmt.Errorf("archive/tar: missed writing %d bytes", tw.nb) + } if _, tw.err = tw.w.Write(zeroBlock[:tw.pad]); tw.err != nil { return tw.err } @@ -63,13 +64,16 @@ func (tw *Writer) WriteHeader(hdr *Header) error { tw.hdr = *hdr // Shallow copy of Header switch allowedFormats, paxHdrs := tw.hdr.allowedFormats(); { case allowedFormats&formatUSTAR != 0: - return tw.writeUSTARHeader(&tw.hdr) + tw.err = tw.writeUSTARHeader(&tw.hdr) + return tw.err case allowedFormats&formatPAX != 0: - return tw.writePAXHeader(&tw.hdr, paxHdrs) + tw.err = tw.writePAXHeader(&tw.hdr, paxHdrs) + return tw.err case allowedFormats&formatGNU != 0: - return tw.writeGNUHeader(&tw.hdr) + tw.err = tw.writeGNUHeader(&tw.hdr) + return tw.err default: - return ErrHeader + return ErrHeader // Non-fatal error } } @@ -273,45 +277,46 @@ func splitUSTARPath(name string) (prefix, suffix string, ok bool) { // Write writes to the current entry in the tar archive. // Write returns the error ErrWriteTooLong if more than -// hdr.Size bytes are written after WriteHeader. -func (tw *Writer) Write(b []byte) (n int, err error) { - if tw.closed { - err = ErrWriteAfterClose - return +// Header.Size bytes are written after WriteHeader. +// +// Calling Write on special types like TypeLink, TypeSymLink, TypeChar, +// TypeBlock, TypeDir, and TypeFifo returns (0, ErrWriteTooLong) regardless +// of what the Header.Size claims. +func (tw *Writer) Write(b []byte) (int, error) { + if tw.err != nil { + return 0, tw.err } - overwrite := false - if int64(len(b)) > tw.nb { - b = b[0:tw.nb] - overwrite = true + + overwrite := int64(len(b)) > tw.nb + if overwrite { + b = b[:tw.nb] } - n, err = tw.w.Write(b) + n, err := tw.w.Write(b) tw.nb -= int64(n) if err == nil && overwrite { - err = ErrWriteTooLong - return + return n, ErrWriteTooLong // Non-fatal error } tw.err = err - return + return n, err } // Close closes the tar archive, flushing any unwritten // data to the underlying writer. func (tw *Writer) Close() error { - if tw.err != nil || tw.closed { - return tw.err + if tw.err == ErrWriteAfterClose { + return nil } - tw.Flush() - tw.closed = true if tw.err != nil { return tw.err } - // trailer: two zero blocks - for i := 0; i < 2; i++ { - _, tw.err = tw.w.Write(zeroBlock[:]) - if tw.err != nil { - break - } + // Trailer: two zero blocks. + err := tw.Flush() + for i := 0; i < 2 && err == nil; i++ { + _, err = tw.w.Write(zeroBlock[:]) } - return tw.err + + // Ensure all future actions are invalid. + tw.err = ErrWriteAfterClose + return err // Report IO errors } diff --git a/src/archive/tar/writer_test.go b/src/archive/tar/writer_test.go index 9d92ab89a6..9cfc225611 100644 --- a/src/archive/tar/writer_test.go +++ b/src/archive/tar/writer_test.go @@ -576,40 +576,104 @@ func TestValidTypeflagWithPAXHeader(t *testing.T) { } } -func TestWriteHeaderOnly(t *testing.T) { - tw := NewWriter(new(bytes.Buffer)) - hdr := &Header{Name: "dir/", Typeflag: TypeDir} - if err := tw.WriteHeader(hdr); err != nil { - t.Fatalf("WriteHeader() = %v, want nil", err) - } - if _, err := tw.Write([]byte{0x00}); err != ErrWriteTooLong { - t.Fatalf("Write() = %v, want %v", err, ErrWriteTooLong) +// failOnceWriter fails exactly once and then always reports success. +type failOnceWriter bool + +func (w *failOnceWriter) Write(b []byte) (int, error) { + if !*w { + return 0, io.ErrShortWrite } + *w = true + return len(b), nil } -func TestWriteNegativeSize(t *testing.T) { - tw := NewWriter(new(bytes.Buffer)) - hdr := &Header{Name: "small.txt", Size: -1} - if err := tw.WriteHeader(hdr); err != ErrHeader { - t.Fatalf("WriteHeader() = nil, want %v", ErrHeader) - } -} +func TestWriterErrors(t *testing.T) { + t.Run("HeaderOnly", func(t *testing.T) { + tw := NewWriter(new(bytes.Buffer)) + hdr := &Header{Name: "dir/", Typeflag: TypeDir} + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("WriteHeader() = %v, want nil", err) + } + if _, err := tw.Write([]byte{0x00}); err != ErrWriteTooLong { + t.Fatalf("Write() = %v, want %v", err, ErrWriteTooLong) + } + }) -func TestWriteAfterClose(t *testing.T) { - var buffer bytes.Buffer - tw := NewWriter(&buffer) + t.Run("NegativeSize", func(t *testing.T) { + tw := NewWriter(new(bytes.Buffer)) + hdr := &Header{Name: "small.txt", Size: -1} + if err := tw.WriteHeader(hdr); err != ErrHeader { + t.Fatalf("WriteHeader() = nil, want %v", ErrHeader) + } + }) - hdr := &Header{ - Name: "small.txt", - Size: 5, - } - if err := tw.WriteHeader(hdr); err != nil { - t.Fatalf("Failed to write header: %s", err) - } - tw.Close() - if _, err := tw.Write([]byte("Kilts")); err != ErrWriteAfterClose { - t.Fatalf("Write: got %v; want ErrWriteAfterClose", err) - } + t.Run("BeforeHeader", func(t *testing.T) { + tw := NewWriter(new(bytes.Buffer)) + if _, err := tw.Write([]byte("Kilts")); err != ErrWriteTooLong { + t.Fatalf("Write() = %v, want %v", err, ErrWriteTooLong) + } + }) + + t.Run("AfterClose", func(t *testing.T) { + tw := NewWriter(new(bytes.Buffer)) + hdr := &Header{Name: "small.txt"} + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("WriteHeader() = %v, want nil", err) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close() = %v, want nil", err) + } + if _, err := tw.Write([]byte("Kilts")); err != ErrWriteAfterClose { + t.Fatalf("Write() = %v, want %v", err, ErrWriteAfterClose) + } + if err := tw.Flush(); err != ErrWriteAfterClose { + t.Fatalf("Flush() = %v, want %v", err, ErrWriteAfterClose) + } + if err := tw.Close(); err != nil { + t.Fatalf("Close() = %v, want nil", err) + } + }) + + t.Run("PrematureFlush", func(t *testing.T) { + tw := NewWriter(new(bytes.Buffer)) + hdr := &Header{Name: "small.txt", Size: 5} + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("WriteHeader() = %v, want nil", err) + } + if err := tw.Flush(); err == nil { + t.Fatalf("Flush() = %v, want non-nil error", err) + } + }) + + t.Run("PrematureClose", func(t *testing.T) { + tw := NewWriter(new(bytes.Buffer)) + hdr := &Header{Name: "small.txt", Size: 5} + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("WriteHeader() = %v, want nil", err) + } + if err := tw.Close(); err == nil { + t.Fatalf("Close() = %v, want non-nil error", err) + } + }) + + t.Run("Persistence", func(t *testing.T) { + tw := NewWriter(new(failOnceWriter)) + if err := tw.WriteHeader(&Header{}); err != io.ErrShortWrite { + t.Fatalf("WriteHeader() = %v, want %v", err, io.ErrShortWrite) + } + if err := tw.WriteHeader(&Header{Name: "small.txt"}); err == nil { + t.Errorf("WriteHeader() = got %v, want non-nil error", err) + } + if _, err := tw.Write(nil); err == nil { + t.Errorf("Write() = %v, want non-nil error", err) + } + if err := tw.Flush(); err == nil { + t.Errorf("Flush() = %v, want non-nil error", err) + } + if err := tw.Close(); err == nil { + t.Errorf("Close() = %v, want non-nil error", err) + } + }) } func TestSplitUSTARPath(t *testing.T) {