diff --git a/src/cmd/gofmt/Makefile b/src/cmd/gofmt/Makefile index a93b8c3726..dbc134f88e 100644 --- a/src/cmd/gofmt/Makefile +++ b/src/cmd/gofmt/Makefile @@ -7,6 +7,7 @@ include $(GOROOT)/src/Make.$(GOARCH) TARG=gofmt GOFILES=\ gofmt.go\ + rewrite.go\ include $(GOROOT)/src/Make.cmd diff --git a/src/cmd/gofmt/gofmt.go b/src/cmd/gofmt/gofmt.go index bec4c88918..d7c96dc3ac 100644 --- a/src/cmd/gofmt/gofmt.go +++ b/src/cmd/gofmt/gofmt.go @@ -8,6 +8,7 @@ import ( "bytes"; "flag"; "fmt"; + "go/ast"; "go/parser"; "go/printer"; "go/scanner"; @@ -20,8 +21,9 @@ import ( var ( // main operation modes - list = flag.Bool("l", false, "list files whose formatting differs from gofmt's"); - write = flag.Bool("w", false, "write result to (source) file instead of stdout"); + list = flag.Bool("l", false, "list files whose formatting differs from gofmt's"); + write = flag.Bool("w", false, "write result to (source) file instead of stdout"); + rewriteRule = flag.String("r", "", "rewrite rule (e.g., 'α[β:len(α)] -> α[β:]')"); // debugging support comments = flag.Bool("comments", true, "print comments"); @@ -34,6 +36,8 @@ var ( var exitCode = 0 +var rewrite func(*ast.File) *ast.File + func report(err os.Error) { scanner.PrintError(os.Stderr, err); @@ -86,6 +90,10 @@ func processFile(filename string) os.Error { return err } + if rewrite != nil { + file = rewrite(file) + } + var res bytes.Buffer; _, err = (&printer.Config{printerMode(), *tabwidth, nil}).Fprint(&res, file); if err != nil { @@ -154,6 +162,8 @@ func main() { os.Exit(2); } + initRewrite(); + if flag.NArg() == 0 { if err := processFile("/dev/stdin"); err != nil { report(err) diff --git a/src/cmd/gofmt/rewrite.go b/src/cmd/gofmt/rewrite.go new file mode 100644 index 0000000000..9399bcd49f --- /dev/null +++ b/src/cmd/gofmt/rewrite.go @@ -0,0 +1,226 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt"; + "go/ast"; + "go/parser"; + "go/token"; + "os"; + "reflect"; + "strings"; + "unicode"; + "utf8"; +) + + +func initRewrite() { + if *rewriteRule == "" { + return + } + f := strings.Split(*rewriteRule, "->", 0); + if len(f) != 2 { + fmt.Fprintf(os.Stderr, "rewrite rule must be of the form 'pattern -> replacement'\n"); + os.Exit(2); + } + pattern := parseExpr(f[0], "pattern"); + replace := parseExpr(f[1], "replacement"); + rewrite = func(p *ast.File) *ast.File { return rewriteFile(pattern, replace, p) }; +} + + +// parseExpr parses s as an expression. +// It might make sense to expand this to allow statement patterns, +// but there are problems with preserving formatting and also +// with what a wildcard for a statement looks like. +func parseExpr(s string, what string) ast.Expr { + stmts, err := parser.ParseStmtList("input", s); + if err != nil { + fmt.Fprintf(os.Stderr, "parsing %s %s: %s\n", what, s, err); + os.Exit(2); + } + if len(stmts) != 1 { + fmt.Fprintf(os.Stderr, "%s must be single expression\n", what); + os.Exit(2); + } + x, ok := stmts[0].(*ast.ExprStmt); + if !ok { + fmt.Fprintf(os.Stderr, "%s must be single expression\n", what); + os.Exit(2); + } + return x.X; +} + + +// rewriteFile applys the rewrite rule pattern -> replace to an entire file. +func rewriteFile(pattern, replace ast.Expr, p *ast.File) *ast.File { + m := make(map[string]reflect.Value); + pat := reflect.NewValue(pattern); + repl := reflect.NewValue(replace); + var f func(val reflect.Value) reflect.Value; // f is recursive + f = func(val reflect.Value) reflect.Value { + for k := range m { + m[k] = nil, false + } + if match(m, pat, val) { + return subst(m, repl) + } + return apply(f, val); + }; + return apply(f, reflect.NewValue(p)).Interface().(*ast.File); +} + + +var positionType = reflect.Typeof(token.Position{}) +var zeroPosition = reflect.NewValue(token.Position{}) +var identType = reflect.Typeof((*ast.Ident)(nil)) + + +func isWildcard(s string) bool { + rune, _ := utf8.DecodeRuneInString(s); + return unicode.Is(unicode.Greek, rune) && unicode.IsLower(rune); +} + + +// apply replaces each AST field x in val with f(x), returning val. +// To avoid extra conversions, f operates on the reflect.Value form. +func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value { + if val == nil { + return nil + } + switch v := reflect.Indirect(val).(type) { + case *reflect.SliceValue: + for i := 0; i < v.Len(); i++ { + e := v.Elem(i); + e.SetValue(f(e)); + } + case *reflect.StructValue: + for i := 0; i < v.NumField(); i++ { + e := v.Field(i); + e.SetValue(f(e)); + } + case *reflect.InterfaceValue: + e := v.Elem(); + v.SetValue(f(e)); + } + return val; +} + + +// match returns true if pattern matches val, +// recording wildcard submatches in m. +// If m == nil, match checks whether pattern == val. +func match(m map[string]reflect.Value, pattern, val reflect.Value) bool { + // Wildcard matches any expression. If it appears multiple + // times in the pattern, it must match the same expression + // each time. + if m != nil && pattern.Type() == identType { + name := pattern.Interface().(*ast.Ident).Value; + if isWildcard(name) { + if old, ok := m[name]; ok { + return match(nil, old, val) + } + m[name] = val; + return true; + } + } + + // Otherwise, the expressions must match recursively. + if pattern == nil || val == nil { + return pattern == nil && val == nil + } + if pattern.Type() != val.Type() { + return false + } + + // Token positions need not match. + if pattern.Type() == positionType { + return true + } + + p := reflect.Indirect(pattern); + v := reflect.Indirect(val); + + switch p := p.(type) { + case *reflect.SliceValue: + v := v.(*reflect.SliceValue); + for i := 0; i < p.Len(); i++ { + if !match(m, p.Elem(i), v.Elem(i)) { + return false + } + } + return true; + + case *reflect.StructValue: + v := v.(*reflect.StructValue); + for i := 0; i < p.NumField(); i++ { + if !match(m, p.Field(i), v.Field(i)) { + return false + } + } + return true; + + case *reflect.InterfaceValue: + v := v.(*reflect.InterfaceValue); + return match(m, p.Elem(), v.Elem()); + } + + // Handle token integers, etc. + return p.Interface() == v.Interface(); +} + + +// subst returns a copy of pattern with values from m substituted in place of wildcards. +// if m == nil, subst returns a copy of pattern. +// Either way, the returned value has no valid line number information. +func subst(m map[string]reflect.Value, pattern reflect.Value) reflect.Value { + if pattern == nil { + return nil + } + + // Wildcard gets replaced with map value. + if m != nil && pattern.Type() == identType { + name := pattern.Interface().(*ast.Ident).Value; + if isWildcard(name) { + if old, ok := m[name]; ok { + return subst(nil, old) + } + } + } + + if pattern.Type() == positionType { + return zeroPosition + } + + // Otherwise copy. + switch p := pattern.(type) { + case *reflect.SliceValue: + v := reflect.MakeSlice(p.Type().(*reflect.SliceType), p.Len(), p.Len()); + for i := 0; i < p.Len(); i++ { + v.Elem(i).SetValue(subst(m, p.Elem(i))) + } + return v; + + case *reflect.StructValue: + v := reflect.MakeZero(p.Type()).(*reflect.StructValue); + for i := 0; i < p.NumField(); i++ { + v.Field(i).SetValue(subst(m, p.Field(i))) + } + return v; + + case *reflect.PtrValue: + v := reflect.MakeZero(p.Type()).(*reflect.PtrValue); + v.PointTo(subst(m, p.Elem())); + return v; + + case *reflect.InterfaceValue: + v := reflect.MakeZero(p.Type()).(*reflect.InterfaceValue); + v.SetValue(subst(m, p.Elem())); + return v; + } + + return pattern; +} diff --git a/src/pkg/go/printer/nodes.go b/src/pkg/go/printer/nodes.go index 6304830bd3..1c7460313a 100644 --- a/src/pkg/go/printer/nodes.go +++ b/src/pkg/go/printer/nodes.go @@ -52,6 +52,17 @@ func (p *printer) linebreak(line, min, max int, ws whiteSpace, newSection bool) case n > max: n = max } + + // TODO(gri): try to avoid direct manipulation of p.pos + // demo of why this is necessary: run gofmt -r 'i < i -> i < j' x.go on this x.go: + // package main + // func main() { + // i < i; + // j < 10; + // } + // + p.pos.Line += n; + if n > 0 { p.print(ws); if newSection { @@ -199,7 +210,7 @@ func (p *printer) exprList(prev token.Position, list []ast.Expr, depth int, mode if mode&commaSep != 0 { p.print(token.COMMA) } - if prev < line { + if prev < line && prev > 0 && line > 0 { if p.linebreak(line, 1, 2, ws, true) { ws = ignore; *multiLine = true; @@ -564,8 +575,7 @@ func (p *printer) binaryExpr(x *ast.BinaryExpr, prec1, cutoff, depth int, multiL xline := p.pos.Line; // before the operator (it may be on the next line!) yline := x.Y.Pos().Line; p.print(x.OpPos, x.Op); - if xline != yline { - //println(x.OpPos.String()); + if xline != yline && xline > 0 && yline > 0 { // at least one line break, but respect an extra empty line // in the source if p.linebreak(yline, 1, 2, ws, true) {