diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index c09736221e..480d3df075 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -973,17 +973,31 @@ func TestIgnoredFields(t *testing.T) { } } -type Bad0 struct { - ch chan int - c float64 + +func TestBadRecursiveType(t *testing.T) { + type Rec ***Rec + var rec Rec + b := new(bytes.Buffer) + err := NewEncoder(b).Encode(&rec) + if err == nil { + t.Error("expected error; got none") + } else if strings.Index(err.String(), "recursive") < 0 { + t.Error("expected recursive type error; got", err) + } + // Can't test decode easily because we can't encode one, so we can't pass one to a Decoder. +} + +type Bad0 struct { + CH chan int + C float64 } -var nilEncoder *Encoder func TestInvalidField(t *testing.T) { var bad0 Bad0 - bad0.ch = make(chan int) + bad0.CH = make(chan int) b := new(bytes.Buffer) + var nilEncoder *Encoder err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0))) if err == nil { t.Error("expected error; got none") diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index d3f87144da..655a28bfe1 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -410,7 +410,6 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { } func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) { - defer catchError(&err) p = allocate(ut.base, p, ut.indir) state := newDecodeState(dec, &dec.buf) state.fieldnum = singletonField @@ -433,7 +432,6 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) // This state cannot arise for decodeSingle, which is called directly // from the user's value, not from the innards of an engine. func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) { - defer catchError(&err) p = allocate(ut.base.(*reflect.StructType), p, indir) state := newDecodeState(dec, &dec.buf) state.fieldnum = -1 @@ -463,7 +461,6 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, } func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { - defer catchError(&err) state := newDecodeState(dec, &dec.buf) state.fieldnum = -1 for state.b.Len() > 0 { @@ -486,7 +483,6 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { } func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) { - defer catchError(&err) state := newDecodeState(dec, &dec.buf) state.fieldnum = singletonField delta := int(state.decodeUint()) @@ -937,7 +933,6 @@ func isExported(name string) bool { } func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { - defer catchError(&err) srt, ok := rt.(*reflect.StructType) if !ok { return dec.compileSingle(remoteId, rt) @@ -1026,7 +1021,8 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er return } -func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error { +func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) { + defer catchError(&err) // If the value is nil, it means we should just ignore this item. if val == nil { return dec.decodeIgnoredValue(wireId) diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index 1419a27844..92d036c11c 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -200,9 +200,12 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { // Remove any nested writers remaining due to previous errors. enc.w = enc.w[0:1] - enc.err = nil - ut := userType(value.Type()) + ut, err := validUserType(value.Type()) + if err != nil { + return err + } + enc.err = nil state := newEncoderState(enc, new(bytes.Buffer)) enc.sendTypeDescriptor(enc.writer(), state, ut) @@ -212,7 +215,7 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error { } // Encode the object. - err := enc.encode(state.b, value, ut) + err = enc.encode(state.b, value, ut) if err != nil { enc.setError(err) } else { diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index c9c116abf8..3ed4cce924 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -27,28 +27,63 @@ var ( userTypeCache = make(map[reflect.Type]*userTypeInfo) ) -// userType returns, and saves, the information associated with user-provided type rt -func userType(rt reflect.Type) *userTypeInfo { +// validType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, err will be non-nil. To be used when the error handler +// is not set up. +func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { userTypeLock.RLock() - ut := userTypeCache[rt] + ut = userTypeCache[rt] userTypeLock.RUnlock() if ut != nil { - return ut + return } // Now set the value under the write lock. userTypeLock.Lock() defer userTypeLock.Unlock() if ut = userTypeCache[rt]; ut != nil { // Lost the race; not a problem. - return ut + return } ut = new(userTypeInfo) + ut.base = rt ut.user = rt - ut.base, ut.indir = indirect(rt) + // A type that is just a cycle of pointers (such as type T *T) cannot + // be represented in gobs, which need some concrete data. We use a + // cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6, + // pp 539-540. As we step through indirections, run another type at + // half speed. If they meet up, there's a cycle. + // TODO: still need to deal with self-referential non-structs such + // as type T map[string]T but that is a larger undertaking - and can + // be useful, not always erroneous. + slowpoke := ut.base // walks half as fast as ut.base + for { + pt, ok := ut.base.(*reflect.PtrType) + if !ok { + break + } + ut.base = pt.Elem() + if ut.base == slowpoke { // ut.base lapped slowpoke + // recursive pointer type. + return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String()) + } + if ut.indir%2 == 0 { + slowpoke = slowpoke.(*reflect.PtrType).Elem() + } + ut.indir++ + } userTypeCache[rt] = ut - return ut + return } +// userType returns, and saves, the information associated with user-provided type rt. +// If the user type is not valid, it calls error. +func userType(rt reflect.Type) *userTypeInfo { + ut, err := validUserType(rt) + if err != nil { + error(err) + } + return ut +} // A typeId represents a gob Type as an integer that can be passed on the wire. // Internally, typeIds are used as keys to a map to recover the underlying type info. type typeId int32 @@ -273,21 +308,6 @@ func newStructType(name string) *structType { return s } -// Step through the indirections on a type to discover the base type. -// Return the base type and the number of indirections. -func indirect(t reflect.Type) (rt reflect.Type, count int) { - rt = t - for { - pt, ok := rt.(*reflect.PtrType) - if !ok { - break - } - rt = pt.Elem() - count++ - } - return -} - func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { switch t := rt.(type) { // All basic types are easy: they are predefined.