mirror of
https://github.com/golang/go
synced 2024-11-18 16:14:46 -07:00
imports: extend findImports to return a boolean, rename, that tells
goimports to use the package name as a local qualifier in an import. For example, if findImports("pkg", "X") returns ("foo/bar", rename=true), then goimports adds the import line: import pkg "foo/bar" to satisfy uses of pkg.X in the file. This change doesn't add any implementations of rename=true, though one is sketched in a TODO. LGTM=crawshaw R=crawshaw, rsc CC=bradfitz, golang-codereviews https://golang.org/cl/76400050
This commit is contained in:
parent
a781b00b0d
commit
a1c1cf19ba
@ -93,6 +93,7 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
|
||||
searches := 0
|
||||
type result struct {
|
||||
ipath string
|
||||
name string
|
||||
err error
|
||||
}
|
||||
results := make(chan result)
|
||||
@ -101,8 +102,12 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
|
||||
continue // skip over packages already imported
|
||||
}
|
||||
go func(pkgName string, symbols map[string]bool) {
|
||||
ipath, err := findImport(pkgName, symbols)
|
||||
results <- result{ipath, err}
|
||||
ipath, rename, err := findImport(pkgName, symbols)
|
||||
r := result{ipath: ipath, err: err}
|
||||
if rename {
|
||||
r.name = pkgName
|
||||
}
|
||||
results <- r
|
||||
}(pkgName, symbols)
|
||||
searches++
|
||||
}
|
||||
@ -112,7 +117,11 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) {
|
||||
return nil, result.err
|
||||
}
|
||||
if result.ipath != "" {
|
||||
astutil.AddImport(fset, f, result.ipath)
|
||||
if result.name != "" {
|
||||
astutil.AddNamedImport(fset, f, result.name, result.ipath)
|
||||
} else {
|
||||
astutil.AddImport(fset, f, result.ipath)
|
||||
}
|
||||
added = append(added, result.ipath)
|
||||
}
|
||||
}
|
||||
@ -270,14 +279,19 @@ func loadExportsGoPath(dir string) map[string]bool {
|
||||
// extended by adding a file with an init function.
|
||||
var findImport = findImportGoPath
|
||||
|
||||
func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
|
||||
func findImportGoPath(pkgName string, symbols map[string]bool) (string, bool, error) {
|
||||
// Fast path for the standard library.
|
||||
// In the common case we hopefully never have to scan the GOPATH, which can
|
||||
// be slow with moving disks.
|
||||
if pkg, ok := findImportStdlib(pkgName, symbols); ok {
|
||||
return pkg, nil
|
||||
if pkg, rename, ok := findImportStdlib(pkgName, symbols); ok {
|
||||
return pkg, rename, nil
|
||||
}
|
||||
|
||||
// TODO(sameer): look at the import lines for other Go files in the
|
||||
// local directory, since the user is likely to import the same packages
|
||||
// in the current Go file. Return rename=true when the other Go files
|
||||
// use a renamed package that's also used in the current file.
|
||||
|
||||
pkgIndexOnce.Do(loadPkgIndex)
|
||||
|
||||
// Collect exports for packages with matching names.
|
||||
@ -311,7 +325,7 @@ func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
|
||||
}
|
||||
}
|
||||
if len(pkgs) == 0 {
|
||||
return "", nil
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
// If there are multiple candidate packages, the shortest one wins.
|
||||
@ -323,7 +337,7 @@ func findImportGoPath(pkgName string, symbols map[string]bool) (string, error) {
|
||||
shortest = importPath
|
||||
}
|
||||
}
|
||||
return shortest, nil
|
||||
return shortest, false, nil
|
||||
}
|
||||
|
||||
type visitFn func(node ast.Node) ast.Visitor
|
||||
@ -332,17 +346,17 @@ func (fn visitFn) Visit(node ast.Node) ast.Visitor {
|
||||
return fn(node)
|
||||
}
|
||||
|
||||
func findImportStdlib(shortPkg string, symbols map[string]bool) (importPath string, ok bool) {
|
||||
func findImportStdlib(shortPkg string, symbols map[string]bool) (importPath string, rename, ok bool) {
|
||||
for symbol := range symbols {
|
||||
path := stdlib[shortPkg+"."+symbol]
|
||||
if path == "" {
|
||||
return "", false
|
||||
return "", false, false
|
||||
}
|
||||
if importPath != "" && importPath != path {
|
||||
// Ambiguous. Symbols pointed to different things.
|
||||
return "", false
|
||||
return "", false, false
|
||||
}
|
||||
importPath = path
|
||||
}
|
||||
return importPath, importPath != ""
|
||||
return importPath, false, importPath != ""
|
||||
}
|
||||
|
@ -505,6 +505,20 @@ var (
|
||||
b = gu.a
|
||||
c = fmt.Printf
|
||||
)
|
||||
`,
|
||||
},
|
||||
|
||||
{
|
||||
name: "renamed package",
|
||||
in: `package main
|
||||
|
||||
var _ = str.HasPrefix
|
||||
`,
|
||||
out: `package main
|
||||
|
||||
import str "strings"
|
||||
|
||||
var _ = str.HasPrefix
|
||||
`,
|
||||
},
|
||||
}
|
||||
@ -519,9 +533,10 @@ func TestFixImports(t *testing.T) {
|
||||
"zip": "archive/zip",
|
||||
"bytes": "bytes",
|
||||
"snappy": "code.google.com/p/snappy-go/snappy",
|
||||
"str": "strings",
|
||||
}
|
||||
findImport = func(pkgName string, symbols map[string]bool) (string, error) {
|
||||
return simplePkgs[pkgName], nil
|
||||
findImport = func(pkgName string, symbols map[string]bool) (string, bool, error) {
|
||||
return simplePkgs[pkgName], pkgName == "str", nil
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
@ -577,20 +592,20 @@ type Buffer2 struct {}
|
||||
build.Default.GOPATH = oldGOPATH
|
||||
}()
|
||||
|
||||
got, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true})
|
||||
got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != bytesPkgPath {
|
||||
t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, want "%s"`, got, bytesPkgPath)
|
||||
if got != bytesPkgPath || rename {
|
||||
t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath)
|
||||
}
|
||||
|
||||
got, err = findImportGoPath("bytes", map[string]bool{"Missing": true})
|
||||
got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got != "" {
|
||||
t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, want ""`, got)
|
||||
if got != "" || rename {
|
||||
t.Errorf(`findImportGoPath("bytes", Missing ...)=%q, %t, want "", false`, got, rename)
|
||||
}
|
||||
}
|
||||
|
||||
@ -607,12 +622,12 @@ func TestFindImportStdlib(t *testing.T) {
|
||||
{"ioutil", []string{"Discard"}, "io/ioutil"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got, ok := findImportStdlib(tt.pkg, strSet(tt.symbols))
|
||||
got, rename, ok := findImportStdlib(tt.pkg, strSet(tt.symbols))
|
||||
if (got != "") != ok {
|
||||
t.Error("findImportStdlib return value inconsistent")
|
||||
}
|
||||
if got != tt.want {
|
||||
t.Errorf("findImportStdlib(%q, %q) = %q; want %q", tt.pkg, tt.symbols, got, tt.want)
|
||||
if got != tt.want || rename {
|
||||
t.Errorf("findImportStdlib(%q, %q) = %q, %t; want %q, false", tt.pkg, tt.symbols, got, rename, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user