From 7bbe32067922643a30ac7adf8aa3da9785d89d13 Mon Sep 17 00:00:00 2001 From: Rob Pike Date: Wed, 21 Aug 2013 11:27:27 +1000 Subject: [PATCH] text/template: implement comparison of basic types Add eq, lt, etc. to allow one to do simple comparisons. It's basic types only (booleans, integers, unsigned integers, floats, complex, string) because that's easy, easy to define, and covers the great majority of useful cases, while leaving open the possibility of a more sweeping definition later. {{if eq .X .Y}}X and Y are equal{{else}}X and Y are unequal{{end}} R=golang-dev, adg CC=golang-dev https://golang.org/cl/13091045 --- src/pkg/text/template/doc.go | 28 +++++- src/pkg/text/template/exec_test.go | 107 ++++++++++++++++++++ src/pkg/text/template/funcs.go | 154 +++++++++++++++++++++++++++++ 3 files changed, 287 insertions(+), 2 deletions(-) diff --git a/src/pkg/text/template/doc.go b/src/pkg/text/template/doc.go index c9121f74d3..b952789d1c 100644 --- a/src/pkg/text/template/doc.go +++ b/src/pkg/text/template/doc.go @@ -301,8 +301,32 @@ Predefined global functions are named as follows. Returns the escaped value of the textual representation of its arguments in a form suitable for embedding in a URL query. -The boolean functions take any zero value to be false and a non-zero value to -be true. +The boolean functions take any zero value to be false and a non-zero +value to be true. + +There is also a set of binary comparison operators defined as +functions: + + eq + Returns the boolean truth of arg1 == arg2 + ne + Returns the boolean truth of arg1 != arg2 + lt + Returns the boolean truth of arg1 < arg2 + le + Returns the boolean truth of arg1 <= arg2 + gt + Returns the boolean truth of arg1 > arg2 + ge + Returns the boolean truth of arg1 >= arg2 + +These functions work on basic types only (or named basic types, +such as "type Celsius float32"). They implement the Go rules for +comparison of values, except that size and exact type are ignored, +so any integer value may be compared with any other integer value, +any unsigned integer value may be compared with any other unsigned +integer value, and so on. However, as usual, one may not compare +an int with a float32 and so on. Associated templates diff --git a/src/pkg/text/template/exec_test.go b/src/pkg/text/template/exec_test.go index 3d110af9cc..341c502173 100644 --- a/src/pkg/text/template/exec_test.go +++ b/src/pkg/text/template/exec_test.go @@ -863,3 +863,110 @@ func TestMessageForExecuteEmpty(t *testing.T) { t.Fatal(err) } } + +type cmpTest struct { + expr string + truth string + ok bool +} + +var cmpTests = []cmpTest{ + {"eq true true", "true", true}, + {"eq true false", "false", true}, + {"eq 1+2i 1+2i", "true", true}, + {"eq 1+2i 1+3i", "false", true}, + {"eq 1.5 1.5", "true", true}, + {"eq 1.5 2.5", "false", true}, + {"eq 1 1", "true", true}, + {"eq 1 2", "false", true}, + {"eq `xy` `xy`", "true", true}, + {"eq `xy` `xyz`", "false", true}, + {"eq .Xuint .Xuint", "true", true}, + {"eq .Xuint .Yuint", "false", true}, + {"ne true true", "false", true}, + {"ne true false", "true", true}, + {"ne 1+2i 1+2i", "false", true}, + {"ne 1+2i 1+3i", "true", true}, + {"ne 1.5 1.5", "false", true}, + {"ne 1.5 2.5", "true", true}, + {"ne 1 1", "false", true}, + {"ne 1 2", "true", true}, + {"ne `xy` `xy`", "false", true}, + {"ne `xy` `xyz`", "true", true}, + {"ne .Xuint .Xuint", "false", true}, + {"ne .Xuint .Yuint", "true", true}, + {"lt 1.5 1.5", "false", true}, + {"lt 1.5 2.5", "true", true}, + {"lt 1 1", "false", true}, + {"lt 1 2", "true", true}, + {"lt `xy` `xy`", "false", true}, + {"lt `xy` `xyz`", "true", true}, + {"lt .Xuint .Xuint", "false", true}, + {"lt .Xuint .Yuint", "true", true}, + {"le 1.5 1.5", "true", true}, + {"le 1.5 2.5", "true", true}, + {"le 2.5 1.5", "false", true}, + {"le 1 1", "true", true}, + {"le 1 2", "true", true}, + {"le 2 1", "false", true}, + {"le `xy` `xy`", "true", true}, + {"le `xy` `xyz`", "true", true}, + {"le `xyz` `xy`", "false", true}, + {"le .Xuint .Xuint", "true", true}, + {"le .Xuint .Yuint", "true", true}, + {"le .Yuint .Xuint", "false", true}, + {"gt 1.5 1.5", "false", true}, + {"gt 1.5 2.5", "false", true}, + {"gt 1 1", "false", true}, + {"gt 2 1", "true", true}, + {"gt 1 2", "false", true}, + {"gt `xy` `xy`", "false", true}, + {"gt `xy` `xyz`", "false", true}, + {"gt .Xuint .Xuint", "false", true}, + {"gt .Xuint .Yuint", "false", true}, + {"gt .Yuint .Xuint", "true", true}, + {"ge 1.5 1.5", "true", true}, + {"ge 1.5 2.5", "false", true}, + {"ge 2.5 1.5", "true", true}, + {"ge 1 1", "true", true}, + {"ge 1 2", "false", true}, + {"ge 2 1", "true", true}, + {"ge `xy` `xy`", "true", true}, + {"ge `xy` `xyz`", "false", true}, + {"ge `xyz` `xy`", "true", true}, + {"ge .Xuint .Xuint", "true", true}, + {"ge .Xuint .Yuint", "false", true}, + {"ge .Yuint .Xuint", "true", true}, + // Errors + {"eq 3 4 5", "", false}, // Too many arguments. + {"eq `xy` 1", "", false}, // Different types. + {"lt true true", "", false}, // Unordered types. + {"lt 1+0i 1+0i", "", false}, // Unordered types. +} + +func TestComparison(t *testing.T) { + b := new(bytes.Buffer) + var cmpStruct = struct { + Xuint, Yuint uint + }{3, 4} + for _, test := range cmpTests { + text := fmt.Sprintf("{{if %s}}true{{else}}false{{end}}", test.expr) + tmpl, err := New("empty").Parse(text) + if err != nil { + t.Fatal(err) + } + b.Reset() + err = tmpl.Execute(b, &cmpStruct) + if test.ok && err != nil { + t.Errorf("%s errored incorrectly: %s", test.expr, err) + continue + } + if !test.ok && err == nil { + t.Errorf("%s did not error") + continue + } + if b.String() != test.truth { + t.Errorf("%s: want %s; got %s", test.expr, test.truth, b.String()) + } + } +} diff --git a/src/pkg/text/template/funcs.go b/src/pkg/text/template/funcs.go index 643a728cb7..9402170bd0 100644 --- a/src/pkg/text/template/funcs.go +++ b/src/pkg/text/template/funcs.go @@ -6,6 +6,7 @@ package template import ( "bytes" + "errors" "fmt" "io" "net/url" @@ -35,6 +36,14 @@ var builtins = FuncMap{ "printf": fmt.Sprintf, "println": fmt.Sprintln, "urlquery": URLQueryEscaper, + + // Comparisons + "eq": eq, // == + "ge": ge, // >= + "gt": gt, // > + "le": le, // <= + "lt": lt, // < + "ne": ne, // != } var builtinFuncs = createValueFuncs(builtins) @@ -248,6 +257,151 @@ func not(arg interface{}) (truth bool) { return !truth } +// Comparison. + +// TODO: Perhaps allow comparison between signed and unsigned integers. + +var ( + errBadComparisonType = errors.New("invalid type for comparison") + errBadComparison = errors.New("incompatible types for comparison") +) + +type kind int + +const ( + invalidKind kind = iota + boolKind + complexKind + intKind + floatKind + integerKind + stringKind + uintKind +) + +func basicKind(v reflect.Value) (kind, error) { + switch v.Kind() { + case reflect.Bool: + return boolKind, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return intKind, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return uintKind, nil + case reflect.Float32, reflect.Float64: + return floatKind, nil + case reflect.Complex64, reflect.Complex128: + return complexKind, nil + case reflect.String: + return stringKind, nil + } + return invalidKind, errBadComparisonType +} + +// eq evaluates the comparison a == b. +func eq(arg1, arg2 interface{}) (bool, error) { + v1 := reflect.ValueOf(arg1) + k1, err := basicKind(v1) + if err != nil { + return false, err + } + v2 := reflect.ValueOf(arg2) + k2, err := basicKind(v2) + if err != nil { + return false, err + } + if k1 != k2 { + return false, errBadComparison + } + truth := false + switch k1 { + case boolKind: + truth = v1.Bool() == v2.Bool() + case complexKind: + truth = v1.Complex() == v2.Complex() + case floatKind: + truth = v1.Float() == v2.Float() + case intKind: + truth = v1.Int() == v2.Int() + case stringKind: + truth = v1.String() == v2.String() + case uintKind: + truth = v1.Uint() == v2.Uint() + default: + panic("invalid kind") + } + return truth, nil +} + +// ne evaluates the comparison a != b. +func ne(arg1, arg2 interface{}) (bool, error) { + // != is the inverse of ==. + equal, err := eq(arg1, arg2) + return !equal, err +} + +// lt evaluates the comparison a < b. +func lt(arg1, arg2 interface{}) (bool, error) { + v1 := reflect.ValueOf(arg1) + k1, err := basicKind(v1) + if err != nil { + return false, err + } + v2 := reflect.ValueOf(arg2) + k2, err := basicKind(v2) + if err != nil { + return false, err + } + if k1 != k2 { + return false, errBadComparison + } + truth := false + switch k1 { + case boolKind, complexKind: + return false, errBadComparisonType + case floatKind: + truth = v1.Float() < v2.Float() + case intKind: + truth = v1.Int() < v2.Int() + case stringKind: + truth = v1.String() < v2.String() + case uintKind: + truth = v1.Uint() < v2.Uint() + default: + panic("invalid kind") + } + return truth, nil +} + +// le evaluates the comparison <= b. +func le(arg1, arg2 interface{}) (bool, error) { + // <= is < or ==. + lessThan, err := lt(arg1, arg2) + if lessThan || err != nil { + return lessThan, err + } + return eq(arg1, arg2) +} + +// gt evaluates the comparison a > b. +func gt(arg1, arg2 interface{}) (bool, error) { + // > is the inverse of <=. + lessOrEqual, err := le(arg1, arg2) + if err != nil { + return false, err + } + return !lessOrEqual, nil +} + +// ge evaluates the comparison a >= b. +func ge(arg1, arg2 interface{}) (bool, error) { + // >= is the inverse of <. + lessThan, err := lt(arg1, arg2) + if err != nil { + return false, err + } + return !lessThan, nil +} + // HTML escaping. var (