1
0
mirror of https://github.com/golang/go synced 2024-11-22 19:54:39 -07:00

cmd/compile: allow mid-stack inlining when there is a cycle of recursion

We still disallow inlining for an immediately-recursive function, but allow
inlining if a function is in a recursion chain.

If all functions in the recursion chain are simple, then we could inline
forever down the recursion chain (eventually running out of stack on the
compiler), so we add a map to keep track of the functions we have
already inlined at a call site. We stop inlining when we reach a
function that we have already inlined in the recursive chain. Of course,
normally the inlining will have stopped earlier, because of the cost
function.

We could also limit the depth of inlining by a simple count (say, limit
max inlining of 10 at any given site). Would that limit other
opportunities too much?

Added a test in test/inline.go. runtime.BenchmarkStackCopyNoCache() is
also already a good test that triggers the check to stop inlining
when we reach the start of the recursive chain again.

For the bent benchmark suite, the performance improvement was mostly not
statistically significant, but the geomean averaged out to: -0.68%. The text size
increase was less than .1% for all bent benchmarks. The cmd/go text size increase
was 0.02% and the cmd/compile text size increase was .1%.

Fixes #29737

Change-Id: I892fa84bb07a947b3125ec8f25ed0e508bf2bdf5
Reviewed-on: https://go-review.googlesource.com/c/go/+/226818
Run-TryBot: Dan Scales <danscales@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Keith Randall <khr@golang.org>
This commit is contained in:
Dan Scales 2020-03-31 20:24:05 -07:00
parent 339e9c6400
commit ed7a8332c4
4 changed files with 69 additions and 18 deletions

View File

@ -496,7 +496,14 @@ func inlcalls(fn *Node) {
if countNodes(fn) >= inlineBigFunctionNodes { if countNodes(fn) >= inlineBigFunctionNodes {
maxCost = inlineBigFunctionMaxCost maxCost = inlineBigFunctionMaxCost
} }
fn = inlnode(fn, maxCost) // Map to keep track of functions that have been inlined at a particular
// call site, in order to stop inlining when we reach the beginning of a
// recursion cycle again. We don't inline immediately recursive functions,
// but allow inlining if there is a recursion cycle of many functions.
// Most likely, the inlining will stop before we even hit the beginning of
// the cycle again, but the map catches the unusual case.
inlMap := make(map[*Node]bool)
fn = inlnode(fn, maxCost, inlMap)
if fn != Curfn { if fn != Curfn {
Fatalf("inlnode replaced curfn") Fatalf("inlnode replaced curfn")
} }
@ -537,10 +544,10 @@ func inlconv2list(n *Node) []*Node {
return s return s
} }
func inlnodelist(l Nodes, maxCost int32) { func inlnodelist(l Nodes, maxCost int32, inlMap map[*Node]bool) {
s := l.Slice() s := l.Slice()
for i := range s { for i := range s {
s[i] = inlnode(s[i], maxCost) s[i] = inlnode(s[i], maxCost, inlMap)
} }
} }
@ -557,7 +564,7 @@ func inlnodelist(l Nodes, maxCost int32) {
// shorter and less complicated. // shorter and less complicated.
// The result of inlnode MUST be assigned back to n, e.g. // The result of inlnode MUST be assigned back to n, e.g.
// n.Left = inlnode(n.Left) // n.Left = inlnode(n.Left)
func inlnode(n *Node, maxCost int32) *Node { func inlnode(n *Node, maxCost int32, inlMap map[*Node]bool) *Node {
if n == nil { if n == nil {
return n return n
} }
@ -585,19 +592,19 @@ func inlnode(n *Node, maxCost int32) *Node {
lno := setlineno(n) lno := setlineno(n)
inlnodelist(n.Ninit, maxCost) inlnodelist(n.Ninit, maxCost, inlMap)
for _, n1 := range n.Ninit.Slice() { for _, n1 := range n.Ninit.Slice() {
if n1.Op == OINLCALL { if n1.Op == OINLCALL {
inlconv2stmt(n1) inlconv2stmt(n1)
} }
} }
n.Left = inlnode(n.Left, maxCost) n.Left = inlnode(n.Left, maxCost, inlMap)
if n.Left != nil && n.Left.Op == OINLCALL { if n.Left != nil && n.Left.Op == OINLCALL {
n.Left = inlconv2expr(n.Left) n.Left = inlconv2expr(n.Left)
} }
n.Right = inlnode(n.Right, maxCost) n.Right = inlnode(n.Right, maxCost, inlMap)
if n.Right != nil && n.Right.Op == OINLCALL { if n.Right != nil && n.Right.Op == OINLCALL {
if n.Op == OFOR || n.Op == OFORUNTIL { if n.Op == OFOR || n.Op == OFORUNTIL {
inlconv2stmt(n.Right) inlconv2stmt(n.Right)
@ -612,7 +619,7 @@ func inlnode(n *Node, maxCost int32) *Node {
} }
} }
inlnodelist(n.List, maxCost) inlnodelist(n.List, maxCost, inlMap)
if n.Op == OBLOCK { if n.Op == OBLOCK {
for _, n2 := range n.List.Slice() { for _, n2 := range n.List.Slice() {
if n2.Op == OINLCALL { if n2.Op == OINLCALL {
@ -628,7 +635,7 @@ func inlnode(n *Node, maxCost int32) *Node {
} }
} }
inlnodelist(n.Rlist, maxCost) inlnodelist(n.Rlist, maxCost, inlMap)
s := n.Rlist.Slice() s := n.Rlist.Slice()
for i1, n1 := range s { for i1, n1 := range s {
if n1.Op == OINLCALL { if n1.Op == OINLCALL {
@ -640,7 +647,7 @@ func inlnode(n *Node, maxCost int32) *Node {
} }
} }
inlnodelist(n.Nbody, maxCost) inlnodelist(n.Nbody, maxCost, inlMap)
for _, n := range n.Nbody.Slice() { for _, n := range n.Nbody.Slice() {
if n.Op == OINLCALL { if n.Op == OINLCALL {
inlconv2stmt(n) inlconv2stmt(n)
@ -663,12 +670,12 @@ func inlnode(n *Node, maxCost int32) *Node {
fmt.Printf("%v:call to func %+v\n", n.Line(), n.Left) fmt.Printf("%v:call to func %+v\n", n.Line(), n.Left)
} }
if n.Left.Func != nil && n.Left.Func.Inl != nil && !isIntrinsicCall(n) { // normal case if n.Left.Func != nil && n.Left.Func.Inl != nil && !isIntrinsicCall(n) { // normal case
n = mkinlcall(n, n.Left, maxCost) n = mkinlcall(n, n.Left, maxCost, inlMap)
} else if n.Left.isMethodExpression() && asNode(n.Left.Sym.Def) != nil { } else if n.Left.isMethodExpression() && asNode(n.Left.Sym.Def) != nil {
n = mkinlcall(n, asNode(n.Left.Sym.Def), maxCost) n = mkinlcall(n, asNode(n.Left.Sym.Def), maxCost, inlMap)
} else if n.Left.Op == OCLOSURE { } else if n.Left.Op == OCLOSURE {
if f := inlinableClosure(n.Left); f != nil { if f := inlinableClosure(n.Left); f != nil {
n = mkinlcall(n, f, maxCost) n = mkinlcall(n, f, maxCost, inlMap)
} }
} else if n.Left.Op == ONAME && n.Left.Name != nil && n.Left.Name.Defn != nil { } else if n.Left.Op == ONAME && n.Left.Name != nil && n.Left.Name.Defn != nil {
if d := n.Left.Name.Defn; d.Op == OAS && d.Right.Op == OCLOSURE { if d := n.Left.Name.Defn; d.Op == OAS && d.Right.Op == OCLOSURE {
@ -694,7 +701,7 @@ func inlnode(n *Node, maxCost int32) *Node {
} }
break break
} }
n = mkinlcall(n, f, maxCost) n = mkinlcall(n, f, maxCost, inlMap)
} }
} }
} }
@ -713,7 +720,7 @@ func inlnode(n *Node, maxCost int32) *Node {
Fatalf("no function definition for [%p] %+v\n", n.Left.Type, n.Left.Type) Fatalf("no function definition for [%p] %+v\n", n.Left.Type, n.Left.Type)
} }
n = mkinlcall(n, asNode(n.Left.Type.FuncType().Nname), maxCost) n = mkinlcall(n, asNode(n.Left.Type.FuncType().Nname), maxCost, inlMap)
} }
lineno = lno lineno = lno
@ -833,7 +840,7 @@ var inlgen int
// parameters. // parameters.
// The result of mkinlcall MUST be assigned back to n, e.g. // The result of mkinlcall MUST be assigned back to n, e.g.
// n.Left = mkinlcall(n.Left, fn, isddd) // n.Left = mkinlcall(n.Left, fn, isddd)
func mkinlcall(n, fn *Node, maxCost int32) *Node { func mkinlcall(n, fn *Node, maxCost int32, inlMap map[*Node]bool) *Node {
if fn.Func.Inl == nil { if fn.Func.Inl == nil {
// No inlinable body. // No inlinable body.
return n return n
@ -866,6 +873,16 @@ func mkinlcall(n, fn *Node, maxCost int32) *Node {
return n return n
} }
if inlMap[fn] {
if Debug['m'] > 1 {
fmt.Printf("%v: cannot inline %v into %v: repeated recursive cycle\n", n.Line(), fn, Curfn.funcname())
}
return n
}
inlMap[fn] = true
defer func() {
inlMap[fn] = false
}()
if Debug_typecheckinl == 0 { if Debug_typecheckinl == 0 {
typecheckinl(fn) typecheckinl(fn)
} }
@ -1129,7 +1146,7 @@ func mkinlcall(n, fn *Node, maxCost int32) *Node {
// instead we emit the things that the body needs // instead we emit the things that the body needs
// and each use must redo the inlining. // and each use must redo the inlining.
// luckily these are small. // luckily these are small.
inlnodelist(call.Nbody, maxCost) inlnodelist(call.Nbody, maxCost, inlMap)
for _, n := range call.Nbody.Slice() { for _, n := range call.Nbody.Slice() {
if n.Op == OINLCALL { if n.Op == OINLCALL {
inlconv2stmt(n) inlconv2stmt(n)

View File

@ -679,8 +679,12 @@ func Main(archInit func(*Arch)) {
if Debug['l'] != 0 { if Debug['l'] != 0 {
// Find functions that can be inlined and clone them before walk expands them. // Find functions that can be inlined and clone them before walk expands them.
visitBottomUp(xtop, func(list []*Node, recursive bool) { visitBottomUp(xtop, func(list []*Node, recursive bool) {
numfns := numNonClosures(list)
for _, n := range list { for _, n := range list {
if !recursive { if !recursive || numfns > 1 {
// We allow inlining if there is no
// recursion, or the recursion cycle is
// across more than one function.
caninl(n) caninl(n)
} else { } else {
if Debug['m'] > 1 { if Debug['m'] > 1 {
@ -824,6 +828,17 @@ func Main(archInit func(*Arch)) {
} }
} }
// numNonClosures returns the number of functions in list which are not closures.
func numNonClosures(list []*Node) int {
count := 0
for _, n := range list {
if n.Func.Closure == nil {
count++
}
}
return count
}
func writebench(filename string) error { func writebench(filename string) error {
f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666) f, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0666)
if err != nil { if err != nil {

View File

@ -180,3 +180,21 @@ func (T) meth2(int, int) { // not inlineable - has 2 calls.
runtime.GC() runtime.GC()
runtime.GC() runtime.GC()
} }
// Issue #29737 - make sure we can do inlining for a chain of recursive functions
func ee() { // ERROR "can inline ee"
ff(100) // ERROR "inlining call to ff" "inlining call to gg" "inlining call to hh"
}
func ff(x int) { // ERROR "can inline ff"
if x < 0 {
return
}
gg(x - 1)
}
func gg(x int) { // ERROR "can inline gg"
hh(x - 1)
}
func hh(x int) { // ERROR "can inline hh"
ff(x - 1) // ERROR "inlining call to ff" // ERROR "inlining call to gg"
}

View File

@ -67,6 +67,7 @@ func d2() {
d3() d3()
} }
//go:noinline
func d3() { func d3() {
x.f = y // ERROR "write barrier prohibited by caller" x.f = y // ERROR "write barrier prohibited by caller"
d4() d4()