diff --git a/go/packages/golist_overlay.go b/go/packages/golist_overlay.go index bdba230e1ed..3c99b6e48d7 100644 --- a/go/packages/golist_overlay.go +++ b/go/packages/golist_overlay.go @@ -5,6 +5,7 @@ import ( "fmt" "go/parser" "go/token" + "log" "os" "path/filepath" "sort" @@ -22,10 +23,15 @@ func (state *golistState) processGolistOverlay(response *responseDeduper) (modif needPkgsSet := make(map[string]bool) modifiedPkgsSet := make(map[string]bool) + pkgOfDir := make(map[string][]*Package) for _, pkg := range response.dr.Packages { // This is an approximation of import path to id. This can be // wrong for tests, vendored packages, and a number of other cases. havePkgs[pkg.PkgPath] = pkg.ID + x := commonDir(pkg.GoFiles) + if x != "" { + pkgOfDir[x] = append(pkgOfDir[x], pkg) + } } // If no new imports are added, it is safe to avoid loading any needPkgs. @@ -64,6 +70,9 @@ func (state *golistState) processGolistOverlay(response *responseDeduper) (modif // to the overlay. continue } + // if all the overlay files belong to a different package, change the package + // name to that package. Otherwise leave it alone; there will be an error message. + maybeFixPackageName(pkgName, pkgOfDir, dir) nextPackage: for _, p := range response.dr.Packages { if pkgName != p.Name && p.ID != "command-line-arguments" { @@ -384,3 +393,46 @@ func extractPackageName(filename string, contents []byte) (string, bool) { } return f.Name.Name, true } + +func commonDir(a []string) string { + seen := make(map[string]bool) + x := append([]string{}, a...) + for _, f := range x { + seen[filepath.Dir(f)] = true + } + if len(seen) > 1 { + log.Fatalf("commonDir saw %v for %v", seen, x) + } + for k := range seen { + // len(seen) == 1 + return k + } + return "" // no files +} + +// It is possible that the files in the disk directory dir have a different package +// name from newName, which is deduced from the overlays. If they all have a different +// package name, and they all have the same package name, then that name becomes +// the package name. +// It returns true if it changes the package name, false otherwise. +func maybeFixPackageName(newName string, pkgOfDir map[string][]*Package, dir string) bool { + names := make(map[string]int) + for _, p := range pkgOfDir[dir] { + names[p.Name]++ + } + if len(names) != 1 { + // some files are in different packages + return false + } + oldName := "" + for k := range names { + oldName = k + } + if newName == oldName { + return false + } + for _, p := range pkgOfDir[dir] { + p.Name = newName + } + return true +} diff --git a/go/packages/overlay_test.go b/go/packages/overlay_test.go new file mode 100644 index 00000000000..7459f2b8666 --- /dev/null +++ b/go/packages/overlay_test.go @@ -0,0 +1,134 @@ +package packages_test + +import ( + "log" + "path/filepath" + "testing" + + "golang.org/x/tools/go/packages" + "golang.org/x/tools/go/packages/packagestest" +) + +const commonMode = packages.NeedName | packages.NeedFiles | + packages.NeedCompiledGoFiles | packages.NeedImports | packages.NeedSyntax + +func TestOverlayChangesPackage(t *testing.T) { + log.SetFlags(log.Lshortfile) + exported := packagestest.Export(t, packagestest.GOPATH, []packagestest.Module{{ + Name: "fake", + Files: map[string]interface{}{ + "a.go": "package foo\nfunc f(){}\n", + }, + Overlay: map[string][]byte{ + "a.go": []byte("package foox\nfunc f(){}\n"), + }, + }}) + defer exported.Cleanup() + exported.Config.Mode = packages.NeedName + + initial, err := packages.Load(exported.Config, + filepath.Dir(exported.File("fake", "a.go"))) + if err != nil { + t.Fatalf("failed to load: %v", err) + } + if len(initial) != 1 || initial[0].ID != "fake" || initial[0].Name != "foox" { + t.Fatalf("got %v, expected [fake]", initial) + } + if len(initial[0].Errors) != 0 { + t.Fatalf("got %v, expected no errors", initial[0].Errors) + } + log.SetFlags(0) +} +func TestOverlayChangesBothPackages(t *testing.T) { + log.SetFlags(log.Lshortfile) + exported := packagestest.Export(t, packagestest.GOPATH, []packagestest.Module{{ + Name: "fake", + Files: map[string]interface{}{ + "a.go": "package foo\nfunc g(){}\n", + "a_test.go": "package foo\nfunc f(){}\n", + }, + Overlay: map[string][]byte{ + "a.go": []byte("package foox\nfunc g(){}\n"), + "a_test.go": []byte("package foox\nfunc f(){}\n"), + }, + }}) + defer exported.Cleanup() + exported.Config.Mode = commonMode + + initial, err := packages.Load(exported.Config, + filepath.Dir(exported.File("fake", "a.go"))) + if err != nil { + t.Fatalf("failed to load: %v", err) + } + if len(initial) != 3 { + t.Errorf("got %d packges, expected 3", len(initial)) + } + want := []struct { + id, name string + count int + }{ + {"fake", "foox", 1}, + {"fake [fake.test]", "foox", 2}, + {"fake.test", "main", 1}, + } + for i := 0; i < 3; i++ { + if ok := checkPkg(t, initial[i], want[i].id, want[i].name, want[i].count); !ok { + t.Errorf("%d: got {%s %s %d}, expected %v", i, initial[i].ID, + initial[i].Name, len(initial[i].Syntax), want[i]) + } + if len(initial[i].Errors) != 0 { + t.Errorf("%d: got %v, expected no errors", i, initial[i].Errors) + } + } + log.SetFlags(0) +} + +func TestOverlayChangesTestPackage(t *testing.T) { + log.SetFlags(log.Lshortfile) + exported := packagestest.Export(t, packagestest.GOPATH, []packagestest.Module{{ + Name: "fake", + Files: map[string]interface{}{ + "a_test.go": "package foo\nfunc f(){}\n", + }, + Overlay: map[string][]byte{ + "a_test.go": []byte("package foox\nfunc f(){}\n"), + }, + }}) + defer exported.Cleanup() + exported.Config.Mode = commonMode + + initial, err := packages.Load(exported.Config, + filepath.Dir(exported.File("fake", "a_test.go"))) + if err != nil { + t.Fatalf("failed to load: %v", err) + } + if len(initial) != 3 { + t.Errorf("got %d packges, expected 3", len(initial)) + } + want := []struct { + id, name string + count int + }{ + {"fake", "foo", 0}, + {"fake [fake.test]", "foox", 1}, + {"fake.test", "main", 1}, + } + for i := 0; i < 3; i++ { + if ok := checkPkg(t, initial[i], want[i].id, want[i].name, want[i].count); !ok { + t.Errorf("got {%s %s %d}, expected %v", initial[i].ID, + initial[i].Name, len(initial[i].Syntax), want[i]) + } + } + if len(initial[0].Errors) != 0 { + t.Fatalf("got %v, expected no errors", initial[0].Errors) + } + log.SetFlags(0) +} + +func checkPkg(t *testing.T, p *packages.Package, id, name string, syntax int) bool { + t.Helper() + if p.ID == id && p.Name == name && len(p.Syntax) == syntax { + return true + } + return false +}