mirror of
https://github.com/golang/go
synced 2024-11-17 21:54:49 -07:00
text/template: make reflect.Value indirections more robust
Always shadow or modify the original parameter name. With code like: func index(item reflect.Value, ... { v := indirectInterface(item) It was possible to incorrectly use 'item' and 'v' later in the function, which could result in subtle bugs. This is precisely the kind of mistake that led to #36199. Instead, don't keep both the old and new reflect.Value variables in scope. Always shadow or modify the original variable. While at it, simplify the signature of 'length', to receive a reflect.Value directly and save a few redundant lines. Change-Id: I01416636a9d49f81246d28b91aca6413b1ba1aa5 Reviewed-on: https://go-review.googlesource.com/c/go/+/212117 Run-TryBot: Daniel Martí <mvdan@mvdan.cc> TryBot-Result: Gobot Gobot <gobot@golang.org> Reviewed-by: Roberto Clapis <robclap8@gmail.com> Reviewed-by: Rob Pike <r@golang.org>
This commit is contained in:
parent
31acdcc701
commit
218f4572f5
@ -185,41 +185,41 @@ func indexArg(index reflect.Value, cap int) (int, error) {
|
|||||||
// arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
|
// arguments. Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
|
||||||
// indexed item must be a map, slice, or array.
|
// indexed item must be a map, slice, or array.
|
||||||
func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
|
func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
|
||||||
v := indirectInterface(item)
|
item = indirectInterface(item)
|
||||||
if !v.IsValid() {
|
if !item.IsValid() {
|
||||||
return reflect.Value{}, fmt.Errorf("index of untyped nil")
|
return reflect.Value{}, fmt.Errorf("index of untyped nil")
|
||||||
}
|
}
|
||||||
for _, i := range indexes {
|
for _, index := range indexes {
|
||||||
index := indirectInterface(i)
|
index = indirectInterface(index)
|
||||||
var isNil bool
|
var isNil bool
|
||||||
if v, isNil = indirect(v); isNil {
|
if item, isNil = indirect(item); isNil {
|
||||||
return reflect.Value{}, fmt.Errorf("index of nil pointer")
|
return reflect.Value{}, fmt.Errorf("index of nil pointer")
|
||||||
}
|
}
|
||||||
switch v.Kind() {
|
switch item.Kind() {
|
||||||
case reflect.Array, reflect.Slice, reflect.String:
|
case reflect.Array, reflect.Slice, reflect.String:
|
||||||
x, err := indexArg(index, v.Len())
|
x, err := indexArg(index, item.Len())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reflect.Value{}, err
|
return reflect.Value{}, err
|
||||||
}
|
}
|
||||||
v = v.Index(x)
|
item = item.Index(x)
|
||||||
case reflect.Map:
|
case reflect.Map:
|
||||||
index, err := prepareArg(index, v.Type().Key())
|
index, err := prepareArg(index, item.Type().Key())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return reflect.Value{}, err
|
return reflect.Value{}, err
|
||||||
}
|
}
|
||||||
if x := v.MapIndex(index); x.IsValid() {
|
if x := item.MapIndex(index); x.IsValid() {
|
||||||
v = x
|
item = x
|
||||||
} else {
|
} else {
|
||||||
v = reflect.Zero(v.Type().Elem())
|
item = reflect.Zero(item.Type().Elem())
|
||||||
}
|
}
|
||||||
case reflect.Invalid:
|
case reflect.Invalid:
|
||||||
// the loop holds invariant: v.IsValid()
|
// the loop holds invariant: item.IsValid()
|
||||||
panic("unreachable")
|
panic("unreachable")
|
||||||
default:
|
default:
|
||||||
return reflect.Value{}, fmt.Errorf("can't index item of type %s", v.Type())
|
return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return v, nil
|
return item, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Slicing.
|
// Slicing.
|
||||||
@ -229,29 +229,27 @@ func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error)
|
|||||||
// is x[:], "slice x 1" is x[1:], and "slice x 1 2 3" is x[1:2:3]. The first
|
// is x[:], "slice x 1" is x[1:], and "slice x 1 2 3" is x[1:2:3]. The first
|
||||||
// argument must be a string, slice, or array.
|
// argument must be a string, slice, or array.
|
||||||
func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
|
func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
|
||||||
var (
|
item = indirectInterface(item)
|
||||||
cap int
|
if !item.IsValid() {
|
||||||
v = indirectInterface(item)
|
|
||||||
)
|
|
||||||
if !v.IsValid() {
|
|
||||||
return reflect.Value{}, fmt.Errorf("slice of untyped nil")
|
return reflect.Value{}, fmt.Errorf("slice of untyped nil")
|
||||||
}
|
}
|
||||||
if len(indexes) > 3 {
|
if len(indexes) > 3 {
|
||||||
return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
|
return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
|
||||||
}
|
}
|
||||||
switch v.Kind() {
|
var cap int
|
||||||
|
switch item.Kind() {
|
||||||
case reflect.String:
|
case reflect.String:
|
||||||
if len(indexes) == 3 {
|
if len(indexes) == 3 {
|
||||||
return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
|
return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
|
||||||
}
|
}
|
||||||
cap = v.Len()
|
cap = item.Len()
|
||||||
case reflect.Array, reflect.Slice:
|
case reflect.Array, reflect.Slice:
|
||||||
cap = v.Cap()
|
cap = item.Cap()
|
||||||
default:
|
default:
|
||||||
return reflect.Value{}, fmt.Errorf("can't slice item of type %s", v.Type())
|
return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
|
||||||
}
|
}
|
||||||
// set default values for cases item[:], item[i:].
|
// set default values for cases item[:], item[i:].
|
||||||
idx := [3]int{0, v.Len()}
|
idx := [3]int{0, item.Len()}
|
||||||
for i, index := range indexes {
|
for i, index := range indexes {
|
||||||
x, err := indexArg(index, cap)
|
x, err := indexArg(index, cap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -264,32 +262,28 @@ func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error)
|
|||||||
return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
|
return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
|
||||||
}
|
}
|
||||||
if len(indexes) < 3 {
|
if len(indexes) < 3 {
|
||||||
return v.Slice(idx[0], idx[1]), nil
|
return item.Slice(idx[0], idx[1]), nil
|
||||||
}
|
}
|
||||||
// given item[i:j:k], make sure i <= j <= k.
|
// given item[i:j:k], make sure i <= j <= k.
|
||||||
if idx[1] > idx[2] {
|
if idx[1] > idx[2] {
|
||||||
return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
|
return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
|
||||||
}
|
}
|
||||||
return v.Slice3(idx[0], idx[1], idx[2]), nil
|
return item.Slice3(idx[0], idx[1], idx[2]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Length
|
// Length
|
||||||
|
|
||||||
// length returns the length of the item, with an error if it has no defined length.
|
// length returns the length of the item, with an error if it has no defined length.
|
||||||
func length(item interface{}) (int, error) {
|
func length(item reflect.Value) (int, error) {
|
||||||
v := reflect.ValueOf(item)
|
item, isNil := indirect(item)
|
||||||
if !v.IsValid() {
|
|
||||||
return 0, fmt.Errorf("len of untyped nil")
|
|
||||||
}
|
|
||||||
v, isNil := indirect(v)
|
|
||||||
if isNil {
|
if isNil {
|
||||||
return 0, fmt.Errorf("len of nil pointer")
|
return 0, fmt.Errorf("len of nil pointer")
|
||||||
}
|
}
|
||||||
switch v.Kind() {
|
switch item.Kind() {
|
||||||
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
|
case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
|
||||||
return v.Len(), nil
|
return item.Len(), nil
|
||||||
}
|
}
|
||||||
return 0, fmt.Errorf("len of type %s", v.Type())
|
return 0, fmt.Errorf("len of type %s", item.Type())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Function invocation
|
// Function invocation
|
||||||
@ -297,11 +291,11 @@ func length(item interface{}) (int, error) {
|
|||||||
// call returns the result of evaluating the first argument as a function.
|
// call returns the result of evaluating the first argument as a function.
|
||||||
// The function must return 1 result, or 2 results, the second of which is an error.
|
// The function must return 1 result, or 2 results, the second of which is an error.
|
||||||
func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
|
func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
|
||||||
v := indirectInterface(fn)
|
fn = indirectInterface(fn)
|
||||||
if !v.IsValid() {
|
if !fn.IsValid() {
|
||||||
return reflect.Value{}, fmt.Errorf("call of nil")
|
return reflect.Value{}, fmt.Errorf("call of nil")
|
||||||
}
|
}
|
||||||
typ := v.Type()
|
typ := fn.Type()
|
||||||
if typ.Kind() != reflect.Func {
|
if typ.Kind() != reflect.Func {
|
||||||
return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
|
return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
|
||||||
}
|
}
|
||||||
@ -322,7 +316,7 @@ func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
|
|||||||
}
|
}
|
||||||
argv := make([]reflect.Value, len(args))
|
argv := make([]reflect.Value, len(args))
|
||||||
for i, arg := range args {
|
for i, arg := range args {
|
||||||
value := indirectInterface(arg)
|
arg = indirectInterface(arg)
|
||||||
// Compute the expected type. Clumsy because of variadics.
|
// Compute the expected type. Clumsy because of variadics.
|
||||||
argType := dddType
|
argType := dddType
|
||||||
if !typ.IsVariadic() || i < numIn-1 {
|
if !typ.IsVariadic() || i < numIn-1 {
|
||||||
@ -330,11 +324,11 @@ func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if argv[i], err = prepareArg(value, argType); err != nil {
|
if argv[i], err = prepareArg(arg, argType); err != nil {
|
||||||
return reflect.Value{}, fmt.Errorf("arg %d: %s", i, err)
|
return reflect.Value{}, fmt.Errorf("arg %d: %s", i, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return safeCall(v, argv)
|
return safeCall(fn, argv)
|
||||||
}
|
}
|
||||||
|
|
||||||
// safeCall runs fun.Call(args), and returns the resulting value and error, if
|
// safeCall runs fun.Call(args), and returns the resulting value and error, if
|
||||||
@ -440,52 +434,52 @@ func basicKind(v reflect.Value) (kind, error) {
|
|||||||
|
|
||||||
// eq evaluates the comparison a == b || a == c || ...
|
// eq evaluates the comparison a == b || a == c || ...
|
||||||
func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
|
func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
|
||||||
v1 := indirectInterface(arg1)
|
arg1 = indirectInterface(arg1)
|
||||||
if v1 != zero {
|
if arg1 != zero {
|
||||||
if t1 := v1.Type(); !t1.Comparable() {
|
if t1 := arg1.Type(); !t1.Comparable() {
|
||||||
return false, fmt.Errorf("uncomparable type %s: %v", t1, v1)
|
return false, fmt.Errorf("uncomparable type %s: %v", t1, arg1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(arg2) == 0 {
|
if len(arg2) == 0 {
|
||||||
return false, errNoComparison
|
return false, errNoComparison
|
||||||
}
|
}
|
||||||
k1, _ := basicKind(v1)
|
k1, _ := basicKind(arg1)
|
||||||
for _, arg := range arg2 {
|
for _, arg := range arg2 {
|
||||||
v2 := indirectInterface(arg)
|
arg = indirectInterface(arg)
|
||||||
k2, _ := basicKind(v2)
|
k2, _ := basicKind(arg)
|
||||||
truth := false
|
truth := false
|
||||||
if k1 != k2 {
|
if k1 != k2 {
|
||||||
// Special case: Can compare integer values regardless of type's sign.
|
// Special case: Can compare integer values regardless of type's sign.
|
||||||
switch {
|
switch {
|
||||||
case k1 == intKind && k2 == uintKind:
|
case k1 == intKind && k2 == uintKind:
|
||||||
truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
|
truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
|
||||||
case k1 == uintKind && k2 == intKind:
|
case k1 == uintKind && k2 == intKind:
|
||||||
truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
|
truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
|
||||||
default:
|
default:
|
||||||
return false, errBadComparison
|
return false, errBadComparison
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
switch k1 {
|
switch k1 {
|
||||||
case boolKind:
|
case boolKind:
|
||||||
truth = v1.Bool() == v2.Bool()
|
truth = arg1.Bool() == arg.Bool()
|
||||||
case complexKind:
|
case complexKind:
|
||||||
truth = v1.Complex() == v2.Complex()
|
truth = arg1.Complex() == arg.Complex()
|
||||||
case floatKind:
|
case floatKind:
|
||||||
truth = v1.Float() == v2.Float()
|
truth = arg1.Float() == arg.Float()
|
||||||
case intKind:
|
case intKind:
|
||||||
truth = v1.Int() == v2.Int()
|
truth = arg1.Int() == arg.Int()
|
||||||
case stringKind:
|
case stringKind:
|
||||||
truth = v1.String() == v2.String()
|
truth = arg1.String() == arg.String()
|
||||||
case uintKind:
|
case uintKind:
|
||||||
truth = v1.Uint() == v2.Uint()
|
truth = arg1.Uint() == arg.Uint()
|
||||||
default:
|
default:
|
||||||
if v2 == zero {
|
if arg == zero {
|
||||||
truth = v1 == v2
|
truth = arg1 == arg
|
||||||
} else {
|
} else {
|
||||||
if t2 := v2.Type(); !t2.Comparable() {
|
if t2 := arg.Type(); !t2.Comparable() {
|
||||||
return false, fmt.Errorf("uncomparable type %s: %v", t2, v2)
|
return false, fmt.Errorf("uncomparable type %s: %v", t2, arg)
|
||||||
}
|
}
|
||||||
truth = v1.Interface() == v2.Interface()
|
truth = arg1.Interface() == arg.Interface()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -505,13 +499,13 @@ func ne(arg1, arg2 reflect.Value) (bool, error) {
|
|||||||
|
|
||||||
// lt evaluates the comparison a < b.
|
// lt evaluates the comparison a < b.
|
||||||
func lt(arg1, arg2 reflect.Value) (bool, error) {
|
func lt(arg1, arg2 reflect.Value) (bool, error) {
|
||||||
v1 := indirectInterface(arg1)
|
arg1 = indirectInterface(arg1)
|
||||||
k1, err := basicKind(v1)
|
k1, err := basicKind(arg1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
v2 := indirectInterface(arg2)
|
arg2 = indirectInterface(arg2)
|
||||||
k2, err := basicKind(v2)
|
k2, err := basicKind(arg2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@ -520,9 +514,9 @@ func lt(arg1, arg2 reflect.Value) (bool, error) {
|
|||||||
// Special case: Can compare integer values regardless of type's sign.
|
// Special case: Can compare integer values regardless of type's sign.
|
||||||
switch {
|
switch {
|
||||||
case k1 == intKind && k2 == uintKind:
|
case k1 == intKind && k2 == uintKind:
|
||||||
truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
|
truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
|
||||||
case k1 == uintKind && k2 == intKind:
|
case k1 == uintKind && k2 == intKind:
|
||||||
truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
|
truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
|
||||||
default:
|
default:
|
||||||
return false, errBadComparison
|
return false, errBadComparison
|
||||||
}
|
}
|
||||||
@ -531,13 +525,13 @@ func lt(arg1, arg2 reflect.Value) (bool, error) {
|
|||||||
case boolKind, complexKind:
|
case boolKind, complexKind:
|
||||||
return false, errBadComparisonType
|
return false, errBadComparisonType
|
||||||
case floatKind:
|
case floatKind:
|
||||||
truth = v1.Float() < v2.Float()
|
truth = arg1.Float() < arg2.Float()
|
||||||
case intKind:
|
case intKind:
|
||||||
truth = v1.Int() < v2.Int()
|
truth = arg1.Int() < arg2.Int()
|
||||||
case stringKind:
|
case stringKind:
|
||||||
truth = v1.String() < v2.String()
|
truth = arg1.String() < arg2.String()
|
||||||
case uintKind:
|
case uintKind:
|
||||||
truth = v1.Uint() < v2.Uint()
|
truth = arg1.Uint() < arg2.Uint()
|
||||||
default:
|
default:
|
||||||
panic("invalid kind")
|
panic("invalid kind")
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user