diff --git a/src/cmd/gofix/Makefile b/src/cmd/gofix/Makefile index 72d690f58d..fea50cccc5 100644 --- a/src/cmd/gofix/Makefile +++ b/src/cmd/gofix/Makefile @@ -6,6 +6,7 @@ include ../../Make.inc TARG=gofix GOFILES=\ + error.go\ filepath.go\ fix.go\ httpfinalurl.go\ diff --git a/src/cmd/gofix/error.go b/src/cmd/gofix/error.go new file mode 100644 index 0000000000..e0ced633d9 --- /dev/null +++ b/src/cmd/gofix/error.go @@ -0,0 +1,352 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "go/ast" + "regexp" + "strings" +) + +func init() { + fixes = append(fixes, errorFix) +} + +var errorFix = fix{ + "error", + errorFn, + `Use error instead of os.Error. + +This fix rewrites code using os.Error to use error: + + os.Error -> error + os.NewError -> errors.New + os.EOF -> io.EOF + +Seeing the old names above (os.Error and so on) triggers the following +heuristic rewrites. The heuristics can be forced using the -force=error flag. + +A top-level function, variable, or constant named error is renamed error_. + +Error implementations—those types used as os.Error or named +XxxError—have their String methods renamed to Error. Any existing +Error field or method is renamed to Err. + +Error values—those with type os.Error or named e, err, error, err1, +and so on—have method calls and field references rewritten just +as the types do (String to Error, Error to Err). Also, a type assertion +of the form err.(*os.Waitmsg) becomes err.(*exec.ExitError). + +http://codereview.appspot.com/5305066 +`, +} + +// At minimum, this fix applies the following rewrites: +// +// os.Error -> error +// os.NewError -> errors.New +// os.EOF -> io.EOF +// +// However, if can apply any of those rewrites, it assumes that the +// file predates the error type and tries to update the code to use +// the new definition for error - an Error method, not a String method. +// This more heuristic procedure may not be 100% accurate, so it is +// only run when the file needs updating anyway. The heuristic can +// be forced to run using -force=error. +// +// First, we must identify the implementations of os.Error. +// These include the type of any value returned as or assigned to an os.Error. +// To that set we add any type whose name contains "Error" or "error". +// The heuristic helps for implementations that are not used as os.Error +// in the file in which they are defined. +// +// In any implementation of os.Error, we rename an existing struct field +// or method named Error to Err and rename the String method to Error. +// +// Second, we must identify the values of type os.Error. +// These include any value that obviously has type os.Error. +// To that set we add any variable whose name is e or err or error +// possibly followed by _ or a numeric or capitalized suffix. +// The heuristic helps for variables that are initialized using calls +// to functions in other packages. The type checker does not have +// information about those packages available, and in general cannot +// (because the packages may themselves not compile). +// +// For any value of type os.Error, we replace a call to String with a call to Error. +// We also replace type assertion err.(*os.Waitmsg) with err.(*exec.ExitError). + +// Variables matching this regexp are assumed to have type os.Error. +var errVar = regexp.MustCompile(`^(e|err|error)_?([A-Z0-9].*)?$`) + +// Types matching this regexp are assumed to be implementations of os.Error. +var errType = regexp.MustCompile(`^\*?([Ee]rror|.*Error)$`) + +// Type-checking configuration: tell the type-checker this basic +// information about types, functions, and variables in external packages. +var errorTypeConfig = &TypeConfig{ + Type: map[string]*Type{ + "os.Error": &Type{}, + }, + Func: map[string]string{ + "fmt.Errorf": "os.Error", + "os.NewError": "os.Error", + }, + Var: map[string]string{ + "os.EPERM": "os.Error", + "os.ENOENT": "os.Error", + "os.ESRCH": "os.Error", + "os.EINTR": "os.Error", + "os.EIO": "os.Error", + "os.ENXIO": "os.Error", + "os.E2BIG": "os.Error", + "os.ENOEXEC": "os.Error", + "os.EBADF": "os.Error", + "os.ECHILD": "os.Error", + "os.EDEADLK": "os.Error", + "os.ENOMEM": "os.Error", + "os.EACCES": "os.Error", + "os.EFAULT": "os.Error", + "os.EBUSY": "os.Error", + "os.EEXIST": "os.Error", + "os.EXDEV": "os.Error", + "os.ENODEV": "os.Error", + "os.ENOTDIR": "os.Error", + "os.EISDIR": "os.Error", + "os.EINVAL": "os.Error", + "os.ENFILE": "os.Error", + "os.EMFILE": "os.Error", + "os.ENOTTY": "os.Error", + "os.EFBIG": "os.Error", + "os.ENOSPC": "os.Error", + "os.ESPIPE": "os.Error", + "os.EROFS": "os.Error", + "os.EMLINK": "os.Error", + "os.EPIPE": "os.Error", + "os.EAGAIN": "os.Error", + "os.EDOM": "os.Error", + "os.ERANGE": "os.Error", + "os.EADDRINUSE": "os.Error", + "os.ECONNREFUSED": "os.Error", + "os.ENAMETOOLONG": "os.Error", + "os.EAFNOSUPPORT": "os.Error", + "os.ETIMEDOUT": "os.Error", + "os.ENOTCONN": "os.Error", + }, +} + +func errorFn(f *ast.File) bool { + if !imports(f, "os") && !force["error"] { + return false + } + + // Fix gets called once to run the heuristics described above + // when we notice that this file definitely needs fixing + // (it mentions os.Error or something similar). + var fixed bool + var didHeuristic bool + heuristic := func() { + if didHeuristic { + return + } + didHeuristic = true + + // We have identified a necessary fix (like os.Error -> error) + // but have not applied it or any others yet. Prepare the file + // for fixing and apply heuristic fixes. + + // Rename error to error_ to make room for error. + fixed = renameTop(f, "error", "error_") || fixed + + // Use type checker to build list of error implementations. + typeof, assign := typecheck(errorTypeConfig, f) + + isError := map[string]bool{} + for _, val := range assign["os.Error"] { + t := typeof[val] + if strings.HasPrefix(t, "*") { + t = t[1:] + } + if t != "" && !strings.HasPrefix(t, "func(") { + isError[t] = true + } + } + + // We use both the type check results and the "Error" name heuristic + // to identify implementations of os.Error. + isErrorImpl := func(typ string) bool { + return isError[typ] || errType.MatchString(typ) + } + + isErrorVar := func(x ast.Expr) bool { + if typ := typeof[x]; typ != "" { + return isErrorImpl(typ) || typ == "os.Error" + } + if sel, ok := x.(*ast.SelectorExpr); ok { + return sel.Sel.Name == "Error" || sel.Sel.Name == "Err" + } + if id, ok := x.(*ast.Ident); ok { + return errVar.MatchString(id.Name) + } + return false + } + + walk(f, func(n interface{}) { + // In method declaration on error implementation type, + // rename String() to Error() and Error() to Err(). + fn, ok := n.(*ast.FuncDecl) + if ok && + fn.Recv != nil && + len(fn.Recv.List) == 1 && + isErrorImpl(typeName(fn.Recv.List[0].Type)) { + // Rename. + switch fn.Name.Name { + case "String": + fn.Name.Name = "Error" + fixed = true + case "Error": + fn.Name.Name = "Err" + fixed = true + } + return + } + + // In type definition of an error implementation type, + // rename Error field to Err to make room for method. + // Given type XxxError struct { ... Error T } rename field to Err. + d, ok := n.(*ast.GenDecl) + if ok { + for _, s := range d.Specs { + switch s := s.(type) { + case *ast.TypeSpec: + if isErrorImpl(typeName(s.Name)) { + st, ok := s.Type.(*ast.StructType) + if ok { + for _, f := range st.Fields.List { + for _, n := range f.Names { + if n.Name == "Error" { + n.Name = "Err" + fixed = true + } + } + } + } + } + } + } + } + + // For values that are an error implementation type, + // rename .Error to .Err and .String to .Error + sel, selok := n.(*ast.SelectorExpr) + if selok && isErrorImpl(typeof[sel.X]) { + switch sel.Sel.Name { + case "Error": + sel.Sel.Name = "Err" + fixed = true + case "String": + sel.Sel.Name = "Error" + fixed = true + } + } + + // Assume x.Err is an error value and rename .String to .Error + // Children have been processed so the rewrite from Error to Err + // has already happened there. + if selok { + if subsel, ok := sel.X.(*ast.SelectorExpr); ok && subsel.Sel.Name == "Err" && sel.Sel.Name == "String" { + sel.Sel.Name = "Error" + fixed = true + } + } + + // For values that are an error variable, rename .String to .Error. + if selok && isErrorVar(sel.X) && sel.Sel.Name == "String" { + sel.Sel.Name = "Error" + fixed = true + } + + // Rewrite composite literal of error type to turn Error: into Err:. + lit, ok := n.(*ast.CompositeLit) + if ok && isErrorImpl(typeof[lit]) { + for _, e := range lit.Elts { + if kv, ok := e.(*ast.KeyValueExpr); ok && isName(kv.Key, "Error") { + kv.Key.(*ast.Ident).Name = "Err" + fixed = true + } + } + } + + // Rename os.Waitmsg to exec.ExitError + // when used in a type assertion on an error. + ta, ok := n.(*ast.TypeAssertExpr) + if ok && isErrorVar(ta.X) && isPtrPkgDot(ta.Type, "os", "Waitmsg") { + addImport(f, "exec") + sel := ta.Type.(*ast.StarExpr).X.(*ast.SelectorExpr) + sel.X.(*ast.Ident).Name = "exec" + sel.Sel.Name = "ExitError" + fixed = true + } + + }) + } + + fix := func() { + if fixed { + return + } + fixed = true + heuristic() + } + + if force["error"] { + heuristic() + } + + walk(f, func(n interface{}) { + p, ok := n.(*ast.Expr) + if !ok { + return + } + sel, ok := (*p).(*ast.SelectorExpr) + if !ok { + return + } + switch { + case isPkgDot(sel, "os", "Error"): + fix() + *p = &ast.Ident{NamePos: sel.Pos(), Name: "error"} + case isPkgDot(sel, "os", "NewError"): + fix() + addImport(f, "errors") + sel.X.(*ast.Ident).Name = "errors" + sel.Sel.Name = "New" + case isPkgDot(sel, "os", "EOF"): + fix() + addImport(f, "io") + sel.X.(*ast.Ident).Name = "io" + } + }) + + if fixed && !usesImport(f, "os") { + deleteImport(f, "os") + } + + return fixed +} + +func typeName(typ ast.Expr) string { + if p, ok := typ.(*ast.StarExpr); ok { + typ = p.X + } + id, ok := typ.(*ast.Ident) + if ok { + return id.Name + } + sel, ok := typ.(*ast.SelectorExpr) + if ok { + return typeName(sel.X) + "." + sel.Sel.Name + } + return "" +} diff --git a/src/cmd/gofix/error_test.go b/src/cmd/gofix/error_test.go new file mode 100644 index 0000000000..eeab7e2ee1 --- /dev/null +++ b/src/cmd/gofix/error_test.go @@ -0,0 +1,232 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +func init() { + addTestCases(errorTests, errorFn) +} + +var errorTests = []testCase{ + { + Name: "error.0", + In: `package main + +func error() {} + +var error int +`, + Out: `package main + +func error() {} + +var error int +`, + }, + { + Name: "error.1", + In: `package main + +import "os" + +func f() os.Error { + return os.EOF +} + +func error() {} + +var error int + +func g() { + error := 1 + _ = error +} +`, + Out: `package main + +import "io" + +func f() error { + return io.EOF +} + +func error_() {} + +var error_ int + +func g() { + error := 1 + _ = error +} +`, + }, + { + Name: "error.2", + In: `package main + +import "os" + +func f() os.Error { + return os.EOF +} + +func g() string { + // these all convert because f is known + if err := f(); err != nil { + return err.String() + } + if err1 := f(); err1 != nil { + return err1.String() + } + if e := f(); e != nil { + return e.String() + } + if x := f(); x != nil { + return x.String() + } + + // only the error names (err, err1, e) convert; u is not known + if err := u(); err != nil { + return err.String() + } + if err1 := u(); err1 != nil { + return err1.String() + } + if e := u(); e != nil { + return e.String() + } + if x := u(); x != nil { + return x.String() + } + return "" +} + +type T int + +func (t T) String() string { return "t" } + +type PT int + +func (p *PT) String() string { return "pt" } + +type MyError int + +func (t MyError) String() string { return "myerror" } + +type PMyError int + +func (p *PMyError) String() string { return "pmyerror" } + +func error() {} + +var error int +`, + Out: `package main + +import "io" + +func f() error { + return io.EOF +} + +func g() string { + // these all convert because f is known + if err := f(); err != nil { + return err.Error() + } + if err1 := f(); err1 != nil { + return err1.Error() + } + if e := f(); e != nil { + return e.Error() + } + if x := f(); x != nil { + return x.Error() + } + + // only the error names (err, err1, e) convert; u is not known + if err := u(); err != nil { + return err.Error() + } + if err1 := u(); err1 != nil { + return err1.Error() + } + if e := u(); e != nil { + return e.Error() + } + if x := u(); x != nil { + return x.String() + } + return "" +} + +type T int + +func (t T) String() string { return "t" } + +type PT int + +func (p *PT) String() string { return "pt" } + +type MyError int + +func (t MyError) Error() string { return "myerror" } + +type PMyError int + +func (p *PMyError) Error() string { return "pmyerror" } + +func error_() {} + +var error_ int +`, + }, + { + Name: "error.3", + In: `package main + +import "os" + +func f() os.Error { + return os.EOF +} + +type PathError struct { + Name string + Error os.Error +} + +func (p *PathError) String() string { + return p.Name + ": " + p.Error.String() +} + +func (p *PathError) Error1() string { + p = &PathError{Error: nil} + return fmt.Sprint(p.Name, ": ", p.Error) +} +`, + Out: `package main + +import "io" + +func f() error { + return io.EOF +} + +type PathError struct { + Name string + Err error +} + +func (p *PathError) Error() string { + return p.Name + ": " + p.Err.Error() +} + +func (p *PathError) Error1() string { + p = &PathError{Err: nil} + return fmt.Sprint(p.Name, ": ", p.Err) +} +`, + }, +} diff --git a/src/cmd/gofix/fix.go b/src/cmd/gofix/fix.go index 4eaadac2b4..9a51085dd1 100644 --- a/src/cmd/gofix/fix.go +++ b/src/cmd/gofix/fix.go @@ -4,11 +4,20 @@ package main +/* +receiver named error +function named error +method on error +exiterror +slice of named type (go/scanner) +*/ + import ( "fmt" "go/ast" "go/token" "os" + "path" "strconv" "strings" ) @@ -97,6 +106,8 @@ func walkBeforeAfter(x interface{}, before, after func(interface{})) { walkBeforeAfter(*n, before, after) case **ast.Ident: walkBeforeAfter(*n, before, after) + case **ast.BasicLit: + walkBeforeAfter(*n, before, after) // pointers to slices case *[]ast.Decl: @@ -114,7 +125,9 @@ func walkBeforeAfter(x interface{}, before, after func(interface{})) { // These are ordered and grouped to match ../../pkg/go/ast/ast.go case *ast.Field: + walkBeforeAfter(&n.Names, before, after) walkBeforeAfter(&n.Type, before, after) + walkBeforeAfter(&n.Tag, before, after) case *ast.FieldList: for _, field := range n.List { walkBeforeAfter(field, before, after) @@ -484,23 +497,101 @@ func newPkgDot(pos token.Pos, pkg, name string) ast.Expr { } } +// renameTop renames all references to the top-level name top. +// It returns true if it makes any changes. +func renameTop(f *ast.File, old, new string) bool { + var fixed bool + + // Rename any conflicting imports + // (assuming package name is last element of path). + for _, s := range f.Imports { + if s.Name != nil { + if s.Name.Name == old { + s.Name.Name = new + fixed = true + } + } else { + _, thisName := path.Split(importPath(s)) + if thisName == old { + s.Name = ast.NewIdent(new) + fixed = true + } + } + } + + // Rename any top-level declarations. + for _, d := range f.Decls { + switch d := d.(type) { + case *ast.FuncDecl: + if d.Recv == nil && d.Name.Name == old { + d.Name.Name = new + d.Name.Obj.Name = new + fixed = true + } + case *ast.GenDecl: + for _, s := range d.Specs { + switch s := s.(type) { + case *ast.TypeSpec: + if s.Name.Name == old { + s.Name.Name = new + s.Name.Obj.Name = new + fixed = true + } + case *ast.ValueSpec: + for _, n := range s.Names { + if n.Name == old { + n.Name = new + n.Obj.Name = new + fixed = true + } + } + } + } + } + } + + // Rename top-level old to new, both unresolved names + // (probably defined in another file) and names that resolve + // to a declaration we renamed. + walk(f, func(n interface{}) { + id, ok := n.(*ast.Ident) + if ok && isTopName(id, old) { + id.Name = new + fixed = true + } + if ok && id.Obj != nil && id.Name == old && id.Obj.Name == new { + id.Name = id.Obj.Name + fixed = true + } + }) + + return fixed +} + // addImport adds the import path to the file f, if absent. -func addImport(f *ast.File, path string) { - if imports(f, path) { +func addImport(f *ast.File, ipath string) { + if imports(f, ipath) { return } + // Determine name of import. + // Assume added imports follow convention of using last element. + _, name := path.Split(ipath) + + // Rename any conflicting top-level references from name to name_. + renameTop(f, name, name+"_") + newImport := &ast.ImportSpec{ Path: &ast.BasicLit{ Kind: token.STRING, - Value: strconv.Quote(path), + Value: strconv.Quote(ipath), }, } var impdecl *ast.GenDecl // Find an import decl to add to. - var lastImport int = -1 + var lastImport = -1 for i, decl := range f.Decls { gen, ok := decl.(*ast.GenDecl) @@ -535,7 +626,7 @@ func addImport(f *ast.File, path string) { insertAt := len(impdecl.Specs) // default to end of specs for i, spec := range impdecl.Specs { impspec := spec.(*ast.ImportSpec) - if importPath(impspec) > path { + if importPath(impspec) > ipath { insertAt = i break } diff --git a/src/cmd/gofix/httpheaders.go b/src/cmd/gofix/httpheaders.go index 2e906d859c..e9856f5db4 100644 --- a/src/cmd/gofix/httpheaders.go +++ b/src/cmd/gofix/httpheaders.go @@ -31,7 +31,7 @@ func httpheaders(f *ast.File) bool { }) fixed := false - typeof := typecheck(headerTypeConfig, f) + typeof, _ := typecheck(headerTypeConfig, f) walk(f, func(ni interface{}) { switch n := ni.(type) { case *ast.SelectorExpr: diff --git a/src/cmd/gofix/imagecolor.go b/src/cmd/gofix/imagecolor.go index d6171196d9..c7900e4657 100644 --- a/src/cmd/gofix/imagecolor.go +++ b/src/cmd/gofix/imagecolor.go @@ -47,8 +47,6 @@ func imagecolor(f *ast.File) (fixed bool) { return } - importColor := false - walk(f, func(n interface{}) { s, ok := n.(*ast.SelectorExpr) @@ -66,20 +64,17 @@ func imagecolor(f *ast.File) (fixed bool) { default: for _, rename := range colorRenames { if sel == rename.in { + addImport(f, "image/color") s.X.(*ast.Ident).Name = "color" s.Sel.Name = rename.out fixed = true - importColor = true } } } }) - if importColor { - addImport(f, "image/color") - if !usesImport(f, "image") { - deleteImport(f, "image") - } + if fixed && !usesImport(f, "image") { + deleteImport(f, "image") } return } diff --git a/src/cmd/gofix/main.go b/src/cmd/gofix/main.go index e0709fc8ba..56232d708a 100644 --- a/src/cmd/gofix/main.go +++ b/src/cmd/gofix/main.go @@ -28,14 +28,18 @@ var ( var allowedRewrites = flag.String("r", "", "restrict the rewrites to this comma-separated list") -var allowed map[string]bool +var forceRewrites = flag.String("force", "", + "force these fixes to run even if the code looks updated") + +var allowed, force map[string]bool var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files") func usage() { - fmt.Fprintf(os.Stderr, "usage: gofix [-diff] [-r fixname,...] [path ...]\n") + fmt.Fprintf(os.Stderr, "usage: gofix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n") flag.PrintDefaults() fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n") + sort.Sort(fixes) for _, f := range fixes { fmt.Fprintf(os.Stderr, "\n%s\n", f.name) desc := strings.TrimSpace(f.desc) @@ -46,8 +50,6 @@ func usage() { } func main() { - sort.Sort(fixes) - flag.Usage = usage flag.Parse() @@ -58,6 +60,13 @@ func main() { } } + if *forceRewrites != "" { + force = make(map[string]bool) + for _, f := range strings.Split(*forceRewrites, ",") { + force[f] = true + } + } + if flag.NArg() == 0 { if err := processFile("standard input", true); err != nil { report(err) diff --git a/src/cmd/gofix/math.go b/src/cmd/gofix/math.go index 8af4e87c7d..a9a11ed615 100644 --- a/src/cmd/gofix/math.go +++ b/src/cmd/gofix/math.go @@ -4,14 +4,7 @@ package main -import ( - "fmt" - "os" - "go/ast" -) - -var _ fmt.Stringer -var _ os.Error +import "go/ast" var mathFix = fix{ "math", diff --git a/src/cmd/gofix/reflect.go b/src/cmd/gofix/reflect.go index c292543ab8..2227d69b44 100644 --- a/src/cmd/gofix/reflect.go +++ b/src/cmd/gofix/reflect.go @@ -95,7 +95,7 @@ func reflectFn(f *ast.File) bool { // Rewrite names in method calls. // Needs basic type information (see above). - typeof := typecheck(reflectTypeConfig, f) + typeof, _ := typecheck(reflectTypeConfig, f) walk(f, func(n interface{}) { switch n := n.(type) { case *ast.SelectorExpr: diff --git a/src/cmd/gofix/signal.go b/src/cmd/gofix/signal.go index aaad348259..9b548bd089 100644 --- a/src/cmd/gofix/signal.go +++ b/src/cmd/gofix/signal.go @@ -32,16 +32,14 @@ func signal(f *ast.File) (fixed bool) { sel := s.Sel.String() if sel == "Signal" || sel == "UnixSignal" || strings.HasPrefix(sel, "SIG") { + addImport(f, "os") s.X = &ast.Ident{Name: "os"} fixed = true } }) - if fixed { - addImport(f, "os") - if !usesImport(f, "os/signal") { - deleteImport(f, "os/signal") - } + if fixed && !usesImport(f, "os/signal") { + deleteImport(f, "os/signal") } return } diff --git a/src/cmd/gofix/typecheck.go b/src/cmd/gofix/typecheck.go index 2d81b9710e..23fc8bfe96 100644 --- a/src/cmd/gofix/typecheck.go +++ b/src/cmd/gofix/typecheck.go @@ -99,6 +99,7 @@ type Type struct { Field map[string]string // map field name to type Method map[string]string // map method name to comma-separated return types Embed []string // list of types this type embeds (for extra methods) + Def string // definition of named type } // dot returns the type of "typ.name", making its decision @@ -128,9 +129,15 @@ func (typ *Type) dot(cfg *TypeConfig, name string) string { } // typecheck type checks the AST f assuming the information in cfg. -// It returns a map from AST nodes to type information in gofmt string form. -func typecheck(cfg *TypeConfig, f *ast.File) map[interface{}]string { - typeof := make(map[interface{}]string) +// It returns two maps with type information: +// typeof maps AST nodes to type information in gofmt string form. +// assign maps type strings to lists of expressions that were assigned +// to values of another type that were assigned to that type. +func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, assign map[string][]interface{}) { + typeof = make(map[interface{}]string) + assign = make(map[string][]interface{}) + cfg1 := &TypeConfig{} + *cfg1 = *cfg // make copy so we can add locally // gather function declarations for _, decl := range f.Decls { @@ -138,7 +145,7 @@ func typecheck(cfg *TypeConfig, f *ast.File) map[interface{}]string { if !ok { continue } - typecheck1(cfg, fn.Type, typeof) + typecheck1(cfg, fn.Type, typeof, assign) t := typeof[fn.Type] if fn.Recv != nil { // The receiver must be a type. @@ -168,8 +175,42 @@ func typecheck(cfg *TypeConfig, f *ast.File) map[interface{}]string { } } - typecheck1(cfg, f, typeof) - return typeof + // gather struct declarations + for _, decl := range f.Decls { + d, ok := decl.(*ast.GenDecl) + if ok { + for _, s := range d.Specs { + switch s := s.(type) { + case *ast.TypeSpec: + if cfg1.Type[s.Name.Name] != nil { + break + } + if cfg1.Type == cfg.Type || cfg1.Type == nil { + // Copy map lazily: it's time. + cfg1.Type = make(map[string]*Type) + for k, v := range cfg.Type { + cfg1.Type[k] = v + } + } + t := &Type{Field: map[string]string{}} + cfg1.Type[s.Name.Name] = t + switch st := s.Type.(type) { + case *ast.StructType: + for _, f := range st.Fields.List { + for _, n := range f.Names { + t.Field[n.Name] = gofmt(f.Type) + } + } + case *ast.ArrayType, *ast.StarExpr, *ast.MapType: + t.Def = gofmt(st) + } + } + } + } + } + + typecheck1(cfg1, f, typeof, assign) + return typeof, assign } func makeExprList(a []*ast.Ident) []ast.Expr { @@ -183,11 +224,14 @@ func makeExprList(a []*ast.Ident) []ast.Expr { // Typecheck1 is the recursive form of typecheck. // It is like typecheck but adds to the information in typeof // instead of allocating a new map. -func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { +func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string, assign map[string][]interface{}) { // set sets the type of n to typ. // If isDecl is true, n is being declared. set := func(n ast.Expr, typ string, isDecl bool) { if typeof[n] != "" || typ == "" { + if typeof[n] != typ { + assign[typ] = append(assign[typ], n) + } return } typeof[n] = typ @@ -236,6 +280,14 @@ func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { } } + expand := func(s string) string { + typ := cfg.Type[s] + if typ != nil && typ.Def != "" { + return typ.Def + } + return s + } + // The main type check is a recursive algorithm implemented // by walkBeforeAfter(n, before, after). // Most of it is bottom-up, but in a few places we need @@ -263,7 +315,7 @@ func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { defer func() { if t := typeof[n]; t != "" { pos := fset.Position(n.(ast.Node).Pos()) - fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos.String(), gofmt(n), t) + fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t) } }() } @@ -405,6 +457,8 @@ func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { // x.(T) has type T. if t := typeof[n.Type]; isType(t) { typeof[n] = getType(t) + } else { + typeof[n] = gofmt(n.Type) } case *ast.SliceExpr: @@ -413,7 +467,7 @@ func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { case *ast.IndexExpr: // x[i] has key type of x's type. - t := typeof[n.X] + t := expand(typeof[n.X]) if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") { // Lazy: assume there are no nested [] in the array // length or map key type. @@ -426,7 +480,7 @@ func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { // *x for x of type *T has type T when x is an expr. // We don't use the result when *x is a type, but // compute it anyway. - t := typeof[n.X] + t := expand(typeof[n.X]) if isType(t) { typeof[n] = "type *" + getType(t) } else if strings.HasPrefix(t, "*") { @@ -448,6 +502,39 @@ func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { // (x) has type of x. typeof[n] = typeof[n.X] + case *ast.RangeStmt: + t := expand(typeof[n.X]) + if t == "" { + return + } + var key, value string + if t == "string" { + key, value = "int", "rune" + } else if strings.HasPrefix(t, "[") { + key = "int" + if i := strings.Index(t, "]"); i >= 0 { + value = t[i+1:] + } + } else if strings.HasPrefix(t, "map[") { + if i := strings.Index(t, "]"); i >= 0 { + key, value = t[4:i], t[i+1:] + } + } + changed := false + if n.Key != nil && key != "" { + changed = true + set(n.Key, key, n.Tok == token.DEFINE) + } + if n.Value != nil && value != "" { + changed = true + set(n.Value, value, n.Tok == token.DEFINE) + } + // Ugly failure of vision: already type-checked body. + // Do it again now that we have that type info. + if changed { + typecheck1(cfg, n.Body, typeof, assign) + } + case *ast.TypeSwitchStmt: // Type of variable changes for each case in type switch, // but go/parser generates just one variable. @@ -471,7 +558,7 @@ func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string) { tt = getType(tt) typeof[varx] = tt typeof[varx.Obj] = tt - typecheck1(cfg, cas.Body, typeof) + typecheck1(cfg, cas.Body, typeof, assign) } } } diff --git a/src/cmd/gofix/url.go b/src/cmd/gofix/url.go index 455b544b63..d90f2b0cc1 100644 --- a/src/cmd/gofix/url.go +++ b/src/cmd/gofix/url.go @@ -4,14 +4,7 @@ package main -import ( - "fmt" - "os" - "go/ast" -) - -var _ fmt.Stringer -var _ os.Error +import "go/ast" var urlFix = fix{ "url", @@ -42,12 +35,7 @@ func url(f *ast.File) bool { fixed := false // Update URL code. - var skip interface{} urlWalk := func(n interface{}) { - if n == skip { - skip = nil - return - } // Is it an identifier? if ident, ok := n.(*ast.Ident); ok && ident.Name == "url" { ident.Name = "url_" @@ -58,12 +46,6 @@ func url(f *ast.File) bool { fixed = urlDoFields(fn.Params) || fixed fixed = urlDoFields(fn.Results) || fixed } - // U{url: ...} is likely a struct field. - if kv, ok := n.(*ast.KeyValueExpr); ok { - if ident, ok := kv.Key.(*ast.Ident); ok && ident.Name == "url" { - skip = ident - } - } } // Fix up URL code and add import, at most once. @@ -71,8 +53,8 @@ func url(f *ast.File) bool { if fixed { return } - walkBeforeAfter(f, urlWalk, nop) addImport(f, "url") + walkBeforeAfter(f, urlWalk, nop) fixed = true } diff --git a/src/cmd/gofix/url_test.go b/src/cmd/gofix/url_test.go index ca886e983e..39827f780e 100644 --- a/src/cmd/gofix/url_test.go +++ b/src/cmd/gofix/url_test.go @@ -103,14 +103,14 @@ func h() (url string) { import "url" -type U struct{ url int } +type U struct{ url_ int } type M map[int]int func f() { url.Parse(a) var url_ = 23 url_, x := 45, y - _ = U{url: url_} + _ = U{url_: url_} _ = M{url_ + 1: url_} }