From 95cbcc5c1c5db05e659b769784e57db777c4cd6a Mon Sep 17 00:00:00 2001 From: Ariel Mashraki Date: Fri, 6 Sep 2019 18:04:51 +0300 Subject: [PATCH] text/template: support all comparable types in eq Extends the built-in eq function to support all Go comparable types. Fixes #33740 Change-Id: I522310e313e251c4dc6a013d33d7c2034fe2ec8e Reviewed-on: https://go-review.googlesource.com/c/go/+/193837 Run-TryBot: Rob Pike TryBot-Result: Gobot Gobot Reviewed-by: Rob Pike --- src/text/template/doc.go | 14 ++++++------- src/text/template/exec_test.go | 36 +++++++++++++++++++++++++++------- src/text/template/funcs.go | 22 +++++++++++++-------- 3 files changed, 49 insertions(+), 23 deletions(-) diff --git a/src/text/template/doc.go b/src/text/template/doc.go index 22266143a6e..4b0efd2df87 100644 --- a/src/text/template/doc.go +++ b/src/text/template/doc.go @@ -385,14 +385,12 @@ returning in effect (Unlike with || in Go, however, eq is a function call and all the arguments will be evaluated.) -The comparison functions work on basic types only (or named basic -types, such as "type Celsius float32"). They implement the Go rules -for comparison of values, except that size and exact type are -ignored, so any integer value, signed or unsigned, may be compared -with any other integer value. (The arithmetic value is compared, -not the bit pattern, so all negative integers are less than all -unsigned integers.) However, as usual, one may not compare an int -with a float32 and so on. +The comparison functions work on any values whose type Go defines as +comparable. For basic types such as integers, the rules are relaxed: +size and exact type are ignored, so any integer value, signed or unsigned, +may be compared with any other integer value. (The arithmetic value is compared, +not the bit pattern, so all negative integers are less than all unsigned integers.) +However, as usual, one may not compare an int with a float32 and so on. Associated templates diff --git a/src/text/template/exec_test.go b/src/text/template/exec_test.go index 81f9e044764..7f2305ace00 100644 --- a/src/text/template/exec_test.go +++ b/src/text/template/exec_test.go @@ -1156,19 +1156,41 @@ var cmpTests = []cmpTest{ {"ge .Uthree .NegOne", "true", true}, {"eq (index `x` 0) 'x'", "true", true}, // The example that triggered this rule. {"eq (index `x` 0) 'y'", "false", true}, + {"eq .V1 .V2", "true", true}, + {"eq .Ptr .Ptr", "true", true}, + {"eq .Ptr .NilPtr", "false", true}, + {"eq .NilPtr .NilPtr", "true", true}, + {"eq .Iface1 .Iface1", "true", true}, + {"eq .Iface1 .Iface2", "false", true}, + {"eq .Iface2 .Iface2", "true", true}, // Errors - {"eq `xy` 1", "", false}, // Different types. - {"eq 2 2.0", "", false}, // Different types. - {"lt true true", "", false}, // Unordered types. - {"lt 1+0i 1+0i", "", false}, // Unordered types. + {"eq `xy` 1", "", false}, // Different types. + {"eq 2 2.0", "", false}, // Different types. + {"lt true true", "", false}, // Unordered types. + {"lt 1+0i 1+0i", "", false}, // Unordered types. + {"eq .Ptr 1", "", false}, // Incompatible types. + {"eq .Ptr .NegOne", "", false}, // Incompatible types. + {"eq .Map .Map", "", false}, // Uncomparable types. + {"eq .Map .V1", "", false}, // Uncomparable types. } func TestComparison(t *testing.T) { b := new(bytes.Buffer) var cmpStruct = struct { - Uthree, Ufour uint - NegOne, Three int - }{3, 4, -1, 3} + Uthree, Ufour uint + NegOne, Three int + Ptr, NilPtr *int + Map map[int]int + V1, V2 V + Iface1, Iface2 fmt.Stringer + }{ + Uthree: 3, + Ufour: 4, + NegOne: -1, + Three: 3, + Ptr: new(int), + Iface1: b, + } for _, test := range cmpTests { text := fmt.Sprintf("{{if %s}}true{{else}}false{{end}}", test.expr) tmpl, err := New("empty").Parse(text) diff --git a/src/text/template/funcs.go b/src/text/template/funcs.go index 248dbcf22ed..0985eda3175 100644 --- a/src/text/template/funcs.go +++ b/src/text/template/funcs.go @@ -441,19 +441,18 @@ func basicKind(v reflect.Value) (kind, error) { // eq evaluates the comparison a == b || a == c || ... func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) { v1 := indirectInterface(arg1) - k1, err := basicKind(v1) - if err != nil { - return false, err + if v1 != zero { + if t1 := v1.Type(); !t1.Comparable() { + return false, fmt.Errorf("uncomparable type %s: %v", t1, v1) + } } if len(arg2) == 0 { return false, errNoComparison } + k1, _ := basicKind(v1) for _, arg := range arg2 { v2 := indirectInterface(arg) - k2, err := basicKind(v2) - if err != nil { - return false, err - } + k2, _ := basicKind(v2) truth := false if k1 != k2 { // Special case: Can compare integer values regardless of type's sign. @@ -480,7 +479,14 @@ func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) { case uintKind: truth = v1.Uint() == v2.Uint() default: - panic("invalid kind") + if v2 == zero { + truth = v1 == v2 + } else { + if t2 := v2.Type(); !t2.Comparable() { + return false, fmt.Errorf("uncomparable type %s: %v", t2, v2) + } + truth = v1.Interface() == v2.Interface() + } } } if truth {