diff --git a/src/encoding/xml/marshal.go b/src/encoding/xml/marshal.go index 05b5542dfb4..d6cd3c267af 100644 --- a/src/encoding/xml/marshal.go +++ b/src/encoding/xml/marshal.go @@ -451,22 +451,25 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat if val.CanInterface() && typ.Implements(marshalerType) { return p.marshalInterface(val.Interface().(Marshaler), defaultStart(typ, finfo, startTemplate)) } + + var pv reflect.Value if val.CanAddr() { - pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(marshalerType) { - return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate)) - } + pv = val.Addr() + } else { + pv = reflect.New(typ) + pv.Elem().Set(val) + } + + if pv.CanInterface() && pv.Type().Implements(marshalerType) { + return p.marshalInterface(pv.Interface().(Marshaler), defaultStart(pv.Type(), finfo, startTemplate)) } // Check for text marshaler. if val.CanInterface() && typ.Implements(textMarshalerType) { return p.marshalTextInterface(val.Interface().(encoding.TextMarshaler), defaultStart(typ, finfo, startTemplate)) } - if val.CanAddr() { - pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { - return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate)) - } + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + return p.marshalTextInterface(pv.Interface().(encoding.TextMarshaler), defaultStart(pv.Type(), finfo, startTemplate)) } // Slices and arrays iterate over the elements. They do not have an enclosing tag. @@ -589,18 +592,23 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value) return nil } + var pv reflect.Value if val.CanAddr() { - pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) { - attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name) - if err != nil { - return err - } - if attr.Name.Local != "" { - start.Attr = append(start.Attr, attr) - } - return nil + pv = val.Addr() + } else { + pv = reflect.New(val.Type()) + pv.Elem().Set(val) + } + + if pv.CanInterface() && pv.Type().Implements(marshalerAttrType) { + attr, err := pv.Interface().(MarshalerAttr).MarshalXMLAttr(name) + if err != nil { + return err } + if attr.Name.Local != "" { + start.Attr = append(start.Attr, attr) + } + return nil } if val.CanInterface() && val.Type().Implements(textMarshalerType) { @@ -612,16 +620,13 @@ func (p *printer) marshalAttr(start *StartElement, name Name, val reflect.Value) return nil } - if val.CanAddr() { - pv := val.Addr() - if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { - text, err := pv.Interface().(encoding.TextMarshaler).MarshalText() - if err != nil { - return err - } - start.Attr = append(start.Attr, Attr{name, string(text)}) - return nil + if pv.CanInterface() && pv.Type().Implements(textMarshalerType) { + text, err := pv.Interface().(encoding.TextMarshaler).MarshalText() + if err != nil { + return err } + start.Attr = append(start.Attr, Attr{name, string(text)}) + return nil } // Dereference or skip nil pointer, interface values. diff --git a/src/encoding/xml/marshal_test.go b/src/encoding/xml/marshal_test.go index b8bce7170a6..3f6faaac9ef 100644 --- a/src/encoding/xml/marshal_test.go +++ b/src/encoding/xml/marshal_test.go @@ -6,6 +6,7 @@ package xml import ( "bytes" + "encoding" "errors" "fmt" "io" @@ -2589,3 +2590,109 @@ func TestClose(t *testing.T) { }) } } + +type structWithMarshalXML struct{ V int } + +func (s *structWithMarshalXML) MarshalXML(e *Encoder, _ StartElement) error { + _ = e.EncodeToken(StartElement{Name: Name{Local: "marshalled"}}) + _ = e.EncodeToken(CharData(strconv.Itoa(s.V))) + _ = e.EncodeToken(EndElement{Name: Name{Local: "marshalled"}}) + return nil +} + +var _ = Marshaler(&structWithMarshalXML{}) + +type embedderX struct { + V structWithMarshalXML +} + +func TestMarshalXMLWithPointerXMLMarshalers(t *testing.T) { + for _, test := range []struct { + name string + v interface{} + expected string + }{ + {name: "a value with MarshalXML", v: structWithMarshalXML{V: 1}, expected: `1`}, + {name: "pointer to a value with MarshalXML", v: &structWithMarshalXML{V: 1}, expected: "1"}, + {name: "a struct with a value with MarshalXML", v: embedderX{V: structWithMarshalXML{V: 1}}, expected: "1"}, + {name: "a slice of structs with a value with MarshalXML", v: []embedderX{{V: structWithMarshalXML{V: 1}}}, expected: `1`}, + } { + test := test + t.Run(test.name, func(t *testing.T) { + result, err := Marshal(test.v) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + if string(result) != test.expected { + t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected) + } + }) + } +} + +type structWithMarshalText struct{ V int } + +func (s *structWithMarshalText) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("marshalled(%d)", s.V)), nil +} + +var _ = encoding.TextMarshaler(&structWithMarshalText{}) + +type embedderT struct { + V structWithMarshalText +} + +func TestMarshalXMLWithPointerTextMarshalers(t *testing.T) { + for _, test := range []struct { + name string + v interface{} + expected string + }{ + {name: "a value with MarshalText", v: structWithMarshalText{V: 1}, expected: "marshalled(1)"}, + {name: "pointer to a value with MarshalText", v: &structWithMarshalText{V: 1}, expected: "marshalled(1)"}, + {name: "a struct with a value with MarshalText", v: embedderT{V: structWithMarshalText{V: 1}}, expected: "marshalled(1)"}, + {name: "a slice of structs with a value with MarshalText", v: []embedderT{{V: structWithMarshalText{V: 1}}}, expected: "marshalled(1)"}, + } { + test := test + t.Run(test.name, func(t *testing.T) { + result, err := Marshal(test.v) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + if string(result) != test.expected { + t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected) + } + }) + } +} + +type structWithMarshalXMLAttr struct{ v int } + +func (s *structWithMarshalXMLAttr) MarshalXMLAttr(name Name) (Attr, error) { + return Attr{Name: Name{Local: "marshalled"}, Value: strconv.Itoa(s.v)}, nil +} + +var _ = MarshalerAttr(&structWithMarshalXMLAttr{}) + +type embedderAT struct { + X structWithMarshalXMLAttr `xml:"X,attr"` + T structWithMarshalText `xml:"T,attr"` + XP *structWithMarshalXMLAttr `xml:"XP,attr"` + XT *structWithMarshalText `xml:"XT,attr"` +} + +func TestMarshalXMLWithPointerAttrMarshalers(t *testing.T) { + result, err := Marshal(embedderAT{ + X: structWithMarshalXMLAttr{1}, + T: structWithMarshalText{2}, + XP: &structWithMarshalXMLAttr{3}, + XT: &structWithMarshalText{4}, + }) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + expected := `` + if string(result) != expected { + t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, expected) + } +}