diff --git a/imports/fix.go b/imports/fix.go index aacb1bd5d2..ebb228d1ae 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -7,6 +7,7 @@ package imports import ( "bufio" "bytes" + "context" "fmt" "go/ast" "go/build" @@ -28,11 +29,6 @@ import ( // Debug controls verbose logging. var Debug = false -var ( - inTests = false // set true by fix_test.go; if false, no need to use testMu - testMu sync.RWMutex // guards globals reset by tests; used only if inTests -) - // LocalPrefix is a comma-separated string of import path prefixes, which, if // set, instructs Process to sort the import paths with the given prefixes // into another group after 3rd-party packages. @@ -293,15 +289,27 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri } // Search for imports matching potential package references. - searches := 0 type result struct { - ipath string // import path (if err == nil) + ipath string // import path name string // optional name to rename import as - err error } - results := make(chan result) + results := make(chan result, len(refs)) + + ctx, cancel := context.WithCancel(context.TODO()) + var wg sync.WaitGroup + defer func() { + cancel() + wg.Wait() + }() + var ( + firstErr error + firstErrOnce sync.Once + ) for pkgName, symbols := range refs { + wg.Add(1) go func(pkgName string, symbols map[string]bool) { + defer wg.Done() + if packageInfo != nil { sibling := packageInfo.Imports[pkgName] if sibling.Path != "" { @@ -314,30 +322,45 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri } } } - ipath, rename, err := findImport(pkgName, symbols, filename) - r := result{ipath: ipath, err: err} + + ipath, rename, err := findImport(ctx, pkgName, symbols, filename) + if err != nil { + firstErrOnce.Do(func() { + firstErr = err + cancel() + }) + return + } + + if ipath == "" { + return // No matching package. + } + + r := result{ipath: ipath} if rename { r.name = pkgName } results <- r + return }(pkgName, symbols) - searches++ } - for i := 0; i < searches; i++ { - result := <-results - if result.err != nil { - return nil, result.err - } - if result.ipath != "" { - if result.name != "" { - astutil.AddNamedImport(fset, f, result.name, result.ipath) - } else { - astutil.AddImport(fset, f, result.ipath) - } - added = append(added, result.ipath) + go func() { + wg.Wait() + close(results) + }() + + for result := range results { + if result.name != "" { + astutil.AddNamedImport(fset, f, result.name, result.ipath) + } else { + astutil.AddImport(fset, f, result.ipath) } + added = append(added, result.ipath) } + if firstErr != nil { + return nil, firstErr + } return added, nil } @@ -446,7 +469,7 @@ var ( populateIgnoreOnce sync.Once ignoredDirs []os.FileInfo - dirScanMu sync.RWMutex + dirScanMu sync.Mutex dirScan map[string]*pkg // abs dir path => *pkg ) @@ -454,12 +477,16 @@ type pkg struct { dir string // absolute file path to pkg directory ("/usr/lib/go/src/net/http") importPath string // full pkg import path ("net/http", "foo/bar/vendor/a/b") importPathShort string // vendorless import path ("net/http", "a/b") - distance int // relative distance to target +} + +type pkgDistance struct { + pkg *pkg + distance int // relative distance to target } // byDistanceOrImportPathShortLength sorts by relative distance breaking ties // on the short import path length and then the import string itself. -type byDistanceOrImportPathShortLength []*pkg +type byDistanceOrImportPathShortLength []pkgDistance func (s byDistanceOrImportPathShortLength) Len() int { return len(s) } func (s byDistanceOrImportPathShortLength) Less(i, j int) bool { @@ -474,7 +501,7 @@ func (s byDistanceOrImportPathShortLength) Less(i, j int) bool { return di < dj } - vi, vj := s[i].importPathShort, s[j].importPathShort + vi, vj := s[i].pkg.importPathShort, s[j].pkg.importPathShort if len(vi) != len(vj) { return len(vi) < len(vj) } @@ -590,35 +617,26 @@ func shouldTraverse(dir string, fi os.FileInfo) bool { var testHookScanDir = func(dir string) {} +type goDirType string + +const ( + goRoot goDirType = "$GOROOT" + goPath goDirType = "$GOPATH" +) + var scanGoRootDone = make(chan struct{}) // closed when scanGoRoot is done -func scanGoRoot() { - go func() { - scanGoDirs(true) - close(scanGoRootDone) - }() -} - -func scanGoPath() { scanGoDirs(false) } - -func scanGoDirs(goRoot bool) { +// scanGoDirs populates the dirScan map for the given directory type. It may be +// called concurrently (and usually is, if both directory types are needed). +func scanGoDirs(which goDirType) { if Debug { - which := "$GOROOT" - if !goRoot { - which = "$GOPATH" - } - log.Printf("scanning " + which) - defer log.Printf("scanned " + which) + log.Printf("scanning %s", which) + defer log.Printf("scanned %s", which) } - dirScanMu.Lock() - if dirScan == nil { - dirScan = make(map[string]*pkg) - } - dirScanMu.Unlock() for _, srcDir := range build.Default.SrcDirs() { isGoroot := srcDir == filepath.Join(build.Default.GOROOT, "src") - if isGoroot != goRoot { + if isGoroot != (which == goRoot) { continue } testHookScanDir(srcDir) @@ -633,16 +651,21 @@ func scanGoDirs(goRoot bool) { if !strings.HasSuffix(path, ".go") { return nil } + dirScanMu.Lock() - if _, dup := dirScan[dir]; !dup { - importpath := filepath.ToSlash(dir[len(srcDir)+len("/"):]) - dirScan[dir] = &pkg{ - importPath: importpath, - importPathShort: VendorlessPath(importpath), - dir: dir, - } + defer dirScanMu.Unlock() + if _, dup := dirScan[dir]; dup { + return nil + } + if dirScan == nil { + dirScan = make(map[string]*pkg) + } + importpath := filepath.ToSlash(dir[len(srcDir)+len("/"):]) + dirScan[dir] = &pkg{ + importPath: importpath, + importPathShort: VendorlessPath(importpath), + dir: dir, } - dirScanMu.Unlock() return nil } if typ == os.ModeDir { @@ -698,20 +721,20 @@ func VendorlessPath(ipath string) string { // loadExports returns the set of exported symbols in the package at dir. // It returns nil on error or if the package name in dir does not match expectPackage. -var loadExports func(expectPackage, dir string) map[string]bool = loadExportsGoPath +var loadExports func(ctx context.Context, expectPackage, dir string) (map[string]bool, error) = loadExportsGoPath -func loadExportsGoPath(expectPackage, dir string) map[string]bool { +func loadExportsGoPath(ctx context.Context, expectPackage, dir string) (map[string]bool, error) { if Debug { log.Printf("loading exports in dir %s (seeking package %s)", dir, expectPackage) } exports := make(map[string]bool) - ctx := build.Default + buildCtx := build.Default // ReadDir is like ioutil.ReadDir, but only returns *.go files // and filters out _test.go files since they're not relevant // and only slow things down. - ctx.ReadDir = func(dir string) (notTests []os.FileInfo, err error) { + buildCtx.ReadDir = func(dir string) (notTests []os.FileInfo, err error) { all, err := ioutil.ReadDir(dir) if err != nil { return nil, err @@ -726,16 +749,22 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool { return notTests, nil } - files, err := ctx.ReadDir(dir) + files, err := buildCtx.ReadDir(dir) if err != nil { log.Print(err) - return nil + return nil, err } fset := token.NewFileSet() for _, fi := range files { - match, err := ctx.MatchFile(dir, fi.Name()) + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + match, err := buildCtx.MatchFile(dir, fi.Name()) if err != nil || !match { continue } @@ -745,19 +774,20 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool { if Debug { log.Printf("Parsing %s: %v", fullFile, err) } - return nil + return nil, err } pkgName := f.Name.Name if pkgName == "documentation" { // Special case from go/build.ImportDir, not - // handled by ctx.MatchFile. + // handled by buildCtx.MatchFile. continue } if pkgName != expectPackage { + err := fmt.Errorf("scan of dir %v is not expected package %v (actually %v)", dir, expectPackage, pkgName) if Debug { - log.Printf("scan of dir %v is not expected package %v (actually %v)", dir, expectPackage, pkgName) + log.Print(err) } - return nil + return nil, err } for name := range f.Scope.Objects { if ast.IsExported(name) { @@ -774,7 +804,7 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool { sort.Strings(exportList) log.Printf("loaded exports in dir %v (package %v): %v", dir, expectPackage, strings.Join(exportList, ", ")) } - return exports + return exports, nil } // findImport searches for a package with the given symbols. @@ -789,16 +819,11 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool { // import line: // import pkg "foo/bar" // to satisfy uses of pkg.X in the file. -var findImport func(pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) = findImportGoPath +var findImport func(ctx context.Context, pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) = findImportGoPath // findImportGoPath is the normal implementation of findImport. // (Some companies have their own internally.) -func findImportGoPath(pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) { - if inTests { - testMu.RLock() - defer testMu.RUnlock() - } - +func findImportGoPath(ctx context.Context, pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) { pkgDir, err := filepath.Abs(filename) if err != nil { return "", false, err @@ -836,18 +861,25 @@ func findImportGoPath(pkgName string, symbols map[string]bool, filename string) // // TODO(bradfitz): run each $GOPATH entry async. But nobody // really has more than one anyway, so low priority. - scanGoRootOnce.Do(scanGoRoot) // async + scanGoRootOnce.Do(func() { + go func() { + scanGoDirs(goRoot) + close(scanGoRootDone) + }() + }) if !fileInDir(filename, build.Default.GOROOT) { - scanGoPathOnce.Do(scanGoPath) // blocking + scanGoPathOnce.Do(func() { scanGoDirs(goPath) }) } <-scanGoRootDone // Find candidate packages, looking only at their directory names first. - var candidates []*pkg + var candidates []pkgDistance for _, pkg := range dirScan { if pkgIsCandidate(filename, pkgName, pkg) { - pkg.distance = distance(pkgDir, pkg.dir) - candidates = append(candidates, pkg) + candidates = append(candidates, pkgDistance{ + pkg: pkg, + distance: distance(pkgDir, pkg.dir), + }) } } @@ -857,60 +889,63 @@ func findImportGoPath(pkgName string, symbols map[string]bool, filename string) // there's no "penalty" for vendoring. sort.Sort(byDistanceOrImportPathShortLength(candidates)) if Debug { - for i, pkg := range candidates { - log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), pkg.importPathShort, pkg.dir) + for i, c := range candidates { + log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir) } } // Collect exports for packages with matching names. - done := make(chan struct{}) // closed when we find the answer - defer close(done) - rescv := make([]chan *pkg, len(candidates)) for i := range candidates { - rescv[i] = make(chan *pkg) + rescv[i] = make(chan *pkg, 1) } const maxConcurrentPackageImport = 4 loadExportsSem := make(chan struct{}, maxConcurrentPackageImport) + ctx, cancel := context.WithCancel(ctx) + var wg sync.WaitGroup + defer func() { + cancel() + wg.Wait() + }() + + wg.Add(1) go func() { - for i, pkg := range candidates { + defer wg.Done() + for i, c := range candidates { select { case loadExportsSem <- struct{}{}: - select { - case <-done: - return - default: - } - case <-done: + case <-ctx.Done(): return } - pkg := pkg - resc := rescv[i] - go func() { - if inTests { - testMu.RLock() - defer testMu.RUnlock() + + wg.Add(1) + go func(c pkgDistance, resc chan<- *pkg) { + defer func() { + <-loadExportsSem + wg.Done() + }() + + exports, err := loadExports(ctx, pkgName, c.pkg.dir) + if err != nil { + resc <- nil + return } - defer func() { <-loadExportsSem }() - exports := loadExports(pkgName, pkg.dir) // If it doesn't have the right // symbols, send nil to mean no match. for symbol := range symbols { if !exports[symbol] { - pkg = nil - break + resc <- nil + return } } - select { - case resc <- pkg: - case <-done: - } - }() + resc <- c.pkg + }(c, rescv[i]) } }() + for _, resc := range rescv { pkg := <-resc if pkg == nil { diff --git a/imports/fix_test.go b/imports/fix_test.go index 539f2cbbce..62c6f3c32c 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -6,6 +6,7 @@ package imports import ( "bytes" + "context" "flag" "go/build" "io/ioutil" @@ -909,7 +910,7 @@ func TestFixImports(t *testing.T) { defer func() { findImport = old }() - findImport = func(pkgName string, symbols map[string]bool, filename string) (string, bool, error) { + findImport = func(_ context.Context, pkgName string, symbols map[string]bool, filename string) (string, bool, error) { return simplePkgs[pkgName], pkgName == "str", nil } @@ -1185,7 +1186,7 @@ type Buffer2 struct {} } build.Default.GOROOT = goroot - got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}, "x.go") + got, rename, err := findImportGoPath(context.Background(), "bytes", map[string]bool{"Buffer2": true}, "x.go") if err != nil { t.Fatal(err) } @@ -1193,7 +1194,7 @@ type Buffer2 struct {} t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath) } - got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}, "x.go") + got, rename, err = findImportGoPath(context.Background(), "bytes", map[string]bool{"Missing": true}, "x.go") if err != nil { t.Fatal(err) } @@ -1203,34 +1204,23 @@ type Buffer2 struct {} }) } -func init() { - inTests = true -} - func withEmptyGoPath(fn func()) { - testMu.Lock() - - dirScanMu.Lock() populateIgnoreOnce = sync.Once{} scanGoRootOnce = sync.Once{} scanGoPathOnce = sync.Once{} dirScan = nil ignoredDirs = nil scanGoRootDone = make(chan struct{}) - dirScanMu.Unlock() oldGOPATH := build.Default.GOPATH oldGOROOT := build.Default.GOROOT build.Default.GOPATH = "" testHookScanDir = func(string) {} - testMu.Unlock() defer func() { - testMu.Lock() testHookScanDir = func(string) {} build.Default.GOPATH = oldGOPATH build.Default.GOROOT = oldGOROOT - testMu.Unlock() }() fn() @@ -1246,7 +1236,7 @@ func TestFindImportInternal(t *testing.T) { t.Skip(err) } - got, rename, err := findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go")) + got, rename, err := findImportGoPath(context.Background(), "race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go")) if err != nil { t.Fatal(err) } @@ -1255,7 +1245,7 @@ func TestFindImportInternal(t *testing.T) { } // should not be able to use internal from outside that tree - got, rename, err = findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "x.go")) + got, rename, err = findImportGoPath(context.Background(), "race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "x.go")) if err != nil { t.Fatal(err) } @@ -1295,7 +1285,7 @@ func TestFindImportRandRead(t *testing.T) { for _, sym := range tt.syms { m[sym] = true } - got, _, err := findImportGoPath("rand", m, file) + got, _, err := findImportGoPath(context.Background(), "rand", m, file) if err != nil { t.Errorf("for %q: %v", tt.syms, err) continue @@ -1313,7 +1303,7 @@ func TestFindImportVendor(t *testing.T) { "vendor/golang.org/x/net/http2/hpack/huffman.go": "package hpack\nfunc HuffmanDecode() { }\n", }, }.test(t, func(t *goimportTest) { - got, rename, err := findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(t.goroot, "src/math/x.go")) + got, rename, err := findImportGoPath(context.Background(), "hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(t.goroot, "src/math/x.go")) if err != nil { t.Fatal(err) } @@ -1650,8 +1640,8 @@ func TestImportPathToNameGoPathParse(t *testing.T) { func TestIgnoreConfiguration(t *testing.T) { testConfig{ gopathFiles: map[string]string{ - ".goimportsignore": "# comment line\n\n example.net", // tests comment, blank line, whitespace trimming - "example.net/pkg/pkg.go": "package pkg\nconst X = 1", + ".goimportsignore": "# comment line\n\n example.net", // tests comment, blank line, whitespace trimming + "example.net/pkg/pkg.go": "package pkg\nconst X = 1", "otherwise-longer-so-worse.example.net/foo/pkg/pkg.go": "package pkg\nconst X = 1", }, }.test(t, func(t *goimportTest) {