diff --git a/src/pkg/exp/template/Makefile b/src/pkg/exp/template/Makefile index 50a0bd7234..8550b0d522 100644 --- a/src/pkg/exp/template/Makefile +++ b/src/pkg/exp/template/Makefile @@ -7,6 +7,7 @@ include ../../../Make.inc TARG=exp/template GOFILES=\ exec.go\ + funcs.go\ lex.go\ parse.go\ set.go\ diff --git a/src/pkg/exp/template/exec.go b/src/pkg/exp/template/exec.go index 3eaecd1941..636bc4c334 100644 --- a/src/pkg/exp/template/exec.go +++ b/src/pkg/exp/template/exec.go @@ -174,8 +174,11 @@ func (s *state) evalPipeline(data reflect.Value, pipe []*commandNode) reflect.Va func (s *state) evalCommand(data reflect.Value, cmd *commandNode, final reflect.Value) reflect.Value { firstWord := cmd.args[0] - if field, ok := firstWord.(*fieldNode); ok { - return s.evalFieldNode(data, field, cmd.args, final) + switch n := firstWord.(type) { + case *fieldNode: + return s.evalFieldNode(data, n, cmd.args, final) + case *identifierNode: + return s.evalFieldOrCall(data, n.ident, cmd.args, final) } if len(cmd.args) > 1 || final.IsValid() { // TODO: functions @@ -215,7 +218,7 @@ func (s *state) evalFieldNode(data reflect.Value, field *fieldNode, args []node, data = s.evalField(data, field.ident[i]) } // Now it can be a field or method and if a method, gets arguments. - return s.evalMethodOrField(data, field.ident[n-1], args, final) + return s.evalFieldOrCall(data, field.ident[n-1], args, final) } func (s *state) evalField(data reflect.Value, fieldName string) reflect.Value { @@ -238,14 +241,18 @@ func (s *state) evalField(data reflect.Value, fieldName string) reflect.Value { panic("not reached") } -func (s *state) evalMethodOrField(data reflect.Value, fieldName string, args []node, final reflect.Value) reflect.Value { +func (s *state) evalFieldOrCall(data reflect.Value, fieldName string, args []node, final reflect.Value) reflect.Value { + // Is it a function? + if function, ok := findFunction(fieldName, s.tmpl, s.set); ok { + return s.evalCall(data, function, fieldName, false, args, final) + } ptr := data for data.Kind() == reflect.Ptr { ptr, data = data, reflect.Indirect(data) } // Is it a method? We use the pointer because it has value methods too. if method, ok := ptr.Type().MethodByName(fieldName); ok { - return s.evalMethod(ptr, method, args, final) + return s.evalCall(ptr, method.Func, fieldName, true, args, final) } if len(args) > 1 || final.IsValid() { s.errorf("%s is not a method but has arguments", fieldName) @@ -263,31 +270,46 @@ var ( osErrorType = reflect.TypeOf(new(os.Error)).Elem() ) -func (s *state) evalMethod(v reflect.Value, method reflect.Method, args []node, final reflect.Value) reflect.Value { - typ := method.Type - fun := method.Func +func (s *state) evalCall(v, fun reflect.Value, name string, isMethod bool, args []node, final reflect.Value) reflect.Value { + typ := fun.Type() + if !isMethod && len(args) > 0 { // Args will be nil if it's a niladic call in an argument list + args = args[1:] // first arg is name of function; not used in call. + } numIn := len(args) if final.IsValid() { numIn++ } - if !typ.IsVariadic() && numIn < typ.NumIn()-1 || !typ.IsVariadic() && numIn != typ.NumIn() { - s.errorf("wrong number of args for %s: want %d got %d", method.Name, typ.NumIn(), len(args)) + numFixed := len(args) + if typ.IsVariadic() { + numFixed = typ.NumIn() - 1 // last arg is the variadic one. + if numIn < numFixed { + s.errorf("wrong number of args for %s: want at least %d got %d", name, typ.NumIn()-1, len(args)) + } + } else if numIn < typ.NumIn()-1 || !typ.IsVariadic() && numIn != typ.NumIn() { + s.errorf("wrong number of args for %s: want %d got %d", name, typ.NumIn(), len(args)) } - // We allow methods with 1 result or 2 results where the second is an os.Error. - switch { - case typ.NumOut() == 1: - case typ.NumOut() == 2 && typ.Out(1) == osErrorType: - default: - s.errorf("can't handle multiple results from method %q", method.Name) + if !goodFunc(typ) { + s.errorf("can't handle multiple results from method/function %q", name) } // Build the arg list. argv := make([]reflect.Value, numIn) // First arg is the receiver. - argv[0] = v - // Others must be evaluated. - for i := 1; i < len(args); i++ { + i := 0 + if isMethod { + argv[0] = v + i++ + } + // Others must be evaluated. Fixed args first. + for ; i < numFixed; i++ { argv[i] = s.evalArg(v, typ.In(i), args[i]) } + // And now the ... args. + if typ.IsVariadic() { + argType := typ.In(typ.NumIn() - 1).Elem() // Argument is a slice. + for ; i < len(args); i++ { + argv[i] = s.evalArg(v, argType, args[i]) + } + } // Add final value if necessary. if final.IsValid() { argv[len(args)] = final @@ -310,23 +332,27 @@ func (s *state) evalArg(data reflect.Value, typ reflect.Type, n node) reflect.Va } switch typ.Kind() { case reflect.Bool: - return s.evalBool(data, typ, n) + return s.evalBool(typ, n) case reflect.String: - return s.evalString(data, typ, n) + return s.evalString(typ, n) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - return s.evalInteger(data, typ, n) + return s.evalInteger(typ, n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: - return s.evalUnsignedInteger(data, typ, n) + return s.evalUnsignedInteger(typ, n) case reflect.Float32, reflect.Float64: - return s.evalFloat(data, typ, n) + return s.evalFloat(typ, n) case reflect.Complex64, reflect.Complex128: - return s.evalComplex(data, typ, n) + return s.evalComplex(typ, n) + case reflect.Interface: + if typ.NumMethod() == 0 { + return s.evalEmptyInterface(data, typ, n) + } } - s.errorf("can't handle node %s for method arg of type %s", n, typ) + s.errorf("can't handle %s for arg of type %s", n, typ) panic("not reached") } -func (s *state) evalBool(v reflect.Value, typ reflect.Type, n node) reflect.Value { +func (s *state) evalBool(typ reflect.Type, n node) reflect.Value { if n, ok := n.(*boolNode); ok { value := reflect.New(typ).Elem() value.SetBool(n.true) @@ -336,7 +362,7 @@ func (s *state) evalBool(v reflect.Value, typ reflect.Type, n node) reflect.Valu panic("not reached") } -func (s *state) evalString(v reflect.Value, typ reflect.Type, n node) reflect.Value { +func (s *state) evalString(typ reflect.Type, n node) reflect.Value { if n, ok := n.(*stringNode); ok { value := reflect.New(typ).Elem() value.SetString(n.text) @@ -346,7 +372,7 @@ func (s *state) evalString(v reflect.Value, typ reflect.Type, n node) reflect.Va panic("not reached") } -func (s *state) evalInteger(v reflect.Value, typ reflect.Type, n node) reflect.Value { +func (s *state) evalInteger(typ reflect.Type, n node) reflect.Value { if n, ok := n.(*numberNode); ok && n.isInt { value := reflect.New(typ).Elem() value.SetInt(n.int64) @@ -356,7 +382,7 @@ func (s *state) evalInteger(v reflect.Value, typ reflect.Type, n node) reflect.V panic("not reached") } -func (s *state) evalUnsignedInteger(v reflect.Value, typ reflect.Type, n node) reflect.Value { +func (s *state) evalUnsignedInteger(typ reflect.Type, n node) reflect.Value { if n, ok := n.(*numberNode); ok && n.isUint { value := reflect.New(typ).Elem() value.SetUint(n.uint64) @@ -366,7 +392,7 @@ func (s *state) evalUnsignedInteger(v reflect.Value, typ reflect.Type, n node) r panic("not reached") } -func (s *state) evalFloat(v reflect.Value, typ reflect.Type, n node) reflect.Value { +func (s *state) evalFloat(typ reflect.Type, n node) reflect.Value { if n, ok := n.(*numberNode); ok && n.isFloat { value := reflect.New(typ).Elem() value.SetFloat(n.float64) @@ -376,7 +402,7 @@ func (s *state) evalFloat(v reflect.Value, typ reflect.Type, n node) reflect.Val panic("not reached") } -func (s *state) evalComplex(v reflect.Value, typ reflect.Type, n node) reflect.Value { +func (s *state) evalComplex(typ reflect.Type, n node) reflect.Value { if n, ok := n.(*numberNode); ok && n.isComplex { value := reflect.New(typ).Elem() value.SetComplex(n.complex128) @@ -386,6 +412,34 @@ func (s *state) evalComplex(v reflect.Value, typ reflect.Type, n node) reflect.V panic("not reached") } +func (s *state) evalEmptyInterface(data reflect.Value, typ reflect.Type, n node) reflect.Value { + switch n := n.(type) { + case *boolNode: + return reflect.ValueOf(n.true) + case *fieldNode: + return s.evalFieldNode(data, n, nil, reflect.Value{}) + case *identifierNode: + return s.evalFieldOrCall(data, n.ident, nil, reflect.Value{}) + case *numberNode: + if n.isComplex { + return reflect.ValueOf(n.complex128) + } + if n.isInt { + return reflect.ValueOf(n.int64) + } + if n.isUint { + return reflect.ValueOf(n.uint64) + } + if n.isFloat { + return reflect.ValueOf(n.float64) + } + case *stringNode: + return reflect.ValueOf(n.text) + } + s.errorf("can't handle assignment of %s to empty interface argument", n) + panic("not reached") +} + // printValue writes the textual representation of the value to the output of // the template. func (s *state) printValue(n node, v reflect.Value) { diff --git a/src/pkg/exp/template/exec_test.go b/src/pkg/exp/template/exec_test.go index 6e4da692e9..8784a0b9fd 100644 --- a/src/pkg/exp/template/exec_test.go +++ b/src/pkg/exp/template/exec_test.go @@ -140,6 +140,16 @@ var execTests = []execTest{ {"if slice", "{{if .SI}}NON-EMPTY{{else}}EMPTY{{end}}", "NON-EMPTY", tVal, true}, {"if emptymap", "{{if .MSIEmpty}}NON-EMPTY{{else}}EMPTY{{end}}", "EMPTY", tVal, true}, {"if map", "{{if .MSI}}NON-EMPTY{{else}}EMPTY{{end}}", "NON-EMPTY", tVal, true}, + // Function calls. + {"printf", `{{printf "hello, printf"}}`, "hello, printf", tVal, true}, + {"printf int", `{{printf "%04x" 127}}`, "007f", tVal, true}, + {"printf float", `{{printf "%g" 3.5}}`, "3.5", tVal, true}, + {"printf complex", `{{printf "%g" 1+7i}}`, "(1+7i)", tVal, true}, + {"printf string", `{{printf "%s" "hello"}}`, "hello", tVal, true}, + {"printf function", `{{printf "%#q" gopher}}`, "`gopher`", tVal, true}, + {"printf field", `{{printf "%s" .U.V}}`, "v", tVal, true}, + {"printf method", `{{printf "%s" .Method0}}`, "resultOfMethod0", tVal, true}, + {"printf lots", `{{printf "%d %s %g %s" 127 "hello" 7-3i .Method0}}`, "127 hello (7-3i) resultOfMethod0", tVal, true}, // With. {"with true", "{{with true}}{{.}}{{end}}", "true", tVal, true}, {"with false", "{{with false}}{{.}}{{else}}FALSE{{end}}", "FALSE", tVal, true}, @@ -171,10 +181,15 @@ var execTests = []execTest{ {"error method, no error", "{{.EPERM false}}", "false", tVal, true}, } +func gopher() string { + return "gopher" +} + func testExecute(execTests []execTest, set *Set, t *testing.T) { b := new(bytes.Buffer) + funcs := FuncMap{"gopher": gopher} for _, test := range execTests { - tmpl := New(test.name) + tmpl := New(test.name).Funcs(funcs) err := tmpl.Parse(test.input) if err != nil { t.Errorf("%s: parse error: %s", test.name, err) diff --git a/src/pkg/exp/template/funcs.go b/src/pkg/exp/template/funcs.go new file mode 100644 index 0000000000..88f82f3b2c --- /dev/null +++ b/src/pkg/exp/template/funcs.go @@ -0,0 +1,63 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package template + +import ( + "fmt" + "reflect" +) + +// FuncMap is the type of the map defining the mapping from names to functions. +// Each function must have either a single return value, or two return values of +// which the second has type os.Error. +type FuncMap map[string]interface{} + +var funcs = map[string]reflect.Value{ + "printf": reflect.ValueOf(fmt.Sprintf), +} + +// addFuncs adds to values the functions in funcs, converting them to reflect.Values. +func addFuncs(values map[string]reflect.Value, funcMap FuncMap) { + for name, fn := range funcMap { + v := reflect.ValueOf(fn) + if v.Kind() != reflect.Func { + panic("value for " + name + " not a function") + } + if !goodFunc(v.Type()) { + panic(fmt.Errorf("can't handle multiple results from method/function %q", name)) + } + values[name] = v + } +} + +// goodFunc checks that the function or method has the right result signature. +func goodFunc(typ reflect.Type) bool { + // We allow functions with 1 result or 2 results where the second is an os.Error. + switch { + case typ.NumOut() == 1: + return true + case typ.NumOut() == 2 && typ.Out(1) == osErrorType: + return true + } + return false +} + +// findFunction looks for a function in the template, set, and global map. +func findFunction(name string, tmpl *Template, set *Set) (reflect.Value, bool) { + if tmpl != nil { + if fn := tmpl.funcs[name]; fn.IsValid() { + return fn, true + } + } + if set != nil { + if fn := set.funcs[name]; fn.IsValid() { + return fn, true + } + } + if fn := funcs[name]; fn.IsValid() { + return fn, true + } + return reflect.Value{}, false +} diff --git a/src/pkg/exp/template/parse.go b/src/pkg/exp/template/parse.go index 74b5f2c0ae..aaed411d49 100644 --- a/src/pkg/exp/template/parse.go +++ b/src/pkg/exp/template/parse.go @@ -8,6 +8,7 @@ import ( "bytes" "fmt" "os" + "reflect" "runtime" "strconv" "strings" @@ -16,8 +17,9 @@ import ( // Template is the representation of a parsed template. type Template struct { - name string - root *listNode + name string + root *listNode + funcs map[string]reflect.Value // Parsing. lex *lexer tokens <-chan item @@ -450,12 +452,23 @@ func (w *withNode) String() string { // New allocates a new template with the given name. func New(name string) *Template { return &Template{ - name: name, + name: name, + funcs: make(map[string]reflect.Value), } } +// Funcs adds to the template's function map the elements of the +// argument map. It panics if a value in the map is not a function +// with appropriate return type. +// The return value is the template, so calls can be chained. +func (t *Template) Funcs(funcMap FuncMap) *Template { + addFuncs(t.funcs, funcMap) + return t +} + // errorf formats the error and terminates processing. func (t *Template) errorf(format string, args ...interface{}) { + t.root = nil format = fmt.Sprintf("template: %s:%d: %s", t.name, t.lex.lineNumber(), format) panic(fmt.Errorf(format, args...)) } @@ -488,7 +501,6 @@ func (t *Template) recover(errp *os.Error) { panic(e) } t.stopParse() - t.root = nil *errp = e.(os.Error) } return diff --git a/src/pkg/exp/template/set.go b/src/pkg/exp/template/set.go index 13d93d03ca..3aaabaad5a 100644 --- a/src/pkg/exp/template/set.go +++ b/src/pkg/exp/template/set.go @@ -6,6 +6,7 @@ package template import ( "os" + "reflect" "runtime" "strconv" ) @@ -13,16 +14,27 @@ import ( // Set holds a set of related templates that can refer to one another by name. // A template may be a member of multiple sets. type Set struct { - tmpl map[string]*Template + tmpl map[string]*Template + funcs map[string]reflect.Value } // NewSet allocates a new, empty template set. func NewSet() *Set { return &Set{ - tmpl: make(map[string]*Template), + tmpl: make(map[string]*Template), + funcs: make(map[string]reflect.Value), } } +// Funcs adds to the set's function map the elements of the +// argument map. It panics if a value in the map is not a function +// with appropriate return type. +// The return value is the set, so calls can be chained. +func (s *Set) Funcs(funcMap FuncMap) *Set { + addFuncs(s.funcs, funcMap) + return s +} + // recover is the handler that turns panics into returns from the top // level of Parse. func (s *Set) recover(errp *os.Error) {