From 8cab8a1319f0be9798e7fe78b15da75e5f94b2e9 Mon Sep 17 00:00:00 2001 From: Michael Fraenkel Date: Fri, 8 Dec 2017 09:26:14 -0500 Subject: [PATCH] imports: sibling imports must have matching references When selecting a sibling's import, the unresolved reference must have been also used otherwise use the normal search to determine the best possible package to import. Fixes golang/go#23001 Change-Id: I38a983569991464970ad5921fe7f280dd3e35a2c Reviewed-on: https://go-review.googlesource.com/82875 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick --- imports/fix.go | 56 ++++++++++++++++++++++++++++++++++++++++++--- imports/fix_test.go | 11 ++++++++- 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/imports/fix.go b/imports/fix.go index 8a0d9834b7a..8dbbd27a156 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -78,6 +78,10 @@ type importInfo struct { type packageInfo struct { Globals map[string]bool // symbol => true Imports map[string]importInfo // pkg base name or alias => info + // refs are a set of package references currently satisfied by imports. + // first key: either base package (e.g. "fmt") or renamed package + // second key: referenced package symbol (e.g. "Println") + Refs map[string]map[string]bool } // dirPackageInfo exposes the dirPackageInfoFile function so that it can be overridden. @@ -93,7 +97,13 @@ func dirPackageInfoFile(pkgName, srcDir, filename string) (*packageInfo, error) return nil, err } - info := &packageInfo{Globals: make(map[string]bool), Imports: make(map[string]importInfo)} + info := &packageInfo{ + Globals: make(map[string]bool), + Imports: make(map[string]importInfo), + Refs: make(map[string]map[string]bool), + } + + visitor := collectReferences(info.Refs) for _, fi := range packageFileInfos { if fi.Name() == fileBase || !strings.HasSuffix(fi.Name(), ".go") { continue @@ -132,10 +142,45 @@ func dirPackageInfoFile(pkgName, srcDir, filename string) (*packageInfo, error) } info.Imports[name] = impInfo } + + ast.Walk(visitor, root) } return info, nil } +// collectReferences returns a visitor that collects all exported package +// references +func collectReferences(refs map[string]map[string]bool) visitFn { + var visitor visitFn + visitor = func(node ast.Node) ast.Visitor { + if node == nil { + return visitor + } + switch v := node.(type) { + case *ast.SelectorExpr: + xident, ok := v.X.(*ast.Ident) + if !ok { + break + } + if xident.Obj != nil { + // if the parser can resolve it, it's not a package ref + break + } + pkgName := xident.Name + r := refs[pkgName] + if r == nil { + r = make(map[string]bool) + refs[pkgName] = r + } + if ast.IsExported(v.Sel.Name) { + r[v.Sel.Name] = true + } + } + return visitor + } + return visitor +} + func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []string, err error) { // refs are a set of possible package references currently unsatisfied by imports. // first key: either base package (e.g. "fmt") or renamed package @@ -249,8 +294,13 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri if packageInfo != nil { sibling := packageInfo.Imports[pkgName] if sibling.Path != "" { - results <- result{ipath: sibling.Path, name: sibling.Alias} - return + refs := packageInfo.Refs[pkgName] + for symbol := range symbols { + if refs[symbol] { + results <- result{ipath: sibling.Path, name: sibling.Alias} + return + } + } } } ipath, rename, err := findImport(pkgName, symbols, filename) diff --git a/imports/fix_test.go b/imports/fix_test.go index d39f28ff55b..309855281f5 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -1609,15 +1609,19 @@ func TestSiblingImports(t *testing.T) { const provide = `package siblingimporttest import "local/log" +import "my/bytes" func LogSomething() { log.Print("Something") + bytes.SomeFunc() } ` // need is the file being tested that needs the import. const need = `package siblingimporttest +var _ = bytes.Buffer{} + func LogSomethingElse() { log.Print("Something else") } @@ -1626,7 +1630,12 @@ func LogSomethingElse() { // want is the expected result file const want = `package siblingimporttest -import "local/log" +import ( + "bytes" + "local/log" +) + +var _ = bytes.Buffer{} func LogSomethingElse() { log.Print("Something else")