diff --git a/go/packages/golist_fallback.go b/go/packages/golist_fallback.go index 835e2aba2f..54bb61f06d 100644 --- a/go/packages/golist_fallback.go +++ b/go/packages/golist_fallback.go @@ -127,9 +127,11 @@ func golistDriverFallback(cfg *Config, words ...string) (*driverResponse, error) if isRoot { response.Roots = append(response.Roots, xtestID) } - for i, imp := range p.XTestImports { + // Rewrite import to package under test to refer to test variant. + imports := importMap(p.XTestImports) + for imp := range imports { if imp == p.ImportPath { - p.XTestImports[i] = testID + imports[imp] = &Package{ID: testID} break } } @@ -139,7 +141,7 @@ func golistDriverFallback(cfg *Config, words ...string) (*driverResponse, error) GoFiles: absJoin(p.Dir, p.XTestGoFiles), CompiledGoFiles: absJoin(p.Dir, p.XTestGoFiles), PkgPath: pkgpath, - Imports: importMap(p.XTestImports), + Imports: imports, }) } } diff --git a/go/packages/packages110_test.go b/go/packages/packages110_test.go index d4005682d1..9bc72a2c67 100644 --- a/go/packages/packages110_test.go +++ b/go/packages/packages110_test.go @@ -6,6 +6,51 @@ package packages_test +import ( + "bytes" + "fmt" + "os" + "strings" + "testing" + + "golang.org/x/tools/go/packages" +) + func init() { usesOldGolist = true } + +func TestXTestImports(t *testing.T) { + tmp, cleanup := makeTree(t, map[string]string{ + "src/a/a_test.go": `package a_test; import "a"`, + "src/a/a.go": `package a`, + }) + defer cleanup() + + cfg := &packages.Config{ + Mode: packages.LoadImports, + Dir: tmp, + Env: append(os.Environ(), "GOPATH="+tmp, "GO111MODULE=off"), + Tests: true, + } + initial, err := packages.Load(cfg, "a") + if err != nil { + t.Fatal(err) + } + + var gotImports bytes.Buffer + for _, pkg := range initial { + var imports []string + for imp, pkg := range pkg.Imports { + imports = append(imports, fmt.Sprintf("%q: %q", imp, pkg.ID)) + } + fmt.Fprintf(&gotImports, "%s {%s}\n", pkg.ID, strings.Join(imports, ", ")) + } + wantImports := `a {} +a [a.test] {} +a_test [a.test] {"a": "a [a.test]"} +` + if gotImports.String() != wantImports { + t.Fatalf("wrong imports: got <<%s>>, want <<%s>>", gotImports.String(), wantImports) + } +}