diff --git a/src/pkg/reflect/all_test.go b/src/pkg/reflect/all_test.go index c83a9b75f63..94b0fb5b361 100644 --- a/src/pkg/reflect/all_test.go +++ b/src/pkg/reflect/all_test.go @@ -1050,6 +1050,12 @@ type Point struct { x, y int } +// This will be index 0. +func (p Point) AnotherMethod(scale int) int { + return -1 +} + +// This will be index 1. func (p Point) Dist(scale int) int { // println("Point.Dist", p.x, p.y, scale) return p.x*p.x*scale + p.y*p.y*scale @@ -1058,26 +1064,52 @@ func (p Point) Dist(scale int) int { func TestMethod(t *testing.T) { // Non-curried method of type. p := Point{3, 4} - i := TypeOf(p).Method(0).Func.Call([]Value{ValueOf(p), ValueOf(10)})[0].Int() + i := TypeOf(p).Method(1).Func.Call([]Value{ValueOf(p), ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Type Method returned %d; want 250", i) } - i = TypeOf(&p).Method(0).Func.Call([]Value{ValueOf(&p), ValueOf(10)})[0].Int() + m, ok := TypeOf(p).MethodByName("Dist") + if !ok { + t.Fatalf("method by name failed") + } + m.Func.Call([]Value{ValueOf(p), ValueOf(10)})[0].Int() + if i != 250 { + t.Errorf("Type MethodByName returned %d; want 250", i) + } + + i = TypeOf(&p).Method(1).Func.Call([]Value{ValueOf(&p), ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Pointer Type Method returned %d; want 250", i) } + m, ok = TypeOf(&p).MethodByName("Dist") + if !ok { + t.Fatalf("ptr method by name failed") + } + i = m.Func.Call([]Value{ValueOf(&p), ValueOf(10)})[0].Int() + if i != 250 { + t.Errorf("Pointer Type MethodByName returned %d; want 250", i) + } + // Curried method of value. - i = ValueOf(p).Method(0).Call([]Value{ValueOf(10)})[0].Int() + i = ValueOf(p).Method(1).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { t.Errorf("Value Method returned %d; want 250", i) } + i = ValueOf(p).MethodByName("Dist").Call([]Value{ValueOf(10)})[0].Int() + if i != 250 { + t.Errorf("Value MethodByName returned %d; want 250", i) + } // Curried method of pointer. - i = ValueOf(&p).Method(0).Call([]Value{ValueOf(10)})[0].Int() + i = ValueOf(&p).Method(1).Call([]Value{ValueOf(10)})[0].Int() if i != 250 { - t.Errorf("Value Method returned %d; want 250", i) + t.Errorf("Pointer Value Method returned %d; want 250", i) + } + i = ValueOf(&p).MethodByName("Dist").Call([]Value{ValueOf(10)})[0].Int() + if i != 250 { + t.Errorf("Pointer Value MethodByName returned %d; want 250", i) } // Curried method of interface value. @@ -1094,6 +1126,10 @@ func TestMethod(t *testing.T) { if i != 250 { t.Errorf("Interface Method returned %d; want 250", i) } + i = pv.MethodByName("Dist").Call([]Value{ValueOf(10)})[0].Int() + if i != 250 { + t.Errorf("Interface MethodByName returned %d; want 250", i) + } } func TestInterfaceSet(t *testing.T) { diff --git a/src/pkg/reflect/type.go b/src/pkg/reflect/type.go index 6c1ab60982d..f774f730752 100644 --- a/src/pkg/reflect/type.go +++ b/src/pkg/reflect/type.go @@ -47,6 +47,16 @@ type Type interface { // method signature, without a receiver, and the Func field is nil. Method(int) Method + // MethodByName returns the method with that name in the type's + // method set and a boolean indicating if the method was found. + // + // For a non-interface type T or *T, the returned Method's Type and Func + // fields describe a function whose first argument is the receiver. + // + // For an interface type, the returned Method's Type field gives the + // method signature, without a receiver, and the Func field is nil. + MethodByName(string) (Method, bool) + // NumMethod returns the number of methods in the type's method set. NumMethod() int @@ -344,6 +354,7 @@ type Method struct { Name string Type Type Func Value + Index int } // High bit says whether type has @@ -451,6 +462,7 @@ func (t *uncommonType) Method(i int) (m Method) { m.Type = toType(p.typ) fn := p.tfn m.Func = valueFromIword(flag, m.Type, iword(fn)) + m.Index = i return } @@ -461,6 +473,20 @@ func (t *uncommonType) NumMethod() int { return len(t.methods) } +func (t *uncommonType) MethodByName(name string) (m Method, ok bool) { + if t == nil { + return + } + var p *method + for i := range t.methods { + p = &t.methods[i] + if p.name != nil && *p.name == name { + return t.Method(i), true + } + } + return +} + // TODO(rsc): 6g supplies these, but they are not // as efficient as they could be: they have commonType // as the receiver instead of *commonType. @@ -480,6 +506,14 @@ func (t *commonType) Method(i int) (m Method) { return t.uncommonType.Method(i) } +func (t *commonType) MethodByName(name string) (m Method, ok bool) { + if t.Kind() == Interface { + tt := (*interfaceType)(unsafe.Pointer(t)) + return tt.MethodByName(name) + } + return t.uncommonType.MethodByName(name) +} + func (t *commonType) PkgPath() string { return t.uncommonType.PkgPath() } @@ -636,12 +670,28 @@ func (t *interfaceType) Method(i int) (m Method) { m.PkgPath = *p.pkgPath } m.Type = toType(p.typ) + m.Index = i return } // NumMethod returns the number of interface methods in the type's method set. func (t *interfaceType) NumMethod() int { return len(t.methods) } +// MethodByName method with the given name in the type's method set. +func (t *interfaceType) MethodByName(name string) (m Method, ok bool) { + if t == nil { + return + } + var p *imethod + for i := range t.methods { + p = &t.methods[i] + if *p.name == name { + return t.Method(i), true + } + } + return +} + type StructField struct { PkgPath string // empty for uppercase Name Name string diff --git a/src/pkg/reflect/value.go b/src/pkg/reflect/value.go index b1999aa6348..889d9455bda 100644 --- a/src/pkg/reflect/value.go +++ b/src/pkg/reflect/value.go @@ -1023,6 +1023,23 @@ func (v Value) Method(i int) Value { return Value{v.Internal, i + 1} } +// MethodByName returns a function value corresponding to the method +// of v with the given name. +// The arguments to a Call on the returned function should not include +// a receiver; the returned function will always use v as the receiver. +// It returns the zero Value if no method was found. +func (v Value) MethodByName(name string) Value { + iv := v.internal() + if iv.kind == Invalid { + panic(&ValueError{"reflect.Value.MethodByName", Invalid}) + } + m, ok := iv.typ.MethodByName(name) + if ok { + return Value{v.Internal, m.Index + 1} + } + return Value{} +} + // NumField returns the number of fields in the struct v. // It panics if v's Kind is not Struct. func (v Value) NumField() int {