1
0
mirror of https://github.com/golang/go synced 2024-10-01 01:48:32 -06:00

imports: fix races in findImportGoPath

Before this change, findImportGoPath used a field within the
(otherwise read-only) structs in the dirScan map to cache the distance
from the importing package to the candidate package to be imported. As
a result, the top-level imports.Process function was not safe to call
concurrently: one goroutine could overwrite the distances while
another was attempting to sort by them. Furthermore, there were some
internal write-after-write races (writing the same cached distance to
the same address) that otherwise violate the Go memory model.

This change fixes those races, simplifies the concurrency patterns,
and clarifies goroutine lifetimes. The functions in the imports
package now wait for the goroutines they spawn to finish before
returning, eliminating the need for an awkward test-only mutex that
could otherwise mask real races in the production code paths.

See also:
https://golang.org/wiki/CodeReviewComments#goroutine-lifetimes
https://golang.org/wiki/CodeReviewComments#synchronous-functions

Fixes golang/go#25030.

Change-Id: I8fec735e0d4ff7abab406dea9d0c11d1bd93d775
Reviewed-on: https://go-review.googlesource.com/109156
Run-TryBot: Bryan C. Mills <bcmills@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Joe Tsai <thebrokentoaster@gmail.com>
This commit is contained in:
Bryan C. Mills 2018-04-24 16:05:29 -04:00
parent 8e070db38e
commit 165bdd618e
2 changed files with 155 additions and 130 deletions

View File

@ -7,6 +7,7 @@ package imports
import (
"bufio"
"bytes"
"context"
"fmt"
"go/ast"
"go/build"
@ -28,11 +29,6 @@ import (
// Debug controls verbose logging.
var Debug = false
var (
inTests = false // set true by fix_test.go; if false, no need to use testMu
testMu sync.RWMutex // guards globals reset by tests; used only if inTests
)
// LocalPrefix is a comma-separated string of import path prefixes, which, if
// set, instructs Process to sort the import paths with the given prefixes
// into another group after 3rd-party packages.
@ -293,15 +289,27 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri
}
// Search for imports matching potential package references.
searches := 0
type result struct {
ipath string // import path (if err == nil)
ipath string // import path
name string // optional name to rename import as
err error
}
results := make(chan result)
results := make(chan result, len(refs))
ctx, cancel := context.WithCancel(context.TODO())
var wg sync.WaitGroup
defer func() {
cancel()
wg.Wait()
}()
var (
firstErr error
firstErrOnce sync.Once
)
for pkgName, symbols := range refs {
wg.Add(1)
go func(pkgName string, symbols map[string]bool) {
defer wg.Done()
if packageInfo != nil {
sibling := packageInfo.Imports[pkgName]
if sibling.Path != "" {
@ -314,21 +322,34 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri
}
}
}
ipath, rename, err := findImport(pkgName, symbols, filename)
r := result{ipath: ipath, err: err}
ipath, rename, err := findImport(ctx, pkgName, symbols, filename)
if err != nil {
firstErrOnce.Do(func() {
firstErr = err
cancel()
})
return
}
if ipath == "" {
return // No matching package.
}
r := result{ipath: ipath}
if rename {
r.name = pkgName
}
results <- r
return
}(pkgName, symbols)
searches++
}
for i := 0; i < searches; i++ {
result := <-results
if result.err != nil {
return nil, result.err
}
if result.ipath != "" {
go func() {
wg.Wait()
close(results)
}()
for result := range results {
if result.name != "" {
astutil.AddNamedImport(fset, f, result.name, result.ipath)
} else {
@ -336,8 +357,10 @@ func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []stri
}
added = append(added, result.ipath)
}
}
if firstErr != nil {
return nil, firstErr
}
return added, nil
}
@ -446,7 +469,7 @@ var (
populateIgnoreOnce sync.Once
ignoredDirs []os.FileInfo
dirScanMu sync.RWMutex
dirScanMu sync.Mutex
dirScan map[string]*pkg // abs dir path => *pkg
)
@ -454,12 +477,16 @@ type pkg struct {
dir string // absolute file path to pkg directory ("/usr/lib/go/src/net/http")
importPath string // full pkg import path ("net/http", "foo/bar/vendor/a/b")
importPathShort string // vendorless import path ("net/http", "a/b")
}
type pkgDistance struct {
pkg *pkg
distance int // relative distance to target
}
// byDistanceOrImportPathShortLength sorts by relative distance breaking ties
// on the short import path length and then the import string itself.
type byDistanceOrImportPathShortLength []*pkg
type byDistanceOrImportPathShortLength []pkgDistance
func (s byDistanceOrImportPathShortLength) Len() int { return len(s) }
func (s byDistanceOrImportPathShortLength) Less(i, j int) bool {
@ -474,7 +501,7 @@ func (s byDistanceOrImportPathShortLength) Less(i, j int) bool {
return di < dj
}
vi, vj := s[i].importPathShort, s[j].importPathShort
vi, vj := s[i].pkg.importPathShort, s[j].pkg.importPathShort
if len(vi) != len(vj) {
return len(vi) < len(vj)
}
@ -590,35 +617,26 @@ func shouldTraverse(dir string, fi os.FileInfo) bool {
var testHookScanDir = func(dir string) {}
type goDirType string
const (
goRoot goDirType = "$GOROOT"
goPath goDirType = "$GOPATH"
)
var scanGoRootDone = make(chan struct{}) // closed when scanGoRoot is done
func scanGoRoot() {
go func() {
scanGoDirs(true)
close(scanGoRootDone)
}()
}
func scanGoPath() { scanGoDirs(false) }
func scanGoDirs(goRoot bool) {
// scanGoDirs populates the dirScan map for the given directory type. It may be
// called concurrently (and usually is, if both directory types are needed).
func scanGoDirs(which goDirType) {
if Debug {
which := "$GOROOT"
if !goRoot {
which = "$GOPATH"
log.Printf("scanning %s", which)
defer log.Printf("scanned %s", which)
}
log.Printf("scanning " + which)
defer log.Printf("scanned " + which)
}
dirScanMu.Lock()
if dirScan == nil {
dirScan = make(map[string]*pkg)
}
dirScanMu.Unlock()
for _, srcDir := range build.Default.SrcDirs() {
isGoroot := srcDir == filepath.Join(build.Default.GOROOT, "src")
if isGoroot != goRoot {
if isGoroot != (which == goRoot) {
continue
}
testHookScanDir(srcDir)
@ -633,16 +651,21 @@ func scanGoDirs(goRoot bool) {
if !strings.HasSuffix(path, ".go") {
return nil
}
dirScanMu.Lock()
if _, dup := dirScan[dir]; !dup {
defer dirScanMu.Unlock()
if _, dup := dirScan[dir]; dup {
return nil
}
if dirScan == nil {
dirScan = make(map[string]*pkg)
}
importpath := filepath.ToSlash(dir[len(srcDir)+len("/"):])
dirScan[dir] = &pkg{
importPath: importpath,
importPathShort: VendorlessPath(importpath),
dir: dir,
}
}
dirScanMu.Unlock()
return nil
}
if typ == os.ModeDir {
@ -698,20 +721,20 @@ func VendorlessPath(ipath string) string {
// loadExports returns the set of exported symbols in the package at dir.
// It returns nil on error or if the package name in dir does not match expectPackage.
var loadExports func(expectPackage, dir string) map[string]bool = loadExportsGoPath
var loadExports func(ctx context.Context, expectPackage, dir string) (map[string]bool, error) = loadExportsGoPath
func loadExportsGoPath(expectPackage, dir string) map[string]bool {
func loadExportsGoPath(ctx context.Context, expectPackage, dir string) (map[string]bool, error) {
if Debug {
log.Printf("loading exports in dir %s (seeking package %s)", dir, expectPackage)
}
exports := make(map[string]bool)
ctx := build.Default
buildCtx := build.Default
// ReadDir is like ioutil.ReadDir, but only returns *.go files
// and filters out _test.go files since they're not relevant
// and only slow things down.
ctx.ReadDir = func(dir string) (notTests []os.FileInfo, err error) {
buildCtx.ReadDir = func(dir string) (notTests []os.FileInfo, err error) {
all, err := ioutil.ReadDir(dir)
if err != nil {
return nil, err
@ -726,16 +749,22 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool {
return notTests, nil
}
files, err := ctx.ReadDir(dir)
files, err := buildCtx.ReadDir(dir)
if err != nil {
log.Print(err)
return nil
return nil, err
}
fset := token.NewFileSet()
for _, fi := range files {
match, err := ctx.MatchFile(dir, fi.Name())
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
match, err := buildCtx.MatchFile(dir, fi.Name())
if err != nil || !match {
continue
}
@ -745,19 +774,20 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool {
if Debug {
log.Printf("Parsing %s: %v", fullFile, err)
}
return nil
return nil, err
}
pkgName := f.Name.Name
if pkgName == "documentation" {
// Special case from go/build.ImportDir, not
// handled by ctx.MatchFile.
// handled by buildCtx.MatchFile.
continue
}
if pkgName != expectPackage {
err := fmt.Errorf("scan of dir %v is not expected package %v (actually %v)", dir, expectPackage, pkgName)
if Debug {
log.Printf("scan of dir %v is not expected package %v (actually %v)", dir, expectPackage, pkgName)
log.Print(err)
}
return nil
return nil, err
}
for name := range f.Scope.Objects {
if ast.IsExported(name) {
@ -774,7 +804,7 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool {
sort.Strings(exportList)
log.Printf("loaded exports in dir %v (package %v): %v", dir, expectPackage, strings.Join(exportList, ", "))
}
return exports
return exports, nil
}
// findImport searches for a package with the given symbols.
@ -789,16 +819,11 @@ func loadExportsGoPath(expectPackage, dir string) map[string]bool {
// import line:
// import pkg "foo/bar"
// to satisfy uses of pkg.X in the file.
var findImport func(pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) = findImportGoPath
var findImport func(ctx context.Context, pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) = findImportGoPath
// findImportGoPath is the normal implementation of findImport.
// (Some companies have their own internally.)
func findImportGoPath(pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) {
if inTests {
testMu.RLock()
defer testMu.RUnlock()
}
func findImportGoPath(ctx context.Context, pkgName string, symbols map[string]bool, filename string) (foundPkg string, rename bool, err error) {
pkgDir, err := filepath.Abs(filename)
if err != nil {
return "", false, err
@ -836,18 +861,25 @@ func findImportGoPath(pkgName string, symbols map[string]bool, filename string)
//
// TODO(bradfitz): run each $GOPATH entry async. But nobody
// really has more than one anyway, so low priority.
scanGoRootOnce.Do(scanGoRoot) // async
scanGoRootOnce.Do(func() {
go func() {
scanGoDirs(goRoot)
close(scanGoRootDone)
}()
})
if !fileInDir(filename, build.Default.GOROOT) {
scanGoPathOnce.Do(scanGoPath) // blocking
scanGoPathOnce.Do(func() { scanGoDirs(goPath) })
}
<-scanGoRootDone
// Find candidate packages, looking only at their directory names first.
var candidates []*pkg
var candidates []pkgDistance
for _, pkg := range dirScan {
if pkgIsCandidate(filename, pkgName, pkg) {
pkg.distance = distance(pkgDir, pkg.dir)
candidates = append(candidates, pkg)
candidates = append(candidates, pkgDistance{
pkg: pkg,
distance: distance(pkgDir, pkg.dir),
})
}
}
@ -857,60 +889,63 @@ func findImportGoPath(pkgName string, symbols map[string]bool, filename string)
// there's no "penalty" for vendoring.
sort.Sort(byDistanceOrImportPathShortLength(candidates))
if Debug {
for i, pkg := range candidates {
log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), pkg.importPathShort, pkg.dir)
for i, c := range candidates {
log.Printf("%s candidate %d/%d: %v in %v", pkgName, i+1, len(candidates), c.pkg.importPathShort, c.pkg.dir)
}
}
// Collect exports for packages with matching names.
done := make(chan struct{}) // closed when we find the answer
defer close(done)
rescv := make([]chan *pkg, len(candidates))
for i := range candidates {
rescv[i] = make(chan *pkg)
rescv[i] = make(chan *pkg, 1)
}
const maxConcurrentPackageImport = 4
loadExportsSem := make(chan struct{}, maxConcurrentPackageImport)
ctx, cancel := context.WithCancel(ctx)
var wg sync.WaitGroup
defer func() {
cancel()
wg.Wait()
}()
wg.Add(1)
go func() {
for i, pkg := range candidates {
defer wg.Done()
for i, c := range candidates {
select {
case loadExportsSem <- struct{}{}:
select {
case <-done:
return
default:
}
case <-done:
case <-ctx.Done():
return
}
pkg := pkg
resc := rescv[i]
go func() {
if inTests {
testMu.RLock()
defer testMu.RUnlock()
wg.Add(1)
go func(c pkgDistance, resc chan<- *pkg) {
defer func() {
<-loadExportsSem
wg.Done()
}()
exports, err := loadExports(ctx, pkgName, c.pkg.dir)
if err != nil {
resc <- nil
return
}
defer func() { <-loadExportsSem }()
exports := loadExports(pkgName, pkg.dir)
// If it doesn't have the right
// symbols, send nil to mean no match.
for symbol := range symbols {
if !exports[symbol] {
pkg = nil
break
resc <- nil
return
}
}
select {
case resc <- pkg:
case <-done:
}
}()
resc <- c.pkg
}(c, rescv[i])
}
}()
for _, resc := range rescv {
pkg := <-resc
if pkg == nil {

View File

@ -6,6 +6,7 @@ package imports
import (
"bytes"
"context"
"flag"
"go/build"
"io/ioutil"
@ -909,7 +910,7 @@ func TestFixImports(t *testing.T) {
defer func() {
findImport = old
}()
findImport = func(pkgName string, symbols map[string]bool, filename string) (string, bool, error) {
findImport = func(_ context.Context, pkgName string, symbols map[string]bool, filename string) (string, bool, error) {
return simplePkgs[pkgName], pkgName == "str", nil
}
@ -1185,7 +1186,7 @@ type Buffer2 struct {}
}
build.Default.GOROOT = goroot
got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}, "x.go")
got, rename, err := findImportGoPath(context.Background(), "bytes", map[string]bool{"Buffer2": true}, "x.go")
if err != nil {
t.Fatal(err)
}
@ -1193,7 +1194,7 @@ type Buffer2 struct {}
t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath)
}
got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}, "x.go")
got, rename, err = findImportGoPath(context.Background(), "bytes", map[string]bool{"Missing": true}, "x.go")
if err != nil {
t.Fatal(err)
}
@ -1203,34 +1204,23 @@ type Buffer2 struct {}
})
}
func init() {
inTests = true
}
func withEmptyGoPath(fn func()) {
testMu.Lock()
dirScanMu.Lock()
populateIgnoreOnce = sync.Once{}
scanGoRootOnce = sync.Once{}
scanGoPathOnce = sync.Once{}
dirScan = nil
ignoredDirs = nil
scanGoRootDone = make(chan struct{})
dirScanMu.Unlock()
oldGOPATH := build.Default.GOPATH
oldGOROOT := build.Default.GOROOT
build.Default.GOPATH = ""
testHookScanDir = func(string) {}
testMu.Unlock()
defer func() {
testMu.Lock()
testHookScanDir = func(string) {}
build.Default.GOPATH = oldGOPATH
build.Default.GOROOT = oldGOROOT
testMu.Unlock()
}()
fn()
@ -1246,7 +1236,7 @@ func TestFindImportInternal(t *testing.T) {
t.Skip(err)
}
got, rename, err := findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go"))
got, rename, err := findImportGoPath(context.Background(), "race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go"))
if err != nil {
t.Fatal(err)
}
@ -1255,7 +1245,7 @@ func TestFindImportInternal(t *testing.T) {
}
// should not be able to use internal from outside that tree
got, rename, err = findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "x.go"))
got, rename, err = findImportGoPath(context.Background(), "race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "x.go"))
if err != nil {
t.Fatal(err)
}
@ -1295,7 +1285,7 @@ func TestFindImportRandRead(t *testing.T) {
for _, sym := range tt.syms {
m[sym] = true
}
got, _, err := findImportGoPath("rand", m, file)
got, _, err := findImportGoPath(context.Background(), "rand", m, file)
if err != nil {
t.Errorf("for %q: %v", tt.syms, err)
continue
@ -1313,7 +1303,7 @@ func TestFindImportVendor(t *testing.T) {
"vendor/golang.org/x/net/http2/hpack/huffman.go": "package hpack\nfunc HuffmanDecode() { }\n",
},
}.test(t, func(t *goimportTest) {
got, rename, err := findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(t.goroot, "src/math/x.go"))
got, rename, err := findImportGoPath(context.Background(), "hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(t.goroot, "src/math/x.go"))
if err != nil {
t.Fatal(err)
}