diff --git a/src/cmd/compile/internal/ssa/gen/RISCV64.rules b/src/cmd/compile/internal/ssa/gen/RISCV64.rules index 5bc47ee1cc1..9d2d785d0ea 100644 --- a/src/cmd/compile/internal/ssa/gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/gen/RISCV64.rules @@ -52,6 +52,10 @@ (Hmul32 x y) => (SRAI [32] (MUL (SignExt32to64 x) (SignExt32to64 y))) (Hmul32u x y) => (SRLI [32] (MUL (ZeroExt32to64 x) (ZeroExt32to64 y))) +(Select0 (Add64carry x y c)) => (ADD (ADD x y) c) +(Select1 (Add64carry x y c)) => + (OR (SLTU s:(ADD x y) x) (SLTU (ADD s c) s)) + // (x + y) / 2 => (x / 2) + (y / 2) + (x & y & 1) (Avg64u x y) => (ADD (ADD (SRLI [1] x) (SRLI [1] y)) (ANDI [1] (AND x y))) @@ -743,6 +747,9 @@ (SLTI [x] (MOVDconst [y])) => (MOVDconst [b2i(int64(y) < int64(x))]) (SLTIU [x] (MOVDconst [y])) => (MOVDconst [b2i(uint64(y) < uint64(x))]) +(SLT x x) => (MOVDconst [0]) +(SLTU x x) => (MOVDconst [0]) + // deadcode for LoweredMuluhilo (Select0 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MULHU x y) (Select1 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MUL x y) diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index 9253d2d7296..e4e4003f34e 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -509,10 +509,14 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpRISCV64SLL(v) case OpRISCV64SLLI: return rewriteValueRISCV64_OpRISCV64SLLI(v) + case OpRISCV64SLT: + return rewriteValueRISCV64_OpRISCV64SLT(v) case OpRISCV64SLTI: return rewriteValueRISCV64_OpRISCV64SLTI(v) case OpRISCV64SLTIU: return rewriteValueRISCV64_OpRISCV64SLTIU(v) + case OpRISCV64SLTU: + return rewriteValueRISCV64_OpRISCV64SLTU(v) case OpRISCV64SRA: return rewriteValueRISCV64_OpRISCV64SRA(v) case OpRISCV64SRAI: @@ -4864,6 +4868,22 @@ func rewriteValueRISCV64_OpRISCV64SLLI(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64SLT(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (SLT x x) + // result: (MOVDconst [0]) + for { + x := v_0 + if x != v_1 { + break + } + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(0) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64SLTI(v *Value) bool { v_0 := v.Args[0] // match: (SLTI [x] (MOVDconst [y])) @@ -4896,6 +4916,22 @@ func rewriteValueRISCV64_OpRISCV64SLTIU(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64SLTU(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (SLTU x x) + // result: (MOVDconst [0]) + for { + x := v_0 + if x != v_1 { + break + } + v.reset(OpRISCV64MOVDconst) + v.AuxInt = int64ToAuxInt(0) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64SRA(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] @@ -6036,6 +6072,23 @@ func rewriteValueRISCV64_OpRsh8x8(v *Value) bool { } func rewriteValueRISCV64_OpSelect0(v *Value) bool { v_0 := v.Args[0] + b := v.Block + typ := &b.Func.Config.Types + // match: (Select0 (Add64carry x y c)) + // result: (ADD (ADD x y) c) + for { + if v_0.Op != OpAdd64carry { + break + } + c := v_0.Args[2] + x := v_0.Args[0] + y := v_0.Args[1] + v.reset(OpRISCV64ADD) + v0 := b.NewValue0(v.Pos, OpRISCV64ADD, typ.UInt64) + v0.AddArg2(x, y) + v.AddArg2(v0, c) + return true + } // match: (Select0 m:(LoweredMuluhilo x y)) // cond: m.Uses == 1 // result: (MULHU x y) @@ -6057,6 +6110,29 @@ func rewriteValueRISCV64_OpSelect0(v *Value) bool { } func rewriteValueRISCV64_OpSelect1(v *Value) bool { v_0 := v.Args[0] + b := v.Block + typ := &b.Func.Config.Types + // match: (Select1 (Add64carry x y c)) + // result: (OR (SLTU s:(ADD x y) x) (SLTU (ADD s c) s)) + for { + if v_0.Op != OpAdd64carry { + break + } + c := v_0.Args[2] + x := v_0.Args[0] + y := v_0.Args[1] + v.reset(OpRISCV64OR) + v0 := b.NewValue0(v.Pos, OpRISCV64SLTU, typ.UInt64) + s := b.NewValue0(v.Pos, OpRISCV64ADD, typ.UInt64) + s.AddArg2(x, y) + v0.AddArg2(s, x) + v2 := b.NewValue0(v.Pos, OpRISCV64SLTU, typ.UInt64) + v3 := b.NewValue0(v.Pos, OpRISCV64ADD, typ.UInt64) + v3.AddArg2(s, c) + v2.AddArg2(v3, s) + v.AddArg2(v0, v2) + return true + } // match: (Select1 m:(LoweredMuluhilo x y)) // cond: m.Uses == 1 // result: (MUL x y) diff --git a/src/cmd/compile/internal/ssagen/ssa.go b/src/cmd/compile/internal/ssagen/ssa.go index dda813518a5..107944170fc 100644 --- a/src/cmd/compile/internal/ssagen/ssa.go +++ b/src/cmd/compile/internal/ssagen/ssa.go @@ -4726,8 +4726,8 @@ func InitTables() { func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return s.newValue3(ssa.OpAdd64carry, types.NewTuple(types.Types[types.TUINT64], types.Types[types.TUINT64]), args[0], args[1], args[2]) }, - sys.AMD64, sys.ARM64, sys.PPC64, sys.S390X) - alias("math/bits", "Add", "math/bits", "Add64", sys.ArchAMD64, sys.ArchARM64, sys.ArchPPC64, sys.ArchPPC64LE, sys.ArchS390X) + sys.AMD64, sys.ARM64, sys.PPC64, sys.S390X, sys.RISCV64) + alias("math/bits", "Add", "math/bits", "Add64", sys.ArchAMD64, sys.ArchARM64, sys.ArchPPC64, sys.ArchPPC64LE, sys.ArchS390X, sys.ArchRISCV64) addF("math/bits", "Sub64", func(s *state, n *ir.CallExpr, args []*ssa.Value) *ssa.Value { return s.newValue3(ssa.OpSub64borrow, types.NewTuple(types.Types[types.TUINT64], types.Types[types.TUINT64]), args[0], args[1], args[2]) diff --git a/test/codegen/mathbits.go b/test/codegen/mathbits.go index a507d32843d..f36916ad03a 100644 --- a/test/codegen/mathbits.go +++ b/test/codegen/mathbits.go @@ -442,6 +442,7 @@ func Add(x, y, ci uint) (r, co uint) { // ppc64: "ADDC", "ADDE", "ADDZE" // ppc64le: "ADDC", "ADDE", "ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add(x, y, ci) } @@ -451,6 +452,7 @@ func AddC(x, ci uint) (r, co uint) { // ppc64: "ADDC", "ADDE", "ADDZE" // ppc64le: "ADDC", "ADDE", "ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add(x, 7, ci) } @@ -460,6 +462,7 @@ func AddZ(x, y uint) (r, co uint) { // ppc64: "ADDC", -"ADDE", "ADDZE" // ppc64le: "ADDC", -"ADDE", "ADDZE" // s390x:"ADDC",-"ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add(x, y, 0) } @@ -469,6 +472,7 @@ func AddR(x, y, ci uint) uint { // ppc64: "ADDC", "ADDE", -"ADDZE" // ppc64le: "ADDC", "ADDE", -"ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD",-"SLTU" r, _ := bits.Add(x, y, ci) return r } @@ -489,6 +493,7 @@ func Add64(x, y, ci uint64) (r, co uint64) { // ppc64: "ADDC", "ADDE", "ADDZE" // ppc64le: "ADDC", "ADDE", "ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add64(x, y, ci) } @@ -498,6 +503,7 @@ func Add64C(x, ci uint64) (r, co uint64) { // ppc64: "ADDC", "ADDE", "ADDZE" // ppc64le: "ADDC", "ADDE", "ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add64(x, 7, ci) } @@ -507,6 +513,7 @@ func Add64Z(x, y uint64) (r, co uint64) { // ppc64: "ADDC", -"ADDE", "ADDZE" // ppc64le: "ADDC", -"ADDE", "ADDZE" // s390x:"ADDC",-"ADDC\t[$]-1," + // riscv64: "ADD","SLTU" return bits.Add64(x, y, 0) } @@ -516,6 +523,7 @@ func Add64R(x, y, ci uint64) uint64 { // ppc64: "ADDC", "ADDE", -"ADDZE" // ppc64le: "ADDC", "ADDE", -"ADDZE" // s390x:"ADDE","ADDC\t[$]-1," + // riscv64: "ADD",-"SLTU" r, _ := bits.Add64(x, y, ci) return r }