From 165bdd618e6d174c61ee7bd73a28f3e5c6fdced1 Mon Sep 17 00:00:00 2001 From: "Bryan C. Mills" Date: Tue, 24 Apr 2018 16:05:29 -0400 Subject: [PATCH] imports: fix races in findImportGoPath Before this change, findImportGoPath used a field within the (otherwise read-only) structs in the dirScan map to cache the distance from the importing package to the candidate package to be imported. As a result, the top-level imports.Process function was not safe to call concurrently: one goroutine could overwrite the distances while another was attempting to sort by them. Furthermore, there were some internal write-after-write races (writing the same cached distance to the same address) that otherwise violate the Go memory model. This change fixes those races, simplifies the concurrency patterns, and clarifies goroutine lifetimes. The functions in the imports package now wait for the goroutines they spawn to finish before returning, eliminating the need for an awkward test-only mutex that could otherwise mask real races in the production code paths. See also: https://golang.org/wiki/CodeReviewComments#goroutine-lifetimes https://golang.org/wiki/CodeReviewComments#synchronous-functions Fixes golang/go#25030. Change-Id: I8fec735e0d4ff7abab406dea9d0c11d1bd93d775 Reviewed-on: https://go-review.googlesource.com/109156 Run-TryBot: Bryan C. Mills TryBot-Result: Gobot Gobot Reviewed-by: Joe Tsai --- imports/fix.go | 255 +++++++++++++++++++++++++------------------- imports/fix_test.go | 30 ++---- 2 files changed, 155 insertions(+), 130 deletions(-) diff --git a/imports/fix.go b/imports/fix.go index aacb1bd5d24..ebb228d1ae9 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -7,6 +7,7 @@ package imports import ( "bufio" "bytes" + "context" "fmt" "go/ast" "go/build" @@ -28,11 +29,6 @@ import ( // Debug controls verbose logging. var Debug = false -var ( - inTests = false // set true by fix_test.go; if false, no need to use testMu - testMu sync.RWMutex // guards globals reset by tests; used only if inTests -) - // LocalPrefix is a comma-separated string of import path prefixes, which, if // set, instructs Process to sort the import paths with the given prefixes // into another group after 3rd-party packages. @@ -293,15 +289,27 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri } // Search for imports matching potential package references. - searches := 0 type result struct { - ipath string // import path (if err == nil) + ipath string // import path name string // optional name to rename import as - err error } - results := make(chan result) + results := make(chan result, len(refs)) + + ctx, cancel := context.WithCancel(context.TODO()) + var wg sync.WaitGroup + defer func() { + cancel() + wg.Wait() + }() + var ( + firstErr error + firstErrOnce sync.Once + ) for pkgName, symbols := range refs { + wg.Add(1) go func(pkgName string, symbols map[string]bool) { + defer wg.Done() + if packageInfo != nil { sibling := packageInfo.Imports[pkgName] if sibling.Path != "" { @@ -314,30 +322,45 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri } } } - ipath, rename, err := findImport(pkgName, symbols, filename) - r := result{ipath: ipath, err: err} + + ipath, rename, err := findImport(ctx, pkgName, symbols, filename) + if err != nil { + firstErrOnce.Do(func() { + firstErr = err + cancel() + }) + return + } + + if ipath == "" { + return // No matching package. + } + + r := result{ipath: ipath} if rename { r.name = pkgName } results <- r + return }(pkgName, symbols) - searches++ } - for i := 0; i < searches; i++ { - result := <-results - if result.err != nil { - return nil, result.err - } - if result.ipath != "" { - if result.name != "" { - astutil.AddNamedImport(fset, f, result.name, result.ipath) - } else { - astutil.AddImport(fset, f, result.ipath) - } - added = append(added, result.ipath) + go func() { + wg.Wait() + close(results) + }() + + for result := range results { + if result.name != "" { + astutil.AddNamedImport(fset, f, result.name, result.ipath) + } else { + astutil.AddImport(fset, f, result.ipath) } + added = append(added, result.ipath) } + if firstErr != nil { + return nil, firstErr + } return added, nil } @@ -446,7 +469,7 @@ var ( populateIgnoreOnce sync.Once ignoredDirs []os.FileInfo - dirScanMu sync.RWMutex + dirScanMu sync.Mutex dirScan map[string]*pkg // abs dir path => *pkg ) @@ -454,12 +477,16 @@ type pkg struct { dir string // absolute file path to pkg directory ("/usr/lib/go/src/net/http") importPath string // full pkg import path ("net/http", "foo/bar/vendor/a/b") importPathShort string // vendorless import path ("net/http", "a/b") - distance int // relative distance to target +} + +type pkgDistance struct { + pkg *pkg + distance int // relative distance to target } // byDistanceOrImportPathShortLength sorts by relative distance breaking ties // on the short import path length and then the import string itself. -type byDistanceOrImportPathShortLength []*pkg +type byDistanceOrImportPathShortLength []pkgDistance func (s byDistanceOrImportPathShortLength) Len() int { return len(s) } func (s byDistanceOrImportPathShortLength) Less(i, j int) bool { @@ -474,7 +501,7 @@ func (s byDistanceOrImportPathShortLength) Less(i, j int) bool { return di < dj } - vi, vj := s[i].importPathShort, s[j].importPathShort + vi, vj := s[i].pkg.importPathShort, s[j].pkg.importPathShort if len(vi) != len(vj) { return len(vi) < len(vj) } @@ -590,35 +617,26 @@ func shouldTraverse(dir string, fi os.FileInfo) bool { var testHookScanDir = func(dir string) {} +type goDirType string + +const ( + goRoot goDirType = "$GOROOT" + goPath goDirType = "$GOPATH" +) + var scanGoRootDone = make(chan struct{}) // closed when scanGoRoot is done -func scanGoRoot() { - go func() { - scanGoDirs(true) - close(scanGoRootDone) - }() -} - -func scanGoPath() { scanGoDirs(false) } - -func scanGoDirs(goRoot bool) { +// scanGoDirs populates the dirScan map for the given directory type. It may be +// called concurrently (and usually is, if both directory types are needed). +func scanGoDirs(which goDirType) { if Debug { - which := "$GOROOT" - if !goRoot { - which = "$GOPATH" - } - log.Printf("scanning " + which) - defer log.Printf("scanned " + which) + log.Printf("scanning %s", which) + defer log.Printf("scanned %s", which) } - dirScanMu.Lock() - if dirScan == nil { - dirScan = make(map[string]*pkg) - } - dirScanMu.Unlock() for _, srcDir := range build.Default.SrcDirs() { isGoroot := srcDir == filepath.Join(build.Default.GOROOT, "src") - if isGoroot != goRoot { + if isGoroot != (which == goRoot) { continue } testHookScanDir(srcDir) @@ -633,16 +651,21 @@ func scanGoDirs(goRoot bool) { if !strings.HasSuffix(path, ".go") { return nil } + dirScanMu.Lock() - if _, dup := dirScan[dir]; !dup { - importpath := filepath.ToSlash(dir[len(srcDir)+len("/"):]) - dirScan[dir] = &pkg{ - importPath: importpath, - importPathShort: VendorlessPath(importpath), - dir: dir, - } + defer dirScanMu.Unlock() + if _, dup := dirScan[dir]; dup { + return nil + } + if dirScan == nil { + dirScan = make(map[string]*pkg) + } + importpath := filepath.ToSlash(dir[len(srcDir)+len("/"):]) + dirScan[dir] = &pkg{ + importPath: importpath, + importPathShort: VendorlessPath(importpath), + dir: dir, } - dirScanMu.Unlock() return nil } if typ == os.ModeDir { @@ -698,20 +721,20 @@ func VendorlessPath(ipath string) string { // loadExports returns the set of exported symbols in the package at dir. // It returns nil on error or if the package name in dir does not match expectPackage. -var loadExports func(expectPackage, dir string) map[string]bool = loadExportsGoPath +var loadExports func(ctx context.Context, expectPackage, dir string) (map[string]bool, error) = loadExportsGoPath -func loadExportsGoPath(expectPackage, dir string) map[string]bool { +func loadExportsGoPath(ctx context.Context, expectPackage, dir string) (map[string]bool, error) { if Debug { log.Printf("loading exports in dir %s (seeking package %s)", dir, expectPackage) } exports := make(map[string]bool) - ctx := build.Default + buildCtx := build.Default // ReadDir is like ioutil.ReadDir, but only returns *.go files // and filters out _test.go files since they're not relevant // and only slow things down. - ctx.ReadDir = func(dir string) (notTests []os.FileInfo, err error) { + buildCtx.ReadDir = func(dir string) (notTests []os.FileInfo, err error) { all, err := ioutil.ReadDir(dir) if err != nil { return nil, err @@ -726,16 +749,22 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool { return notTests, nil } - files, err := ctx.ReadDir(dir) + files, err := buildCtx.ReadDir(dir) if err != nil { log.Print(err) - return nil + return nil, err } fset := token.NewFileSet() for _, fi := range files { - match, err := ctx.MatchFile(dir, fi.Name()) + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + + match, err := buildCtx.MatchFile(dir, fi.Name()) if err != nil || !match { continue } @@ -745,19 +774,20 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool { if Debug { log.Printf("Parsing %s: %v", fullFile, err) } - return nil + return nil, err } pkgName := f.Name.Name if pkgName == "documentation" { // Special case from go/build.ImportDir, not - // handled by ctx.MatchFile. + // handled by buildCtx.MatchFile. continue } if pkgName != expectPackage { + err := fmt.Errorf("scan of dir %v is not expected package %v (actually %v)", dir, expectPackage, pkgName) if Debug { - log.Printf("scan of dir %v is not expected package %v (actually %v)", dir, expectPackage, pkgName) + log.Print(err) } - return nil + return nil, err } for name := range f.Scope.Objects { if ast.IsExported(name) { @@ -774,7 +804,7 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool { sort.Strings(exportList) log.Printf("loaded exports in dir %v (package %v): %v", dir, expectPackage, strings.Join(exportList, ", ")) } - return exports + return exports, nil } // findImport searches for a package with the given symbols. @@ -789,16 +819,11 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool { // import line: // import pkg "foo/bar" // to satisfy uses of pkg.X in the file. -var findImport func(pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) = findImportGoPath +var findImport func(ctx context.Context, pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) = findImportGoPath // findImportGoPath is the normal implementation of findImport. // (Some companies have their own internally.) -func findImportGoPath(pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) { - if inTests { - testMu.RLock() - defer testMu.RUnlock() - } - +func findImportGoPath(ctx context.Context, pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) { pkgDir, err := filepath.Abs(filename) if err != nil { return "", false, err @@ -836,18 +861,25 @@ func findImportGoPath(pkgName string, symbols map[string]bool, filename string) // // TODO(bradfitz): run each $GOPATH entry async. But nobody // really has more than one anyway, so low priority. - scanGoRootOnce.Do(scanGoRoot) // async + scanGoRootOnce.Do(func() { + go func() { + scanGoDirs(goRoot) + close(scanGoRootDone) + }() + }) if !fileInDir(filename, build.Default.GOROOT) { - scanGoPathOnce.Do(scanGoPath) // blocking + scanGoPathOnce.Do(func() { scanGoDirs(goPath) }) } <-scanGoRootDone // Find candidate packages, looking only at their directory names first. - var candidates []*pkg + var candidates []pkgDistance for _, pkg := range dirScan { if pkgIsCandidate(filename, pkgName, pkg) { - pkg.distance = distance(pkgDir, pkg.dir) - candidates = append(candidates, pkg) + candidates = append(candidates, pkgDistance{ + pkg: pkg, + distance: distance(pkgDir, pkg.dir), + }) } } @@ -857,60 +889,63 @@ func findImportGoPath(pkgName string, symbols map[string]bool, filename string) // there's no "penalty" for vendoring. sort.Sort(byDistanceOrImportPathShortLength(candidates)) if Debug { - for i, pkg := range candidates { - log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), pkg.importPathShort, pkg.dir) + for i, c := range candidates { + log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir) } } // Collect exports for packages with matching names. - done := make(chan struct{}) // closed when we find the answer - defer close(done) - rescv := make([]chan *pkg, len(candidates)) for i := range candidates { - rescv[i] = make(chan *pkg) + rescv[i] = make(chan *pkg, 1) } const maxConcurrentPackageImport = 4 loadExportsSem := make(chan struct{}, maxConcurrentPackageImport) + ctx, cancel := context.WithCancel(ctx) + var wg sync.WaitGroup + defer func() { + cancel() + wg.Wait() + }() + + wg.Add(1) go func() { - for i, pkg := range candidates { + defer wg.Done() + for i, c := range candidates { select { case loadExportsSem <- struct{}{}: - select { - case <-done: - return - default: - } - case <-done: + case <-ctx.Done(): return } - pkg := pkg - resc := rescv[i] - go func() { - if inTests { - testMu.RLock() - defer testMu.RUnlock() + + wg.Add(1) + go func(c pkgDistance, resc chan<- *pkg) { + defer func() { + <-loadExportsSem + wg.Done() + }() + + exports, err := loadExports(ctx, pkgName, c.pkg.dir) + if err != nil { + resc <- nil + return } - defer func() { <-loadExportsSem }() - exports := loadExports(pkgName, pkg.dir) // If it doesn't have the right // symbols, send nil to mean no match. for symbol := range symbols { if !exports[symbol] { - pkg = nil - break + resc <- nil + return } } - select { - case resc <- pkg: - case <-done: - } - }() + resc <- c.pkg + }(c, rescv[i]) } }() + for _, resc := range rescv { pkg := <-resc if pkg == nil { diff --git a/imports/fix_test.go b/imports/fix_test.go index 539f2cbbce8..62c6f3c32c7 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -6,6 +6,7 @@ package imports import ( "bytes" + "context" "flag" "go/build" "io/ioutil" @@ -909,7 +910,7 @@ func TestFixImports(t *testing.T) { defer func() { findImport = old }() - findImport = func(pkgName string, symbols map[string]bool, filename string) (string, bool, error) { + findImport = func(_ context.Context, pkgName string, symbols map[string]bool, filename string) (string, bool, error) { return simplePkgs[pkgName], pkgName == "str", nil } @@ -1185,7 +1186,7 @@ type Buffer2 struct {} } build.Default.GOROOT = goroot - got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}, "x.go") + got, rename, err := findImportGoPath(context.Background(), "bytes", map[string]bool{"Buffer2": true}, "x.go") if err != nil { t.Fatal(err) } @@ -1193,7 +1194,7 @@ type Buffer2 struct {} t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath) } - got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}, "x.go") + got, rename, err = findImportGoPath(context.Background(), "bytes", map[string]bool{"Missing": true}, "x.go") if err != nil { t.Fatal(err) } @@ -1203,34 +1204,23 @@ type Buffer2 struct {} }) } -func init() { - inTests = true -} - func withEmptyGoPath(fn func()) { - testMu.Lock() - - dirScanMu.Lock() populateIgnoreOnce = sync.Once{} scanGoRootOnce = sync.Once{} scanGoPathOnce = sync.Once{} dirScan = nil ignoredDirs = nil scanGoRootDone = make(chan struct{}) - dirScanMu.Unlock() oldGOPATH := build.Default.GOPATH oldGOROOT := build.Default.GOROOT build.Default.GOPATH = "" testHookScanDir = func(string) {} - testMu.Unlock() defer func() { - testMu.Lock() testHookScanDir = func(string) {} build.Default.GOPATH = oldGOPATH build.Default.GOROOT = oldGOROOT - testMu.Unlock() }() fn() @@ -1246,7 +1236,7 @@ func TestFindImportInternal(t *testing.T) { t.Skip(err) } - got, rename, err := findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go")) + got, rename, err := findImportGoPath(context.Background(), "race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go")) if err != nil { t.Fatal(err) } @@ -1255,7 +1245,7 @@ func TestFindImportInternal(t *testing.T) { } // should not be able to use internal from outside that tree - got, rename, err = findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "x.go")) + got, rename, err = findImportGoPath(context.Background(), "race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "x.go")) if err != nil { t.Fatal(err) } @@ -1295,7 +1285,7 @@ func TestFindImportRandRead(t *testing.T) { for _, sym := range tt.syms { m[sym] = true } - got, _, err := findImportGoPath("rand", m, file) + got, _, err := findImportGoPath(context.Background(), "rand", m, file) if err != nil { t.Errorf("for %q: %v", tt.syms, err) continue @@ -1313,7 +1303,7 @@ func TestFindImportVendor(t *testing.T) { "vendor/golang.org/x/net/http2/hpack/huffman.go": "package hpack\nfunc HuffmanDecode() { }\n", }, }.test(t, func(t *goimportTest) { - got, rename, err := findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(t.goroot, "src/math/x.go")) + got, rename, err := findImportGoPath(context.Background(), "hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(t.goroot, "src/math/x.go")) if err != nil { t.Fatal(err) } @@ -1650,8 +1640,8 @@ func TestImportPathToNameGoPathParse(t *testing.T) { func TestIgnoreConfiguration(t *testing.T) { testConfig{ gopathFiles: map[string]string{ - ".goimportsignore": "# comment line\n\n example.net", // tests comment, blank line, whitespace trimming - "example.net/pkg/pkg.go": "package pkg\nconst X = 1", + ".goimportsignore": "# comment line\n\n example.net", // tests comment, blank line, whitespace trimming + "example.net/pkg/pkg.go": "package pkg\nconst X = 1", "otherwise-longer-so-worse.example.net/foo/pkg/pkg.go": "package pkg\nconst X = 1", }, }.test(t, func(t *goimportTest) {