From f0d880e25cbe4497f7aeb368b1d35d721e586d7d Mon Sep 17 00:00:00 2001 From: Dmitry Zenovich Date: Sat, 17 Aug 2024 07:23:21 +0300 Subject: [PATCH] encoding/json: call MarshalJSON() and MarshalText() defined with pointer receivers even for non-addressable values of non-pointer types on marshalling JSON --- src/encoding/json/encode.go | 48 ++++++++----------- src/encoding/json/encode_test.go | 80 ++++++++++++++++++++++++++++++++ 2 files changed, 100 insertions(+), 28 deletions(-) diff --git a/src/encoding/json/encode.go b/src/encoding/json/encode.go index 988de716124..6b8b9134fa9 100644 --- a/src/encoding/json/encode.go +++ b/src/encoding/json/encode.go @@ -363,7 +363,7 @@ func typeEncoder(t reflect.Type) encoderFunc { } // Compute the real encoder and replace the indirect func with it. - f = newTypeEncoder(t, true) + f = newTypeEncoder(t) wg.Done() encoderCache.Store(t, f) return f @@ -375,20 +375,19 @@ var ( ) // newTypeEncoder constructs an encoderFunc for a type. -// The returned encoder only checks CanAddr when allowAddr is true. -func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { +func newTypeEncoder(t reflect.Type) encoderFunc { // If we have a non-pointer value whose type implements // Marshaler with a value receiver, then we're better off taking // the address of the value - otherwise we end up with an // allocation as we cast the value to an interface. - if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) { - return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) + if t.Kind() != reflect.Pointer && reflect.PointerTo(t).Implements(marshalerType) { + return addrMarshalerEncoder } if t.Implements(marshalerType) { return marshalerEncoder } - if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) { - return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false)) + if t.Kind() != reflect.Pointer && reflect.PointerTo(t).Implements(textMarshalerType) { + return addrTextMarshalerEncoder } if t.Implements(textMarshalerType) { return textMarshalerEncoder @@ -451,7 +450,13 @@ func marshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { } func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { - va := v.Addr() + var va reflect.Value + if v.CanAddr() { + va = v.Addr() + } else { + va = reflect.New(v.Type()) + va.Elem().Set(v) + } if va.IsNil() { e.WriteString("null") return @@ -487,7 +492,13 @@ func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { } func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { - va := v.Addr() + var va reflect.Value + if v.CanAddr() { + va = v.Addr() + } else { + va = reflect.New(v.Type()) + va.Elem().Set(v) + } if va.IsNil() { e.WriteString("null") return @@ -893,25 +904,6 @@ func newPtrEncoder(t reflect.Type) encoderFunc { return enc.encode } -type condAddrEncoder struct { - canAddrEnc, elseEnc encoderFunc -} - -func (ce condAddrEncoder) encode(e *encodeState, v reflect.Value, opts encOpts) { - if v.CanAddr() { - ce.canAddrEnc(e, v, opts) - } else { - ce.elseEnc(e, v, opts) - } -} - -// newCondAddrEncoder returns an encoder that checks whether its value -// CanAddr and delegates to canAddrEnc if so, else to elseEnc. -func newCondAddrEncoder(canAddrEnc, elseEnc encoderFunc) encoderFunc { - enc := condAddrEncoder{canAddrEnc: canAddrEnc, elseEnc: elseEnc} - return enc.encode -} - func isValidTag(s string) bool { if s == "" { return false diff --git a/src/encoding/json/encode_test.go b/src/encoding/json/encode_test.go index 23a14d0b172..ac12880109e 100644 --- a/src/encoding/json/encode_test.go +++ b/src/encoding/json/encode_test.go @@ -1219,3 +1219,83 @@ func TestIssue63379(t *testing.T) { } } } + +type structWithMarshalJSON struct{ v int } + +func (s *structWithMarshalJSON) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`"marshalled(%d)"`, s.v)), nil +} + +var _ = Marshaler(&structWithMarshalJSON{}) + +type embedderJ struct { + V structWithMarshalJSON +} + +func TestMarshalJSONWithPointerJSONMarshalers(t *testing.T) { + for _, test := range []struct { + name string + v interface{} + expected string + }{ + {name: "a value with MarshalJSON", v: structWithMarshalJSON{v: 1}, expected: `"marshalled(1)"`}, + {name: "pointer to a value with MarshalJSON", v: &structWithMarshalJSON{v: 1}, expected: `"marshalled(1)"`}, + {name: "a map with a value with MarshalJSON", v: map[string]interface{}{"v": structWithMarshalJSON{v: 1}}, expected: `{"v":"marshalled(1)"}`}, + {name: "a map with a pointer to a value with MarshalJSON", v: map[string]interface{}{"v": &structWithMarshalJSON{v: 1}}, expected: `{"v":"marshalled(1)"}`}, + {name: "a slice of maps with a value with MarshalJSON", v: []map[string]interface{}{{"v": structWithMarshalJSON{v: 1}}}, expected: `[{"v":"marshalled(1)"}]`}, + {name: "a slice of maps with a pointer to a value with MarshalJSON", v: []map[string]interface{}{{"v": &structWithMarshalJSON{v: 1}}}, expected: `[{"v":"marshalled(1)"}]`}, + {name: "a struct with a value with MarshalJSON", v: embedderJ{V: structWithMarshalJSON{v: 1}}, expected: `{"V":"marshalled(1)"}`}, + {name: "a slice of structs with a value with MarshalJSON", v: []embedderJ{{V: structWithMarshalJSON{v: 1}}}, expected: `[{"V":"marshalled(1)"}]`}, + } { + test := test + t.Run(test.name, func(t *testing.T) { + result, err := Marshal(test.v) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + if string(result) != test.expected { + t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected) + } + }) + } +} + +type structWithMarshalText struct{ v int } + +func (s *structWithMarshalText) MarshalText() ([]byte, error) { + return []byte(fmt.Sprintf("marshalled(%d)", s.v)), nil +} + +var _ = encoding.TextMarshaler(&structWithMarshalText{}) + +type embedderT struct { + V structWithMarshalText +} + +func TestMarshalJSONWithPointerTextMarshalers(t *testing.T) { + for _, test := range []struct { + name string + v interface{} + expected string + }{ + {name: "a value with MarshalText", v: structWithMarshalText{v: 1}, expected: `"marshalled(1)"`}, + {name: "pointer to a value with MarshalText", v: &structWithMarshalText{v: 1}, expected: `"marshalled(1)"`}, + {name: "a map with a value with MarshalText", v: map[string]interface{}{"v": structWithMarshalText{v: 1}}, expected: `{"v":"marshalled(1)"}`}, + {name: "a map with a pointer to a value with MarshalText", v: map[string]interface{}{"v": &structWithMarshalText{v: 1}}, expected: `{"v":"marshalled(1)"}`}, + {name: "a slice of maps with a value with MarshalText", v: []map[string]interface{}{{"v": structWithMarshalText{v: 1}}}, expected: `[{"v":"marshalled(1)"}]`}, + {name: "a slice of maps with a pointer to a value with MarshalText", v: []map[string]interface{}{{"v": &structWithMarshalText{v: 1}}}, expected: `[{"v":"marshalled(1)"}]`}, + {name: "a struct with a value with MarshalText", v: embedderT{V: structWithMarshalText{v: 1}}, expected: `{"V":"marshalled(1)"}`}, + {name: "a slice of structs with a value with MarshalText", v: []embedderT{{V: structWithMarshalText{v: 1}}}, expected: `[{"V":"marshalled(1)"}]`}, + } { + test := test + t.Run(test.name, func(t *testing.T) { + result, err := Marshal(test.v) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + if string(result) != test.expected { + t.Errorf("Marshal:\n\tgot: %s\n\twant: %s", result, test.expected) + } + }) + } +}