1
0
mirror of https://github.com/golang/go synced 2024-11-25 13:27:57 -07:00

encoding/xml: call MarshalXML(), MarshalXMLAttr(), and MarshalText() defined with pointer receivers even for non-addressable values of non-pointer types on marshalling XML

This commit is contained in:
Dmitry Zenovich 2024-08-17 07:25:13 +03:00
parent f0d880e25c
commit 928e3d925d
2 changed files with 140 additions and 28 deletions

View File

@ -451,22 +451,25 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
if val.CanInterface() && typ.Implements(marshalerType) { if val.CanInterface() && typ.Implements(marshalerType) {
return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate)) return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate))
} }
var pv reflect.Value
if val.CanAddr() { if val.CanAddr() {
pv := val.Addr() pv = val.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerType) { } else {
return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate)) pv = reflect.New(typ)
} pv.Elem().Set(val)
}
if pv.CanInterface() && pv.Type().Implements(marshalerType) {
return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
} }
// Check for text marshaler. // Check for text marshaler.
if val.CanInterface() && typ.Implements(textMarshalerType) { if val.CanInterface() && typ.Implements(textMarshalerType) {
return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate)) return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate))
} }
if val.CanAddr() { if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
pv := val.Addr() return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
}
} }
// Slices and arrays iterate over the elements. They do not have an enclosing tag. // Slices and arrays iterate over the elements. They do not have an enclosing tag.
@ -589,18 +592,23 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
return nil return nil
} }
var pv reflect.Value
if val.CanAddr() { if val.CanAddr() {
pv := val.Addr() pv = val.Addr()
if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) { } else {
attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name) pv = reflect.New(val.Type())
if err != nil { pv.Elem().Set(val)
return err }
}
if attr.Name.Local != "" { if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
start.Attr = append(start.Attr, attr) attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
} if err != nil {
return nil return err
} }
if attr.Name.Local != "" {
start.Attr = append(start.Attr, attr)
}
return nil
} }
if val.CanInterface() && val.Type().Implements(textMarshalerType) { if val.CanInterface() && val.Type().Implements(textMarshalerType) {
@ -612,16 +620,13 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
return nil return nil
} }
if val.CanAddr() { if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
pv := val.Addr() text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { if err != nil {
text, err := pv.Interface().(encoding.TextMarshaler).MarshalText() return err
if err != nil {
return err
}
start.Attr = append(start.Attr, Attr{name, string(text)})
return nil
} }
start.Attr = append(start.Attr, Attr{name, string(text)})
return nil
} }
// Dereference or skip nil pointer, interface values. // Dereference or skip nil pointer, interface values.

View File

@ -6,6 +6,7 @@ package xml
import ( import (
"bytes" "bytes"
"encoding"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -2589,3 +2590,109 @@ func TestClose(t *testing.T) {
}) })
} }
} }
type structWithMarshalXML struct{ V int }
func (s *structWithMarshalXML) MarshalXML(e *Encoder, _ StartElement) error {
_ = e.EncodeToken(StartElement{Name: Name{Local: "marshalled"}})
_ = e.EncodeToken(CharData(strconv.Itoa(s.V)))
_ = e.EncodeToken(EndElement{Name: Name{Local: "marshalled"}})
return nil
}
var _ = Marshaler(&structWithMarshalXML{})
type embedderX struct {
V structWithMarshalXML
}
func TestMarshalXMLWithPointerXMLMarshalers(t *testing.T) {
for _, test := range []struct {
name string
v interface{}
expected string
}{
{name: "a value with MarshalXML", v: structWithMarshalXML{V: 1}, expected: `<marshalled>1</marshalled>`},
{name: "pointer to a value with MarshalXML", v: &structWithMarshalXML{V: 1}, expected: "<marshalled>1</marshalled>"},
{name: "a struct with a value with MarshalXML", v: embedderX{V: structWithMarshalXML{V: 1}}, expected: "<embedderX><marshalled>1</marshalled></embedderX>"},
{name: "a slice of structs with a value with MarshalXML", v: []embedderX{{V: structWithMarshalXML{V: 1}}}, expected: `<embedderX><marshalled>1</marshalled></embedderX>`},
} {
test := test
t.Run(test.name, func(t *testing.T) {
result, err := Marshal(test.v)
if err != nil {
t.Fatalf("Marshal error: %v", err)
}
if string(result) != test.expected {
t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected)
}
})
}
}
type structWithMarshalText struct{ V int }
func (s *structWithMarshalText) MarshalText() ([]byte, error) {
return []byte(fmt.Sprintf("marshalled(%d)", s.V)), nil
}
var _ = encoding.TextMarshaler(&structWithMarshalText{})
type embedderT struct {
V structWithMarshalText
}
func TestMarshalXMLWithPointerTextMarshalers(t *testing.T) {
for _, test := range []struct {
name string
v interface{}
expected string
}{
{name: "a value with MarshalText", v: structWithMarshalText{V: 1}, expected: "<structWithMarshalText>marshalled(1)</structWithMarshalText>"},
{name: "pointer to a value with MarshalText", v: &structWithMarshalText{V: 1}, expected: "<structWithMarshalText>marshalled(1)</structWithMarshalText>"},
{name: "a struct with a value with MarshalText", v: embedderT{V: structWithMarshalText{V: 1}}, expected: "<embedderT><V>marshalled(1)</V></embedderT>"},
{name: "a slice of structs with a value with MarshalText", v: []embedderT{{V: structWithMarshalText{V: 1}}}, expected: "<embedderT><V>marshalled(1)</V></embedderT>"},
} {
test := test
t.Run(test.name, func(t *testing.T) {
result, err := Marshal(test.v)
if err != nil {
t.Fatalf("Marshal error: %v", err)
}
if string(result) != test.expected {
t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected)
}
})
}
}
type structWithMarshalXMLAttr struct{ v int }
func (s *structWithMarshalXMLAttr) MarshalXMLAttr(name Name) (Attr, error) {
return Attr{Name: Name{Local: "marshalled"}, Value: strconv.Itoa(s.v)}, nil
}
var _ = MarshalerAttr(&structWithMarshalXMLAttr{})
type embedderAT struct {
X structWithMarshalXMLAttr `xml:"X,attr"`
T structWithMarshalText `xml:"T,attr"`
XP *structWithMarshalXMLAttr `xml:"XP,attr"`
XT *structWithMarshalText `xml:"XT,attr"`
}
func TestMarshalXMLWithPointerAttrMarshalers(t *testing.T) {
result, err := Marshal(embedderAT{
X: structWithMarshalXMLAttr{1},
T: structWithMarshalText{2},
XP: &structWithMarshalXMLAttr{3},
XT: &structWithMarshalText{4},
})
if err != nil {
t.Fatalf("Marshal error: %v", err)
}
expected := `<embedderAT marshalled="1" T="marshalled(2)" marshalled="3" XT="marshalled(4)"></embedderAT>`
if string(result) != expected {
t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, expected)
}
}