diff --git a/go/packages/golist.go b/go/packages/golist.go index c4a1bd687f..54b0dbba50 100644 --- a/go/packages/golist.go +++ b/go/packages/golist.go @@ -130,12 +130,13 @@ extractQueries: response.Packages = append(response.Packages, p) } + var containsCandidates []string + if len(containFiles) != 0 { - containsResults, err := runContainsQueries(cfg, listfunc, isFallback, addPkg, containFiles) + containsCandidates, err = runContainsQueries(cfg, listfunc, isFallback, addPkg, containFiles) if err != nil { return nil, err } - response.Roots = append(response.Roots, containsResults...) } if len(packagesNamed) != 0 { @@ -146,12 +147,33 @@ extractQueries: response.Roots = append(response.Roots, namedResults...) } - needPkgs, err := processGolistOverlay(cfg, response) + modifiedPkgs, needPkgs, err := processGolistOverlay(cfg, response) if err != nil { return nil, err } + if len(containFiles) > 0 { + containsCandidates = append(containsCandidates, modifiedPkgs...) + containsCandidates = append(containsCandidates, needPkgs...) + } + if len(needPkgs) > 0 { addNeededOverlayPackages(cfg, listfunc, addPkg, needPkgs) + if err != nil { + return nil, err + } + } + // Check candidate packages for containFiles. + if len(containFiles) > 0 { + for _, id := range containsCandidates { + pkg := seenPkgs[id] + for _, f := range containFiles { + for _, g := range pkg.GoFiles { + if sameFile(f, g) { + response.Roots = append(response.Roots, id) + } + } + } + } } return response, nil diff --git a/go/packages/golist_overlay.go b/go/packages/golist_overlay.go index 60438a1acd..71ffcd9d55 100644 --- a/go/packages/golist_overlay.go +++ b/go/packages/golist_overlay.go @@ -16,9 +16,10 @@ import ( // - adding test and non-test files to test variants of packages // - determining the correct package to add given a new import path // - creating packages that don't exist -func processGolistOverlay(cfg *Config, response *driverResponse) (needPkgs []string, err error) { +func processGolistOverlay(cfg *Config, response *driverResponse) (modifiedPkgs, needPkgs []string, err error) { havePkgs := make(map[string]string) // importPath -> non-test package ID needPkgsSet := make(map[string]bool) + modifiedPkgsSet := make(map[string]bool) for _, pkg := range response.Packages { // This is an approximation of import path to id. This can be @@ -49,6 +50,7 @@ outer: if !fileExists { pkg.GoFiles = append(pkg.GoFiles, path) // TODO(matloob): should the file just be added to GoFiles? pkg.CompiledGoFiles = append(pkg.CompiledGoFiles, path) + modifiedPkgsSet[pkg.ID] = true } imports, err := extractImports(path, contents) if err != nil { @@ -77,7 +79,11 @@ outer: for pkg := range needPkgsSet { needPkgs = append(needPkgs, pkg) } - return needPkgs, err + modifiedPkgs = make([]string, 0, len(modifiedPkgsSet)) + for pkg := range modifiedPkgsSet { + modifiedPkgs = append(modifiedPkgs, pkg) + } + return modifiedPkgs, needPkgs, err } func extractImports(filename string, contents []byte) ([]string, error) { diff --git a/go/packages/packages_test.go b/go/packages/packages_test.go index 207996919c..39095bf918 100644 --- a/go/packages/packages_test.go +++ b/go/packages/packages_test.go @@ -1058,6 +1058,35 @@ func testContains(t *testing.T, exporter packagestest.Exporter) { } } +func TestContainsOverlay(t *testing.T) { packagestest.TestAll(t, testContainsOverlay) } +func testContainsOverlay(t *testing.T, exporter packagestest.Exporter) { + exported := packagestest.Export(t, exporter, []packagestest.Module{{ + Name: "golang.org/fake", + Files: map[string]interface{}{ + "a/a.go": `package a; import "golang.org/fake/b"`, + "b/b.go": `package b; import "golang.org/fake/c"`, + "c/c.go": `package c`, + }}}) + defer exported.Cleanup() + bOverlayFile := filepath.Join(filepath.Dir(exported.File("golang.org/fake", "b/b.go")), "b_overlay.go") + exported.Config.Mode = packages.LoadImports + exported.Config.Overlay = map[string][]byte{bOverlayFile: []byte(`package b;`)} + initial, err := packages.Load(exported.Config, "file="+bOverlayFile) + if err != nil { + t.Fatal(err) + } + + graph, _ := importGraph(initial) + wantGraph := ` +* golang.org/fake/b + golang.org/fake/c + golang.org/fake/b -> golang.org/fake/c +`[1:] + if graph != wantGraph { + t.Errorf("wrong import graph: got <<%s>>, want <<%s>>", graph, wantGraph) + } +} + // This test ensures that the effective GOARCH variable in the // application determines the Sizes function used by the type checker. // This behavior is a stop-gap until we make the build system's query