From a0453a180fc3555843185385e9d4ad9d57f1d36a Mon Sep 17 00:00:00 2001 From: Alberto Donizetti Date: Mon, 14 Aug 2017 11:44:09 +0200 Subject: [PATCH] cmd/compile: combine x*n + y*n into (x+y)*n There are a few cases where this can be useful. Apart from the obvious (and silly) 100*n + 200*n where we generate one IMUL instead of two, consider: 15*n + 31*n Currently, the compiler strength-reduces both imuls, generating: 0x0000 00000 MOVQ "".n+8(SP), AX 0x0005 00005 MOVQ AX, CX 0x0008 00008 SHLQ $4, AX 0x000c 00012 SUBQ CX, AX 0x000f 00015 MOVQ CX, DX 0x0012 00018 SHLQ $5, CX 0x0016 00022 SUBQ DX, CX 0x0019 00025 ADDQ CX, AX 0x001c 00028 MOVQ AX, "".~r1+16(SP) 0x0021 00033 RET But combining the imuls is both faster and shorter: 0x0000 00000 MOVQ "".n+8(SP), AX 0x0005 00005 IMULQ $46, AX 0x0009 00009 MOVQ AX, "".~r1+16(SP) 0x000e 00014 RET even without strength-reduction. Moreover, consider: 5*n + 7*(n+1) + 11*(n+2) We already have a rule that rewrites 7(n+1) into 7n+7, so the generated code (without imuls merging) looks like this: 0x0000 00000 MOVQ "".n+8(SP), AX 0x0005 00005 LEAQ (AX)(AX*4), CX 0x0009 00009 MOVQ AX, DX 0x000c 00012 NEGQ AX 0x000f 00015 LEAQ (AX)(DX*8), AX 0x0013 00019 ADDQ CX, AX 0x0016 00022 LEAQ (DX)(CX*2), CX 0x001a 00026 LEAQ 29(AX)(CX*1), AX 0x001f 00031 MOVQ AX, "".~r1+16(SP) But with imuls merging, the 5n, 7n and 11n factors get merged, and the generated code looks like this: 0x0000 00000 MOVQ "".n+8(SP), AX 0x0005 00005 IMULQ $23, AX 0x0009 00009 ADDQ $29, AX 0x000d 00013 MOVQ AX, "".~r1+16(SP) 0x0012 00018 RET Which is both faster and shorter; that's also the exact same code that clang and the intel c compiler generate for the above expression. Change-Id: Ib4d5503f05d2f2efe31a1be14e2fe6cac33730a9 Reviewed-on: https://go-review.googlesource.com/55143 Reviewed-by: Keith Randall --- src/cmd/compile/internal/gc/asm_test.go | 38 + .../compile/internal/ssa/gen/generic.rules | 6 + .../compile/internal/ssa/rewritegeneric.go | 1068 ++++++++++++++++- test/mergemul.go | 81 ++ 4 files changed, 1149 insertions(+), 44 deletions(-) create mode 100644 test/mergemul.go diff --git a/src/cmd/compile/internal/gc/asm_test.go b/src/cmd/compile/internal/gc/asm_test.go index 08ec638f44e..23b70ae41d0 100644 --- a/src/cmd/compile/internal/gc/asm_test.go +++ b/src/cmd/compile/internal/gc/asm_test.go @@ -741,6 +741,29 @@ var linuxAMD64Tests = []*asmTest{ }`, []string{"\tPOPCNTQ\t", "support_popcnt"}, }, + // multiplication merging tests + { + ` + func mul1(n int) int { + return 15*n + 31*n + }`, + []string{"\tIMULQ\t[$]46"}, // 46*n + }, + { + ` + func mul2(n int) int { + return 5*n + 7*(n+1) + 11*(n+2) + }`, + []string{"\tIMULQ\t[$]23", "\tADDQ\t[$]29"}, // 23*n + 29 + }, + { + ` + func mul3(a, n int) int { + return a*n + 19*n + }`, + []string{"\tADDQ\t[$]19", "\tIMULQ"}, // (a+19)*n + }, + // see issue 19595. // We want to merge load+op in f58, but not in f59. { @@ -928,6 +951,21 @@ var linux386Tests = []*asmTest{ `, []string{"\tMOVL\t\\(.*\\)\\(.*\\*1\\),"}, }, + // multiplication merging tests + { + ` + func mul1(n int) int { + return 9*n + 14*n + }`, + []string{"\tIMULL\t[$]23"}, // 23*n + }, + { + ` + func mul2(a, n int) int { + return 19*a + a*n + }`, + []string{"\tADDL\t[$]19", "\tIMULL"}, // (n+19)*a + }, } var linuxS390XTests = []*asmTest{ diff --git a/src/cmd/compile/internal/ssa/gen/generic.rules b/src/cmd/compile/internal/ssa/gen/generic.rules index bb4a83738f5..126e6ee6a1e 100644 --- a/src/cmd/compile/internal/ssa/gen/generic.rules +++ b/src/cmd/compile/internal/ssa/gen/generic.rules @@ -322,6 +322,12 @@ (Mul32 (Const32 [c]) (Add32 (Const32 [d]) x)) -> (Add32 (Const32 [int64(int32(c*d))]) (Mul32 (Const32 [c]) x)) +// Rewrite x*y + x*z to x*(y+z) +(Add64 (Mul64 x y) (Mul64 x z)) -> (Mul64 x (Add64 y z)) +(Add32 (Mul32 x y) (Mul32 x z)) -> (Mul32 x (Add32 y z)) +(Add16 (Mul16 x y) (Mul16 x z)) -> (Mul16 x (Add16 y z)) +(Add8 (Mul8 x y) (Mul8 x z)) -> (Mul8 x (Add8 y z)) + // rewrite shifts of 8/16/32 bit consts into 64 bit consts to reduce // the number of the other rewrite rules for const shifts (Lsh64x32 x (Const32 [c])) -> (Lsh64x64 x (Const64 [int64(uint32(c))])) diff --git a/src/cmd/compile/internal/ssa/rewritegeneric.go b/src/cmd/compile/internal/ssa/rewritegeneric.go index 0812449019e..61fad279ab8 100644 --- a/src/cmd/compile/internal/ssa/rewritegeneric.go +++ b/src/cmd/compile/internal/ssa/rewritegeneric.go @@ -16,17 +16,17 @@ var _ = types.TypeMem // in case not otherwise used func rewriteValuegeneric(v *Value) bool { switch v.Op { case OpAdd16: - return rewriteValuegeneric_OpAdd16_0(v) || rewriteValuegeneric_OpAdd16_10(v) || rewriteValuegeneric_OpAdd16_20(v) + return rewriteValuegeneric_OpAdd16_0(v) || rewriteValuegeneric_OpAdd16_10(v) || rewriteValuegeneric_OpAdd16_20(v) || rewriteValuegeneric_OpAdd16_30(v) case OpAdd32: - return rewriteValuegeneric_OpAdd32_0(v) || rewriteValuegeneric_OpAdd32_10(v) || rewriteValuegeneric_OpAdd32_20(v) + return rewriteValuegeneric_OpAdd32_0(v) || rewriteValuegeneric_OpAdd32_10(v) || rewriteValuegeneric_OpAdd32_20(v) || rewriteValuegeneric_OpAdd32_30(v) case OpAdd32F: return rewriteValuegeneric_OpAdd32F_0(v) case OpAdd64: - return rewriteValuegeneric_OpAdd64_0(v) || rewriteValuegeneric_OpAdd64_10(v) || rewriteValuegeneric_OpAdd64_20(v) + return rewriteValuegeneric_OpAdd64_0(v) || rewriteValuegeneric_OpAdd64_10(v) || rewriteValuegeneric_OpAdd64_20(v) || rewriteValuegeneric_OpAdd64_30(v) case OpAdd64F: return rewriteValuegeneric_OpAdd64F_0(v) case OpAdd8: - return rewriteValuegeneric_OpAdd8_0(v) || rewriteValuegeneric_OpAdd8_10(v) || rewriteValuegeneric_OpAdd8_20(v) + return rewriteValuegeneric_OpAdd8_0(v) || rewriteValuegeneric_OpAdd8_10(v) || rewriteValuegeneric_OpAdd8_20(v) || rewriteValuegeneric_OpAdd8_30(v) case OpAddPtr: return rewriteValuegeneric_OpAddPtr_0(v) case OpAnd16: @@ -467,6 +467,251 @@ func rewriteValuegeneric_OpAdd16_0(v *Value) bool { v.AuxInt = int64(int16(c + d)) return true } + // match: (Add16 (Mul16 x y) (Mul16 x z)) + // cond: + // result: (Mul16 x (Add16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add16 (Mul16 y x) (Mul16 x z)) + // cond: + // result: (Mul16 x (Add16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add16 (Mul16 x y) (Mul16 z x)) + // cond: + // result: (Mul16 x (Add16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add16 (Mul16 y x) (Mul16 z x)) + // cond: + // result: (Mul16 x (Add16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add16 (Mul16 x z) (Mul16 x y)) + // cond: + // result: (Mul16 x (Add16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + z := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + y := v_1.Args[1] + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add16 (Mul16 z x) (Mul16 x y)) + // cond: + // result: (Mul16 x (Add16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + z := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + y := v_1.Args[1] + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add16 (Mul16 x z) (Mul16 y x)) + // cond: + // result: (Mul16 x (Add16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + z := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + y := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add16 (Mul16 z x) (Mul16 y x)) + // cond: + // result: (Mul16 x (Add16 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul16 { + break + } + _ = v_0.Args[1] + z := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul16 { + break + } + _ = v_1.Args[1] + y := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul16) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd16, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + return false +} +func rewriteValuegeneric_OpAdd16_10(v *Value) bool { + b := v.Block + _ = b // match: (Add16 (Const16 [0]) x) // cond: // result: x @@ -657,11 +902,6 @@ func rewriteValuegeneric_OpAdd16_0(v *Value) bool { v.AddArg(v0) return true } - return false -} -func rewriteValuegeneric_OpAdd16_10(v *Value) bool { - b := v.Block - _ = b // match: (Add16 (Sub16 i:(Const16 ) z) x) // cond: (z.Op != OpConst16 && x.Op != OpConst16) // result: (Add16 i (Sub16 x z)) @@ -718,6 +958,11 @@ func rewriteValuegeneric_OpAdd16_10(v *Value) bool { v.AddArg(v0) return true } + return false +} +func rewriteValuegeneric_OpAdd16_20(v *Value) bool { + b := v.Block + _ = b // match: (Add16 x (Sub16 i:(Const16 ) z)) // cond: (z.Op != OpConst16 && x.Op != OpConst16) // result: (Add16 i (Sub16 x z)) @@ -950,11 +1195,6 @@ func rewriteValuegeneric_OpAdd16_10(v *Value) bool { v.AddArg(x) return true } - return false -} -func rewriteValuegeneric_OpAdd16_20(v *Value) bool { - b := v.Block - _ = b // match: (Add16 (Add16 (Const16 [d]) x) (Const16 [c])) // cond: // result: (Add16 (Const16 [int64(int16(c+d))]) x) @@ -1019,6 +1259,11 @@ func rewriteValuegeneric_OpAdd16_20(v *Value) bool { v.AddArg(x) return true } + return false +} +func rewriteValuegeneric_OpAdd16_30(v *Value) bool { + b := v.Block + _ = b // match: (Add16 (Const16 [c]) (Sub16 (Const16 [d]) x)) // cond: // result: (Sub16 (Const16 [int64(int16(c+d))]) x) @@ -1190,6 +1435,251 @@ func rewriteValuegeneric_OpAdd32_0(v *Value) bool { v.AuxInt = int64(int32(c + d)) return true } + // match: (Add32 (Mul32 x y) (Mul32 x z)) + // cond: + // result: (Mul32 x (Add32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add32 (Mul32 y x) (Mul32 x z)) + // cond: + // result: (Mul32 x (Add32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add32 (Mul32 x y) (Mul32 z x)) + // cond: + // result: (Mul32 x (Add32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add32 (Mul32 y x) (Mul32 z x)) + // cond: + // result: (Mul32 x (Add32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add32 (Mul32 x z) (Mul32 x y)) + // cond: + // result: (Mul32 x (Add32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + z := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + y := v_1.Args[1] + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add32 (Mul32 z x) (Mul32 x y)) + // cond: + // result: (Mul32 x (Add32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + z := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + y := v_1.Args[1] + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add32 (Mul32 x z) (Mul32 y x)) + // cond: + // result: (Mul32 x (Add32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + z := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + y := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add32 (Mul32 z x) (Mul32 y x)) + // cond: + // result: (Mul32 x (Add32 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul32 { + break + } + _ = v_0.Args[1] + z := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul32 { + break + } + _ = v_1.Args[1] + y := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul32) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd32, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + return false +} +func rewriteValuegeneric_OpAdd32_10(v *Value) bool { + b := v.Block + _ = b // match: (Add32 (Const32 [0]) x) // cond: // result: x @@ -1380,11 +1870,6 @@ func rewriteValuegeneric_OpAdd32_0(v *Value) bool { v.AddArg(v0) return true } - return false -} -func rewriteValuegeneric_OpAdd32_10(v *Value) bool { - b := v.Block - _ = b // match: (Add32 (Sub32 i:(Const32 ) z) x) // cond: (z.Op != OpConst32 && x.Op != OpConst32) // result: (Add32 i (Sub32 x z)) @@ -1441,6 +1926,11 @@ func rewriteValuegeneric_OpAdd32_10(v *Value) bool { v.AddArg(v0) return true } + return false +} +func rewriteValuegeneric_OpAdd32_20(v *Value) bool { + b := v.Block + _ = b // match: (Add32 x (Sub32 i:(Const32 ) z)) // cond: (z.Op != OpConst32 && x.Op != OpConst32) // result: (Add32 i (Sub32 x z)) @@ -1673,11 +2163,6 @@ func rewriteValuegeneric_OpAdd32_10(v *Value) bool { v.AddArg(x) return true } - return false -} -func rewriteValuegeneric_OpAdd32_20(v *Value) bool { - b := v.Block - _ = b // match: (Add32 (Add32 (Const32 [d]) x) (Const32 [c])) // cond: // result: (Add32 (Const32 [int64(int32(c+d))]) x) @@ -1742,6 +2227,11 @@ func rewriteValuegeneric_OpAdd32_20(v *Value) bool { v.AddArg(x) return true } + return false +} +func rewriteValuegeneric_OpAdd32_30(v *Value) bool { + b := v.Block + _ = b // match: (Add32 (Const32 [c]) (Sub32 (Const32 [d]) x)) // cond: // result: (Sub32 (Const32 [int64(int32(c+d))]) x) @@ -1990,6 +2480,251 @@ func rewriteValuegeneric_OpAdd64_0(v *Value) bool { v.AuxInt = c + d return true } + // match: (Add64 (Mul64 x y) (Mul64 x z)) + // cond: + // result: (Mul64 x (Add64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add64 (Mul64 y x) (Mul64 x z)) + // cond: + // result: (Mul64 x (Add64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add64 (Mul64 x y) (Mul64 z x)) + // cond: + // result: (Mul64 x (Add64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add64 (Mul64 y x) (Mul64 z x)) + // cond: + // result: (Mul64 x (Add64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add64 (Mul64 x z) (Mul64 x y)) + // cond: + // result: (Mul64 x (Add64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + z := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + y := v_1.Args[1] + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add64 (Mul64 z x) (Mul64 x y)) + // cond: + // result: (Mul64 x (Add64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + z := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + y := v_1.Args[1] + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add64 (Mul64 x z) (Mul64 y x)) + // cond: + // result: (Mul64 x (Add64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + z := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + y := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add64 (Mul64 z x) (Mul64 y x)) + // cond: + // result: (Mul64 x (Add64 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul64 { + break + } + _ = v_0.Args[1] + z := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul64 { + break + } + _ = v_1.Args[1] + y := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul64) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd64, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + return false +} +func rewriteValuegeneric_OpAdd64_10(v *Value) bool { + b := v.Block + _ = b // match: (Add64 (Const64 [0]) x) // cond: // result: x @@ -2180,11 +2915,6 @@ func rewriteValuegeneric_OpAdd64_0(v *Value) bool { v.AddArg(v0) return true } - return false -} -func rewriteValuegeneric_OpAdd64_10(v *Value) bool { - b := v.Block - _ = b // match: (Add64 (Sub64 i:(Const64 ) z) x) // cond: (z.Op != OpConst64 && x.Op != OpConst64) // result: (Add64 i (Sub64 x z)) @@ -2241,6 +2971,11 @@ func rewriteValuegeneric_OpAdd64_10(v *Value) bool { v.AddArg(v0) return true } + return false +} +func rewriteValuegeneric_OpAdd64_20(v *Value) bool { + b := v.Block + _ = b // match: (Add64 x (Sub64 i:(Const64 ) z)) // cond: (z.Op != OpConst64 && x.Op != OpConst64) // result: (Add64 i (Sub64 x z)) @@ -2473,11 +3208,6 @@ func rewriteValuegeneric_OpAdd64_10(v *Value) bool { v.AddArg(x) return true } - return false -} -func rewriteValuegeneric_OpAdd64_20(v *Value) bool { - b := v.Block - _ = b // match: (Add64 (Add64 (Const64 [d]) x) (Const64 [c])) // cond: // result: (Add64 (Const64 [c+d]) x) @@ -2542,6 +3272,11 @@ func rewriteValuegeneric_OpAdd64_20(v *Value) bool { v.AddArg(x) return true } + return false +} +func rewriteValuegeneric_OpAdd64_30(v *Value) bool { + b := v.Block + _ = b // match: (Add64 (Const64 [c]) (Sub64 (Const64 [d]) x)) // cond: // result: (Sub64 (Const64 [c+d]) x) @@ -2790,6 +3525,251 @@ func rewriteValuegeneric_OpAdd8_0(v *Value) bool { v.AuxInt = int64(int8(c + d)) return true } + // match: (Add8 (Mul8 x y) (Mul8 x z)) + // cond: + // result: (Mul8 x (Add8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add8 (Mul8 y x) (Mul8 x z)) + // cond: + // result: (Mul8 x (Add8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + z := v_1.Args[1] + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add8 (Mul8 x y) (Mul8 z x)) + // cond: + // result: (Mul8 x (Add8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + y := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add8 (Mul8 y x) (Mul8 z x)) + // cond: + // result: (Mul8 x (Add8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + y := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + z := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add8 (Mul8 x z) (Mul8 x y)) + // cond: + // result: (Mul8 x (Add8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + z := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + y := v_1.Args[1] + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add8 (Mul8 z x) (Mul8 x y)) + // cond: + // result: (Mul8 x (Add8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + z := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + if x != v_1.Args[0] { + break + } + y := v_1.Args[1] + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add8 (Mul8 x z) (Mul8 y x)) + // cond: + // result: (Mul8 x (Add8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + x := v_0.Args[0] + z := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + y := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + // match: (Add8 (Mul8 z x) (Mul8 y x)) + // cond: + // result: (Mul8 x (Add8 y z)) + for { + t := v.Type + _ = v.Args[1] + v_0 := v.Args[0] + if v_0.Op != OpMul8 { + break + } + _ = v_0.Args[1] + z := v_0.Args[0] + x := v_0.Args[1] + v_1 := v.Args[1] + if v_1.Op != OpMul8 { + break + } + _ = v_1.Args[1] + y := v_1.Args[0] + if x != v_1.Args[1] { + break + } + v.reset(OpMul8) + v.AddArg(x) + v0 := b.NewValue0(v.Pos, OpAdd8, t) + v0.AddArg(y) + v0.AddArg(z) + v.AddArg(v0) + return true + } + return false +} +func rewriteValuegeneric_OpAdd8_10(v *Value) bool { + b := v.Block + _ = b // match: (Add8 (Const8 [0]) x) // cond: // result: x @@ -2980,11 +3960,6 @@ func rewriteValuegeneric_OpAdd8_0(v *Value) bool { v.AddArg(v0) return true } - return false -} -func rewriteValuegeneric_OpAdd8_10(v *Value) bool { - b := v.Block - _ = b // match: (Add8 (Sub8 i:(Const8 ) z) x) // cond: (z.Op != OpConst8 && x.Op != OpConst8) // result: (Add8 i (Sub8 x z)) @@ -3041,6 +4016,11 @@ func rewriteValuegeneric_OpAdd8_10(v *Value) bool { v.AddArg(v0) return true } + return false +} +func rewriteValuegeneric_OpAdd8_20(v *Value) bool { + b := v.Block + _ = b // match: (Add8 x (Sub8 i:(Const8 ) z)) // cond: (z.Op != OpConst8 && x.Op != OpConst8) // result: (Add8 i (Sub8 x z)) @@ -3273,11 +4253,6 @@ func rewriteValuegeneric_OpAdd8_10(v *Value) bool { v.AddArg(x) return true } - return false -} -func rewriteValuegeneric_OpAdd8_20(v *Value) bool { - b := v.Block - _ = b // match: (Add8 (Add8 (Const8 [d]) x) (Const8 [c])) // cond: // result: (Add8 (Const8 [int64(int8(c+d))]) x) @@ -3342,6 +4317,11 @@ func rewriteValuegeneric_OpAdd8_20(v *Value) bool { v.AddArg(x) return true } + return false +} +func rewriteValuegeneric_OpAdd8_30(v *Value) bool { + b := v.Block + _ = b // match: (Add8 (Const8 [c]) (Sub8 (Const8 [d]) x)) // cond: // result: (Sub8 (Const8 [int64(int8(c+d))]) x) diff --git a/test/mergemul.go b/test/mergemul.go new file mode 100644 index 00000000000..86fbe676cb8 --- /dev/null +++ b/test/mergemul.go @@ -0,0 +1,81 @@ +// runoutput + +// Copyright 2017 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "fmt" + +// Check that expressions like (c*n + d*(n+k)) get correctly merged by +// the compiler into (c+d)*n + d*k (with c+d and d*k computed at +// compile time). +// +// The merging is performed by a combination of the multiplication +// merge rules +// (c*n + d*n) -> (c+d)*n +// and the distributive multiplication rules +// c * (d+x) -> c*d + c*x + +// Generate a MergeTest that looks like this: +// +// a8, b8 = m1*n8 + m2*(n8+k), (m1+m2)*n8 + m2*k +// if a8 != b8 { +// // print error msg and panic +// } +func makeMergeTest(m1, m2, k int, size string) string { + + model := " a" + size + ", b" + size + model += fmt.Sprintf(" = %%d*n%s + %%d*(n%s+%%d), (%%d+%%d)*n%s + (%%d*%%d)", size, size, size) + + test := fmt.Sprintf(model, m1, m2, k, m1, m2, m2, k) + test += fmt.Sprintf(` + if a%s != b%s { + fmt.Printf("MergeTest(%d, %d, %d, %s) failed\n") + fmt.Printf("%%d != %%d\n", a%s, b%s) + panic("FAIL") + } +`, size, size, m1, m2, k, size, size, size) + return test + "\n" +} + +func makeAllSizes(m1, m2, k int) string { + var tests string + tests += makeMergeTest(m1, m2, k, "8") + tests += makeMergeTest(m1, m2, k, "16") + tests += makeMergeTest(m1, m2, k, "32") + tests += makeMergeTest(m1, m2, k, "64") + tests += "\n" + return tests +} + +func main() { + fmt.Println(`package main + +import "fmt" + +var n8 int8 = 42 +var n16 int16 = 42 +var n32 int32 = 42 +var n64 int64 = 42 + +func main() { + var a8, b8 int8 + var a16, b16 int16 + var a32, b32 int32 + var a64, b64 int64 +`) + + fmt.Println(makeAllSizes(03, 05, 0)) // 3*n + 5*n + fmt.Println(makeAllSizes(17, 33, 0)) + fmt.Println(makeAllSizes(80, 45, 0)) + fmt.Println(makeAllSizes(32, 64, 0)) + + fmt.Println(makeAllSizes(7, 11, +1)) // 7*n + 11*(n+1) + fmt.Println(makeAllSizes(9, 13, +2)) + fmt.Println(makeAllSizes(11, 16, -1)) + fmt.Println(makeAllSizes(17, 9, -2)) + + fmt.Println("}") +}