diff --git a/src/cmd/gofix/fix.go b/src/cmd/gofix/fix.go index f153da9701..0a49df1be3 100644 --- a/src/cmd/gofix/fix.go +++ b/src/cmd/gofix/fix.go @@ -552,6 +552,15 @@ func renameTop(f *ast.File, old, new string) bool { return fixed } +// matchLen returns the length of the longest prefix shared by x and y. +func matchLen(x, y string) int { + i := 0 + for i < len(x) && i < len(y) && x[i] == y[i] { + i++ + } + return i +} + // addImport adds the import path to the file f, if absent. func addImport(f *ast.File, ipath string) (added bool) { if imports(f, ipath) { @@ -572,53 +581,65 @@ func addImport(f *ast.File, ipath string) (added bool) { }, } - var impdecl *ast.GenDecl - // Find an import decl to add to. - var lastImport = -1 + var ( + bestMatch = -1 + lastImport = -1 + impDecl *ast.GenDecl + impIndex = -1 + ) for i, decl := range f.Decls { gen, ok := decl.(*ast.GenDecl) - if ok && gen.Tok == token.IMPORT { lastImport = i // Do not add to import "C", to avoid disrupting the // association with its doc comment, breaking cgo. - if !declImports(gen, "C") { - impdecl = gen - break + if declImports(gen, "C") { + continue + } + + // Compute longest shared prefix with imports in this block. + for j, spec := range gen.Specs { + impspec := spec.(*ast.ImportSpec) + n := matchLen(importPath(impspec), ipath) + if n > bestMatch { + bestMatch = n + impDecl = gen + impIndex = j + } } } } - // No import decl found. Add one. - if impdecl == nil { - impdecl = &ast.GenDecl{ + // If no import decl found, add one after the last import. + if impDecl == nil { + impDecl = &ast.GenDecl{ Tok: token.IMPORT, } f.Decls = append(f.Decls, nil) copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:]) - f.Decls[lastImport+1] = impdecl + f.Decls[lastImport+1] = impDecl } // Ensure the import decl has parentheses, if needed. - if len(impdecl.Specs) > 0 && !impdecl.Lparen.IsValid() { - impdecl.Lparen = impdecl.Pos() + if len(impDecl.Specs) > 0 && !impDecl.Lparen.IsValid() { + impDecl.Lparen = impDecl.Pos() } - // Assume the import paths are alphabetically ordered. - // If they are not, the result is ugly, but legal. - insertAt := len(impdecl.Specs) // default to end of specs - for i, spec := range impdecl.Specs { - impspec := spec.(*ast.ImportSpec) - if importPath(impspec) > ipath { - insertAt = i - break - } + insertAt := impIndex + 1 + if insertAt == 0 { + insertAt = len(impDecl.Specs) + } + impDecl.Specs = append(impDecl.Specs, nil) + copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:]) + impDecl.Specs[insertAt] = newImport + if insertAt > 0 { + // Assign same position as the previous import, + // so that the sorter sees it as being in the same block. + prev := impDecl.Specs[insertAt-1] + newImport.Path.ValuePos = prev.Pos() + newImport.EndPos = prev.Pos() } - - impdecl.Specs = append(impdecl.Specs, nil) - copy(impdecl.Specs[insertAt+1:], impdecl.Specs[insertAt:]) - impdecl.Specs[insertAt] = newImport f.Imports = append(f.Imports, newImport) return true @@ -682,6 +703,9 @@ func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) { for _, imp := range f.Imports { if importPath(imp) == oldPath { rewrote = true + // record old End, beacuse the default is to compute + // it using the length of imp.Path.Value. + imp.EndPos = imp.End() imp.Path.Value = strconv.Quote(newPath) } } diff --git a/src/cmd/gofix/go1pkgrename_test.go b/src/cmd/gofix/go1pkgrename_test.go index 464d67e7f0..32d659653b 100644 --- a/src/cmd/gofix/go1pkgrename_test.go +++ b/src/cmd/gofix/go1pkgrename_test.go @@ -93,6 +93,33 @@ import poot "html/template" var _ = cmplx.Sin var _ = poot.Poot +`, + }, + { + Name: "go1rename.2", + In: `package foo + +import ( + "fmt" + "http" + "url" + + "google/secret/project/go" +) + +func main() {} +`, + Out: `package foo + +import ( + "fmt" + "net/http" + "net/url" + + "google/secret/project/go" +) + +func main() {} `, }, } diff --git a/src/cmd/gofix/import_test.go b/src/cmd/gofix/import_test.go index a06dc821fb..a2ba2e7b92 100644 --- a/src/cmd/gofix/import_test.go +++ b/src/cmd/gofix/import_test.go @@ -348,17 +348,54 @@ import ( ) var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18 +`, + }, + { + Name: "import.3", + Fn: addImportFn("x/y/z", "x/a/c"), + In: `package main + +// Comment +import "C" + +import ( + "a" + "b" + + "x/w" + + "d/f" +) +`, + Out: `package main + +// Comment +import "C" + +import ( + "a" + "b" + + "x/a/c" + "x/w" + "x/y/z" + + "d/f" +) `, }, } -func addImportFn(path string) func(*ast.File) bool { +func addImportFn(path ...string) func(*ast.File) bool { return func(f *ast.File) bool { - if !imports(f, path) { - addImport(f, path) - return true + fixed := false + for _, p := range path { + if !imports(f, p) { + addImport(f, p) + fixed = true + } } - return false + return fixed } }