diff --git a/src/cmd/Makefile b/src/cmd/Makefile index 779bd44c79..fdb33f0702 100644 --- a/src/cmd/Makefile +++ b/src/cmd/Makefile @@ -41,6 +41,7 @@ CLEANDIRS=\ cgo\ ebnflint\ godoc\ + gofix\ gofmt\ goinstall\ gotype\ diff --git a/src/cmd/gofix/Makefile b/src/cmd/gofix/Makefile new file mode 100644 index 0000000000..020a6a2920 --- /dev/null +++ b/src/cmd/gofix/Makefile @@ -0,0 +1,16 @@ +# Copyright 2011 The Go Authors. All rights reserved. +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file. + +include ../../Make.inc + +TARG=gofix +GOFILES=\ + fix.go\ + main.go\ + httpserver.go\ + +include ../../Make.cmd + +test: + gotest diff --git a/src/cmd/gofix/doc.go b/src/cmd/gofix/doc.go new file mode 100644 index 0000000000..e267d5d7bf --- /dev/null +++ b/src/cmd/gofix/doc.go @@ -0,0 +1,34 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +Gofix finds Go programs that use old APIs and rewrites them to use +newer ones. After you update to a new Go release, gofix helps make +the necessary changes to your programs. + +Usage: + gofix [-r name,...] [path ...] + +Without an explicit path, gofix reads standard input and writes the +result to standard output. + +If the named path is a file, gofix rewrites the named files in place. +If the named path is a directory, gofix rewrites all .go files in that +directory tree. When gofix rewrites a file, it prints a line to standard +error giving the name of the file and the rewrite applied. + +The -r flag restricts the set of rewrites considered to those in the +named list. By default gofix considers all known rewrites. Gofix's +rewrites are idempotent, so that it is safe to apply gofix to updated +or partially updated code even without using the -r flag. + +Gofix prints the full list of fixes it can apply in its help output; +to see them, run godoc -?. + +Gofix does not make backup copies of the files that it edits. +Instead, use a version control system's ``diff'' functionality to inspect +the changes that gofix makes before committing them. + +*/ +package documentation diff --git a/src/cmd/gofix/fix.go b/src/cmd/gofix/fix.go new file mode 100644 index 0000000000..eadddbadc8 --- /dev/null +++ b/src/cmd/gofix/fix.go @@ -0,0 +1,297 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "go/ast" + "go/token" + "os" +) + +type fix struct { + name string + f func(*ast.File) bool + desc string +} + +// main runs sort.Sort(fixes) after init process is done. +type fixlist []fix + +func (f fixlist) Len() int { return len(f) } +func (f fixlist) Swap(i, j int) { f[i], f[j] = f[j], f[i] } +func (f fixlist) Less(i, j int) bool { return f[i].name < f[j].name } + +var fixes fixlist + +func register(f fix) { + fixes = append(fixes, f) +} + +// rewrite walks the AST x, calling visit(y) for each node y in the tree but +// also with a pointer to each ast.Expr, in a bottom-up traversal. +func rewrite(x interface{}, visit func(interface{})) { + switch n := x.(type) { + case *ast.Expr: + rewrite(*n, visit) + + // everything else just recurses + default: + panic(fmt.Errorf("unexpected type %T in walk", x, visit)) + + case nil: + + // These are ordered and grouped to match ../../pkg/go/ast/ast.go + case *ast.Field: + rewrite(&n.Type, visit) + case *ast.FieldList: + for _, field := range n.List { + rewrite(field, visit) + } + case *ast.BadExpr: + case *ast.Ident: + case *ast.Ellipsis: + case *ast.BasicLit: + case *ast.FuncLit: + rewrite(n.Type, visit) + rewrite(n.Body, visit) + case *ast.CompositeLit: + rewrite(&n.Type, visit) + rewrite(n.Elts, visit) + case *ast.ParenExpr: + rewrite(&n.X, visit) + case *ast.SelectorExpr: + rewrite(&n.X, visit) + case *ast.IndexExpr: + rewrite(&n.X, visit) + rewrite(&n.Index, visit) + case *ast.SliceExpr: + rewrite(&n.X, visit) + if n.Low != nil { + rewrite(&n.Low, visit) + } + if n.High != nil { + rewrite(&n.High, visit) + } + case *ast.TypeAssertExpr: + rewrite(&n.X, visit) + rewrite(&n.Type, visit) + case *ast.CallExpr: + rewrite(&n.Fun, visit) + rewrite(n.Args, visit) + case *ast.StarExpr: + rewrite(&n.X, visit) + case *ast.UnaryExpr: + rewrite(&n.X, visit) + case *ast.BinaryExpr: + rewrite(&n.X, visit) + rewrite(&n.Y, visit) + case *ast.KeyValueExpr: + rewrite(&n.Key, visit) + rewrite(&n.Value, visit) + + case *ast.ArrayType: + rewrite(&n.Len, visit) + rewrite(&n.Elt, visit) + case *ast.StructType: + rewrite(n.Fields, visit) + case *ast.FuncType: + rewrite(n.Params, visit) + if n.Results != nil { + rewrite(n.Results, visit) + } + case *ast.InterfaceType: + rewrite(n.Methods, visit) + case *ast.MapType: + rewrite(&n.Key, visit) + rewrite(&n.Value, visit) + case *ast.ChanType: + rewrite(&n.Value, visit) + + case *ast.BadStmt: + case *ast.DeclStmt: + rewrite(n.Decl, visit) + case *ast.EmptyStmt: + case *ast.LabeledStmt: + rewrite(n.Stmt, visit) + case *ast.ExprStmt: + rewrite(&n.X, visit) + case *ast.SendStmt: + rewrite(&n.Chan, visit) + rewrite(&n.Value, visit) + case *ast.IncDecStmt: + rewrite(&n.X, visit) + case *ast.AssignStmt: + rewrite(n.Lhs, visit) + if len(n.Lhs) == 2 && len(n.Rhs) == 1 { + rewrite(n.Rhs, visit) + } else { + rewrite(n.Rhs, visit) + } + case *ast.GoStmt: + rewrite(n.Call, visit) + case *ast.DeferStmt: + rewrite(n.Call, visit) + case *ast.ReturnStmt: + rewrite(n.Results, visit) + case *ast.BranchStmt: + case *ast.BlockStmt: + rewrite(n.List, visit) + case *ast.IfStmt: + rewrite(n.Init, visit) + rewrite(&n.Cond, visit) + rewrite(n.Body, visit) + rewrite(n.Else, visit) + case *ast.CaseClause: + rewrite(n.Values, visit) + rewrite(n.Body, visit) + case *ast.SwitchStmt: + rewrite(n.Init, visit) + rewrite(&n.Tag, visit) + rewrite(n.Body, visit) + case *ast.TypeCaseClause: + rewrite(n.Types, visit) + rewrite(n.Body, visit) + case *ast.TypeSwitchStmt: + rewrite(n.Init, visit) + rewrite(n.Assign, visit) + rewrite(n.Body, visit) + case *ast.CommClause: + rewrite(n.Comm, visit) + rewrite(n.Body, visit) + case *ast.SelectStmt: + rewrite(n.Body, visit) + case *ast.ForStmt: + rewrite(n.Init, visit) + rewrite(&n.Cond, visit) + rewrite(n.Post, visit) + rewrite(n.Body, visit) + case *ast.RangeStmt: + rewrite(&n.Key, visit) + rewrite(&n.Value, visit) + rewrite(&n.X, visit) + rewrite(n.Body, visit) + + case *ast.ImportSpec: + case *ast.ValueSpec: + rewrite(&n.Type, visit) + rewrite(n.Values, visit) + case *ast.TypeSpec: + rewrite(&n.Type, visit) + + case *ast.BadDecl: + case *ast.GenDecl: + rewrite(n.Specs, visit) + case *ast.FuncDecl: + if n.Recv != nil { + rewrite(n.Recv, visit) + } + rewrite(n.Type, visit) + if n.Body != nil { + rewrite(n.Body, visit) + } + + case *ast.File: + rewrite(n.Decls, visit) + + case *ast.Package: + for _, file := range n.Files { + rewrite(file, visit) + } + + case []ast.Decl: + for _, d := range n { + rewrite(d, visit) + } + case []ast.Expr: + for i := range n { + rewrite(&n[i], visit) + } + case []ast.Stmt: + for _, s := range n { + rewrite(s, visit) + } + case []ast.Spec: + for _, s := range n { + rewrite(s, visit) + } + } + visit(x) +} + +func imports(f *ast.File, path string) bool { + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if !ok { + continue + } + for _, spec := range d.Specs { + s, ok := spec.(*ast.ImportSpec) + if !ok { + continue + } + if string(s.Path.Value) == `"`+path+`"` { + return true + } + } + } + return false +} + +func isPkgDot(t ast.Expr, pkg, name string) bool { + sel, ok := t.(*ast.SelectorExpr) + if !ok { + return false + } + return isName(sel.X, pkg) && sel.Sel.String() == name +} + +func isPtrPkgDot(t ast.Expr, pkg, name string) bool { + ptr, ok := t.(*ast.StarExpr) + if !ok { + return false + } + return isPkgDot(ptr.X, pkg, name) +} + +func isName(n ast.Expr, name string) bool { + id, ok := n.(*ast.Ident) + if !ok { + return false + } + return id.String() == name +} + +func refersTo(n ast.Node, x *ast.Ident) bool { + id, ok := n.(*ast.Ident) + if !ok { + return false + } + return id.String() == x.String() +} + +func isBlank(n ast.Expr) bool { + return isName(n, "_") +} + +func isEmptyString(n ast.Expr) bool { + lit, ok := n.(*ast.BasicLit) + if !ok { + return false + } + if lit.Kind != token.STRING { + return false + } + s := string(lit.Value) + return s == `""` || s == "``" +} + +func warn(pos token.Pos, msg string, args ...interface{}) { + s := "" + if pos.IsValid() { + s = fmt.Sprintf("%s: ", fset.Position(pos).String()) + } + fmt.Fprintf(os.Stderr, "%s"+msg+"\n", append([]interface{}{s}, args...)) +} diff --git a/src/cmd/gofix/httpserver.go b/src/cmd/gofix/httpserver.go new file mode 100644 index 0000000000..88996532b4 --- /dev/null +++ b/src/cmd/gofix/httpserver.go @@ -0,0 +1,125 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" + "go/token" +) + +var httpserverFix = fix{ + "httpserver", + httpserver, +`Adapt http server methods and functions to changes +made to the http ResponseWriter interface. + +http://codereview.appspot.com/4245064 Hijacker +http://codereview.appspot.com/4239076 Header +http://codereview.appspot.com/4239077 Flusher +http://codereview.appspot.com/4248075 RemoteAddr, UsingTLS +`, +} + +func init() { + register(httpserverFix) +} + +func httpserver(f *ast.File) bool { + if !imports(f, "http") { + return false + } + + fixed := false + for _, decl := range f.Decls { + fn, ok := decl.(*ast.FuncDecl) + if !ok { + continue + } + w, req, ok := isServeHTTP(fn) + if !ok { + continue + } + rewrite(fn.Body, func(n interface{}) { + // Want to replace expression sometimes, + // so record pointer to it for updating below. + ptr, ok := n.(*ast.Expr) + if ok { + n = *ptr + } + + // Look for w.UsingTLS() and w.Remoteaddr(). + call, ok := n.(*ast.CallExpr) + if !ok || len(call.Args) != 0 { + return + } + sel, ok := call.Fun.(*ast.SelectorExpr) + if !ok { + return + } + if !refersTo(sel.X, w) { + return + } + switch sel.Sel.String() { + case "Hijack": + // replace w with w.(http.Hijacker) + sel.X = &ast.TypeAssertExpr{ + X: sel.X, + Type: ast.NewIdent("http.Hijacker"), + } + fixed = true + case "Flush": + // replace w with w.(http.Flusher) + sel.X = &ast.TypeAssertExpr{ + X: sel.X, + Type: ast.NewIdent("http.Flusher"), + } + fixed = true + case "UsingTLS": + if ptr == nil { + // can only replace expression if we have pointer to it + break + } + // replace with req.TLS != nil + *ptr = &ast.BinaryExpr{ + X: &ast.SelectorExpr{ + X: ast.NewIdent(req.String()), + Sel: ast.NewIdent("TLS"), + }, + Op: token.NEQ, + Y: ast.NewIdent("nil"), + } + fixed = true + case "RemoteAddr": + if ptr == nil { + // can only replace expression if we have pointer to it + break + } + // replace with req.RemoteAddr + *ptr = &ast.SelectorExpr{ + X: ast.NewIdent(req.String()), + Sel: ast.NewIdent("RemoteAddr"), + } + fixed = true + } + }) + } + return fixed +} + +func isServeHTTP(fn *ast.FuncDecl) (w, req *ast.Ident, ok bool) { + for _, field := range fn.Type.Params.List { + if isPkgDot(field.Type, "http", "ResponseWriter") { + w = field.Names[0] + continue + } + if isPtrPkgDot(field.Type, "http", "Request") { + req = field.Names[0] + continue + } + } + + ok = w != nil && req != nil + return +} diff --git a/src/cmd/gofix/httpserver_test.go b/src/cmd/gofix/httpserver_test.go new file mode 100644 index 0000000000..7e79056c50 --- /dev/null +++ b/src/cmd/gofix/httpserver_test.go @@ -0,0 +1,46 @@ +package main + +func init() { + addTestCases(httpserverTests) +} + +var httpserverTests = []testCase{ + { + Name: "httpserver.0", + Fn: httpserver, + In: `package main + +import "http" + +func f(xyz http.ResponseWriter, abc *http.Request, b string) { + xyz.Hijack() + xyz.Flush() + go xyz.Hijack() + defer xyz.Flush() + _ = xyz.UsingTLS() + _ = true == xyz.UsingTLS() + _ = xyz.RemoteAddr() + _ = xyz.RemoteAddr() == "hello" + if xyz.UsingTLS() { + } +} +`, + Out: `package main + +import "http" + +func f(xyz http.ResponseWriter, abc *http.Request, b string) { + xyz.(http.Hijacker).Hijack() + xyz.(http.Flusher).Flush() + go xyz.(http.Hijacker).Hijack() + defer xyz.(http.Flusher).Flush() + _ = abc.TLS != nil + _ = true == (abc.TLS != nil) + _ = abc.RemoteAddr + _ = abc.RemoteAddr == "hello" + if abc.TLS != nil { + } +} +`, + }, +} diff --git a/src/cmd/gofix/main.go b/src/cmd/gofix/main.go new file mode 100644 index 0000000000..40c86e8f21 --- /dev/null +++ b/src/cmd/gofix/main.go @@ -0,0 +1,179 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "flag" + "fmt" + "go/parser" + "go/printer" + "go/scanner" + "go/token" + "io/ioutil" + "os" + "path/filepath" + "sort" + "strings" +) + +var ( + fset = token.NewFileSet() + exitCode = 0 +) + +var allowedRewrites = flag.String("r", "", + "restrict the rewrites to this comma-separated list") + +var allowed map[string]bool + +func usage() { + fmt.Fprintf(os.Stderr, "usage: gofix [-r fixname,...] [path ...]\n") + flag.PrintDefaults() + fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n") + for _, f := range fixes { + fmt.Fprintf(os.Stderr, "\n%s\n", f.name) + desc := strings.TrimSpace(f.desc) + desc = strings.Replace(desc, "\n", "\n\t", -1) + fmt.Fprintf(os.Stderr, "\t%s\n", desc) + } + os.Exit(2) +} + +func main() { + sort.Sort(fixes) + + flag.Usage = usage + flag.Parse() + + if *allowedRewrites != "" { + allowed = make(map[string]bool) + for _, f := range strings.Split(*allowedRewrites, ",", -1) { + allowed[f] = true + } + } + + if flag.NArg() == 0 { + if err := processFile("standard input", true); err != nil { + report(err) + } + os.Exit(exitCode) + } + + for i := 0; i < flag.NArg(); i++ { + path := flag.Arg(i) + switch dir, err := os.Stat(path); { + case err != nil: + report(err) + case dir.IsRegular(): + if err := processFile(path, false); err != nil { + report(err) + } + case dir.IsDirectory(): + walkDir(path) + } + } + + os.Exit(exitCode) +} + +const ( + tabWidth = 8 + parserMode = parser.ParseComments + printerMode = printer.TabIndent +) + + +func processFile(filename string, useStdin bool) os.Error { + var f *os.File + var err os.Error + + if useStdin { + f = os.Stdin + } else { + f, err = os.Open(filename, os.O_RDONLY, 0) + if err != nil { + return err + } + defer f.Close() + } + + src, err := ioutil.ReadAll(f) + if err != nil { + return err + } + + file, err := parser.ParseFile(fset, filename, src, parserMode) + if err != nil { + return err + } + + fixed := false + var buf bytes.Buffer + for _, fix := range fixes { + if allowed != nil && !allowed[fix.desc] { + continue + } + if fix.f(file) { + fixed = true + fmt.Fprintf(&buf, " %s", fix.name) + } + } + if !fixed { + return nil + } + fmt.Fprintf(os.Stderr, "%s: %s\n", filename, buf.String()[1:]) + + buf.Reset() + _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file) + if err != nil { + return err + } + + if useStdin { + os.Stdout.Write(buf.Bytes()) + return nil + } + + return ioutil.WriteFile(f.Name(), buf.Bytes(), 0) +} + +func report(err os.Error) { + scanner.PrintError(os.Stderr, err) + exitCode = 2 +} + +func walkDir(path string) { + v := make(fileVisitor) + go func() { + filepath.Walk(path, v, v) + close(v) + }() + for err := range v { + if err != nil { + report(err) + } + } +} + +type fileVisitor chan os.Error + +func (v fileVisitor) VisitDir(path string, f *os.FileInfo) bool { + return true +} + +func (v fileVisitor) VisitFile(path string, f *os.FileInfo) { + if isGoFile(f) { + v <- nil // synchronize error handler + if err := processFile(path, false); err != nil { + v <- err + } + } +} + +func isGoFile(f *os.FileInfo) bool { + // ignore non-Go files + return f.IsRegular() && !strings.HasPrefix(f.Name, ".") && strings.HasSuffix(f.Name, ".go") +} diff --git a/src/cmd/gofix/main_test.go b/src/cmd/gofix/main_test.go new file mode 100644 index 0000000000..597bff22ac --- /dev/null +++ b/src/cmd/gofix/main_test.go @@ -0,0 +1,102 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "go/ast" + "go/parser" + "go/printer" + "testing" +) + +type testCase struct { + Name string + Fn func(*ast.File) bool + In string + Out string +} + +var testCases []testCase + +func addTestCases(t []testCase) { + testCases = append(testCases, t...) +} + +func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string) (out string, fixed, ok bool) { + file, err := parser.ParseFile(fset, desc, in, parserMode) + if err != nil { + t.Errorf("%s: parsing: %v", desc, err) + return + } + + var buf bytes.Buffer + buf.Reset() + _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file) + if err != nil { + t.Errorf("%s: printing: %v", desc, err) + return + } + if s := buf.String(); in != s { + t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s", + desc, desc, in, desc, s) + return + } + + if fn == nil { + for _, fix := range fixes { + if fix.f(file) { + fixed = true + } + } + } else { + fixed = fn(file) + } + + buf.Reset() + _, err = (&printer.Config{printerMode, tabWidth}).Fprint(&buf, fset, file) + if err != nil { + t.Errorf("%s: printing: %v", desc, err) + return + } + + return buf.String(), fixed, true +} + +func TestRewrite(t *testing.T) { + for _, tt := range testCases { + // Apply fix: should get tt.Out. + out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In) + if !ok { + continue + } + + if out != tt.Out { + t.Errorf("%s: incorrect output.\n--- have\n%s\n--- want\n%s", tt.Name, out, tt.Out) + continue + } + + if changed := out != tt.In; changed != fixed { + t.Errorf("%s: changed=%v != fixed=%v", tt.Name, changed, fixed) + continue + } + + // Should not change if run again. + out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out) + if !ok { + continue + } + + if fixed2 { + t.Errorf("%s: applied fixes during second round", tt.Name) + continue + } + + if out2 != out { + t.Errorf("%s: changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s", + tt.Name, out, out2) + } + } +} diff --git a/src/pkg/Makefile b/src/pkg/Makefile index 3edb1e60bd..24b304346d 100644 --- a/src/pkg/Makefile +++ b/src/pkg/Makefile @@ -155,6 +155,7 @@ DIRS=\ ../cmd/cgo\ ../cmd/ebnflint\ ../cmd/godoc\ + ../cmd/gofix\ ../cmd/gofmt\ ../cmd/gotype\ ../cmd/goinstall\