diff --git a/src/pkg/encoding/xml/read.go b/src/pkg/encoding/xml/read.go index f960f5649ca..698bf1a22ef 100644 --- a/src/pkg/encoding/xml/read.go +++ b/src/pkg/encoding/xml/read.go @@ -7,6 +7,7 @@ package xml import ( "bytes" "errors" + "fmt" "reflect" "strconv" "strings" @@ -137,6 +138,100 @@ type UnmarshalError string func (e UnmarshalError) Error() string { return string(e) } +// Unmarshaler is the interface implemented by objects that can unmarshal +// an XML element description of themselves. +// +// UnmarshalXML decodes a single XML element +// beginning with the given start element. +// If it returns an error, the outer call to Unmarshal stops and +// returns that error. +// UnmarshalXML must consume exactly one XML element. +// One common implementation strategy is to unmarshal into +// a separate value with a layout matching the expected XML +// using d.DecodeElement, and then to copy the data from +// that value into the receiver. +// Another common strategy is to use d.Token to process the +// XML object one token at a time. +// UnmarshalXML may not use d.RawToken. +type Unmarshaler interface { + UnmarshalXML(d *Decoder, start StartElement) error +} + +// UnmarshalerAttr is the interface implemented by objects that can unmarshal +// an XML attribute description of themselves. +// +// UnmarshalXMLAttr decodes a single XML attribute. +// If it returns an error, the outer call to Unmarshal stops and +// returns that error. +// UnmarshalXMLAttr is used only for struct fields with the +// "attr" option in the field tag. +type UnmarshalerAttr interface { + UnmarshalXMLAttr(attr Attr) error +} + +// receiverType returns the receiver type to use in an expression like "%s.MethodName". +func receiverType(val interface{}) string { + t := reflect.TypeOf(val) + if t.Name() != "" { + return t.String() + } + return "(" + t.String() + ")" +} + +// unmarshalInterface unmarshals a single XML element into val, +// which is known to implement Unmarshaler. +// start is the opening tag of the element. +func (p *Decoder) unmarshalInterface(val Unmarshaler, start *StartElement) error { + // Record that decoder must stop at end tag corresponding to start. + p.pushEOF() + + p.unmarshalDepth++ + err := val.UnmarshalXML(p, *start) + p.unmarshalDepth-- + if err != nil { + p.popEOF() + return err + } + + if !p.popEOF() { + return fmt.Errorf("xml: %s.UnmarshalXML did not consume entire <%s> element", receiverType(val), start.Name.Local) + } + + return nil +} + +// unmarshalAttr unmarshals a single XML attribute into val. +func (p *Decoder) unmarshalAttr(val reflect.Value, attr Attr) error { + if val.Kind() == reflect.Ptr { + if val.IsNil() { + val.Set(reflect.New(val.Type().Elem())) + } + val = val.Elem() + } + + if val.CanInterface() && val.Type().Implements(unmarshalerAttrType) { + // This is an unmarshaler with a non-pointer receiver, + // so it's likely to be incorrect, but we do what we're told. + return val.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr) + } + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(unmarshalerAttrType) { + return pv.Interface().(UnmarshalerAttr).UnmarshalXMLAttr(attr) + } + } + + // TODO: Check for and use encoding.TextUnmarshaler. + + copyValue(val, []byte(attr.Value)) + return nil +} + +var ( + unmarshalerType = reflect.TypeOf((*Unmarshaler)(nil)).Elem() + unmarshalerAttrType = reflect.TypeOf((*UnmarshalerAttr)(nil)).Elem() +) + // Unmarshal a single XML element into val. func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { // Find start element if we need it. @@ -153,13 +248,28 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { } } - if pv := val; pv.Kind() == reflect.Ptr { - if pv.IsNil() { - pv.Set(reflect.New(pv.Type().Elem())) + if val.Kind() == reflect.Ptr { + if val.IsNil() { + val.Set(reflect.New(val.Type().Elem())) } - val = pv.Elem() + val = val.Elem() } + if val.CanInterface() && val.Type().Implements(unmarshalerType) { + // This is an unmarshaler with a non-pointer receiver, + // so it's likely to be incorrect, but we do what we're told. + return p.unmarshalInterface(val.Interface().(Unmarshaler), start) + } + + if val.CanAddr() { + pv := val.Addr() + if pv.CanInterface() && pv.Type().Implements(unmarshalerType) { + return p.unmarshalInterface(pv.Interface().(Unmarshaler), start) + } + } + + // TODO: Check for and use encoding.TextUnmarshaler. + var ( data []byte saveData reflect.Value @@ -264,7 +374,9 @@ func (p *Decoder) unmarshal(val reflect.Value, start *StartElement) error { // Look for attribute. for _, a := range start.Attr { if a.Name.Local == finfo.name && (finfo.xmlns == "" || finfo.xmlns == a.Name.Space) { - copyValue(strv, []byte(a.Value)) + if err := p.unmarshalAttr(strv, a); err != nil { + return err + } break } } diff --git a/src/pkg/encoding/xml/read_test.go b/src/pkg/encoding/xml/read_test.go index 7d28c5d7d6c..1404c900f50 100644 --- a/src/pkg/encoding/xml/read_test.go +++ b/src/pkg/encoding/xml/read_test.go @@ -5,6 +5,7 @@ package xml import ( + "io" "reflect" "strings" "testing" @@ -621,3 +622,66 @@ func TestMarshalNSAttr(t *testing.T) { t.Errorf("Unmarshal = %q, want %q", dst, src) } } + +type MyCharData struct { + body string +} + +func (m *MyCharData) UnmarshalXML(d *Decoder, start StartElement) error { + for { + t, err := d.Token() + if err == io.EOF { // found end of element + break + } + if err != nil { + return err + } + if char, ok := t.(CharData); ok { + m.body += string(char) + } + } + return nil +} + +var _ Unmarshaler = (*MyCharData)(nil) + +func (m *MyCharData) UnmarshalXMLAttr(attr Attr) error { + panic("must not call") +} + +type MyAttr struct { + attr string +} + +func (m *MyAttr) UnmarshalXMLAttr(attr Attr) error { + m.attr = attr.Value + return nil +} + +var _ UnmarshalerAttr = (*MyAttr)(nil) + +type MyStruct struct { + Data *MyCharData + Attr *MyAttr `xml:",attr"` + + Data2 MyCharData + Attr2 MyAttr `xml:",attr"` +} + +func TestUnmarshaler(t *testing.T) { + xml := ` + + hello world + howdy world + + ` + + var m MyStruct + if err := Unmarshal([]byte(xml), &m); err != nil { + t.Fatal(err) + } + + if m.Data == nil || m.Attr == nil || m.Data.body != "hello world" || m.Attr.attr != "attr1" || m.Data2.body != "howdy world" || m.Attr2.attr != "attr2" { + t.Errorf("m=%#+v\n", m) + } +} diff --git a/src/pkg/encoding/xml/xml.go b/src/pkg/encoding/xml/xml.go index 2f36604797d..da8eb2e5f9d 100644 --- a/src/pkg/encoding/xml/xml.go +++ b/src/pkg/encoding/xml/xml.go @@ -16,6 +16,7 @@ package xml import ( "bufio" "bytes" + "errors" "fmt" "io" "strconv" @@ -174,18 +175,19 @@ type Decoder struct { // the attribute xmlns="DefaultSpace". DefaultSpace string - r io.ByteReader - buf bytes.Buffer - saved *bytes.Buffer - stk *stack - free *stack - needClose bool - toClose Name - nextToken Token - nextByte int - ns map[string]string - err error - line int + r io.ByteReader + buf bytes.Buffer + saved *bytes.Buffer + stk *stack + free *stack + needClose bool + toClose Name + nextToken Token + nextByte int + ns map[string]string + err error + line int + unmarshalDepth int } // NewDecoder creates a new XML parser reading from r. @@ -223,10 +225,14 @@ func NewDecoder(r io.Reader) *Decoder { // If Token encounters an unrecognized name space prefix, // it uses the prefix as the Space rather than report an error. func (d *Decoder) Token() (t Token, err error) { + if d.stk != nil && d.stk.kind == stkEOF { + err = io.EOF + return + } if d.nextToken != nil { t = d.nextToken d.nextToken = nil - } else if t, err = d.RawToken(); err != nil { + } else if t, err = d.rawToken(); err != nil { return } @@ -322,6 +328,7 @@ type stack struct { const ( stkStart = iota stkNs + stkEOF ) func (d *Decoder) push(kind int) *stack { @@ -347,6 +354,43 @@ func (d *Decoder) pop() *stack { return s } +// Record that after the current element is finished +// (that element is already pushed on the stack) +// Token should return EOF until popEOF is called. +func (d *Decoder) pushEOF() { + // Walk down stack to find Start. + // It might not be the top, because there might be stkNs + // entries above it. + start := d.stk + for start.kind != stkStart { + start = start.next + } + // The stkNs entries below a start are associated with that + // element too; skip over them. + for start.next != nil && start.next.kind == stkNs { + start = start.next + } + s := d.free + if s != nil { + d.free = s.next + } else { + s = new(stack) + } + s.kind = stkEOF + s.next = start.next + start.next = s +} + +// Undo a pushEOF. +// The element must have been finished, so the EOF should be at the top of the stack. +func (d *Decoder) popEOF() bool { + if d.stk == nil || d.stk.kind != stkEOF { + return false + } + d.pop() + return true +} + // Record that we are starting an element with the given name. func (d *Decoder) pushElement(name Name) { s := d.push(stkStart) @@ -395,9 +439,9 @@ func (d *Decoder) popElement(t *EndElement) bool { return false } - // Pop stack until a Start is on the top, undoing the + // Pop stack until a Start or EOF is on the top, undoing the // translations that were associated with the element we just closed. - for d.stk != nil && d.stk.kind != stkStart { + for d.stk != nil && d.stk.kind != stkStart && d.stk.kind != stkEOF { s := d.pop() if s.ok { d.ns[s.name.Local] = s.name.Space @@ -429,10 +473,19 @@ func (d *Decoder) autoClose(t Token) (Token, bool) { return nil, false } +var errRawToken = errors.New("xml: cannot use RawToken from UnmarshalXML method") + // RawToken is like Token but does not verify that // start and end elements match and does not translate // name space prefixes to their corresponding URLs. func (d *Decoder) RawToken() (Token, error) { + if d.unmarshalDepth > 0 { + return nil, errRawToken + } + return d.rawToken() +} + +func (d *Decoder) rawToken() (Token, error) { if d.err != nil { return nil, d.err } @@ -484,8 +537,7 @@ func (d *Decoder) RawToken() (Token, error) { case '?': // ': + esc = esc_gt + case '\t': + esc = esc_tab + case '\n': + esc = esc_nl + case '\r': + esc = esc_cr + default: + if !isInCharacterRange(r) || (r == 0xFFFD && width == 1) { + esc = esc_fffd + break + } + continue + } + p.WriteString(s[last : i-width]) + p.Write(esc) + last = i + } + p.WriteString(s[last:]) +} + // Escape is like EscapeText but omits the error return value. // It is provided for backwards compatibility with Go 1.0. // Code targeting Go 1.1 or later should use EscapeText.