diff --git a/src/cmd/compile/internal/ssa/prove.go b/src/cmd/compile/internal/ssa/prove.go index 26176af07c..3d1b5007b3 100644 --- a/src/cmd/compile/internal/ssa/prove.go +++ b/src/cmd/compile/internal/ssa/prove.go @@ -1222,13 +1222,13 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { // Replace OpSlicemask operations in b with constants where possible. x, delta := isConstDelta(v.Args[0]) if x == nil { - continue + break } // slicemask(x + y) // if x is larger than -y (y is negative), then slicemask is -1. lim, ok := ft.limits[x.ID] if !ok { - continue + break } if lim.umin > uint64(-delta) { if v.Args[0].Op == OpAdd64 { @@ -1248,7 +1248,7 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { x := v.Args[0] lim, ok := ft.limits[x.ID] if !ok { - continue + break } if lim.umin > 0 || lim.min > 0 || lim.max < 0 { if b.Func.pass.debug > 0 { @@ -1280,7 +1280,7 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { panic("unexpected integer size") } v.AuxInt = 0 - continue // Be sure not to fallthrough - this is no longer OpRsh. + break // Be sure not to fallthrough - this is no longer OpRsh. } // If the Rsh hasn't been replaced with 0, still check if it is bounded. fallthrough @@ -1297,7 +1297,7 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { by := v.Args[1] lim, ok := ft.limits[by.ID] if !ok { - continue + break } bits := 8 * v.Args[0].Type.Size() if lim.umax < uint64(bits) || (lim.max < bits && ft.isNonNegative(by)) { @@ -1331,6 +1331,60 @@ func simplifyBlock(sdom SparseTree, ft *factsTable, b *Block) { } } } + // Fold provable constant results. + // Helps in cases where we reuse a value after branching on its equality. + for i, arg := range v.Args { + switch arg.Op { + case OpConst64, OpConst32, OpConst16, OpConst8: + continue + } + lim, ok := ft.limits[arg.ID] + if !ok { + continue + } + + var constValue int64 + typ := arg.Type + bits := 8 * typ.Size() + switch { + case lim.min == lim.max: + constValue = lim.min + case lim.umin == lim.umax: + // truncate then sign extand + switch bits { + case 64: + constValue = int64(lim.umin) + case 32: + constValue = int64(int32(lim.umin)) + case 16: + constValue = int64(int16(lim.umin)) + case 8: + constValue = int64(int8(lim.umin)) + default: + panic("unexpected integer size") + } + default: + continue + } + var c *Value + f := b.Func + switch bits { + case 64: + c = f.ConstInt64(typ, constValue) + case 32: + c = f.ConstInt32(typ, int32(constValue)) + case 16: + c = f.ConstInt16(typ, int16(constValue)) + case 8: + c = f.ConstInt8(typ, int8(constValue)) + default: + panic("unexpected integer size") + } + v.SetArg(i, c) + if b.Func.pass.debug > 1 { + b.Func.Warnl(v.Pos, "Proved %v's arg %d (%v) is constant %d", v, i, arg, constValue) + } + } } if b.Kind != BlockIf { diff --git a/test/codegen/comparisons.go b/test/codegen/comparisons.go index fd32ea335c..181bb93496 100644 --- a/test/codegen/comparisons.go +++ b/test/codegen/comparisons.go @@ -161,7 +161,7 @@ func CmpZero4(a int64, ptr *int) { } } -func CmpToZero(a, b, d int32, e, f int64) int32 { +func CmpToZero(a, b, d int32, e, f int64, deOptC0, deOptC1 bool) int32 { // arm:`TST`,-`AND` // arm64:`TSTW`,-`AND` // 386:`TESTL`,-`ANDL` @@ -201,13 +201,17 @@ func CmpToZero(a, b, d int32, e, f int64) int32 { } else if c4 { return 5 } else if c5 { - return b + d + return 6 } else if c6 { - return a & d - } else if c7 { return 7 + } else if c7 { + return 9 } else if c8 { - return 8 + return 10 + } else if deOptC0 { + return b + d + } else if deOptC1 { + return a & d } else { return 0 } diff --git a/test/prove_constant_folding.go b/test/prove_constant_folding.go new file mode 100644 index 0000000000..d4bdb20d83 --- /dev/null +++ b/test/prove_constant_folding.go @@ -0,0 +1,32 @@ +// +build amd64 +// errorcheck -0 -d=ssa/prove/debug=2 + +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func f0i(x int) int { + if x == 20 { + return x // ERROR "Proved.+is constant 20$" + } + + if (x + 20) == 20 { + return x + 5 // ERROR "Proved.+is constant 0$" + } + + return x / 2 +} + +func f0u(x uint) uint { + if x == 20 { + return x // ERROR "Proved.+is constant 20$" + } + + if (x + 20) == 20 { + return x + 5 // ERROR "Proved.+is constant 0$" + } + + return x / 2 +}