mirror of
https://github.com/golang/go
synced 2024-11-22 07:54:40 -07:00
Rework gobs to fix bad bug related to sharing of id's between encoder and decoder side.
Fix is to move all decoder state into the decoder object. Fixes #215. R=rsc CC=golang-dev https://golang.org/cl/155077
This commit is contained in:
parent
50c04132ac
commit
30b1b9a36a
@ -37,7 +37,6 @@ var encodeT = []EncodeT{
|
||||
EncodeT{1 << 63, []byte{0xF8, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
|
||||
}
|
||||
|
||||
|
||||
// Test basic encode/decode routines for unsigned integers
|
||||
func TestUintCodec(t *testing.T) {
|
||||
b := new(bytes.Buffer);
|
||||
@ -592,9 +591,15 @@ func TestEndToEnd(t *testing.T) {
|
||||
t: &T2{"this is T2"},
|
||||
};
|
||||
b := new(bytes.Buffer);
|
||||
encode(b, t1);
|
||||
err := NewEncoder(b).Encode(t1);
|
||||
if err != nil {
|
||||
t.Error("encode:", err)
|
||||
}
|
||||
var _t1 T1;
|
||||
decode(b, getTypeInfoNoError(reflect.Typeof(_t1)).id, &_t1);
|
||||
err = NewDecoder(b).Decode(&_t1);
|
||||
if err != nil {
|
||||
t.Fatal("decode:", err)
|
||||
}
|
||||
if !reflect.DeepEqual(t1, &_t1) {
|
||||
t.Errorf("encode expected %v got %v", *t1, _t1)
|
||||
}
|
||||
@ -610,8 +615,9 @@ func TestOverflow(t *testing.T) {
|
||||
}
|
||||
var it inputT;
|
||||
var err os.Error;
|
||||
id := getTypeInfoNoError(reflect.Typeof(it)).id;
|
||||
b := new(bytes.Buffer);
|
||||
enc := NewEncoder(b);
|
||||
dec := NewDecoder(b);
|
||||
|
||||
// int8
|
||||
b.Reset();
|
||||
@ -623,8 +629,8 @@ func TestOverflow(t *testing.T) {
|
||||
mini int8;
|
||||
}
|
||||
var o1 outi8;
|
||||
encode(b, it);
|
||||
err = decode(b, id, &o1);
|
||||
enc.Encode(it);
|
||||
err = dec.Decode(&o1);
|
||||
if err == nil || err.String() != `value for "maxi" out of range` {
|
||||
t.Error("wrong overflow error for int8:", err)
|
||||
}
|
||||
@ -632,8 +638,8 @@ func TestOverflow(t *testing.T) {
|
||||
mini: math.MinInt8 - 1,
|
||||
};
|
||||
b.Reset();
|
||||
encode(b, it);
|
||||
err = decode(b, id, &o1);
|
||||
enc.Encode(it);
|
||||
err = dec.Decode(&o1);
|
||||
if err == nil || err.String() != `value for "mini" out of range` {
|
||||
t.Error("wrong underflow error for int8:", err)
|
||||
}
|
||||
@ -648,8 +654,8 @@ func TestOverflow(t *testing.T) {
|
||||
mini int16;
|
||||
}
|
||||
var o2 outi16;
|
||||
encode(b, it);
|
||||
err = decode(b, id, &o2);
|
||||
enc.Encode(it);
|
||||
err = dec.Decode(&o2);
|
||||
if err == nil || err.String() != `value for "maxi" out of range` {
|
||||
t.Error("wrong overflow error for int16:", err)
|
||||
}
|
||||
@ -657,8 +663,8 @@ func TestOverflow(t *testing.T) {
|
||||
mini: math.MinInt16 - 1,
|
||||
};
|
||||
b.Reset();
|
||||
encode(b, it);
|
||||
err = decode(b, id, &o2);
|
||||
enc.Encode(it);
|
||||
err = dec.Decode(&o2);
|
||||
if err == nil || err.String() != `value for "mini" out of range` {
|
||||
t.Error("wrong underflow error for int16:", err)
|
||||
}
|
||||
@ -673,8 +679,8 @@ func TestOverflow(t *testing.T) {
|
||||
mini int32;
|
||||
}
|
||||
var o3 outi32;
|
||||
encode(b, it);
|
||||
err = decode(b, id, &o3);
|
||||
enc.Encode(it);
|
||||
err = dec.Decode(&o3);
|
||||
if err == nil || err.String() != `value for "maxi" out of range` {
|
||||
t.Error("wrong overflow error for int32:", err)
|
||||
}
|
||||
@ -682,8 +688,8 @@ func TestOverflow(t *testing.T) {
|
||||
mini: math.MinInt32 - 1,
|
||||
};
|
||||
b.Reset();
|
||||
encode(b, it);
|
||||
err = decode(b, id, &o3);
|
||||
enc.Encode(it);
|
||||
err = dec.Decode(&o3);
|
||||
if err == nil || err.String() != `value for "mini" out of range` {
|
||||
t.Error("wrong underflow error for int32:", err)
|
||||
}
|
||||
@ -697,8 +703,8 @@ func TestOverflow(t *testing.T) {
|
||||
maxu uint8;
|
||||
}
|
||||
var o4 outu8;
|
||||
encode(b, it);
|
||||
err = decode(b, id, &o4);
|
||||
enc.Encode(it);
|
||||
err = dec.Decode(&o4);
|
||||
if err == nil || err.String() != `value for "maxu" out of range` {
|
||||
t.Error("wrong overflow error for uint8:", err)
|
||||
}
|
||||
@ -712,8 +718,8 @@ func TestOverflow(t *testing.T) {
|
||||
maxu uint16;
|
||||
}
|
||||
var o5 outu16;
|
||||
encode(b, it);
|
||||
err = decode(b, id, &o5);
|
||||
enc.Encode(it);
|
||||
err = dec.Decode(&o5);
|
||||
if err == nil || err.String() != `value for "maxu" out of range` {
|
||||
t.Error("wrong overflow error for uint16:", err)
|
||||
}
|
||||
@ -727,8 +733,8 @@ func TestOverflow(t *testing.T) {
|
||||
maxu uint32;
|
||||
}
|
||||
var o6 outu32;
|
||||
encode(b, it);
|
||||
err = decode(b, id, &o6);
|
||||
enc.Encode(it);
|
||||
err = dec.Decode(&o6);
|
||||
if err == nil || err.String() != `value for "maxu" out of range` {
|
||||
t.Error("wrong overflow error for uint32:", err)
|
||||
}
|
||||
@ -743,8 +749,8 @@ func TestOverflow(t *testing.T) {
|
||||
minf float32;
|
||||
}
|
||||
var o7 outf32;
|
||||
encode(b, it);
|
||||
err = decode(b, id, &o7);
|
||||
enc.Encode(it);
|
||||
err = dec.Decode(&o7);
|
||||
if err == nil || err.String() != `value for "maxf" out of range` {
|
||||
t.Error("wrong overflow error for float32:", err)
|
||||
}
|
||||
@ -761,9 +767,13 @@ func TestNesting(t *testing.T) {
|
||||
rt.next = new(RT);
|
||||
rt.next.a = "level2";
|
||||
b := new(bytes.Buffer);
|
||||
encode(b, rt);
|
||||
NewEncoder(b).Encode(rt);
|
||||
var drt RT;
|
||||
decode(b, getTypeInfoNoError(reflect.Typeof(drt)).id, &drt);
|
||||
dec := NewDecoder(b);
|
||||
err := dec.Decode(&drt);
|
||||
if err != nil {
|
||||
t.Errorf("decoder error:", err)
|
||||
}
|
||||
if drt.a != rt.a {
|
||||
t.Errorf("nesting: encode expected %v got %v", *rt, drt)
|
||||
}
|
||||
@ -809,10 +819,11 @@ func TestAutoIndirection(t *testing.T) {
|
||||
**t1.d = new(int);
|
||||
***t1.d = 17777;
|
||||
b := new(bytes.Buffer);
|
||||
encode(b, t1);
|
||||
enc := NewEncoder(b);
|
||||
enc.Encode(t1);
|
||||
dec := NewDecoder(b);
|
||||
var t0 T0;
|
||||
t0Id := getTypeInfoNoError(reflect.Typeof(t0)).id;
|
||||
decode(b, t0Id, &t0);
|
||||
dec.Decode(&t0);
|
||||
if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
|
||||
t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0)
|
||||
}
|
||||
@ -830,9 +841,9 @@ func TestAutoIndirection(t *testing.T) {
|
||||
**t2.a = new(int);
|
||||
***t2.a = 17;
|
||||
b.Reset();
|
||||
encode(b, t2);
|
||||
enc.Encode(t2);
|
||||
t0 = T0{};
|
||||
decode(b, t0Id, &t0);
|
||||
dec.Decode(&t0);
|
||||
if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 {
|
||||
t.Errorf("t2->t0 expected {17 177 1777 17777}; got %v", t0)
|
||||
}
|
||||
@ -840,32 +851,30 @@ func TestAutoIndirection(t *testing.T) {
|
||||
// Now transfer t0 into t1
|
||||
t0 = T0{17, 177, 1777, 17777};
|
||||
b.Reset();
|
||||
encode(b, t0);
|
||||
enc.Encode(t0);
|
||||
t1 = T1{};
|
||||
t1Id := getTypeInfoNoError(reflect.Typeof(t1)).id;
|
||||
decode(b, t1Id, &t1);
|
||||
dec.Decode(&t1);
|
||||
if t1.a != 17 || *t1.b != 177 || **t1.c != 1777 || ***t1.d != 17777 {
|
||||
t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.a, *t1.b, **t1.c, ***t1.d)
|
||||
}
|
||||
|
||||
// Now transfer t0 into t2
|
||||
b.Reset();
|
||||
encode(b, t0);
|
||||
enc.Encode(t0);
|
||||
t2 = T2{};
|
||||
t2Id := getTypeInfoNoError(reflect.Typeof(t2)).id;
|
||||
decode(b, t2Id, &t2);
|
||||
dec.Decode(&t2);
|
||||
if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
|
||||
t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d)
|
||||
}
|
||||
|
||||
// Now do t2 again but without pre-allocated pointers.
|
||||
b.Reset();
|
||||
encode(b, t0);
|
||||
enc.Encode(t0);
|
||||
***t2.a = 0;
|
||||
**t2.b = 0;
|
||||
*t2.c = 0;
|
||||
t2.d = 0;
|
||||
decode(b, t2Id, &t2);
|
||||
dec.Decode(&t2);
|
||||
if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 {
|
||||
t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d)
|
||||
}
|
||||
@ -889,11 +898,14 @@ func TestReorderedFields(t *testing.T) {
|
||||
rt0.b = "hello";
|
||||
rt0.c = 3.14159;
|
||||
b := new(bytes.Buffer);
|
||||
encode(b, rt0);
|
||||
rt0Id := getTypeInfoNoError(reflect.Typeof(rt0)).id;
|
||||
NewEncoder(b).Encode(rt0);
|
||||
dec := NewDecoder(b);
|
||||
var rt1 RT1;
|
||||
// Wire type is RT0, local type is RT1.
|
||||
decode(b, rt0Id, &rt1);
|
||||
err := dec.Decode(&rt1);
|
||||
if err != nil {
|
||||
t.Error("decode error:", err)
|
||||
}
|
||||
if rt0.a != rt1.a || rt0.b != rt1.b || rt0.c != rt1.c {
|
||||
t.Errorf("rt1->rt0: expected %v; got %v", rt0, rt1)
|
||||
}
|
||||
@ -927,11 +939,11 @@ func TestIgnoredFields(t *testing.T) {
|
||||
it0.ignore_i = &RT1{3.1, "hi", 7, "hello"};
|
||||
|
||||
b := new(bytes.Buffer);
|
||||
encode(b, it0);
|
||||
rt0Id := getTypeInfoNoError(reflect.Typeof(it0)).id;
|
||||
NewEncoder(b).Encode(it0);
|
||||
dec := NewDecoder(b);
|
||||
var rt1 RT1;
|
||||
// Wire type is IT0, local type is RT1.
|
||||
err := decode(b, rt0Id, &rt1);
|
||||
err := dec.Decode(&rt1);
|
||||
if err != nil {
|
||||
t.Error("error: ", err)
|
||||
}
|
||||
|
@ -348,7 +348,7 @@ func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) {
|
||||
// Execution engine
|
||||
|
||||
// The encoder engine is an array of instructions indexed by field number of the incoming
|
||||
// data. It is executed with random access according to field number.
|
||||
// decoder. It is executed with random access according to field number.
|
||||
type decEngine struct {
|
||||
instr []decInstr;
|
||||
numInstr int; // the number of active instructions
|
||||
@ -515,7 +515,7 @@ var decIgnoreOpMap = map[typeId]decOp{
|
||||
|
||||
// Return the decoding op for the base type under rt and
|
||||
// the indirection count to reach it.
|
||||
func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) {
|
||||
func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error) {
|
||||
typ, indir := indirect(rt);
|
||||
op, ok := decOpMap[reflect.Typeof(typ)];
|
||||
if !ok {
|
||||
@ -528,7 +528,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
|
||||
break;
|
||||
}
|
||||
elemId := wireId.gobType().(*sliceType).Elem;
|
||||
elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name);
|
||||
elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name);
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@ -540,7 +540,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
|
||||
case *reflect.ArrayType:
|
||||
name = "element of " + name;
|
||||
elemId := wireId.gobType().(*arrayType).Elem;
|
||||
elemOp, elemIndir, err := decOpFor(elemId, t.Elem(), name);
|
||||
elemOp, elemIndir, err := dec.decOpFor(elemId, t.Elem(), name);
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@ -551,7 +551,7 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
|
||||
|
||||
case *reflect.StructType:
|
||||
// Generate a closure that calls out to the engine for the nested type.
|
||||
enginePtr, err := getDecEnginePtr(wireId, typ);
|
||||
enginePtr, err := dec.getDecEnginePtr(wireId, typ);
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
@ -568,14 +568,14 @@ func decOpFor(wireId typeId, rt reflect.Type, name string) (decOp, int, os.Error
|
||||
}
|
||||
|
||||
// Return the decoding op for a field that has no destination.
|
||||
func decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
|
||||
func (dec *Decoder) decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
|
||||
op, ok := decIgnoreOpMap[wireId];
|
||||
if !ok {
|
||||
// Special cases
|
||||
switch t := wireId.gobType().(type) {
|
||||
case *sliceType:
|
||||
elemId := wireId.gobType().(*sliceType).Elem;
|
||||
elemOp, err := decIgnoreOpFor(elemId);
|
||||
elemOp, err := dec.decIgnoreOpFor(elemId);
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -585,7 +585,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
|
||||
|
||||
case *arrayType:
|
||||
elemId := wireId.gobType().(*arrayType).Elem;
|
||||
elemOp, err := decIgnoreOpFor(elemId);
|
||||
elemOp, err := dec.decIgnoreOpFor(elemId);
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -595,7 +595,7 @@ func decIgnoreOpFor(wireId typeId) (decOp, os.Error) {
|
||||
|
||||
case *structType:
|
||||
// Generate a closure that calls out to the engine for the nested type.
|
||||
enginePtr, err := getIgnoreEnginePtr(wireId);
|
||||
enginePtr, err := dec.getIgnoreEnginePtr(wireId);
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -676,11 +676,18 @@ func compatibleType(fr reflect.Type, fw typeId) bool {
|
||||
return true;
|
||||
}
|
||||
|
||||
func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
|
||||
func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
|
||||
srt, ok1 := rt.(*reflect.StructType);
|
||||
wireStruct, ok2 := wireId.gobType().(*structType);
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errNotStruct
|
||||
var wireStruct *structType;
|
||||
// Builtin types can come from global pool; the rest must be defined by the decoder
|
||||
if t, ok := builtinIdToType[remoteId]; ok {
|
||||
wireStruct = t.(*structType)
|
||||
} else {
|
||||
w, ok2 := dec.wireType[remoteId];
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errNotStruct
|
||||
}
|
||||
wireStruct = w.s;
|
||||
}
|
||||
engine = new(decEngine);
|
||||
engine.instr = make([]decInstr, len(wireStruct.field));
|
||||
@ -692,7 +699,7 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
|
||||
ovfl := overflow(wireField.name);
|
||||
// TODO(r): anonymous names
|
||||
if !present {
|
||||
op, err := decIgnoreOpFor(wireField.id);
|
||||
op, err := dec.decIgnoreOpFor(wireField.id);
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -700,10 +707,10 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
|
||||
continue;
|
||||
}
|
||||
if !compatibleType(localField.Type, wireField.id) {
|
||||
details := " (" + wireField.id.String() + " incompatible with " + localField.Type.String() + ") in type " + wireId.Name();
|
||||
details := " (" + wireField.id.String() + " incompatible with " + localField.Type.String() + ") in type " + remoteId.Name();
|
||||
return nil, os.ErrorString("gob: wrong type for field " + wireField.name + details);
|
||||
}
|
||||
op, indir, err := decOpFor(wireField.id, localField.Type, localField.Name);
|
||||
op, indir, err := dec.decOpFor(wireField.id, localField.Type, localField.Name);
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -713,23 +720,19 @@ func compileDec(wireId typeId, rt reflect.Type) (engine *decEngine, err os.Error
|
||||
return;
|
||||
}
|
||||
|
||||
var decoderCache = make(map[reflect.Type]map[typeId]**decEngine)
|
||||
var ignorerCache = make(map[typeId]**decEngine)
|
||||
|
||||
// typeLock must be held.
|
||||
func getDecEnginePtr(wireId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) {
|
||||
decoderMap, ok := decoderCache[rt];
|
||||
func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) {
|
||||
decoderMap, ok := dec.decoderCache[rt];
|
||||
if !ok {
|
||||
decoderMap = make(map[typeId]**decEngine);
|
||||
decoderCache[rt] = decoderMap;
|
||||
dec.decoderCache[rt] = decoderMap;
|
||||
}
|
||||
if enginePtr, ok = decoderMap[wireId]; !ok {
|
||||
if enginePtr, ok = decoderMap[remoteId]; !ok {
|
||||
// To handle recursive types, mark this engine as underway before compiling.
|
||||
enginePtr = new(*decEngine);
|
||||
decoderMap[wireId] = enginePtr;
|
||||
*enginePtr, err = compileDec(wireId, rt);
|
||||
decoderMap[remoteId] = enginePtr;
|
||||
*enginePtr, err = dec.compileDec(remoteId, rt);
|
||||
if err != nil {
|
||||
decoderMap[wireId] = nil, false
|
||||
decoderMap[remoteId] = nil, false
|
||||
}
|
||||
}
|
||||
return;
|
||||
@ -740,35 +743,28 @@ type emptyStruct struct{}
|
||||
|
||||
var emptyStructType = reflect.Typeof(emptyStruct{})
|
||||
|
||||
// typeLock must be held.
|
||||
func getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
|
||||
func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) {
|
||||
var ok bool;
|
||||
if enginePtr, ok = ignorerCache[wireId]; !ok {
|
||||
if enginePtr, ok = dec.ignorerCache[wireId]; !ok {
|
||||
// To handle recursive types, mark this engine as underway before compiling.
|
||||
enginePtr = new(*decEngine);
|
||||
ignorerCache[wireId] = enginePtr;
|
||||
*enginePtr, err = compileDec(wireId, emptyStructType);
|
||||
dec.ignorerCache[wireId] = enginePtr;
|
||||
*enginePtr, err = dec.compileDec(wireId, emptyStructType);
|
||||
if err != nil {
|
||||
ignorerCache[wireId] = nil, false
|
||||
dec.ignorerCache[wireId] = nil, false
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error {
|
||||
func (dec *Decoder) decode(wireId typeId, e interface{}) os.Error {
|
||||
// Dereference down to the underlying struct type.
|
||||
rt, indir := indirect(reflect.Typeof(e));
|
||||
st, ok := rt.(*reflect.StructType);
|
||||
if !ok {
|
||||
return os.ErrorString("gob: decode can't handle " + rt.String())
|
||||
}
|
||||
typeLock.Lock();
|
||||
if _, ok := idToType[wireId]; !ok {
|
||||
typeLock.Unlock();
|
||||
return errBadType;
|
||||
}
|
||||
enginePtr, err := getDecEnginePtr(wireId, rt);
|
||||
typeLock.Unlock();
|
||||
enginePtr, err := dec.getDecEnginePtr(wireId, rt);
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -777,7 +773,7 @@ func decode(b *bytes.Buffer, wireId typeId, e interface{}) os.Error {
|
||||
name := rt.Name();
|
||||
return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name);
|
||||
}
|
||||
return decodeStruct(engine, st, b, uintptr(reflect.NewValue(e).Addr()), indir);
|
||||
return decodeStruct(engine, st, dec.state.b, uintptr(reflect.NewValue(e).Addr()), indir);
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
@ -8,17 +8,20 @@ import (
|
||||
"bytes";
|
||||
"io";
|
||||
"os";
|
||||
"reflect";
|
||||
"sync";
|
||||
)
|
||||
|
||||
// A Decoder manages the receipt of type and data information read from the
|
||||
// remote side of a connection.
|
||||
type Decoder struct {
|
||||
mutex sync.Mutex; // each item must be received atomically
|
||||
r io.Reader; // source of the data
|
||||
seen map[typeId]*wireType; // which types we've already seen described
|
||||
state *decodeState; // reads data from in-memory buffer
|
||||
countState *decodeState; // reads counts from wire
|
||||
mutex sync.Mutex; // each item must be received atomically
|
||||
r io.Reader; // source of the data
|
||||
wireType map[typeId]*wireType; // map from remote ID to local description
|
||||
decoderCache map[reflect.Type]map[typeId]**decEngine; // cache of compiled engines
|
||||
ignorerCache map[typeId]**decEngine; // ditto for ignored objects
|
||||
state *decodeState; // reads data from in-memory buffer
|
||||
countState *decodeState; // reads counts from wire
|
||||
buf []byte;
|
||||
oneByte []byte;
|
||||
}
|
||||
@ -27,8 +30,10 @@ type Decoder struct {
|
||||
func NewDecoder(r io.Reader) *Decoder {
|
||||
dec := new(Decoder);
|
||||
dec.r = r;
|
||||
dec.seen = make(map[typeId]*wireType);
|
||||
dec.wireType = make(map[typeId]*wireType);
|
||||
dec.state = newDecodeState(nil); // buffer set in Decode(); rest is unimportant
|
||||
dec.decoderCache = make(map[reflect.Type]map[typeId]**decEngine);
|
||||
dec.ignorerCache = make(map[typeId]**decEngine);
|
||||
dec.oneByte = make([]byte, 1);
|
||||
|
||||
return dec;
|
||||
@ -36,16 +41,16 @@ func NewDecoder(r io.Reader) *Decoder {
|
||||
|
||||
func (dec *Decoder) recvType(id typeId) {
|
||||
// Have we already seen this type? That's an error
|
||||
if _, alreadySeen := dec.seen[id]; alreadySeen {
|
||||
if _, alreadySeen := dec.wireType[id]; alreadySeen {
|
||||
dec.state.err = os.ErrorString("gob: duplicate type received");
|
||||
return;
|
||||
}
|
||||
|
||||
// Type:
|
||||
wire := new(wireType);
|
||||
decode(dec.state.b, tWireType, wire);
|
||||
dec.state.err = dec.decode(tWireType, wire);
|
||||
// Remember we've seen this type.
|
||||
dec.seen[id] = wire;
|
||||
dec.wireType[id] = wire;
|
||||
}
|
||||
|
||||
// Decode reads the next value from the connection and stores
|
||||
@ -97,7 +102,13 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
|
||||
}
|
||||
|
||||
// No, it's a value.
|
||||
dec.state.err = decode(dec.state.b, id, e);
|
||||
// Make sure the type has been defined already.
|
||||
_, ok := dec.wireType[id];
|
||||
if !ok {
|
||||
dec.state.err = errBadType;
|
||||
break;
|
||||
}
|
||||
dec.state.err = dec.decode(id, e);
|
||||
break;
|
||||
}
|
||||
return dec.state.err;
|
||||
|
@ -235,19 +235,15 @@ func (enc *Encoder) send() {
|
||||
enc.w.Write(enc.buf[0:total]);
|
||||
}
|
||||
|
||||
func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
|
||||
func (enc *Encoder) sendType(origt reflect.Type) {
|
||||
// Drill down to the base type.
|
||||
rt, _ := indirect(origt);
|
||||
|
||||
// We only send structs - everything else is basic or an error
|
||||
switch rt.(type) {
|
||||
switch rt := rt.(type) {
|
||||
default:
|
||||
// Basic types do not need to be described, but if this is a top-level
|
||||
// type, it's a user error, at least for now.
|
||||
if topLevel {
|
||||
enc.badType(rt)
|
||||
}
|
||||
return;
|
||||
// Basic types do not need to be described.
|
||||
return
|
||||
case *reflect.StructType:
|
||||
// Structs do need to be described.
|
||||
break
|
||||
@ -255,10 +251,9 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
|
||||
// Probably a bad field in a struct.
|
||||
enc.badType(rt);
|
||||
return;
|
||||
case *reflect.ArrayType, *reflect.SliceType:
|
||||
// Array and slice types are not sent, only their element types.
|
||||
// If we see one here it's user error; probably a bad top-level value.
|
||||
enc.badType(rt);
|
||||
// Array and slice types are not sent, only their element types.
|
||||
case reflect.ArrayOrSliceType:
|
||||
enc.sendType(rt.Elem());
|
||||
return;
|
||||
}
|
||||
|
||||
@ -289,7 +284,7 @@ func (enc *Encoder) sendType(origt reflect.Type, topLevel bool) {
|
||||
// Now send the inner types
|
||||
st := rt.(*reflect.StructType);
|
||||
for i := 0; i < st.NumField(); i++ {
|
||||
enc.sendType(st.Field(i).Type, false)
|
||||
enc.sendType(st.Field(i).Type)
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -301,6 +296,12 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
|
||||
panicln("Encoder: buffer not empty")
|
||||
}
|
||||
rt, _ := indirect(reflect.Typeof(e));
|
||||
// Must be a struct
|
||||
if _, ok := rt.(*reflect.StructType); !ok {
|
||||
enc.badType(rt);
|
||||
return enc.state.err;
|
||||
}
|
||||
|
||||
|
||||
// Make sure we're single-threaded through here.
|
||||
enc.mutex.Lock();
|
||||
@ -310,7 +311,7 @@ func (enc *Encoder) Encode(e interface{}) os.Error {
|
||||
// First, have we already sent this type?
|
||||
if _, alreadySent := enc.sent[rt]; !alreadySent {
|
||||
// No, so send it.
|
||||
enc.sendType(rt, true);
|
||||
enc.sendType(rt);
|
||||
if enc.state.err != nil {
|
||||
enc.state.b.Reset();
|
||||
enc.countState.b.Reset();
|
||||
|
@ -32,122 +32,10 @@ type ET3 struct {
|
||||
// Like ET1 but with a different type for a field
|
||||
type ET4 struct {
|
||||
a int;
|
||||
et2 *ET1;
|
||||
et2 float;
|
||||
next int;
|
||||
}
|
||||
|
||||
func TestBasicEncoder(t *testing.T) {
|
||||
b := new(bytes.Buffer);
|
||||
enc := NewEncoder(b);
|
||||
et1 := new(ET1);
|
||||
et1.a = 7;
|
||||
et1.et2 = new(ET2);
|
||||
enc.Encode(et1);
|
||||
if enc.state.err != nil {
|
||||
t.Error("encoder fail:", enc.state.err)
|
||||
}
|
||||
|
||||
// Decode the result by hand to verify;
|
||||
state := newDecodeState(b);
|
||||
// The output should be:
|
||||
// 0) The length, 38.
|
||||
length := decodeUint(state);
|
||||
if length != 38 {
|
||||
t.Fatal("0. expected length 38; got", length)
|
||||
}
|
||||
// 1) -7: the type id of ET1
|
||||
id1 := decodeInt(state);
|
||||
if id1 >= 0 {
|
||||
t.Fatal("expected ET1 negative id; got", id1)
|
||||
}
|
||||
// 2) The wireType for ET1
|
||||
wire1 := new(wireType);
|
||||
err := decode(b, tWireType, wire1);
|
||||
if err != nil {
|
||||
t.Fatal("error decoding ET1 type:", err)
|
||||
}
|
||||
info := getTypeInfoNoError(reflect.Typeof(ET1{}));
|
||||
trueWire1 := &wireType{s: info.id.gobType().(*structType)};
|
||||
if !reflect.DeepEqual(wire1, trueWire1) {
|
||||
t.Fatalf("invalid wireType for ET1: expected %+v; got %+v\n", *trueWire1, *wire1)
|
||||
}
|
||||
// 3) The length, 21.
|
||||
length = decodeUint(state);
|
||||
if length != 21 {
|
||||
t.Fatal("3. expected length 21; got", length)
|
||||
}
|
||||
// 4) -8: the type id of ET2
|
||||
id2 := decodeInt(state);
|
||||
if id2 >= 0 {
|
||||
t.Fatal("expected ET2 negative id; got", id2)
|
||||
}
|
||||
// 5) The wireType for ET2
|
||||
wire2 := new(wireType);
|
||||
err = decode(b, tWireType, wire2);
|
||||
if err != nil {
|
||||
t.Fatal("error decoding ET2 type:", err)
|
||||
}
|
||||
info = getTypeInfoNoError(reflect.Typeof(ET2{}));
|
||||
trueWire2 := &wireType{s: info.id.gobType().(*structType)};
|
||||
if !reflect.DeepEqual(wire2, trueWire2) {
|
||||
t.Fatalf("invalid wireType for ET2: expected %+v; got %+v\n", *trueWire2, *wire2)
|
||||
}
|
||||
// 6) The length, 6.
|
||||
length = decodeUint(state);
|
||||
if length != 6 {
|
||||
t.Fatal("6. expected length 6; got", length)
|
||||
}
|
||||
// 7) The type id for the et1 value
|
||||
newId1 := decodeInt(state);
|
||||
if newId1 != -id1 {
|
||||
t.Fatal("expected Et1 id", -id1, "got", newId1)
|
||||
}
|
||||
// 8) The value of et1
|
||||
newEt1 := new(ET1);
|
||||
et1Id := getTypeInfoNoError(reflect.Typeof(*newEt1)).id;
|
||||
err = decode(b, et1Id, newEt1);
|
||||
if err != nil {
|
||||
t.Fatal("error decoding ET1 value:", err)
|
||||
}
|
||||
if !reflect.DeepEqual(et1, newEt1) {
|
||||
t.Fatalf("invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1)
|
||||
}
|
||||
// 9) EOF
|
||||
if b.Len() != 0 {
|
||||
t.Error("not at eof;", b.Len(), "bytes left")
|
||||
}
|
||||
|
||||
// Now do it again. This time we should see only the type id and value.
|
||||
b.Reset();
|
||||
enc.Encode(et1);
|
||||
if enc.state.err != nil {
|
||||
t.Error("2nd round: encoder fail:", enc.state.err)
|
||||
}
|
||||
// The length.
|
||||
length = decodeUint(state);
|
||||
if length != 6 {
|
||||
t.Fatal("6. expected length 6; got", length)
|
||||
}
|
||||
// 5a) The type id for the et1 value
|
||||
newId1 = decodeInt(state);
|
||||
if newId1 != -id1 {
|
||||
t.Fatal("2nd round: expected Et1 id", -id1, "got", newId1)
|
||||
}
|
||||
// 6a) The value of et1
|
||||
newEt1 = new(ET1);
|
||||
err = decode(b, et1Id, newEt1);
|
||||
if err != nil {
|
||||
t.Fatal("2nd round: error decoding ET1 value:", err)
|
||||
}
|
||||
if !reflect.DeepEqual(et1, newEt1) {
|
||||
t.Fatalf("2nd round: invalid data for et1: expected %+v; got %+v\n", *et1, *newEt1)
|
||||
}
|
||||
// 7a) EOF
|
||||
if b.Len() != 0 {
|
||||
t.Error("2nd round: not at eof;", b.Len(), "bytes left")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncoderDecoder(t *testing.T) {
|
||||
b := new(bytes.Buffer);
|
||||
enc := NewEncoder(b);
|
||||
@ -215,7 +103,7 @@ func badTypeCheck(e interface{}, shouldFail bool, msg string, t *testing.T) {
|
||||
t.Error("expected error for", msg)
|
||||
}
|
||||
if !shouldFail && (dec.state.err != nil) {
|
||||
t.Error("unexpected error for", msg)
|
||||
t.Error("unexpected error for", msg, dec.state.err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -48,6 +48,7 @@ type gobType interface {
|
||||
|
||||
var types = make(map[reflect.Type]gobType)
|
||||
var idToType = make(map[typeId]gobType)
|
||||
var builtinIdToType map[typeId]gobType // set in init() after builtins are established
|
||||
|
||||
func setTypeId(typ gobType) {
|
||||
nextId++;
|
||||
@ -104,6 +105,10 @@ func init() {
|
||||
checkId(8, getTypeInfoNoError(reflect.Typeof(structType{})).id);
|
||||
checkId(9, getTypeInfoNoError(reflect.Typeof(commonType{})).id);
|
||||
checkId(10, getTypeInfoNoError(reflect.Typeof(fieldType{})).id);
|
||||
builtinIdToType = make(map[typeId]gobType);
|
||||
for k, v := range idToType {
|
||||
builtinIdToType[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
// Array type
|
||||
|
Loading…
Reference in New Issue
Block a user