diff --git a/src/pkg/gob/codec_test.go b/src/pkg/gob/codec_test.go index c822d6863ac..4562e19309d 100644 --- a/src/pkg/gob/codec_test.go +++ b/src/pkg/gob/codec_test.go @@ -303,7 +303,7 @@ func TestScalarEncInstructions(t *testing.T) { } } -func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p unsafe.Pointer) { +func execDec(typ string, instr *decInstr, state *decoderState, t *testing.T, p unsafe.Pointer) { defer testError(t) v := int(state.decodeUint()) if v+state.fieldnum != 6 { @@ -313,7 +313,7 @@ func execDec(typ string, instr *decInstr, state *decodeState, t *testing.T, p un state.fieldnum = 6 } -func newDecodeStateFromData(data []byte) *decodeState { +func newDecodeStateFromData(data []byte) *decoderState { b := bytes.NewBuffer(data) state := newDecodeState(nil, b) state.fieldnum = -1 diff --git a/src/pkg/gob/debug.go b/src/pkg/gob/debug.go index e4583901e92..69c83bda782 100644 --- a/src/pkg/gob/debug.go +++ b/src/pkg/gob/debug.go @@ -155,6 +155,16 @@ func (deb *debugger) dump(format string, args ...interface{}) { // Debug prints a human-readable representation of the gob data read from r. func Debug(r io.Reader) { + err := debug(r) + if err != nil { + fmt.Fprintf(os.Stderr, "gob debug: %s\n", err) + } +} + +// debug implements Debug, but catches panics and returns +// them as errors to be printed by Debug. +func debug(r io.Reader) (err os.Error) { + defer catchError(&err) fmt.Fprintln(os.Stderr, "Start of debugging") deb := &debugger{ r: newPeekReader(r), @@ -166,6 +176,7 @@ func Debug(r io.Reader) { deb.remainingKnown = true } deb.gobStream() + return } // note that we've consumed some bytes @@ -386,11 +397,15 @@ func (deb *debugger) typeDefinition(indent tab, id typeId) { // Field number 1 is type Id of key deb.delta(1) keyId := deb.typeId() - wire.SliceT = &sliceType{com, id} // Field number 2 is type Id of elem deb.delta(1) elemId := deb.typeId() wire.MapT = &mapType{com, keyId, elemId} + case 4: // GobEncoder type, one field of {{Common}} + // Field number 0 is CommonType + deb.delta(1) + com := deb.common() + wire.GobEncoderT = &gobEncoderType{com} default: errorf("bad field in type %d", fieldNum) } @@ -507,6 +522,8 @@ func (deb *debugger) printWireType(indent tab, wire *wireType) { for i, field := range wire.StructT.Field { fmt.Fprintf(os.Stderr, "%sfield %d:\t%s\tid=%d\n", indent+1, i, field.Name, field.Id) } + case wire.GobEncoderT != nil: + deb.printCommonType(indent, "GobEncoder", &wire.GobEncoderT.CommonType) } indent-- fmt.Fprintf(os.Stderr, "%s}\n", indent) @@ -538,6 +555,8 @@ func (deb *debugger) fieldValue(indent tab, id typeId) { deb.sliceValue(indent, wire) case wire.StructT != nil: deb.structValue(indent, id) + case wire.GobEncoderT != nil: + deb.gobEncoderValue(indent, id) default: panic("bad wire type for field") } @@ -654,3 +673,17 @@ func (deb *debugger) structValue(indent tab, id typeId) { fmt.Fprintf(os.Stderr, "%s} // end %s struct\n", indent, id.name()) deb.dump(">> End of struct value of type %d %q", id, id.name()) } + +// GobEncoderValue: +// uint(n) byte*n +func (deb *debugger) gobEncoderValue(indent tab, id typeId) { + len := deb.uint64() + deb.dump("GobEncoder value of %q id=%d, length %d\n", id.name(), id, len) + fmt.Fprintf(os.Stderr, "%s%s (implements GobEncoder)\n", indent, id.name()) + data := make([]byte, len) + _, err := deb.r.Read(data) + if err != nil { + errorf("gobEncoder data read: %s", err) + } + fmt.Fprintf(os.Stderr, "%s[% .2x]\n", indent+1, data) +} diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index 8f599e10041..37f49312a8d 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -24,9 +24,9 @@ var ( errRange = os.ErrorString("gob: internal error: field numbers out of bounds") ) -// The execution state of an instance of the decoder. A new state +// decoderState is the execution state of an instance of the decoder. A new state // is created for nested objects. -type decodeState struct { +type decoderState struct { dec *Decoder // The buffer is stored with an extra indirection because it may be replaced // if we load a type during decode (when reading an interface value). @@ -37,8 +37,8 @@ type decodeState struct { // We pass the bytes.Buffer separately for easier testing of the infrastructure // without requiring a full Decoder. -func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decodeState { - d := new(decodeState) +func newDecodeState(dec *Decoder, buf *bytes.Buffer) *decoderState { + d := new(decoderState) d.dec = dec d.b = buf d.buf = make([]byte, uint64Size) @@ -85,7 +85,7 @@ func decodeUintReader(r io.Reader, buf []byte) (x uint64, width int, err os.Erro // decodeUint reads an encoded unsigned integer from state.r. // Does not check for overflow. -func (state *decodeState) decodeUint() (x uint64) { +func (state *decoderState) decodeUint() (x uint64) { b, err := state.b.ReadByte() if err != nil { error(err) @@ -112,7 +112,7 @@ func (state *decodeState) decodeUint() (x uint64) { // decodeInt reads an encoded signed integer from state.r. // Does not check for overflow. -func (state *decodeState) decodeInt() int64 { +func (state *decoderState) decodeInt() int64 { x := state.decodeUint() if x&1 != 0 { return ^int64(x >> 1) @@ -120,7 +120,8 @@ func (state *decodeState) decodeInt() int64 { return int64(x >> 1) } -type decOp func(i *decInstr, state *decodeState, p unsafe.Pointer) +// decOp is the signature of a decoding operator for a given type. +type decOp func(i *decInstr, state *decoderState, p unsafe.Pointer) // The 'instructions' of the decoding machine type decInstr struct { @@ -150,26 +151,31 @@ func decIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { return p } -func ignoreUint(i *decInstr, state *decodeState, p unsafe.Pointer) { +// ignoreUint discards a uint value with no destination. +func ignoreUint(i *decInstr, state *decoderState, p unsafe.Pointer) { state.decodeUint() } -func ignoreTwoUints(i *decInstr, state *decodeState, p unsafe.Pointer) { +// ignoreTwoUints discards a uint value with no destination. It's used to skip +// complex values. +func ignoreTwoUints(i *decInstr, state *decoderState, p unsafe.Pointer) { state.decodeUint() state.decodeUint() } -func decBool(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decBool decodes a uiint and stores it as a boolean through p. +func decBool(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(bool)) } p = *(*unsafe.Pointer)(p) } - *(*bool)(p) = state.decodeInt() != 0 + *(*bool)(p) = state.decodeUint() != 0 } -func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decInt8 decodes an integer and stores it as an int8 through p. +func decInt8(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int8)) @@ -184,7 +190,8 @@ func decInt8(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decUint8 decodes an unsigned integer and stores it as a uint8 through p. +func decUint8(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint8)) @@ -199,7 +206,8 @@ func decUint8(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decInt16 decodes an integer and stores it as an int16 through p. +func decInt16(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int16)) @@ -214,7 +222,8 @@ func decInt16(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decUint16 decodes an unsigned integer and stores it as a uint16 through p. +func decUint16(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint16)) @@ -229,7 +238,8 @@ func decUint16(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decInt32 decodes an integer and stores it as an int32 through p. +func decInt32(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int32)) @@ -244,7 +254,8 @@ func decInt32(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decUint32 decodes an unsigned integer and stores it as a uint32 through p. +func decUint32(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint32)) @@ -259,7 +270,8 @@ func decUint32(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decInt64 decodes an integer and stores it as an int64 through p. +func decInt64(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(int64)) @@ -269,7 +281,8 @@ func decInt64(i *decInstr, state *decodeState, p unsafe.Pointer) { *(*int64)(p) = int64(state.decodeInt()) } -func decUint64(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decUint64 decodes an unsigned integer and stores it as a uint64 through p. +func decUint64(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(uint64)) @@ -294,7 +307,9 @@ func floatFromBits(u uint64) float64 { return math.Float64frombits(v) } -func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { +// storeFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point +// number, and stores it through p. It's a helper function for float32 and complex64. +func storeFloat32(i *decInstr, state *decoderState, p unsafe.Pointer) { v := floatFromBits(state.decodeUint()) av := v if av < 0 { @@ -308,7 +323,9 @@ func storeFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { } } -func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point +// number, and stores it through p. +func decFloat32(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float32)) @@ -318,7 +335,9 @@ func decFloat32(i *decInstr, state *decodeState, p unsafe.Pointer) { storeFloat32(i, state, p) } -func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decFloat64 decodes an unsigned integer, treats it as a 64-bit floating-point +// number, and stores it through p. +func decFloat64(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(float64)) @@ -328,8 +347,10 @@ func decFloat64(i *decInstr, state *decodeState, p unsafe.Pointer) { *(*float64)(p) = floatFromBits(uint64(state.decodeUint())) } -// Complex numbers are just a pair of floating-point numbers, real part first. -func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decComplex64 decodes a pair of unsigned integers, treats them as a +// pair of floating point numbers, and stores them as a complex64 through p. +// The real part comes first. +func decComplex64(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex64)) @@ -340,7 +361,10 @@ func decComplex64(i *decInstr, state *decodeState, p unsafe.Pointer) { storeFloat32(i, state, unsafe.Pointer(uintptr(p)+uintptr(unsafe.Sizeof(float32(0))))) } -func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) { +// decComplex128 decodes a pair of unsigned integers, treats them as a +// pair of floating point numbers, and stores them as a complex128 through p. +// The real part comes first. +func decComplex128(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new(complex128)) @@ -352,8 +376,10 @@ func decComplex128(i *decInstr, state *decodeState, p unsafe.Pointer) { *(*complex128)(p) = complex(real, imag) } +// decUint8Array decodes byte array and stores through p a slice header +// describing the data. // uint8 arrays are encoded as an unsigned count followed by the raw bytes. -func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { +func decUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]uint8)) @@ -365,8 +391,10 @@ func decUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { *(*[]uint8)(p) = b } +// decString decodes byte array and stores through p a string header +// describing the data. // Strings are encoded as an unsigned count followed by the raw bytes. -func decString(i *decInstr, state *decodeState, p unsafe.Pointer) { +func decString(i *decInstr, state *decoderState, p unsafe.Pointer) { if i.indir > 0 { if *(*unsafe.Pointer)(p) == nil { *(*unsafe.Pointer)(p) = unsafe.Pointer(new([]byte)) @@ -378,7 +406,8 @@ func decString(i *decInstr, state *decodeState, p unsafe.Pointer) { *(*string)(p) = string(b) } -func ignoreUint8Array(i *decInstr, state *decodeState, p unsafe.Pointer) { +// ignoreUint8Array skips over the data for a byte slice value with no destination. +func ignoreUint8Array(i *decInstr, state *decoderState, p unsafe.Pointer) { b := make([]byte, state.decodeUint()) state.b.Read(b) } @@ -409,8 +438,15 @@ func allocate(rtyp reflect.Type, p uintptr, indir int) uintptr { return *(*uintptr)(up) } +// decodeSingle decodes a top-level value that is not a struct and stores it through p. +// Such values are preceded by a zero, making them have the memory layout of a +// struct field (although with an illegal field number). func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) (err os.Error) { - p = allocate(ut.base, p, ut.indir) + indir := ut.indir + if ut.isGobDecoder { + indir = int(ut.decIndir) + } + p = allocate(ut.base, p, indir) state := newDecodeState(dec, &dec.buf) state.fieldnum = singletonField basep := p @@ -427,6 +463,7 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, p uintptr) return nil } +// decodeSingle decodes a top-level struct and stores it through p. // Indir is for the value, not the type. At the time of the call it may // differ from ut.indir, which was computed when the engine was built. // This state cannot arise for decodeSingle, which is called directly @@ -460,6 +497,7 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, p uintptr, return nil } +// ignoreStruct discards the data for a struct with no destination. func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { state := newDecodeState(dec, &dec.buf) state.fieldnum = -1 @@ -482,6 +520,8 @@ func (dec *Decoder) ignoreStruct(engine *decEngine) (err os.Error) { return nil } +// ignoreSingle discards the data for a top-level non-struct value with no +// destination. It's used when calling Decode with a nil value. func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) { state := newDecodeState(dec, &dec.buf) state.fieldnum = singletonField @@ -494,7 +534,8 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) (err os.Error) { return nil } -func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) { +// decodeArrayHelper does the work for decoding arrays and slices. +func (dec *Decoder) decodeArrayHelper(state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, elemIndir int, ovfl os.ErrorString) { instr := &decInstr{elemOp, 0, elemIndir, 0, ovfl} for i := 0; i < length; i++ { up := unsafe.Pointer(p) @@ -506,7 +547,10 @@ func (dec *Decoder) decodeArrayHelper(state *decodeState, p uintptr, elemOp decO } } -func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) { +// decodeArray decodes an array and stores it through p, that is, p points to the zeroth element. +// The length is an unsigned integer preceding the elements. Even though the length is redundant +// (it's part of the type), it's a useful check and is included in the encoding. +func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, length, indir, elemIndir int, ovfl os.ErrorString) { if indir > 0 { p = allocate(atyp, p, 1) // All but the last level has been allocated by dec.Indirect } @@ -516,7 +560,9 @@ func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p u dec.decodeArrayHelper(state, p, elemOp, elemWid, length, elemIndir, ovfl) } -func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { +// decodeIntoValue is a helper for map decoding. Since maps are decoded using reflection, +// unlike the other items we can't use a pointer directly. +func decodeIntoValue(state *decoderState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { instr := &decInstr{op, 0, indir, 0, ovfl} up := unsafe.Pointer(v.UnsafeAddr()) if indir > 1 { @@ -526,7 +572,11 @@ func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, o return v } -func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) { +// decodeMap decodes a map and stores its header through p. +// Maps are encoded as a length followed by key:value pairs. +// Because the internals of maps are not visible to us, we must +// use reflection rather than pointer magic. +func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decoderState, p uintptr, keyOp, elemOp decOp, indir, keyIndir, elemIndir int, ovfl os.ErrorString) { if indir > 0 { p = allocate(mtyp, p, 1) // All but the last level has been allocated by dec.Indirect } @@ -538,7 +588,7 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for // the iteration. - v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer((p)))).(*reflect.MapValue) + v := reflect.NewValue(unsafe.Unreflect(mtyp, unsafe.Pointer(p))).(*reflect.MapValue) n := int(state.decodeUint()) for i := 0; i < n; i++ { key := decodeIntoValue(state, keyOp, keyIndir, reflect.MakeZero(mtyp.Key()), ovfl) @@ -547,21 +597,24 @@ func (dec *Decoder) decodeMap(mtyp *reflect.MapType, state *decodeState, p uintp } } -func (dec *Decoder) ignoreArrayHelper(state *decodeState, elemOp decOp, length int) { +// ignoreArrayHelper does the work for discarding arrays and slices. +func (dec *Decoder) ignoreArrayHelper(state *decoderState, elemOp decOp, length int) { instr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} for i := 0; i < length; i++ { elemOp(instr, state, nil) } } -func (dec *Decoder) ignoreArray(state *decodeState, elemOp decOp, length int) { +// ignoreArray discards the data for an array value with no destination. +func (dec *Decoder) ignoreArray(state *decoderState, elemOp decOp, length int) { if n := state.decodeUint(); n != uint64(length) { errorf("gob: length mismatch in ignoreArray") } dec.ignoreArrayHelper(state, elemOp, length) } -func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) { +// ignoreMap discards the data for a map value with no destination. +func (dec *Decoder) ignoreMap(state *decoderState, keyOp, elemOp decOp) { n := int(state.decodeUint()) keyInstr := &decInstr{keyOp, 0, 0, 0, os.ErrorString("no error")} elemInstr := &decInstr{elemOp, 0, 0, 0, os.ErrorString("no error")} @@ -571,7 +624,9 @@ func (dec *Decoder) ignoreMap(state *decodeState, keyOp, elemOp decOp) { } } -func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) { +// decodeSlice decodes a slice and stores the slice header through p. +// Slices are encoded as an unsigned length followed by the elements. +func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decoderState, p uintptr, elemOp decOp, elemWid uintptr, indir, elemIndir int, ovfl os.ErrorString) { n := int(uintptr(state.decodeUint())) if indir > 0 { up := unsafe.Pointer(p) @@ -590,7 +645,8 @@ func (dec *Decoder) decodeSlice(atyp *reflect.SliceType, state *decodeState, p u dec.decodeArrayHelper(state, hdrp.Data, elemOp, elemWid, n, elemIndir, ovfl) } -func (dec *Decoder) ignoreSlice(state *decodeState, elemOp decOp) { +// ignoreSlice skips over the data for a slice value with no destination. +func (dec *Decoder) ignoreSlice(state *decoderState, elemOp decOp) { dec.ignoreArrayHelper(state, elemOp, int(state.decodeUint())) } @@ -609,9 +665,10 @@ func setInterfaceValue(ivalue *reflect.InterfaceValue, value reflect.Value) { ivalue.Set(value) } -// decodeInterface receives the name of a concrete type followed by its value. +// decodeInterface decodes an interface value and stores it through p. +// Interfaces are encoded as the name of a concrete type followed by a value. // If the name is empty, the value is nil and no value is sent. -func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeState, p uintptr, indir int) { +func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decoderState, p uintptr, indir int) { // Create an interface reflect.Value. We need one even for the nil case. ivalue := reflect.MakeZero(ityp).(*reflect.InterfaceValue) // Read the name of the concrete type. @@ -655,7 +712,8 @@ func (dec *Decoder) decodeInterface(ityp *reflect.InterfaceType, state *decodeSt *(*[2]uintptr)(unsafe.Pointer(p)) = ivalue.Get() } -func (dec *Decoder) ignoreInterface(state *decodeState) { +// ignoreInterface discards the data for an interface value with no destination. +func (dec *Decoder) ignoreInterface(state *decoderState) { // Read the name of the concrete type. b := make([]byte, state.decodeUint()) _, err := state.b.Read(b) @@ -670,6 +728,32 @@ func (dec *Decoder) ignoreInterface(state *decodeState) { state.b.Next(int(state.decodeUint())) } +// decodeGobDecoder decodes something implementing the GobDecoder interface. +// The data is encoded as a byte slice. +func (dec *Decoder) decodeGobDecoder(state *decoderState, v reflect.Value, index int) { + // Read the bytes for the value. + b := make([]byte, state.decodeUint()) + _, err := state.b.Read(b) + if err != nil { + error(err) + } + // We know it's a GobDecoder, so just call the method directly. + err = v.Interface().(_GobDecoder)._GobDecode(b) + if err != nil { + error(err) + } +} + +// ignoreGobDecoder discards the data for a GobDecoder value with no destination. +func (dec *Decoder) ignoreGobDecoder(state *decoderState) { + // Read the bytes for the value. + b := make([]byte, state.decodeUint()) + _, err := state.b.Read(b) + if err != nil { + error(err) + } +} + // Index by Go types. var decOpTable = [...]decOp{ reflect.Bool: decBool, @@ -699,10 +783,14 @@ var decIgnoreOpMap = map[typeId]decOp{ tComplex: ignoreTwoUints, } -// Return the decoding op for the base type under rt and +// decOpFor returns the decoding op for the base type under rt and // the indirection count to reach it. func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProgress map[reflect.Type]*decOp) (*decOp, int) { ut := userType(rt) + // If the type implements GobEncoder, we handle it without further processing. + if ut.isGobDecoder { + return dec.gobDecodeOpFor(ut) + } // If this type is already in progress, it's a recursive type (e.g. map[string]*T). // Return the pointer to the op we're already building. if opPtr := inProgress[rt]; opPtr != nil { @@ -724,7 +812,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg elemId := dec.wireType[wireId].ArrayT.Elem elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { state.dec.decodeArray(t, state, uintptr(p), *elemOp, t.Elem().Size(), t.Len(), i.indir, elemIndir, ovfl) } @@ -735,7 +823,7 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg keyOp, keyIndir := dec.decOpFor(keyId, t.Key(), name, inProgress) elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { up := unsafe.Pointer(p) state.dec.decodeMap(t, state, uintptr(up), *keyOp, *elemOp, i.indir, keyIndir, elemIndir, ovfl) } @@ -754,17 +842,17 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg } elemOp, elemIndir := dec.decOpFor(elemId, t.Elem(), name, inProgress) ovfl := overflow(name) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { state.dec.decodeSlice(t, state, uintptr(p), *elemOp, t.Elem().Size(), i.indir, elemIndir, ovfl) } case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. - enginePtr, err := dec.getDecEnginePtr(wireId, typ) + enginePtr, err := dec.getDecEnginePtr(wireId, userType(typ)) if err != nil { error(err) } - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { // indirect through enginePtr to delay evaluation for recursive structs. err = dec.decodeStruct(*enginePtr, userType(typ), uintptr(p), i.indir) if err != nil { @@ -772,8 +860,8 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg } } case *reflect.InterfaceType: - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - dec.decodeInterface(t, state, uintptr(p), i.indir) + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { + state.dec.decodeInterface(t, state, uintptr(p), i.indir) } } } @@ -783,15 +871,15 @@ func (dec *Decoder) decOpFor(wireId typeId, rt reflect.Type, name string, inProg return &op, indir } -// Return the decoding op for a field that has no destination. +// decIgnoreOpFor returns the decoding op for a field that has no destination. func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { op, ok := decIgnoreOpMap[wireId] if !ok { if wireId == tInterface { // Special case because it's a method: the ignored item might // define types and we need to record their state in the decoder. - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { - dec.ignoreInterface(state) + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { + state.dec.ignoreInterface(state) } return op } @@ -803,7 +891,7 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { case wire.ArrayT != nil: elemId := wire.ArrayT.Elem elemOp := dec.decIgnoreOpFor(elemId) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { state.dec.ignoreArray(state, elemOp, wire.ArrayT.Len) } @@ -812,14 +900,14 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { elemId := dec.wireType[wireId].MapT.Elem keyOp := dec.decIgnoreOpFor(keyId) elemOp := dec.decIgnoreOpFor(elemId) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { state.dec.ignoreMap(state, keyOp, elemOp) } case wire.SliceT != nil: elemId := wire.SliceT.Elem elemOp := dec.decIgnoreOpFor(elemId) - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { state.dec.ignoreSlice(state, elemOp) } @@ -829,10 +917,15 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { if err != nil { error(err) } - op = func(i *decInstr, state *decodeState, p unsafe.Pointer) { + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { // indirect through enginePtr to delay evaluation for recursive structs state.dec.ignoreStruct(*enginePtr) } + + case wire.GobEncoderT != nil: + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { + state.dec.ignoreGobDecoder(state) + } } } if op == nil { @@ -841,16 +934,56 @@ func (dec *Decoder) decIgnoreOpFor(wireId typeId) decOp { return op } -// Are these two gob Types compatible? -// Answers the question for basic types, arrays, and slices. +// gobDecodeOpFor returns the op for a type that is known to implement +// GobDecoder. +func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) (*decOp, int) { + rt := ut.user + if ut.decIndir != 0 { + errorf("gob: TODO: can't handle indirection to reach GobDecoder") + } + index := -1 + for i := 0; i < rt.NumMethod(); i++ { + if rt.Method(i).Name == gobDecodeMethodName { + index = i + break + } + } + if index < 0 { + panic("can't find GobDecode method") + } + var op decOp + op = func(i *decInstr, state *decoderState, p unsafe.Pointer) { + // Allocate the underlying data, but hold on to the address we have, + // since it's known to be the receiver's address. + // TODO: fix this up when decIndir can be non-zero. + allocate(ut.base, uintptr(p), ut.indir) + v := reflect.NewValue(unsafe.Unreflect(rt, p)) + state.dec.decodeGobDecoder(state, v, index) + } + return &op, int(ut.decIndir) + +} + +// compatibleType asks: Are these two gob Types compatible? +// Answers the question for basic types, arrays, maps and slices, plus +// GobEncoder/Decoder pairs. // Structs are considered ok; fields will be checked later. func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[reflect.Type]typeId) bool { if rhs, ok := inProgress[fr]; ok { return rhs == fw } inProgress[fr] = fw - fr = userType(fr).base - switch t := fr.(type) { + ut := userType(fr) + wire, ok := dec.wireType[fw] + // If fr is a GobDecoder, the wire type must be GobEncoder. + // And if fr is not a GobDecoder, the wire type must not be either. + if ut.isGobDecoder != (ok && wire.GobEncoderT != nil) { // the parentheses look odd but are correct. + return false + } + if ut.isGobDecoder { // This test trumps all others. + return true + } + switch t := ut.base.(type) { default: // chan, etc: cannot handle. return false @@ -869,14 +1002,12 @@ func (dec *Decoder) compatibleType(fr reflect.Type, fw typeId, inProgress map[re case *reflect.InterfaceType: return fw == tInterface case *reflect.ArrayType: - wire, ok := dec.wireType[fw] if !ok || wire.ArrayT == nil { return false } array := wire.ArrayT return t.Len() == array.Len && dec.compatibleType(t.Elem(), array.Elem, inProgress) case *reflect.MapType: - wire, ok := dec.wireType[fw] if !ok || wire.MapT == nil { return false } @@ -911,8 +1042,13 @@ func (dec *Decoder) typeString(remoteId typeId) string { return dec.wireType[remoteId].string() } - -func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { +// compileSingle compiles the decoder engine for a non-struct top-level value, including +// GobDecoders. +func (dec *Decoder) compileSingle(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err os.Error) { + rt := ut.base + if ut.isGobDecoder { + rt = ut.user + } engine = new(decEngine) engine.instr = make([]decInstr, 1) // one item name := rt.String() // best we can do @@ -926,6 +1062,7 @@ func (dec *Decoder) compileSingle(remoteId typeId, rt reflect.Type) (engine *dec return } +// compileIgnoreSingle compiles the decoder engine for a non-struct top-level value that will be discarded. func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err os.Error) { engine = new(decEngine) engine.instr = make([]decInstr, 1) // one item @@ -936,16 +1073,19 @@ func (dec *Decoder) compileIgnoreSingle(remoteId typeId) (engine *decEngine, err return } -// Is this an exported - upper case - name? +// isExported reports whether this is an exported - upper case - name. func isExported(name string) bool { rune, _ := utf8.DecodeRuneInString(name) return unicode.IsUpper(rune) } -func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEngine, err os.Error) { +// compileDec compiles the decoder engine for a value. If the value is not a struct, +// it calls out to compileSingle. +func (dec *Decoder) compileDec(remoteId typeId, ut *userTypeInfo) (engine *decEngine, err os.Error) { + rt := ut.base srt, ok := rt.(*reflect.StructType) - if !ok { - return dec.compileSingle(remoteId, rt) + if !ok || ut.isGobDecoder { + return dec.compileSingle(remoteId, ut) } var wireStruct *structType // Builtin types can come from global pool; the rest must be defined by the decoder. @@ -990,7 +1130,9 @@ func (dec *Decoder) compileDec(remoteId typeId, rt reflect.Type) (engine *decEng return } -func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr **decEngine, err os.Error) { +// getDecEnginePtr returns the engine for the specified type. +func (dec *Decoder) getDecEnginePtr(remoteId typeId, ut *userTypeInfo) (enginePtr **decEngine, err os.Error) { + rt := ut.base decoderMap, ok := dec.decoderCache[rt] if !ok { decoderMap = make(map[typeId]**decEngine) @@ -1000,7 +1142,7 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr // To handle recursive types, mark this engine as underway before compiling. enginePtr = new(*decEngine) decoderMap[remoteId] = enginePtr - *enginePtr, err = dec.compileDec(remoteId, rt) + *enginePtr, err = dec.compileDec(remoteId, ut) if err != nil { decoderMap[remoteId] = nil, false } @@ -1008,11 +1150,12 @@ func (dec *Decoder) getDecEnginePtr(remoteId typeId, rt reflect.Type) (enginePtr return } -// When ignoring struct data, in effect we compile it into this type +// emptyStruct is the type we compile into when ignoring a struct value. type emptyStruct struct{} var emptyStructType = reflect.Typeof(emptyStruct{}) +// getDecEnginePtr returns the engine for the specified type when the value is to be discarded. func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, err os.Error) { var ok bool if enginePtr, ok = dec.ignorerCache[wireId]; !ok { @@ -1021,7 +1164,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er dec.ignorerCache[wireId] = enginePtr wire := dec.wireType[wireId] if wire != nil && wire.StructT != nil { - *enginePtr, err = dec.compileDec(wireId, emptyStructType) + *enginePtr, err = dec.compileDec(wireId, userType(emptyStructType)) } else { *enginePtr, err = dec.compileIgnoreSingle(wireId) } @@ -1032,6 +1175,7 @@ func (dec *Decoder) getIgnoreEnginePtr(wireId typeId) (enginePtr **decEngine, er return } +// decodeValue decodes the data stream representing a value and stores it in val. func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) { defer catchError(&err) // If the value is nil, it means we should just ignore this item. @@ -1042,12 +1186,18 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) ut := userType(val.Type()) base := ut.base indir := ut.indir - enginePtr, err := dec.getDecEnginePtr(wireId, base) + if ut.isGobDecoder { + indir = int(ut.decIndir) + if indir != 0 { + errorf("TODO: can't handle indirection in GobDecoder value") + } + } + enginePtr, err := dec.getDecEnginePtr(wireId, ut) if err != nil { return err } engine := *enginePtr - if st, ok := base.(*reflect.StructType); ok { + if st, ok := base.(*reflect.StructType); ok && !ut.isGobDecoder { if engine.numInstr == 0 && st.NumField() > 0 && len(dec.wireType[wireId].StructT.Field) > 0 { name := base.Name() return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) @@ -1057,6 +1207,7 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) return dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr())) } +// decodeIgnoredValue decodes the data stream representing a value of the specified type and discards it. func (dec *Decoder) decodeIgnoredValue(wireId typeId) os.Error { enginePtr, err := dec.getIgnoreEnginePtr(wireId) if err != nil { diff --git a/src/pkg/gob/decoder.go b/src/pkg/gob/decoder.go index f7c994ffa78..71927458369 100644 --- a/src/pkg/gob/decoder.go +++ b/src/pkg/gob/decoder.go @@ -21,7 +21,7 @@ type Decoder struct { wireType map[typeId]*wireType // map from remote ID to local description decoderCache map[reflect.Type]map[typeId]**decEngine // cache of compiled engines ignorerCache map[typeId]**decEngine // ditto for ignored objects - countState *decodeState // reads counts from wire + countState *decoderState // reads counts from wire countBuf []byte // used for decoding integers while parsing messages tmp []byte // temporary storage for i/o; saves reallocating err os.Error diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index e92db74ffdd..d69e734ff9f 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -15,7 +15,7 @@ import ( const uint64Size = unsafe.Sizeof(uint64(0)) -// The global execution state of an instance of the encoder. +// encoderState is the global execution state of an instance of the encoder. // Field numbers are delta encoded and always increase. The field // number is initialized to -1 so 0 comes out as delta(1). A delta of // 0 terminates the structure. @@ -72,6 +72,7 @@ func (state *encoderState) encodeInt(i int64) { state.encodeUint(uint64(x)) } +// encOp is the signature of an encoding operator for a given type. type encOp func(i *encInstr, state *encoderState, p unsafe.Pointer) // The 'instructions' of the encoding machine @@ -82,8 +83,8 @@ type encInstr struct { offset uintptr // offset in the structure of the field to encode } -// Emit a field number and update the state to record its value for delta encoding. -// If the instruction pointer is nil, do nothing +// update emits a field number and updates the state to record its value for delta encoding. +// If the instruction pointer is nil, it does nothing func (state *encoderState) update(instr *encInstr) { if instr != nil { state.encodeUint(uint64(instr.field - state.fieldnum)) @@ -97,6 +98,7 @@ func (state *encoderState) update(instr *encInstr) { // Otherwise, the output (for a scalar) is the field number, as an encoded integer, // followed by the field data in its appropriate format. +// encIndirect dereferences p indir times and returns the result. func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { for ; indir > 0; indir-- { p = *(*unsafe.Pointer)(p) @@ -107,6 +109,7 @@ func encIndirect(p unsafe.Pointer, indir int) unsafe.Pointer { return p } +// encBool encodes the bool with address p as an unsigned 0 or 1. func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) { b := *(*bool)(p) if b || state.sendZero { @@ -119,6 +122,7 @@ func encBool(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt encodes the int with address p. func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int)(p)) if v != 0 || state.sendZero { @@ -127,6 +131,7 @@ func encInt(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUint encodes the uint with address p. func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint)(p)) if v != 0 || state.sendZero { @@ -135,6 +140,7 @@ func encUint(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt8 encodes the int8 with address p. func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int8)(p)) if v != 0 || state.sendZero { @@ -143,6 +149,7 @@ func encInt8(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUint8 encodes the uint8 with address p. func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint8)(p)) if v != 0 || state.sendZero { @@ -151,6 +158,7 @@ func encUint8(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt16 encodes the int16 with address p. func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int16)(p)) if v != 0 || state.sendZero { @@ -159,6 +167,7 @@ func encInt16(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUint16 encodes the uint16 with address p. func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint16)(p)) if v != 0 || state.sendZero { @@ -167,6 +176,7 @@ func encUint16(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt32 encodes the int32 with address p. func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) { v := int64(*(*int32)(p)) if v != 0 || state.sendZero { @@ -175,6 +185,7 @@ func encInt32(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUint encodes the uint32 with address p. func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uint32)(p)) if v != 0 || state.sendZero { @@ -183,6 +194,7 @@ func encUint32(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt64 encodes the int64 with address p. func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) { v := *(*int64)(p) if v != 0 || state.sendZero { @@ -191,6 +203,7 @@ func encInt64(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encInt64 encodes the uint64 with address p. func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) { v := *(*uint64)(p) if v != 0 || state.sendZero { @@ -199,6 +212,7 @@ func encUint64(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUintptr encodes the uintptr with address p. func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) { v := uint64(*(*uintptr)(p)) if v != 0 || state.sendZero { @@ -207,6 +221,7 @@ func encUintptr(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// floatBits returns a uint64 holding the bits of a floating-point number. // Floating-point numbers are transmitted as uint64s holding the bits // of the underlying representation. They are sent byte-reversed, with // the exponent end coming out first, so integer floating point numbers @@ -223,6 +238,7 @@ func floatBits(f float64) uint64 { return v } +// encFloat32 encodes the float32 with address p. func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) { f := *(*float32)(p) if f != 0 || state.sendZero { @@ -232,6 +248,7 @@ func encFloat32(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encFloat64 encodes the float64 with address p. func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) { f := *(*float64)(p) if f != 0 || state.sendZero { @@ -241,6 +258,7 @@ func encFloat64(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encComplex64 encodes the complex64 with address p. // Complex numbers are just a pair of floating-point numbers, real part first. func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) { c := *(*complex64)(p) @@ -253,6 +271,7 @@ func encComplex64(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encComplex128 encodes the complex128 with address p. func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) { c := *(*complex128)(p) if c != 0+0i || state.sendZero { @@ -264,6 +283,7 @@ func encComplex128(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encUint8Array encodes the byte slice whose header has address p. // Byte arrays are encoded as an unsigned count followed by the raw bytes. func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) { b := *(*[]byte)(p) @@ -274,6 +294,7 @@ func encUint8Array(i *encInstr, state *encoderState, p unsafe.Pointer) { } } +// encString encodes the string whose header has address p. // Strings are encoded as an unsigned count followed by the raw bytes. func encString(i *encInstr, state *encoderState, p unsafe.Pointer) { s := *(*string)(p) @@ -284,14 +305,15 @@ func encString(i *encInstr, state *encoderState, p unsafe.Pointer) { } } -// The end of a struct is marked by a delta field number of 0. +// encStructTerminator encodes the end of an encoded struct +// as delta field number of 0. func encStructTerminator(i *encInstr, state *encoderState, p unsafe.Pointer) { state.encodeUint(0) } // Execution engine -// The encoder engine is an array of instructions indexed by field number of the encoding +// encEngine an array of instructions indexed by field number of the encoding // data, typically a struct. It is executed top to bottom, walking the struct. type encEngine struct { instr []encInstr @@ -299,6 +321,7 @@ type encEngine struct { const singletonField = 0 +// encodeSingle encodes a single top-level non-struct value. func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintptr) { state := newEncoderState(enc, b) state.fieldnum = singletonField @@ -315,6 +338,7 @@ func (enc *Encoder) encodeSingle(b *bytes.Buffer, engine *encEngine, basep uintp instr.op(instr, state, p) } +// encodeStruct encodes a single struct value. func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintptr) { state := newEncoderState(enc, b) state.fieldnum = -1 @@ -330,6 +354,7 @@ func (enc *Encoder) encodeStruct(b *bytes.Buffer, engine *encEngine, basep uintp } } +// encodeArray encodes the array whose 0th element is at p. func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid uintptr, elemIndir int, length int) { state := newEncoderState(enc, b) state.fieldnum = -1 @@ -349,6 +374,7 @@ func (enc *Encoder) encodeArray(b *bytes.Buffer, p uintptr, op encOp, elemWid ui } } +// encodeReflectValue is a helper for maps. It encodes the value v. func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir int) { for i := 0; i < indir && v != nil; i++ { v = reflect.Indirect(v) @@ -359,6 +385,9 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in op(nil, state, unsafe.Pointer(v.UnsafeAddr())) } +// encodeMap encodes a map as unsigned count followed by key:value pairs. +// Because map internals are not exposed, we must use reflection rather than +// addresses. func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) { state := newEncoderState(enc, b) state.fieldnum = -1 @@ -371,6 +400,7 @@ func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elem } } +// encodeInterface encodes the interface value iv. // To send an interface, we send a string identifying the concrete type, followed // by the type identifier (which might require defining that type right now), followed // by the concrete value. A nil value gets sent as the empty string for the name, @@ -414,6 +444,21 @@ func (enc *Encoder) encodeInterface(b *bytes.Buffer, iv *reflect.InterfaceValue) } } +// encGobEncoder encodes a value that implements the GobEncoder interface. +// The data is sent as a byte array. +func (enc *Encoder) encodeGobEncoder(b *bytes.Buffer, v reflect.Value, index int) { + // TODO: should we catch panics from the called method? + // We know it's a GobEncoder, so just call the method directly. + data, err := v.Interface().(_GobEncoder)._GobEncode() + if err != nil { + error(err) + } + state := newEncoderState(enc, b) + state.fieldnum = -1 + state.encodeUint(uint64(len(data))) + state.b.Write(data) +} + var encOpTable = [...]encOp{ reflect.Bool: encBool, reflect.Int: encInt, @@ -434,10 +479,14 @@ var encOpTable = [...]encOp{ reflect.String: encString, } -// Return (a pointer to) the encoding op for the base type under rt and +// encOpFor returns (a pointer to) the encoding op for the base type under rt and // the indirection count to reach it. func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp) (*encOp, int) { ut := userType(rt) + // If the type implements GobEncoder, we handle it without further processing. + if ut.isGobEncoder { + return enc.gobEncodeOpFor(ut) + } // If this type is already in progress, it's a recursive type (e.g. map[string]*T). // Return the pointer to the op we're already building. if opPtr := inProgress[rt]; opPtr != nil { @@ -483,7 +532,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp // Maps cannot be accessed by moving addresses around the way // that slices etc. can. We must recover a full reflection value for // the iteration. - v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p)))) + v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p))) mv := reflect.Indirect(v).(*reflect.MapValue) if !state.sendZero && mv.Len() == 0 { return @@ -493,7 +542,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp } case *reflect.StructType: // Generate a closure that calls out to the engine for the nested type. - enc.getEncEngine(typ) + enc.getEncEngine(userType(typ)) info := mustGetTypeInfo(typ) op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { state.update(i) @@ -504,7 +553,7 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { // Interfaces transmit the name and contents of the concrete // value they contain. - v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer((p)))) + v := reflect.NewValue(unsafe.Unreflect(t, unsafe.Pointer(p))) iv := reflect.Indirect(v).(*reflect.InterfaceValue) if !state.sendZero && (iv == nil || iv.IsNil()) { return @@ -520,12 +569,43 @@ func (enc *Encoder) encOpFor(rt reflect.Type, inProgress map[reflect.Type]*encOp return &op, indir } -// The local Type was compiled from the actual value, so we know it's compatible. -func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine { - srt, isStruct := rt.(*reflect.StructType) +// gobEncodeOpFor returns the op for a type that is known to implement +// GobEncoder. +func (enc *Encoder) gobEncodeOpFor(ut *userTypeInfo) (*encOp, int) { + rt := ut.user + if ut.encIndir != 0 { + errorf("gob: TODO: can't handle indirection to reach GobEncoder") + } + index := -1 + for i := 0; i < rt.NumMethod(); i++ { + if rt.Method(i).Name == gobEncodeMethodName { + index = i + break + } + } + if index < 0 { + panic("can't find GobEncode method") + } + var op encOp + op = func(i *encInstr, state *encoderState, p unsafe.Pointer) { + // TODO: this will need fixing when ut.encIndr != 0. + v := reflect.NewValue(unsafe.Unreflect(rt, p)) + state.update(i) + state.enc.encodeGobEncoder(state.b, v, index) + } + return &op, int(ut.encIndir) +} + +// compileEnc returns the engine to compile the type. +func (enc *Encoder) compileEnc(ut *userTypeInfo) *encEngine { + srt, isStruct := ut.base.(*reflect.StructType) engine := new(encEngine) seen := make(map[reflect.Type]*encOp) - if isStruct { + rt := ut.base + if ut.isGobEncoder { + rt = ut.user + } + if !ut.isGobEncoder && isStruct { for fieldNum := 0; fieldNum < srt.NumField(); fieldNum++ { f := srt.Field(fieldNum) if !isExported(f.Name) { @@ -546,35 +626,43 @@ func (enc *Encoder) compileEnc(rt reflect.Type) *encEngine { return engine } +// getEncEngine returns the engine to compile the type. // typeLock must be held (or we're in initialization and guaranteed single-threaded). -// The reflection type must have all its indirections processed out. -func (enc *Encoder) getEncEngine(rt reflect.Type) *encEngine { - info, err1 := getTypeInfo(rt) +func (enc *Encoder) getEncEngine(ut *userTypeInfo) *encEngine { + info, err1 := getTypeInfo(ut) if err1 != nil { error(err1) } if info.encoder == nil { // mark this engine as underway before compiling to handle recursive types. info.encoder = new(encEngine) - info.encoder = enc.compileEnc(rt) + info.encoder = enc.compileEnc(ut) } return info.encoder } -// Put this in a function so we can hold the lock only while compiling, not when encoding. -func (enc *Encoder) lockAndGetEncEngine(rt reflect.Type) *encEngine { +// lockAndGetEncEngine is a function that locks and compiles. +// This lets us hold the lock only while compiling, not when encoding. +func (enc *Encoder) lockAndGetEncEngine(ut *userTypeInfo) *encEngine { typeLock.Lock() defer typeLock.Unlock() - return enc.getEncEngine(rt) + return enc.getEncEngine(ut) } func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInfo) (err os.Error) { defer catchError(&err) - for i := 0; i < ut.indir; i++ { + engine := enc.lockAndGetEncEngine(ut) + indir := ut.indir + if ut.isGobEncoder { + indir = int(ut.encIndir) + if indir != 0 { + errorf("TODO: can't handle indirection in GobEncoder value") + } + } + for i := 0; i < indir; i++ { value = reflect.Indirect(value) } - engine := enc.lockAndGetEncEngine(ut.base) - if value.Type().Kind() == reflect.Struct { + if !ut.isGobEncoder && value.Type().Kind() == reflect.Struct { enc.encodeStruct(b, engine, value.UnsafeAddr()) } else { enc.encodeSingle(b, engine, value.UnsafeAddr()) diff --git a/src/pkg/gob/encoder.go b/src/pkg/gob/encoder.go index 92d036c11c3..4bfcf15c7f9 100644 --- a/src/pkg/gob/encoder.go +++ b/src/pkg/gob/encoder.go @@ -78,12 +78,57 @@ func (enc *Encoder) writeMessage(w io.Writer, b *bytes.Buffer) { } } -func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { - // Drill down to the base type. - ut := userType(origt) - rt := ut.base +// sendActualType sends the requested type, without further investigation, unless +// it's been sent before. +func (enc *Encoder) sendActualType(w io.Writer, state *encoderState, ut *userTypeInfo, actual reflect.Type) (sent bool) { + if _, alreadySent := enc.sent[actual]; alreadySent { + return false + } + typeLock.Lock() + info, err := getTypeInfo(ut) + typeLock.Unlock() + if err != nil { + enc.setError(err) + return + } + // Send the pair (-id, type) + // Id: + state.encodeInt(-int64(info.id)) + // Type: + enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo) + enc.writeMessage(w, state.b) + if enc.err != nil { + return + } - switch rt := rt.(type) { + // Remember we've sent this type, both what the user gave us and the base type. + enc.sent[ut.base] = info.id + if ut.user != ut.base { + enc.sent[ut.user] = info.id + } + // Now send the inner types + switch st := actual.(type) { + case *reflect.StructType: + for i := 0; i < st.NumField(); i++ { + enc.sendType(w, state, st.Field(i).Type) + } + case reflect.ArrayOrSliceType: + enc.sendType(w, state, st.Elem()) + } + return true +} + +// sendType sends the type info to the other side, if necessary. +func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Type) (sent bool) { + ut := userType(origt) + if ut.isGobEncoder { + // The rules are different: regardless of the underlying type's representation, + // we need to tell the other side that this exact type is a GobEncoder. + return enc.sendActualType(w, state, ut, ut.user) + } + + // It's a concrete value, so drill down to the base type. + switch rt := ut.base.(type) { default: // Basic types and interfaces do not need to be described. return @@ -109,43 +154,7 @@ func (enc *Encoder) sendType(w io.Writer, state *encoderState, origt reflect.Typ return } - // Have we already sent this type? This time we ask about the base type. - if _, alreadySent := enc.sent[rt]; alreadySent { - return - } - - // Need to send it. - typeLock.Lock() - info, err := getTypeInfo(rt) - typeLock.Unlock() - if err != nil { - enc.setError(err) - return - } - // Send the pair (-id, type) - // Id: - state.encodeInt(-int64(info.id)) - // Type: - enc.encode(state.b, reflect.NewValue(info.wire), wireTypeUserInfo) - enc.writeMessage(w, state.b) - if enc.err != nil { - return - } - - // Remember we've sent this type. - enc.sent[rt] = info.id - // Remember we've sent the top-level, possibly indirect type too. - enc.sent[origt] = info.id - // Now send the inner types - switch st := rt.(type) { - case *reflect.StructType: - for i := 0; i < st.NumField(); i++ { - enc.sendType(w, state, st.Field(i).Type) - } - case reflect.ArrayOrSliceType: - enc.sendType(w, state, st.Elem()) - } - return true + return enc.sendActualType(w, state, ut, ut.base) } // Encode transmits the data item represented by the empty interface value, @@ -159,11 +168,17 @@ func (enc *Encoder) Encode(e interface{}) os.Error { // sent. func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *userTypeInfo) { // Make sure the type is known to the other side. - // First, have we already sent this (base) type? - base := ut.base - if _, alreadySent := enc.sent[base]; !alreadySent { + // First, have we already sent this type? + rt := ut.base + if ut.isGobEncoder { + rt = ut.user + if ut.encIndir != 0 { + panic("TODO: can't handle non-zero encIndir") + } + } + if _, alreadySent := enc.sent[rt]; !alreadySent { // No, so send it. - sent := enc.sendType(w, state, base) + sent := enc.sendType(w, state, rt) if enc.err != nil { return } @@ -172,13 +187,13 @@ func (enc *Encoder) sendTypeDescriptor(w io.Writer, state *encoderState, ut *use // need to send the type info but we do need to update enc.sent. if !sent { typeLock.Lock() - info, err := getTypeInfo(base) + info, err := getTypeInfo(ut) typeLock.Unlock() if err != nil { enc.setError(err) return } - enc.sent[base] = info.id + enc.sent[rt] = info.id } } } diff --git a/src/pkg/gob/gobencdec_test.go b/src/pkg/gob/gobencdec_test.go new file mode 100644 index 00000000000..dbe7d3fe313 --- /dev/null +++ b/src/pkg/gob/gobencdec_test.go @@ -0,0 +1,331 @@ +// Copyright 20011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This file contains tests of the GobEncoder/GobDecoder support. + +package gob + +import ( + "bytes" + "fmt" + "os" + "strings" + "testing" +) + +// Types that implement the GobEncoder/Decoder interfaces. + +type ByteStruct struct { + a byte // not an exported field +} + +type StringStruct struct { + s string // not an exported field +} + +type Gobber int + +type ValueGobber string // encodes with a value, decodes with a pointer. + +// The relevant methods + +func (g *ByteStruct) _GobEncode() ([]byte, os.Error) { + b := make([]byte, 3) + b[0] = g.a + b[1] = g.a + 1 + b[2] = g.a + 2 + return b, nil +} + +func (g *ByteStruct) _GobDecode(data []byte) os.Error { + if g == nil { + return os.ErrorString("NIL RECEIVER") + } + // Expect N sequential-valued bytes. + if len(data) == 0 { + return os.EOF + } + g.a = data[0] + for i, c := range data { + if c != g.a+byte(i) { + return os.ErrorString("invalid data sequence") + } + } + return nil +} + +func (g *StringStruct) _GobEncode() ([]byte, os.Error) { + return []byte(g.s), nil +} + +func (g *StringStruct) _GobDecode(data []byte) os.Error { + // Expect N sequential-valued bytes. + if len(data) == 0 { + return os.EOF + } + a := data[0] + for i, c := range data { + if c != a+byte(i) { + return os.ErrorString("invalid data sequence") + } + } + g.s = string(data) + return nil +} + +func (g *Gobber) _GobEncode() ([]byte, os.Error) { + return []byte(fmt.Sprintf("VALUE=%d", *g)), nil +} + +func (g *Gobber) _GobDecode(data []byte) os.Error { + _, err := fmt.Sscanf(string(data), "VALUE=%d", (*int)(g)) + return err +} + +func (v ValueGobber) _GobEncode() ([]byte, os.Error) { + return []byte(fmt.Sprintf("VALUE=%s", v)), nil +} + +func (v *ValueGobber) _GobDecode(data []byte) os.Error { + _, err := fmt.Sscanf(string(data), "VALUE=%s", (*string)(v)) + return err +} + +// Structs that include GobEncodable fields. + +type GobTest0 struct { + X int // guarantee we have something in common with GobTest* + G *ByteStruct +} + +type GobTest1 struct { + X int // guarantee we have something in common with GobTest* + G *StringStruct +} + +type GobTest2 struct { + X int // guarantee we have something in common with GobTest* + G string // not a GobEncoder - should give us errors +} + +type GobTest3 struct { + X int // guarantee we have something in common with GobTest* + G *Gobber // TODO: should be able to satisfy interface without a pointer +} + +type GobTest4 struct { + X int // guarantee we have something in common with GobTest* + V ValueGobber +} + +type GobTest5 struct { + X int // guarantee we have something in common with GobTest* + V *ValueGobber +} + +type GobTestIgnoreEncoder struct { + X int // guarantee we have something in common with GobTest* +} + +func TestGobEncoderField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + err := enc.Encode(GobTest0{17, &ByteStruct{'A'}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest0) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.G.a != 'A' { + t.Errorf("expected 'A' got %c", x.G.a) + } + // Now a field that's not a structure. + b.Reset() + gobber := Gobber(23) + err = enc.Encode(GobTest3{17, &gobber}) + if err != nil { + t.Fatal("encode error:", err) + } + y := new(GobTest3) + err = dec.Decode(y) + if err != nil { + t.Fatal("decode error:", err) + } + if *y.G != 23 { + t.Errorf("expected '23 got %d", *y.G) + } +} + +// As long as the fields have the same name and implement the +// interface, we can cross-connect them. Not sure it's useful +// and may even be bad but it works and it's hard to prevent +// without exposing the contents of the object, which would +// defeat the purpose. +func TestGobEncoderFieldsOfDifferentType(t *testing.T) { + // first, string in field to byte in field + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest0) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.G.a != 'A' { + t.Errorf("expected 'A' got %c", x.G.a) + } + // now the other direction, byte in field to string in field + b.Reset() + err = enc.Encode(GobTest0{17, &ByteStruct{'X'}}) + if err != nil { + t.Fatal("encode error:", err) + } + y := new(GobTest1) + err = dec.Decode(y) + if err != nil { + t.Fatal("decode error:", err) + } + if y.G.s != "XYZ" { + t.Fatalf("expected `XYZ` got %c", y.G.s) + } +} + +// Test that we can encode a value and decode into a pointer. +func TestGobEncoderValueEncoder(t *testing.T) { + // first, string in field to byte in field + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(GobTest4{17, ValueGobber("hello")}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTest5) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if *x.V != "hello" { + t.Errorf("expected `hello` got %s", x.V) + } +} + +func TestGobEncoderFieldTypeError(t *testing.T) { + // GobEncoder to non-decoder: error + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(GobTest1{17, &StringStruct{"ABC"}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := &GobTest2{} + err = dec.Decode(x) + if err == nil { + t.Fatal("expected decode error for mistmatched fields (encoder to non-decoder)") + } + if strings.Index(err.String(), "type") < 0 { + t.Fatal("expected type error; got", err) + } + // Non-encoder to GobDecoder: error + b.Reset() + err = enc.Encode(GobTest2{17, "ABC"}) + if err != nil { + t.Fatal("encode error:", err) + } + y := &GobTest1{} + err = dec.Decode(y) + if err == nil { + t.Fatal("expected decode error for mistmatched fields (non-encoder to decoder)") + } + if strings.Index(err.String(), "type") < 0 { + t.Fatal("expected type error; got", err) + } +} + +// Even though ByteStruct is a struct, it's treated as a singleton at the top level. +func TestGobEncoderStructSingleton(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + err := enc.Encode(&ByteStruct{'A'}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(ByteStruct) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.a != 'A' { + t.Errorf("expected 'A' got %c", x.a) + } +} + +func TestGobEncoderNonStructSingleton(t *testing.T) { + b := new(bytes.Buffer) + enc := NewEncoder(b) + g := Gobber(1234) // TODO: shouldn't need to take the address here. + err := enc.Encode(&g) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + var x Gobber + err = dec.Decode(&x) + if err != nil { + t.Fatal("decode error:", err) + } + if x != 1234 { + t.Errorf("expected 1234 got %c", x) + } +} + +func TestGobEncoderIgnoreStructField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + err := enc.Encode(GobTest0{17, &ByteStruct{'A'}}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestIgnoreEncoder) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.X != 17 { + t.Errorf("expected 17 got %c", x.X) + } +} + +func TestGobEncoderIgnoreNonStructField(t *testing.T) { + b := new(bytes.Buffer) + // First a field that's a structure. + enc := NewEncoder(b) + gobber := Gobber(23) + err := enc.Encode(GobTest3{17, &gobber}) + if err != nil { + t.Fatal("encode error:", err) + } + dec := NewDecoder(b) + x := new(GobTestIgnoreEncoder) + err = dec.Decode(x) + if err != nil { + t.Fatal("decode error:", err) + } + if x.X != 17 { + t.Errorf("expected 17 got %c", x.X) + } +} diff --git a/src/pkg/gob/type.go b/src/pkg/gob/type.go index 6e3f148b4e7..05d5f122e9a 100644 --- a/src/pkg/gob/type.go +++ b/src/pkg/gob/type.go @@ -15,9 +15,13 @@ import ( // to the package. It's computed once and stored in a map keyed by reflection // type. type userTypeInfo struct { - user reflect.Type // the type the user handed us - base reflect.Type // the base type after all indirections - indir int // number of indirections to reach the base type + user reflect.Type // the type the user handed us + base reflect.Type // the base type after all indirections + indir int // number of indirections to reach the base type + isGobEncoder bool // does the type implement _GobEncoder? + isGobDecoder bool // does the type implement _GobDecoder? + encIndir int8 // number of indirections to reach the receiver type; may be negative + decIndir int8 // number of indirections to reach the receiver type; may be negative } var ( @@ -68,10 +72,83 @@ func validUserType(rt reflect.Type) (ut *userTypeInfo, err os.Error) { } ut.indir++ } + ut.isGobEncoder, ut.encIndir = implementsGobEncoder(ut.user) + ut.isGobDecoder, ut.decIndir = implementsGobDecoder(ut.user) userTypeCache[rt] = ut + if ut.encIndir != 0 || ut.decIndir != 0 { + // There are checks in lots of other places, but putting this here means we won't even + // attempt to encode/decode this type. + // TODO: make it possible to handle types that are indirect to the implementation, + // such as a structure field of type T when *T implements GobDecoder. + return nil, os.ErrorString("TODO: gob can't handle indirections to GobEncoder/Decoder") + } return } +const ( + gobEncodeMethodName = "_GobEncode" + gobDecodeMethodName = "_GobDecode" +) + +// implementsGobEncoder reports whether the type implements the interface. It also +// returns the number of indirections required to get to the implementation. +// TODO: when reflection makes it possible, should also be prepared to climb up +// one level if we're not on a pointer (implementation could be on *T for our T). +// That will mean that indir could be < 0, which is sure to cause problems, but +// we ignore them now as indir is always >= 0 now. +func implementsGobEncoder(rt reflect.Type) (implements bool, indir int8) { + if rt == nil { + return + } + // The type might be a pointer, or it might not, and we need to keep + // dereferencing to the base type until we find an implementation. + for { + if rt.NumMethod() > 0 { // avoid allocations etc. unless there's some chance + if _, ok := reflect.MakeZero(rt).Interface().(_GobEncoder); ok { + return true, indir + } + } + if p, ok := rt.(*reflect.PtrType); ok { + indir++ + if indir > 100 { // insane number of indirections + return false, 0 + } + rt = p.Elem() + continue + } + break + } + return false, 0 +} + +// implementsGobDecoder reports whether the type implements the interface. It also +// returns the number of indirections required to get to the implementation. +// TODO: see comment on implementsGobEncoder. +func implementsGobDecoder(rt reflect.Type) (implements bool, indir int8) { + if rt == nil { + return + } + // The type might be a pointer, or it might not, and we need to keep + // dereferencing to the base type until we find an implementation. + for { + if rt.NumMethod() > 0 { // avoid allocations etc. unless there's some chance + if _, ok := reflect.MakeZero(rt).Interface().(_GobDecoder); ok { + return true, indir + } + } + if p, ok := rt.(*reflect.PtrType); ok { + indir++ + if indir > 100 { // insane number of indirections + return false, 0 + } + rt = p.Elem() + continue + } + break + } + return false, 0 +} + // userType returns, and saves, the information associated with user-provided type rt. // If the user type is not valid, it calls error. func userType(rt reflect.Type) *userTypeInfo { @@ -229,6 +306,23 @@ func (a *arrayType) safeString(seen map[typeId]bool) string { func (a *arrayType) string() string { return a.safeString(make(map[typeId]bool)) } +// GobEncoder type (something that implements the _GobEncoder interface) +type gobEncoderType struct { + CommonType +} + +func newGobEncoderType(name string) *gobEncoderType { + g := &gobEncoderType{CommonType{Name: name}} + setTypeId(g) + return g +} + +func (g *gobEncoderType) safeString(seen map[typeId]bool) string { + return g.Name +} + +func (g *gobEncoderType) string() string { return g.Name } + // Map type type mapType struct { CommonType @@ -328,7 +422,16 @@ func (s *structType) init(field []*fieldType) { s.Field = field } -func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { +// newTypeObject allocates a gobType for the reflection type rt. +// Unless ut represents a GobEncoder, rt should be the base type +// of ut. +// This is only called from the encoding side. The decoding side +// works through typeIds and userTypeInfos alone. +func newTypeObject(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) { + // Does this type implement GobEncoder? + if ut.isGobEncoder { + return newGobEncoderType(name), nil + } var err os.Error var type0, type1 gobType defer func() { @@ -364,7 +467,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { case *reflect.ArrayType: at := newArrayType(name) types[rt] = at - type0, err = getType("", t.Elem()) + type0, err = getBaseType("", t.Elem()) if err != nil { return nil, err } @@ -382,11 +485,11 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { case *reflect.MapType: mt := newMapType(name) types[rt] = mt - type0, err = getType("", t.Key()) + type0, err = getBaseType("", t.Key()) if err != nil { return nil, err } - type1, err = getType("", t.Elem()) + type1, err = getBaseType("", t.Elem()) if err != nil { return nil, err } @@ -400,7 +503,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { } st := newSliceType(name) types[rt] = st - type0, err = getType(t.Elem().Name(), t.Elem()) + type0, err = getBaseType(t.Elem().Name(), t.Elem()) if err != nil { return nil, err } @@ -413,6 +516,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { idToType[st.id()] = st field := make([]*fieldType, t.NumField()) for i := 0; i < t.NumField(); i++ { + // TODO: don't send unexported fields. f := t.Field(i) typ := userType(f.Type).base tname := typ.Name() @@ -420,7 +524,7 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { t := userType(f.Type).base tname = t.String() } - gt, err := getType(tname, f.Type) + gt, err := getBaseType(tname, f.Type) if err != nil { return nil, err } @@ -435,15 +539,24 @@ func newTypeObject(name string, rt reflect.Type) (gobType, os.Error) { return nil, nil } -// getType returns the Gob type describing the given reflect.Type. +// getBaseType returns the Gob type describing the given reflect.Type's base type. // typeLock must be held. -func getType(name string, rt reflect.Type) (gobType, os.Error) { - rt = userType(rt).base +func getBaseType(name string, rt reflect.Type) (gobType, os.Error) { + ut := userType(rt) + return getType(name, ut, ut.base) +} + +// getType returns the Gob type describing the given reflect.Type. +// Should be called only when handling GobEncoders/Decoders, +// which may be pointers. All other types are handled through the +// base type, never a pointer. +// typeLock must be held. +func getType(name string, ut *userTypeInfo, rt reflect.Type) (gobType, os.Error) { typ, present := types[rt] if present { return typ, nil } - typ, err := newTypeObject(name, rt) + typ, err := newTypeObject(name, ut, rt) if err == nil { types[rt] = typ } @@ -484,10 +597,11 @@ func bootstrapType(name string, e interface{}, expect typeId) typeId { // To maintain binary compatibility, if you extend this type, always put // the new fields last. type wireType struct { - ArrayT *arrayType - SliceT *sliceType - StructT *structType - MapT *mapType + ArrayT *arrayType + SliceT *sliceType + StructT *structType + MapT *mapType + GobEncoderT *gobEncoderType } func (w *wireType) string() string { @@ -504,6 +618,8 @@ func (w *wireType) string() string { return w.StructT.Name case w.MapT != nil: return w.MapT.Name + case w.GobEncoderT != nil: + return w.GobEncoderT.Name } return unknown } @@ -516,23 +632,43 @@ type typeInfo struct { var typeInfoMap = make(map[reflect.Type]*typeInfo) // protected by typeLock -// The reflection type must have all its indirections processed out. // typeLock must be held. -func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) { - if rt.Kind() == reflect.Ptr { - panic("pointer type in getTypeInfo: " + rt.String()) +func getTypeInfo(ut *userTypeInfo) (*typeInfo, os.Error) { + + if ut.isGobEncoder { + // TODO: clean up this code - too much duplication. + info, ok := typeInfoMap[ut.user] + if ok { + return info, nil + } + // We want the user type, not the base type. + userType, err := getType(ut.user.Name(), ut, ut.user) + if err != nil { + return nil, err + } + info = new(typeInfo) + gt, err := getBaseType(ut.base.Name(), ut.base) + if err != nil { + return nil, err + } + info.id = gt.id() + info.wire = &wireType{GobEncoderT: userType.id().gobType().(*gobEncoderType)} + typeInfoMap[ut.user] = info + return info, nil } - info, ok := typeInfoMap[rt] + + base := ut.base + info, ok := typeInfoMap[base] if !ok { info = new(typeInfo) - name := rt.Name() - gt, err := getType(name, rt) + name := base.Name() + gt, err := getBaseType(name, base) if err != nil { return nil, err } info.id = gt.id() t := info.id.gobType() - switch typ := rt.(type) { + switch typ := base.(type) { case *reflect.ArrayType: info.wire = &wireType{ArrayT: t.(*arrayType)} case *reflect.MapType: @@ -545,20 +681,27 @@ func getTypeInfo(rt reflect.Type) (*typeInfo, os.Error) { case *reflect.StructType: info.wire = &wireType{StructT: t.(*structType)} } - typeInfoMap[rt] = info + typeInfoMap[base] = info } return info, nil } // Called only when a panic is acceptable and unexpected. func mustGetTypeInfo(rt reflect.Type) *typeInfo { - t, err := getTypeInfo(rt) + t, err := getTypeInfo(userType(rt)) if err != nil { panic("getTypeInfo: " + err.String()) } return t } +type _GobEncoder interface { + _GobEncode() ([]byte, os.Error) +} // use _ prefix until we get it working properly +type _GobDecoder interface { + _GobDecode([]byte) os.Error +} // use _ prefix until we get it working properly + var ( nameToConcreteType = make(map[string]reflect.Type) concreteTypeToName = make(map[reflect.Type]string) diff --git a/src/pkg/gob/type_test.go b/src/pkg/gob/type_test.go index 5aecde103a5..6fe1ecf93e0 100644 --- a/src/pkg/gob/type_test.go +++ b/src/pkg/gob/type_test.go @@ -26,7 +26,7 @@ var basicTypes = []typeT{ func getTypeUnlocked(name string, rt reflect.Type) gobType { typeLock.Lock() defer typeLock.Unlock() - t, err := getType(name, rt) + t, err := getBaseType(name, rt) if err != nil { panic("getTypeUnlocked: " + err.String()) }