diff --git a/src/cmd/gofix/fix.go b/src/cmd/gofix/fix.go index 57c87531eb..0852ce21ed 100644 --- a/src/cmd/gofix/fix.go +++ b/src/cmd/gofix/fix.go @@ -9,6 +9,7 @@ import ( "go/ast" "go/token" "os" + "strconv" ) type fix struct { @@ -30,272 +31,297 @@ func register(f fix) { fixes = append(fixes, f) } -// rewrite walks the AST x, calling visit(y) for each node y in the tree but -// also with a pointer to each ast.Expr, in a bottom-up traversal. -func rewrite(x interface{}, visit func(interface{})) { - switch n := x.(type) { - case *ast.Expr: - rewrite(*n, visit) +// walk traverses the AST x, calling visit(y) for each node y in the tree but +// also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt, +// in a bottom-up traversal. +func walk(x interface{}, visit func(interface{})) { + walkBeforeAfter(x, nop, visit) +} - // everything else just recurses +func nop(interface{}) {} + +// walkBeforeAfter is like walk but calls before(x) before traversing +// x's children and after(x) afterward. +func walkBeforeAfter(x interface{}, before, after func(interface{})) { + before(x) + + switch n := x.(type) { default: - panic(fmt.Errorf("unexpected type %T in walk", x)) + panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x)) case nil: + // pointers to interfaces + case *ast.Decl: + walkBeforeAfter(*n, before, after) + case *ast.Expr: + walkBeforeAfter(*n, before, after) + case *ast.Spec: + walkBeforeAfter(*n, before, after) + case *ast.Stmt: + walkBeforeAfter(*n, before, after) + + // pointers to struct pointers + case **ast.BlockStmt: + walkBeforeAfter(*n, before, after) + case **ast.CallExpr: + walkBeforeAfter(*n, before, after) + case **ast.FieldList: + walkBeforeAfter(*n, before, after) + case **ast.FuncType: + walkBeforeAfter(*n, before, after) + + // pointers to slices + case *[]ast.Stmt: + walkBeforeAfter(*n, before, after) + case *[]ast.Expr: + walkBeforeAfter(*n, before, after) + case *[]ast.Decl: + walkBeforeAfter(*n, before, after) + case *[]ast.Spec: + walkBeforeAfter(*n, before, after) + case *[]*ast.File: + walkBeforeAfter(*n, before, after) + // These are ordered and grouped to match ../../pkg/go/ast/ast.go case *ast.Field: - rewrite(&n.Type, visit) + walkBeforeAfter(&n.Type, before, after) case *ast.FieldList: for _, field := range n.List { - rewrite(field, visit) + walkBeforeAfter(field, before, after) } case *ast.BadExpr: case *ast.Ident: case *ast.Ellipsis: case *ast.BasicLit: case *ast.FuncLit: - rewrite(n.Type, visit) - rewrite(n.Body, visit) + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Body, before, after) case *ast.CompositeLit: - rewrite(&n.Type, visit) - rewrite(n.Elts, visit) + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Elts, before, after) case *ast.ParenExpr: - rewrite(&n.X, visit) + walkBeforeAfter(&n.X, before, after) case *ast.SelectorExpr: - rewrite(&n.X, visit) + walkBeforeAfter(&n.X, before, after) case *ast.IndexExpr: - rewrite(&n.X, visit) - rewrite(&n.Index, visit) + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Index, before, after) case *ast.SliceExpr: - rewrite(&n.X, visit) + walkBeforeAfter(&n.X, before, after) if n.Low != nil { - rewrite(&n.Low, visit) + walkBeforeAfter(&n.Low, before, after) } if n.High != nil { - rewrite(&n.High, visit) + walkBeforeAfter(&n.High, before, after) } case *ast.TypeAssertExpr: - rewrite(&n.X, visit) - rewrite(&n.Type, visit) + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Type, before, after) case *ast.CallExpr: - rewrite(&n.Fun, visit) - rewrite(n.Args, visit) + walkBeforeAfter(&n.Fun, before, after) + walkBeforeAfter(&n.Args, before, after) case *ast.StarExpr: - rewrite(&n.X, visit) + walkBeforeAfter(&n.X, before, after) case *ast.UnaryExpr: - rewrite(&n.X, visit) + walkBeforeAfter(&n.X, before, after) case *ast.BinaryExpr: - rewrite(&n.X, visit) - rewrite(&n.Y, visit) + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Y, before, after) case *ast.KeyValueExpr: - rewrite(&n.Key, visit) - rewrite(&n.Value, visit) + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) case *ast.ArrayType: - rewrite(&n.Len, visit) - rewrite(&n.Elt, visit) + walkBeforeAfter(&n.Len, before, after) + walkBeforeAfter(&n.Elt, before, after) case *ast.StructType: - rewrite(n.Fields, visit) + walkBeforeAfter(&n.Fields, before, after) case *ast.FuncType: - rewrite(n.Params, visit) + walkBeforeAfter(&n.Params, before, after) if n.Results != nil { - rewrite(n.Results, visit) + walkBeforeAfter(&n.Results, before, after) } case *ast.InterfaceType: - rewrite(n.Methods, visit) + walkBeforeAfter(&n.Methods, before, after) case *ast.MapType: - rewrite(&n.Key, visit) - rewrite(&n.Value, visit) + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) case *ast.ChanType: - rewrite(&n.Value, visit) + walkBeforeAfter(&n.Value, before, after) case *ast.BadStmt: case *ast.DeclStmt: - rewrite(n.Decl, visit) + walkBeforeAfter(&n.Decl, before, after) case *ast.EmptyStmt: case *ast.LabeledStmt: - rewrite(n.Stmt, visit) + walkBeforeAfter(&n.Stmt, before, after) case *ast.ExprStmt: - rewrite(&n.X, visit) + walkBeforeAfter(&n.X, before, after) case *ast.SendStmt: - rewrite(&n.Chan, visit) - rewrite(&n.Value, visit) + walkBeforeAfter(&n.Chan, before, after) + walkBeforeAfter(&n.Value, before, after) case *ast.IncDecStmt: - rewrite(&n.X, visit) + walkBeforeAfter(&n.X, before, after) case *ast.AssignStmt: - rewrite(n.Lhs, visit) - if len(n.Lhs) == 2 && len(n.Rhs) == 1 { - rewrite(n.Rhs, visit) - } else { - rewrite(n.Rhs, visit) - } + walkBeforeAfter(&n.Lhs, before, after) + walkBeforeAfter(&n.Rhs, before, after) case *ast.GoStmt: - rewrite(n.Call, visit) + walkBeforeAfter(&n.Call, before, after) case *ast.DeferStmt: - rewrite(n.Call, visit) + walkBeforeAfter(&n.Call, before, after) case *ast.ReturnStmt: - rewrite(n.Results, visit) + walkBeforeAfter(&n.Results, before, after) case *ast.BranchStmt: case *ast.BlockStmt: - rewrite(n.List, visit) + walkBeforeAfter(&n.List, before, after) case *ast.IfStmt: - rewrite(n.Init, visit) - rewrite(&n.Cond, visit) - rewrite(n.Body, visit) - rewrite(n.Else, visit) + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Cond, before, after) + walkBeforeAfter(&n.Body, before, after) + walkBeforeAfter(&n.Else, before, after) case *ast.CaseClause: - rewrite(n.List, visit) - rewrite(n.Body, visit) + walkBeforeAfter(&n.List, before, after) + walkBeforeAfter(&n.Body, before, after) case *ast.SwitchStmt: - rewrite(n.Init, visit) - rewrite(&n.Tag, visit) - rewrite(n.Body, visit) + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Tag, before, after) + walkBeforeAfter(&n.Body, before, after) case *ast.TypeSwitchStmt: - rewrite(n.Init, visit) - rewrite(n.Assign, visit) - rewrite(n.Body, visit) + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Assign, before, after) + walkBeforeAfter(&n.Body, before, after) case *ast.CommClause: - rewrite(n.Comm, visit) - rewrite(n.Body, visit) + walkBeforeAfter(&n.Comm, before, after) + walkBeforeAfter(&n.Body, before, after) case *ast.SelectStmt: - rewrite(n.Body, visit) + walkBeforeAfter(&n.Body, before, after) case *ast.ForStmt: - rewrite(n.Init, visit) - rewrite(&n.Cond, visit) - rewrite(n.Post, visit) - rewrite(n.Body, visit) + walkBeforeAfter(&n.Init, before, after) + walkBeforeAfter(&n.Cond, before, after) + walkBeforeAfter(&n.Post, before, after) + walkBeforeAfter(&n.Body, before, after) case *ast.RangeStmt: - rewrite(&n.Key, visit) - rewrite(&n.Value, visit) - rewrite(&n.X, visit) - rewrite(n.Body, visit) + walkBeforeAfter(&n.Key, before, after) + walkBeforeAfter(&n.Value, before, after) + walkBeforeAfter(&n.X, before, after) + walkBeforeAfter(&n.Body, before, after) case *ast.ImportSpec: case *ast.ValueSpec: - rewrite(&n.Type, visit) - rewrite(n.Values, visit) + walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Values, before, after) case *ast.TypeSpec: - rewrite(&n.Type, visit) + walkBeforeAfter(&n.Type, before, after) case *ast.BadDecl: case *ast.GenDecl: - rewrite(n.Specs, visit) + walkBeforeAfter(&n.Specs, before, after) case *ast.FuncDecl: if n.Recv != nil { - rewrite(n.Recv, visit) + walkBeforeAfter(&n.Recv, before, after) } - rewrite(n.Type, visit) + walkBeforeAfter(&n.Type, before, after) if n.Body != nil { - rewrite(n.Body, visit) + walkBeforeAfter(&n.Body, before, after) } case *ast.File: - rewrite(n.Decls, visit) + walkBeforeAfter(&n.Decls, before, after) case *ast.Package: - for _, file := range n.Files { - rewrite(file, visit) - } + walkBeforeAfter(&n.Files, before, after) + case []*ast.File: + for i := range n { + walkBeforeAfter(&n[i], before, after) + } case []ast.Decl: - for _, d := range n { - rewrite(d, visit) + for i := range n { + walkBeforeAfter(&n[i], before, after) } case []ast.Expr: for i := range n { - rewrite(&n[i], visit) + walkBeforeAfter(&n[i], before, after) } case []ast.Stmt: - for _, s := range n { - rewrite(s, visit) + for i := range n { + walkBeforeAfter(&n[i], before, after) } case []ast.Spec: - for _, s := range n { - rewrite(s, visit) + for i := range n { + walkBeforeAfter(&n[i], before, after) } } - visit(x) + after(x) } +// imports returns true if f imports path. func imports(f *ast.File, path string) bool { - for _, decl := range f.Decls { - d, ok := decl.(*ast.GenDecl) - if !ok { - continue - } - for _, spec := range d.Specs { - s, ok := spec.(*ast.ImportSpec) - if !ok { - continue - } - if string(s.Path.Value) == `"`+path+`"` { - return true - } + for _, s := range f.Imports { + t, err := strconv.Unquote(s.Path.Value) + if err == nil && t == path { + return true } } return false } +// isPkgDot returns true if t is the expression "pkg.name" +// where pkg is an imported identifier. func isPkgDot(t ast.Expr, pkg, name string) bool { sel, ok := t.(*ast.SelectorExpr) - if !ok { - return false - } - return isTopName(sel.X, pkg) && sel.Sel.String() == name + return ok && isTopName(sel.X, pkg) && sel.Sel.String() == name } +// isPtrPkgDot returns true if f is the expression "*pkg.name" +// where pkg is an imported identifier. func isPtrPkgDot(t ast.Expr, pkg, name string) bool { ptr, ok := t.(*ast.StarExpr) - if !ok { - return false - } - return isPkgDot(ptr.X, pkg, name) + return ok && isPkgDot(ptr.X, pkg, name) } +// isTopName returns true if n is a top-level unresolved identifier with the given name. func isTopName(n ast.Expr, name string) bool { id, ok := n.(*ast.Ident) - if !ok { - return false - } - return id.Name == name && id.Obj == nil + return ok && id.Name == name && id.Obj == nil } +// isName returns true if n is an identifier with the given name. func isName(n ast.Expr, name string) bool { id, ok := n.(*ast.Ident) - if !ok { - return false - } - return id.String() == name + return ok && id.String() == name } +// isCall returns true if t is a call to pkg.name. func isCall(t ast.Expr, pkg, name string) bool { call, ok := t.(*ast.CallExpr) return ok && isPkgDot(call.Fun, pkg, name) } -func refersTo(n ast.Node, x *ast.Ident) bool { - id, ok := n.(*ast.Ident) - if !ok { - return false - } - return id.String() == x.String() +// If n is an *ast.Ident, isIdent returns it; otherwise isIdent returns nil. +func isIdent(n interface{}) *ast.Ident { + id, _ := n.(*ast.Ident) + return id } +// refersTo returns true if n is a reference to the same object as x. +func refersTo(n ast.Node, x *ast.Ident) bool { + id, ok := n.(*ast.Ident) + // The test of id.Name == x.Name handles top-level unresolved + // identifiers, which all have Obj == nil. + return ok && id.Obj == x.Obj && id.Name == x.Name +} + +// isBlank returns true if n is the blank identifier. func isBlank(n ast.Expr) bool { return isName(n, "_") } +// isEmptyString returns true if n is an empty string literal. func isEmptyString(n ast.Expr) bool { lit, ok := n.(*ast.BasicLit) - if !ok { - return false - } - if lit.Kind != token.STRING { - return false - } - s := string(lit.Value) - return s == `""` || s == "``" + return ok && lit.Kind == token.STRING && len(lit.Value) == 2 } func warn(pos token.Pos, msg string, args ...interface{}) { @@ -306,3 +332,91 @@ func warn(pos token.Pos, msg string, args ...interface{}) { } fmt.Fprintf(os.Stderr, msg+"\n", args...) } + +// countUses returns the number of uses of the identifier x in scope. +func countUses(x *ast.Ident, scope []ast.Stmt) int { + count := 0 + ff := func(n interface{}) { + if n, ok := n.(ast.Node); ok && refersTo(n, x) { + count++ + } + } + for _, n := range scope { + walk(n, ff) + } + return count +} + +// rewriteUses replaces all uses of the identifier x and !x in scope +// with f(x.Pos()) and fnot(x.Pos()). +func rewriteUses(x *ast.Ident, f, fnot func(token.Pos) ast.Expr, scope []ast.Stmt) { + var lastF ast.Expr + ff := func(n interface{}) { + ptr, ok := n.(*ast.Expr) + if !ok { + return + } + nn := *ptr + + // The child node was just walked and possibly replaced. + // If it was replaced and this is a negation, replace with fnot(p). + not, ok := nn.(*ast.UnaryExpr) + if ok && not.Op == token.NOT && not.X == lastF { + *ptr = fnot(nn.Pos()) + return + } + if refersTo(nn, x) { + lastF = f(nn.Pos()) + *ptr = lastF + } + } + for _, n := range scope { + walk(n, ff) + } +} + +// assignsTo returns true if any of the code in scope assigns to or takes the address of x. +func assignsTo(x *ast.Ident, scope []ast.Stmt) bool { + assigned := false + ff := func(n interface{}) { + if assigned { + return + } + switch n := n.(type) { + case *ast.UnaryExpr: + // use of &x + if n.Op == token.AND && refersTo(n.X, x) { + assigned = true + return + } + case *ast.AssignStmt: + for _, l := range n.Lhs { + if refersTo(l, x) { + assigned = true + return + } + } + } + } + for _, n := range scope { + if assigned { + break + } + walk(n, ff) + } + return assigned +} + +// newPkgDot returns an ast.Expr referring to "pkg.name" at position pos. +func newPkgDot(pos token.Pos, pkg, name string) ast.Expr { + return &ast.SelectorExpr{ + X: &ast.Ident{ + NamePos: pos, + Name: pkg, + }, + Sel: &ast.Ident{ + NamePos: pos, + Name: name, + }, + } +} diff --git a/src/cmd/gofix/httpserver.go b/src/cmd/gofix/httpserver.go index 79eea08c6f..37866e88b1 100644 --- a/src/cmd/gofix/httpserver.go +++ b/src/cmd/gofix/httpserver.go @@ -41,7 +41,7 @@ func httpserver(f *ast.File) bool { if !ok { continue } - rewrite(fn.Body, func(n interface{}) { + walk(fn.Body, func(n interface{}) { // Want to replace expression sometimes, // so record pointer to it for updating below. ptr, ok := n.(*ast.Expr) diff --git a/src/cmd/gofix/main.go b/src/cmd/gofix/main.go index c02137679f..4f7e923e3d 100644 --- a/src/cmd/gofix/main.go +++ b/src/cmd/gofix/main.go @@ -6,6 +6,7 @@ package main import ( "bytes" + "exec" "flag" "fmt" "go/parser" @@ -29,8 +30,10 @@ var allowedRewrites = flag.String("r", "", var allowed map[string]bool +var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files") + func usage() { - fmt.Fprintf(os.Stderr, "usage: gofix [-r fixname,...] [path ...]\n") + fmt.Fprintf(os.Stderr, "usage: gofix [-diff] [-r fixname,...] [path ...]\n") flag.PrintDefaults() fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n") for _, f := range fixes { @@ -85,10 +88,16 @@ const ( printerMode = printer.TabIndent | printer.UseSpaces ) +var printConfig = &printer.Config{ + printerMode, + tabWidth, +} func processFile(filename string, useStdin bool) os.Error { var f *os.File var err os.Error + var fixlog bytes.Buffer + var buf bytes.Buffer if useStdin { f = os.Stdin @@ -110,34 +119,77 @@ func processFile(filename string, useStdin bool) os.Error { return err } + // Apply all fixes to file. + newFile := file fixed := false - var buf bytes.Buffer for _, fix := range fixes { if allowed != nil && !allowed[fix.desc] { continue } - if fix.f(file) { + if fix.f(newFile) { fixed = true - fmt.Fprintf(&buf, " %s", fix.name) + fmt.Fprintf(&fixlog, " %s", fix.name) + + // AST changed. + // Print and parse, to update any missing scoping + // or position information for subsequent fixers. + buf.Reset() + _, err = printConfig.Fprint(&buf, fset, newFile) + if err != nil { + return err + } + newSrc := buf.Bytes() + newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode) + if err != nil { + return err + } } } if !fixed { return nil } - fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, buf.String()[1:]) + fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:]) + // Print AST. We did that after each fix, so this appears + // redundant, but it is necessary to generate gofmt-compatible + // source code in a few cases. The official gofmt style is the + // output of the printer run on a standard AST generated by the parser, + // but the source we generated inside the loop above is the + // output of the printer run on a mangled AST generated by a fixer. buf.Reset() - _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file) + _, err = printConfig.Fprint(&buf, fset, newFile) if err != nil { return err } + newSrc := buf.Bytes() - if useStdin { - os.Stdout.Write(buf.Bytes()) + if *doDiff { + data, err := diff(src, newSrc) + if err != nil { + return fmt.Errorf("computing diff: %s", err) + } + fmt.Printf("diff %s fixed/%s\n", filename, filename) + os.Stdout.Write(data) return nil } - return ioutil.WriteFile(f.Name(), buf.Bytes(), 0) + if useStdin { + os.Stdout.Write(newSrc) + return nil + } + + return ioutil.WriteFile(f.Name(), newSrc, 0) +} + +var gofmtBuf bytes.Buffer + +func gofmt(n interface{}) string { + gofmtBuf.Reset() + _, err := printConfig.Fprint(&gofmtBuf, fset, n) + if err != nil { + return "<" + err.String() + ">" + } + return gofmtBuf.String() } func report(err os.Error) { @@ -177,3 +229,36 @@ func isGoFile(f *os.FileInfo) bool { // ignore non-Go files return f.IsRegular() && !strings.HasPrefix(f.Name, ".") && strings.HasSuffix(f.Name, ".go") } + +func diff(b1, b2 []byte) (data []byte, err os.Error) { + f1, err := ioutil.TempFile("", "gofix") + if err != nil { + return nil, err + } + defer os.Remove(f1.Name()) + defer f1.Close() + + f2, err := ioutil.TempFile("", "gofix") + if err != nil { + return nil, err + } + defer os.Remove(f2.Name()) + defer f2.Close() + + f1.Write(b1) + f2.Write(b2) + + diffcmd, err := exec.LookPath("diff") + if err != nil { + return nil, err + } + + c, err := exec.Run(diffcmd, []string{"diff", f1.Name(), f2.Name()}, nil, "", + exec.DevNull, exec.Pipe, exec.MergeWithStdout) + if err != nil { + return nil, err + } + defer c.Close() + + return ioutil.ReadAll(c.Stdout) +} diff --git a/src/cmd/gofix/main_test.go b/src/cmd/gofix/main_test.go index e4d0f60cce..275778e5be 100644 --- a/src/cmd/gofix/main_test.go +++ b/src/cmd/gofix/main_test.go @@ -6,12 +6,10 @@ package main import ( "bytes" - "exec" "go/ast" "go/parser" "go/printer" - "io/ioutil" - "os" + "strings" "testing" ) @@ -28,6 +26,8 @@ func addTestCases(t []testCase) { testCases = append(testCases, t...) } +func fnop(*ast.File) bool { return false } + func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out string, fixed, ok bool) { file, err := parser.ParseFile(fset, desc, in, parserMode) if err != nil { @@ -42,7 +42,7 @@ func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out t.Errorf("%s: printing: %v", desc, err) return } - if s := buf.String(); in != s { + if s := buf.String(); in != s && fn != fnop { t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s", desc, desc, in, desc, s) tdiff(t, in, s) @@ -77,8 +77,17 @@ func TestRewrite(t *testing.T) { continue } + // reformat to get printing right + out, _, ok = parseFixPrint(t, fnop, tt.Name, out) + if !ok { + continue + } + if out != tt.Out { - t.Errorf("%s: incorrect output.\n--- have\n%s\n--- want\n%s", tt.Name, out, tt.Out) + t.Errorf("%s: incorrect output.\n", tt.Name) + if !strings.HasPrefix(tt.Name, "testdata/") { + t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out) + } tdiff(t, out, tt.Out) continue } @@ -108,44 +117,10 @@ func TestRewrite(t *testing.T) { } func tdiff(t *testing.T, a, b string) { - f1, err := ioutil.TempFile("", "gofix") + data, err := diff([]byte(a), []byte(b)) if err != nil { t.Error(err) return } - defer os.Remove(f1.Name()) - defer f1.Close() - - f2, err := ioutil.TempFile("", "gofix") - if err != nil { - t.Error(err) - return - } - defer os.Remove(f2.Name()) - defer f2.Close() - - f1.Write([]byte(a)) - f2.Write([]byte(b)) - - diffcmd, err := exec.LookPath("diff") - if err != nil { - t.Error(err) - return - } - - c, err := exec.Run(diffcmd, []string{"diff", f1.Name(), f2.Name()}, nil, "", - exec.DevNull, exec.Pipe, exec.MergeWithStdout) - if err != nil { - t.Error(err) - return - } - defer c.Close() - - data, err := ioutil.ReadAll(c.Stdout) - if err != nil { - t.Error(err) - return - } - t.Error(string(data)) } diff --git a/src/cmd/gofix/netdial.go b/src/cmd/gofix/netdial.go index d1531b647e..afa98953b9 100644 --- a/src/cmd/gofix/netdial.go +++ b/src/cmd/gofix/netdial.go @@ -47,7 +47,7 @@ func netdial(f *ast.File) bool { } fixed := false - rewrite(f, func(n interface{}) { + walk(f, func(n interface{}) { call, ok := n.(*ast.CallExpr) if !ok || !isPkgDot(call.Fun, "net", "Dial") || len(call.Args) != 3 { return @@ -70,7 +70,7 @@ func tlsdial(f *ast.File) bool { } fixed := false - rewrite(f, func(n interface{}) { + walk(f, func(n interface{}) { call, ok := n.(*ast.CallExpr) if !ok || !isPkgDot(call.Fun, "tls", "Dial") || len(call.Args) != 4 { return @@ -94,7 +94,7 @@ func netlookup(f *ast.File) bool { } fixed := false - rewrite(f, func(n interface{}) { + walk(f, func(n interface{}) { as, ok := n.(*ast.AssignStmt) if !ok || len(as.Lhs) != 3 || len(as.Rhs) != 1 { return diff --git a/src/cmd/gofix/osopen.go b/src/cmd/gofix/osopen.go index 49993d8f99..2acf1c4556 100644 --- a/src/cmd/gofix/osopen.go +++ b/src/cmd/gofix/osopen.go @@ -27,7 +27,7 @@ func osopen(f *ast.File) bool { } fixed := false - rewrite(f, func(n interface{}) { + walk(f, func(n interface{}) { // Rename O_CREAT to O_CREATE. if expr, ok := n.(ast.Expr); ok && isPkgDot(expr, "os", "O_CREAT") { expr.(*ast.SelectorExpr).Sel.Name = "O_CREATE" diff --git a/src/cmd/gofix/procattr.go b/src/cmd/gofix/procattr.go index 80b75d1d48..0e2190b1f4 100644 --- a/src/cmd/gofix/procattr.go +++ b/src/cmd/gofix/procattr.go @@ -28,7 +28,7 @@ func procattr(f *ast.File) bool { } fixed := false - rewrite(f, func(n interface{}) { + walk(f, func(n interface{}) { call, ok := n.(*ast.CallExpr) if !ok || len(call.Args) != 5 { return diff --git a/src/cmd/gofix/typecheck.go b/src/cmd/gofix/typecheck.go new file mode 100644 index 0000000000..d565e7b4bd --- /dev/null +++ b/src/cmd/gofix/typecheck.go @@ -0,0 +1,579 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "go/ast" + "go/token" + "os" + "reflect" + "strings" +) + +// Partial type checker. +// +// The fact that it is partial is very important: the input is +// an AST and a description of some type information to +// assume about one or more packages, but not all the +// packages that the program imports. The checker is +// expected to do as much as it can with what it has been +// given. There is not enough information supplied to do +// a full type check, but the type checker is expected to +// apply information that can be derived from variable +// declarations, function and method returns, and type switches +// as far as it can, so that the caller can still tell the types +// of expression relevant to a particular fix. +// +// TODO(rsc,gri): Replace with go/typechecker. +// Doing that could be an interesting test case for go/typechecker: +// the constraints about working with partial information will +// likely exercise it in interesting ways. The ideal interface would +// be to pass typecheck a map from importpath to package API text +// (Go source code), but for now we use data structures (TypeConfig, Type). +// +// The strings mostly use gofmt form. +// +// A Field or FieldList has as its type a comma-separated list +// of the types of the fields. For example, the field list +// x, y, z int +// has type "int, int, int". + +// The prefix "type " is the type of a type. +// For example, given +// var x int +// type T int +// x's type is "int" but T's type is "type int". +// mkType inserts the "type " prefix. +// getType removes it. +// isType tests for it. + +func mkType(t string) string { + return "type " + t +} + +func getType(t string) string { + if !isType(t) { + return "" + } + return t[len("type "):] +} + +func isType(t string) bool { + return strings.HasPrefix(t, "type ") +} + +// TypeConfig describes the universe of relevant types. +// For ease of creation, the types are all referred to by string +// name (e.g., "reflect.Value"). TypeByName is the only place +// where the strings are resolved. + +type TypeConfig struct { + Type map[string]*Type + Var map[string]string + Func map[string]string +} + +// typeof returns the type of the given name, which may be of +// the form "x" or "p.X". +func (cfg *TypeConfig) typeof(name string) string { + if cfg.Var != nil { + if t := cfg.Var[name]; t != "" { + return t + } + } + if cfg.Func != nil { + if t := cfg.Func[name]; t != "" { + return "func()" + t + } + } + return "" +} + +// Type describes the Fields and Methods of a type. +// If the field or method cannot be found there, it is next +// looked for in the Embed list. +type Type struct { + Field map[string]string // map field name to type + Method map[string]string // map method name to comma-separated return types + Embed []string // list of types this type embeds (for extra methods) +} + +// dot returns the type of "typ.name", making its decision +// using the type information in cfg. +func (typ *Type) dot(cfg *TypeConfig, name string) string { + if typ.Field != nil { + if t := typ.Field[name]; t != "" { + return t + } + } + if typ.Method != nil { + if t := typ.Method[name]; t != "" { + return t + } + } + + for _, e := range typ.Embed { + etyp := cfg.Type[e] + if etyp != nil { + if t := etyp.dot(cfg, name); t != "" { + return t + } + } + } + + return "" +} + +// typecheck type checks the AST f assuming the information in cfg. +// It returns a map from AST nodes to type information in gofmt string form. +func typecheck(cfg *TypeConfig, f *ast.File) map[interface{}]string { + typeof := make(map[interface{}]string) + + // gather function declarations + for _, decl := range f.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + typecheck1(cfg, fn.Type, typeof) + t := typeof[fn.Type] + if fn.Recv != nil { + // The receiver must be a type. + rcvr := typeof[fn.Recv] + if !isType(rcvr) { + if len(fn.Recv.List) != 1 { + continue + } + rcvr = mkType(gofmt(fn.Recv.List[0].Type)) + typeof[fn.Recv.List[0].Type] = rcvr + } + rcvr = getType(rcvr) + if rcvr != "" && rcvr[0] == '*' { + rcvr = rcvr[1:] + } + typeof[rcvr+"."+fn.Name.Name] = t + } else { + if isType(t) { + t = getType(t) + } else { + t = gofmt(fn.Type) + } + typeof[fn.Name] = t + + // Record typeof[fn.Name.Obj] for future references to fn.Name. + typeof[fn.Name.Obj] = t + } + } + + typecheck1(cfg, f, typeof) + return typeof +} + +func makeExprList(a []*ast.Ident) []ast.Expr { + var b []ast.Expr + for _, x := range a { + b = append(b, x) + } + return b +} + +// Typecheck1 is the recursive form of typecheck. +// It is like typecheck but adds to the information in typeof +// instead of allocating a new map. +func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { + // set sets the type of n to typ. + // If isDecl is true, n is being declared. + set := func(n ast.Expr, typ string, isDecl bool) { + if typeof[n] != "" || typ == "" { + return + } + typeof[n] = typ + + // If we obtained typ from the declaration of x + // propagate the type to all the uses. + // The !isDecl case is a cheat here, but it makes + // up in some cases for not paying attention to + // struct fields. The real type checker will be + // more accurate so we won't need the cheat. + if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") { + typeof[id.Obj] = typ + } + } + + // Type-check an assignment lhs = rhs. + // If isDecl is true, this is := so we can update + // the types of the objects that lhs refers to. + typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) { + if len(lhs) > 1 && len(rhs) == 1 { + if _, ok := rhs[0].(*ast.CallExpr); ok { + t := split(typeof[rhs[0]]) + // Lists should have same length but may not; pair what can be paired. + for i := 0; i < len(lhs) && i < len(t); i++ { + set(lhs[i], t[i], isDecl) + } + return + } + } + if len(lhs) == 1 && len(rhs) == 2 { + // x = y, ok + rhs = rhs[:1] + } else if len(lhs) == 2 && len(rhs) == 1 { + // x, ok = y + lhs = lhs[:1] + } + + // Match as much as we can. + for i := 0; i < len(lhs) && i < len(rhs); i++ { + x, y := lhs[i], rhs[i] + if typeof[y] != "" { + set(x, typeof[y], isDecl) + } else { + set(y, typeof[x], false) + } + } + } + + // The main type check is a recursive algorithm implemented + // by walkBeforeAfter(n, before, after). + // Most of it is bottom-up, but in a few places we need + // to know the type of the function we are checking. + // The before function records that information on + // the curfn stack. + var curfn []*ast.FuncType + + before := func(n interface{}) { + // push function type on stack + switch n := n.(type) { + case *ast.FuncDecl: + curfn = append(curfn, n.Type) + case *ast.FuncLit: + curfn = append(curfn, n.Type) + } + } + + // After is the real type checker. + after := func(n interface{}) { + if n == nil { + return + } + if false && reflect.Typeof(n).Kind() == reflect.Ptr { // debugging trace + defer func() { + if t := typeof[n]; t != "" { + pos := fset.Position(n.(ast.Node).Pos()) + fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos.String(), gofmt(n), t) + } + }() + } + + switch n := n.(type) { + case *ast.FuncDecl, *ast.FuncLit: + // pop function type off stack + curfn = curfn[:len(curfn)-1] + + case *ast.FuncType: + typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results]))) + + case *ast.FieldList: + // Field list is concatenation of sub-lists. + t := "" + for _, field := range n.List { + if t != "" { + t += ", " + } + t += typeof[field] + } + typeof[n] = t + + case *ast.Field: + // Field is one instance of the type per name. + all := "" + t := typeof[n.Type] + if !isType(t) { + // Create a type, because it is typically *T or *p.T + // and we might care about that type. + t = mkType(gofmt(n.Type)) + typeof[n.Type] = t + } + t = getType(t) + if len(n.Names) == 0 { + all = t + } else { + for _, id := range n.Names { + if all != "" { + all += ", " + } + all += t + typeof[id.Obj] = t + typeof[id] = t + } + } + typeof[n] = all + + case *ast.ValueSpec: + // var declaration. Use type if present. + if n.Type != nil { + t := typeof[n.Type] + if !isType(t) { + t = mkType(gofmt(n.Type)) + typeof[n.Type] = t + } + t = getType(t) + for _, id := range n.Names { + set(id, t, true) + } + } + // Now treat same as assignment. + typecheckAssign(makeExprList(n.Names), n.Values, true) + + case *ast.AssignStmt: + typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE) + + case *ast.Ident: + // Identifier can take its type from underlying object. + if t := typeof[n.Obj]; t != "" { + typeof[n] = t + } + + case *ast.SelectorExpr: + // Field or method. + name := n.Sel.Name + if t := typeof[n.X]; t != "" { + if strings.HasPrefix(t, "*") { + t = t[1:] // implicit * + } + if typ := cfg.Type[t]; typ != nil { + if t := typ.dot(cfg, name); t != "" { + typeof[n] = t + return + } + } + tt := typeof[t+"."+name] + if isType(tt) { + typeof[n] = getType(tt) + return + } + } + // Package selector. + if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil { + str := x.Name + "." + name + if cfg.Type[str] != nil { + typeof[n] = mkType(str) + return + } + if t := cfg.typeof(x.Name + "." + name); t != "" { + typeof[n] = t + return + } + } + + case *ast.CallExpr: + // make(T) has type T. + if isTopName(n.Fun, "make") && len(n.Args) >= 1 { + typeof[n] = gofmt(n.Args[0]) + return + } + // Otherwise, use type of function to determine arguments. + t := typeof[n.Fun] + in, out := splitFunc(t) + if in == nil && out == nil { + return + } + typeof[n] = join(out) + for i, arg := range n.Args { + if i >= len(in) { + break + } + if typeof[arg] == "" { + typeof[arg] = in[i] + } + } + + case *ast.TypeAssertExpr: + // x.(type) has type of x. + if n.Type == nil { + typeof[n] = typeof[n.X] + return + } + // x.(T) has type T. + if t := typeof[n.Type]; isType(t) { + typeof[n] = getType(t) + } + + case *ast.SliceExpr: + // x[i:j] has type of x. + typeof[n] = typeof[n.X] + + case *ast.IndexExpr: + // x[i] has key type of x's type. + t := typeof[n.X] + if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") { + // Lazy: assume there are no nested [] in the array + // length or map key type. + if i := strings.Index(t, "]"); i >= 0 { + typeof[n] = t[i+1:] + } + } + + case *ast.StarExpr: + // *x for x of type *T has type T when x is an expr. + // We don't use the result when *x is a type, but + // compute it anyway. + t := typeof[n.X] + if isType(t) { + typeof[n] = "type *" + getType(t) + } else if strings.HasPrefix(t, "*") { + typeof[n] = t[len("*"):] + } + + case *ast.UnaryExpr: + // &x for x of type T has type *T. + t := typeof[n.X] + if t != "" && n.Op == token.AND { + typeof[n] = "&" + t + } + + case *ast.CompositeLit: + // T{...} has type T. + typeof[n] = gofmt(n.Type) + + case *ast.ParenExpr: + // (x) has type of x. + typeof[n] = typeof[n.X] + + case *ast.TypeSwitchStmt: + // Type of variable changes for each case in type switch, + // but go/parser generates just one variable. + // Repeat type check for each case with more precise + // type information. + as, ok := n.Assign.(*ast.AssignStmt) + if !ok { + return + } + varx, ok := as.Lhs[0].(*ast.Ident) + if !ok { + return + } + t := typeof[varx] + for _, cas := range n.Body.List { + cas := cas.(*ast.CaseClause) + if len(cas.List) == 1 { + // Variable has specific type only when there is + // exactly one type in the case list. + if tt := typeof[cas.List[0]]; isType(tt) { + tt = getType(tt) + typeof[varx] = tt + typeof[varx.Obj] = tt + typecheck1(cfg, cas.Body, typeof) + } + } + } + // Restore t. + typeof[varx] = t + typeof[varx.Obj] = t + + case *ast.ReturnStmt: + if len(curfn) == 0 { + // Probably can't happen. + return + } + f := curfn[len(curfn)-1] + res := n.Results + if f.Results != nil { + t := split(typeof[f.Results]) + for i := 0; i < len(res) && i < len(t); i++ { + set(res[i], t[i], false) + } + } + } + } + walkBeforeAfter(f, before, after) +} + +// Convert between function type strings and lists of types. +// Using strings makes this a little harder, but it makes +// a lot of the rest of the code easier. This will all go away +// when we can use go/typechecker directly. + +// splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"]. +func splitFunc(s string) (in, out []string) { + if !strings.HasPrefix(s, "func(") { + return nil, nil + } + + i := len("func(") // index of beginning of 'in' arguments + nparen := 0 + for j := i; j < len(s); j++ { + switch s[j] { + case '(': + nparen++ + case ')': + nparen-- + if nparen < 0 { + // found end of parameter list + out := strings.TrimSpace(s[j+1:]) + if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' { + out = out[1 : len(out)-1] + } + return split(s[i:j]), split(out) + } + } + } + return nil, nil +} + +// joinFunc is the inverse of splitFunc. +func joinFunc(in, out []string) string { + outs := "" + if len(out) == 1 { + outs = " " + out[0] + } else if len(out) > 1 { + outs = " (" + join(out) + ")" + } + return "func(" + join(in) + ")" + outs +} + +// split splits "int, float" into ["int", "float"] and splits "" into []. +func split(s string) []string { + out := []string{} + i := 0 // current type being scanned is s[i:j]. + nparen := 0 + for j := 0; j < len(s); j++ { + switch s[j] { + case ' ': + if i == j { + i++ + } + case '(': + nparen++ + case ')': + nparen-- + if nparen < 0 { + // probably can't happen + return nil + } + case ',': + if nparen == 0 { + if i < j { + out = append(out, s[i:j]) + } + i = j + 1 + } + } + } + if nparen != 0 { + // probably can't happen + return nil + } + if i < len(s) { + out = append(out, s[i:]) + } + return out +} + +// join is the inverse of split. +func join(x []string) string { + return strings.Join(x, ", ") +}