mirror of
https://github.com/golang/go
synced 2024-11-20 09:44:45 -07:00
encoding/csv, encoding/xml: report write errors
Fixes #3773. R=bradfitz, rsc CC=golang-dev https://golang.org/cl/6327053
This commit is contained in:
parent
689931c5b0
commit
32a0cbb881
@ -101,11 +101,10 @@ func (w *Writer) WriteAll(records [][]string) (err error) {
|
||||
for _, record := range records {
|
||||
err = w.Write(record)
|
||||
if err != nil {
|
||||
break
|
||||
return err
|
||||
}
|
||||
}
|
||||
w.Flush()
|
||||
return nil
|
||||
return w.w.Flush()
|
||||
}
|
||||
|
||||
// fieldNeedsQuotes returns true if our field must be enclosed in quotes.
|
||||
|
@ -83,9 +83,7 @@ func MarshalIndent(v interface{}, prefix, indent string) ([]byte, error) {
|
||||
enc := NewEncoder(&b)
|
||||
enc.prefix = prefix
|
||||
enc.indent = indent
|
||||
err := enc.marshalValue(reflect.ValueOf(v), nil)
|
||||
enc.Flush()
|
||||
if err != nil {
|
||||
if err := enc.Encode(v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b.Bytes(), nil
|
||||
@ -107,8 +105,10 @@ func NewEncoder(w io.Writer) *Encoder {
|
||||
// of Go values to XML.
|
||||
func (enc *Encoder) Encode(v interface{}) error {
|
||||
err := enc.marshalValue(reflect.ValueOf(v), nil)
|
||||
enc.Flush()
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return enc.Flush()
|
||||
}
|
||||
|
||||
type printer struct {
|
||||
@ -224,7 +224,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo) error {
|
||||
p.WriteString(name)
|
||||
p.WriteByte('>')
|
||||
|
||||
return nil
|
||||
return p.cachedWriteError()
|
||||
}
|
||||
|
||||
var timeType = reflect.TypeOf(time.Time{})
|
||||
@ -260,15 +260,15 @@ func (p *printer) marshalSimple(typ reflect.Type, val reflect.Value) error {
|
||||
default:
|
||||
return &UnsupportedTypeError{typ}
|
||||
}
|
||||
return nil
|
||||
return p.cachedWriteError()
|
||||
}
|
||||
|
||||
var ddBytes = []byte("--")
|
||||
|
||||
func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
|
||||
if val.Type() == timeType {
|
||||
p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano))
|
||||
return nil
|
||||
_, err := p.WriteString(val.Interface().(time.Time).Format(time.RFC3339Nano))
|
||||
return err
|
||||
}
|
||||
s := parentStack{printer: p}
|
||||
for i := range tinfo.fields {
|
||||
@ -353,7 +353,13 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
|
||||
}
|
||||
}
|
||||
s.trim(nil)
|
||||
return nil
|
||||
return p.cachedWriteError()
|
||||
}
|
||||
|
||||
// return the bufio Writer's cached write error
|
||||
func (p *printer) cachedWriteError() error {
|
||||
_, err := p.Write(nil)
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *printer) writeIndent(depthDelta int) {
|
||||
|
@ -5,6 +5,9 @@
|
||||
package xml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
@ -779,6 +782,55 @@ func TestUnmarshal(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type limitedBytesWriter struct {
|
||||
w io.Writer
|
||||
remain int // until writes fail
|
||||
}
|
||||
|
||||
func (lw *limitedBytesWriter) Write(p []byte) (n int, err error) {
|
||||
if lw.remain <= 0 {
|
||||
println("error")
|
||||
return 0, errors.New("write limit hit")
|
||||
}
|
||||
if len(p) > lw.remain {
|
||||
p = p[:lw.remain]
|
||||
n, _ = lw.w.Write(p)
|
||||
lw.remain = 0
|
||||
return n, errors.New("write limit hit")
|
||||
}
|
||||
n, err = lw.w.Write(p)
|
||||
lw.remain -= n
|
||||
return n, err
|
||||
}
|
||||
|
||||
func TestMarshalWriteErrors(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
const writeCap = 1024
|
||||
w := &limitedBytesWriter{&buf, writeCap}
|
||||
enc := NewEncoder(w)
|
||||
var err error
|
||||
var i int
|
||||
const n = 4000
|
||||
for i = 1; i <= n; i++ {
|
||||
err = enc.Encode(&Passenger{
|
||||
Name: []string{"Alice", "Bob"},
|
||||
Weight: 5,
|
||||
})
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err == nil {
|
||||
t.Error("expected an error")
|
||||
}
|
||||
if i == n {
|
||||
t.Errorf("expected to fail before the end")
|
||||
}
|
||||
if buf.Len() != writeCap {
|
||||
t.Errorf("buf.Len() = %d; want %d", buf.Len(), writeCap)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkMarshal(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
Marshal(atomValue)
|
||||
|
Loading…
Reference in New Issue
Block a user