diff --git a/src/pkg/encoding/gob/codec_test.go b/src/pkg/encoding/gob/codec_test.go index c7b2567ca02..4f17a28931d 100644 --- a/src/pkg/encoding/gob/codec_test.go +++ b/src/pkg/encoding/gob/codec_test.go @@ -323,7 +323,7 @@ func execDec(typ string, instr *decInstr, state *decoderState, t *testing.T, val if v+state.fieldnum != 6 { t.Fatalf("decoding field number %d, got %d", 6, v+state.fieldnum) } - instr.op(instr, state, value) + instr.op(instr, state, value.Elem()) state.fieldnum = 6 } diff --git a/src/pkg/encoding/gob/decode.go b/src/pkg/encoding/gob/decode.go index feed80513c0..76274a0cacd 100644 --- a/src/pkg/encoding/gob/decode.go +++ b/src/pkg/encoding/gob/decode.go @@ -155,6 +155,8 @@ func ignoreTwoUints(i *decInstr, state *decoderState, v reflect.Value) { // decAlloc takes a value and returns a settable value that can // be assigned to. If the value is a pointer, decAlloc guarantees it points to storage. +// The callers to the individual decoders are expected to have used decAlloc. +// The individual decoders don't need to it. func decAlloc(v reflect.Value) reflect.Value { for v.Kind() == reflect.Ptr { if v.IsNil() { @@ -167,7 +169,7 @@ func decAlloc(v reflect.Value) reflect.Value { // decBool decodes a uint and stores it as a boolean in value. func decBool(i *decInstr, state *decoderState, value reflect.Value) { - decAlloc(value).SetBool(state.decodeUint() != 0) + value.SetBool(state.decodeUint() != 0) } // decInt8 decodes an integer and stores it as an int8 in value. @@ -176,7 +178,7 @@ func decInt8(i *decInstr, state *decoderState, value reflect.Value) { if v < math.MinInt8 || math.MaxInt8 < v { error_(i.ovfl) } - decAlloc(value).SetInt(v) + value.SetInt(v) } // decUint8 decodes an unsigned integer and stores it as a uint8 in value. @@ -185,7 +187,7 @@ func decUint8(i *decInstr, state *decoderState, value reflect.Value) { if math.MaxUint8 < v { error_(i.ovfl) } - decAlloc(value).SetUint(v) + value.SetUint(v) } // decInt16 decodes an integer and stores it as an int16 in value. @@ -194,7 +196,7 @@ func decInt16(i *decInstr, state *decoderState, value reflect.Value) { if v < math.MinInt16 || math.MaxInt16 < v { error_(i.ovfl) } - decAlloc(value).SetInt(v) + value.SetInt(v) } // decUint16 decodes an unsigned integer and stores it as a uint16 in value. @@ -203,7 +205,7 @@ func decUint16(i *decInstr, state *decoderState, value reflect.Value) { if math.MaxUint16 < v { error_(i.ovfl) } - decAlloc(value).SetUint(v) + value.SetUint(v) } // decInt32 decodes an integer and stores it as an int32 in value. @@ -212,7 +214,7 @@ func decInt32(i *decInstr, state *decoderState, value reflect.Value) { if v < math.MinInt32 || math.MaxInt32 < v { error_(i.ovfl) } - decAlloc(value).SetInt(v) + value.SetInt(v) } // decUint32 decodes an unsigned integer and stores it as a uint32 in value. @@ -221,19 +223,19 @@ func decUint32(i *decInstr, state *decoderState, value reflect.Value) { if math.MaxUint32 < v { error_(i.ovfl) } - decAlloc(value).SetUint(v) + value.SetUint(v) } // decInt64 decodes an integer and stores it as an int64 in value. func decInt64(i *decInstr, state *decoderState, value reflect.Value) { v := state.decodeInt() - decAlloc(value).SetInt(v) + value.SetInt(v) } // decUint64 decodes an unsigned integer and stores it as a uint64 in value. func decUint64(i *decInstr, state *decoderState, value reflect.Value) { v := state.decodeUint() - decAlloc(value).SetUint(v) + value.SetUint(v) } // Floating-point numbers are transmitted as uint64s holding the bits @@ -271,13 +273,13 @@ func float32FromBits(i *decInstr, u uint64) float64 { // decFloat32 decodes an unsigned integer, treats it as a 32-bit floating-point // number, and stores it in value. func decFloat32(i *decInstr, state *decoderState, value reflect.Value) { - decAlloc(value).SetFloat(float32FromBits(i, state.decodeUint())) + value.SetFloat(float32FromBits(i, state.decodeUint())) } // decFloat64 decodes an unsigned integer, treats it as a 64-bit floating-point // number, and stores it in value. func decFloat64(i *decInstr, state *decoderState, value reflect.Value) { - decAlloc(value).SetFloat(float64FromBits(state.decodeUint())) + value.SetFloat(float64FromBits(state.decodeUint())) } // decComplex64 decodes a pair of unsigned integers, treats them as a @@ -286,7 +288,7 @@ func decFloat64(i *decInstr, state *decoderState, value reflect.Value) { func decComplex64(i *decInstr, state *decoderState, value reflect.Value) { real := float32FromBits(i, state.decodeUint()) imag := float32FromBits(i, state.decodeUint()) - decAlloc(value).SetComplex(complex(real, imag)) + value.SetComplex(complex(real, imag)) } // decComplex128 decodes a pair of unsigned integers, treats them as a @@ -295,7 +297,7 @@ func decComplex64(i *decInstr, state *decoderState, value reflect.Value) { func decComplex128(i *decInstr, state *decoderState, value reflect.Value) { real := float64FromBits(state.decodeUint()) imag := float64FromBits(state.decodeUint()) - decAlloc(value).SetComplex(complex(real, imag)) + value.SetComplex(complex(real, imag)) } // decUint8Slice decodes a byte slice and stores in value a slice header @@ -310,7 +312,6 @@ func decUint8Slice(i *decInstr, state *decoderState, value reflect.Value) { if n > state.b.Len() { errorf("%s data too long for buffer: %d", value.Type(), n) } - value = decAlloc(value) if value.Cap() < n { value.Set(reflect.MakeSlice(value.Type(), n, n)) } else { @@ -338,7 +339,7 @@ func decString(i *decInstr, state *decoderState, value reflect.Value) { if _, err := state.b.Read(data); err != nil { errorf("error decoding string: %s", err) } - decAlloc(value).SetString(string(data)) + value.SetString(string(data)) } // ignoreUint8Array skips over the data for a byte slice value with no destination. @@ -376,8 +377,6 @@ func (dec *Decoder) decodeSingle(engine *decEngine, ut *userTypeInfo, value refl // This state cannot arise for decodeSingle, which is called directly // from the user's value, not from the innards of an engine. func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, value reflect.Value) { - value = decAlloc(value) - // println(value.Kind() == reflect.Ptr) state := dec.newDecoderState(&dec.buf) defer dec.freeDecoderState(state) state.fieldnum = -1 @@ -399,6 +398,9 @@ func (dec *Decoder) decodeStruct(engine *decEngine, ut *userTypeInfo, value refl if instr.index != nil { // Otherwise the field is unknown to us and instr.op is an ignore op. field = value.FieldByIndex(instr.index) + if field.Kind() == reflect.Ptr { + field = decAlloc(field) + } } instr.op(instr, state, field) state.fieldnum = fieldnum @@ -447,11 +449,16 @@ func (dec *Decoder) ignoreSingle(engine *decEngine) { // decodeArrayHelper does the work for decoding arrays and slices. func (dec *Decoder) decodeArrayHelper(state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error) { instr := &decInstr{elemOp, 0, nil, ovfl} + isPtr := value.Type().Elem().Kind() == reflect.Ptr for i := 0; i < length; i++ { if state.b.Len() == 0 { errorf("decoding array or slice: length exceeds input size (%d elements)", length) } - elemOp(instr, state, value.Index(i)) + v := value.Index(i) + if isPtr { + v = decAlloc(v) + } + elemOp(instr, state, v) } } @@ -459,7 +466,6 @@ func (dec *Decoder) decodeArrayHelper(state *decoderState, value reflect.Value, // The length is an unsigned integer preceding the elements. Even though the length is redundant // (it's part of the type), it's a useful check and is included in the encoding. func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, value reflect.Value, elemOp decOp, length int, ovfl error) { - value = decAlloc(value) if n := state.decodeUint(); n != uint64(length) { errorf("length mismatch in decodeArray") } @@ -467,9 +473,13 @@ func (dec *Decoder) decodeArray(atyp reflect.Type, state *decoderState, value re } // decodeIntoValue is a helper for map decoding. -func decodeIntoValue(state *decoderState, op decOp, value reflect.Value, ovfl error) reflect.Value { +func decodeIntoValue(state *decoderState, op decOp, isPtr bool, value reflect.Value, ovfl error) reflect.Value { instr := &decInstr{op, 0, nil, ovfl} - op(instr, state, value) + v := value + if isPtr { + v = decAlloc(value) + } + op(instr, state, v) return value } @@ -478,15 +488,16 @@ func decodeIntoValue(state *decoderState, op decOp, value reflect.Value, ovfl er // Because the internals of maps are not visible to us, we must // use reflection rather than pointer magic. func (dec *Decoder) decodeMap(mtyp reflect.Type, state *decoderState, value reflect.Value, keyOp, elemOp decOp, ovfl error) { - value = decAlloc(value) if value.IsNil() { // Allocate map. value.Set(reflect.MakeMap(mtyp)) } n := int(state.decodeUint()) + keyIsPtr := mtyp.Key().Kind() == reflect.Ptr + elemIsPtr := mtyp.Elem().Kind() == reflect.Ptr for i := 0; i < n; i++ { - key := decodeIntoValue(state, keyOp, allocValue(mtyp.Key()), ovfl) - elem := decodeIntoValue(state, elemOp, allocValue(mtyp.Elem()), ovfl) + key := decodeIntoValue(state, keyOp, keyIsPtr, allocValue(mtyp.Key()), ovfl) + elem := decodeIntoValue(state, elemOp, elemIsPtr, allocValue(mtyp.Elem()), ovfl) value.SetMapIndex(key, elem) } } @@ -528,7 +539,6 @@ func (dec *Decoder) decodeSlice(state *decoderState, value reflect.Value, elemOp // of interfaces, there will be buffer reloads. errorf("length of %s is negative (%d bytes)", value.Type(), u) } - value = decAlloc(value) if value.Cap() < n { value.Set(reflect.MakeSlice(value.Type(), n, n)) } else { @@ -558,7 +568,6 @@ func (dec *Decoder) decodeInterface(ityp reflect.Type, state *decoderState, valu state.b.Read(b) name := string(b) // Allocate the destination interface value. - value = decAlloc(value) if name == "" { // Copy the nil interface value to the target. value.Set(reflect.Zero(value.Type())) @@ -834,7 +843,6 @@ func (dec *Decoder) gobDecodeOpFor(ut *userTypeInfo) *decOp { } var op decOp op = func(i *decInstr, state *decoderState, value reflect.Value) { - value = decAlloc(value) // We now have the base type. We need its address if the receiver is a pointer. if value.Kind() != reflect.Ptr && rcvrType.Kind() == reflect.Ptr { value = value.Addr() @@ -1072,6 +1080,7 @@ func (dec *Decoder) decodeValue(wireId typeId, value reflect.Value) { if dec.err != nil { return } + value = decAlloc(value) engine := *enginePtr if st := base; st.Kind() == reflect.Struct && ut.externalDec == 0 { if engine.numInstr == 0 && st.NumField() > 0 &&