diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index 6054541c3b3..2bda780d026 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -296,6 +296,15 @@ func (ft *factsTable) update(v, w *Value, d domain, r relation) { } } +// isNonNegative returns true if v is known to be non-negative. +func (ft *factsTable) isNonNegative(v *Value) bool { + if isNonNegative(v) { + return true + } + l, has := ft.limits[v.ID] + return has && (l.min >= 0 || l.umax <= math.MaxInt64) +} + // checkpoint saves the current state of known relations. // Called when descending on a branch. func (ft *factsTable) checkpoint() { @@ -608,8 +617,7 @@ func simplifyBlock(ft *factsTable, b *Block) branch { // to the upper bound than this is proven. Most useful in cases such as: // if len(a) <= 1 { return } // do something with a[1] - // TODO: use constant bounds to do isNonNegative. - if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && isNonNegative(c.Args[0]) { + if (c.Op == OpIsInBounds || c.Op == OpIsSliceInBounds) && ft.isNonNegative(c.Args[0]) { m := ft.get(a0, a1, signed) if m != 0 && tr.r&m == m { if b.Func.pass.debug > 0 { diff --git a/test/prove.go b/test/prove.go index 4fc1d674d87..a78adf03dce 100644 --- a/test/prove.go +++ b/test/prove.go @@ -29,11 +29,30 @@ func f1(a []int) int { } func f1b(a []int, i int, j uint) int { - if i >= 0 && i < len(a) { // TODO: handle this case - return a[i] + if i >= 0 && i < len(a) { + return a[i] // ERROR "Proved non-negative bounds IsInBounds$" + } + if i >= 10 && i < len(a) { + return a[i] // ERROR "Proved non-negative bounds IsInBounds$" + } + if i >= 10 && i < len(a) { + return a[i] // ERROR "Proved non-negative bounds IsInBounds$" + } + if i >= 10 && i < len(a) { // todo: handle this case + return a[i-10] } if j < uint(len(a)) { - return a[j] // ERROR "Proved IsInBounds" + return a[j] // ERROR "Proved IsInBounds$" + } + return 0 +} + +func f1c(a []int, i int64) int { + c := uint64(math.MaxInt64 + 10) // overflows int + d := int64(c) + if i >= d && i < int64(len(a)) { + // d overflows, should not be handled. + return a[i] } return 0 }