1
0
mirror of https://github.com/golang/go synced 2024-11-18 16:54:43 -07:00

astutil: fix a comment corruption case

This fixes a case where adding an import when there are is no
existing import declaration can corrupt the position of
comments attached to types. This was the last known
goimports/astutil corruption case.

See golang.org/issue/6884 for more details.

Unfortunately this requires changing the API to add a
*token.FileSet, which we should've had before.  I will update
goimports (the only user of this API?) immediately after
submitting this.

This CL also contains a hack (used only in this case of no
imports): rather than fix the comment positions by hand
(something that only Robert might know how to do), it instead
just prints the AST, manipulates the source, and re-parses
the AST. We can fix up later.

Fixes golang/go#6884

R=golang-dev, gri
CC=golang-dev
https://golang.org/cl/38270043
This commit is contained in:
Brad Fitzpatrick 2013-12-16 14:43:29 -08:00
parent 5eb4fdc120
commit ae534bcb6c
2 changed files with 58 additions and 13 deletions

View File

@ -6,16 +6,22 @@
package astutil
import (
"bufio"
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"log"
"path"
"strconv"
"strings"
)
// AddImport adds the import path to the file f, if absent.
func AddImport(f *ast.File, ipath string) (added bool) {
return AddNamedImport(f, "", ipath)
func AddImport(fset *token.FileSet, f *ast.File, ipath string) (added bool) {
return AddNamedImport(fset, f, "", ipath)
}
// AddNamedImport adds the import path to the file f, if absent.
@ -25,7 +31,7 @@ func AddImport(f *ast.File, ipath string) (added bool) {
// AddNamedImport(f, "pathpkg", "path")
// adds
// import pathpkg "path"
func AddNamedImport(f *ast.File, name, ipath string) (added bool) {
func AddNamedImport(fset *token.FileSet, f *ast.File, name, ipath string) (added bool) {
if imports(f, ipath) {
return false
}
@ -46,10 +52,12 @@ func AddNamedImport(f *ast.File, name, ipath string) (added bool) {
lastImport = -1
impDecl *ast.GenDecl
impIndex = -1
hasImports = false
)
for i, decl := range f.Decls {
gen, ok := decl.(*ast.GenDecl)
if ok && gen.Tok == token.IMPORT {
hasImports = true
lastImport = i
// Do not add to import "C", to avoid disrupting the
// association with its doc comment, breaking cgo.
@ -72,6 +80,18 @@ func AddNamedImport(f *ast.File, name, ipath string) (added bool) {
// If no import decl found, add one after the last import.
if impDecl == nil {
// TODO(bradfitz): remove this hack. See comment below on
// addImportViaSourceModification.
if !hasImports {
f2, err := addImportViaSourceModification(fset, f, name, ipath)
if err == nil {
*f = *f2
return true
}
log.Printf("addImportViaSourceModification error: %v", err)
}
// TODO(bradfitz): fix above and resume using this old code:
impDecl = &ast.GenDecl{
Tok: token.IMPORT,
}
@ -110,7 +130,7 @@ func AddNamedImport(f *ast.File, name, ipath string) (added bool) {
}
// DeleteImport deletes the import path from the file f, if present.
func DeleteImport(f *ast.File, path string) (deleted bool) {
func DeleteImport(fset *token.FileSet, f *ast.File, path string) (deleted bool) {
oldImport := importSpec(f, path)
// Find the import node that imports path, if any.
@ -163,7 +183,7 @@ func DeleteImport(f *ast.File, path string) (deleted bool) {
}
// RewriteImport rewrites any import of path oldPath to path newPath.
func RewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) {
func RewriteImport(fset *token.FileSet, f *ast.File, oldPath, newPath string) (rewrote bool) {
for _, imp := range f.Imports {
if importPath(imp) == oldPath {
rewrote = true
@ -371,3 +391,29 @@ func Imports(fset *token.FileSet, f *ast.File) [][]*ast.ImportSpec {
return groups
}
// NOTE(bradfitz): this is a bit of a hack for golang.org/issue/6884
// because we can't get the comment positions correct. Instead of modifying
// the AST, we print it, modify the text, and re-parse it. Gross.
func addImportViaSourceModification(fset *token.FileSet, f *ast.File, name, ipath string) (*ast.File, error) {
var buf bytes.Buffer
if err := format.Node(&buf, fset, f); err != nil {
return nil, fmt.Errorf("Error formatting ast.File node: %v", err)
}
var out bytes.Buffer
sc := bufio.NewScanner(bytes.NewReader(buf.Bytes()))
didAdd := false
for sc.Scan() {
ln := sc.Text()
out.WriteString(ln)
out.WriteByte('\n')
if !didAdd && strings.HasPrefix(ln, "package ") {
fmt.Fprintf(&out, "\nimport %s %q\n\n", name, ipath)
didAdd = true
}
}
if err := sc.Err(); err != nil {
return nil, err
}
return parser.ParseFile(fset, "", out.Bytes(), parser.ParseComments)
}

View File

@ -176,9 +176,8 @@ import (
`,
},
{
broken: true,
name: "struct comment",
pkg: "time",
name: "struct comment",
pkg: "time",
in: `package main
// This is a comment before a struct.
@ -203,7 +202,7 @@ func TestAddImport(t *testing.T) {
file := parse(t, test.name, test.in)
var before bytes.Buffer
ast.Fprint(&before, fset, file, nil)
AddNamedImport(file, test.renamedPkg, test.pkg)
AddNamedImport(fset, file, test.renamedPkg, test.pkg)
if got := print(t, test.name, file); got != test.out {
if test.broken {
t.Logf("%s is known broken:\ngot: %s\nwant: %s", test.name, got, test.out)
@ -220,8 +219,8 @@ func TestAddImport(t *testing.T) {
func TestDoubleAddImport(t *testing.T) {
file := parse(t, "doubleimport", "package main\n")
AddImport(file, "os")
AddImport(file, "bytes")
AddImport(fset, file, "os")
AddImport(fset, file, "bytes")
want := `package main
import (
@ -416,7 +415,7 @@ import (
func TestDeleteImport(t *testing.T) {
for _, test := range deleteTests {
file := parse(t, test.name, test.in)
DeleteImport(file, test.pkg)
DeleteImport(fset, file, test.pkg)
if got := print(t, test.name, file); got != test.out {
t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out)
}
@ -545,7 +544,7 @@ var addr = flag.String("addr", ":1718", "http service address") // Q=17, R=18
func TestRewriteImport(t *testing.T) {
for _, test := range rewriteTests {
file := parse(t, test.name, test.in)
RewriteImport(file, test.srcPkg, test.dstPkg)
RewriteImport(fset, file, test.srcPkg, test.dstPkg)
if got := print(t, test.name, file); got != test.out {
t.Errorf("%s:\ngot: %s\nwant: %s", test.name, got, test.out)
}