From 7e886740d1c4b62bf7aea2a71048be432aeff945 Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Wed, 14 Aug 2013 14:56:07 -0400 Subject: [PATCH] encoding/json: support encoding.TextMarshaler, encoding.TextUnmarshaler R=golang-dev, bradfitz CC=golang-dev https://golang.org/cl/12703043 --- src/pkg/encoding/json/decode.go | 66 +++++++++++--- src/pkg/encoding/json/decode_test.go | 46 +++++++++- src/pkg/encoding/json/encode.go | 130 +++++++++++++++++++++++++-- src/pkg/encoding/json/encode_test.go | 99 +++++++++++++++++++- 4 files changed, 315 insertions(+), 26 deletions(-) diff --git a/src/pkg/encoding/json/decode.go b/src/pkg/encoding/json/decode.go index 62ac294b89f..b6c23cc77a1 100644 --- a/src/pkg/encoding/json/decode.go +++ b/src/pkg/encoding/json/decode.go @@ -8,6 +8,7 @@ package json import ( + "encoding" "encoding/base64" "errors" "fmt" @@ -293,7 +294,7 @@ func (d *decodeState) value(v reflect.Value) { // until it gets to a non-pointer. // if it encounters an Unmarshaler, indirect stops and returns that. // if decodingNull is true, indirect stops at the last pointer so it can be set to nil. -func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, reflect.Value) { +func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { // If v is a named type and is addressable, // start with its address, so that if the type has pointer methods, // we find them. @@ -322,28 +323,38 @@ func (d *decodeState) indirect(v reflect.Value, decodingNull bool) (Unmarshaler, v.Set(reflect.New(v.Type().Elem())) } if v.Type().NumMethod() > 0 { - if unmarshaler, ok := v.Interface().(Unmarshaler); ok { - return unmarshaler, reflect.Value{} + if u, ok := v.Interface().(Unmarshaler); ok { + return u, nil, reflect.Value{} + } + if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { + return nil, u, reflect.Value{} } } v = v.Elem() } - return nil, v + return nil, nil, v } // array consumes an array from d.data[d.off-1:], decoding into the value v. // the first byte of the array ('[') has been read already. func (d *decodeState) array(v reflect.Value) { // Check for unmarshaler. - unmarshaler, pv := d.indirect(v, false) - if unmarshaler != nil { + u, ut, pv := d.indirect(v, false) + if u != nil { d.off-- - err := unmarshaler.UnmarshalJSON(d.next()) + err := u.UnmarshalJSON(d.next()) if err != nil { d.error(err) } return } + if ut != nil { + d.saveError(&UnmarshalTypeError{"array", v.Type()}) + d.off-- + d.next() + return + } + v = pv // Check type of target. @@ -434,15 +445,21 @@ func (d *decodeState) array(v reflect.Value) { // the first byte of the object ('{') has been read already. func (d *decodeState) object(v reflect.Value) { // Check for unmarshaler. - unmarshaler, pv := d.indirect(v, false) - if unmarshaler != nil { + u, ut, pv := d.indirect(v, false) + if u != nil { d.off-- - err := unmarshaler.UnmarshalJSON(d.next()) + err := u.UnmarshalJSON(d.next()) if err != nil { d.error(err) } return } + if ut != nil { + d.saveError(&UnmarshalTypeError{"object", v.Type()}) + d.off-- + d.next() // skip over { } in input + return + } v = pv // Decoding into nil interface? Switch to non-reflect code. @@ -611,14 +628,37 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool return } wantptr := item[0] == 'n' // null - unmarshaler, pv := d.indirect(v, wantptr) - if unmarshaler != nil { - err := unmarshaler.UnmarshalJSON(item) + u, ut, pv := d.indirect(v, wantptr) + if u != nil { + err := u.UnmarshalJSON(item) if err != nil { d.error(err) } return } + if ut != nil { + if item[0] != '"' { + if fromQuoted { + d.saveError(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.saveError(&UnmarshalTypeError{"string", v.Type()}) + } + } + s, ok := unquoteBytes(item) + if !ok { + if fromQuoted { + d.error(fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", item, v.Type())) + } else { + d.error(errPhase) + } + } + err := ut.UnmarshalText(s) + if err != nil { + d.error(err) + } + return + } + v = pv switch c := item[0]; c { diff --git a/src/pkg/encoding/json/decode_test.go b/src/pkg/encoding/json/decode_test.go index 3fa366500ff..6635ba6ec68 100644 --- a/src/pkg/encoding/json/decode_test.go +++ b/src/pkg/encoding/json/decode_test.go @@ -6,6 +6,7 @@ package json import ( "bytes" + "encoding" "fmt" "image" "reflect" @@ -57,7 +58,7 @@ type unmarshaler struct { } func (u *unmarshaler) UnmarshalJSON(b []byte) error { - *u = unmarshaler{true} // All we need to see that UnmarshalJson is called. + *u = unmarshaler{true} // All we need to see that UnmarshalJSON is called. return nil } @@ -65,6 +66,26 @@ type ustruct struct { M unmarshaler } +type unmarshalerText struct { + T bool +} + +// needed for re-marshaling tests +func (u *unmarshalerText) MarshalText() ([]byte, error) { + return []byte(""), nil +} + +func (u *unmarshalerText) UnmarshalText(b []byte) error { + *u = unmarshalerText{true} // All we need to see that UnmarshalText is called. + return nil +} + +var _ encoding.TextUnmarshaler = (*unmarshalerText)(nil) + +type ustructText struct { + M unmarshalerText +} + var ( um0, um1 unmarshaler // target2 of unmarshaling ump = &um1 @@ -72,6 +93,13 @@ var ( umslice = []unmarshaler{{true}} umslicep = new([]unmarshaler) umstruct = ustruct{unmarshaler{true}} + + um0T, um1T unmarshalerText // target2 of unmarshaling + umpT = &um1T + umtrueT = unmarshalerText{true} + umsliceT = []unmarshalerText{{true}} + umslicepT = new([]unmarshalerText) + umstructT = ustructText{unmarshalerText{true}} ) // Test data structures for anonymous fields. @@ -261,6 +289,13 @@ var unmarshalTests = []unmarshalTest{ {in: `[{"T":false}]`, ptr: &umslicep, out: &umslice}, {in: `{"M":{"T":false}}`, ptr: &umstruct, out: umstruct}, + // UnmarshalText interface test + {in: `"X"`, ptr: &um0T, out: umtrueT}, // use "false" so test will fail if custom unmarshaler is not called + {in: `"X"`, ptr: &umpT, out: &umtrueT}, + {in: `["X"]`, ptr: &umsliceT, out: umsliceT}, + {in: `["X"]`, ptr: &umslicepT, out: &umsliceT}, + {in: `{"M":"X"}`, ptr: &umstructT, out: umstructT}, + { in: `{ "Level0": 1, @@ -505,7 +540,7 @@ func TestUnmarshal(t *testing.T) { dec.UseNumber() } if err := dec.Decode(vv.Interface()); err != nil { - t.Errorf("#%d: error re-unmarshaling: %v", i, err) + t.Errorf("#%d: error re-unmarshaling %#q: %v", i, enc, err) continue } if !reflect.DeepEqual(v.Elem().Interface(), vv.Elem().Interface()) { @@ -979,15 +1014,20 @@ func TestRefUnmarshal(t *testing.T) { // Ref is defined in encode_test.go. R0 Ref R1 *Ref + R2 RefText + R3 *RefText } want := S{ R0: 12, R1: new(Ref), + R2: 13, + R3: new(RefText), } *want.R1 = 12 + *want.R3 = 13 var got S - if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref"}`), &got); err != nil { + if err := Unmarshal([]byte(`{"R0":"ref","R1":"ref","R2":"ref","R3":"ref"}`), &got); err != nil { t.Fatalf("Unmarshal: %v", err) } if !reflect.DeepEqual(got, want) { diff --git a/src/pkg/encoding/json/encode.go b/src/pkg/encoding/json/encode.go index a1127072694..f951250e980 100644 --- a/src/pkg/encoding/json/encode.go +++ b/src/pkg/encoding/json/encode.go @@ -12,6 +12,7 @@ package json import ( "bytes" + "encoding" "encoding/base64" "math" "reflect" @@ -361,17 +362,29 @@ func newTypeEncoder(t reflect.Type, vx reflect.Value) encoderFunc { if !vx.IsValid() { vx = reflect.New(t).Elem() } + _, ok := vx.Interface().(Marshaler) if ok { - return valueIsMarshallerEncoder + return marshalerEncoder } - // T doesn't match the interface. Check against *T too. if vx.Kind() != reflect.Ptr && vx.CanAddr() { _, ok = vx.Addr().Interface().(Marshaler) if ok { - return valueAddrIsMarshallerEncoder + return addrMarshalerEncoder } } + + _, ok = vx.Interface().(encoding.TextMarshaler) + if ok { + return textMarshalerEncoder + } + if vx.Kind() != reflect.Ptr && vx.CanAddr() { + _, ok = vx.Addr().Interface().(encoding.TextMarshaler) + if ok { + return addrTextMarshalerEncoder + } + } + switch vx.Kind() { case reflect.Bool: return boolEncoder @@ -406,7 +419,7 @@ func invalidValueEncoder(e *encodeState, v reflect.Value, quoted bool) { e.WriteString("null") } -func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) { +func marshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { if v.Kind() == reflect.Ptr && v.IsNil() { e.WriteString("null") return @@ -422,9 +435,9 @@ func valueIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) { } } -func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) { +func addrMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { va := v.Addr() - if va.Kind() == reflect.Ptr && va.IsNil() { + if va.IsNil() { e.WriteString("null") return } @@ -439,6 +452,37 @@ func valueAddrIsMarshallerEncoder(e *encodeState, v reflect.Value, quoted bool) } } +func textMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { + if v.Kind() == reflect.Ptr && v.IsNil() { + e.WriteString("null") + return + } + m := v.Interface().(encoding.TextMarshaler) + b, err := m.MarshalText() + if err == nil { + _, err = e.stringBytes(b) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err}) + } +} + +func addrTextMarshalerEncoder(e *encodeState, v reflect.Value, quoted bool) { + va := v.Addr() + if va.IsNil() { + e.WriteString("null") + return + } + m := va.Interface().(encoding.TextMarshaler) + b, err := m.MarshalText() + if err == nil { + _, err = e.stringBytes(b) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err}) + } +} + func boolEncoder(e *encodeState, v reflect.Value, quoted bool) { if quoted { e.WriteByte('"') @@ -728,6 +772,7 @@ func (sv stringValues) Swap(i, j int) { sv[i], sv[j] = sv[j], sv[i] } func (sv stringValues) Less(i, j int) bool { return sv.get(i) < sv.get(j) } func (sv stringValues) get(i int) string { return sv[i].String() } +// NOTE: keep in sync with stringBytes below. func (e *encodeState) string(s string) (int, error) { len0 := e.Len() e.WriteByte('"') @@ -800,6 +845,79 @@ func (e *encodeState) string(s string) (int, error) { return e.Len() - len0, nil } +// NOTE: keep in sync with string above. +func (e *encodeState) stringBytes(s []byte) (int, error) { + len0 := e.Len() + e.WriteByte('"') + start := 0 + for i := 0; i < len(s); { + if b := s[i]; b < utf8.RuneSelf { + if 0x20 <= b && b != '\\' && b != '"' && b != '<' && b != '>' && b != '&' { + i++ + continue + } + if start < i { + e.Write(s[start:i]) + } + switch b { + case '\\', '"': + e.WriteByte('\\') + e.WriteByte(b) + case '\n': + e.WriteByte('\\') + e.WriteByte('n') + case '\r': + e.WriteByte('\\') + e.WriteByte('r') + default: + // This encodes bytes < 0x20 except for \n and \r, + // as well as < and >. The latter are escaped because they + // can lead to security holes when user-controlled strings + // are rendered into JSON and served to some browsers. + e.WriteString(`\u00`) + e.WriteByte(hex[b>>4]) + e.WriteByte(hex[b&0xF]) + } + i++ + start = i + continue + } + c, size := utf8.DecodeRune(s[i:]) + if c == utf8.RuneError && size == 1 { + if start < i { + e.Write(s[start:i]) + } + e.WriteString(`\ufffd`) + i += size + start = i + continue + } + // U+2028 is LINE SEPARATOR. + // U+2029 is PARAGRAPH SEPARATOR. + // They are both technically valid characters in JSON strings, + // but don't work in JSONP, which has to be evaluated as JavaScript, + // and can lead to security holes there. It is valid JSON to + // escape them, so we do so unconditionally. + // See http://timelessrepo.com/json-isnt-a-javascript-subset for discussion. + if c == '\u2028' || c == '\u2029' { + if start < i { + e.Write(s[start:i]) + } + e.WriteString(`\u202`) + e.WriteByte(hex[c&0xF]) + i += size + start = i + continue + } + i += size + } + if start < len(s) { + e.Write(s[start:]) + } + e.WriteByte('"') + return e.Len() - len0, nil +} + // A field represents a single field found in a struct. type field struct { name string diff --git a/src/pkg/encoding/json/encode_test.go b/src/pkg/encoding/json/encode_test.go index 5be0a992e1c..7052e1db7c9 100644 --- a/src/pkg/encoding/json/encode_test.go +++ b/src/pkg/encoding/json/encode_test.go @@ -9,6 +9,7 @@ import ( "math" "reflect" "testing" + "unicode" ) type Optionals struct { @@ -146,19 +147,46 @@ func (Val) MarshalJSON() ([]byte, error) { return []byte(`"val"`), nil } +// RefText has Marshaler and Unmarshaler methods with pointer receiver. +type RefText int + +func (*RefText) MarshalText() ([]byte, error) { + return []byte(`"ref"`), nil +} + +func (r *RefText) UnmarshalText([]byte) error { + *r = 13 + return nil +} + +// ValText has Marshaler methods with value receiver. +type ValText int + +func (ValText) MarshalText() ([]byte, error) { + return []byte(`"val"`), nil +} + func TestRefValMarshal(t *testing.T) { var s = struct { R0 Ref R1 *Ref + R2 RefText + R3 *RefText V0 Val V1 *Val + V2 ValText + V3 *ValText }{ R0: 12, R1: new(Ref), + R2: 14, + R3: new(RefText), V0: 13, V1: new(Val), + V2: 15, + V3: new(ValText), } - const want = `{"R0":"ref","R1":"ref","V0":"val","V1":"val"}` + const want = `{"R0":"ref","R1":"ref","R2":"\"ref\"","R3":"\"ref\"","V0":"val","V1":"val","V2":"\"val\"","V3":"\"val\""}` b, err := Marshal(&s) if err != nil { t.Fatalf("Marshal: %v", err) @@ -175,15 +203,32 @@ func (C) MarshalJSON() ([]byte, error) { return []byte(`"<&>"`), nil } +// CText implements Marshaler and returns unescaped text. +type CText int + +func (CText) MarshalText() ([]byte, error) { + return []byte(`"<&>"`), nil +} + func TestMarshalerEscaping(t *testing.T) { var c C - const want = `"\u003c\u0026\u003e"` + want := `"\u003c\u0026\u003e"` b, err := Marshal(c) if err != nil { - t.Fatalf("Marshal: %v", err) + t.Fatalf("Marshal(c): %v", err) } if got := string(b); got != want { - t.Errorf("got %q, want %q", got, want) + t.Errorf("Marshal(c) = %#q, want %#q", got, want) + } + + var ct CText + want = `"\"\u003c\u0026\u003e\""` + b, err = Marshal(ct) + if err != nil { + t.Fatalf("Marshal(ct): %v", err) + } + if got := string(b); got != want { + t.Errorf("Marshal(ct) = %#q, want %#q", got, want) } } @@ -310,3 +355,49 @@ func TestDuplicatedFieldDisappears(t *testing.T) { t.Fatalf("Marshal: got %s want %s", got, want) } } + +func TestStringBytes(t *testing.T) { + // Test that encodeState.stringBytes and encodeState.string use the same encoding. + es := &encodeState{} + var r []rune + for i := '\u0000'; i <= unicode.MaxRune; i++ { + r = append(r, i) + } + s := string(r) + "\xff\xff\xffhello" // some invalid UTF-8 too + _, err := es.string(s) + if err != nil { + t.Fatal(err) + } + + esBytes := &encodeState{} + _, err = esBytes.stringBytes([]byte(s)) + if err != nil { + t.Fatal(err) + } + + enc := es.Buffer.String() + encBytes := esBytes.Buffer.String() + if enc != encBytes { + i := 0 + for i < len(enc) && i < len(encBytes) && enc[i] == encBytes[i] { + i++ + } + enc = enc[i:] + encBytes = encBytes[i:] + i = 0 + for i < len(enc) && i < len(encBytes) && enc[len(enc)-i-1] == encBytes[len(encBytes)-i-1] { + i++ + } + enc = enc[:len(enc)-i] + encBytes = encBytes[:len(encBytes)-i] + + if len(enc) > 20 { + enc = enc[:20] + "..." + } + if len(encBytes) > 20 { + encBytes = encBytes[:20] + "..." + } + + t.Errorf("encodings differ at %#q vs %#q", enc, encBytes) + } +}