1
0
mirror of https://github.com/golang/go synced 2024-11-21 15:14:43 -07:00

exp/template: functions

Add the ability to attach functions to template and template sets.
Make variadic functions and methods work.
Still to come: static checking of function names during parse.

R=golang-dev, dsymonds
CC=golang-dev
https://golang.org/cl/4643068
This commit is contained in:
Rob Pike 2011-07-05 14:23:51 +10:00
parent 9cf37c3723
commit b177c97803
6 changed files with 196 additions and 39 deletions

View File

@ -7,6 +7,7 @@ include ../../../Make.inc
TARG=exp/template
GOFILES=\
exec.go\
funcs.go\
lex.go\
parse.go\
set.go\

View File

@ -174,8 +174,11 @@ func (s *state) evalPipeline(data reflect.Value, pipe []*commandNode) reflect.Va
func (s *state) evalCommand(data reflect.Value, cmd *commandNode, final reflect.Value) reflect.Value {
firstWord := cmd.args[0]
if field, ok := firstWord.(*fieldNode); ok {
return s.evalFieldNode(data, field, cmd.args, final)
switch n := firstWord.(type) {
case *fieldNode:
return s.evalFieldNode(data, n, cmd.args, final)
case *identifierNode:
return s.evalFieldOrCall(data, n.ident, cmd.args, final)
}
if len(cmd.args) > 1 || final.IsValid() {
// TODO: functions
@ -215,7 +218,7 @@ func (s *state) evalFieldNode(data reflect.Value, field *fieldNode, args []node,
data = s.evalField(data, field.ident[i])
}
// Now it can be a field or method and if a method, gets arguments.
return s.evalMethodOrField(data, field.ident[n-1], args, final)
return s.evalFieldOrCall(data, field.ident[n-1], args, final)
}
func (s *state) evalField(data reflect.Value, fieldName string) reflect.Value {
@ -238,14 +241,18 @@ func (s *state) evalField(data reflect.Value, fieldName string) reflect.Value {
panic("not reached")
}
func (s *state) evalMethodOrField(data reflect.Value, fieldName string, args []node, final reflect.Value) reflect.Value {
func (s *state) evalFieldOrCall(data reflect.Value, fieldName string, args []node, final reflect.Value) reflect.Value {
// Is it a function?
if function, ok := findFunction(fieldName, s.tmpl, s.set); ok {
return s.evalCall(data, function, fieldName, false, args, final)
}
ptr := data
for data.Kind() == reflect.Ptr {
ptr, data = data, reflect.Indirect(data)
}
// Is it a method? We use the pointer because it has value methods too.
if method, ok := ptr.Type().MethodByName(fieldName); ok {
return s.evalMethod(ptr, method, args, final)
return s.evalCall(ptr, method.Func, fieldName, true, args, final)
}
if len(args) > 1 || final.IsValid() {
s.errorf("%s is not a method but has arguments", fieldName)
@ -263,31 +270,46 @@ var (
osErrorType = reflect.TypeOf(new(os.Error)).Elem()
)
func (s *state) evalMethod(v reflect.Value, method reflect.Method, args []node, final reflect.Value) reflect.Value {
typ := method.Type
fun := method.Func
func (s *state) evalCall(v, fun reflect.Value, name string, isMethod bool, args []node, final reflect.Value) reflect.Value {
typ := fun.Type()
if !isMethod && len(args) > 0 { // Args will be nil if it's a niladic call in an argument list
args = args[1:] // first arg is name of function; not used in call.
}
numIn := len(args)
if final.IsValid() {
numIn++
}
if !typ.IsVariadic() && numIn < typ.NumIn()-1 || !typ.IsVariadic() && numIn != typ.NumIn() {
s.errorf("wrong number of args for %s: want %d got %d", method.Name, typ.NumIn(), len(args))
numFixed := len(args)
if typ.IsVariadic() {
numFixed = typ.NumIn() - 1 // last arg is the variadic one.
if numIn < numFixed {
s.errorf("wrong number of args for %s: want at least %d got %d", name, typ.NumIn()-1, len(args))
}
} else if numIn < typ.NumIn()-1 || !typ.IsVariadic() && numIn != typ.NumIn() {
s.errorf("wrong number of args for %s: want %d got %d", name, typ.NumIn(), len(args))
}
// We allow methods with 1 result or 2 results where the second is an os.Error.
switch {
case typ.NumOut() == 1:
case typ.NumOut() == 2 && typ.Out(1) == osErrorType:
default:
s.errorf("can't handle multiple results from method %q", method.Name)
if !goodFunc(typ) {
s.errorf("can't handle multiple results from method/function %q", name)
}
// Build the arg list.
argv := make([]reflect.Value, numIn)
// First arg is the receiver.
argv[0] = v
// Others must be evaluated.
for i := 1; i < len(args); i++ {
i := 0
if isMethod {
argv[0] = v
i++
}
// Others must be evaluated. Fixed args first.
for ; i < numFixed; i++ {
argv[i] = s.evalArg(v, typ.In(i), args[i])
}
// And now the ... args.
if typ.IsVariadic() {
argType := typ.In(typ.NumIn() - 1).Elem() // Argument is a slice.
for ; i < len(args); i++ {
argv[i] = s.evalArg(v, argType, args[i])
}
}
// Add final value if necessary.
if final.IsValid() {
argv[len(args)] = final
@ -310,23 +332,27 @@ func (s *state) evalArg(data reflect.Value, typ reflect.Type, n node) reflect.Va
}
switch typ.Kind() {
case reflect.Bool:
return s.evalBool(data, typ, n)
return s.evalBool(typ, n)
case reflect.String:
return s.evalString(data, typ, n)
return s.evalString(typ, n)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return s.evalInteger(data, typ, n)
return s.evalInteger(typ, n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
return s.evalUnsignedInteger(data, typ, n)
return s.evalUnsignedInteger(typ, n)
case reflect.Float32, reflect.Float64:
return s.evalFloat(data, typ, n)
return s.evalFloat(typ, n)
case reflect.Complex64, reflect.Complex128:
return s.evalComplex(data, typ, n)
return s.evalComplex(typ, n)
case reflect.Interface:
if typ.NumMethod() == 0 {
return s.evalEmptyInterface(data, typ, n)
}
}
s.errorf("can't handle node %s for method arg of type %s", n, typ)
s.errorf("can't handle %s for arg of type %s", n, typ)
panic("not reached")
}
func (s *state) evalBool(v reflect.Value, typ reflect.Type, n node) reflect.Value {
func (s *state) evalBool(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*boolNode); ok {
value := reflect.New(typ).Elem()
value.SetBool(n.true)
@ -336,7 +362,7 @@ func (s *state) evalBool(v reflect.Value, typ reflect.Type, n node) reflect.Valu
panic("not reached")
}
func (s *state) evalString(v reflect.Value, typ reflect.Type, n node) reflect.Value {
func (s *state) evalString(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*stringNode); ok {
value := reflect.New(typ).Elem()
value.SetString(n.text)
@ -346,7 +372,7 @@ func (s *state) evalString(v reflect.Value, typ reflect.Type, n node) reflect.Va
panic("not reached")
}
func (s *state) evalInteger(v reflect.Value, typ reflect.Type, n node) reflect.Value {
func (s *state) evalInteger(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*numberNode); ok && n.isInt {
value := reflect.New(typ).Elem()
value.SetInt(n.int64)
@ -356,7 +382,7 @@ func (s *state) evalInteger(v reflect.Value, typ reflect.Type, n node) reflect.V
panic("not reached")
}
func (s *state) evalUnsignedInteger(v reflect.Value, typ reflect.Type, n node) reflect.Value {
func (s *state) evalUnsignedInteger(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*numberNode); ok && n.isUint {
value := reflect.New(typ).Elem()
value.SetUint(n.uint64)
@ -366,7 +392,7 @@ func (s *state) evalUnsignedInteger(v reflect.Value, typ reflect.Type, n node) r
panic("not reached")
}
func (s *state) evalFloat(v reflect.Value, typ reflect.Type, n node) reflect.Value {
func (s *state) evalFloat(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*numberNode); ok && n.isFloat {
value := reflect.New(typ).Elem()
value.SetFloat(n.float64)
@ -376,7 +402,7 @@ func (s *state) evalFloat(v reflect.Value, typ reflect.Type, n node) reflect.Val
panic("not reached")
}
func (s *state) evalComplex(v reflect.Value, typ reflect.Type, n node) reflect.Value {
func (s *state) evalComplex(typ reflect.Type, n node) reflect.Value {
if n, ok := n.(*numberNode); ok && n.isComplex {
value := reflect.New(typ).Elem()
value.SetComplex(n.complex128)
@ -386,6 +412,34 @@ func (s *state) evalComplex(v reflect.Value, typ reflect.Type, n node) reflect.V
panic("not reached")
}
func (s *state) evalEmptyInterface(data reflect.Value, typ reflect.Type, n node) reflect.Value {
switch n := n.(type) {
case *boolNode:
return reflect.ValueOf(n.true)
case *fieldNode:
return s.evalFieldNode(data, n, nil, reflect.Value{})
case *identifierNode:
return s.evalFieldOrCall(data, n.ident, nil, reflect.Value{})
case *numberNode:
if n.isComplex {
return reflect.ValueOf(n.complex128)
}
if n.isInt {
return reflect.ValueOf(n.int64)
}
if n.isUint {
return reflect.ValueOf(n.uint64)
}
if n.isFloat {
return reflect.ValueOf(n.float64)
}
case *stringNode:
return reflect.ValueOf(n.text)
}
s.errorf("can't handle assignment of %s to empty interface argument", n)
panic("not reached")
}
// printValue writes the textual representation of the value to the output of
// the template.
func (s *state) printValue(n node, v reflect.Value) {

View File

@ -140,6 +140,16 @@ var execTests = []execTest{
{"if slice", "{{if .SI}}NON-EMPTY{{else}}EMPTY{{end}}", "NON-EMPTY", tVal, true},
{"if emptymap", "{{if .MSIEmpty}}NON-EMPTY{{else}}EMPTY{{end}}", "EMPTY", tVal, true},
{"if map", "{{if .MSI}}NON-EMPTY{{else}}EMPTY{{end}}", "NON-EMPTY", tVal, true},
// Function calls.
{"printf", `{{printf "hello, printf"}}`, "hello, printf", tVal, true},
{"printf int", `{{printf "%04x" 127}}`, "007f", tVal, true},
{"printf float", `{{printf "%g" 3.5}}`, "3.5", tVal, true},
{"printf complex", `{{printf "%g" 1+7i}}`, "(1+7i)", tVal, true},
{"printf string", `{{printf "%s" "hello"}}`, "hello", tVal, true},
{"printf function", `{{printf "%#q" gopher}}`, "`gopher`", tVal, true},
{"printf field", `{{printf "%s" .U.V}}`, "v", tVal, true},
{"printf method", `{{printf "%s" .Method0}}`, "resultOfMethod0", tVal, true},
{"printf lots", `{{printf "%d %s %g %s" 127 "hello" 7-3i .Method0}}`, "127 hello (7-3i) resultOfMethod0", tVal, true},
// With.
{"with true", "{{with true}}{{.}}{{end}}", "true", tVal, true},
{"with false", "{{with false}}{{.}}{{else}}FALSE{{end}}", "FALSE", tVal, true},
@ -171,10 +181,15 @@ var execTests = []execTest{
{"error method, no error", "{{.EPERM false}}", "false", tVal, true},
}
func gopher() string {
return "gopher"
}
func testExecute(execTests []execTest, set *Set, t *testing.T) {
b := new(bytes.Buffer)
funcs := FuncMap{"gopher": gopher}
for _, test := range execTests {
tmpl := New(test.name)
tmpl := New(test.name).Funcs(funcs)
err := tmpl.Parse(test.input)
if err != nil {
t.Errorf("%s: parse error: %s", test.name, err)

View File

@ -0,0 +1,63 @@
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package template
import (
"fmt"
"reflect"
)
// FuncMap is the type of the map defining the mapping from names to functions.
// Each function must have either a single return value, or two return values of
// which the second has type os.Error.
type FuncMap map[string]interface{}
var funcs = map[string]reflect.Value{
"printf": reflect.ValueOf(fmt.Sprintf),
}
// addFuncs adds to values the functions in funcs, converting them to reflect.Values.
func addFuncs(values map[string]reflect.Value, funcMap FuncMap) {
for name, fn := range funcMap {
v := reflect.ValueOf(fn)
if v.Kind() != reflect.Func {
panic("value for " + name + " not a function")
}
if !goodFunc(v.Type()) {
panic(fmt.Errorf("can't handle multiple results from method/function %q", name))
}
values[name] = v
}
}
// goodFunc checks that the function or method has the right result signature.
func goodFunc(typ reflect.Type) bool {
// We allow functions with 1 result or 2 results where the second is an os.Error.
switch {
case typ.NumOut() == 1:
return true
case typ.NumOut() == 2 && typ.Out(1) == osErrorType:
return true
}
return false
}
// findFunction looks for a function in the template, set, and global map.
func findFunction(name string, tmpl *Template, set *Set) (reflect.Value, bool) {
if tmpl != nil {
if fn := tmpl.funcs[name]; fn.IsValid() {
return fn, true
}
}
if set != nil {
if fn := set.funcs[name]; fn.IsValid() {
return fn, true
}
}
if fn := funcs[name]; fn.IsValid() {
return fn, true
}
return reflect.Value{}, false
}

View File

@ -8,6 +8,7 @@ import (
"bytes"
"fmt"
"os"
"reflect"
"runtime"
"strconv"
"strings"
@ -16,8 +17,9 @@ import (
// Template is the representation of a parsed template.
type Template struct {
name string
root *listNode
name string
root *listNode
funcs map[string]reflect.Value
// Parsing.
lex *lexer
tokens <-chan item
@ -450,12 +452,23 @@ func (w *withNode) String() string {
// New allocates a new template with the given name.
func New(name string) *Template {
return &Template{
name: name,
name: name,
funcs: make(map[string]reflect.Value),
}
}
// Funcs adds to the template's function map the elements of the
// argument map. It panics if a value in the map is not a function
// with appropriate return type.
// The return value is the template, so calls can be chained.
func (t *Template) Funcs(funcMap FuncMap) *Template {
addFuncs(t.funcs, funcMap)
return t
}
// errorf formats the error and terminates processing.
func (t *Template) errorf(format string, args ...interface{}) {
t.root = nil
format = fmt.Sprintf("template: %s:%d: %s", t.name, t.lex.lineNumber(), format)
panic(fmt.Errorf(format, args...))
}
@ -488,7 +501,6 @@ func (t *Template) recover(errp *os.Error) {
panic(e)
}
t.stopParse()
t.root = nil
*errp = e.(os.Error)
}
return

View File

@ -6,6 +6,7 @@ package template
import (
"os"
"reflect"
"runtime"
"strconv"
)
@ -13,16 +14,27 @@ import (
// Set holds a set of related templates that can refer to one another by name.
// A template may be a member of multiple sets.
type Set struct {
tmpl map[string]*Template
tmpl map[string]*Template
funcs map[string]reflect.Value
}
// NewSet allocates a new, empty template set.
func NewSet() *Set {
return &Set{
tmpl: make(map[string]*Template),
tmpl: make(map[string]*Template),
funcs: make(map[string]reflect.Value),
}
}
// Funcs adds to the set's function map the elements of the
// argument map. It panics if a value in the map is not a function
// with appropriate return type.
// The return value is the set, so calls can be chained.
func (s *Set) Funcs(funcMap FuncMap) *Set {
addFuncs(s.funcs, funcMap)
return s
}
// recover is the handler that turns panics into returns from the top
// level of Parse.
func (s *Set) recover(errp *os.Error) {