diff --git a/src/encoding/xml/marshal.go b/src/encoding/xml/marshal.go
index 05b5542dfb4..d6cd3c267af 100644
--- a/src/encoding/xml/marshal.go
+++ b/src/encoding/xml/marshal.go
@@ -451,22 +451,25 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
if val.CanInterface() && typ.Implements(marshalerType) {
return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate))
}
+
+ var pv reflect.Value
if val.CanAddr() {
- pv := val.Addr()
- if pv.CanInterface() && pv.Type().Implements(marshalerType) {
- return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
- }
+ pv = val.Addr()
+ } else {
+ pv = reflect.New(typ)
+ pv.Elem().Set(val)
+ }
+
+ if pv.CanInterface() && pv.Type().Implements(marshalerType) {
+ return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate))
}
// Check for text marshaler.
if val.CanInterface() && typ.Implements(textMarshalerType) {
return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate))
}
- if val.CanAddr() {
- pv := val.Addr()
- if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
- return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
- }
+ if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
+ return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate))
}
// Slices and arrays iterate over the elements. They do not have an enclosing tag.
@@ -589,18 +592,23 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
return nil
}
+ var pv reflect.Value
if val.CanAddr() {
- pv := val.Addr()
- if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
- attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
- if err != nil {
- return err
- }
- if attr.Name.Local != "" {
- start.Attr = append(start.Attr, attr)
- }
- return nil
+ pv = val.Addr()
+ } else {
+ pv = reflect.New(val.Type())
+ pv.Elem().Set(val)
+ }
+
+ if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) {
+ attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name)
+ if err != nil {
+ return err
}
+ if attr.Name.Local != "" {
+ start.Attr = append(start.Attr, attr)
+ }
+ return nil
}
if val.CanInterface() && val.Type().Implements(textMarshalerType) {
@@ -612,16 +620,13 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value)
return nil
}
- if val.CanAddr() {
- pv := val.Addr()
- if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
- text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
- if err != nil {
- return err
- }
- start.Attr = append(start.Attr, Attr{name, string(text)})
- return nil
+ if pv.CanInterface() && pv.Type().Implements(textMarshalerType) {
+ text, err := pv.Interface().(encoding.TextMarshaler).MarshalText()
+ if err != nil {
+ return err
}
+ start.Attr = append(start.Attr, Attr{name, string(text)})
+ return nil
}
// Dereference or skip nil pointer, interface values.
diff --git a/src/encoding/xml/marshal_test.go b/src/encoding/xml/marshal_test.go
index b8bce7170a6..3f6faaac9ef 100644
--- a/src/encoding/xml/marshal_test.go
+++ b/src/encoding/xml/marshal_test.go
@@ -6,6 +6,7 @@ package xml
import (
"bytes"
+ "encoding"
"errors"
"fmt"
"io"
@@ -2589,3 +2590,109 @@ func TestClose(t *testing.T) {
})
}
}
+
+type structWithMarshalXML struct{ V int }
+
+func (s *structWithMarshalXML) MarshalXML(e *Encoder, _ StartElement) error {
+ _ = e.EncodeToken(StartElement{Name: Name{Local: "marshalled"}})
+ _ = e.EncodeToken(CharData(strconv.Itoa(s.V)))
+ _ = e.EncodeToken(EndElement{Name: Name{Local: "marshalled"}})
+ return nil
+}
+
+var _ = Marshaler(&structWithMarshalXML{})
+
+type embedderX struct {
+ V structWithMarshalXML
+}
+
+func TestMarshalXMLWithPointerXMLMarshalers(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ v interface{}
+ expected string
+ }{
+ {name: "a value with MarshalXML", v: structWithMarshalXML{V: 1}, expected: `1`},
+ {name: "pointer to a value with MarshalXML", v: &structWithMarshalXML{V: 1}, expected: "1"},
+ {name: "a struct with a value with MarshalXML", v: embedderX{V: structWithMarshalXML{V: 1}}, expected: "1"},
+ {name: "a slice of structs with a value with MarshalXML", v: []embedderX{{V: structWithMarshalXML{V: 1}}}, expected: `1`},
+ } {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ result, err := Marshal(test.v)
+ if err != nil {
+ t.Fatalf("Marshal error: %v", err)
+ }
+ if string(result) != test.expected {
+ t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected)
+ }
+ })
+ }
+}
+
+type structWithMarshalText struct{ V int }
+
+func (s *structWithMarshalText) MarshalText() ([]byte, error) {
+ return []byte(fmt.Sprintf("marshalled(%d)", s.V)), nil
+}
+
+var _ = encoding.TextMarshaler(&structWithMarshalText{})
+
+type embedderT struct {
+ V structWithMarshalText
+}
+
+func TestMarshalXMLWithPointerTextMarshalers(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ v interface{}
+ expected string
+ }{
+ {name: "a value with MarshalText", v: structWithMarshalText{V: 1}, expected: "marshalled(1)"},
+ {name: "pointer to a value with MarshalText", v: &structWithMarshalText{V: 1}, expected: "marshalled(1)"},
+ {name: "a struct with a value with MarshalText", v: embedderT{V: structWithMarshalText{V: 1}}, expected: "marshalled(1)"},
+ {name: "a slice of structs with a value with MarshalText", v: []embedderT{{V: structWithMarshalText{V: 1}}}, expected: "marshalled(1)"},
+ } {
+ test := test
+ t.Run(test.name, func(t *testing.T) {
+ result, err := Marshal(test.v)
+ if err != nil {
+ t.Fatalf("Marshal error: %v", err)
+ }
+ if string(result) != test.expected {
+ t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected)
+ }
+ })
+ }
+}
+
+type structWithMarshalXMLAttr struct{ v int }
+
+func (s *structWithMarshalXMLAttr) MarshalXMLAttr(name Name) (Attr, error) {
+ return Attr{Name: Name{Local: "marshalled"}, Value: strconv.Itoa(s.v)}, nil
+}
+
+var _ = MarshalerAttr(&structWithMarshalXMLAttr{})
+
+type embedderAT struct {
+ X structWithMarshalXMLAttr `xml:"X,attr"`
+ T structWithMarshalText `xml:"T,attr"`
+ XP *structWithMarshalXMLAttr `xml:"XP,attr"`
+ XT *structWithMarshalText `xml:"XT,attr"`
+}
+
+func TestMarshalXMLWithPointerAttrMarshalers(t *testing.T) {
+ result, err := Marshal(embedderAT{
+ X: structWithMarshalXMLAttr{1},
+ T: structWithMarshalText{2},
+ XP: &structWithMarshalXMLAttr{3},
+ XT: &structWithMarshalText{4},
+ })
+ if err != nil {
+ t.Fatalf("Marshal error: %v", err)
+ }
+ expected := ``
+ if string(result) != expected {
+ t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, expected)
+ }
+}