From 5a4db102b21489c39b3a654e06cc25155432a38a Mon Sep 17 00:00:00 2001 From: Ian Lance Taylor Date: Wed, 2 Dec 2020 19:18:46 -0800 Subject: [PATCH] html/template: avoid race when escaping updates template Fixes #39807 Change-Id: Icf384f800e2541bc753507daa3a9bc7e5d1c3f79 Reviewed-on: https://go-review.googlesource.com/c/go/+/274450 Trust: Ian Lance Taylor Run-TryBot: Ian Lance Taylor TryBot-Result: Go Bot Reviewed-by: Roberto Clapis Reviewed-by: Russ Cox --- src/html/template/exec_test.go | 70 ++++++++++++++++++++++++++ src/html/template/template.go | 90 +++++++++++++++++++++++++++++----- 2 files changed, 147 insertions(+), 13 deletions(-) diff --git a/src/html/template/exec_test.go b/src/html/template/exec_test.go index 232945a0bb1..eb008242603 100644 --- a/src/html/template/exec_test.go +++ b/src/html/template/exec_test.go @@ -14,6 +14,7 @@ import ( "io" "reflect" "strings" + "sync" "testing" "text/template" ) @@ -1706,3 +1707,72 @@ func TestIssue31810(t *testing.T) { t.Errorf("%s got %q, expected %q", textCall, b.String(), "result") } } + +// Issue 39807. There was a race applying escapeTemplate. + +const raceText = ` +{{- define "jstempl" -}} +var v = "v"; +{{- end -}} + +` + +func TestEscapeRace(t *testing.T) { + tmpl := New("") + _, err := tmpl.New("templ.html").Parse(raceText) + if err != nil { + t.Fatal(err) + } + const count = 20 + for i := 0; i < count; i++ { + _, err := tmpl.New(fmt.Sprintf("x%d.html", i)).Parse(`{{ template "templ.html" .}}`) + if err != nil { + t.Fatal(err) + } + } + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < count; j++ { + sub := tmpl.Lookup(fmt.Sprintf("x%d.html", j)) + if err := sub.Execute(io.Discard, nil); err != nil { + t.Error(err) + } + } + }() + } + wg.Wait() +} + +func TestRecursiveExecute(t *testing.T) { + tmpl := New("") + + recur := func() (HTML, error) { + var sb strings.Builder + if err := tmpl.ExecuteTemplate(&sb, "subroutine", nil); err != nil { + t.Fatal(err) + } + return HTML(sb.String()), nil + } + + m := FuncMap{ + "recur": recur, + } + + top, err := tmpl.New("x.html").Funcs(m).Parse(`{{recur}}`) + if err != nil { + t.Fatal(err) + } + _, err = tmpl.New("subroutine").Parse(``) + if err != nil { + t.Fatal(err) + } + if err := top.Execute(io.Discard, nil); err != nil { + t.Fatal(err) + } +} diff --git a/src/html/template/template.go b/src/html/template/template.go index 69312d36fdb..09d71d43e2f 100644 --- a/src/html/template/template.go +++ b/src/html/template/template.go @@ -11,6 +11,7 @@ import ( "os" "path" "path/filepath" + "reflect" "sync" "text/template" "text/template/parse" @@ -26,7 +27,9 @@ type Template struct { // template's in sync. text *template.Template // The underlying template's parse tree, updated to be HTML-safe. - Tree *parse.Tree + Tree *parse.Tree + // The original functions, before wrapping. + funcMap FuncMap *nameSpace // common to all associated templates } @@ -35,7 +38,7 @@ var escapeOK = fmt.Errorf("template escaped correctly") // nameSpace is the data structure shared by all templates in an association. type nameSpace struct { - mu sync.Mutex + mu sync.RWMutex set map[string]*Template escaped bool esc escaper @@ -45,8 +48,8 @@ type nameSpace struct { // itself. func (t *Template) Templates() []*Template { ns := t.nameSpace - ns.mu.Lock() - defer ns.mu.Unlock() + ns.mu.RLock() + defer ns.mu.RUnlock() // Return a slice so we don't expose the map. m := make([]*Template, 0, len(ns.set)) for _, v := range ns.set { @@ -84,8 +87,8 @@ func (t *Template) checkCanParse() error { if t == nil { return nil } - t.nameSpace.mu.Lock() - defer t.nameSpace.mu.Unlock() + t.nameSpace.mu.RLock() + defer t.nameSpace.mu.RUnlock() if t.nameSpace.escaped { return fmt.Errorf("html/template: cannot Parse after Execute") } @@ -94,6 +97,16 @@ func (t *Template) checkCanParse() error { // escape escapes all associated templates. func (t *Template) escape() error { + t.nameSpace.mu.RLock() + escapeErr := t.escapeErr + t.nameSpace.mu.RUnlock() + if escapeErr != nil { + if escapeErr == escapeOK { + return nil + } + return escapeErr + } + t.nameSpace.mu.Lock() defer t.nameSpace.mu.Unlock() t.nameSpace.escaped = true @@ -121,6 +134,8 @@ func (t *Template) Execute(wr io.Writer, data interface{}) error { if err := t.escape(); err != nil { return err } + t.nameSpace.mu.RLock() + defer t.nameSpace.mu.RUnlock() return t.text.Execute(wr, data) } @@ -136,6 +151,8 @@ func (t *Template) ExecuteTemplate(wr io.Writer, name string, data interface{}) if err != nil { return err } + t.nameSpace.mu.RLock() + defer t.nameSpace.mu.RUnlock() return tmpl.text.Execute(wr, data) } @@ -143,13 +160,27 @@ func (t *Template) ExecuteTemplate(wr io.Writer, name string, data interface{}) // is escaped, or returns an error if it cannot be. It returns the named // template. func (t *Template) lookupAndEscapeTemplate(name string) (tmpl *Template, err error) { - t.nameSpace.mu.Lock() - defer t.nameSpace.mu.Unlock() - t.nameSpace.escaped = true + t.nameSpace.mu.RLock() tmpl = t.set[name] + var escapeErr error + if tmpl != nil { + escapeErr = tmpl.escapeErr + } + t.nameSpace.mu.RUnlock() + if tmpl == nil { return nil, fmt.Errorf("html/template: %q is undefined", name) } + if escapeErr != nil { + if escapeErr != escapeOK { + return nil, escapeErr + } + return tmpl, nil + } + + t.nameSpace.mu.Lock() + defer t.nameSpace.mu.Unlock() + t.nameSpace.escaped = true if tmpl.escapeErr != nil && tmpl.escapeErr != escapeOK { return nil, tmpl.escapeErr } @@ -229,6 +260,7 @@ func (t *Template) AddParseTree(name string, tree *parse.Tree) (*Template, error nil, text, text.Tree, + nil, t.nameSpace, } t.set[name] = ret @@ -259,8 +291,10 @@ func (t *Template) Clone() (*Template, error) { nil, textClone, textClone.Tree, + t.funcMap, ns, } + ret.wrapFuncs() ret.set[ret.Name()] = ret for _, x := range textClone.Templates() { name := x.Name() @@ -269,12 +303,15 @@ func (t *Template) Clone() (*Template, error) { return nil, fmt.Errorf("html/template: cannot Clone %q after it has executed", t.Name()) } x.Tree = x.Tree.Copy() - ret.set[name] = &Template{ + tc := &Template{ nil, x, x.Tree, + src.funcMap, ret.nameSpace, } + tc.wrapFuncs() + ret.set[name] = tc } // Return the template associated with the name of this template. return ret.set[ret.Name()], nil @@ -288,6 +325,7 @@ func New(name string) *Template { nil, template.New(name), nil, + nil, ns, } tmpl.set[name] = tmpl @@ -313,6 +351,7 @@ func (t *Template) new(name string) *Template { nil, t.text.New(name), nil, + nil, t.nameSpace, } if existing, ok := tmpl.set[name]; ok { @@ -343,10 +382,35 @@ type FuncMap map[string]interface{} // type. However, it is legal to overwrite elements of the map. The return // value is the template, so calls can be chained. func (t *Template) Funcs(funcMap FuncMap) *Template { - t.text.Funcs(template.FuncMap(funcMap)) + t.funcMap = funcMap + t.wrapFuncs() return t } +// wrapFuncs records the functions with text/template. We wrap them to +// unlock the nameSpace. See TestRecursiveExecute for a test case. +func (t *Template) wrapFuncs() { + if len(t.funcMap) == 0 { + return + } + tfuncs := make(template.FuncMap, len(t.funcMap)) + for name, fn := range t.funcMap { + fnv := reflect.ValueOf(fn) + wrapper := func(args []reflect.Value) []reflect.Value { + t.nameSpace.mu.RUnlock() + defer t.nameSpace.mu.RLock() + if fnv.Type().IsVariadic() { + return fnv.CallSlice(args) + } else { + return fnv.Call(args) + } + } + wrapped := reflect.MakeFunc(fnv.Type(), wrapper) + tfuncs[name] = wrapped.Interface() + } + t.text.Funcs(tfuncs) +} + // Delims sets the action delimiters to the specified strings, to be used in // subsequent calls to Parse, ParseFiles, or ParseGlob. Nested template // definitions will inherit the settings. An empty delimiter stands for the @@ -360,8 +424,8 @@ func (t *Template) Delims(left, right string) *Template { // Lookup returns the template with the given name that is associated with t, // or nil if there is no such template. func (t *Template) Lookup(name string) *Template { - t.nameSpace.mu.Lock() - defer t.nameSpace.mu.Unlock() + t.nameSpace.mu.RLock() + defer t.nameSpace.mu.RUnlock() return t.set[name] }