1
0
mirror of https://github.com/golang/go synced 2024-11-20 05:04:43 -07:00

gob: protect against pure recursive types.

There are further changes required for things like
recursive map types.  Recursive struct types work
but the mechanism needs generalization.  The
case handled in this CL is pathological since it
cannot be represented at all by gob, so it should
be handled separately. (Prior to this CL, encode
would recur forever.)

R=rsc
CC=golang-dev
https://golang.org/cl/4206041
This commit is contained in:
Rob Pike 2011-02-23 09:49:35 -08:00
parent da8e6eec9a
commit c9b90c9d70
4 changed files with 69 additions and 36 deletions

View File

@ -973,17 +973,31 @@ func TestIgnoredFields(t *testing.T) {
}
}
type Bad0 struct {
ch chan int
c float64
func TestBadRecursiveType(t *testing.T) {
type Rec ***Rec
var rec Rec
b := new(bytes.Buffer)
err := NewEncoder(b).Encode(&rec)
if err == nil {
t.Error("expected error; got none")
} else if strings.Index(err.String(), "recursive") < 0 {
t.Error("expected recursive type error; got", err)
}
// Can't test decode easily because we can't encode one, so we can't pass one to a Decoder.
}
type Bad0 struct {
CH chan int
C float64
}
var nilEncoder *Encoder
func TestInvalidField(t *testing.T) {
var bad0 Bad0
bad0.ch = make(chan int)
bad0.CH = make(chan int)
b := new(bytes.Buffer)
var nilEncoder *Encoder
err := nilEncoder.encode(b, reflect.NewValue(&bad0), userType(reflect.Typeof(&bad0)))
if err == nil {
t.Error("expected error; got none")

View File

@ -410,7 +410,6 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr {
}
func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) {
defer catchError(&err)
p = allocate(ut.base, p, ut.indir)
state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField
@ -433,7 +432,6 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr)
// This state cannot arise for decodeSingle, which is called directly
// from the user's value, not from the innards of an engine.
func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, indir int) (err os.Error) {
defer catchError(&err)
p = allocate(ut.base.(*reflect.StructType), p, indir)
state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1
@ -463,7 +461,6 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr,
}
func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
defer catchError(&err)
state := newDecodeState(dec, &dec.buf)
state.fieldnum = -1
for state.b.Len() > 0 {
@ -486,7 +483,6 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) {
}
func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) {
defer catchError(&err)
state := newDecodeState(dec, &dec.buf)
state.fieldnum = singletonField
delta := int(state.decodeUint())
@ -937,7 +933,6 @@ func isExported(name string) bool {
}
func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) {
defer catchError(&err)
srt, ok := rt.(*reflect.StructType)
if !ok {
return dec.compileSingle(remoteId, rt)
@ -1026,7 +1021,8 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er
return
}
func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) os.Error {
func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) {
defer catchError(&err)
// If the value is nil, it means we should just ignore this item.
if val == nil {
return dec.decodeIgnoredValue(wireId)

View File

@ -200,9 +200,12 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
// Remove any nested writers remaining due to previous errors.
enc.w = enc.w[0:1]
enc.err = nil
ut := userType(value.Type())
ut, err := validUserType(value.Type())
if err != nil {
return err
}
enc.err = nil
state := newEncoderState(enc, new(bytes.Buffer))
enc.sendTypeDescriptor(enc.writer(), state, ut)
@ -212,7 +215,7 @@ func (enc *Encoder) EncodeValue(value reflect.Value) os.Error {
}
// Encode the object.
err := enc.encode(state.b, value, ut)
err = enc.encode(state.b, value, ut)
if err != nil {
enc.setError(err)
} else {

View File

@ -27,28 +27,63 @@ var (
userTypeCache = make(map[reflect.Type]*userTypeInfo)
)
// userType returns, and saves, the information associated with user-provided type rt
func userType(rt reflect.Type) *userTypeInfo {
// validType returns, and saves, the information associated with user-provided type rt.
// If the user type is not valid, err will be non-nil. To be used when the error handler
// is not set up.
func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) {
userTypeLock.RLock()
ut := userTypeCache[rt]
ut = userTypeCache[rt]
userTypeLock.RUnlock()
if ut != nil {
return ut
return
}
// Now set the value under the write lock.
userTypeLock.Lock()
defer userTypeLock.Unlock()
if ut = userTypeCache[rt]; ut != nil {
// Lost the race; not a problem.
return ut
return
}
ut = new(userTypeInfo)
ut.base = rt
ut.user = rt
ut.base, ut.indir = indirect(rt)
// A type that is just a cycle of pointers (such as type T *T) cannot
// be represented in gobs, which need some concrete data. We use a
// cycle detection algorithm from Knuth, Vol 2, Section 3.1, Ex 6,
// pp 539-540. As we step through indirections, run another type at
// half speed. If they meet up, there's a cycle.
// TODO: still need to deal with self-referential non-structs such
// as type T map[string]T but that is a larger undertaking - and can
// be useful, not always erroneous.
slowpoke := ut.base // walks half as fast as ut.base
for {
pt, ok := ut.base.(*reflect.PtrType)
if !ok {
break
}
ut.base = pt.Elem()
if ut.base == slowpoke { // ut.base lapped slowpoke
// recursive pointer type.
return nil, os.ErrorString("can't represent recursive pointer type " + ut.base.String())
}
if ut.indir%2 == 0 {
slowpoke = slowpoke.(*reflect.PtrType).Elem()
}
ut.indir++
}
userTypeCache[rt] = ut
return ut
return
}
// userType returns, and saves, the information associated with user-provided type rt.
// If the user type is not valid, it calls error.
func userType(rt reflect.Type) *userTypeInfo {
ut, err := validUserType(rt)
if err != nil {
error(err)
}
return ut
}
// A typeId represents a gob Type as an integer that can be passed on the wire.
// Internally, typeIds are used as keys to a map to recover the underlying type info.
type typeId int32
@ -273,21 +308,6 @@ func newStructType(name string) *structType {
return s
}
// Step through the indirections on a type to discover the base type.
// Return the base type and the number of indirections.
func indirect(t reflect.Type) (rt reflect.Type, count int) {
rt = t
for {
pt, ok := rt.(*reflect.PtrType)
if !ok {
break
}
rt = pt.Elem()
count++
}
return
}
func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) {
switch t := rt.(type) {
// All basic types are easy: they are predefined.