From 5925cd3d15c7e1eb71125964e681c4b4c2db750d Mon Sep 17 00:00:00 2001 From: "khr@golang.org" Date: Sun, 7 Jul 2024 14:58:47 -0700 Subject: [PATCH] cmd/compile: handle boolean and pointer relations The constant lattice for these types is pretty simple. We no longer need the old-style facts table, as the ordering table now has all that information. Change-Id: If0e118c27a4de8e9bfd727b78942185c2eb50c4b Reviewed-on: https://go-review.googlesource.com/c/go/+/599097 Reviewed-by: David Chase LUCI-TryBot-Result: Go LUCI Reviewed-by: Michael Knyszek --- src/cmd/compile/internal/ssa/prove.go | 365 +++++++++++++++++--------- test/fuse.go | 8 +- test/prove.go | 22 ++ 3 files changed, 263 insertions(+), 132 deletions(-) diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index 6091950be8..c8d2ab7a6f 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -128,9 +128,18 @@ type fact struct { } // a limit records known upper and lower bounds for a value. +// +// If we have min>max or umin>umax, then this limit is +// called "unsatisfiable". When we encounter such a limit, we +// know that any code for which that limit applies is unreachable. +// We don't particularly care how unsatisfiable limits propagate, +// including becoming satisfiable, because any optimization +// decisions based on those limits only apply to unreachable code. type limit struct { min, max int64 // min <= value <= max, signed umin, umax uint64 // umin <= value <= umax, unsigned + // For booleans, we use 0==false, 1==true for both ranges + // For pointers, we use 0,0,0,0 for nil and minInt64,maxInt64,1,maxUint64 for nonnil } func (l limit) String() string { @@ -359,8 +368,9 @@ type ordering struct { next *ordering // linked list of all known orderings for v. // Note: v is implicit here, determined by which linked list it is in. w *Value - d domain // one of signed or unsigned + d domain r relation // one of ==,!=,<,<=,>,>= + // if d is boolean or pointer, r can only be ==, != } // factsTable keeps track of relations between pairs of values. @@ -379,9 +389,6 @@ type factsTable struct { unsat bool // true if facts contains a contradiction unsatDepth int // number of unsat checkpoints - facts map[pair]relation // current known set of relation - stack []fact // previous sets of relations - // order* is a couple of partial order sets that record information // about relations between SSA values in the signed and unsigned // domain. @@ -423,8 +430,6 @@ func newFactsTable(f *Func) *factsTable { ft.orderS.SetUnsigned(false) ft.orderU.SetUnsigned(true) ft.orderings = make(map[ID]*ordering) - ft.facts = make(map[pair]relation) - ft.stack = make([]fact, 4) ft.limits = f.Cache.allocLimitSlice(f.NumValues()) for _, b := range f.Blocks { for _, v := range b.Values { @@ -471,6 +476,21 @@ func (ft *factsTable) unsignedMinMax(v *Value, min, max uint64) bool { return ft.newLimit(v, limit{min: math.MinInt64, max: math.MaxInt64, umin: min, umax: max}) } +func (ft *factsTable) booleanFalse(v *Value) bool { + return ft.newLimit(v, limit{min: 0, max: 0, umin: 0, umax: 0}) +} +func (ft *factsTable) booleanTrue(v *Value) bool { + return ft.newLimit(v, limit{min: 1, max: 1, umin: 1, umax: 1}) +} +func (ft *factsTable) pointerNil(v *Value) bool { + return ft.newLimit(v, limit{min: 0, max: 0, umin: 0, umax: 0}) +} +func (ft *factsTable) pointerNonNil(v *Value) bool { + l := noLimit + l.umin = 1 + return ft.newLimit(v, l) +} + // newLimit adds new limiting information for v. // Returns true if the new limit added any new information. func (ft *factsTable) newLimit(v *Value, newLim limit) bool { @@ -574,6 +594,38 @@ func (ft *factsTable) newLimit(v *Value, newLim limit) bool { } } } + case boolean: + switch o.r { + case eq: + if lim.min == 0 && lim.max == 0 { // constant false + ft.booleanFalse(o.w) + } + if lim.min == 1 && lim.max == 1 { // constant true + ft.booleanTrue(o.w) + } + case lt | gt: + if lim.min == 0 && lim.max == 0 { // constant false + ft.booleanTrue(o.w) + } + if lim.min == 1 && lim.max == 1 { // constant true + ft.booleanFalse(o.w) + } + } + case pointer: + switch o.r { + case eq: + if lim.umax == 0 { // nil + ft.pointerNil(o.w) + } + if lim.umin > 0 { // non-nil + ft.pointerNonNil(o.w) + } + case lt | gt: + if lim.umax == 0 { // nil + ft.pointerNonNil(o.w) + } + // note: not equal to non-nil doesn't tell us anything. + } } } @@ -647,122 +699,163 @@ func (ft *factsTable) update(parent *Block, v, w *Value, d domain, r relation) { ft.unsat = true return } - } else { - if lessByID(w, v) { - v, w = w, v - r = reverseBits[r] - } - - p := pair{v, w, d} - oldR, ok := ft.facts[p] - if !ok { - if v == w { - oldR = eq - } else { - oldR = lt | eq | gt + } + if d == boolean || d == pointer { + for o := ft.orderings[v.ID]; o != nil; o = o.next { + if o.d == d && o.w == w { + // We already know a relationship between v and w. + // Either it is a duplicate, or it is a contradiction, + // as we only allow eq and lt|gt for these domains, + if o.r != r { + ft.unsat = true + } + return } } - // No changes compared to information already in facts table. - if oldR == r { - return - } - ft.stack = append(ft.stack, fact{p, oldR}) - ft.facts[p] = oldR & r - // If this relation is not satisfiable, mark it and exit right away - if oldR&r == 0 { - if parent.Func.pass.debug > 2 { - parent.Func.Warnl(parent.Pos, "unsat %s %s %s", v, w, r) - } - ft.unsat = true - return - } + // TODO: this does not do transitive equality. + // We could use a poset like above, but somewhat degenerate (==,!= only). + ft.addOrdering(v, w, d, r) + ft.addOrdering(w, v, d, r) // note: reverseBits unnecessary for eq and lt|gt. } // Extract new constant limits based on the comparison. - if d == signed || d == unsigned { - vLimit := ft.limits[v.ID] - wLimit := ft.limits[w.ID] - // Note: all the +1/-1 below could overflow/underflow. Either will - // still generate correct results, it will just lead to imprecision. - // In fact if there is overflow/underflow, the corresponding - // code is unreachable because the known range is outside the range - // of the value's type. - switch d { - case signed: - switch r { - case eq: // v == w - ft.signedMinMax(v, wLimit.min, wLimit.max) - ft.signedMinMax(w, vLimit.min, vLimit.max) - case lt: // v < w - ft.signedMax(v, wLimit.max-1) - ft.signedMin(w, vLimit.min+1) - case lt | eq: // v <= w - ft.signedMax(v, wLimit.max) - ft.signedMin(w, vLimit.min) - case gt: // v > w - ft.signedMin(v, wLimit.min+1) - ft.signedMax(w, vLimit.max-1) - case gt | eq: // v >= w - ft.signedMin(v, wLimit.min) - ft.signedMax(w, vLimit.max) - case lt | gt: // v != w - if vLimit.min == vLimit.max { // v is a constant - c := vLimit.min - if wLimit.min == c { - ft.signedMin(w, c+1) - } - if wLimit.max == c { - ft.signedMax(w, c-1) - } + vLimit := ft.limits[v.ID] + wLimit := ft.limits[w.ID] + // Note: all the +1/-1 below could overflow/underflow. Either will + // still generate correct results, it will just lead to imprecision. + // In fact if there is overflow/underflow, the corresponding + // code is unreachable because the known range is outside the range + // of the value's type. + switch d { + case signed: + switch r { + case eq: // v == w + ft.signedMinMax(v, wLimit.min, wLimit.max) + ft.signedMinMax(w, vLimit.min, vLimit.max) + case lt: // v < w + ft.signedMax(v, wLimit.max-1) + ft.signedMin(w, vLimit.min+1) + case lt | eq: // v <= w + ft.signedMax(v, wLimit.max) + ft.signedMin(w, vLimit.min) + case gt: // v > w + ft.signedMin(v, wLimit.min+1) + ft.signedMax(w, vLimit.max-1) + case gt | eq: // v >= w + ft.signedMin(v, wLimit.min) + ft.signedMax(w, vLimit.max) + case lt | gt: // v != w + if vLimit.min == vLimit.max { // v is a constant + c := vLimit.min + if wLimit.min == c { + ft.signedMin(w, c+1) } - if wLimit.min == wLimit.max { // w is a constant - c := wLimit.min - if vLimit.min == c { - ft.signedMin(v, c+1) - } - if vLimit.max == c { - ft.signedMax(v, c-1) - } + if wLimit.max == c { + ft.signedMax(w, c-1) } } - case unsigned: - switch r { - case eq: // v == w - ft.unsignedMinMax(v, wLimit.umin, wLimit.umax) - ft.unsignedMinMax(w, vLimit.umin, vLimit.umax) - case lt: // v < w - ft.unsignedMax(v, wLimit.umax-1) - ft.unsignedMin(w, vLimit.umin+1) - case lt | eq: // v <= w - ft.unsignedMax(v, wLimit.umax) - ft.unsignedMin(w, vLimit.umin) - case gt: // v > w - ft.unsignedMin(v, wLimit.umin+1) - ft.unsignedMax(w, vLimit.umax-1) - case gt | eq: // v >= w - ft.unsignedMin(v, wLimit.umin) - ft.unsignedMax(w, vLimit.umax) - case lt | gt: // v != w - if vLimit.umin == vLimit.umax { // v is a constant - c := vLimit.umin - if wLimit.umin == c { - ft.unsignedMin(w, c+1) - } - if wLimit.umax == c { - ft.unsignedMax(w, c-1) - } + if wLimit.min == wLimit.max { // w is a constant + c := wLimit.min + if vLimit.min == c { + ft.signedMin(v, c+1) } - if wLimit.umin == wLimit.umax { // w is a constant - c := wLimit.umin - if vLimit.umin == c { - ft.unsignedMin(v, c+1) - } - if vLimit.umax == c { - ft.unsignedMax(v, c-1) - } + if vLimit.max == c { + ft.signedMax(v, c-1) } } } + case unsigned: + switch r { + case eq: // v == w + ft.unsignedMinMax(v, wLimit.umin, wLimit.umax) + ft.unsignedMinMax(w, vLimit.umin, vLimit.umax) + case lt: // v < w + ft.unsignedMax(v, wLimit.umax-1) + ft.unsignedMin(w, vLimit.umin+1) + case lt | eq: // v <= w + ft.unsignedMax(v, wLimit.umax) + ft.unsignedMin(w, vLimit.umin) + case gt: // v > w + ft.unsignedMin(v, wLimit.umin+1) + ft.unsignedMax(w, vLimit.umax-1) + case gt | eq: // v >= w + ft.unsignedMin(v, wLimit.umin) + ft.unsignedMax(w, vLimit.umax) + case lt | gt: // v != w + if vLimit.umin == vLimit.umax { // v is a constant + c := vLimit.umin + if wLimit.umin == c { + ft.unsignedMin(w, c+1) + } + if wLimit.umax == c { + ft.unsignedMax(w, c-1) + } + } + if wLimit.umin == wLimit.umax { // w is a constant + c := wLimit.umin + if vLimit.umin == c { + ft.unsignedMin(v, c+1) + } + if vLimit.umax == c { + ft.unsignedMax(v, c-1) + } + } + } + case boolean: + switch r { + case eq: // v == w + if vLimit.min == 1 { // v is true + ft.booleanTrue(w) + } + if vLimit.max == 0 { // v is false + ft.booleanFalse(w) + } + if wLimit.min == 1 { // w is true + ft.booleanTrue(v) + } + if wLimit.max == 0 { // w is false + ft.booleanFalse(v) + } + case lt | gt: // v != w + if vLimit.min == 1 { // v is true + ft.booleanFalse(w) + } + if vLimit.max == 0 { // v is false + ft.booleanTrue(w) + } + if wLimit.min == 1 { // w is true + ft.booleanFalse(v) + } + if wLimit.max == 0 { // w is false + ft.booleanTrue(v) + } + } + case pointer: + switch r { + case eq: // v == w + if vLimit.umax == 0 { // v is nil + ft.pointerNil(w) + } + if vLimit.umin > 0 { // v is non-nil + ft.pointerNonNil(w) + } + if wLimit.umax == 0 { // w is nil + ft.pointerNil(v) + } + if wLimit.umin > 0 { // w is non-nil + ft.pointerNonNil(v) + } + case lt | gt: // v != w + if vLimit.umax == 0 { // v is nil + ft.pointerNonNil(w) + } + if wLimit.umax == 0 { // w is nil + ft.pointerNonNil(v) + } + // Note: the other direction doesn't work. + // Being not equal to a non-nil pointer doesn't + // make you (necessarily) a nil pointer. + } } // Derived facts below here are only about numbers. @@ -970,7 +1063,6 @@ func (ft *factsTable) checkpoint() { if ft.unsat { ft.unsatDepth++ } - ft.stack = append(ft.stack, checkpointFact) ft.limitStack = append(ft.limitStack, checkpointBound) ft.orderS.Checkpoint() ft.orderU.Checkpoint() @@ -986,18 +1078,6 @@ func (ft *factsTable) restore() { } else { ft.unsat = false } - for { - old := ft.stack[len(ft.stack)-1] - ft.stack = ft.stack[:len(ft.stack)-1] - if old == checkpointFact { - break - } - if old.r == lt|eq|gt { - delete(ft.facts, old.p) - } else { - ft.facts[old.p] = old.r - } - } for { old := ft.limitStack[len(ft.limitStack)-1] ft.limitStack = ft.limitStack[:len(ft.limitStack)-1] @@ -1050,12 +1130,14 @@ var ( OpEq32: {signed | unsigned, eq}, OpEq64: {signed | unsigned, eq}, OpEqPtr: {pointer, eq}, + OpEqB: {boolean, eq}, OpNeq8: {signed | unsigned, lt | gt}, OpNeq16: {signed | unsigned, lt | gt}, OpNeq32: {signed | unsigned, lt | gt}, OpNeq64: {signed | unsigned, lt | gt}, OpNeqPtr: {pointer, lt | gt}, + OpNeqB: {boolean, lt | gt}, OpLess8: {signed, lt}, OpLess8U: {unsigned, lt}, @@ -1407,8 +1489,28 @@ func prove(f *Func) { // flowLimit, below, which computes additional constraints based on // ranges of opcode arguments). func initLimit(v *Value) limit { + if v.Type.IsBoolean() { + switch v.Op { + case OpConstBool: + b := v.AuxInt + return limit{min: b, max: b, umin: uint64(b), umax: uint64(b)} + default: + return limit{min: 0, max: 1, umin: 0, umax: 1} + } + } + if v.Type.IsPtrShaped() { // These are the types that EqPtr/NeqPtr operate on, except uintptr. + switch v.Op { + case OpConstNil: + return limit{min: 0, max: 0, umin: 0, umax: 0} + case OpAddr, OpLocalAddr: // TODO: others? + l := noLimit + l.umin = 1 + return l + default: + return noLimit + } + } if !v.Type.IsInteger() { - // TODO: boolean? return noLimit } @@ -1700,9 +1802,9 @@ func addBranchRestrictions(ft *factsTable, b *Block, br branch) { c := b.Controls[0] switch { case br == negative: - addRestrictions(b, ft, boolean, nil, c, eq) + ft.booleanFalse(c) case br == positive: - addRestrictions(b, ft, boolean, nil, c, lt|gt) + ft.booleanTrue(c) case br >= jumpTable0: idx := br - jumpTable0 val := int64(idx) @@ -1769,7 +1871,14 @@ func addBranchRestrictions(ft *factsTable, b *Block, br branch) { addRestrictions(b, ft, d, c.Args[0], c.Args[1], tr.r) } } - + } + if c.Op == OpIsNonNil { + switch br { + case positive: + ft.pointerNonNil(c.Args[0]) + case negative: + ft.pointerNil(c.Args[0]) + } } } @@ -1984,7 +2093,7 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { // Helps in cases where we reuse a value after branching on its equality. for i, arg := range v.Args { switch arg.Op { - case OpConst64, OpConst32, OpConst16, OpConst8: + case OpConst64, OpConst32, OpConst16, OpConst8, OpConstBool, OpConstNil: continue } lim := ft.limits[arg.ID] diff --git a/test/fuse.go b/test/fuse.go index e9205dcc23..9366b21858 100644 --- a/test/fuse.go +++ b/test/fuse.go @@ -148,11 +148,11 @@ func fEqInterEqInter(a interface{}, f float64) bool { } func fEqInterNeqInter(a interface{}, f float64) bool { - return a == nil && f > Cf2 || a != nil && f < -Cf2 + return a == nil && f > Cf2 || a != nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil" } func fNeqInterEqInter(a interface{}, f float64) bool { - return a != nil && f > Cf2 || a == nil && f < -Cf2 + return a != nil && f > Cf2 || a == nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil" } func fNeqInterNeqInter(a interface{}, f float64) bool { @@ -164,11 +164,11 @@ func fEqSliceEqSlice(a []int, f float64) bool { } func fEqSliceNeqSlice(a []int, f float64) bool { - return a == nil && f > Cf2 || a != nil && f < -Cf2 + return a == nil && f > Cf2 || a != nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil" } func fNeqSliceEqSlice(a []int, f float64) bool { - return a != nil && f > Cf2 || a == nil && f < -Cf2 + return a != nil && f > Cf2 || a == nil && f < -Cf2 // ERROR "Redirect IsNonNil based on IsNonNil" } func fNeqSliceNeqSlice(a []int, f float64) bool { diff --git a/test/prove.go b/test/prove.go index 6cb30c6ce1..b85ee5fe0d 100644 --- a/test/prove.go +++ b/test/prove.go @@ -1159,6 +1159,28 @@ func issue66826b(a [31]byte, i int) { _ = a[3*i] // ERROR "Proved IsInBounds" } +func f20(a, b bool) int { + if a == b { + if a { + if b { // ERROR "Proved Arg" + return 1 + } + } + } + return 0 +} + +func f21(a, b *int) int { + if a == b { + if a != nil { + if b != nil { // ERROR "Proved IsNonNil" + return 1 + } + } + } + return 0 +} + //go:noinline func useInt(a int) { }