1
0
mirror of https://github.com/golang/go synced 2024-09-30 11:18:33 -06:00

encoding/xml: only initialize nil struct fields when decoding

fieldInfo.value used to initialize nil anonymous struct fields if they
were encountered. This behavior is wanted when decoding, but not when
encoding. When encoding, the value should never be modified, and these
nil fields should be skipped entirely.

To fix the bug, add a bool argument to the function which tells the
code whether we are encoding or decoding.

Finally, add a couple of tests to cover the edge cases pointed out in
the original issue.

Fixes #27240.

Change-Id: Ic97ae4bfe5f2062c8518e03d1dec07c3875e18f6
Reviewed-on: https://go-review.googlesource.com/c/go/+/196809
Run-TryBot: Daniel Martí <mvdan@mvdan.cc>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
This commit is contained in:
Daniel Martí 2019-09-24 18:14:10 +01:00
parent 107ebb1781
commit 8f4151ea67
4 changed files with 50 additions and 15 deletions

View File

@ -482,8 +482,11 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
xmlname := tinfo.xmlname xmlname := tinfo.xmlname
if xmlname.name != "" { if xmlname.name != "" {
start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name start.Name.Space, start.Name.Local = xmlname.xmlns, xmlname.name
} else if v, ok := xmlname.value(val).Interface().(Name); ok && v.Local != "" { } else {
start.Name = v fv := xmlname.value(val, dontInitNilPointers)
if v, ok := fv.Interface().(Name); ok && v.Local != "" {
start.Name = v
}
} }
} }
if start.Name.Local == "" && finfo != nil { if start.Name.Local == "" && finfo != nil {
@ -503,7 +506,7 @@ func (p *printer) marshalValue(val reflect.Value, finfo *fieldInfo, startTemplat
if finfo.flags&fAttr == 0 { if finfo.flags&fAttr == 0 {
continue continue
} }
fv := finfo.value(val) fv := finfo.value(val, dontInitNilPointers)
if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) { if finfo.flags&fOmitEmpty != 0 && isEmptyValue(fv) {
continue continue
@ -806,7 +809,12 @@ func (p *printer) marshalStruct(tinfo *typeInfo, val reflect.Value) error {
if finfo.flags&fAttr != 0 { if finfo.flags&fAttr != 0 {
continue continue
} }
vf := finfo.value(val) vf := finfo.value(val, dontInitNilPointers)
if !vf.IsValid() {
// The field is behind an anonymous struct field that's
// nil. Skip it.
continue
}
switch finfo.flags & fMode { switch finfo.flags & fMode {
case fCDATA, fCharData: case fCDATA, fCharData:

View File

@ -309,6 +309,11 @@ type ChardataEmptyTest struct {
Contents *string `xml:",chardata"` Contents *string `xml:",chardata"`
} }
type PointerAnonFields struct {
*MyInt
*NamedType
}
type MyMarshalerTest struct { type MyMarshalerTest struct {
} }
@ -889,6 +894,18 @@ var marshalTests = []struct {
`</EmbedA>`, `</EmbedA>`,
}, },
// Anonymous struct pointer field which is nil
{
Value: &EmbedB{},
ExpectXML: `<EmbedB><FieldB></FieldB></EmbedB>`,
},
// Other kinds of nil anonymous fields
{
Value: &PointerAnonFields{},
ExpectXML: `<PointerAnonFields></PointerAnonFields>`,
},
// Test that name casing matters // Test that name casing matters
{ {
Value: &NameCasing{Xy: "mixed", XY: "upper", XyA: "mixedA", XYA: "upperA"}, Value: &NameCasing{Xy: "mixed", XY: "upper", XyA: "mixedA", XYA: "upperA"},

View File

@ -435,7 +435,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
} }
return UnmarshalError(e) return UnmarshalError(e)
} }
fv := finfo.value(sv) fv := finfo.value(sv, initNilPointers)
if _, ok := fv.Interface().(Name); ok { if _, ok := fv.Interface().(Name); ok {
fv.Set(reflect.ValueOf(start.Name)) fv.Set(reflect.ValueOf(start.Name))
} }
@ -449,7 +449,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
finfo := &tinfo.fields[i] finfo := &tinfo.fields[i]
switch finfo.flags & fMode { switch finfo.flags & fMode {
case fAttr: case fAttr:
strv := finfo.value(sv) strv := finfo.value(sv, initNilPointers)
if a.Name.Local == finfo.name && (finfo.xmlns == "" || finfo.xmlns == a.Name.Space) { if a.Name.Local == finfo.name && (finfo.xmlns == "" || finfo.xmlns == a.Name.Space) {
if err := d.unmarshalAttr(strv, a); err != nil { if err := d.unmarshalAttr(strv, a); err != nil {
return err return err
@ -465,7 +465,7 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
} }
if !handled && any >= 0 { if !handled && any >= 0 {
finfo := &tinfo.fields[any] finfo := &tinfo.fields[any]
strv := finfo.value(sv) strv := finfo.value(sv, initNilPointers)
if err := d.unmarshalAttr(strv, a); err != nil { if err := d.unmarshalAttr(strv, a); err != nil {
return err return err
} }
@ -478,22 +478,22 @@ func (d *Decoder) unmarshal(val reflect.Value, start *StartElement) error {
switch finfo.flags & fMode { switch finfo.flags & fMode {
case fCDATA, fCharData: case fCDATA, fCharData:
if !saveData.IsValid() { if !saveData.IsValid() {
saveData = finfo.value(sv) saveData = finfo.value(sv, initNilPointers)
} }
case fComment: case fComment:
if !saveComment.IsValid() { if !saveComment.IsValid() {
saveComment = finfo.value(sv) saveComment = finfo.value(sv, initNilPointers)
} }
case fAny, fAny | fElement: case fAny, fAny | fElement:
if !saveAny.IsValid() { if !saveAny.IsValid() {
saveAny = finfo.value(sv) saveAny = finfo.value(sv, initNilPointers)
} }
case fInnerXML: case fInnerXML:
if !saveXML.IsValid() { if !saveXML.IsValid() {
saveXML = finfo.value(sv) saveXML = finfo.value(sv, initNilPointers)
if d.saved == nil { if d.saved == nil {
saveXMLIndex = 0 saveXMLIndex = 0
d.saved = new(bytes.Buffer) d.saved = new(bytes.Buffer)
@ -687,7 +687,7 @@ Loop:
} }
if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local { if len(finfo.parents) == len(parents) && finfo.name == start.Name.Local {
// It's a perfect match, unmarshal the field. // It's a perfect match, unmarshal the field.
return true, d.unmarshal(finfo.value(sv), start) return true, d.unmarshal(finfo.value(sv, initNilPointers), start)
} }
if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local { if len(finfo.parents) > len(parents) && finfo.parents[len(parents)] == start.Name.Local {
// It's a prefix for the field. Break and recurse // It's a prefix for the field. Break and recurse

View File

@ -344,15 +344,25 @@ func (e *TagPathError) Error() string {
return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2) return fmt.Sprintf("%s field %q with tag %q conflicts with field %q with tag %q", e.Struct, e.Field1, e.Tag1, e.Field2, e.Tag2)
} }
const (
initNilPointers = true
dontInitNilPointers = false
)
// value returns v's field value corresponding to finfo. // value returns v's field value corresponding to finfo.
// It's equivalent to v.FieldByIndex(finfo.idx), but initializes // It's equivalent to v.FieldByIndex(finfo.idx), but when passed
// and dereferences pointers as necessary. // initNilPointers, it initializes and dereferences pointers as necessary.
func (finfo *fieldInfo) value(v reflect.Value) reflect.Value { // When passed dontInitNilPointers and a nil pointer is reached, the function
// returns a zero reflect.Value.
func (finfo *fieldInfo) value(v reflect.Value, shouldInitNilPointers bool) reflect.Value {
for i, x := range finfo.idx { for i, x := range finfo.idx {
if i > 0 { if i > 0 {
t := v.Type() t := v.Type()
if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct { if t.Kind() == reflect.Ptr && t.Elem().Kind() == reflect.Struct {
if v.IsNil() { if v.IsNil() {
if !shouldInitNilPointers {
return reflect.Value{}
}
v.Set(reflect.New(v.Type().Elem())) v.Set(reflect.New(v.Type().Elem()))
} }
v = v.Elem() v = v.Elem()