mirror of
https://github.com/golang/go
synced 2024-11-21 21:04:41 -07:00
encoding/json: support encoding.TextMarshaler, encoding.TextUnmarshaler
R=golang-dev, bradfitz CC=golang-dev https://golang.org/cl/12703043
This commit is contained in:
parent
5822e7848a
commit
7e886740d1
@ -8,6 +8,7 @@
|
||||
package json
|
||||
|
||||
import (
|
||||
"encoding"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -293,7 +294,7 @@ func (d *decodeState) value(v reflect.Value) {
|
||||
// until it gets to a non-pointer.
|
||||
// if it encounters an Unmarshaler, indirect stops and returns that.
|
||||
// if decodingNull is true, indirect stops at the last pointer so it can be set to nil.
|
||||
func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) {
|
||||
func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
|
||||
// If v is a named type and is addressable,
|
||||
// start with its address, so that if the type has pointer methods,
|
||||
// we find them.
|
||||
@ -322,28 +323,38 @@ func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler,
|
||||
v.Set(reflect.New(v.Type().Elem()))
|
||||
}
|
||||
if v.Type().NumMethod() > 0 {
|
||||
if unmarshaler, ok := v.Interface().(Unmarshaler); ok {
|
||||
return unmarshaler, reflect.Value{}
|
||||
if u, ok := v.Interface().(Unmarshaler); ok {
|
||||
return u, nil, reflect.Value{}
|
||||
}
|
||||
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
|
||||
return nil, u, reflect.Value{}
|
||||
}
|
||||
}
|
||||
v = v.Elem()
|
||||
}
|
||||
return nil, v
|
||||
return nil, nil, v
|
||||
}
|
||||
|
||||
// array consumes an array from d.data[d.off-1:], decoding into the value v.
|
||||
// the first byte of the array ('[') has been read already.
|
||||
func (d *decodeState) array(v reflect.Value) {
|
||||
// Check for unmarshaler.
|
||||
unmarshaler, pv := d.indirect(v, false)
|
||||
if unmarshaler != nil {
|
||||
u, ut, pv := d.indirect(v, false)
|
||||
if u != nil {
|
||||
d.off--
|
||||
err := unmarshaler.UnmarshalJSON(d.next())
|
||||
err := u.UnmarshalJSON(d.next())
|
||||
if err != nil {
|
||||
d.error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if ut != nil {
|
||||
d.saveError(&UnmarshalTypeError{"array", v.Type()})
|
||||
d.off--
|
||||
d.next()
|
||||
return
|
||||
}
|
||||
|
||||
v = pv
|
||||
|
||||
// Check type of target.
|
||||
@ -434,15 +445,21 @@ func (d *decodeState) array(v reflect.Value) {
|
||||
// the first byte of the object ('{') has been read already.
|
||||
func (d *decodeState) object(v reflect.Value) {
|
||||
// Check for unmarshaler.
|
||||
unmarshaler, pv := d.indirect(v, false)
|
||||
if unmarshaler != nil {
|
||||
u, ut, pv := d.indirect(v, false)
|
||||
if u != nil {
|
||||
d.off--
|
||||
err := unmarshaler.UnmarshalJSON(d.next())
|
||||
err := u.UnmarshalJSON(d.next())
|
||||
if err != nil {
|
||||
d.error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if ut != nil {
|
||||
d.saveError(&UnmarshalTypeError{"object", v.Type()})
|
||||
d.off--
|
||||
d.next() // skip over { } in input
|
||||
return
|
||||
}
|
||||
v = pv
|
||||
|
||||
// Decoding into nil interface? Switch to non-reflect code.
|
||||
@ -611,14 +628,37 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
|
||||
return
|
||||
}
|
||||
wantptr := item[0] == 'n' // null
|
||||
unmarshaler, pv := d.indirect(v, wantptr)
|
||||
if unmarshaler != nil {
|
||||
err := unmarshaler.UnmarshalJSON(item)
|
||||
u, ut, pv := d.indirect(v, wantptr)
|
||||
if u != nil {
|
||||
err := u.UnmarshalJSON(item)
|
||||
if err != nil {
|
||||
d.error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if ut != nil {
|
||||
if item[0] != '"' {
|
||||
if fromQuoted {
|
||||
d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
|
||||
} else {
|
||||
d.saveError(&UnmarshalTypeError{"string", v.Type()})
|
||||
}
|
||||
}
|
||||
s, ok := unquoteBytes(item)
|
||||
if !ok {
|
||||
if fromQuoted {
|
||||
d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type()))
|
||||
} else {
|
||||
d.error(errPhase)
|
||||
}
|
||||
}
|
||||
err := ut.UnmarshalText(s)
|
||||
if err != nil {
|
||||
d.error(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
v = pv
|
||||
|
||||
switch c := item[0]; c {
|
||||
|
@ -6,6 +6,7 @@ package json
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding"
|
||||
"fmt"
|
||||
"image"
|
||||
"reflect"
|
||||
@ -57,7 +58,7 @@ type unmarshaler struct {
|
||||
}
|
||||
|
||||
func (u *unmarshaler) UnmarshalJSON(b []byte) error {
|
||||
*u = unmarshaler{true} // All we need to see that UnmarshalJson is called.
|
||||
*u = unmarshaler{true} // All we need to see that UnmarshalJSON is called.
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -65,6 +66,26 @@ type ustruct struct {
|
||||
M unmarshaler
|
||||
}
|
||||
|
||||
type unmarshalerText struct {
|
||||
T bool
|
||||
}
|
||||
|
||||
// needed for re-marshaling tests
|
||||
func (u *unmarshalerText) MarshalText() ([]byte, error) {
|
||||
return []byte(""), nil
|
||||
}
|
||||
|
||||
func (u *unmarshalerText) UnmarshalText(b []byte) error {
|
||||
*u = unmarshalerText{true} // All we need to see that UnmarshalText is called.
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ encoding.TextUnmarshaler = (*unmarshalerText)(nil)
|
||||
|
||||
type ustructText struct {
|
||||
M unmarshalerText
|
||||
}
|
||||
|
||||
var (
|
||||
um0, um1 unmarshaler // target2 of unmarshaling
|
||||
ump = &um1
|
||||
@ -72,6 +93,13 @@ var (
|
||||
umslice = []unmarshaler{{true}}
|
||||
umslicep = new([]unmarshaler)
|
||||
umstruct = ustruct{unmarshaler{true}}
|
||||
|
||||
um0T, um1T unmarshalerText // target2 of unmarshaling
|
||||
umpT = &um1T
|
||||
umtrueT = unmarshalerText{true}
|
||||
umsliceT = []unmarshalerText{{true}}
|
||||
umslicepT = new([]unmarshalerText)
|
||||
umstructT = ustructText{unmarshalerText{true}}
|
||||
)
|
||||
|
||||
// Test data structures for anonymous fields.
|
||||
@ -261,6 +289,13 @@ var unmarshalTests = []unmarshalTest{
|
||||
{in: `[{"T":false}]`, ptr: &umslicep, out: &umslice},
|
||||
{in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct},
|
||||
|
||||
// UnmarshalText interface test
|
||||
{in: `"X"`, ptr: &um0T, out: umtrueT}, // use "false" so test will fail if custom unmarshaler is not called
|
||||
{in: `"X"`, ptr: &umpT, out: &umtrueT},
|
||||
{in: `["X"]`, ptr: &umsliceT, out: umsliceT},
|
||||
{in: `["X"]`, ptr: &umslicepT, out: &umsliceT},
|
||||
{in: `{"M":"X"}`, ptr: &umstructT, out: umstructT},
|
||||
|
||||
{
|
||||
in: `{
|
||||
"Level0": 1,
|
||||
@ -505,7 +540,7 @@ func TestUnmarshal(t *testing.T) {
|
||||
dec.UseNumber()
|
||||
}
|
||||
if err := dec.Decode(vv.Interface()); err != nil {
|
||||
t.Errorf("#%d: error re-unmarshaling: %v", i, err)
|
||||
t.Errorf("#%d: error re-unmarshaling %#q: %v", i, enc, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(v.Elem().Interface(), vv.Elem().Interface()) {
|
||||
@ -979,15 +1014,20 @@ func TestRefUnmarshal(t *testing.T) {
|
||||
// Ref is defined in encode_test.go.
|
||||
R0 Ref
|
||||
R1 *Ref
|
||||
R2 RefText
|
||||
R3 *RefText
|
||||
}
|
||||
want := S{
|
||||
R0: 12,
|
||||
R1: new(Ref),
|
||||
R2: 13,
|
||||
R3: new(RefText),
|
||||
}
|
||||
*want.R1 = 12
|
||||
*want.R3 = 13
|
||||
|
||||
var got S
|
||||
if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref"}`), &got); err != nil {
|
||||
if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref","R2":"ref","R3":"ref"}`), &got); err != nil {
|
||||
t.Fatalf("Unmarshal: %v", err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
|
@ -12,6 +12,7 @@ package json
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding"
|
||||
"encoding/base64"
|
||||
"math"
|
||||
"reflect"
|
||||
@ -361,17 +362,29 @@ func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc {
|
||||
if !vx.IsValid() {
|
||||
vx = reflect.New(t).Elem()
|
||||
}
|
||||
|
||||
_, ok := vx.Interface().(Marshaler)
|
||||
if ok {
|
||||
return valueIsMarshallerEncoder
|
||||
return marshalerEncoder
|
||||
}
|
||||
// T doesn't match the interface. Check against *T too.
|
||||
if vx.Kind() != reflect.Ptr && vx.CanAddr() {
|
||||
_, ok = vx.Addr().Interface().(Marshaler)
|
||||
if ok {
|
||||
return valueAddrIsMarshallerEncoder
|
||||
return addrMarshalerEncoder
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = vx.Interface().(encoding.TextMarshaler)
|
||||
if ok {
|
||||
return textMarshalerEncoder
|
||||
}
|
||||
if vx.Kind() != reflect.Ptr && vx.CanAddr() {
|
||||
_, ok = vx.Addr().Interface().(encoding.TextMarshaler)
|
||||
if ok {
|
||||
return addrTextMarshalerEncoder
|
||||
}
|
||||
}
|
||||
|
||||
switch vx.Kind() {
|
||||
case reflect.Bool:
|
||||
return boolEncoder
|
||||
@ -406,7 +419,7 @@ func invalidValueEncoder(e *encodeState, v reflect.Value, quoted bool) {
|
||||
e.WriteString("null")
|
||||
}
|
||||
|
||||
func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
|
||||
func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
e.WriteString("null")
|
||||
return
|
||||
@ -422,9 +435,9 @@ func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) {
|
||||
func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
|
||||
va := v.Addr()
|
||||
if va.Kind() == reflect.Ptr && va.IsNil() {
|
||||
if va.IsNil() {
|
||||
e.WriteString("null")
|
||||
return
|
||||
}
|
||||
@ -439,6 +452,37 @@ func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool)
|
||||
}
|
||||
}
|
||||
|
||||
func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
e.WriteString("null")
|
||||
return
|
||||
}
|
||||
m := v.Interface().(encoding.TextMarshaler)
|
||||
b, err := m.MarshalText()
|
||||
if err == nil {
|
||||
_, err = e.stringBytes(b)
|
||||
}
|
||||
if err != nil {
|
||||
e.error(&MarshalerError{v.Type(), err})
|
||||
}
|
||||
}
|
||||
|
||||
func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) {
|
||||
va := v.Addr()
|
||||
if va.IsNil() {
|
||||
e.WriteString("null")
|
||||
return
|
||||
}
|
||||
m := va.Interface().(encoding.TextMarshaler)
|
||||
b, err := m.MarshalText()
|
||||
if err == nil {
|
||||
_, err = e.stringBytes(b)
|
||||
}
|
||||
if err != nil {
|
||||
e.error(&MarshalerError{v.Type(), err})
|
||||
}
|
||||
}
|
||||
|
||||
func boolEncoder(e *encodeState, v reflect.Value, quoted bool) {
|
||||
if quoted {
|
||||
e.WriteByte('"')
|
||||
@ -728,6 +772,7 @@ func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] }
|
||||
func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) }
|
||||
func (sv stringValues) get(i int) string { return sv[i].String() }
|
||||
|
||||
// NOTE: keep in sync with stringBytes below.
|
||||
func (e *encodeState) string(s string) (int, error) {
|
||||
len0 := e.Len()
|
||||
e.WriteByte('"')
|
||||
@ -800,6 +845,79 @@ func (e *encodeState) string(s string) (int, error) {
|
||||
return e.Len() - len0, nil
|
||||
}
|
||||
|
||||
// NOTE: keep in sync with string above.
|
||||
func (e *encodeState) stringBytes(s []byte) (int, error) {
|
||||
len0 := e.Len()
|
||||
e.WriteByte('"')
|
||||
start := 0
|
||||
for i := 0; i < len(s); {
|
||||
if b := s[i]; b < utf8.RuneSelf {
|
||||
if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' {
|
||||
i++
|
||||
continue
|
||||
}
|
||||
if start < i {
|
||||
e.Write(s[start:i])
|
||||
}
|
||||
switch b {
|
||||
case '\\', '"':
|
||||
e.WriteByte('\\')
|
||||
e.WriteByte(b)
|
||||
case '\n':
|
||||
e.WriteByte('\\')
|
||||
e.WriteByte('n')
|
||||
case '\r':
|
||||
e.WriteByte('\\')
|
||||
e.WriteByte('r')
|
||||
default:
|
||||
// This encodes bytes < 0x20 except for \n and \r,
|
||||
// as well as < and >. The latter are escaped because they
|
||||
// can lead to security holes when user-controlled strings
|
||||
// are rendered into JSON and served to some browsers.
|
||||
e.WriteString(`\u00`)
|
||||
e.WriteByte(hex[b>>4])
|
||||
e.WriteByte(hex[b&0xF])
|
||||
}
|
||||
i++
|
||||
start = i
|
||||
continue
|
||||
}
|
||||
c, size := utf8.DecodeRune(s[i:])
|
||||
if c == utf8.RuneError && size == 1 {
|
||||
if start < i {
|
||||
e.Write(s[start:i])
|
||||
}
|
||||
e.WriteString(`\ufffd`)
|
||||
i += size
|
||||
start = i
|
||||
continue
|
||||
}
|
||||
// U+2028 is LINE SEPARATOR.
|
||||
// U+2029 is PARAGRAPH SEPARATOR.
|
||||
// They are both technically valid characters in JSON strings,
|
||||
// but don't work in JSONP, which has to be evaluated as JavaScript,
|
||||
// and can lead to security holes there. It is valid JSON to
|
||||
// escape them, so we do so unconditionally.
|
||||
// See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion.
|
||||
if c == '\u2028' || c == '\u2029' {
|
||||
if start < i {
|
||||
e.Write(s[start:i])
|
||||
}
|
||||
e.WriteString(`\u202`)
|
||||
e.WriteByte(hex[c&0xF])
|
||||
i += size
|
||||
start = i
|
||||
continue
|
||||
}
|
||||
i += size
|
||||
}
|
||||
if start < len(s) {
|
||||
e.Write(s[start:])
|
||||
}
|
||||
e.WriteByte('"')
|
||||
return e.Len() - len0, nil
|
||||
}
|
||||
|
||||
// A field represents a single field found in a struct.
|
||||
type field struct {
|
||||
name string
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type Optionals struct {
|
||||
@ -146,19 +147,46 @@ func (Val) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"val"`), nil
|
||||
}
|
||||
|
||||
// RefText has Marshaler and Unmarshaler methods with pointer receiver.
|
||||
type RefText int
|
||||
|
||||
func (*RefText) MarshalText() ([]byte, error) {
|
||||
return []byte(`"ref"`), nil
|
||||
}
|
||||
|
||||
func (r *RefText) UnmarshalText([]byte) error {
|
||||
*r = 13
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValText has Marshaler methods with value receiver.
|
||||
type ValText int
|
||||
|
||||
func (ValText) MarshalText() ([]byte, error) {
|
||||
return []byte(`"val"`), nil
|
||||
}
|
||||
|
||||
func TestRefValMarshal(t *testing.T) {
|
||||
var s = struct {
|
||||
R0 Ref
|
||||
R1 *Ref
|
||||
R2 RefText
|
||||
R3 *RefText
|
||||
V0 Val
|
||||
V1 *Val
|
||||
V2 ValText
|
||||
V3 *ValText
|
||||
}{
|
||||
R0: 12,
|
||||
R1: new(Ref),
|
||||
R2: 14,
|
||||
R3: new(RefText),
|
||||
V0: 13,
|
||||
V1: new(Val),
|
||||
V2: 15,
|
||||
V3: new(ValText),
|
||||
}
|
||||
const want = `{"R0":"ref","R1":"ref","V0":"val","V1":"val"}`
|
||||
const want = `{"R0":"ref","R1":"ref","R2":"\"ref\"","R3":"\"ref\"","V0":"val","V1":"val","V2":"\"val\"","V3":"\"val\""}`
|
||||
b, err := Marshal(&s)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal: %v", err)
|
||||
@ -175,15 +203,32 @@ func (C) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"<&>"`), nil
|
||||
}
|
||||
|
||||
// CText implements Marshaler and returns unescaped text.
|
||||
type CText int
|
||||
|
||||
func (CText) MarshalText() ([]byte, error) {
|
||||
return []byte(`"<&>"`), nil
|
||||
}
|
||||
|
||||
func TestMarshalerEscaping(t *testing.T) {
|
||||
var c C
|
||||
const want = `"\u003c\u0026\u003e"`
|
||||
want := `"\u003c\u0026\u003e"`
|
||||
b, err := Marshal(c)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal: %v", err)
|
||||
t.Fatalf("Marshal(c): %v", err)
|
||||
}
|
||||
if got := string(b); got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
t.Errorf("Marshal(c) = %#q, want %#q", got, want)
|
||||
}
|
||||
|
||||
var ct CText
|
||||
want = `"\"\u003c\u0026\u003e\""`
|
||||
b, err = Marshal(ct)
|
||||
if err != nil {
|
||||
t.Fatalf("Marshal(ct): %v", err)
|
||||
}
|
||||
if got := string(b); got != want {
|
||||
t.Errorf("Marshal(ct) = %#q, want %#q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
@ -310,3 +355,49 @@ func TestDuplicatedFieldDisappears(t *testing.T) {
|
||||
t.Fatalf("Marshal: got %s want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringBytes(t *testing.T) {
|
||||
// Test that encodeState.stringBytes and encodeState.string use the same encoding.
|
||||
es := &encodeState{}
|
||||
var r []rune
|
||||
for i := '\u0000'; i <= unicode.MaxRune; i++ {
|
||||
r = append(r, i)
|
||||
}
|
||||
s := string(r) + "\xff\xff\xffhello" // some invalid UTF-8 too
|
||||
_, err := es.string(s)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
esBytes := &encodeState{}
|
||||
_, err = esBytes.stringBytes([]byte(s))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
enc := es.Buffer.String()
|
||||
encBytes := esBytes.Buffer.String()
|
||||
if enc != encBytes {
|
||||
i := 0
|
||||
for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] {
|
||||
i++
|
||||
}
|
||||
enc = enc[i:]
|
||||
encBytes = encBytes[i:]
|
||||
i = 0
|
||||
for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] {
|
||||
i++
|
||||
}
|
||||
enc = enc[:len(enc)-i]
|
||||
encBytes = encBytes[:len(encBytes)-i]
|
||||
|
||||
if len(enc) > 20 {
|
||||
enc = enc[:20] + "..."
|
||||
}
|
||||
if len(encBytes) > 20 {
|
||||
encBytes = encBytes[:20] + "..."
|
||||
}
|
||||
|
||||
t.Errorf("encodings differ at %#q vs %#q", enc, encBytes)
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user