1
0
mirror of https://github.com/golang/go synced 2024-10-01 04:18:33 -06:00

imports: fix races in findImportGoPath

Before this change, findImportGoPath used a field within the
(otherwise read-only) structs in the dirScan map to cache the distance
from the importing package to the candidate package to be imported. As
a result, the top-level imports.Process function was not safe to call
concurrently: one goroutine could overwrite the distances while
another was attempting to sort by them. Furthermore, there were some
internal write-after-write races (writing the same cached distance to
the same address) that otherwise violate the Go memory model.

This change fixes those races, simplifies the concurrency patterns,
and clarifies goroutine lifetimes. The functions in the imports
package now wait for the goroutines they spawn to finish before
returning, eliminating the need for an awkward test-only mutex that
could otherwise mask real races in the production code paths.

See also:
https://golang.org/wiki/CodeReviewComments#goroutine-lifetimes
https://golang.org/wiki/CodeReviewComments#synchronous-functions

Fixes golang/go#25030.

Change-Id: I8fec735e0d4ff7abab406dea9d0c11d1bd93d775
Reviewed-on: https://go-review.googlesource.com/109156
Run-TryBot: Bryan C. Mills <bcmills@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Bryan C. Mills 2018-04-24 16:05:29 -04:00
parent 8e070db38e
commit 165bdd618e
2 changed files with 155 additions and 130 deletions

View File

@ -7,6 +7,7 @@ package imports
import ( import (
"bufio" "bufio"
"bytes" "bytes"
"context"
"fmt" "fmt"
"go/ast" "go/ast"
"go/build" "go/build"
@ -28,11 +29,6 @@ import (
// Debug controls verbose logging. // Debug controls verbose logging.
var Debug = false var Debug = false
var (
inTests = false // set true by fix_test.go; if false, no need to use testMu
testMu sync.RWMutex // guards globals reset by tests; used only if inTests
)
// LocalPrefix is a comma-separated string of import path prefixes, which, if // LocalPrefix is a comma-separated string of import path prefixes, which, if
// set, instructs Process to sort the import paths with the given prefixes // set, instructs Process to sort the import paths with the given prefixes
// into another group after 3rd-party packages. // into another group after 3rd-party packages.
@ -293,15 +289,27 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri
} }
// Search for imports matching potential package references. // Search for imports matching potential package references.
searches := 0
type result struct { type result struct {
ipath string // import path (if err == nil) ipath string // import path
name string // optional name to rename import as name string // optional name to rename import as
err error
} }
results := make(chan result) results := make(chan result, len(refs))
ctx, cancel := context.WithCancel(context.TODO())
var wg sync.WaitGroup
defer func() {
cancel()
wg.Wait()
}()
var (
firstErr error
firstErrOnce sync.Once
)
for pkgName, symbols := range refs { for pkgName, symbols := range refs {
wg.Add(1)
go func(pkgName string, symbols map[string]bool) { go func(pkgName string, symbols map[string]bool) {
defer wg.Done()
if packageInfo != nil { if packageInfo != nil {
sibling := packageInfo.Imports[pkgName] sibling := packageInfo.Imports[pkgName]
if sibling.Path != "" { if sibling.Path != "" {
@ -314,21 +322,34 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri
} }
} }
} }
ipath, rename, err := findImport(pkgName, symbols, filename)
r := result{ipath: ipath, err: err} ipath, rename, err := findImport(ctx, pkgName, symbols, filename)
if err != nil {
firstErrOnce.Do(func() {
firstErr = err
cancel()
})
return
}
if ipath == "" {
return // No matching package.
}
r := result{ipath: ipath}
if rename { if rename {
r.name = pkgName r.name = pkgName
} }
results <- r results <- r
return
}(pkgName, symbols) }(pkgName, symbols)
searches++
} }
for i := 0; i < searches; i++ { go func() {
result := <-results wg.Wait()
if result.err != nil { close(results)
return nil, result.err }()
}
if result.ipath != "" { for result := range results {
if result.name != "" { if result.name != "" {
astutil.AddNamedImport(fset, f, result.name, result.ipath) astutil.AddNamedImport(fset, f, result.name, result.ipath)
} else { } else {
@ -336,8 +357,10 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri
} }
added = append(added, result.ipath) added = append(added, result.ipath)
} }
}
if firstErr != nil {
return nil, firstErr
}
return added, nil return added, nil
} }
@ -446,7 +469,7 @@ var (
populateIgnoreOnce sync.Once populateIgnoreOnce sync.Once
ignoredDirs []os.FileInfo ignoredDirs []os.FileInfo
dirScanMu sync.RWMutex dirScanMu sync.Mutex
dirScan map[string]*pkg // abs dir path => *pkg dirScan map[string]*pkg // abs dir path => *pkg
) )
@ -454,12 +477,16 @@ type pkg struct {
dir string // absolute file path to pkg directory ("/usr/lib/go/src/net/http") dir string // absolute file path to pkg directory ("/usr/lib/go/src/net/http")
importPath string // full pkg import path ("net/http", "foo/bar/vendor/a/b") importPath string // full pkg import path ("net/http", "foo/bar/vendor/a/b")
importPathShort string // vendorless import path ("net/http", "a/b") importPathShort string // vendorless import path ("net/http", "a/b")
}
type pkgDistance struct {
pkg *pkg
distance int // relative distance to target distance int // relative distance to target
} }
// byDistanceOrImportPathShortLength sorts by relative distance breaking ties // byDistanceOrImportPathShortLength sorts by relative distance breaking ties
// on the short import path length and then the import string itself. // on the short import path length and then the import string itself.
type byDistanceOrImportPathShortLength []*pkg type byDistanceOrImportPathShortLength []pkgDistance
func (s byDistanceOrImportPathShortLength) Len() int { return len(s) } func (s byDistanceOrImportPathShortLength) Len() int { return len(s) }
func (s byDistanceOrImportPathShortLength) Less(i, j int) bool { func (s byDistanceOrImportPathShortLength) Less(i, j int) bool {
@ -474,7 +501,7 @@ func (s byDistanceOrImportPathShortLength) Less(i, j int) bool {
return di < dj return di < dj
} }
vi, vj := s[i].importPathShort, s[j].importPathShort vi, vj := s[i].pkg.importPathShort, s[j].pkg.importPathShort
if len(vi) != len(vj) { if len(vi) != len(vj) {
return len(vi) < len(vj) return len(vi) < len(vj)
} }
@ -590,35 +617,26 @@ func shouldTraverse(dir string, fi os.FileInfo) bool {
var testHookScanDir = func(dir string) {} var testHookScanDir = func(dir string) {}
type goDirType string
const (
goRoot goDirType = "$GOROOT"
goPath goDirType = "$GOPATH"
)
var scanGoRootDone = make(chan struct{}) // closed when scanGoRoot is done var scanGoRootDone = make(chan struct{}) // closed when scanGoRoot is done
func scanGoRoot() { // scanGoDirs populates the dirScan map for the given directory type. It may be
go func() { // called concurrently (and usually is, if both directory types are needed).
scanGoDirs(true) func scanGoDirs(which goDirType) {
close(scanGoRootDone)
}()
}
func scanGoPath() { scanGoDirs(false) }
func scanGoDirs(goRoot bool) {
if Debug { if Debug {
which := "$GOROOT" log.Printf("scanning %s", which)
if !goRoot { defer log.Printf("scanned %s", which)
which = "$GOPATH"
} }
log.Printf("scanning " + which)
defer log.Printf("scanned " + which)
}
dirScanMu.Lock()
if dirScan == nil {
dirScan = make(map[string]*pkg)
}
dirScanMu.Unlock()
for _, srcDir := range build.Default.SrcDirs() { for _, srcDir := range build.Default.SrcDirs() {
isGoroot := srcDir == filepath.Join(build.Default.GOROOT, "src") isGoroot := srcDir == filepath.Join(build.Default.GOROOT, "src")
if isGoroot != goRoot { if isGoroot != (which == goRoot) {
continue continue
} }
testHookScanDir(srcDir) testHookScanDir(srcDir)
@ -633,16 +651,21 @@ func scanGoDirs(goRoot bool) {
if !strings.HasSuffix(path, ".go") { if !strings.HasSuffix(path, ".go") {
return nil return nil
} }
dirScanMu.Lock() dirScanMu.Lock()
if _, dup := dirScan[dir]; !dup { defer dirScanMu.Unlock()
if _, dup := dirScan[dir]; dup {
return nil
}
if dirScan == nil {
dirScan = make(map[string]*pkg)
}
importpath := filepath.ToSlash(dir[len(srcDir)+len("/"):]) importpath := filepath.ToSlash(dir[len(srcDir)+len("/"):])
dirScan[dir] = &pkg{ dirScan[dir] = &pkg{
importPath: importpath, importPath: importpath,
importPathShort: VendorlessPath(importpath), importPathShort: VendorlessPath(importpath),
dir: dir, dir: dir,
} }
}
dirScanMu.Unlock()
return nil return nil
} }
if typ == os.ModeDir { if typ == os.ModeDir {
@ -698,20 +721,20 @@ func VendorlessPath(ipath string) string {
// loadExports returns the set of exported symbols in the package at dir. // loadExports returns the set of exported symbols in the package at dir.
// It returns nil on error or if the package name in dir does not match expectPackage. // It returns nil on error or if the package name in dir does not match expectPackage.
var loadExports func(expectPackage, dir string) map[string]bool = loadExportsGoPath var loadExports func(ctx context.Context, expectPackage, dir string) (map[string]bool, error) = loadExportsGoPath
func loadExportsGoPath(expectPackage, dir string) map[string]bool { func loadExportsGoPath(ctx context.Context, expectPackage, dir string) (map[string]bool, error) {
if Debug { if Debug {
log.Printf("loading exports in dir %s (seeking package %s)", dir, expectPackage) log.Printf("loading exports in dir %s (seeking package %s)", dir, expectPackage)
} }
exports := make(map[string]bool) exports := make(map[string]bool)
ctx := build.Default buildCtx := build.Default
// ReadDir is like ioutil.ReadDir, but only returns *.go files // ReadDir is like ioutil.ReadDir, but only returns *.go files
// and filters out _test.go files since they're not relevant // and filters out _test.go files since they're not relevant
// and only slow things down. // and only slow things down.
ctx.ReadDir = func(dir string) (notTests []os.FileInfo, err error) { buildCtx.ReadDir = func(dir string) (notTests []os.FileInfo, err error) {
all, err := ioutil.ReadDir(dir) all, err := ioutil.ReadDir(dir)
if err != nil { if err != nil {
return nil, err return nil, err
@ -726,16 +749,22 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool {
return notTests, nil return notTests, nil
} }
files, err := ctx.ReadDir(dir) files, err := buildCtx.ReadDir(dir)
if err != nil { if err != nil {
log.Print(err) log.Print(err)
return nil return nil, err
} }
fset := token.NewFileSet() fset := token.NewFileSet()
for _, fi := range files { for _, fi := range files {
match, err := ctx.MatchFile(dir, fi.Name()) select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
match, err := buildCtx.MatchFile(dir, fi.Name())
if err != nil || !match { if err != nil || !match {
continue continue
} }
@ -745,19 +774,20 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool {
if Debug { if Debug {
log.Printf("Parsing %s: %v", fullFile, err) log.Printf("Parsing %s: %v", fullFile, err)
} }
return nil return nil, err
} }
pkgName := f.Name.Name pkgName := f.Name.Name
if pkgName == "documentation" { if pkgName == "documentation" {
// Special case from go/build.ImportDir, not // Special case from go/build.ImportDir, not
// handled by ctx.MatchFile. // handled by buildCtx.MatchFile.
continue continue
} }
if pkgName != expectPackage { if pkgName != expectPackage {
err := fmt.Errorf("scan of dir %v is not expected package %v (actually %v)", dir, expectPackage, pkgName)
if Debug { if Debug {
log.Printf("scan of dir %v is not expected package %v (actually %v)", dir, expectPackage, pkgName) log.Print(err)
} }
return nil return nil, err
} }
for name := range f.Scope.Objects { for name := range f.Scope.Objects {
if ast.IsExported(name) { if ast.IsExported(name) {
@ -774,7 +804,7 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool {
sort.Strings(exportList) sort.Strings(exportList)
log.Printf("loaded exports in dir %v (package %v): %v", dir, expectPackage, strings.Join(exportList, ", ")) log.Printf("loaded exports in dir %v (package %v): %v", dir, expectPackage, strings.Join(exportList, ", "))
} }
return exports return exports, nil
} }
// findImport searches for a package with the given symbols. // findImport searches for a package with the given symbols.
@ -789,16 +819,11 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool {
// import line: // import line:
// import pkg "foo/bar" // import pkg "foo/bar"
// to satisfy uses of pkg.X in the file. // to satisfy uses of pkg.X in the file.
var findImport func(pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) = findImportGoPath var findImport func(ctx context.Context, pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) = findImportGoPath
// findImportGoPath is the normal implementation of findImport. // findImportGoPath is the normal implementation of findImport.
// (Some companies have their own internally.) // (Some companies have their own internally.)
func findImportGoPath(pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) { func findImportGoPath(ctx context.Context, pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) {
if inTests {
testMu.RLock()
defer testMu.RUnlock()
}
pkgDir, err := filepath.Abs(filename) pkgDir, err := filepath.Abs(filename)
if err != nil { if err != nil {
return "", false, err return "", false, err
@ -836,18 +861,25 @@ func findImportGoPath(pkgName string, symbols map[string]bool, filename string)
// //
// TODO(bradfitz): run each $GOPATH entry async. But nobody // TODO(bradfitz): run each $GOPATH entry async. But nobody
// really has more than one anyway, so low priority. // really has more than one anyway, so low priority.
scanGoRootOnce.Do(scanGoRoot) // async scanGoRootOnce.Do(func() {
go func() {
scanGoDirs(goRoot)
close(scanGoRootDone)
}()
})
if !fileInDir(filename, build.Default.GOROOT) { if !fileInDir(filename, build.Default.GOROOT) {
scanGoPathOnce.Do(scanGoPath) // blocking scanGoPathOnce.Do(func() { scanGoDirs(goPath) })
} }
<-scanGoRootDone <-scanGoRootDone
// Find candidate packages, looking only at their directory names first. // Find candidate packages, looking only at their directory names first.
var candidates []*pkg var candidates []pkgDistance
for _, pkg := range dirScan { for _, pkg := range dirScan {
if pkgIsCandidate(filename, pkgName, pkg) { if pkgIsCandidate(filename, pkgName, pkg) {
pkg.distance = distance(pkgDir, pkg.dir) candidates = append(candidates, pkgDistance{
candidates = append(candidates, pkg) pkg: pkg,
distance: distance(pkgDir, pkg.dir),
})
} }
} }
@ -857,60 +889,63 @@ func findImportGoPath(pkgName string, symbols map[string]bool, filename string)
// there's no "penalty" for vendoring. // there's no "penalty" for vendoring.
sort.Sort(byDistanceOrImportPathShortLength(candidates)) sort.Sort(byDistanceOrImportPathShortLength(candidates))
if Debug { if Debug {
for i, pkg := range candidates { for i, c := range candidates {
log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), pkg.importPathShort, pkg.dir) log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir)
} }
} }
// Collect exports for packages with matching names. // Collect exports for packages with matching names.
done := make(chan struct{}) // closed when we find the answer
defer close(done)
rescv := make([]chan *pkg, len(candidates)) rescv := make([]chan *pkg, len(candidates))
for i := range candidates { for i := range candidates {
rescv[i] = make(chan *pkg) rescv[i] = make(chan *pkg, 1)
} }
const maxConcurrentPackageImport = 4 const maxConcurrentPackageImport = 4
loadExportsSem := make(chan struct{}, maxConcurrentPackageImport) loadExportsSem := make(chan struct{}, maxConcurrentPackageImport)
ctx, cancel := context.WithCancel(ctx)
var wg sync.WaitGroup
defer func() {
cancel()
wg.Wait()
}()
wg.Add(1)
go func() { go func() {
for i, pkg := range candidates { defer wg.Done()
for i, c := range candidates {
select { select {
case loadExportsSem <- struct{}{}: case loadExportsSem <- struct{}{}:
select { case <-ctx.Done():
case <-done:
return
default:
}
case <-done:
return return
} }
pkg := pkg
resc := rescv[i] wg.Add(1)
go func() { go func(c pkgDistance, resc chan<- *pkg) {
if inTests { defer func() {
testMu.RLock() <-loadExportsSem
defer testMu.RUnlock() wg.Done()
}()
exports, err := loadExports(ctx, pkgName, c.pkg.dir)
if err != nil {
resc <- nil
return
} }
defer func() { <-loadExportsSem }()
exports := loadExports(pkgName, pkg.dir)
// If it doesn't have the right // If it doesn't have the right
// symbols, send nil to mean no match. // symbols, send nil to mean no match.
for symbol := range symbols { for symbol := range symbols {
if !exports[symbol] { if !exports[symbol] {
pkg = nil resc <- nil
break return
} }
} }
select { resc <- c.pkg
case resc <- pkg: }(c, rescv[i])
case <-done:
}
}()
} }
}() }()
for _, resc := range rescv { for _, resc := range rescv {
pkg := <-resc pkg := <-resc
if pkg == nil { if pkg == nil {

View File

@ -6,6 +6,7 @@ package imports
import ( import (
"bytes" "bytes"
"context"
"flag" "flag"
"go/build" "go/build"
"io/ioutil" "io/ioutil"
@ -909,7 +910,7 @@ func TestFixImports(t *testing.T) {
defer func() { defer func() {
findImport = old findImport = old
}() }()
findImport = func(pkgName string, symbols map[string]bool, filename string) (string, bool, error) { findImport = func(_ context.Context, pkgName string, symbols map[string]bool, filename string) (string, bool, error) {
return simplePkgs[pkgName], pkgName == "str", nil return simplePkgs[pkgName], pkgName == "str", nil
} }
@ -1185,7 +1186,7 @@ type Buffer2 struct {}
} }
build.Default.GOROOT = goroot build.Default.GOROOT = goroot
got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}, "x.go") got, rename, err := findImportGoPath(context.Background(), "bytes", map[string]bool{"Buffer2": true}, "x.go")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1193,7 +1194,7 @@ type Buffer2 struct {}
t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath) t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath)
} }
got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}, "x.go") got, rename, err = findImportGoPath(context.Background(), "bytes", map[string]bool{"Missing": true}, "x.go")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1203,34 +1204,23 @@ type Buffer2 struct {}
}) })
} }
func init() {
inTests = true
}
func withEmptyGoPath(fn func()) { func withEmptyGoPath(fn func()) {
testMu.Lock()
dirScanMu.Lock()
populateIgnoreOnce = sync.Once{} populateIgnoreOnce = sync.Once{}
scanGoRootOnce = sync.Once{} scanGoRootOnce = sync.Once{}
scanGoPathOnce = sync.Once{} scanGoPathOnce = sync.Once{}
dirScan = nil dirScan = nil
ignoredDirs = nil ignoredDirs = nil
scanGoRootDone = make(chan struct{}) scanGoRootDone = make(chan struct{})
dirScanMu.Unlock()
oldGOPATH := build.Default.GOPATH oldGOPATH := build.Default.GOPATH
oldGOROOT := build.Default.GOROOT oldGOROOT := build.Default.GOROOT
build.Default.GOPATH = "" build.Default.GOPATH = ""
testHookScanDir = func(string) {} testHookScanDir = func(string) {}
testMu.Unlock()
defer func() { defer func() {
testMu.Lock()
testHookScanDir = func(string) {} testHookScanDir = func(string) {}
build.Default.GOPATH = oldGOPATH build.Default.GOPATH = oldGOPATH
build.Default.GOROOT = oldGOROOT build.Default.GOROOT = oldGOROOT
testMu.Unlock()
}() }()
fn() fn()
@ -1246,7 +1236,7 @@ func TestFindImportInternal(t *testing.T) {
t.Skip(err) t.Skip(err)
} }
got, rename, err := findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go")) got, rename, err := findImportGoPath(context.Background(), "race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1255,7 +1245,7 @@ func TestFindImportInternal(t *testing.T) {
} }
// should not be able to use internal from outside that tree // should not be able to use internal from outside that tree
got, rename, err = findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "x.go")) got, rename, err = findImportGoPath(context.Background(), "race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "x.go"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -1295,7 +1285,7 @@ func TestFindImportRandRead(t *testing.T) {
for _, sym := range tt.syms { for _, sym := range tt.syms {
m[sym] = true m[sym] = true
} }
got, _, err := findImportGoPath("rand", m, file) got, _, err := findImportGoPath(context.Background(), "rand", m, file)
if err != nil { if err != nil {
t.Errorf("for %q: %v", tt.syms, err) t.Errorf("for %q: %v", tt.syms, err)
continue continue
@ -1313,7 +1303,7 @@ func TestFindImportVendor(t *testing.T) {
"vendor/golang.org/x/net/http2/hpack/huffman.go": "package hpack\nfunc HuffmanDecode() { }\n", "vendor/golang.org/x/net/http2/hpack/huffman.go": "package hpack\nfunc HuffmanDecode() { }\n",
}, },
}.test(t, func(t *goimportTest) { }.test(t, func(t *goimportTest) {
got, rename, err := findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(t.goroot, "src/math/x.go")) got, rename, err := findImportGoPath(context.Background(), "hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(t.goroot, "src/math/x.go"))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }