From e46acb091fa7dfbdb99e61801f37522bd8b80365 Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Thu, 3 Mar 2011 13:20:17 -0500 Subject: [PATCH] reflect: add PtrTo, add Value.Addr (old Addr is now UnsafeAddr) This change makes it possible to take the address of a struct field or slice element in order to call a method that requires a pointer receiver. Existing code that uses the Value.Addr method will have to change (as gob does in this CL) to call UnsafeAddr instead. R=r, rog CC=golang-dev https://golang.org/cl/4239052 --- src/pkg/gob/decode.go | 6 +- src/pkg/gob/encode.go | 6 +- src/pkg/reflect/all_test.go | 75 +++++++++++++++++ src/pkg/reflect/deepequal.go | 4 +- src/pkg/reflect/type.go | 88 ++++++++++++++++++- src/pkg/reflect/value.go | 158 +++++++++++++++++++++++++---------- 6 files changed, 285 insertions(+), 52 deletions(-) diff --git a/src/pkg/gob/decode.go b/src/pkg/gob/decode.go index ad03d176b8c..8f599e10041 100644 --- a/src/pkg/gob/decode.go +++ b/src/pkg/gob/decode.go @@ -518,7 +518,7 @@ func (dec *Decoder) decodeArray(atyp *reflect.ArrayType, state *decodeState, p u func decodeIntoValue(state *decodeState, op decOp, indir int, v reflect.Value, ovfl os.ErrorString) reflect.Value { instr := &decInstr{op, 0, indir, 0, ovfl} - up := unsafe.Pointer(v.Addr()) + up := unsafe.Pointer(v.UnsafeAddr()) if indir > 1 { up = decIndirect(up, indir) } @@ -1052,9 +1052,9 @@ func (dec *Decoder) decodeValue(wireId typeId, val reflect.Value) (err os.Error) name := base.Name() return os.ErrorString("gob: type mismatch: no fields matched compiling decoder for " + name) } - return dec.decodeStruct(engine, ut, uintptr(val.Addr()), indir) + return dec.decodeStruct(engine, ut, uintptr(val.UnsafeAddr()), indir) } - return dec.decodeSingle(engine, ut, uintptr(val.Addr())) + return dec.decodeSingle(engine, ut, uintptr(val.UnsafeAddr())) } func (dec *Decoder) decodeIgnoredValue(wireId typeId) os.Error { diff --git a/src/pkg/gob/encode.go b/src/pkg/gob/encode.go index 5f4fc6f34bd..e92db74ffdd 100644 --- a/src/pkg/gob/encode.go +++ b/src/pkg/gob/encode.go @@ -356,7 +356,7 @@ func encodeReflectValue(state *encoderState, v reflect.Value, op encOp, indir in if v == nil { errorf("gob: encodeReflectValue: nil element") } - op(nil, state, unsafe.Pointer(v.Addr())) + op(nil, state, unsafe.Pointer(v.UnsafeAddr())) } func (enc *Encoder) encodeMap(b *bytes.Buffer, mv *reflect.MapValue, keyOp, elemOp encOp, keyIndir, elemIndir int) { @@ -575,9 +575,9 @@ func (enc *Encoder) encode(b *bytes.Buffer, value reflect.Value, ut *userTypeInf } engine := enc.lockAndGetEncEngine(ut.base) if value.Type().Kind() == reflect.Struct { - enc.encodeStruct(b, engine, value.Addr()) + enc.encodeStruct(b, engine, value.UnsafeAddr()) } else { - enc.encodeSingle(b, engine, value.Addr()) + enc.encodeSingle(b, engine, value.UnsafeAddr()) } return nil } diff --git a/src/pkg/reflect/all_test.go b/src/pkg/reflect/all_test.go index 3675be6f1c1..7a97ea17376 100644 --- a/src/pkg/reflect/all_test.go +++ b/src/pkg/reflect/all_test.go @@ -1093,6 +1093,18 @@ func TestMethod(t *testing.T) { t.Errorf("Value Method returned %d; want 250", i) } + // Curried method of pointer. + i = NewValue(&p).Method(0).Call([]Value{NewValue(10)})[0].(*IntValue).Get() + if i != 250 { + t.Errorf("Value Method returned %d; want 250", i) + } + + // Curried method of pointer to value. + i = NewValue(p).Addr().Method(0).Call([]Value{NewValue(10)})[0].(*IntValue).Get() + if i != 250 { + t.Errorf("Value Method returned %d; want 250", i) + } + // Curried method of interface value. // Have to wrap interface value in a struct to get at it. // Passing it to NewValue directly would @@ -1390,3 +1402,66 @@ func TestEmbeddedMethods(t *testing.T) { t.Errorf("f(o) = %d, want 2", v) } } + +func TestPtrTo(t *testing.T) { + var i int + + typ := Typeof(i) + for i = 0; i < 100; i++ { + typ = PtrTo(typ) + } + for i = 0; i < 100; i++ { + typ = typ.(*PtrType).Elem() + } + if typ != Typeof(i) { + t.Errorf("after 100 PtrTo and Elem, have %s, want %s", typ, Typeof(i)) + } +} + +func TestAddr(t *testing.T) { + var p struct { + X, Y int + } + + v := NewValue(&p) + v = v.(*PtrValue).Elem() + v = v.Addr() + v = v.(*PtrValue).Elem() + v = v.(*StructValue).Field(0) + v.(*IntValue).Set(2) + if p.X != 2 { + t.Errorf("Addr.Elem.Set failed to set value") + } + + // Again but take address of the NewValue value. + // Exercises generation of PtrTypes not present in the binary. + v = NewValue(&p) + v = v.Addr() + v = v.(*PtrValue).Elem() + v = v.(*PtrValue).Elem() + v = v.Addr() + v = v.(*PtrValue).Elem() + v = v.(*StructValue).Field(0) + v.(*IntValue).Set(3) + if p.X != 3 { + t.Errorf("Addr.Elem.Set failed to set value") + } + + // Starting without pointer we should get changed value + // in interface. + v = NewValue(p) + v0 := v + v = v.Addr() + v = v.(*PtrValue).Elem() + v = v.(*StructValue).Field(0) + v.(*IntValue).Set(4) + if p.X != 3 { // should be unchanged from last time + t.Errorf("somehow value Set changed original p") + } + p = v0.Interface().(struct { + X, Y int + }) + if p.X != 4 { + t.Errorf("Addr.Elem.Set valued to set value in top value") + } +} diff --git a/src/pkg/reflect/deepequal.go b/src/pkg/reflect/deepequal.go index a50925e51e6..c9beec50666 100644 --- a/src/pkg/reflect/deepequal.go +++ b/src/pkg/reflect/deepequal.go @@ -31,8 +31,8 @@ func deepValueEqual(v1, v2 Value, visited map[uintptr]*visit, depth int) bool { // if depth > 10 { panic("deepValueEqual") } // for debugging - addr1 := v1.Addr() - addr2 := v2.Addr() + addr1 := v1.UnsafeAddr() + addr2 := v2.UnsafeAddr() if addr1 > addr2 { // Canonicalize order to reduce number of entries in visited. addr1, addr2 = addr2, addr1 diff --git a/src/pkg/reflect/type.go b/src/pkg/reflect/type.go index efe0238eaa6..2cc1f576aa0 100644 --- a/src/pkg/reflect/type.go +++ b/src/pkg/reflect/type.go @@ -18,6 +18,7 @@ package reflect import ( "runtime" "strconv" + "sync" "unsafe" ) @@ -251,6 +252,8 @@ type Type interface { // NumMethods returns the number of methods in the type's method set. NumMethod() int + + common() *commonType uncommon() *uncommonType } @@ -361,6 +364,8 @@ func (t *commonType) FieldAlign() int { return int(t.fieldAlign) } func (t *commonType) Kind() Kind { return Kind(t.kind & kindMask) } +func (t *commonType) common() *commonType { return t } + func (t *uncommonType) Method(i int) (m Method) { if t == nil || i < 0 || i >= len(t.methods) { return @@ -374,7 +379,7 @@ func (t *uncommonType) Method(i int) (m Method) { } m.Type = toType(*p.typ).(*FuncType) fn := p.tfn - m.Func = &FuncValue{value: value{m.Type, addr(&fn), true}} + m.Func = &FuncValue{value: value{m.Type, addr(&fn), canSet}} return } @@ -689,3 +694,84 @@ type ArrayOrSliceType interface { // Typeof returns the reflection Type of the value in the interface{}. func Typeof(i interface{}) Type { return toType(unsafe.Typeof(i)) } + +// ptrMap is the cache for PtrTo. +var ptrMap struct { + sync.RWMutex + m map[Type]*PtrType +} + +// runtimePtrType is the runtime layout for a *PtrType. +// The memory immediately before the *PtrType is always +// the canonical runtime.Type to be used for a *runtime.Type +// describing this PtrType. +type runtimePtrType struct { + runtime.Type + runtime.PtrType +} + +// PtrTo returns the pointer type with element t. +// For example, if t represents type Foo, PtrTo(t) represents *Foo. +func PtrTo(t Type) *PtrType { + // If t records its pointer-to type, use it. + ct := t.common() + if p := ct.ptrToThis; p != nil { + return toType(*p).(*PtrType) + } + + // Otherwise, synthesize one. + // This only happens for pointers with no methods. + // We keep the mapping in a map on the side, because + // this operation is rare and a separate map lets us keep + // the type structures in read-only memory. + ptrMap.RLock() + if m := ptrMap.m; m != nil { + if p := m[t]; p != nil { + ptrMap.RUnlock() + return p + } + } + ptrMap.RUnlock() + ptrMap.Lock() + if ptrMap.m == nil { + ptrMap.m = make(map[Type]*PtrType) + } + p := ptrMap.m[t] + if p != nil { + // some other goroutine won the race and created it + ptrMap.Unlock() + return p + } + + // runtime.Type value is always right before type structure. + // 2*ptrSize is size of interface header + rt := (*runtime.Type)(unsafe.Pointer(uintptr(unsafe.Pointer(ct)) - uintptr(unsafe.Sizeof(runtime.Type(nil))))) + + rp := new(runtimePtrType) + rp.Type = &rp.PtrType + + // initialize rp.PtrType using *byte's PtrType as a prototype. + // have to do assignment as PtrType, not runtime.PtrType, + // in order to write to unexported fields. + p = (*PtrType)(unsafe.Pointer(&rp.PtrType)) + bp := (*PtrType)(unsafe.Pointer(unsafe.Typeof((*byte)(nil)).(*runtime.PtrType))) + *p = *bp + + s := "*" + *ct.string + p.string = &s + + // For the type structures linked into the binary, the + // compiler provides a good hash of the string. + // Create a good hash for the new string by using + // the FNV-1 hash's mixing function to combine the + // old hash and the new "*". + p.hash = ct.hash*16777619 ^ '*' + + p.uncommonType = nil + p.ptrToThis = nil + p.elem = rt + + ptrMap.m[t] = (*PtrType)(unsafe.Pointer(&rp.PtrType)) + ptrMap.Unlock() + return p +} diff --git a/src/pkg/reflect/value.go b/src/pkg/reflect/value.go index 4d7d8723731..0b70b17f8fa 100644 --- a/src/pkg/reflect/value.go +++ b/src/pkg/reflect/value.go @@ -11,7 +11,7 @@ import ( ) const ptrSize = uintptr(unsafe.Sizeof((*byte)(nil))) -const cannotSet = "cannot set value obtained via unexported struct field" +const cannotSet = "cannot set value obtained from unexported struct field" type addr unsafe.Pointer @@ -51,20 +51,32 @@ type Value interface { // Interface returns the value as an interface{}. Interface() interface{} - // CanSet returns whether the value can be changed. + // CanSet returns true if the value can be changed. // Values obtained by the use of non-exported struct fields // can be used in Get but not Set. - // If CanSet() returns false, calling the type-specific Set - // will cause a crash. + // If CanSet returns false, calling the type-specific Set will panic. CanSet() bool // SetValue assigns v to the value; v must have the same type as the value. SetValue(v Value) - // Addr returns a pointer to the underlying data. - // It is for advanced clients that also - // import the "unsafe" package. - Addr() uintptr + // CanAddr returns true if the value's address can be obtained with Addr. + // Such values are called addressable. A value is addressable if it is + // an element of a slice, an element of an addressable array, + // a field of an addressable struct, the result of dereferencing a pointer, + // or the result of a call to NewValue, MakeChan, MakeMap, or MakeZero. + // If CanAddr returns false, calling Addr will panic. + CanAddr() bool + + // Addr returns the address of the value. + // If the value is not addressable, Addr panics. + // Addr is typically used to obtain a pointer to a struct field or slice element + // in order to call a method that requires a pointer receiver. + Addr() *PtrValue + + // UnsafeAddr returns a pointer to the underlying data. + // It is for advanced clients that also import the "unsafe" package. + UnsafeAddr() uintptr // Method returns a FuncValue corresponding to the value's i'th method. // The arguments to a Call on the returned FuncValue @@ -75,19 +87,42 @@ type Value interface { getAddr() addr } +// flags for value +const ( + canSet uint32 = 1 << iota // can set value (write to *v.addr) + canAddr // can take address of value + canStore // can store through value (write to **v.addr) +) + // value is the common implementation of most values. // It is embedded in other, public struct types, but always // with a unique tag like "uint" or "float" so that the client cannot // convert from, say, *UintValue to *FloatValue. type value struct { - typ Type - addr addr - canSet bool + typ Type + addr addr + flag uint32 } func (v *value) Type() Type { return v.typ } -func (v *value) Addr() uintptr { return uintptr(v.addr) } +func (v *value) Addr() *PtrValue { + if !v.CanAddr() { + panic("reflect: cannot take address of value") + } + a := v.addr + flag := canSet + if v.CanSet() { + flag |= canStore + } + // We could safely set canAddr here too - + // the caller would get the address of a - + // but it doesn't match the Go model. + // The language doesn't let you say &&v. + return newValue(PtrTo(v.typ), addr(&a), flag).(*PtrValue) +} + +func (v *value) UnsafeAddr() uintptr { return uintptr(v.addr) } func (v *value) getAddr() addr { return v.addr } @@ -109,7 +144,10 @@ func (v *value) Interface() interface{} { return unsafe.Unreflect(v.typ, unsafe.Pointer(v.addr)) } -func (v *value) CanSet() bool { return v.canSet } +func (v *value) CanSet() bool { return v.flag&canSet != 0 } + +func (v *value) CanAddr() bool { return v.flag&canAddr != 0 } + /* * basic types @@ -125,7 +163,7 @@ func (v *BoolValue) Get() bool { return *(*bool)(v.addr) } // Set sets v to the value x. func (v *BoolValue) Set(x bool) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } *(*bool)(v.addr) = x @@ -152,7 +190,7 @@ func (v *FloatValue) Get() float64 { // Set sets v to the value x. func (v *FloatValue) Set(x float64) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } switch v.typ.Kind() { @@ -197,7 +235,7 @@ func (v *ComplexValue) Get() complex128 { // Set sets v to the value x. func (v *ComplexValue) Set(x complex128) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } switch v.typ.Kind() { @@ -237,7 +275,7 @@ func (v *IntValue) Get() int64 { // Set sets v to the value x. func (v *IntValue) Set(x int64) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } switch v.typ.Kind() { @@ -282,7 +320,7 @@ func (v *StringValue) Get() string { return *(*string)(v.addr) } // Set sets v to the value x. func (v *StringValue) Set(x string) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } *(*string)(v.addr) = x @@ -317,7 +355,7 @@ func (v *UintValue) Get() uint64 { // Set sets v to the value x. func (v *UintValue) Set(x uint64) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } switch v.typ.Kind() { @@ -361,7 +399,7 @@ func (v *UnsafePointerValue) Get() uintptr { return uintptr(*(*unsafe.Pointer)(v // Set sets v to the value x. func (v *UnsafePointerValue) Set(x unsafe.Pointer) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } *(*unsafe.Pointer)(v.addr) = x @@ -473,7 +511,7 @@ func (v *ArrayValue) addr() addr { return v.value.addr } // Set assigns x to v. // The new value x must have the same type as v. func (v *ArrayValue) Set(x *ArrayValue) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } typesMustMatch(v.typ, x.typ) @@ -491,7 +529,7 @@ func (v *ArrayValue) Elem(i int) Value { panic("array index out of bounds") } p := addr(uintptr(v.addr()) + uintptr(i)*typ.Size()) - return newValue(typ, p, v.canSet) + return newValue(typ, p, v.flag) } /* @@ -537,7 +575,7 @@ func (v *SliceValue) SetLen(n int) { // Set assigns x to v. // The new value x must have the same type as v. func (v *SliceValue) Set(x *SliceValue) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } typesMustMatch(v.typ, x.typ) @@ -566,7 +604,14 @@ func (v *SliceValue) Slice(beg, end int) *SliceValue { s.Data = uintptr(v.addr()) + uintptr(beg)*typ.Elem().Size() s.Len = end - beg s.Cap = cap - beg - return newValue(typ, addr(s), v.canSet).(*SliceValue) + + // Like the result of Addr, we treat Slice as an + // unaddressable temporary, so don't set canAddr. + flag := canSet + if v.flag&canStore != 0 { + flag |= canStore + } + return newValue(typ, addr(s), flag).(*SliceValue) } // Elem returns the i'th element of v. @@ -577,7 +622,11 @@ func (v *SliceValue) Elem(i int) Value { panic("reflect: slice index out of range") } p := addr(uintptr(v.addr()) + uintptr(i)*typ.Size()) - return newValue(typ, p, v.canSet) + flag := canAddr + if v.flag&canStore != 0 { + flag |= canSet | canStore + } + return newValue(typ, p, flag) } // MakeSlice creates a new zero-initialized slice value @@ -588,7 +637,7 @@ func MakeSlice(typ *SliceType, len, cap int) *SliceValue { Len: len, Cap: cap, } - return newValue(typ, addr(s), true).(*SliceValue) + return newValue(typ, addr(s), canAddr|canSet|canStore).(*SliceValue) } /* @@ -606,7 +655,7 @@ func (v *ChanValue) IsNil() bool { return *(*uintptr)(v.addr) == 0 } // Set assigns x to v. // The new value x must have the same type as v. func (v *ChanValue) Set(x *ChanValue) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } typesMustMatch(v.typ, x.typ) @@ -733,7 +782,7 @@ func (v *FuncValue) Get() uintptr { return *(*uintptr)(v.addr) } // Set assigns x to v. // The new value x must have the same type as v. func (v *FuncValue) Set(x *FuncValue) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } typesMustMatch(v.typ, x.typ) @@ -754,7 +803,7 @@ func (v *value) Method(i int) *FuncValue { } p := &t.methods[i] fn := p.tfn - fv := &FuncValue{value: value{toType(*p.typ), addr(&fn), true}, first: v, isInterface: false} + fv := &FuncValue{value: value{toType(*p.typ), addr(&fn), 0}, first: v, isInterface: false} return fv } @@ -765,6 +814,17 @@ type tiny struct { b byte } +// Interface returns the fv as an interface value. +// If fv is a method obtained by invoking Value.Method +// (as opposed to Type.Method), Interface cannot return an +// interface value, so it panics. +func (fv *FuncValue) Interface() interface{} { + if fv.first != nil { + panic("FuncValue: cannot create interface value for method with bound receiver") + } + return fv.value.Interface() +} + // Call calls the function fv with input parameters in. // It returns the function's output parameters as Values. func (fv *FuncValue) Call(in []Value) []Value { @@ -902,7 +962,7 @@ func (v *InterfaceValue) Set(x Value) { if x != nil { i = x.Interface() } - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } // Two different representations; see comment in Get. @@ -933,11 +993,11 @@ func (v *InterfaceValue) Method(i int) *FuncValue { // Interface is two words: itable, data. tab := *(**runtime.Itable)(v.addr) - data := &value{Typeof((*byte)(nil)), addr(uintptr(v.addr) + ptrSize), true} + data := &value{Typeof((*byte)(nil)), addr(uintptr(v.addr) + ptrSize), 0} // Function pointer is at p.perm in the table. fn := tab.Fn[i] - fv := &FuncValue{value: value{toType(*p.typ), addr(&fn), true}, first: data, isInterface: true} + fv := &FuncValue{value: value{toType(*p.typ), addr(&fn), 0}, first: data, isInterface: true} return fv } @@ -956,7 +1016,7 @@ func (v *MapValue) IsNil() bool { return *(*uintptr)(v.addr) == 0 } // Set assigns x to v. // The new value x must have the same type as v. func (v *MapValue) Set(x *MapValue) { - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } if x == nil { @@ -1075,15 +1135,18 @@ func (v *PtrValue) IsNil() bool { return *(*uintptr)(v.addr) == 0 } func (v *PtrValue) Get() uintptr { return *(*uintptr)(v.addr) } // Set assigns x to v. -// The new value x must have the same type as v. +// The new value x must have the same type as v, and x.Elem().CanSet() must be true. func (v *PtrValue) Set(x *PtrValue) { if x == nil { *(**uintptr)(v.addr) = nil return } - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } + if x.flag&canStore == 0 { + panic("cannot copy pointer obtained from unexported struct field") + } typesMustMatch(v.typ, x.typ) // TODO: This will have to move into the runtime // once the new gc goes in @@ -1112,7 +1175,7 @@ func (v *PtrValue) PointTo(x Value) { typesMustMatch(v.typ.(*PtrType).Elem(), x.Type()) // TODO: This will have to move into the runtime // once the new gc goes in. - *(*uintptr)(v.addr) = x.Addr() + *(*uintptr)(v.addr) = x.UnsafeAddr() } // Elem returns the value that v points to. @@ -1121,7 +1184,11 @@ func (v *PtrValue) Elem() Value { if v.IsNil() { return nil } - return newValue(v.typ.(*PtrType).Elem(), *(*addr)(v.addr), v.canSet) + flag := canAddr + if v.flag&canStore != 0 { + flag |= canSet | canStore + } + return newValue(v.typ.(*PtrType).Elem(), *(*addr)(v.addr), flag) } // Indirect returns the value that v points to. @@ -1148,7 +1215,7 @@ type StructValue struct { func (v *StructValue) Set(x *StructValue) { // TODO: This will have to move into the runtime // once the gc goes in. - if !v.canSet { + if !v.CanSet() { panic(cannotSet) } typesMustMatch(v.typ, x.typ) @@ -1165,7 +1232,12 @@ func (v *StructValue) Field(i int) Value { return nil } f := t.Field(i) - return newValue(f.Type, addr(uintptr(v.addr)+f.Offset), v.canSet && f.PkgPath == "") + flag := v.flag + if f.PkgPath != "" { + // unexported field + flag &^= canSet | canStore + } + return newValue(f.Type, addr(uintptr(v.addr)+f.Offset), flag) } // FieldByIndex returns the nested field corresponding to index. @@ -1221,11 +1293,11 @@ func NewValue(i interface{}) Value { return nil } t, a := unsafe.Reflect(i) - return newValue(toType(t), addr(a), true) + return newValue(toType(t), addr(a), canSet|canAddr|canStore) } -func newValue(typ Type, addr addr, canSet bool) Value { - v := value{typ, addr, canSet} +func newValue(typ Type, addr addr, flag uint32) Value { + v := value{typ, addr, flag} switch typ.(type) { case *ArrayType: return &ArrayValue{v} @@ -1266,5 +1338,5 @@ func MakeZero(typ Type) Value { if typ == nil { return nil } - return newValue(typ, addr(unsafe.New(typ)), true) + return newValue(typ, addr(unsafe.New(typ)), canSet|canAddr|canStore) }