From bf084ef7580ee99a5efa3086138c942aca4aefd4 Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Fri, 11 Dec 2015 01:32:07 -0500 Subject: [PATCH] imports: add support for vendor directories Editor modes that invoke the goimports command on temporary copies of actual source files will need to invoke goimports -srcdir now to say where the real source directory is. Otherwise goimports will not consider vendored or internal packages when looking for new imports. In lieu of a test for cmd/goimports (because it has no tests), a command transcript: $ cd /tmp $ cat x.go package p var _ = hpack.HuffmanDecode $ $ GOPATH= goimports < x.go package p var _ = hpack.HuffmanDecode $ GOPATH= goimports x.go package p var _ = hpack.HuffmanDecode $ But with the new flag: $ GOPATH= goimports -srcdir $GOROOT/src/math < x.go package p import "golang.org/x/net/http2/hpack" var _ = hpack.HuffmanDecode $ GOPATH= goimports -srcdir $GOROOT/src/math x.go package p import "golang.org/x/net/http2/hpack" var _ = hpack.HuffmanDecode $ The tests in this CL and the above transcript assume that $GOROOT/src/vendor/golang.org/x/net/http2/hpack exists. It did in 40a26c9, but it does not today. It will again soon (once Go 1.7 opens). For golang/go#12278 (original request). Change-Id: I27b136041f54edcde4bf474215b48ebb0417f34d Reviewed-on: https://go-review.googlesource.com/17728 Run-TryBot: Russ Cox Reviewed-by: Andrew Gerrand --- cmd/goimports/goimports.go | 10 +++- imports/fix.go | 94 +++++++++++++++++++++++------------- imports/fix_test.go | 98 ++++++++++++++++++++++++++++++++++++-- imports/imports.go | 6 ++- 4 files changed, 170 insertions(+), 38 deletions(-) diff --git a/cmd/goimports/goimports.go b/cmd/goimports/goimports.go index b0b7aa8393..d7857d22f4 100644 --- a/cmd/goimports/goimports.go +++ b/cmd/goimports/goimports.go @@ -25,6 +25,7 @@ var ( list = flag.Bool("l", false, "list files whose formatting differs from goimport's") write = flag.Bool("w", false, "write result to (source) file instead of stdout") doDiff = flag.Bool("d", false, "display diffs instead of rewriting files") + srcdir = flag.String("srcdir", "", "choose imports as if source code is from `dir`") options = &imports.Options{ TabWidth: 8, @@ -78,7 +79,14 @@ func processFile(filename string, in io.Reader, out io.Writer, stdin bool) error return err } - res, err := imports.Process(filename, src, opt) + target := filename + if *srcdir != "" { + // Pretend that file is from *srcdir in order to decide + // visible imports correctly. + target = filepath.Join(*srcdir, filepath.Base(filename)) + } + + res, err := imports.Process(target, src, opt) if err != nil { return err } diff --git a/imports/fix.go b/imports/fix.go index 3ccee0ee1b..a201f9efee 100644 --- a/imports/fix.go +++ b/imports/fix.go @@ -45,7 +45,7 @@ func importGroup(importPath string) int { return 0 } -func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) { +func fixImports(fset *token.FileSet, f *ast.File, filename string) (added []string, err error) { // refs are a set of possible package references currently unsatisfied by imports. // first key: either base package (e.g. "fmt") or renamed package // second key: referenced package symbol (e.g. "Println") @@ -117,7 +117,7 @@ func fixImports(fset *token.FileSet, f *ast.File) (added []string, err error) { continue // skip over packages already imported } go func(pkgName string, symbols map[string]bool) { - ipath, rename, err := findImport(pkgName, symbols) + ipath, rename, err := findImport(pkgName, symbols, filename) r := result{ipath: ipath, err: err} if rename { r.name = pkgName @@ -304,7 +304,7 @@ func loadExportsGoPath(dir string) map[string]bool { // extended by adding a file with an init function. var findImport = findImportGoPath -func findImportGoPath(pkgName string, symbols map[string]bool) (string, bool, error) { +func findImportGoPath(pkgName string, symbols map[string]bool, filename string) (string, bool, error) { // Fast path for the standard library. // In the common case we hopefully never have to scan the GOPATH, which can // be slow with moving disks. @@ -320,51 +320,79 @@ func findImportGoPath(pkgName string, symbols map[string]bool) (string, bool, er pkgIndexOnce.Do(loadPkgIndex) // Collect exports for packages with matching names. - var wg sync.WaitGroup - var pkgsMu sync.Mutex // guards pkgs - // full importpath => exported symbol => True - // e.g. "net/http" => "Client" => True - pkgs := make(map[string]map[string]bool) + var ( + wg sync.WaitGroup + mu sync.Mutex + shortest string + ) pkgIndex.Lock() for _, pkg := range pkgIndex.m[pkgName] { + if !canUse(filename, pkg.dir) { + continue + } wg.Add(1) go func(importpath, dir string) { defer wg.Done() exports := loadExports(dir) - if exports != nil { - pkgsMu.Lock() - pkgs[importpath] = exports - pkgsMu.Unlock() + if exports == nil { + return } + // If it doesn't have the right symbols, stop. + for symbol := range symbols { + if !exports[symbol] { + return + } + } + + // Devendorize for use in import statement. + if i := strings.LastIndex(importpath, "/vendor/"); i >= 0 { + importpath = importpath[i+len("/vendor/"):] + } else if strings.HasPrefix(importpath, "vendor/") { + importpath = importpath[len("vendor/"):] + } + + // Save as the answer. + // If there are multiple candidates, the shortest wins, + // to prefer "bytes" over "github.com/foo/bytes". + mu.Lock() + if shortest == "" || len(importpath) < len(shortest) || len(importpath) == len(shortest) && importpath < shortest { + shortest = importpath + } + mu.Unlock() }(pkg.importpath, pkg.dir) } pkgIndex.Unlock() wg.Wait() - // Filter out packages missing required exported symbols. - for symbol := range symbols { - for importpath, exports := range pkgs { - if !exports[symbol] { - delete(pkgs, importpath) - } - } - } - if len(pkgs) == 0 { - return "", false, nil - } - - // If there are multiple candidate packages, the shortest one wins. - // This is a heuristic to prefer the standard library (e.g. "bytes") - // over e.g. "github.com/foo/bar/bytes". - shortest := "" - for importPath := range pkgs { - if shortest == "" || len(importPath) < len(shortest) { - shortest = importPath - } - } return shortest, false, nil } +func canUse(filename, dir string) bool { + dirSlash := filepath.ToSlash(dir) + if !strings.Contains(dirSlash, "/vendor/") && !strings.Contains(dirSlash, "/internal/") && !strings.HasSuffix(dirSlash, "/internal") { + return true + } + // Vendor or internal directory only visible from children of parent. + // That means the path from the current directory to the target directory + // can contain ../vendor or ../internal but not ../foo/vendor or ../foo/internal + // or bar/vendor or bar/internal. + // After stripping all the leading ../, the only okay place to see vendor or internal + // is at the very beginning of the path. + abs, err := filepath.Abs(filename) + if err != nil { + return false + } + rel, err := filepath.Rel(abs, dir) + if err != nil { + return false + } + relSlash := filepath.ToSlash(rel) + if i := strings.LastIndex(relSlash, "../"); i >= 0 { + relSlash = relSlash[i+len("../"):] + } + return !strings.Contains(relSlash, "/vendor/") && !strings.Contains(relSlash, "/internal/") && !strings.HasSuffix(relSlash, "/internal") +} + type visitFn func(node ast.Node) ast.Visitor func (fn visitFn) Visit(node ast.Node) ast.Visitor { diff --git a/imports/fix_test.go b/imports/fix_test.go index f087bc70bf..91ce9d4c10 100644 --- a/imports/fix_test.go +++ b/imports/fix_test.go @@ -5,11 +5,13 @@ package imports import ( + "bytes" "flag" "go/build" "io/ioutil" "os" "path/filepath" + "runtime" "sync" "testing" ) @@ -743,7 +745,11 @@ func TestFixImports(t *testing.T) { "user": "appengine/user", "zip": "archive/zip", } - findImport = func(pkgName string, symbols map[string]bool) (string, bool, error) { + old := findImport + defer func() { + findImport = old + }() + findImport = func(pkgName string, symbols map[string]bool, filename string) (string, bool, error) { return simplePkgs[pkgName], pkgName == "str", nil } @@ -813,7 +819,7 @@ type Buffer2 struct {} build.Default.GOPATH = oldGOPATH }() - got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}) + got, rename, err := findImportGoPath("bytes", map[string]bool{"Buffer2": true}, "x.go") if err != nil { t.Fatal(err) } @@ -821,7 +827,7 @@ type Buffer2 struct {} t.Errorf(`findImportGoPath("bytes", Buffer2 ...)=%q, %t, want "%s", false`, got, rename, bytesPkgPath) } - got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}) + got, rename, err = findImportGoPath("bytes", map[string]bool{"Missing": true}, "x.go") if err != nil { t.Fatal(err) } @@ -830,6 +836,92 @@ type Buffer2 struct {} } } +func TestFindImportInternal(t *testing.T) { + pkgIndexOnce = sync.Once{} + oldGOPATH := build.Default.GOPATH + build.Default.GOPATH = "" + defer func() { + build.Default.GOPATH = oldGOPATH + }() + + _, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/internal")) + if err != nil { + t.Skip(err) + } + + got, rename, err := findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go")) + if err != nil { + t.Fatal(err) + } + if got != "internal/race" || rename { + t.Errorf(`findImportGoPath("race", Acquire ...)=%q, %t, want "internal/race", false`, got, rename) + } + + // should not be able to use internal from outside that tree + got, rename, err = findImportGoPath("race", map[string]bool{"Acquire": true}, filepath.Join(runtime.GOROOT(), "x.go")) + if err != nil { + t.Fatal(err) + } + if got != "" || rename { + t.Errorf(`findImportGoPath("race", Acquire ...)=%q, %t, want "", false`, got, rename) + } +} + +func TestFindImportVendor(t *testing.T) { + pkgIndexOnce = sync.Once{} + oldGOPATH := build.Default.GOPATH + build.Default.GOPATH = "" + defer func() { + build.Default.GOPATH = oldGOPATH + }() + + _, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/vendor")) + if err != nil { + t.Skip(err) + } + + got, rename, err := findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(runtime.GOROOT(), "src/math/x.go")) + if err != nil { + t.Fatal(err) + } + if got != "golang.org/x/net/http2/hpack" || rename { + t.Errorf(`findImportGoPath("hpack", HuffmanDecode ...)=%q, %t, want "golang.org/x/net/http2/hpack", false`, got, rename) + } + + // should not be able to use vendor from outside that tree + got, rename, err = findImportGoPath("hpack", map[string]bool{"HuffmanDecode": true}, filepath.Join(runtime.GOROOT(), "x.go")) + if err != nil { + t.Fatal(err) + } + if got != "" || rename { + t.Errorf(`findImportGoPath("hpack", HuffmanDecode ...)=%q, %t, want "", false`, got, rename) + } +} + +func TestProcessVendor(t *testing.T) { + pkgIndexOnce = sync.Once{} + oldGOPATH := build.Default.GOPATH + build.Default.GOPATH = "" + defer func() { + build.Default.GOPATH = oldGOPATH + }() + + _, err := os.Stat(filepath.Join(runtime.GOROOT(), "src/vendor")) + if err != nil { + t.Skip(err) + } + + target := filepath.Join(runtime.GOROOT(), "src/math/x.go") + out, err := Process(target, []byte("package http\nimport \"bytes\"\nfunc f() { strings.NewReader(); hpack.HuffmanDecode() }\n"), nil) + + if err != nil { + t.Fatal(err) + } + if !bytes.Contains(out, []byte("\"golang.org/x/net/http2/hpack\"")) { + t.Fatalf("Process(%q) did not add expected hpack import:\n%s", target, out) + } +} + func TestFindImportStdlib(t *testing.T) { tests := []struct { pkg string diff --git a/imports/imports.go b/imports/imports.go index e30946bc63..c0d6860163 100644 --- a/imports/imports.go +++ b/imports/imports.go @@ -35,6 +35,10 @@ type Options struct { // Process formats and adjusts imports for the provided file. // If opt is nil the defaults are used. +// +// Note that filename's directory influences which imports can be chosen, +// so it is important that filename be accurate. +// To process data ``as if'' it were in filename, pass the data as a non-nil src. func Process(filename string, src []byte, opt *Options) ([]byte, error) { if opt == nil { opt = &Options{Comments: true, TabIndent: true, TabWidth: 8} @@ -46,7 +50,7 @@ func Process(filename string, src []byte, opt *Options) ([]byte, error) { return nil, err } - _, err = fixImports(fileSet, file) + _, err = fixImports(fileSet, file, filename) if err != nil { return nil, err }