diff --git a/src/cmd/compile/internal/riscv64/ssa.go b/src/cmd/compile/internal/riscv64/ssa.go index 2eb1e7ffa01..143e7c525a3 100644 --- a/src/cmd/compile/internal/riscv64/ssa.go +++ b/src/cmd/compile/internal/riscv64/ssa.go @@ -694,6 +694,9 @@ func ssaGenValue(s *ssagen.State, v *ssa.Value) { p.To.Sym = ir.Syms.Duffcopy p.To.Offset = v.AuxInt + case ssa.OpRISCV64LoweredRound32F, ssa.OpRISCV64LoweredRound64F: + // input is already rounded + case ssa.OpClobber, ssa.OpClobberReg: // TODO: implement for clobberdead experiment. Nop is ok for now. diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules index 181b46a7ced..ac68dfed76e 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64.rules +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64.rules @@ -103,7 +103,7 @@ (CvtBoolToUint8 ...) => (Copy ...) -(Round(64|32)F ...) => (Copy ...) +(Round(32|64)F ...) => (LoweredRound(32|64)F ...) (Slicemask x) => (SRAI [63] (NEG x)) @@ -780,6 +780,9 @@ (Select0 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MULHU x y) (Select1 m:(LoweredMuluhilo x y)) && m.Uses == 1 => (MUL x y) +(FADDD a (FMULD x y)) && a.Block.Func.useFMA(v) => (FMADDD x y a) +(FSUBD a (FMULD x y)) && a.Block.Func.useFMA(v) => (FNMSUBD x y a) +(FSUBD (FMULD x y) a) && a.Block.Func.useFMA(v) => (FMSUBD x y a) // Merge negation into fused multiply-add and multiply-subtract. // // Key: diff --git a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go index 52e87cbe72c..69f2950a88f 100644 --- a/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go +++ b/src/cmd/compile/internal/ssa/_gen/RISCV64Ops.go @@ -237,6 +237,10 @@ func init() { // gets correctly ordered with respect to GC safepoints. {name: "MOVconvert", argLength: 2, reg: gp11, asm: "MOV"}, // arg0, but converted to int/ptr as appropriate; arg1=mem + // Round ops to block fused-multiply-add extraction. + {name: "LoweredRound32F", argLength: 1, reg: fp11, resultInArg0: true}, + {name: "LoweredRound64F", argLength: 1, reg: fp11, resultInArg0: true}, + // Calls {name: "CALLstatic", argLength: -1, reg: call, aux: "CallOff", call: true}, // call static function aux.(*gc.Sym). last arg=mem, auxint=argsize, returns mem {name: "CALLtail", argLength: -1, reg: call, aux: "CallOff", call: true, tailCall: true}, // tail call static function aux.(*gc.Sym). last arg=mem, auxint=argsize, returns mem diff --git a/src/cmd/compile/internal/ssa/opGen.go b/src/cmd/compile/internal/ssa/opGen.go index b599df05259..12d8214ae1d 100644 --- a/src/cmd/compile/internal/ssa/opGen.go +++ b/src/cmd/compile/internal/ssa/opGen.go @@ -2400,6 +2400,8 @@ const ( OpRISCV64SLTU OpRISCV64SLTIU OpRISCV64MOVconvert + OpRISCV64LoweredRound32F + OpRISCV64LoweredRound64F OpRISCV64CALLstatic OpRISCV64CALLtail OpRISCV64CALLclosure @@ -32198,6 +32200,32 @@ var opcodeTable = [...]opInfo{ }, }, }, + { + name: "LoweredRound32F", + argLen: 1, + resultInArg0: true, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, + { + name: "LoweredRound64F", + argLen: 1, + resultInArg0: true, + reg: regInfo{ + inputs: []inputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + outputs: []outputInfo{ + {0, 9223372034707292160}, // F0 F1 F2 F3 F4 F5 F6 F7 F8 F9 F10 F11 F12 F13 F14 F15 F16 F17 F18 F19 F20 F21 F22 F23 F24 F25 F26 F27 F28 F29 F30 F31 + }, + }, + }, { name: "CALLstatic", auxType: auxCallOff, diff --git a/src/cmd/compile/internal/ssa/rewriteRISCV64.go b/src/cmd/compile/internal/ssa/rewriteRISCV64.go index e8002599efe..17af023db34 100644 --- a/src/cmd/compile/internal/ssa/rewriteRISCV64.go +++ b/src/cmd/compile/internal/ssa/rewriteRISCV64.go @@ -440,6 +440,8 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpRISCV64AND(v) case OpRISCV64ANDI: return rewriteValueRISCV64_OpRISCV64ANDI(v) + case OpRISCV64FADDD: + return rewriteValueRISCV64_OpRISCV64FADDD(v) case OpRISCV64FMADDD: return rewriteValueRISCV64_OpRISCV64FMADDD(v) case OpRISCV64FMSUBD: @@ -448,6 +450,8 @@ func rewriteValueRISCV64(v *Value) bool { return rewriteValueRISCV64_OpRISCV64FNMADDD(v) case OpRISCV64FNMSUBD: return rewriteValueRISCV64_OpRISCV64FNMSUBD(v) + case OpRISCV64FSUBD: + return rewriteValueRISCV64_OpRISCV64FSUBD(v) case OpRISCV64MOVBUload: return rewriteValueRISCV64_OpRISCV64MOVBUload(v) case OpRISCV64MOVBUreg: @@ -541,10 +545,10 @@ func rewriteValueRISCV64(v *Value) bool { case OpRotateLeft8: return rewriteValueRISCV64_OpRotateLeft8(v) case OpRound32F: - v.Op = OpCopy + v.Op = OpRISCV64LoweredRound32F return true case OpRound64F: - v.Op = OpCopy + v.Op = OpRISCV64LoweredRound64F return true case OpRsh16Ux16: return rewriteValueRISCV64_OpRsh16Ux16(v) @@ -3335,6 +3339,31 @@ func rewriteValueRISCV64_OpRISCV64ANDI(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64FADDD(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FADDD a (FMULD x y)) + // cond: a.Block.Func.useFMA(v) + // result: (FMADDD x y a) + for { + for _i0 := 0; _i0 <= 1; _i0, v_0, v_1 = _i0+1, v_1, v_0 { + a := v_0 + if v_1.Op != OpRISCV64FMULD { + continue + } + y := v_1.Args[1] + x := v_1.Args[0] + if !(a.Block.Func.useFMA(v)) { + continue + } + v.reset(OpRISCV64FMADDD) + v.AddArg3(x, y, a) + return true + } + break + } + return false +} func rewriteValueRISCV64_OpRISCV64FMADDD(v *Value) bool { v_2 := v.Args[2] v_1 := v.Args[1] @@ -3515,6 +3544,45 @@ func rewriteValueRISCV64_OpRISCV64FNMSUBD(v *Value) bool { } return false } +func rewriteValueRISCV64_OpRISCV64FSUBD(v *Value) bool { + v_1 := v.Args[1] + v_0 := v.Args[0] + // match: (FSUBD a (FMULD x y)) + // cond: a.Block.Func.useFMA(v) + // result: (FNMSUBD x y a) + for { + a := v_0 + if v_1.Op != OpRISCV64FMULD { + break + } + y := v_1.Args[1] + x := v_1.Args[0] + if !(a.Block.Func.useFMA(v)) { + break + } + v.reset(OpRISCV64FNMSUBD) + v.AddArg3(x, y, a) + return true + } + // match: (FSUBD (FMULD x y) a) + // cond: a.Block.Func.useFMA(v) + // result: (FMSUBD x y a) + for { + if v_0.Op != OpRISCV64FMULD { + break + } + y := v_0.Args[1] + x := v_0.Args[0] + a := v_1 + if !(a.Block.Func.useFMA(v)) { + break + } + v.reset(OpRISCV64FMSUBD) + v.AddArg3(x, y, a) + return true + } + return false +} func rewriteValueRISCV64_OpRISCV64MOVBUload(v *Value) bool { v_1 := v.Args[1] v_0 := v.Args[0] diff --git a/test/codegen/floats.go b/test/codegen/floats.go index 9cb62e031a7..1c5fc8a31a9 100644 --- a/test/codegen/floats.go +++ b/test/codegen/floats.go @@ -88,17 +88,20 @@ func FusedAdd64(x, y, z float64) float64 { // s390x:"FMADD\t" // ppc64x:"FMADD\t" // arm64:"FMADDD" + // riscv64:"FMADDD\t" return x*y + z } func FusedSub64_a(x, y, z float64) float64 { // s390x:"FMSUB\t" // ppc64x:"FMSUB\t" + // riscv64:"FMSUBD\t" return x*y - z } func FusedSub64_b(x, y, z float64) float64 { // arm64:"FMSUBD" + // riscv64:"FNMSUBD\t" return z - x*y }