From fb165eaffd1949aa7e0af75af5e3cc27c6e50508 Mon Sep 17 00:00:00 2001 From: Cholerae Hu Date: Thu, 17 Aug 2017 15:06:42 +0800 Subject: [PATCH] cmd/compile: combine x*n - y*n into (x-y)*n Do the similar thing to CL 55143 to reduce IMUL. Change-Id: I1bd38f618058e3cd74fac181f003610ea13f2294 Reviewed-on: https://go-review.googlesource.com/56252 Run-TryBot: Emmanuel Odeke TryBot-Result: Gobot Gobot Reviewed-by: Keith Randall --- src/cmd/compile/internal/gc/asm_test.go | 28 + .../compile/internal/ssa/gen/generic.rules | 6 + .../compile/internal/ssa/rewritegeneric.go | 520 +++++++++++++++++- test/mergemul.go | 48 +- 4 files changed, 576 insertions(+), 26 deletions(-) diff --git a/src/cmd/compile/internal/gc/asm_test.go b/src/cmd/compile/internal/gc/asm_test.go index 3bf8dcd42d..0445caba66 100644 --- a/src/cmd/compile/internal/gc/asm_test.go +++ b/src/cmd/compile/internal/gc/asm_test.go @@ -831,6 +831,20 @@ var linuxAMD64Tests = []*asmTest{ }`, pos: []string{"\tADDQ\t[$]19", "\tIMULQ"}, // (a+19)*n }, + { + ` + func mul4(n int) int { + return 23*n - 9*n + }`, + []string{"\tIMULQ\t[$]14"}, // 14*n + }, + { + ` + func mul5(a, n int) int { + return a*n - 19*n + }`, + []string{"\tADDQ\t[$]-19", "\tIMULQ"}, // (a-19)*n + }, // see issue 19595. // We want to merge load+op in f58, but not in f59. @@ -1150,6 +1164,20 @@ var linux386Tests = []*asmTest{ `, pos: []string{"TEXT\t.*, [$]0-4"}, }, + { + ` + func mul3(n int) int { + return 23*n - 9*n + }`, + []string{"\tIMULL\t[$]14"}, // 14*n + }, + { + ` + func mul4(a, n int) int { + return n*a - a*19 + }`, + []string{"\tADDL\t[$]-19", "\tIMULL"}, // (n-19)*a + }, } var linuxS390XTests = []*asmTest{ diff --git a/src/cmd/compile/internal/ssa/gen/generic.rules b/src/cmd/compile/internal/ssa/gen/generic.rules index 1faf6b3857..dd4018abe2 100644 --- a/src/cmd/compile/internal/ssa/gen/generic.rules +++ b/src/cmd/compile/internal/ssa/gen/generic.rules @@ -328,6 +328,12 @@ (Add16 (Mul16 x y) (Mul16 x z)) -> (Mul16 x (Add16 y z)) (Add8 (Mul8 x y) (Mul8 x z)) -> (Mul8 x (Add8 y z)) +// Rewrite x*y - x*z to x*(y-z) +(Sub64 (Mul64 x y) (Mul64 x z)) -> (Mul64 x (Sub64 y z)) +(Sub32 (Mul32 x y) (Mul32 x z)) -> (Mul32 x (Sub32 y z)) +(Sub16 (Mul16 x y) (Mul16 x z)) -> (Mul16 x (Sub16 y z)) +(Sub8 (Mul8 x y) (Mul8 x z)) -> (Mul8 x (Sub8 y z)) + // rewrite shifts of 8/16/32 bit consts into 64 bit consts to reduce // the number of the other rewrite rules for const shifts (Lsh64x32 x (Const32 [c])) -> (Lsh64x64 x (Const64 [int64(uint32(c))])) diff --git a/src/cmd/compile/internal/ssa/rewritegeneric.go b/src/cmd/compile/internal/ssa/rewritegeneric.go index fdd4c1e167..8310844287 100644 --- a/src/cmd/compile/internal/ssa/rewritegeneric.go +++ b/src/cmd/compile/internal/ssa/rewritegeneric.go @@ -23796,6 +23796,126 @@ func rewriteValuegeneric_OpSub16_0(v *Value) bool { v.AddArg(x) return true } + // match: (Sub16 (Mul16 x y) (Mul16 x z)) + // cond: + // result: (Mul16 x (Sub16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub16 (Mul16 y x) (Mul16 x z)) + // cond: + // result: (Mul16 x (Sub16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub16 (Mul16 x y) (Mul16 z x)) + // cond: + // result: (Mul16 x (Sub16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub16 (Mul16 y x) (Mul16 z x)) + // cond: + // result: (Mul16 x (Sub16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } // match: (Sub16 x x) // cond: // result: (Const16 [0]) @@ -23869,6 +23989,11 @@ func rewriteValuegeneric_OpSub16_0(v *Value) bool { v.AddArg(x) return true } + return false +} +func rewriteValuegeneric_OpSub16_10(v *Value) bool { + b := v.Block + _ = b // match: (Sub16 (Add16 y x) y) // cond: // result: x @@ -23977,11 +24102,6 @@ func rewriteValuegeneric_OpSub16_0(v *Value) bool { v.AddArg(x) return true } - return false -} -func rewriteValuegeneric_OpSub16_10(v *Value) bool { - b := v.Block - _ = b // match: (Sub16 (Const16 [c]) (Sub16 (Const16 [d]) x)) // cond: // result: (Add16 (Const16 [int64(int16(c-d))]) x) @@ -24060,6 +24180,126 @@ func rewriteValuegeneric_OpSub32_0(v *Value) bool { v.AddArg(x) return true } + // match: (Sub32 (Mul32 x y) (Mul32 x z)) + // cond: + // result: (Mul32 x (Sub32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub32 (Mul32 y x) (Mul32 x z)) + // cond: + // result: (Mul32 x (Sub32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub32 (Mul32 x y) (Mul32 z x)) + // cond: + // result: (Mul32 x (Sub32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub32 (Mul32 y x) (Mul32 z x)) + // cond: + // result: (Mul32 x (Sub32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } // match: (Sub32 x x) // cond: // result: (Const32 [0]) @@ -24133,6 +24373,11 @@ func rewriteValuegeneric_OpSub32_0(v *Value) bool { v.AddArg(x) return true } + return false +} +func rewriteValuegeneric_OpSub32_10(v *Value) bool { + b := v.Block + _ = b // match: (Sub32 (Add32 y x) y) // cond: // result: x @@ -24241,11 +24486,6 @@ func rewriteValuegeneric_OpSub32_0(v *Value) bool { v.AddArg(x) return true } - return false -} -func rewriteValuegeneric_OpSub32_10(v *Value) bool { - b := v.Block - _ = b // match: (Sub32 (Const32 [c]) (Sub32 (Const32 [d]) x)) // cond: // result: (Add32 (Const32 [int64(int32(c-d))]) x) @@ -24364,6 +24604,126 @@ func rewriteValuegeneric_OpSub64_0(v *Value) bool { v.AddArg(x) return true } + // match: (Sub64 (Mul64 x y) (Mul64 x z)) + // cond: + // result: (Mul64 x (Sub64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub64 (Mul64 y x) (Mul64 x z)) + // cond: + // result: (Mul64 x (Sub64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub64 (Mul64 x y) (Mul64 z x)) + // cond: + // result: (Mul64 x (Sub64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub64 (Mul64 y x) (Mul64 z x)) + // cond: + // result: (Mul64 x (Sub64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } // match: (Sub64 x x) // cond: // result: (Const64 [0]) @@ -24437,6 +24797,11 @@ func rewriteValuegeneric_OpSub64_0(v *Value) bool { v.AddArg(x) return true } + return false +} +func rewriteValuegeneric_OpSub64_10(v *Value) bool { + b := v.Block + _ = b // match: (Sub64 (Add64 y x) y) // cond: // result: x @@ -24545,11 +24910,6 @@ func rewriteValuegeneric_OpSub64_0(v *Value) bool { v.AddArg(x) return true } - return false -} -func rewriteValuegeneric_OpSub64_10(v *Value) bool { - b := v.Block - _ = b // match: (Sub64 (Const64 [c]) (Sub64 (Const64 [d]) x)) // cond: // result: (Add64 (Const64 [c-d]) x) @@ -24668,6 +25028,126 @@ func rewriteValuegeneric_OpSub8_0(v *Value) bool { v.AddArg(x) return true } + // match: (Sub8 (Mul8 x y) (Mul8 x z)) + // cond: + // result: (Mul8 x (Sub8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub8 (Mul8 y x) (Mul8 x z)) + // cond: + // result: (Mul8 x (Sub8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub8 (Mul8 x y) (Mul8 z x)) + // cond: + // result: (Mul8 x (Sub8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Sub8 (Mul8 y x) (Mul8 z x)) + // cond: + // result: (Mul8 x (Sub8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpSub8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } // match: (Sub8 x x) // cond: // result: (Const8 [0]) @@ -24741,6 +25221,11 @@ func rewriteValuegeneric_OpSub8_0(v *Value) bool { v.AddArg(x) return true } + return false +} +func rewriteValuegeneric_OpSub8_10(v *Value) bool { + b := v.Block + _ = b // match: (Sub8 (Add8 y x) y) // cond: // result: x @@ -24849,11 +25334,6 @@ func rewriteValuegeneric_OpSub8_0(v *Value) bool { v.AddArg(x) return true } - return false -} -func rewriteValuegeneric_OpSub8_10(v *Value) bool { - b := v.Block - _ = b // match: (Sub8 (Const8 [c]) (Sub8 (Const8 [d]) x)) // cond: // result: (Add8 (Const8 [int64(int8(c-d))]) x) diff --git a/test/mergemul.go b/test/mergemul.go index 86fbe676cb..a23115b612 100644 --- a/test/mergemul.go +++ b/test/mergemul.go @@ -24,7 +24,7 @@ import "fmt" // if a8 != b8 { // // print error msg and panic // } -func makeMergeTest(m1, m2, k int, size string) string { +func makeMergeAddTest(m1, m2, k int, size string) string { model := " a" + size + ", b" + size model += fmt.Sprintf(" = %%d*n%s + %%d*(n%s+%%d), (%%d+%%d)*n%s + (%%d*%%d)", size, size, size) @@ -32,7 +32,39 @@ func makeMergeTest(m1, m2, k int, size string) string { test := fmt.Sprintf(model, m1, m2, k, m1, m2, m2, k) test += fmt.Sprintf(` if a%s != b%s { - fmt.Printf("MergeTest(%d, %d, %d, %s) failed\n") + fmt.Printf("MergeAddTest(%d, %d, %d, %s) failed\n") + fmt.Printf("%%d != %%d\n", a%s, b%s) + panic("FAIL") + } +`, size, size, m1, m2, k, size, size, size) + return test + "\n" +} + +// Check that expressions like (c*n - d*(n+k)) get correctly merged by +// the compiler into (c-d)*n - d*k (with c-d and d*k computed at +// compile time). +// +// The merging is performed by a combination of the multiplication +// merge rules +// (c*n - d*n) -> (c-d)*n +// and the distributive multiplication rules +// c * (d-x) -> c*d - c*x + +// Generate a MergeTest that looks like this: +// +// a8, b8 = m1*n8 - m2*(n8+k), (m1-m2)*n8 - m2*k +// if a8 != b8 { +// // print error msg and panic +// } +func makeMergeSubTest(m1, m2, k int, size string) string { + + model := " a" + size + ", b" + size + model += fmt.Sprintf(" = %%d*n%s - %%d*(n%s+%%d), (%%d-%%d)*n%s - (%%d*%%d)", size, size, size) + + test := fmt.Sprintf(model, m1, m2, k, m1, m2, m2, k) + test += fmt.Sprintf(` + if a%s != b%s { + fmt.Printf("MergeSubTest(%d, %d, %d, %s) failed\n") fmt.Printf("%%d != %%d\n", a%s, b%s) panic("FAIL") } @@ -42,10 +74,14 @@ func makeMergeTest(m1, m2, k int, size string) string { func makeAllSizes(m1, m2, k int) string { var tests string - tests += makeMergeTest(m1, m2, k, "8") - tests += makeMergeTest(m1, m2, k, "16") - tests += makeMergeTest(m1, m2, k, "32") - tests += makeMergeTest(m1, m2, k, "64") + tests += makeMergeAddTest(m1, m2, k, "8") + tests += makeMergeAddTest(m1, m2, k, "16") + tests += makeMergeAddTest(m1, m2, k, "32") + tests += makeMergeAddTest(m1, m2, k, "64") + tests += makeMergeSubTest(m1, m2, k, "8") + tests += makeMergeSubTest(m1, m2, k, "16") + tests += makeMergeSubTest(m1, m2, k, "32") + tests += makeMergeSubTest(m1, m2, k, "64") tests += "\n" return tests }