From 98607d01fc7220a768a9468b10cefc559fce5e51 Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Wed, 29 Jul 2009 17:24:25 -0700 Subject: [PATCH] handle unsupported types safely. R=rsc DELTA=154 (71 added, 6 deleted, 77 changed) OCL=32483 CL=32492 --- src/pkg/gob/codec_test.go | 33 ++++++++--- src/pkg/gob/encode.go | 50 ++++++++++------ src/pkg/gob/encoder.go | 9 ++- src/pkg/gob/encoder_test.go | 6 +- src/pkg/gob/type.go | 111 +++++++++++++++++++++--------------- src/pkg/gob/type_test.go | 6 +- 6 files changed, 140 insertions(+), 75 deletions(-) diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index 8263a9286c..9ad3b47278 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -564,7 +564,7 @@ func TestEndToEnd(t *testing.T) { b := new(bytes.Buffer); encode(b, t1); var _t1 T1; - decode(b, getTypeInfo(reflect.Typeof(_t1)).id, &_t1); + decode(b, getTypeInfoNoError(reflect.Typeof(_t1)).id, &_t1); if !reflect.DeepEqual(t1, &_t1) { t.Errorf("encode expected %v got %v", *t1, _t1); } @@ -580,7 +580,7 @@ func TestOverflow(t *testing.T) { } var it inputT; var err os.Error; - id := getTypeInfo(reflect.Typeof(it)).id; + id := getTypeInfoNoError(reflect.Typeof(it)).id; b := new(bytes.Buffer); // int8 @@ -733,7 +733,7 @@ func TestNesting(t *testing.T) { b := new(bytes.Buffer); encode(b, rt); var drt RT; - decode(b, getTypeInfo(reflect.Typeof(drt)).id, &drt); + decode(b, getTypeInfoNoError(reflect.Typeof(drt)).id, &drt); if drt.a != rt.a { t.Errorf("nesting: encode expected %v got %v", *rt, drt); } @@ -775,7 +775,7 @@ func TestAutoIndirection(t *testing.T) { b := new(bytes.Buffer); encode(b, t1); var t0 T0; - t0Id := getTypeInfo(reflect.Typeof(t0)).id; + t0Id := getTypeInfoNoError(reflect.Typeof(t0)).id; decode(b, t0Id, &t0); if t0.a != 17 || t0.b != 177 || t0.c != 1777 || t0.d != 17777 { t.Errorf("t1->t0: expected {17 177 1777 17777}; got %v", t0); @@ -800,7 +800,7 @@ func TestAutoIndirection(t *testing.T) { b.Reset(); encode(b, t0); t1 = T1{}; - t1Id := getTypeInfo(reflect.Typeof(t1)).id; + t1Id := getTypeInfoNoError(reflect.Typeof(t1)).id; decode(b, t1Id, &t1); if t1.a != 17 || *t1.b != 177 || **t1.c != 1777 || ***t1.d != 17777 { t.Errorf("t0->t1 expected {17 177 1777 17777}; got {%d %d %d %d}", t1.a, *t1.b, **t1.c, ***t1.d); @@ -810,7 +810,7 @@ func TestAutoIndirection(t *testing.T) { b.Reset(); encode(b, t0); t2 = T2{}; - t2Id := getTypeInfo(reflect.Typeof(t2)).id; + t2Id := getTypeInfoNoError(reflect.Typeof(t2)).id; decode(b, t2Id, &t2); if ***t2.a != 17 || **t2.b != 177 || *t2.c != 1777 || t2.d != 17777 { t.Errorf("t0->t2 expected {17 177 1777 17777}; got {%d %d %d %d}", ***t2.a, **t2.b, *t2.c, t2.d); @@ -848,7 +848,7 @@ func TestReorderedFields(t *testing.T) { rt0.c = 3.14159; b := new(bytes.Buffer); encode(b, rt0); - rt0Id := getTypeInfo(reflect.Typeof(rt0)).id; + rt0Id := getTypeInfoNoError(reflect.Typeof(rt0)).id; var rt1 RT1; // Wire type is RT0, local type is RT1. decode(b, rt0Id, &rt1); @@ -886,7 +886,7 @@ func TestIgnoredFields(t *testing.T) { b := new(bytes.Buffer); encode(b, it0); - rt0Id := getTypeInfo(reflect.Typeof(it0)).id; + rt0Id := getTypeInfoNoError(reflect.Typeof(it0)).id; var rt1 RT1; // Wire type is IT0, local type is RT1. err := decode(b, rt0Id, &rt1); @@ -897,3 +897,20 @@ func TestIgnoredFields(t *testing.T) { t.Errorf("rt1->rt0: expected %v; got %v", it0, rt1); } } + +type Bad0 struct { + inter interface{}; + c float; +} + +func TestInvalidField(t *testing.T) { + var bad0 Bad0; + bad0.inter = 17; + b := new(bytes.Buffer); + err := encode(b, &bad0); + if err == nil { + t.Error("expected error; got none") + } else if strings.Index(err.String(), "interface") < 0 { + t.Error("expected type error; got", err) + } +} diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index be3599770f..1af79b1fdb 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -335,11 +335,11 @@ var encOpMap = map[reflect.Type] encOp { valueKind("x"): encString, } -func getEncEngine(rt reflect.Type) *encEngine +func getEncEngine(rt reflect.Type) (*encEngine, os.Error) // Return the encoding op for the base type under rt and // the indirection count to reach it. -func encOpFor(rt reflect.Type) (encOp, int) { +func encOpFor(rt reflect.Type) (encOp, int, os.Error) { typ, indir := indirect(rt); op, ok := encOpMap[reflect.Typeof(typ)]; if !ok { @@ -352,7 +352,10 @@ func encOpFor(rt reflect.Type) (encOp, int) { break; } // Slices have a header; we decode it to find the underlying array. - elemOp, indir := encOpFor(t.Elem()); + elemOp, indir, err := encOpFor(t.Elem()); + if err != nil { + return nil, 0, err + } op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { slice := (*reflect.SliceHeader)(p); if slice.Len == 0 { @@ -363,15 +366,21 @@ func encOpFor(rt reflect.Type) (encOp, int) { }; case *reflect.ArrayType: // True arrays have size in the type. - elemOp, indir := encOpFor(t.Elem()); + elemOp, indir, err := encOpFor(t.Elem()); + if err != nil { + return nil, 0, err + } op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i); state.err = encodeArray(state.b, uintptr(p), elemOp, t.Elem().Size(), t.Len(), indir); }; case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. - engine := getEncEngine(typ); - info := getTypeInfo(typ); + engine, err := getEncEngine(typ); + if err != nil { + return nil, 0, err + } + info := getTypeInfoNoError(typ); op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i); // indirect through info to delay evaluation for recursive structs @@ -380,13 +389,13 @@ func encOpFor(rt reflect.Type) (encOp, int) { } } if op == nil { - panicln("can't happen: encode type", rt.String()); + return op, indir, os.ErrorString("gob enc: can't happen: encode type" + rt.String()); } - return op, indir + return op, indir, nil } // The local Type was compiled from the actual value, so we know it's compatible. -func compileEnc(rt reflect.Type) *encEngine { +func compileEnc(rt reflect.Type) (*encEngine, os.Error) { srt, ok := rt.(*reflect.StructType); if !ok { panicln("can't happen: non-struct"); @@ -395,23 +404,29 @@ func compileEnc(rt reflect.Type) *encEngine { engine.instr = make([]encInstr, srt.NumField()+1); // +1 for terminator for fieldnum := 0; fieldnum < srt.NumField(); fieldnum++ { f := srt.Field(fieldnum); - op, indir := encOpFor(f.Type); + op, indir, err := encOpFor(f.Type); + if err != nil { + return nil, err + } engine.instr[fieldnum] = encInstr{op, fieldnum, indir, uintptr(f.Offset)}; } engine.instr[srt.NumField()] = encInstr{encStructTerminator, 0, 0, 0}; - return engine; + return engine, nil; } // typeLock must be held (or we're in initialization and guaranteed single-threaded). // The reflection type must have all its indirections processed out. -func getEncEngine(rt reflect.Type) *encEngine { - info := getTypeInfo(rt); +func getEncEngine(rt reflect.Type) (*encEngine, os.Error) { + info, err := getTypeInfo(rt); + if err != nil { + return nil, err + } if info.encoder == nil { // mark this engine as underway before compiling to handle recursive types. info.encoder = new(encEngine); - info.encoder = compileEnc(rt); + info.encoder, err = compileEnc(rt); } - return info.encoder; + return info.encoder, err; } func encode(b *bytes.Buffer, e interface{}) os.Error { @@ -425,7 +440,10 @@ func encode(b *bytes.Buffer, e interface{}) os.Error { return os.ErrorString("gob: encode can't handle " + v.Type().String()) } typeLock.Lock(); - engine := getEncEngine(rt); + engine, err := getEncEngine(rt); typeLock.Unlock(); + if err != nil { + return err + } return encodeStruct(engine, b, v.Addr()); } diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index 3d8f3928cb..e3ee8b3cb1 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -73,7 +73,7 @@ Maps are not supported yet, but they will be. Interfaces, functions, and channels cannot be sent in a gob. Attempting to encode a value that contains one will - fail. (TODO(r): fix this - it panics now.) + fail. The rest of this comment documents the encoding, details that are not important for most users. Details are presented bottom-up. @@ -267,8 +267,12 @@ func (enc *Encoder) sendType(origt reflect.Type) { // Need to send it. typeLock.Lock(); - info := getTypeInfo(rt); + info, err := getTypeInfo(rt); typeLock.Unlock(); + if err != nil { + enc.state.err = err; + return; + } // Send the pair (-id, type) // Id: encodeInt(enc.state, -int64(info.id)); @@ -285,6 +289,7 @@ func (enc *Encoder) sendType(origt reflect.Type) { for i := 0; i < st.NumField(); i++ { enc.sendType(st.Field(i).Type); } + return } // Encode transmits the data item represented by the empty interface value, diff --git a/src/pkg/gob/encoder_test.go b/src/pkg/gob/encoder_test.go index e1c64e4dd7..6033fbb3fe 100644 --- a/src/pkg/gob/encoder_test.go +++ b/src/pkg/gob/encoder_test.go @@ -69,7 +69,7 @@ func TestBasicEncoder(t *testing.T) { if err != nil { t.Fatal("error decoding ET1 type:", err); } - info := getTypeInfo(reflect.Typeof(ET1{})); + info := getTypeInfoNoError(reflect.Typeof(ET1{})); trueWire1 := &wireType{s: info.id.gobType().(*structType)}; if !reflect.DeepEqual(wire1, trueWire1) { t.Fatalf("invalid wireType for ET1: expected %+v; got %+v\n", *trueWire1, *wire1); @@ -90,7 +90,7 @@ func TestBasicEncoder(t *testing.T) { if err != nil { t.Fatal("error decoding ET2 type:", err); } - info = getTypeInfo(reflect.Typeof(ET2{})); + info = getTypeInfoNoError(reflect.Typeof(ET2{})); trueWire2 := &wireType{s: info.id.gobType().(*structType)}; if !reflect.DeepEqual(wire2, trueWire2) { t.Fatalf("invalid wireType for ET2: expected %+v; got %+v\n", *trueWire2, *wire2); @@ -107,7 +107,7 @@ func TestBasicEncoder(t *testing.T) { } // 8) The value of et1 newEt1 := new(ET1); - et1Id := getTypeInfo(reflect.Typeof(*newEt1)).id; + et1Id := getTypeInfoNoError(reflect.Typeof(*newEt1)).id; err = decode(b, et1Id, newEt1); if err != nil { t.Fatal("error decoding ET1 value:", err); diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index f54746c323..585d4f7472 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -202,7 +202,7 @@ func newStructType(name string) *structType { } // Construction -func newType(name string, rt reflect.Type) gobType +func getType(name string, rt reflect.Type) (gobType, os.Error) // Step through the indirections on a type to discover the base type. // Return the number of indirections. @@ -219,55 +219,63 @@ func indirect(t reflect.Type) (rt reflect.Type, count int) { return; } -func newTypeObject(name string, rt reflect.Type) gobType { +func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { switch t := rt.(type) { // All basic types are easy: they are predefined. case *reflect.BoolType: - return tBool.gobType() + return tBool.gobType(), nil case *reflect.IntType: - return tInt.gobType() + return tInt.gobType(), nil case *reflect.Int8Type: - return tInt.gobType() + return tInt.gobType(), nil case *reflect.Int16Type: - return tInt.gobType() + return tInt.gobType(), nil case *reflect.Int32Type: - return tInt.gobType() + return tInt.gobType(), nil case *reflect.Int64Type: - return tInt.gobType() + return tInt.gobType(), nil case *reflect.UintType: - return tUint.gobType() + return tUint.gobType(), nil case *reflect.Uint8Type: - return tUint.gobType() + return tUint.gobType(), nil case *reflect.Uint16Type: - return tUint.gobType() + return tUint.gobType(), nil case *reflect.Uint32Type: - return tUint.gobType() + return tUint.gobType(), nil case *reflect.Uint64Type: - return tUint.gobType() + return tUint.gobType(), nil case *reflect.UintptrType: - return tUint.gobType() + return tUint.gobType(), nil case *reflect.FloatType: - return tFloat.gobType() + return tFloat.gobType(), nil case *reflect.Float32Type: - return tFloat.gobType() + return tFloat.gobType(), nil case *reflect.Float64Type: - return tFloat.gobType() + return tFloat.gobType(), nil case *reflect.StringType: - return tString.gobType() + return tString.gobType(), nil case *reflect.ArrayType: - return newArrayType(name, newType("", t.Elem()), t.Len()); + gt, err := getType("", t.Elem()); + if err != nil { + return nil, err + } + return newArrayType(name, gt, t.Len()), nil; case *reflect.SliceType: // []byte == []uint8 is a special case if _, ok := t.Elem().(*reflect.Uint8Type); ok { - return tBytes.gobType() + return tBytes.gobType(), nil } - return newSliceType(name, newType(t.Elem().Name(), t.Elem())); + gt, err := getType(t.Elem().Name(), t.Elem()); + if err != nil { + return nil, err + } + return newSliceType(name, gt), nil; case *reflect.StructType: // Install the struct type itself before the fields so recursive @@ -283,18 +291,24 @@ func newTypeObject(name string, rt reflect.Type) gobType { if tname == "" { tname = f.Type.String(); } - field[i] = &fieldType{ f.Name, newType(tname, f.Type).id() }; + gt, err := getType(tname, f.Type); + if err != nil { + return nil, err + } + field[i] = &fieldType{ f.Name, gt.id() }; } strType.field = field; - return strType; + return strType, nil; default: - panicln("gob NewTypeObject can't handle type", rt.String()); // TODO(r): panic? + return nil, os.ErrorString("gob NewTypeObject can't handle type: " + rt.String()); } - return nil + return nil, nil } -func newType(name string, rt reflect.Type) gobType { +// getType returns the Gob type describing the given reflect.Type. +// typeLock must be held. +func getType(name string, rt reflect.Type) (gobType, os.Error) { // Flatten the data structure by collapsing out pointers for { pt, ok := rt.(*reflect.PtrType); @@ -305,19 +319,13 @@ func newType(name string, rt reflect.Type) gobType { } typ, present := types[rt]; if present { - return typ + return typ, nil } - typ = newTypeObject(name, rt); - types[rt] = typ; - return typ -} - -// getType returns the Gob type describing the given reflect.Type. -// typeLock must be held. -func getType(name string, rt reflect.Type) gobType { - // Set lock; all code running under here is synchronized. - t := newType(name, rt); - return t; + typ, err := newTypeObject(name, rt); + if err == nil { + types[rt] = typ + } + return typ, err } func checkId(want, got typeId) { @@ -371,7 +379,7 @@ var typeInfoMap = make(map[reflect.Type] *typeInfo) // protected by typeLock // The reflection type must have all its indirections processed out. // typeLock must be held. -func getTypeInfo(rt reflect.Type) *typeInfo { +func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) { if pt, ok := rt.(*reflect.PtrType); ok { panicln("pointer type in getTypeInfo:", rt.String()) } @@ -379,12 +387,25 @@ func getTypeInfo(rt reflect.Type) *typeInfo { if !ok { info = new(typeInfo); name := rt.Name(); - info.id = getType(name, rt).id(); + gt, err := getType(name, rt); + if err != nil { + return nil, err + } + info.id = gt.id(); // assume it's a struct type info.wire = &wireType{info.id.gobType().(*structType)}; typeInfoMap[rt] = info; } - return info; + return info, nil; +} + +// Called only when a panic is acceptable and unexpected. +func getTypeInfoNoError(rt reflect.Type) *typeInfo { + t, err := getTypeInfo(rt); + if err != nil { + panicln("getTypeInfo:", err.String()); + } + return t } func init() { @@ -396,9 +417,9 @@ func init() { // The string for tBytes is "bytes" not "[]byte" to signify its specialness. tBytes = bootstrapType("bytes", make([]byte, 0), 5); tString= bootstrapType("string", "", 6); - tWireType = getTypeInfo(reflect.Typeof(wireType{})).id; + tWireType = getTypeInfoNoError(reflect.Typeof(wireType{})).id; checkId(7, tWireType); - checkId(8, getTypeInfo(reflect.Typeof(structType{})).id); - checkId(9, getTypeInfo(reflect.Typeof(commonType{})).id); - checkId(10, getTypeInfo(reflect.Typeof(fieldType{})).id); + checkId(8, getTypeInfoNoError(reflect.Typeof(structType{})).id); + checkId(9, getTypeInfoNoError(reflect.Typeof(commonType{})).id); + checkId(10, getTypeInfoNoError(reflect.Typeof(fieldType{})).id); } diff --git a/src/pkg/gob/type_test.go b/src/pkg/gob/type_test.go index 2f11ba3fea..b86fffa21a 100644 --- a/src/pkg/gob/type_test.go +++ b/src/pkg/gob/type_test.go @@ -27,7 +27,11 @@ var basicTypes = []typeT { func getTypeUnlocked(name string, rt reflect.Type) gobType { typeLock.Lock(); defer typeLock.Unlock(); - return getType(name, rt); + t, err := getType(name, rt); + if err != nil { + panicln("getTypeUnlocked:", err.String()) + } + return t; } // Sanity checks