diff --git a/src/cmd/compile/internal/ssa/fuse.go b/src/cmd/compile/internal/ssa/fuse.go index 236d5bbc55e..fec2ba87737 100644 --- a/src/cmd/compile/internal/ssa/fuse.go +++ b/src/cmd/compile/internal/ssa/fuse.go @@ -11,8 +11,8 @@ import ( // fuseEarly runs fuse(f, fuseTypePlain|fuseTypeIntInRange). func fuseEarly(f *Func) { fuse(f, fuseTypePlain|fuseTypeIntInRange) } -// fuseLate runs fuse(f, fuseTypePlain|fuseTypeIf). -func fuseLate(f *Func) { fuse(f, fuseTypePlain|fuseTypeIf) } +// fuseLate runs fuse(f, fuseTypePlain|fuseTypeIf|fuseTypeBranchRedirect). +func fuseLate(f *Func) { fuse(f, fuseTypePlain|fuseTypeIf|fuseTypeBranchRedirect) } type fuseType uint8 @@ -20,6 +20,7 @@ const ( fuseTypePlain fuseType = 1 << iota fuseTypeIf fuseTypeIntInRange + fuseTypeBranchRedirect fuseTypeShortCircuit ) @@ -43,6 +44,9 @@ func fuse(f *Func, typ fuseType) { changed = shortcircuitBlock(b) || changed } } + if typ&fuseTypeBranchRedirect != 0 { + changed = fuseBranchRedirect(f) || changed + } if changed { f.invalidateCFG() } diff --git a/src/cmd/compile/internal/ssa/fuse_branchredirect.go b/src/cmd/compile/internal/ssa/fuse_branchredirect.go new file mode 100644 index 00000000000..1b8b307bcac --- /dev/null +++ b/src/cmd/compile/internal/ssa/fuse_branchredirect.go @@ -0,0 +1,110 @@ +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa + +// fuseBranchRedirect checks for a CFG in which the outbound branch +// of an If block can be derived from its predecessor If block, in +// some such cases, we can redirect the predecessor If block to the +// corresponding successor block directly. For example: +// p: +// v11 = Less64 v10 v8 +// If v11 goto b else u +// b: <- p ... +// v17 = Leq64 v10 v8 +// If v17 goto s else o +// We can redirect p to s directly. +// +// The implementation here borrows the framework of the prove pass. +// 1, Traverse all blocks of function f to find If blocks. +// 2, For any If block b, traverse all its predecessors to find If blocks. +// 3, For any If block predecessor p, update relationship p->b. +// 4, Traverse all successors of b. +// 5, For any successor s of b, try to update relationship b->s, if a +// contradiction is found then redirect p to another successor of b. +func fuseBranchRedirect(f *Func) bool { + ft := newFactsTable(f) + ft.checkpoint() + + changed := false + for i := len(f.Blocks) - 1; i >= 0; i-- { + b := f.Blocks[i] + if b.Kind != BlockIf { + continue + } + // b is either empty or only contains the control value. + // TODO: if b contains only OpCopy or OpNot related to b.Controls, + // such as Copy(Not(Copy(Less64(v1, v2)))), perhaps it can be optimized. + bCtl := b.Controls[0] + if bCtl.Block != b && len(b.Values) != 0 || (len(b.Values) != 1 || bCtl.Uses != 1) && bCtl.Block == b { + continue + } + + for k := 0; k < len(b.Preds); k++ { + pk := b.Preds[k] + p := pk.b + if p.Kind != BlockIf || p == b { + continue + } + pbranch := positive + if pk.i == 1 { + pbranch = negative + } + ft.checkpoint() + // Assume branch p->b is taken. + addBranchRestrictions(ft, p, pbranch) + // Check if any outgoing branch is unreachable based on the above condition. + parent := b + for j, bbranch := range [...]branch{positive, negative} { + ft.checkpoint() + // Try to update relationship b->child, and check if the contradiction occurs. + addBranchRestrictions(ft, parent, bbranch) + unsat := ft.unsat + ft.restore() + if !unsat { + continue + } + // This branch is impossible,so redirect p directly to another branch. + out := 1 ^ j + child := parent.Succs[out].b + if child == b { + continue + } + b.removePred(k) + p.Succs[pk.i] = Edge{child, len(child.Preds)} + // Fix up Phi value in b to have one less argument. + for _, v := range b.Values { + if v.Op != OpPhi { + continue + } + v.RemoveArg(k) + phielimValue(v) + } + // Fix up child to have one more predecessor. + child.Preds = append(child.Preds, Edge{p, pk.i}) + ai := b.Succs[out].i + for _, v := range child.Values { + if v.Op != OpPhi { + continue + } + v.AddArg(v.Args[ai]) + } + if b.Func.pass.debug > 0 { + b.Func.Warnl(b.Controls[0].Pos, "Redirect %s based on %s", b.Controls[0].Op, p.Controls[0].Op) + } + changed = true + k-- + break + } + ft.restore() + } + if len(b.Preds) == 0 && b != f.Entry { + // Block is now dead. + b.Kind = BlockInvalid + } + } + ft.restore() + ft.cleanup(f) + return changed +} diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index bcfdfc13f04..b203584c6b4 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -726,6 +726,20 @@ var ( } ) +// cleanup returns the posets to the free list +func (ft *factsTable) cleanup(f *Func) { + for _, po := range []*poset{ft.orderS, ft.orderU} { + // Make sure it's empty as it should be. A non-empty poset + // might cause errors and miscompilations if reused. + if checkEnabled { + if err := po.CheckEmpty(); err != nil { + f.Fatalf("poset not empty after function %s: %v", f.Name, err) + } + } + f.retPoset(po) + } +} + // prove removes redundant BlockIf branches that can be inferred // from previous dominating comparisons. // @@ -917,17 +931,7 @@ func prove(f *Func) { ft.restore() - // Return the posets to the free list - for _, po := range []*poset{ft.orderS, ft.orderU} { - // Make sure it's empty as it should be. A non-empty poset - // might cause errors and miscompilations if reused. - if checkEnabled { - if err := po.CheckEmpty(); err != nil { - f.Fatalf("prove poset not empty after function %s: %v", f.Name, err) - } - } - f.retPoset(po) - } + ft.cleanup(f) } // getBranch returns the range restrictions added by p diff --git a/test/fuse.go b/test/fuse.go new file mode 100644 index 00000000000..7d39c3cdb94 --- /dev/null +++ b/test/fuse.go @@ -0,0 +1,190 @@ +// +build amd64 arm64 +// errorcheck -0 -d=ssa/late_fuse/debug=1 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "strings" + +const Cf2 = 2.0 + +func fEqEq(a int, f float64) bool { + return a == 0 && f > Cf2 || a == 0 && f < -Cf2 // ERROR "Redirect Eq64 based on Eq64$" +} + +func fEqNeq(a int32, f float64) bool { + return a == 0 && f > Cf2 || a != 0 && f < -Cf2 // ERROR "Redirect Neq32 based on Eq32$" +} + +func fEqLess(a int8, f float64) bool { + return a == 0 && f > Cf2 || a < 0 && f < -Cf2 +} + +func fEqLeq(a float64, f float64) bool { + return a == 0 && f > Cf2 || a <= 0 && f < -Cf2 +} + +func fEqLessU(a uint, f float64) bool { + return a == 0 && f > Cf2 || a < 0 && f < -Cf2 +} + +func fEqLeqU(a uint64, f float64) bool { + return a == 0 && f > Cf2 || a <= 0 && f < -Cf2 // ERROR "Redirect Leq64U based on Eq64$" +} + +func fNeqEq(a int, f float64) bool { + return a != 0 && f > Cf2 || a == 0 && f < -Cf2 // ERROR "Redirect Eq64 based on Neq64$" +} + +func fNeqNeq(a int32, f float64) bool { + return a != 0 && f > Cf2 || a != 0 && f < -Cf2 // ERROR "Redirect Neq32 based on Neq32$" +} + +func fNeqLess(a float32, f float64) bool { + // TODO: Add support for floating point numbers in prove + return a != 0 && f > Cf2 || a < 0 && f < -Cf2 +} + +func fNeqLeq(a int16, f float64) bool { + return a != 0 && f > Cf2 || a <= 0 && f < -Cf2 // ERROR "Redirect Leq16 based on Neq16$" +} + +func fNeqLessU(a uint, f float64) bool { + return a != 0 && f > Cf2 || a < 0 && f < -Cf2 +} + +func fNeqLeqU(a uint32, f float64) bool { + return a != 0 && f > Cf2 || a <= 0 && f < -Cf2 // ERROR "Redirect Leq32U based on Neq32$" +} + +func fLessEq(a int, f float64) bool { + return a < 0 && f > Cf2 || a == 0 && f < -Cf2 +} + +func fLessNeq(a int32, f float64) bool { + return a < 0 && f > Cf2 || a != 0 && f < -Cf2 +} + +func fLessLess(a float32, f float64) bool { + return a < 0 && f > Cf2 || a < 0 && f < -Cf2 // ERROR "Redirect Less32F based on Less32F$" +} + +func fLessLeq(a float64, f float64) bool { + return a < 0 && f > Cf2 || a <= 0 && f < -Cf2 +} + +func fLeqEq(a float64, f float64) bool { + return a <= 0 && f > Cf2 || a == 0 && f < -Cf2 +} + +func fLeqNeq(a int16, f float64) bool { + return a <= 0 && f > Cf2 || a != 0 && f < -Cf2 // ERROR "Redirect Neq16 based on Leq16$" +} + +func fLeqLess(a float32, f float64) bool { + return a <= 0 && f > Cf2 || a < 0 && f < -Cf2 +} + +func fLeqLeq(a int8, f float64) bool { + return a <= 0 && f > Cf2 || a <= 0 && f < -Cf2 // ERROR "Redirect Leq8 based on Leq8$" +} + +func fLessUEq(a uint8, f float64) bool { + return a < 0 && f > Cf2 || a == 0 && f < -Cf2 +} + +func fLessUNeq(a uint16, f float64) bool { + return a < 0 && f > Cf2 || a != 0 && f < -Cf2 +} + +func fLessULessU(a uint32, f float64) bool { + return a < 0 && f > Cf2 || a < 0 && f < -Cf2 +} + +func fLessULeqU(a uint64, f float64) bool { + return a < 0 && f > Cf2 || a <= 0 && f < -Cf2 +} + +func fLeqUEq(a uint8, f float64) bool { + return a <= 0 && f > Cf2 || a == 0 && f < -Cf2 // ERROR "Redirect Eq8 based on Leq8U$" +} + +func fLeqUNeq(a uint16, f float64) bool { + return a <= 0 && f > Cf2 || a != 0 && f < -Cf2 // ERROR "Redirect Neq16 based on Leq16U$" +} + +func fLeqLessU(a uint32, f float64) bool { + return a <= 0 && f > Cf2 || a < 0 && f < -Cf2 +} + +func fLeqLeqU(a uint64, f float64) bool { + return a <= 0 && f > Cf2 || a <= 0 && f < -Cf2 // ERROR "Redirect Leq64U based on Leq64U$" +} + +// Arg tests are disabled because the op name is different on amd64 and arm64. + +func fEqPtrEqPtr(a, b *int, f float64) bool { + return a == b && f > Cf2 || a == b && f < -Cf2 // ERROR "Redirect EqPtr based on EqPtr$" +} + +func fEqPtrNeqPtr(a, b *int, f float64) bool { + return a == b && f > Cf2 || a != b && f < -Cf2 // ERROR "Redirect NeqPtr based on EqPtr$" +} + +func fNeqPtrEqPtr(a, b *int, f float64) bool { + return a != b && f > Cf2 || a == b && f < -Cf2 // ERROR "Redirect EqPtr based on NeqPtr$" +} + +func fNeqPtrNeqPtr(a, b *int, f float64) bool { + return a != b && f > Cf2 || a != b && f < -Cf2 // ERROR "Redirect NeqPtr based on NeqPtr$" +} + +func fEqInterEqInter(a interface{}, f float64) bool { + return a == nil && f > Cf2 || a == nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil$" +} + +func fEqInterNeqInter(a interface{}, f float64) bool { + return a == nil && f > Cf2 || a != nil && f < -Cf2 +} + +func fNeqInterEqInter(a interface{}, f float64) bool { + return a != nil && f > Cf2 || a == nil && f < -Cf2 +} + +func fNeqInterNeqInter(a interface{}, f float64) bool { + return a != nil && f > Cf2 || a != nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil$" +} + +func fEqSliceEqSlice(a []int, f float64) bool { + return a == nil && f > Cf2 || a == nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil$" +} + +func fEqSliceNeqSlice(a []int, f float64) bool { + return a == nil && f > Cf2 || a != nil && f < -Cf2 +} + +func fNeqSliceEqSlice(a []int, f float64) bool { + return a != nil && f > Cf2 || a == nil && f < -Cf2 +} + +func fNeqSliceNeqSlice(a []int, f float64) bool { + return a != nil && f > Cf2 || a != nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil$" +} + +func fPhi(a, b string) string { + aslash := strings.HasSuffix(a, "/") // ERROR "Redirect Phi based on Phi$" + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +func main() { +}