diff --git a/src/cmd/compile/internal/ssa/gen/RISCV64.rules b/src/cmd/compile/internal/ssa/gen/RISCV64.rules index 4290d1b85c..3379e1dac5 100644 --- a/src/cmd/compile/internal/ssa/gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/gen/RISCV64.rules @@ -6,9 +6,7 @@ // * Use SLTI and SLTIU for comparisons to constants, instead of SLT/SLTU with constants in registers // * Use the zero register instead of moving 0 into a register. // * Add rules to avoid generating a temp bool value for (If (SLT[U] ...) ...). -// * Optimize left and right shift by simplifying SLTIU, Neg, and ADD for constants. // * Arrange for non-trivial Zero and Move lowerings to use aligned loads and stores. -// * Eliminate zero immediate shifts, adds, etc. // * Avoid using Neq32 for writeBarrier.enabled checks. // Lowering arithmetic @@ -229,7 +227,7 @@ (Rsh64x32 x y) => (SRA x (OR y (ADDI [-1] (SLTIU [64] (ZeroExt32to64 y))))) (Rsh64x64 x y) => (SRA x (OR y (ADDI [-1] (SLTIU [64] y)))) -// rotates +// Rotates. (RotateLeft8 x (MOVDconst [c])) => (Or8 (Lsh8x64 x (MOVDconst [c&7])) (Rsh8Ux64 x (MOVDconst [-c&7]))) (RotateLeft16 x (MOVDconst [c])) => (Or16 (Lsh16x64 x (MOVDconst [c&15])) (Rsh16Ux64 x (MOVDconst [-c&15]))) (RotateLeft32 x (MOVDconst [c])) => (Or32 (Lsh32x64 x (MOVDconst [c&31])) (Rsh32Ux64 x (MOVDconst [-c&31]))) @@ -707,19 +705,39 @@ (SUB x (MOVDconst [val])) && is32Bit(-val) => (ADDI [-val] x) // Subtraction of zero. -(SUB x (MOVDconst [0])) => x - -// Subtraction of zero with sign extension. +(SUB x (MOVDconst [0])) => x (SUBW x (MOVDconst [0])) => (ADDIW [0] x) // Subtraction from zero. -(SUB (MOVDconst [0]) x) => (NEG x) - -// Subtraction from zero with sign extension. +(SUB (MOVDconst [0]) x) => (NEG x) (SUBW (MOVDconst [0]) x) => (NEGW x) -// Addition of zero. +// Addition of zero or two constants. (ADDI [0] x) => x +(ADDI [x] (MOVDconst [y])) && is32Bit(x + y) => (MOVDconst [x + y]) + +// ANDI with all zeros, all ones or two constants. +(ANDI [0] x) => (MOVDconst [0]) +(ANDI [-1] x) => x +(ANDI [x] (MOVDconst [y])) => (MOVDconst [x & y]) + +// ORI with all zeroes, all ones or two constants. +(ORI [0] x) => x +(ORI [-1] x) => (MOVDconst [-1]) +(ORI [x] (MOVDconst [y])) => (MOVDconst [x | y]) + +// Negation of a constant. +(NEG (MOVDconst [x])) => (MOVDconst [-x]) +(NEGW (MOVDconst [x])) => (MOVDconst [int64(int32(-x))]) + +// Shift of a constant. +(SLLI [x] (MOVDconst [y])) && is32Bit(y << x) => (MOVDconst [y << x]) +(SRLI [x] (MOVDconst [y])) => (MOVDconst [int64(uint64(y) >> x)]) +(SRAI [x] (MOVDconst [y])) => (MOVDconst [int64(y) >> x]) + +// SLTI/SLTIU with constants. +(SLTI [x] (MOVDconst [y])) => (MOVDconst [b2i(int64(y) < int64(x))]) +(SLTIU [x] (MOVDconst [y])) => (MOVDconst [b2i(uint64(y) < uint64(x))]) // Merge negation into fused multiply-add and multiply-subtract. // diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index f856a26d49..885bbaf4a1 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -441,6 +441,8 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpRISCV64ADDI(v) case OpRISCV64AND: return rewriteValueRISCV64_OpRISCV64AND(v) + case OpRISCV64ANDI: + return rewriteValueRISCV64_OpRISCV64ANDI(v) case OpRISCV64FMADDD: return rewriteValueRISCV64_OpRISCV64FMADDD(v) case OpRISCV64FMSUBD: @@ -495,14 +497,30 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpRISCV64MOVWstore(v) case OpRISCV64MOVWstorezero: return rewriteValueRISCV64_OpRISCV64MOVWstorezero(v) + case OpRISCV64NEG: + return rewriteValueRISCV64_OpRISCV64NEG(v) + case OpRISCV64NEGW: + return rewriteValueRISCV64_OpRISCV64NEGW(v) case OpRISCV64OR: return rewriteValueRISCV64_OpRISCV64OR(v) + case OpRISCV64ORI: + return rewriteValueRISCV64_OpRISCV64ORI(v) case OpRISCV64SLL: return rewriteValueRISCV64_OpRISCV64SLL(v) + case OpRISCV64SLLI: + return rewriteValueRISCV64_OpRISCV64SLLI(v) + case OpRISCV64SLTI: + return rewriteValueRISCV64_OpRISCV64SLTI(v) + case OpRISCV64SLTIU: + return rewriteValueRISCV64_OpRISCV64SLTIU(v) case OpRISCV64SRA: return rewriteValueRISCV64_OpRISCV64SRA(v) + case OpRISCV64SRAI: + return rewriteValueRISCV64_OpRISCV64SRAI(v) case OpRISCV64SRL: return rewriteValueRISCV64_OpRISCV64SRL(v) + case OpRISCV64SRLI: + return rewriteValueRISCV64_OpRISCV64SRLI(v) case OpRISCV64SUB: return rewriteValueRISCV64_OpRISCV64SUB(v) case OpRISCV64SUBW: @@ -2822,6 +2840,22 @@ func rewriteValueRISCV64_OpRISCV64ADDI(v *Value) bool { v.copyOf(x) return true } + // match: (ADDI [x] (MOVDconst [y])) + // cond: is32Bit(x + y) + // result: (MOVDconst [x + y]) + for { + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVDconst { + break + } + y := auxIntToInt64(v_0.AuxInt) + if !(is32Bit(x + y)) { + break + } + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(x + y) + return true + } return false } func rewriteValueRISCV64_OpRISCV64AND(v *Value) bool { @@ -2849,6 +2883,42 @@ func rewriteValueRISCV64_OpRISCV64AND(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64ANDI(v *Value) bool { + v_0 := v.Args[0] + // match: (ANDI [0] x) + // result: (MOVDconst [0]) + for { + if auxIntToInt64(v.AuxInt) != 0 { + break + } + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(0) + return true + } + // match: (ANDI [-1] x) + // result: x + for { + if auxIntToInt64(v.AuxInt) != -1 { + break + } + x := v_0 + v.copyOf(x) + return true + } + // match: (ANDI [x] (MOVDconst [y])) + // result: (MOVDconst [x & y]) + for { + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVDconst { + break + } + y := auxIntToInt64(v_0.AuxInt) + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(x & y) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64FMADDD(v *Value) bool { v_2 := v.Args[2] v_1 := v.Args[1] @@ -4619,6 +4689,36 @@ func rewriteValueRISCV64_OpRISCV64MOVWstorezero(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64NEG(v *Value) bool { + v_0 := v.Args[0] + // match: (NEG (MOVDconst [x])) + // result: (MOVDconst [-x]) + for { + if v_0.Op != OpRISCV64MOVDconst { + break + } + x := auxIntToInt64(v_0.AuxInt) + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(-x) + return true + } + return false +} +func rewriteValueRISCV64_OpRISCV64NEGW(v *Value) bool { + v_0 := v.Args[0] + // match: (NEGW (MOVDconst [x])) + // result: (MOVDconst [int64(int32(-x))]) + for { + if v_0.Op != OpRISCV64MOVDconst { + break + } + x := auxIntToInt64(v_0.AuxInt) + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(int64(int32(-x))) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64OR(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] @@ -4644,6 +4744,42 @@ func rewriteValueRISCV64_OpRISCV64OR(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64ORI(v *Value) bool { + v_0 := v.Args[0] + // match: (ORI [0] x) + // result: x + for { + if auxIntToInt64(v.AuxInt) != 0 { + break + } + x := v_0 + v.copyOf(x) + return true + } + // match: (ORI [-1] x) + // result: (MOVDconst [-1]) + for { + if auxIntToInt64(v.AuxInt) != -1 { + break + } + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(-1) + return true + } + // match: (ORI [x] (MOVDconst [y])) + // result: (MOVDconst [x | y]) + for { + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVDconst { + break + } + y := auxIntToInt64(v_0.AuxInt) + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(x | y) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64SLL(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] @@ -4662,6 +4798,58 @@ func rewriteValueRISCV64_OpRISCV64SLL(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64SLLI(v *Value) bool { + v_0 := v.Args[0] + // match: (SLLI [x] (MOVDconst [y])) + // cond: is32Bit(y << x) + // result: (MOVDconst [y << x]) + for { + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVDconst { + break + } + y := auxIntToInt64(v_0.AuxInt) + if !(is32Bit(y << x)) { + break + } + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(y << x) + return true + } + return false +} +func rewriteValueRISCV64_OpRISCV64SLTI(v *Value) bool { + v_0 := v.Args[0] + // match: (SLTI [x] (MOVDconst [y])) + // result: (MOVDconst [b2i(int64(y) < int64(x))]) + for { + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVDconst { + break + } + y := auxIntToInt64(v_0.AuxInt) + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(b2i(int64(y) < int64(x))) + return true + } + return false +} +func rewriteValueRISCV64_OpRISCV64SLTIU(v *Value) bool { + v_0 := v.Args[0] + // match: (SLTIU [x] (MOVDconst [y])) + // result: (MOVDconst [b2i(uint64(y) < uint64(x))]) + for { + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVDconst { + break + } + y := auxIntToInt64(v_0.AuxInt) + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(b2i(uint64(y) < uint64(x))) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64SRA(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] @@ -4680,6 +4868,22 @@ func rewriteValueRISCV64_OpRISCV64SRA(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64SRAI(v *Value) bool { + v_0 := v.Args[0] + // match: (SRAI [x] (MOVDconst [y])) + // result: (MOVDconst [int64(y) >> x]) + for { + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVDconst { + break + } + y := auxIntToInt64(v_0.AuxInt) + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(int64(y) >> x) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64SRL(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] @@ -4698,6 +4902,22 @@ func rewriteValueRISCV64_OpRISCV64SRL(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64SRLI(v *Value) bool { + v_0 := v.Args[0] + // match: (SRLI [x] (MOVDconst [y])) + // result: (MOVDconst [int64(uint64(y) >> x)]) + for { + x := auxIntToInt64(v.AuxInt) + if v_0.Op != OpRISCV64MOVDconst { + break + } + y := auxIntToInt64(v_0.AuxInt) + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(int64(uint64(y) >> x)) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64SUB(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] diff --git a/test/codegen/shift.go b/test/codegen/shift.go index 8e87e96c9e..b3ed69d9e3 100644 --- a/test/codegen/shift.go +++ b/test/codegen/shift.go @@ -11,47 +11,47 @@ package codegen // ------------------ // func lshConst64x64(v int64) int64 { - // riscv64:"SLL","AND","SLTIU" + // riscv64:"SLLI",-"AND",-"SLTIU" return v << uint64(33) } func rshConst64Ux64(v uint64) uint64 { - // riscv64:"SRL","AND","SLTIU" + // riscv64:"SRLI",-"AND",-"SLTIU" return v >> uint64(33) } func rshConst64x64(v int64) int64 { - // riscv64:"SRA","OR","SLTIU" + // riscv64:"SRAI",-"OR",-"SLTIU" return v >> uint64(33) } func lshConst32x64(v int32) int32 { - // riscv64:"SLL","AND","SLTIU" + // riscv64:"SLLI",-"AND",-"SLTIU" return v << uint64(29) } func rshConst32Ux64(v uint32) uint32 { - // riscv64:"SRL","AND","SLTIU" + // riscv64:"SRLI",-"AND",-"SLTIU" return v >> uint64(29) } func rshConst32x64(v int32) int32 { - // riscv64:"SRA","OR","SLTIU" + // riscv64:"SRAI",-"OR",-"SLTIU" return v >> uint64(29) } func lshConst64x32(v int64) int64 { - // riscv64:"SLL","AND","SLTIU" + // riscv64:"SLLI",-"AND",-"SLTIU" return v << uint32(33) } func rshConst64Ux32(v uint64) uint64 { - // riscv64:"SRL","AND","SLTIU" + // riscv64:"SRLI",-"AND",-"SLTIU" return v >> uint32(33) } func rshConst64x32(v int64) int64 { - // riscv64:"SRA","OR","SLTIU" + // riscv64:"SRAI",-"OR",-"SLTIU" return v >> uint32(33) }