From 1a53915c6b905ae5d0a4398362cc655a1406cf06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20Mart=C3=AD?= Date: Thu, 16 May 2019 11:26:40 +0100 Subject: [PATCH] cmd/compile: initial rulegen rewrite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit rulegen.go produces plaintext Go code directly, which was fine for a while. However, that's started being a bottleneck for making code generation more complex, as we can only generate code directly one line at a time. Some workarounds were used, like multiple layers of buffers to generate chunks of code, to then use strings.Contains to see whether variables need to be defined or not. However, that's error-prone, verbose, and difficult to work with. A better approach is to generate an intermediate syntax tree in memory, which we can inspect and modify easily. For example, we could run a number of "passes" on the syntax tree before writing to disk, such as removing unused variables, simplifying logic, or moving declarations closer to their uses. This is the first step in that direction, without changing any of the generated code. We didn't use go/ast directly, as it's too complex for our needs. In particular, we only need a few kinds of simple statements, but we do want to support arbitrary expressions. As such, define a simple set of statement structs, and add thin layers for printer.Fprint and ast.Inspect. A nice side effect of this change, besides removing some buffers and string handling, is that we can now avoid passing so many parameters around. And, while we add over a hundred lines of code, the tricky pieces of code are now a bit simpler to follow. While at it, apply some cleanups, such as replacing isVariable with token.IsIdentifier, and consistently using log.Fatalf. Follow-up CLs will start improving the generated code, also simplifying the rulegen code itself. I've added some TODOs for the low-hanging fruit that I intend to work on right after. Updates #30810. Change-Id: Ic371c192b29c85dfc4a001be7fbcbeec85facc9d Reviewed-on: https://go-review.googlesource.com/c/go/+/177539 Run-TryBot: Daniel Martí TryBot-Result: Gobot Gobot Reviewed-by: Keith Randall --- src/cmd/compile/internal/ssa/gen/rulegen.go | 830 ++++++++++++-------- 1 file changed, 487 insertions(+), 343 deletions(-) diff --git a/src/cmd/compile/internal/ssa/gen/rulegen.go b/src/cmd/compile/internal/ssa/gen/rulegen.go index 4ca6796f7c6..0e89af73e94 100644 --- a/src/cmd/compile/internal/ssa/gen/rulegen.go +++ b/src/cmd/compile/internal/ssa/gen/rulegen.go @@ -16,10 +16,15 @@ import ( "bytes" "flag" "fmt" + "go/ast" "go/format" + "go/parser" + "go/printer" + "go/token" "io" "io/ioutil" "log" + "math" "os" "regexp" "sort" @@ -47,9 +52,7 @@ import ( // If multiple rules match, the first one in file order is selected. -var ( - genLog = flag.Bool("log", false, "generate code that logs; for debugging only") -) +var genLog = flag.Bool("log", false, "generate code that logs; for debugging only") type Rule struct { rule string @@ -136,13 +139,13 @@ func genRulesSuffix(arch arch, suff string) { r := Rule{rule: rule3, loc: loc} if rawop := strings.Split(rule3, " ")[0][1:]; isBlock(rawop, arch) { blockrules[rawop] = append(blockrules[rawop], r) - } else { - // Do fancier value op matching. - match, _, _ := r.parse() - op, oparch, _, _, _, _ := parseValue(match, arch, loc) - opname := fmt.Sprintf("Op%s%s", oparch, op.name) - oprules[opname] = append(oprules[opname], r) + continue } + // Do fancier value op matching. + match, _, _ := r.parse() + op, oparch, _, _, _, _ := parseValue(match, arch, loc) + opname := fmt.Sprintf("Op%s%s", oparch, op.name) + oprules[opname] = append(oprules[opname], r) } } rule = "" @@ -162,256 +165,448 @@ func genRulesSuffix(arch arch, suff string) { } sort.Strings(ops) - // Start output buffer, write header. - w := new(bytes.Buffer) - fmt.Fprintf(w, "// Code generated from gen/%s%s.rules; DO NOT EDIT.\n", arch.name, suff) - fmt.Fprintln(w, "// generated with: cd gen; go run *.go") - fmt.Fprintln(w) - fmt.Fprintln(w, "package ssa") - fmt.Fprintln(w, "import \"fmt\"") - fmt.Fprintln(w, "import \"math\"") - fmt.Fprintln(w, "import \"cmd/internal/obj\"") - fmt.Fprintln(w, "import \"cmd/internal/objabi\"") - fmt.Fprintln(w, "import \"cmd/compile/internal/types\"") - fmt.Fprintln(w, "var _ = fmt.Println // in case not otherwise used") - fmt.Fprintln(w, "var _ = math.MinInt8 // in case not otherwise used") - fmt.Fprintln(w, "var _ = obj.ANOP // in case not otherwise used") - fmt.Fprintln(w, "var _ = objabi.GOROOT // in case not otherwise used") - fmt.Fprintln(w, "var _ = types.TypeMem // in case not otherwise used") - fmt.Fprintln(w) - + file := &File{arch: arch, suffix: suff} const chunkSize = 10 // Main rewrite routine is a switch on v.Op. - fmt.Fprintf(w, "func rewriteValue%s%s(v *Value) bool {\n", arch.name, suff) - fmt.Fprintf(w, "switch v.Op {\n") + fn := &Func{kind: "Value"} + + sw := &Switch{expr: exprf("v.Op")} for _, op := range ops { - fmt.Fprintf(w, "case %s:\n", op) - fmt.Fprint(w, "return ") + var ors []string for chunk := 0; chunk < len(oprules[op]); chunk += chunkSize { - if chunk > 0 { - fmt.Fprint(w, " || ") - } - fmt.Fprintf(w, "rewriteValue%s%s_%s_%d(v)", arch.name, suff, op, chunk) + ors = append(ors, fmt.Sprintf("rewriteValue%s%s_%s_%d(v)", arch.name, suff, op, chunk)) } - fmt.Fprintln(w) + swc := &Case{expr: exprf(op)} + swc.add(stmtf("return %s", strings.Join(ors, " || "))) + sw.add(swc) } - fmt.Fprintf(w, "}\n") - fmt.Fprintf(w, "return false\n") - fmt.Fprintf(w, "}\n") + fn.add(sw) + fn.add(stmtf("return false")) + file.add(fn) // Generate a routine per op. Note that we don't make one giant routine // because it is too big for some compilers. for _, op := range ops { - for chunk := 0; chunk < len(oprules[op]); chunk += chunkSize { - buf := new(bytes.Buffer) - var canFail bool + rules := oprules[op] + // rr is kept between chunks, so that a following chunk checks + // that the previous one ended with a rule that wasn't + // unconditional. + var rr *RuleRewrite + for chunk := 0; chunk < len(rules); chunk += chunkSize { endchunk := chunk + chunkSize - if endchunk > len(oprules[op]) { - endchunk = len(oprules[op]) + if endchunk > len(rules) { + endchunk = len(rules) } - for i, rule := range oprules[op][chunk:endchunk] { - match, cond, result := rule.parse() - fmt.Fprintf(buf, "// match: %s\n", match) - fmt.Fprintf(buf, "// cond: %s\n", cond) - fmt.Fprintf(buf, "// result: %s\n", result) - - canFail = false - fmt.Fprintf(buf, "for {\n") - pos, _, matchCanFail := genMatch(buf, arch, match, rule.loc) + fn := &Func{ + kind: "Value", + suffix: fmt.Sprintf("_%s_%d", op, chunk), + } + var rewrites bodyBase + for _, rule := range rules[chunk:endchunk] { + if rr != nil && !rr.canFail { + log.Fatalf("unconditional rule %s is followed by other rules", rr.match) + } + rr = &RuleRewrite{loc: rule.loc} + rr.match, rr.cond, rr.result = rule.parse() + pos, _ := genMatch(rr, arch, rr.match) if pos == "" { pos = "v.Pos" } - if matchCanFail { - canFail = true + if rr.cond != "" { + rr.add(breakf("!(%s)", rr.cond)) } - - if cond != "" { - fmt.Fprintf(buf, "if !(%s) {\nbreak\n}\n", cond) - canFail = true - } - if !canFail && i+chunk != len(oprules[op])-1 { - log.Fatalf("unconditional rule %s is followed by other rules", match) - } - - genResult(buf, arch, result, rule.loc, pos) + genResult(rr, arch, rr.result, pos) if *genLog { - fmt.Fprintf(buf, "logRule(\"%s\")\n", rule.loc) + rr.add(stmtf("logRule(%q)", rule.loc)) } - fmt.Fprintf(buf, "return true\n") - - fmt.Fprintf(buf, "}\n") - } - if canFail { - fmt.Fprintf(buf, "return false\n") + rewrites.add(rr) } - body := buf.String() - // Figure out whether we need b, config, fe, and/or types; provide them if so. - hasb := strings.Contains(body, " b.") - hasconfig := strings.Contains(body, "config.") || strings.Contains(body, "config)") - hasfe := strings.Contains(body, "fe.") - hastyps := strings.Contains(body, "typ.") - fmt.Fprintf(w, "func rewriteValue%s%s_%s_%d(v *Value) bool {\n", arch.name, suff, op, chunk) - if hasb || hasconfig || hasfe || hastyps { - fmt.Fprintln(w, "b := v.Block") + // TODO(mvdan): remove unused vars later instead + uses := make(map[string]int) + walk(&rewrites, func(node Node) { + switch node := node.(type) { + case *Declare: + // work around var shadowing + // TODO(mvdan): forbid it instead. + uses[node.name] = math.MinInt32 + case *ast.Ident: + uses[node.Name]++ + } + }) + if uses["b"]+uses["config"]+uses["fe"]+uses["typ"] > 0 { + fn.add(declf("b", "v.Block")) } - if hasconfig { - fmt.Fprintln(w, "config := b.Func.Config") + if uses["config"] > 0 { + fn.add(declf("config", "b.Func.Config")) } - if hasfe { - fmt.Fprintln(w, "fe := b.Func.fe") + if uses["fe"] > 0 { + fn.add(declf("fe", "b.Func.fe")) } - if hastyps { - fmt.Fprintln(w, "typ := &b.Func.Config.Types") + if uses["typ"] > 0 { + fn.add(declf("typ", "&b.Func.Config.Types")) } - fmt.Fprint(w, body) - fmt.Fprintf(w, "}\n") + fn.add(rewrites.list...) + if rr.canFail { + fn.add(stmtf("return false")) + } + file.add(fn) } } // Generate block rewrite function. There are only a few block types // so we can make this one function with a switch. - fmt.Fprintf(w, "func rewriteBlock%s%s(b *Block) bool {\n", arch.name, suff) - fmt.Fprintln(w, "config := b.Func.Config") - fmt.Fprintln(w, "typ := &config.Types") - fmt.Fprintln(w, "_ = typ") - fmt.Fprintln(w, "v := b.Control") - fmt.Fprintln(w, "_ = v") - fmt.Fprintf(w, "switch b.Kind {\n") - ops = nil + fn = &Func{kind: "Block"} + fn.add(declf("config", "b.Func.Config")) + // TODO(mvdan): declare these only if needed + fn.add(declf("typ", "&config.Types")) + fn.add(stmtf("_ = typ")) + fn.add(declf("v", "b.Control")) + fn.add(stmtf("_ = v")) + + sw = &Switch{expr: exprf("b.Kind")} + ops = ops[:0] for op := range blockrules { ops = append(ops, op) } sort.Strings(ops) for _, op := range ops { - fmt.Fprintf(w, "case %s:\n", blockName(op, arch)) + swc := &Case{expr: exprf("%s", blockName(op, arch))} for _, rule := range blockrules[op] { - match, cond, result := rule.parse() - fmt.Fprintf(w, "// match: %s\n", match) - fmt.Fprintf(w, "// cond: %s\n", cond) - fmt.Fprintf(w, "// result: %s\n", result) - - _, _, _, aux, s := extract(match) // remove parens, then split - - loopw := new(bytes.Buffer) - - // check match of control value - pos := "" - checkOp := "" - if s[0] != "nil" { - if strings.Contains(s[0], "(") { - pos, checkOp, _ = genMatch0(loopw, arch, s[0], "v", map[string]struct{}{}, rule.loc) - } else { - fmt.Fprintf(loopw, "%s := b.Control\n", s[0]) - } - } - if aux != "" { - fmt.Fprintf(loopw, "%s := b.Aux\n", aux) - } - - if cond != "" { - fmt.Fprintf(loopw, "if !(%s) {\nbreak\n}\n", cond) - } - - // Rule matches. Generate result. - outop, _, _, aux, t := extract(result) // remove parens, then split - newsuccs := t[1:] - - // Check if newsuccs is the same set as succs. - succs := s[1:] - m := map[string]bool{} - for _, succ := range succs { - if m[succ] { - log.Fatalf("can't have a repeat successor name %s in %s", succ, rule) - } - m[succ] = true - } - for _, succ := range newsuccs { - if !m[succ] { - log.Fatalf("unknown successor %s in %s", succ, rule) - } - delete(m, succ) - } - if len(m) != 0 { - log.Fatalf("unmatched successors %v in %s", m, rule) - } - - fmt.Fprintf(loopw, "b.Kind = %s\n", blockName(outop, arch)) - if t[0] == "nil" { - fmt.Fprintf(loopw, "b.SetControl(nil)\n") - } else { - if pos == "" { - pos = "v.Pos" - } - fmt.Fprintf(loopw, "b.SetControl(%s)\n", genResult0(loopw, arch, t[0], new(int), false, false, rule.loc, pos)) - } - if aux != "" { - fmt.Fprintf(loopw, "b.Aux = %s\n", aux) - } else { - fmt.Fprintln(loopw, "b.Aux = nil") - } - - succChanged := false - for i := 0; i < len(succs); i++ { - if succs[i] != newsuccs[i] { - succChanged = true - } - } - if succChanged { - if len(succs) != 2 { - log.Fatalf("changed successors, len!=2 in %s", rule) - } - if succs[0] != newsuccs[1] || succs[1] != newsuccs[0] { - log.Fatalf("can only handle swapped successors in %s", rule) - } - fmt.Fprintln(loopw, "b.swapSuccessors()") - } - - if *genLog { - fmt.Fprintf(loopw, "logRule(\"%s\")\n", rule.loc) - } - fmt.Fprintf(loopw, "return true\n") - - if checkOp != "" { - fmt.Fprintf(w, "for v.Op == %s {\n", checkOp) - } else { - fmt.Fprintf(w, "for {\n") - } - io.Copy(w, loopw) - - fmt.Fprintf(w, "}\n") + swc.add(genBlockRewrite(rule, arch)) } + sw.add(swc) } - fmt.Fprintf(w, "}\n") - fmt.Fprintf(w, "return false\n") - fmt.Fprintf(w, "}\n") + fn.add(sw) + fn.add(stmtf("return false")) + file.add(fn) // gofmt result - b := w.Bytes() + buf := new(bytes.Buffer) + fprint(buf, file) + b := buf.Bytes() src, err := format.Source(b) if err != nil { fmt.Printf("%s\n", b) panic(err) } - // Write to file - err = ioutil.WriteFile("../rewrite"+arch.name+suff+".go", src, 0666) - if err != nil { + if err := ioutil.WriteFile("../rewrite"+arch.name+suff+".go", src, 0666); err != nil { log.Fatalf("can't write output: %v\n", err) } } -// genMatch returns the variable whose source position should be used for the -// result (or "" if no opinion), and a boolean that reports whether the match can fail. -func genMatch(w io.Writer, arch arch, match string, loc string) (pos, checkOp string, canFail bool) { - return genMatch0(w, arch, match, "v", map[string]struct{}{}, loc) +func walk(node Node, fn func(Node)) { + fn(node) + switch node := node.(type) { + case *bodyBase: + case *File: + case *Func: + case *Switch: + walk(node.expr, fn) + case *Case: + walk(node.expr, fn) + case *RuleRewrite: + case *Declare: + walk(node.value, fn) + case *CondBreak: + walk(node.expr, fn) + case ast.Node: + ast.Inspect(node, func(node ast.Node) bool { + fn(node) + return true + }) + default: + log.Fatalf("cannot walk %T", node) + } + if wb, ok := node.(interface{ body() []Statement }); ok { + for _, node := range wb.body() { + walk(node, fn) + } + } + fn(nil) } -func genMatch0(w io.Writer, arch arch, match, v string, m map[string]struct{}, loc string) (pos, checkOp string, canFail bool) { - if match[0] != '(' || match[len(match)-1] != ')' { - panic("non-compound expr in genMatch0: " + match) +func fprint(w io.Writer, n Node) { + switch n := n.(type) { + case *File: + fmt.Fprintf(w, "// Code generated from gen/%s%s.rules; DO NOT EDIT.\n", n.arch.name, n.suffix) + fmt.Fprintf(w, "// generated with: cd gen; go run *.go\n") + fmt.Fprintf(w, "\npackage ssa\n") + // TODO(mvdan): keep the needed imports only + fmt.Fprintln(w, "import \"fmt\"") + fmt.Fprintln(w, "import \"math\"") + fmt.Fprintln(w, "import \"cmd/internal/obj\"") + fmt.Fprintln(w, "import \"cmd/internal/objabi\"") + fmt.Fprintln(w, "import \"cmd/compile/internal/types\"") + fmt.Fprintln(w, "var _ = fmt.Println // in case not otherwise used") + fmt.Fprintln(w, "var _ = math.MinInt8 // in case not otherwise used") + fmt.Fprintln(w, "var _ = obj.ANOP // in case not otherwise used") + fmt.Fprintln(w, "var _ = objabi.GOROOT // in case not otherwise used") + fmt.Fprintln(w, "var _ = types.TypeMem // in case not otherwise used") + fmt.Fprintln(w) + for _, f := range n.list { + f := f.(*Func) + fmt.Fprintf(w, "func rewrite%s%s%s%s(", f.kind, n.arch.name, n.suffix, f.suffix) + fmt.Fprintf(w, "%c *%s) bool {\n", strings.ToLower(f.kind)[0], f.kind) + for _, n := range f.list { + fprint(w, n) + } + fmt.Fprintf(w, "}\n") + } + case *Switch: + fmt.Fprintf(w, "switch ") + fprint(w, n.expr) + fmt.Fprintf(w, " {\n") + for _, n := range n.list { + fprint(w, n) + } + fmt.Fprintf(w, "}\n") + case *Case: + fmt.Fprintf(w, "case ") + fprint(w, n.expr) + fmt.Fprintf(w, ":\n") + for _, n := range n.list { + fprint(w, n) + } + case *RuleRewrite: + fmt.Fprintf(w, "// match: %s\n", n.match) + fmt.Fprintf(w, "// cond: %s\n", n.cond) + fmt.Fprintf(w, "// result: %s\n", n.result) + if n.checkOp != "" { + fmt.Fprintf(w, "for v.Op == %s {\n", n.checkOp) + } else { + fmt.Fprintf(w, "for {\n") + } + for _, n := range n.list { + fprint(w, n) + } + fmt.Fprintf(w, "return true\n}\n") + case *Declare: + fmt.Fprintf(w, "%s := ", n.name) + fprint(w, n.value) + fmt.Fprintln(w) + case *CondBreak: + fmt.Fprintf(w, "if ") + fprint(w, n.expr) + fmt.Fprintf(w, " {\nbreak\n}\n") + case ast.Node: + printer.Fprint(w, emptyFset, n) + if _, ok := n.(ast.Stmt); ok { + fmt.Fprintln(w) + } + default: + log.Fatalf("cannot print %T", n) } - op, oparch, typ, auxint, aux, args := parseValue(match, arch, loc) +} + +var emptyFset = token.NewFileSet() + +// Node can be a Statement or an ast.Expr. +type Node interface{} + +// Statement can be one of our high-level statement struct types, or an +// ast.Stmt under some limited circumstances. +type Statement interface{} + +// bodyBase is shared by all of our statement psuedo-node types which can +// contain other statements. +type bodyBase struct { + list []Statement + canFail bool +} + +func (w *bodyBase) body() []Statement { return w.list } +func (w *bodyBase) add(nodes ...Statement) { + w.list = append(w.list, nodes...) + for _, node := range nodes { + if _, ok := node.(*CondBreak); ok { + w.canFail = true + } + } +} + +// declared reports if the body contains a Declare with the given name. +func (w *bodyBase) declared(name string) bool { + for _, s := range w.list { + if decl, ok := s.(*Declare); ok && decl.name == name { + return true + } + } + return false +} + +// These types define some high-level statement struct types, which can be used +// as a Statement. This allows us to keep some node structs simpler, and have +// higher-level nodes such as an entire rule rewrite. +// +// Note that ast.Expr is always used as-is; we don't declare our own expression +// nodes. +type ( + File struct { + bodyBase // []*Func + arch arch + suffix string + } + Func struct { + bodyBase + kind string // "Value" or "Block" + suffix string + } + Switch struct { + bodyBase // []*Case + expr ast.Expr + } + Case struct { + bodyBase + expr ast.Expr + } + RuleRewrite struct { + bodyBase + match, cond, result string // top comments + checkOp string + + alloc int // for unique var names + loc string // file name & line number of the original rule + } + Declare struct { + name string + value ast.Expr + } + CondBreak struct { + expr ast.Expr + } +) + +// exprf parses a Go expression generated from fmt.Sprintf, panicking if an +// error occurs. +func exprf(format string, a ...interface{}) ast.Expr { + src := fmt.Sprintf(format, a...) + expr, err := parser.ParseExpr(src) + if err != nil { + log.Fatalf("expr parse error on %q: %v", src, err) + } + return expr +} + +// stmtf parses a Go statement generated from fmt.Sprintf. This function is only +// meant for simple statements that don't have a custom Statement node declared +// in this package, such as ast.ReturnStmt or ast.ExprStmt. +func stmtf(format string, a ...interface{}) Statement { + src := fmt.Sprintf(format, a...) + fsrc := "package p\nfunc _() {\n" + src + "\n}\n" + file, err := parser.ParseFile(token.NewFileSet(), "", fsrc, 0) + if err != nil { + log.Fatalf("stmt parse error on %q: %v", src, err) + } + return file.Decls[0].(*ast.FuncDecl).Body.List[0] +} + +// declf constructs a simple "name := value" declaration, using exprf for its +// value. +func declf(name, format string, a ...interface{}) *Declare { + return &Declare{name, exprf(format, a...)} +} + +// breakf constructs a simple "if cond { break }" statement, using exprf for its +// condition. +func breakf(format string, a ...interface{}) *CondBreak { + return &CondBreak{exprf(format, a...)} +} + +func genBlockRewrite(rule Rule, arch arch) *RuleRewrite { + rr := &RuleRewrite{loc: rule.loc} + rr.match, rr.cond, rr.result = rule.parse() + _, _, _, aux, s := extract(rr.match) // remove parens, then split + + // check match of control value + pos := "" + if s[0] != "nil" { + if strings.Contains(s[0], "(") { + pos, rr.checkOp = genMatch0(rr, arch, s[0], "v") + } else { + rr.add(declf(s[0], "b.Control")) + } + } + if aux != "" { + rr.add(declf(aux, "b.Aux")) + } + if rr.cond != "" { + rr.add(breakf("!(%s)", rr.cond)) + } + + // Rule matches. Generate result. + outop, _, _, aux, t := extract(rr.result) // remove parens, then split + newsuccs := t[1:] + + // Check if newsuccs is the same set as succs. + succs := s[1:] + m := map[string]bool{} + for _, succ := range succs { + if m[succ] { + log.Fatalf("can't have a repeat successor name %s in %s", succ, rule) + } + m[succ] = true + } + for _, succ := range newsuccs { + if !m[succ] { + log.Fatalf("unknown successor %s in %s", succ, rule) + } + delete(m, succ) + } + if len(m) != 0 { + log.Fatalf("unmatched successors %v in %s", m, rule) + } + + rr.add(stmtf("b.Kind = %s", blockName(outop, arch))) + if t[0] == "nil" { + rr.add(stmtf("b.SetControl(nil)")) + } else { + if pos == "" { + pos = "v.Pos" + } + v := genResult0(rr, arch, t[0], false, false, pos) + rr.add(stmtf("b.SetControl(%s)", v)) + } + if aux != "" { + rr.add(stmtf("b.Aux = %s", aux)) + } else { + rr.add(stmtf("b.Aux = nil")) + } + + succChanged := false + for i := 0; i < len(succs); i++ { + if succs[i] != newsuccs[i] { + succChanged = true + } + } + if succChanged { + if len(succs) != 2 { + log.Fatalf("changed successors, len!=2 in %s", rule) + } + if succs[0] != newsuccs[1] || succs[1] != newsuccs[0] { + log.Fatalf("can only handle swapped successors in %s", rule) + } + rr.add(stmtf("b.swapSuccessors()")) + } + + if *genLog { + rr.add(stmtf("logRule(%q)", rule.loc)) + } + return rr +} + +// genMatch returns the variable whose source position should be used for the +// result (or "" if no opinion), and a boolean that reports whether the match can fail. +func genMatch(rr *RuleRewrite, arch arch, match string) (pos, checkOp string) { + return genMatch0(rr, arch, match, "v") +} + +func genMatch0(rr *RuleRewrite, arch arch, match, v string) (pos, checkOp string) { + if match[0] != '(' || match[len(match)-1] != ')' { + log.Fatalf("non-compound expr in genMatch0: %q", match) + } + op, oparch, typ, auxint, aux, args := parseValue(match, arch, rr.loc) checkOp = fmt.Sprintf("Op%s%s", oparch, op.name) @@ -421,68 +616,40 @@ func genMatch0(w io.Writer, arch arch, match, v string, m map[string]struct{}, l } if typ != "" { - if !isVariable(typ) { - // code. We must match the results of this code. - fmt.Fprintf(w, "if %s.Type != %s {\nbreak\n}\n", v, typ) - canFail = true + if !token.IsIdentifier(typ) || rr.declared(typ) { + // code or variable + rr.add(breakf("%s.Type != %s", v, typ)) } else { - // variable - if _, ok := m[typ]; ok { - // must match previous variable - fmt.Fprintf(w, "if %s.Type != %s {\nbreak\n}\n", v, typ) - canFail = true - } else { - m[typ] = struct{}{} - fmt.Fprintf(w, "%s := %s.Type\n", typ, v) - } + rr.add(declf(typ, "%s.Type", v)) } } - if auxint != "" { - if !isVariable(auxint) { - // code - fmt.Fprintf(w, "if %s.AuxInt != %s {\nbreak\n}\n", v, auxint) - canFail = true + if !token.IsIdentifier(auxint) || rr.declared(auxint) { + // code or variable + rr.add(breakf("%s.AuxInt != %s", v, auxint)) } else { - // variable - if _, ok := m[auxint]; ok { - fmt.Fprintf(w, "if %s.AuxInt != %s {\nbreak\n}\n", v, auxint) - canFail = true - } else { - m[auxint] = struct{}{} - fmt.Fprintf(w, "%s := %s.AuxInt\n", auxint, v) - } + rr.add(declf(auxint, "%s.AuxInt", v)) } } - if aux != "" { - if !isVariable(aux) { - // code - fmt.Fprintf(w, "if %s.Aux != %s {\nbreak\n}\n", v, aux) - canFail = true + if !token.IsIdentifier(aux) || rr.declared(aux) { + // code or variable + rr.add(breakf("%s.Aux != %s", v, aux)) } else { - // variable - if _, ok := m[aux]; ok { - fmt.Fprintf(w, "if %s.Aux != %s {\nbreak\n}\n", v, aux) - canFail = true - } else { - m[aux] = struct{}{} - fmt.Fprintf(w, "%s := %s.Aux\n", aux, v) - } + rr.add(declf(aux, "%s.Aux", v)) } } // Access last argument first to minimize bounds checks. if n := len(args); n > 1 { a := args[n-1] - if _, set := m[a]; !set && a != "_" && isVariable(a) { - m[a] = struct{}{} - fmt.Fprintf(w, "%s := %s.Args[%d]\n", a, v, n-1) + if a != "_" && !rr.declared(a) && token.IsIdentifier(a) { + rr.add(declf(a, "%s.Args[%d]", v, n-1)) // delete the last argument so it is not reprocessed args = args[:n-1] } else { - fmt.Fprintf(w, "_ = %s.Args[%d]\n", v, n-1) + rr.add(stmtf("_ = %s.Args[%d]", v, n-1)) } } for i, arg := range args { @@ -491,41 +658,36 @@ func genMatch0(w io.Writer, arch arch, match, v string, m map[string]struct{}, l } if !strings.Contains(arg, "(") { // leaf variable - if _, ok := m[arg]; ok { + if rr.declared(arg) { // variable already has a definition. Check whether // the old definition and the new definition match. // For example, (add x x). Equality is just pointer equality // on Values (so cse is important to do before lowering). - fmt.Fprintf(w, "if %s != %s.Args[%d] {\nbreak\n}\n", arg, v, i) - canFail = true + rr.add(breakf("%s != %s.Args[%d]", arg, v, i)) } else { - // remember that this variable references the given value - m[arg] = struct{}{} - fmt.Fprintf(w, "%s := %s.Args[%d]\n", arg, v, i) + rr.add(declf(arg, "%s.Args[%d]", v, i)) } continue } // compound sexpr - var argname string + argname := fmt.Sprintf("%s_%d", v, i) colon := strings.Index(arg, ":") openparen := strings.Index(arg, "(") if colon >= 0 && openparen >= 0 && colon < openparen { // rule-specified name argname = arg[:colon] arg = arg[colon+1:] - } else { - // autogenerated name - argname = fmt.Sprintf("%s_%d", v, i) } if argname == "b" { log.Fatalf("don't name args 'b', it is ambiguous with blocks") } - fmt.Fprintf(w, "%s := %s.Args[%d]\n", argname, v, i) - w2 := new(bytes.Buffer) - argPos, argCheckOp, _ := genMatch0(w2, arch, arg, argname, m, loc) - fmt.Fprintf(w, "if %s.Op != %s {\nbreak\n}\n", argname, argCheckOp) - io.Copy(w, w2) + rr.add(declf(argname, "%s.Args[%d]", v, i)) + bexpr := exprf("%s.Op != addLater", argname) + rr.add(&CondBreak{expr: bexpr}) + rr.canFail = true // since we're not using breakf + argPos, argCheckOp := genMatch0(rr, arch, arg, argname) + bexpr.(*ast.BinaryExpr).Y.(*ast.Ident).Name = argCheckOp if argPos != "" { // Keep the argument in preference to the parent, as the @@ -535,28 +697,26 @@ func genMatch0(w io.Writer, arch arch, match, v string, m map[string]struct{}, l // in the program flow. pos = argPos } - canFail = true } if op.argLength == -1 { - fmt.Fprintf(w, "if len(%s.Args) != %d {\nbreak\n}\n", v, len(args)) - canFail = true + rr.add(breakf("len(%s.Args) != %d", v, len(args))) } - return pos, checkOp, canFail + return pos, checkOp } -func genResult(w io.Writer, arch arch, result string, loc string, pos string) { - move := false - if result[0] == '@' { +func genResult(rr *RuleRewrite, arch arch, result, pos string) { + move := result[0] == '@' + if move { // parse @block directive s := strings.SplitN(result[1:], " ", 2) - fmt.Fprintf(w, "b = %s\n", s[0]) + rr.add(stmtf("b = %s", s[0])) result = s[1] - move = true } - genResult0(w, arch, result, new(int), true, move, loc, pos) + genResult0(rr, arch, result, true, move, pos) } -func genResult0(w io.Writer, arch arch, result string, alloc *int, top, move bool, loc string, pos string) string { + +func genResult0(rr *RuleRewrite, arch arch, result string, top, move bool, pos string) string { // TODO: when generating a constant result, use f.constVal to avoid // introducing copies just to clean them up again. if result[0] != '(' { @@ -565,14 +725,14 @@ func genResult0(w io.Writer, arch arch, result string, alloc *int, top, move boo // It in not safe in general to move a variable between blocks // (and particularly not a phi node). // Introduce a copy. - fmt.Fprintf(w, "v.reset(OpCopy)\n") - fmt.Fprintf(w, "v.Type = %s.Type\n", result) - fmt.Fprintf(w, "v.AddArg(%s)\n", result) + rr.add(stmtf("v.reset(OpCopy)")) + rr.add(stmtf("v.Type = %s.Type", result)) + rr.add(stmtf("v.AddArg(%s)", result)) } return result } - op, oparch, typ, auxint, aux, args := parseValue(result, arch, loc) + op, oparch, typ, auxint, aux, args := parseValue(result, arch, rr.loc) // Find the type of the variable. typeOverride := typ != "" @@ -580,36 +740,35 @@ func genResult0(w io.Writer, arch arch, result string, alloc *int, top, move boo typ = typeName(op.typ) } - var v string + v := "v" if top && !move { - v = "v" - fmt.Fprintf(w, "v.reset(Op%s%s)\n", oparch, op.name) + rr.add(stmtf("v.reset(Op%s%s)", oparch, op.name)) if typeOverride { - fmt.Fprintf(w, "v.Type = %s\n", typ) + rr.add(stmtf("v.Type = %s", typ)) } } else { if typ == "" { - log.Fatalf("sub-expression %s (op=Op%s%s) at %s must have a type", result, oparch, op.name, loc) + log.Fatalf("sub-expression %s (op=Op%s%s) at %s must have a type", result, oparch, op.name, rr.loc) } - v = fmt.Sprintf("v%d", *alloc) - *alloc++ - fmt.Fprintf(w, "%s := b.NewValue0(%s, Op%s%s, %s)\n", v, pos, oparch, op.name, typ) + v = fmt.Sprintf("v%d", rr.alloc) + rr.alloc++ + rr.add(declf(v, "b.NewValue0(%s, Op%s%s, %s)", pos, oparch, op.name, typ)) if move && top { // Rewrite original into a copy - fmt.Fprintf(w, "v.reset(OpCopy)\n") - fmt.Fprintf(w, "v.AddArg(%s)\n", v) + rr.add(stmtf("v.reset(OpCopy)")) + rr.add(stmtf("v.AddArg(%s)", v)) } } if auxint != "" { - fmt.Fprintf(w, "%s.AuxInt = %s\n", v, auxint) + rr.add(stmtf("%s.AuxInt = %s", v, auxint)) } if aux != "" { - fmt.Fprintf(w, "%s.Aux = %s\n", v, aux) + rr.add(stmtf("%s.Aux = %s", v, aux)) } for _, arg := range args { - x := genResult0(w, arch, arg, alloc, false, move, loc, pos) - fmt.Fprintf(w, "%s.AddArg(%s)\n", v, x) + x := genResult0(rr, arch, arg, false, move, pos) + rr.add(stmtf("%s.AddArg(%s)", v, x)) } return v @@ -652,7 +811,7 @@ outer: } } if d != 0 { - panic("imbalanced expression: " + s) + log.Fatalf("imbalanced expression: %q", s) } if nonsp { r = append(r, strings.TrimSpace(s)) @@ -677,7 +836,7 @@ func isBlock(name string, arch arch) bool { return false } -func extract(val string) (op string, typ string, auxint string, aux string, args []string) { +func extract(val string) (op, typ, auxint, aux string, args []string) { val = val[1 : len(val)-1] // remove () // Split val up into regions. @@ -705,7 +864,7 @@ func extract(val string) (op string, typ string, auxint string, aux string, args // The value can be from the match or the result side. // It returns the op and unparsed strings for typ, auxint, and aux restrictions and for all args. // oparch is the architecture that op is located in, or "" for generic. -func parseValue(val string, arch arch, loc string) (op opData, oparch string, typ string, auxint string, aux string, args []string) { +func parseValue(val string, arch arch, loc string) (op opData, oparch, typ, auxint, aux string, args []string) { // Resolve the op. var s string s, typ, auxint, aux, args = extract(val) @@ -723,9 +882,8 @@ func parseValue(val string, arch arch, loc string) (op opData, oparch string, ty if x.argLength != -1 && int(x.argLength) != len(args) { if strict { return false - } else { - log.Printf("%s: op %s (%s) should have %d args, has %d", loc, s, archname, x.argLength, len(args)) } + log.Printf("%s: op %s (%s) should have %d args, has %d", loc, s, archname, x.argLength, len(args)) } return true } @@ -736,16 +894,14 @@ func parseValue(val string, arch arch, loc string) (op opData, oparch string, ty break } } - if arch.name != "generic" { - for _, x := range arch.ops { - if match(x, true, arch.name) { - if op.name != "" { - log.Fatalf("%s: matches for op %s found in both generic and %s", loc, op.name, arch.name) - } - op = x - oparch = arch.name - break + for _, x := range arch.ops { + if arch.name != "generic" && match(x, true, arch.name) { + if op.name != "" { + log.Fatalf("%s: matches for op %s found in both generic and %s", loc, op.name, arch.name) } + op = x + oparch = arch.name + break } } @@ -777,7 +933,6 @@ func parseValue(val string, arch arch, loc string) (op opData, oparch string, ty log.Fatalf("%s: op %s %s can't have aux", loc, op.name, op.aux) } } - return } @@ -795,7 +950,7 @@ func typeName(typ string) string { if typ[0] == '(' { ts := strings.Split(typ[1:len(typ)-1], ",") if len(ts) != 2 { - panic("Tuple expect 2 arguments") + log.Fatalf("Tuple expect 2 arguments") } return "types.NewTuple(" + typeName(ts[0]) + ", " + typeName(ts[1]) + ")" } @@ -809,29 +964,19 @@ func typeName(typ string) string { // unbalanced reports whether there aren't the same number of ( and ) in the string. func unbalanced(s string) bool { - var left, right int + balance := 0 for _, c := range s { if c == '(' { - left++ - } - if c == ')' { - right++ + balance++ + } else if c == ')' { + balance-- } } - return left != right + return balance != 0 } -// isVariable reports whether s is a single Go alphanumeric identifier. -func isVariable(s string) bool { - b, err := regexp.MatchString("^[A-Za-z_][A-Za-z_0-9]*$", s) - if err != nil { - panic("bad variable regexp") - } - return b -} - -// opRegexp is a regular expression to find the opcode portion of s-expressions. -var opRegexp = regexp.MustCompile(`[(](\w+[|])+\w+[)]`) +// findAllOpcode is a function to find the opcode portion of s-expressions. +var findAllOpcode = regexp.MustCompile(`[(](\w+[|])+\w+[)]`).FindAllStringIndex // excludeFromExpansion reports whether the substring s[idx[0]:idx[1]] in a rule // should be disregarded as a candidate for | expansion. @@ -859,7 +1004,7 @@ func expandOr(r string) []string { // Count width of |-forms. They must match. n := 1 - for _, idx := range opRegexp.FindAllStringIndex(r, -1) { + for _, idx := range findAllOpcode(r, -1) { if excludeFromExpansion(r, idx) { continue } @@ -882,7 +1027,7 @@ func expandOr(r string) []string { for i := 0; i < n; i++ { buf := new(strings.Builder) x := 0 - for _, idx := range opRegexp.FindAllStringIndex(r, -1) { + for _, idx := range findAllOpcode(r, -1) { if excludeFromExpansion(r, idx) { continue } @@ -913,7 +1058,7 @@ func commute(r string, arch arch) []string { if len(a) == 1 && normalizeWhitespace(r) != normalizeWhitespace(a[0]) { fmt.Println(normalizeWhitespace(r)) fmt.Println(normalizeWhitespace(a[0])) - panic("commute() is not the identity for noncommuting rule") + log.Fatalf("commute() is not the identity for noncommuting rule") } if false && len(a) > 1 { fmt.Println(r) @@ -925,18 +1070,17 @@ func commute(r string, arch arch) []string { } func commute1(m string, cnt map[string]int, arch arch) []string { - if m[0] == '<' || m[0] == '[' || m[0] == '{' || isVariable(m) { + if m[0] == '<' || m[0] == '[' || m[0] == '{' || token.IsIdentifier(m) { return []string{m} } // Split up input. var prefix string - colon := strings.Index(m, ":") - if colon >= 0 && isVariable(m[:colon]) { - prefix = m[:colon+1] - m = m[colon+1:] + if i := strings.Index(m, ":"); i >= 0 && token.IsIdentifier(m[:i]) { + prefix = m[:i+1] + m = m[i+1:] } if m[0] != '(' || m[len(m)-1] != ')' { - panic("non-compound expr in commute1: " + m) + log.Fatalf("non-compound expr in commute1: %q", m) } s := split(m[1 : len(m)-1]) op := s[0] @@ -978,7 +1122,7 @@ func commute1(m string, cnt map[string]int, arch arch) []string { } } if idx1 == 0 { - panic("couldn't find first two args of commutative op " + s[0]) + log.Fatalf("couldn't find first two args of commutative op %q", s[0]) } if cnt[s[idx0]] == 1 && cnt[s[idx1]] == 1 || s[idx0] == s[idx1] && cnt[s[idx0]] == 2 { // When we have (Add x y) with no other uses of x and y in the matching rule, @@ -1016,22 +1160,22 @@ func varCount(m string) map[string]int { varCount1(m, cnt) return cnt } + func varCount1(m string, cnt map[string]int) { if m[0] == '<' || m[0] == '[' || m[0] == '{' { return } - if isVariable(m) { + if token.IsIdentifier(m) { cnt[m]++ return } // Split up input. - colon := strings.Index(m, ":") - if colon >= 0 && isVariable(m[:colon]) { - cnt[m[:colon]]++ - m = m[colon+1:] + if i := strings.Index(m, ":"); i >= 0 && token.IsIdentifier(m[:i]) { + cnt[m[:i]]++ + m = m[i+1:] } if m[0] != '(' || m[len(m)-1] != ')' { - panic("non-compound expr in commute1: " + m) + log.Fatalf("non-compound expr in commute1: %q", m) } s := split(m[1 : len(m)-1]) for _, arg := range s[1:] {