diff --git a/src/cmd/gofix/fix.go b/src/cmd/gofix/fix.go index 9e4fd56a6ef..4eaadac2b4a 100644 --- a/src/cmd/gofix/fix.go +++ b/src/cmd/gofix/fix.go @@ -19,7 +19,7 @@ type fix struct { desc string } -// main runs sort.Sort(fixes) after init process is done. +// main runs sort.Sort(fixes) before printing list of fixes. type fixlist []fix func (f fixlist) Len() int { return len(f) } @@ -316,6 +316,20 @@ func importPath(s *ast.ImportSpec) string { return "" } +// declImports reports whether gen contains an import of path. +func declImports(gen *ast.GenDecl, path string) bool { + if gen.Tok != token.IMPORT { + return false + } + for _, spec := range gen.Specs { + impspec := spec.(*ast.ImportSpec) + if importPath(impspec) == path { + return true + } + } + return false +} + // isPkgDot returns true if t is the expression "pkg.name" // where pkg is an imported identifier. func isPkgDot(t ast.Expr, pkg, name string) bool { @@ -486,12 +500,18 @@ func addImport(f *ast.File, path string) { var impdecl *ast.GenDecl // Find an import decl to add to. - for _, decl := range f.Decls { + var lastImport int = -1 + for i, decl := range f.Decls { gen, ok := decl.(*ast.GenDecl) if ok && gen.Tok == token.IMPORT { - impdecl = gen - break + lastImport = i + // Do not add to import "C", to avoid disrupting the + // association with its doc comment, breaking cgo. + if !declImports(gen, "C") { + impdecl = gen + break + } } } @@ -501,8 +521,8 @@ func addImport(f *ast.File, path string) { Tok: token.IMPORT, } f.Decls = append(f.Decls, nil) - copy(f.Decls[1:], f.Decls) - f.Decls[0] = impdecl + copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:]) + f.Decls[lastImport+1] = impdecl } // Ensure the import decl has parentheses, if needed. @@ -540,7 +560,6 @@ func deleteImport(f *ast.File, path string) { } for j, spec := range gen.Specs { impspec := spec.(*ast.ImportSpec) - if oldImport != impspec { continue } @@ -558,7 +577,13 @@ func deleteImport(f *ast.File, path string) { } else if len(gen.Specs) == 1 { gen.Lparen = token.NoPos // drop parens } - + if j > 0 { + // We deleted an entry but now there will be + // a blank line-sized hole where the import was. + // Close the hole by making the previous + // import appear to "end" where this one did. + gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End() + } break } } diff --git a/src/cmd/gofix/import_test.go b/src/cmd/gofix/import_test.go new file mode 100644 index 00000000000..f878c0ccfb7 --- /dev/null +++ b/src/cmd/gofix/import_test.go @@ -0,0 +1,269 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "go/ast" + +func init() { + addTestCases(importTests, nil) +} + +var importTests = []testCase{ + { + Name: "import.0", + Fn: addImportFn("os"), + In: `package main + +import ( + "os" +) +`, + Out: `package main + +import ( + "os" +) +`, + }, + { + Name: "import.1", + Fn: addImportFn("os"), + In: `package main +`, + Out: `package main + +import "os" +`, + }, + { + Name: "import.2", + Fn: addImportFn("os"), + In: `package main + +// Comment +import "C" +`, + Out: `package main + +// Comment +import "C" +import "os" +`, + }, + { + Name: "import.3", + Fn: addImportFn("os"), + In: `package main + +// Comment +import "C" + +import ( + "io" + "utf8" +) +`, + Out: `package main + +// Comment +import "C" + +import ( + "io" + "os" + "utf8" +) +`, + }, + { + Name: "import.4", + Fn: deleteImportFn("os"), + In: `package main + +import ( + "os" +) +`, + Out: `package main +`, + }, + { + Name: "import.5", + Fn: deleteImportFn("os"), + In: `package main + +// Comment +import "C" +import "os" +`, + Out: `package main + +// Comment +import "C" +`, + }, + { + Name: "import.6", + Fn: deleteImportFn("os"), + In: `package main + +// Comment +import "C" + +import ( + "io" + "os" + "utf8" +) +`, + Out: `package main + +// Comment +import "C" + +import ( + "io" + "utf8" +) +`, + }, + { + Name: "import.7", + Fn: deleteImportFn("io"), + In: `package main + +import ( + "io" // a + "os" // b + "utf8" // c +) +`, + Out: `package main + +import ( + // a + "os" // b + "utf8" // c +) +`, + }, + { + Name: "import.8", + Fn: deleteImportFn("os"), + In: `package main + +import ( + "io" // a + "os" // b + "utf8" // c +) +`, + Out: `package main + +import ( + "io" // a + // b + "utf8" // c +) +`, + }, + { + Name: "import.9", + Fn: deleteImportFn("utf8"), + In: `package main + +import ( + "io" // a + "os" // b + "utf8" // c +) +`, + Out: `package main + +import ( + "io" // a + "os" // b + // c +) +`, + }, + { + Name: "import.10", + Fn: deleteImportFn("io"), + In: `package main + +import ( + "io" + "os" + "utf8" +) +`, + Out: `package main + +import ( + "os" + "utf8" +) +`, + }, + { + Name: "import.11", + Fn: deleteImportFn("os"), + In: `package main + +import ( + "io" + "os" + "utf8" +) +`, + Out: `package main + +import ( + "io" + "utf8" +) +`, + }, + { + Name: "import.12", + Fn: deleteImportFn("utf8"), + In: `package main + +import ( + "io" + "os" + "utf8" +) +`, + Out: `package main + +import ( + "io" + "os" +) +`, + }, +} + +func addImportFn(path string) func(*ast.File) bool { + return func(f *ast.File) bool { + if !imports(f, path) { + addImport(f, path) + return true + } + return false + } +} + +func deleteImportFn(path string) func(*ast.File) bool { + return func(f *ast.File) bool { + if imports(f, path) { + deleteImport(f, path) + return true + } + return false + } +} diff --git a/src/pkg/go/ast/ast.go b/src/pkg/go/ast/ast.go index 22bd5ee2266..f8caafc179a 100644 --- a/src/pkg/go/ast/ast.go +++ b/src/pkg/go/ast/ast.go @@ -752,6 +752,7 @@ type ( Name *Ident // local package name (including "."); or nil Path *BasicLit // import path Comment *CommentGroup // line comments; or nil + EndPos token.Pos // end of spec (overrides Path.Pos if nonzero) } // A ValueSpec node represents a constant or variable declaration @@ -785,7 +786,13 @@ func (s *ImportSpec) Pos() token.Pos { func (s *ValueSpec) Pos() token.Pos { return s.Names[0].Pos() } func (s *TypeSpec) Pos() token.Pos { return s.Name.Pos() } -func (s *ImportSpec) End() token.Pos { return s.Path.End() } +func (s *ImportSpec) End() token.Pos { + if s.EndPos != 0 { + return s.EndPos + } + return s.Path.End() +} + func (s *ValueSpec) End() token.Pos { if n := len(s.Values); n > 0 { return s.Values[n-1].End() diff --git a/src/pkg/go/parser/parser.go b/src/pkg/go/parser/parser.go index be82b2f801a..c78c6b56ecb 100644 --- a/src/pkg/go/parser/parser.go +++ b/src/pkg/go/parser/parser.go @@ -1909,7 +1909,7 @@ func parseImportSpec(p *parser, doc *ast.CommentGroup, _ int) ast.Spec { p.expectSemi() // call before accessing p.linecomment // collect imports - spec := &ast.ImportSpec{doc, ident, path, p.lineComment} + spec := &ast.ImportSpec{doc, ident, path, p.lineComment, token.NoPos} p.imports = append(p.imports, spec) return spec diff --git a/src/pkg/go/printer/nodes.go b/src/pkg/go/printer/nodes.go index 364530634aa..248e43d4e72 100644 --- a/src/pkg/go/printer/nodes.go +++ b/src/pkg/go/printer/nodes.go @@ -1278,6 +1278,7 @@ func (p *printer) spec(spec ast.Spec, n int, doIndent bool, multiLine *bool) { } p.expr(s.Path, multiLine) p.setComment(s.Comment) + p.print(s.EndPos) case *ast.ValueSpec: if n != 1 {