1
0
mirror of https://github.com/golang/go synced 2024-11-18 16:14:46 -07:00

imports: extend findImports to return a boolean, rename, that tells

goimports to use the package name as a local qualifier in an import.
For example, if findImports("pkg", "X") returns ("foo/bar",
rename=true), then goimports adds the import line:
  import pkg "foo/bar"
to satisfy uses of pkg.X in the file.

This change doesn't add any implementations of rename=true, though one
is sketched in a TODO.

LGTM=crawshaw
R=crawshaw, rsc
CC=bradfitz, golang-codereviews
https://golang.org/cl/76400050
This commit is contained in:
Sameer Ajmani 2014-03-25 09:37:10 -04:00
parent a781b00b0d
commit a1c1cf19ba
2 changed files with 52 additions and 23 deletions

View File

@ -93,6 +93,7 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
searches := 0
type result struct {
ipath string
name string
err error
}
results := make(chan result)
@ -101,8 +102,12 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
continue // skip over packages already imported
}
go func(pkgName string, symbols map[string]bool) {
ipath, err := findImport(pkgName, symbols)
results <- result{ipath, err}
ipath, rename, err := findImport(pkgName, symbols)
r := result{ipath: ipath, err: err}
if rename {
r.name = pkgName
}
results <- r
}(pkgName, symbols)
searches++
}
@ -112,7 +117,11 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
return nil, result.err
}
if result.ipath != "" {
astutil.AddImport(fset, f, result.ipath)
if result.name != "" {
astutil.AddNamedImport(fset, f, result.name, result.ipath)
} else {
astutil.AddImport(fset, f, result.ipath)
}
added = append(added, result.ipath)
}
}
@ -270,14 +279,19 @@ func loadExportsGoPath(dir string) map[string]bool {
// extended by adding a file with an init function.
var findImport = findImportGoPath
func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
func findImportGoPath(pkgName string, symbols map[string]bool) (string, bool, error) {
// Fast path for the standard library.
// In the common case we hopefully never have to scan the GOPATH, which can
// be slow with moving disks.
if pkg, ok := findImportStdlib(pkgName, symbols); ok {
return pkg, nil
if pkg, rename, ok := findImportStdlib(pkgName, symbols); ok {
return pkg, rename, nil
}
// TODO(sameer): look at the import lines for other Go files in the
// local directory, since the user is likely to import the same packages
// in the current Go file. Return rename=true when the other Go files
// use a renamed package that's also used in the current file.
pkgIndexOnce.Do(loadPkgIndex)
// Collect exports for packages with matching names.
@ -311,7 +325,7 @@ func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
}
}
if len(pkgs) == 0 {
return "", nil
return "", false, nil
}
// If there are multiple candidate packages, the shortest one wins.
@ -323,7 +337,7 @@ func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
shortest = importPath
}
}
return shortest, nil
return shortest, false, nil
}
type visitFn func(node ast.Node) ast.Visitor
@ -332,17 +346,17 @@ func (fn visitFn) Visit(node ast.Node) ast.Visitor {
return fn(node)
}
func findImportStdlib(shortPkg string, symbols map[string]bool) (importPath string, ok bool) {
func findImportStdlib(shortPkg string, symbols map[string]bool) (importPath string, rename, ok bool) {
for symbol := range symbols {
path := stdlib[shortPkg+"."+symbol]
if path == "" {
return "", false
return "", false, false
}
if importPath != "" && importPath != path {
// Ambiguous. Symbols pointed to different things.
return "", false
return "", false, false
}
importPath = path
}
return importPath, importPath != ""
return importPath, false, importPath != ""
}

View File

@ -505,6 +505,20 @@ var (
b = gu.a
c = fmt.Printf
)
`,
},
{
name: "renamed package",
in: `package main
var _ = str.HasPrefix
`,
out: `package main
import str "strings"
var _ = str.HasPrefix
`,
},
}
@ -519,9 +533,10 @@ func TestFixImports(t *testing.T) {
"zip": "archive/zip",
"bytes": "bytes",
"snappy": "code.google.com/p/snappy-go/snappy",
"str": "strings",
}
findImport = func(pkgName string, symbols map[string]bool) (string, error) {
return simplePkgs[pkgName], nil
findImport = func(pkgName string, symbols map[string]bool) (string, bool, error) {
return simplePkgs[pkgName], pkgName == "str", nil
}
for _, tt := range tests {
@ -577,20 +592,20 @@ type Buffer2 struct {}
build.Default.GOPATH = oldGOPATH
}()
got, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true})
got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true})
if err != nil {
t.Fatal(err)
}
if got != bytesPkgPath {
t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, want "%s"`, got, bytesPkgPath)
if got != bytesPkgPath || rename {
t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath)
}
got, err = findImportGoPath("bytes", map[string]bool{"Missing": true})
got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true})
if err != nil {
t.Fatal(err)
}
if got != "" {
t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, want ""`, got)
if got != "" || rename {
t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, %t, want "", false`, got, rename)
}
}
@ -607,12 +622,12 @@ func TestFindImportStdlib(t *testing.T) {
{"ioutil", []string{"Discard"}, "io/ioutil"},
}
for _, tt := range tests {
got, ok := findImportStdlib(tt.pkg, strSet(tt.symbols))
got, rename, ok := findImportStdlib(tt.pkg, strSet(tt.symbols))
if (got != "") != ok {
t.Error("findImportStdlib return value inconsistent")
}
if got != tt.want {
t.Errorf("findImportStdlib(%q, %q) = %q; want %q", tt.pkg, tt.symbols, got, tt.want)
if got != tt.want || rename {
t.Errorf("findImportStdlib(%q, %q) = %q, %t; want %q, false", tt.pkg, tt.symbols, got, rename, tt.want)
}
}
}