From 30088ac9a391d8505a3e016f36aaa23170109f6f Mon Sep 17 00:00:00 2001 From: Keith Randall Date: Tue, 4 Oct 2016 14:35:45 -0700 Subject: [PATCH] cmd/compile: make CSE faster To refine a set of possibly equivalent values, the old CSE algorithm picked one value, compared it against all the others, and made two sets out of the results (the values that match the picked value and the values that didn't). Unfortunately, this leads to O(n^2) behavior. The picked value ends up being equal to no other values, we make size 1 and size n-1 sets, and then recurse on the size n-1 set. Instead, sort the set by the equivalence classes of its arguments. Then we just look for spots in the sorted list where the equivalence classes of the arguments change. This lets us do a multi-way split for O(n lg n) time. This change makes cmpDepth unnecessary. The refinement portion used to call the type comparator. That is unnecessary as the type was already part of the initial partition. Lowers time of 16361 from 8 sec to 3 sec. Lowers time of 15112 from 282 sec to 20 sec. That's kind of unfair, as CL 30257 changed it from 21 sec to 282 sec. But that CL fixed other bad compile times (issue #17127) by large factors, so net still a big win. Fixes #15112 Fixes #16361 Change-Id: I351ce111bae446608968c6d48710eeb6a3d8e527 Reviewed-on: https://go-review.googlesource.com/30354 Reviewed-by: Todd Neal --- src/cmd/compile/internal/ssa/cse.go | 125 ++++++++++++++++------------ 1 file changed, 73 insertions(+), 52 deletions(-) diff --git a/src/cmd/compile/internal/ssa/cse.go b/src/cmd/compile/internal/ssa/cse.go index 532232de57..24f071bcfd 100644 --- a/src/cmd/compile/internal/ssa/cse.go +++ b/src/cmd/compile/internal/ssa/cse.go @@ -9,10 +9,6 @@ import ( "sort" ) -const ( - cmpDepth = 1 -) - // cse does common-subexpression elimination on the Function. // Values are just relinked, nothing is deleted. A subsequent deadcode // pass is required to actually remove duplicate expressions. @@ -60,7 +56,8 @@ func cse(f *Func) { valueEqClass[v.ID] = -v.ID } } - for i, e := range partition { + var pNum ID = 1 + for _, e := range partition { if f.pass.debug > 1 && len(e) > 500 { fmt.Printf("CSE.large partition (%d): ", len(e)) for j := 0; j < 3; j++ { @@ -70,20 +67,22 @@ func cse(f *Func) { } for _, v := range e { - valueEqClass[v.ID] = ID(i) + valueEqClass[v.ID] = pNum } if f.pass.debug > 2 && len(e) > 1 { - fmt.Printf("CSE.partition #%d:", i) + fmt.Printf("CSE.partition #%d:", pNum) for _, v := range e { fmt.Printf(" %s", v.String()) } fmt.Printf("\n") } + pNum++ } - // Find an equivalence class where some members of the class have - // non-equivalent arguments. Split the equivalence class appropriately. - // Repeat until we can't find any more splits. + // Split equivalence classes at points where they have + // non-equivalent arguments. Repeat until we can't find any + // more splits. + var splitPoints []int for { changed := false @@ -91,39 +90,51 @@ func cse(f *Func) { // we process new additions as they arrive, avoiding O(n^2) behavior. for i := 0; i < len(partition); i++ { e := partition[i] - v := e[0] - // all values in this equiv class that are not equivalent to v get moved - // into another equiv class. - // To avoid allocating while building that equivalence class, - // move the values equivalent to v to the beginning of e - // and other values to the end of e. - allvals := e - eqloop: - for j := 1; j < len(e); { - w := e[j] - equivalent := true - for i := 0; i < len(v.Args); i++ { - if valueEqClass[v.Args[i].ID] != valueEqClass[w.Args[i].ID] { - equivalent = false + + // Sort by eq class of arguments. + sort.Sort(partitionByArgClass{e, valueEqClass}) + + // Find split points. + splitPoints = append(splitPoints[:0], 0) + for j := 1; j < len(e); j++ { + v, w := e[j-1], e[j] + eqArgs := true + for k, a := range v.Args { + b := w.Args[k] + if valueEqClass[a.ID] != valueEqClass[b.ID] { + eqArgs = false break } } - if !equivalent || v.Type.Compare(w.Type) != CMPeq { - // w is not equivalent to v. - // move it to the end and shrink e. - e[j], e[len(e)-1] = e[len(e)-1], e[j] - e = e[:len(e)-1] - valueEqClass[w.ID] = ID(len(partition)) - changed = true - continue eqloop + if !eqArgs { + splitPoints = append(splitPoints, j) } - // v and w are equivalent. Keep w in e. - j++ } - partition[i] = e - if len(e) < len(allvals) { - partition = append(partition, allvals[len(e):]) + if len(splitPoints) == 1 { + continue // no splits, leave equivalence class alone. } + + // Move another equivalence class down in place of e. + partition[i] = partition[len(partition)-1] + partition = partition[:len(partition)-1] + i-- + + // Add new equivalence classes for the parts of e we found. + splitPoints = append(splitPoints, len(e)) + for j := 0; j < len(splitPoints)-1; j++ { + f := e[splitPoints[j]:splitPoints[j+1]] + if len(f) == 1 { + // Don't add singletons. + valueEqClass[f[0].ID] = -f[0].ID + continue + } + for _, v := range f { + valueEqClass[v.ID] = pNum + } + pNum++ + partition = append(partition, f) + } + changed = true } if !changed { @@ -253,7 +264,7 @@ func partitionValues(a []*Value, auxIDs auxmap) []eqclass { j := 1 for ; j < len(a); j++ { w := a[j] - if cmpVal(v, w, auxIDs, cmpDepth) != CMPeq { + if cmpVal(v, w, auxIDs) != CMPeq { break } } @@ -274,7 +285,7 @@ func lt2Cmp(isLt bool) Cmp { type auxmap map[interface{}]int32 -func cmpVal(v, w *Value, auxIDs auxmap, depth int) Cmp { +func cmpVal(v, w *Value, auxIDs auxmap) Cmp { // Try to order these comparison by cost (cheaper first) if v.Op != w.Op { return lt2Cmp(v.Op < w.Op) @@ -308,18 +319,6 @@ func cmpVal(v, w *Value, auxIDs auxmap, depth int) Cmp { return lt2Cmp(auxIDs[v.Aux] < auxIDs[w.Aux]) } - if depth > 0 { - for i := range v.Args { - if v.Args[i] == w.Args[i] { - // skip comparing equal args - continue - } - if ac := cmpVal(v.Args[i], w.Args[i], auxIDs, depth-1); ac != CMPeq { - return ac - } - } - } - return CMPeq } @@ -334,7 +333,7 @@ func (sv sortvalues) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] } func (sv sortvalues) Less(i, j int) bool { v := sv.a[i] w := sv.a[j] - if cmp := cmpVal(v, w, sv.auxIDs, cmpDepth); cmp != CMPeq { + if cmp := cmpVal(v, w, sv.auxIDs); cmp != CMPeq { return cmp == CMPlt } @@ -354,3 +353,25 @@ func (sv partitionByDom) Less(i, j int) bool { w := sv.a[j] return sv.sdom.domorder(v.Block) < sv.sdom.domorder(w.Block) } + +type partitionByArgClass struct { + a []*Value // array of values + eqClass []ID // equivalence class IDs of values +} + +func (sv partitionByArgClass) Len() int { return len(sv.a) } +func (sv partitionByArgClass) Swap(i, j int) { sv.a[i], sv.a[j] = sv.a[j], sv.a[i] } +func (sv partitionByArgClass) Less(i, j int) bool { + v := sv.a[i] + w := sv.a[j] + for i, a := range v.Args { + b := w.Args[i] + if sv.eqClass[a.ID] < sv.eqClass[b.ID] { + return true + } + if sv.eqClass[a.ID] > sv.eqClass[b.ID] { + return false + } + } + return false +}