From a4a130f6d065187e1b7f4963792af5d5e84efa3c Mon Sep 17 00:00:00 2001 From: "khr@golang.org" Date: Thu, 27 Jun 2024 15:53:24 -0700 Subject: [PATCH] cmd/compile: propagate constant ranges through multiplies and shifts Fixes #40704 Fixes #66826 Change-Id: Ia9c356e29b2ed6f2e3bc6e5eb9304cd4dccb4263 Reviewed-on: https://go-review.googlesource.com/c/go/+/599256 Reviewed-by: Michael Knyszek LUCI-TryBot-Result: Go LUCI Reviewed-by: David Chase --- src/cmd/compile/internal/ssa/prove.go | 70 ++++++++++++++++++++++++++- test/prove.go | 12 +++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index df3566985f..6091950be8 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -8,6 +8,7 @@ import ( "cmd/internal/src" "fmt" "math" + "math/bits" ) type branch int @@ -311,6 +312,40 @@ func (l limit) sub(l2 limit, b uint) limit { return r } +// same as add but for multiplication. +func (l limit) mul(l2 limit, b uint) limit { + r := noLimit + umaxhi, umaxlo := bits.Mul64(l.umax, l2.umax) + if umaxhi == 0 && fitsInBitsU(umaxlo, b) { + r.umax = umaxlo + r.umin = l.umin * l2.umin + // Note: if the code containing this multiply is + // unreachable, then we may have umin>umax, and this + // multiply may overflow. But that's ok for + // unreachable code. If this code is reachable, we + // know umin<=umax, so this multiply will not overflow + // because the max multiply didn't. + } + // Signed is harder, so don't bother. The only useful + // case is when we know both multiplicands are nonnegative, + // but that case is handled above because we would have then + // previously propagated signed info to the unsigned domain, + // and will propagate it back after the multiply. + return r +} + +// Similar to add, but compute 1 << l if it fits without overflow in b bits. +func (l limit) exp2(b uint) limit { + r := noLimit + if l.umax < uint64(b) { + r.umin = 1 << l.umin + r.umax = 1 << l.umax + // Same as above in mul, signed<->unsigned propagation + // will handle the signed case for us. + } + return r +} + var noLimit = limit{math.MinInt64, math.MaxInt64, 0, math.MaxUint64} // a limitFact is a limit known for a particular value. @@ -1548,6 +1583,39 @@ func (ft *factsTable) flowLimit(v *Value) bool { a := ft.limits[v.Args[0].ID] b := ft.limits[v.Args[1].ID] return ft.newLimit(v, a.sub(b, 8)) + case OpMul64: + a := ft.limits[v.Args[0].ID] + b := ft.limits[v.Args[1].ID] + return ft.newLimit(v, a.mul(b, 64)) + case OpMul32: + a := ft.limits[v.Args[0].ID] + b := ft.limits[v.Args[1].ID] + return ft.newLimit(v, a.mul(b, 32)) + case OpMul16: + a := ft.limits[v.Args[0].ID] + b := ft.limits[v.Args[1].ID] + return ft.newLimit(v, a.mul(b, 16)) + case OpMul8: + a := ft.limits[v.Args[0].ID] + b := ft.limits[v.Args[1].ID] + return ft.newLimit(v, a.mul(b, 8)) + case OpLsh64x64, OpLsh64x32, OpLsh64x16, OpLsh64x8: + a := ft.limits[v.Args[0].ID] + b := ft.limits[v.Args[1].ID] + return ft.newLimit(v, a.mul(b.exp2(64), 64)) + case OpLsh32x64, OpLsh32x32, OpLsh32x16, OpLsh32x8: + a := ft.limits[v.Args[0].ID] + b := ft.limits[v.Args[1].ID] + return ft.newLimit(v, a.mul(b.exp2(32), 32)) + case OpLsh16x64, OpLsh16x32, OpLsh16x16, OpLsh16x8: + a := ft.limits[v.Args[0].ID] + b := ft.limits[v.Args[1].ID] + return ft.newLimit(v, a.mul(b.exp2(16), 16)) + case OpLsh8x64, OpLsh8x32, OpLsh8x16, OpLsh8x8: + a := ft.limits[v.Args[0].ID] + b := ft.limits[v.Args[1].ID] + return ft.newLimit(v, a.mul(b.exp2(8), 8)) + case OpPhi: // Compute the union of all the input phis. // Often this will convey no information, because the block @@ -1834,7 +1902,7 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { OpRsh64x8, OpRsh64x16, OpRsh64x32, OpRsh64x64: // Check whether, for a >> b, we know that a is non-negative // and b is all of a's bits except the MSB. If so, a is shifted to zero. - bits := 8 * v.Type.Size() + bits := 8 * v.Args[0].Type.Size() if v.Args[1].isGenericIntConst() && v.Args[1].AuxInt >= bits-1 && ft.isNonNegative(v.Args[0]) { if b.Func.pass.debug > 0 { b.Func.Warnl(v.Pos, "Proved %v shifts to zero", v.Op) diff --git a/test/prove.go b/test/prove.go index 70466ce2c5..6cb30c6ce1 100644 --- a/test/prove.go +++ b/test/prove.go @@ -1147,6 +1147,18 @@ func inequalityPropagation(a [1]int, i, j uint) int { return 0 } +func issue66826a(a [21]byte) { + for i := 0; i <= 10; i++ { // ERROR "Induction variable: limits \[0,10\], increment 1$" + _ = a[2*i] // ERROR "Proved IsInBounds" + } +} +func issue66826b(a [31]byte, i int) { + if i < 0 || i > 10 { + return + } + _ = a[3*i] // ERROR "Proved IsInBounds" +} + //go:noinline func useInt(a int) { }