From 22c45c558b69e7f7bdb0ec37519e02d333188da4 Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Mon, 7 Mar 2011 12:08:31 -0800 Subject: [PATCH] gob: support GobEncoder for type T when the receiver is *T. Still to do: **T. R=rsc CC=golang-dev https://golang.org/cl/4247061 --- src/pkg/gob/decode.go | 18 +++++-- src/pkg/gob/encode.go | 18 +++++-- src/pkg/gob/gobencdec_test.go | 26 ++++++++++ src/pkg/gob/type.go | 93 +++++++++++++++++------------------ 4 files changed, 99 insertions(+), 56 deletions(-) diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index b7ae78200d..39af66698f 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -936,8 +936,11 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { // GobDecoder. func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) { rt := ut.user - if ut.decIndir != 0 { - errorf("gob: TODO: can't handle indirection to reach GobDecoder") + if ut.decIndir > 0 { + errorf("gob: TODO: can't handle >0 indirections to reach GobDecoder") + } + if ut.decIndir == -1 { + rt = reflect.PtrTo(rt) } index := -1 for i := 0; i < rt.NumMethod(); i++ { @@ -953,9 +956,16 @@ func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) { op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { // Allocate the underlying data, but hold on to the address we have, // since it's known to be the receiver's address. - // TODO: fix this up when decIndir can be non-zero. allocate(ut.base, uintptr(p), ut.indir) - v := reflect.NewValue(unsafe.Unreflect(rt, p)) + var v reflect.Value + switch { + case ut.decIndir == 0: + v = reflect.NewValue(unsafe.Unreflect(rt, p)) + case ut.decIndir == -1: + v = reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer(&p))) + default: + errorf("gob: TODO: can't handle >0 indirections to reach GobDecoder") + } state.dec.decodeGobDecoder(state, v, index) } return &op, int(ut.decIndir) diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index 9190d92035..2a4a96e5f2 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -573,8 +573,11 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp // GobEncoder. func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { rt := ut.user - if ut.encIndir != 0 { - errorf("gob: TODO: can't handle indirection to reach GobEncoder") + if ut.encIndir > 0 { + errorf("gob: TODO: can't handle >0 indirections to reach GobEncoder") + } + if ut.encIndir == -1 { + rt = reflect.PtrTo(rt) } index := -1 for i := 0; i < rt.NumMethod(); i++ { @@ -588,8 +591,15 @@ func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { } var op encOp op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { - // TODO: this will need fixing when ut.encIndr != 0. - v := reflect.NewValue(unsafe.Unreflect(rt, p)) + var v reflect.Value + switch { + case ut.encIndir == 0: + v = reflect.NewValue(unsafe.Unreflect(rt, p)) + case ut.encIndir == -1: + v = reflect.NewValue(unsafe.Unreflect(rt, unsafe.Pointer(&p))) + default: + errorf("gob: TODO: can't handle >0 indirections to reach GobEncoder") + } state.update(i) state.enc.encodeGobEncoder(state.b, v, index) } diff --git a/src/pkg/gob/gobencdec_test.go b/src/pkg/gob/gobencdec_test.go index 82ca68170e..1f60404f73 100644 --- a/src/pkg/gob/gobencdec_test.go +++ b/src/pkg/gob/gobencdec_test.go @@ -128,6 +128,11 @@ type GobTestIgnoreEncoder struct { X int // guarantee we have something in common with GobTest* } +type GobTestValueEncDec struct { + X int // guarantee we have something in common with GobTest* + G StringStruct // not a pointer. +} + func TestGobEncoderField(t *testing.T) { b := new(bytes.Buffer) // First a field that's a structure. @@ -162,6 +167,27 @@ func TestGobEncoderField(t *testing.T) { } } +// Even though the field is a value, we can still take its address +// and should be able to call the methods. +func TestGobEncoderValueField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + err := enc.Encode(GobTestValueEncDec{17, StringStruct{"HIJKL"}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestValueEncDec) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.G.s != "HIJKL" { + t.Errorf("expected `HIJKL` got %s", x.G.s) + } +} + // As long as the fields have the same name and implement the // interface, we can cross-connect them. Not sure it's useful // and may even be bad but it works and it's hard to prevent diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index a438139415..7fe0272403 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -74,14 +74,12 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { } ut.indir++ } - ut.isGobEncoder, ut.encIndir = implementsGobEncoder(ut.user) - ut.isGobDecoder, ut.decIndir = implementsGobDecoder(ut.user) + ut.isGobEncoder, ut.encIndir = implementsInterface(ut.user, gobEncoderCheck) + ut.isGobDecoder, ut.decIndir = implementsInterface(ut.user, gobDecoderCheck) userTypeCache[rt] = ut - if ut.encIndir != 0 || ut.decIndir != 0 { + if ut.encIndir > 0 || ut.decIndir > 0 { // There are checks in lots of other places, but putting this here means we won't even // attempt to encode/decode this type. - // TODO: make it possible to handle types that are indirect to the implementation, - // such as a structure field of type T when *T implements GobDecoder. return nil, os.ErrorString("TODO: gob can't handle indirections to GobEncoder/Decoder") } return @@ -92,51 +90,41 @@ const ( gobDecodeMethodName = "GobDecode" ) -// implementsGobEncoder reports whether the type implements the interface. It also -// returns the number of indirections required to get to the implementation. -// TODO: when reflection makes it possible, should also be prepared to climb up -// one level if we're not on a pointer (implementation could be on *T for our T). -// That will mean that indir could be < 0, which is sure to cause problems, but -// we ignore them now as indir is always >= 0 now. -func implementsGobEncoder(rt reflect.Type) (implements bool, indir int8) { - if rt == nil { - return +// implements returns whether the type implements the interface, as encoded +// in the check function. +func implements(typ reflect.Type, check func(typ reflect.Type) bool) bool { + if typ.NumMethod() == 0 { // avoid allocations etc. unless there's some chance + return false } - // The type might be a pointer, or it might not, and we need to keep - // dereferencing to the base type until we find an implementation. - for { - if rt.NumMethod() > 0 { // avoid allocations etc. unless there's some chance - if _, ok := reflect.MakeZero(rt).Interface().(GobEncoder); ok { - return true, indir - } - } - if p, ok := rt.(*reflect.PtrType); ok { - indir++ - if indir > 100 { // insane number of indirections - return false, 0 - } - rt = p.Elem() - continue - } - break - } - return false, 0 + return check(typ) } -// implementsGobDecoder reports whether the type implements the interface. It also -// returns the number of indirections required to get to the implementation. -// TODO: see comment on implementsGobEncoder. -func implementsGobDecoder(rt reflect.Type) (implements bool, indir int8) { - if rt == nil { +// gobEncoderCheck makes the type assertion a boolean function. +func gobEncoderCheck(typ reflect.Type) bool { + _, ok := reflect.MakeZero(typ).Interface().(GobEncoder) + return ok +} + +// gobDecoderCheck makes the type assertion a boolean function. +func gobDecoderCheck(typ reflect.Type) bool { + _, ok := reflect.MakeZero(typ).Interface().(GobDecoder) + return ok +} + +// implementsInterface reports whether the type implements the +// interface. (The actual check is done through the provided function.) +// It also returns the number of indirections required to get to the +// implementation. +func implementsInterface(typ reflect.Type, check func(typ reflect.Type) bool) (success bool, indir int8) { + if typ == nil { return } - // The type might be a pointer, or it might not, and we need to keep + rt := typ + // The type might be a pointer and we need to keep // dereferencing to the base type until we find an implementation. for { - if rt.NumMethod() > 0 { // avoid allocations etc. unless there's some chance - if _, ok := reflect.MakeZero(rt).Interface().(GobDecoder); ok { - return true, indir - } + if implements(typ, check) { + return true, indir } if p, ok := rt.(*reflect.PtrType); ok { indir++ @@ -148,6 +136,13 @@ func implementsGobDecoder(rt reflect.Type) (implements bool, indir int8) { } break } + // No luck yet, but if this is a base type (non-pointer), the pointer might satisfy. + if _, ok := typ.(*reflect.PtrType); !ok { + // Not a pointer, but does the pointer work? + if implements(reflect.PtrTo(typ), check) { + return true, -1 + } + } return false, 0 } @@ -702,10 +697,11 @@ func mustGetTypeInfo(rt reflect.Type) *typeInfo { // to guarantee the encoding used by a GobEncoder is stable as the // software evolves. For instance, it might make sense for GobEncode // to include a version number in the encoding. -// +// // Note: At the moment, the type implementing GobEncoder must -// be exactly the type passed to Encode. For example, if *T implements -// GobEncoder, the data item must be of type *T, not T or **T. +// be more indirect than the type passed to Decode. For example, if +// if *T implements GobDecoder, the data item must be of type *T or T, +// not **T or ***T. type GobEncoder interface { // GobEncode returns a byte slice representing the encoding of the // receiver for transmission to a GobDecoder, usually of the same @@ -717,8 +713,9 @@ type GobEncoder interface { // routine for decoding transmitted values sent by a GobEncoder. // // Note: At the moment, the type implementing GobDecoder must -// be exactly the type passed to Decode. For example, if *T implements -// GobDecoder, the data item must be of type *T, not T or **T. +// be more indirect than the type passed to Decode. For example, if +// if *T implements GobDecoder, the data item must be of type *T or T, +// not **T or ***T. type GobDecoder interface { // GobDecode overwrites the receiver, which must be a pointer, // with the value represented by the byte slice, which was written