diff --git a/imports/fix.go b/imports/fix.go index d13836ce746..be8382991bf 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -194,6 +194,54 @@ func (g gate) leave() { <-g } // Too much disk I/O -> too many threads -> swapping and bad scheduling. var fsgate = make(gate, 8) +var visitedSymlinks struct { + sync.Mutex + m map[string]struct{} +} + +// shouldTraverse checks if fi, found in dir, is a directory or a symlink to a directory. +// It makes sure symlinks were never visited before to avoid symlink loops. +func shouldTraverse(dir string, fi os.FileInfo) bool { + if fi.IsDir() { + return true + } + + if fi.Mode()&os.ModeSymlink == 0 { + return false + } + path := filepath.Join(dir, fi.Name()) + target, err := filepath.EvalSymlinks(path) + if err != nil { + fmt.Fprint(os.Stderr, err) + return false + } + ts, err := os.Stat(target) + if err != nil { + fmt.Fprint(os.Stderr, err) + return false + } + if !ts.IsDir() { + return false + } + + realParent, err := filepath.EvalSymlinks(dir) + if err != nil { + fmt.Fprint(os.Stderr, err) + return false + } + realPath := filepath.Join(realParent, fi.Name()) + visitedSymlinks.Lock() + defer visitedSymlinks.Unlock() + if visitedSymlinks.m == nil { + visitedSymlinks.m = make(map[string]struct{}) + } + if _, ok := visitedSymlinks.m[realPath]; ok { + return false + } + visitedSymlinks.m[realPath] = struct{}{} + return true +} + func loadPkgIndex() { pkgIndex.Lock() pkgIndex.m = make(map[string][]pkg) @@ -216,7 +264,7 @@ func loadPkgIndex() { continue } for _, child := range children { - if child.IsDir() { + if shouldTraverse(path, child) { wg.Add(1) go func(path, name string) { defer wg.Done() @@ -257,7 +305,7 @@ func loadPkg(wg *sync.WaitGroup, root, pkgrelpath string) { if strings.HasSuffix(name, ".go") { hasGo = true } - if child.IsDir() { + if shouldTraverse(dir, child) { wg.Add(1) go func(root, name string) { defer wg.Done() diff --git a/imports/fix_test.go b/imports/fix_test.go index 9c89ef67390..78a50a8cb4a 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -823,6 +823,76 @@ func TestFixImports(t *testing.T) { } } +// Test support for packages in GOPATH that are actually symlinks. +// Also test that a symlink loop does not block the process. +func TestImportSymlinks(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("Skipping test on Windows as there are no symlinks.") + } + + newGoPath, err := ioutil.TempDir("", "symlinktest") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(newGoPath) + + targetPath := newGoPath + "/target" + if err := os.MkdirAll(targetPath, 0755); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(targetPath+"/f.go", []byte("package mypkg\nvar Foo = 123\n"), 0666); err != nil { + t.Fatal(err) + } + + symlinkPath := newGoPath + "/src/x/mypkg" + if err := os.MkdirAll(filepath.Dir(symlinkPath), 0755); err != nil { + t.Fatal(err) + } + if err := os.Symlink(targetPath, symlinkPath); err != nil { + t.Fatal(err) + } + + // Add a symlink loop. + if err := os.Symlink(newGoPath+"/src/x", newGoPath+"/src/x/apkg"); err != nil { + t.Fatal(err) + } + + pkgIndexOnce = &sync.Once{} + oldGOPATH := build.Default.GOPATH + build.Default.GOPATH = newGoPath + defer func() { + build.Default.GOPATH = oldGOPATH + visitedSymlinks.m = nil + }() + + input := `package p + +var ( + _ = fmt.Print + _ = mypkg.Foo +) +` + output := `package p + +import ( + "fmt" + "x/mypkg" +) + +var ( + _ = fmt.Print + _ = mypkg.Foo +) +` + buf, err := Process(newGoPath+"/src/myotherpkg/toformat.go", []byte(input), &Options{}) + if err != nil { + t.Fatal(err) + } + if got := string(buf); got != output { + t.Fatalf("results differ\nGOT:\n%s\nWANT:\n%s\n", got, output) + } +} + // Test for correctly identifying the name of a vendored package when it // differs from its directory name. In this test, the import line // "mypkg.com/mypkg.v1" would be removed if goimports wasn't able to detect