1
0
mirror of https://github.com/golang/go synced 2024-09-30 18:08:33 -06:00

internal/imports, internal/lsp: quick fix import errors

Get quick fixes for the diagnostics related to import errors. These
fixes add, remove, or rename exactly one import.

This change exposes the individual fixes found by the imports package,
and then applies each of them separately to the source.  Since applying each
fix requires a new ast anyway, we pass in the source to be parsed each time.

Change-Id: Ibcbfa703d21b6983d774d2010716da8c25525d4f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/188059
Run-TryBot: Suzy Mueller <suzmue@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
This commit is contained in:
Suzy Mueller 2019-07-30 14:00:02 -04:00
parent 1e85ed8060
commit 5f95ed5921
4 changed files with 317 additions and 103 deletions

View File

@ -67,23 +67,27 @@ func importGroup(env *ProcessEnv, importPath string) int {
return 0
}
type importFixType int
type ImportFixType int
const (
addImport importFixType = iota
deleteImport
setImportName
AddImport ImportFixType = iota
DeleteImport
SetImportName
)
type importFix struct {
info importInfo
fixType importFixType
type ImportFix struct {
// StmtInfo represents the import statement this fix will add, remove, or change.
StmtInfo ImportInfo
// IdentName is the identifier that this fix will add or remove.
IdentName string
// FixType is the type of fix this is (AddImport, DeleteImport, SetImportName).
FixType ImportFixType
}
// An importInfo represents a single import statement.
type importInfo struct {
importPath string // import path, e.g. "crypto/rand".
name string // import name, e.g. "crand", or "" if none.
// An ImportInfo represents a single import statement.
type ImportInfo struct {
ImportPath string // import path, e.g. "crypto/rand".
Name string // import name, e.g. "crand", or "" if none.
}
// A packageInfo represents what's known about a package.
@ -183,8 +187,8 @@ func collectReferences(f *ast.File) references {
// collectImports returns all the imports in f.
// Unnamed imports (., _) and "C" are ignored.
func collectImports(f *ast.File) []*importInfo {
var imports []*importInfo
func collectImports(f *ast.File) []*ImportInfo {
var imports []*ImportInfo
for _, imp := range f.Imports {
var name string
if imp.Name != nil {
@ -194,9 +198,9 @@ func collectImports(f *ast.File) []*importInfo {
continue
}
path := strings.Trim(imp.Path.Value, `"`)
imports = append(imports, &importInfo{
name: name,
importPath: path,
imports = append(imports, &ImportInfo{
Name: name,
ImportPath: path,
})
}
return imports
@ -204,9 +208,9 @@ func collectImports(f *ast.File) []*importInfo {
// findMissingImport searches pass's candidates for an import that provides
// pkg, containing all of syms.
func (p *pass) findMissingImport(pkg string, syms map[string]bool) *importInfo {
func (p *pass) findMissingImport(pkg string, syms map[string]bool) *ImportInfo {
for _, candidate := range p.candidates {
pkgInfo, ok := p.knownPackages[candidate.importPath]
pkgInfo, ok := p.knownPackages[candidate.ImportPath]
if !ok {
continue
}
@ -246,18 +250,18 @@ type pass struct {
otherFiles []*ast.File // sibling files.
// Intermediate state, generated by load.
existingImports map[string]*importInfo
existingImports map[string]*ImportInfo
allRefs references
missingRefs references
// Inputs to fix. These can be augmented between successive fix calls.
lastTry bool // indicates that this is the last call and fix should clean up as best it can.
candidates []*importInfo // candidate imports in priority order.
candidates []*ImportInfo // candidate imports in priority order.
knownPackages map[string]*packageInfo // information about all known packages.
}
// loadPackageNames saves the package names for everything referenced by imports.
func (p *pass) loadPackageNames(imports []*importInfo) error {
func (p *pass) loadPackageNames(imports []*ImportInfo) error {
if p.env.Debug {
p.env.Logf("loading package names for %v packages", len(imports))
defer func() {
@ -266,10 +270,10 @@ func (p *pass) loadPackageNames(imports []*importInfo) error {
}
var unknown []string
for _, imp := range imports {
if _, ok := p.knownPackages[imp.importPath]; ok {
if _, ok := p.knownPackages[imp.ImportPath]; ok {
continue
}
unknown = append(unknown, imp.importPath)
unknown = append(unknown, imp.ImportPath)
}
names, err := p.env.GetResolver().loadPackageNames(unknown, p.srcDir)
@ -289,24 +293,24 @@ func (p *pass) loadPackageNames(imports []*importInfo) error {
// importIdentifier returns the identifier that imp will introduce. It will
// guess if the package name has not been loaded, e.g. because the source
// is not available.
func (p *pass) importIdentifier(imp *importInfo) string {
if imp.name != "" {
return imp.name
func (p *pass) importIdentifier(imp *ImportInfo) string {
if imp.Name != "" {
return imp.Name
}
known := p.knownPackages[imp.importPath]
known := p.knownPackages[imp.ImportPath]
if known != nil && known.name != "" {
return known.name
}
return importPathToAssumedName(imp.importPath)
return importPathToAssumedName(imp.ImportPath)
}
// load reads in everything necessary to run a pass, and reports whether the
// file already has all the imports it needs. It fills in p.missingRefs with the
// file's missing symbols, if any, or removes unused imports if not.
func (p *pass) load() ([]*importFix, bool) {
func (p *pass) load() ([]*ImportFix, bool) {
p.knownPackages = map[string]*packageInfo{}
p.missingRefs = references{}
p.existingImports = map[string]*importInfo{}
p.existingImports = map[string]*ImportInfo{}
// Load basic information about the file in question.
p.allRefs = collectReferences(p.f)
@ -361,9 +365,9 @@ func (p *pass) load() ([]*importFix, bool) {
// fix attempts to satisfy missing imports using p.candidates. If it finds
// everything, or if p.lastTry is true, it updates fixes to add the imports it found,
// delete anything unused, and update import names, and returns true.
func (p *pass) fix() ([]*importFix, bool) {
func (p *pass) fix() ([]*ImportFix, bool) {
// Find missing imports.
var selected []*importInfo
var selected []*ImportInfo
for left, rights := range p.missingRefs {
if imp := p.findMissingImport(left, rights); imp != nil {
selected = append(selected, imp)
@ -375,7 +379,7 @@ func (p *pass) fix() ([]*importFix, bool) {
}
// Found everything, or giving up. Add the new imports and remove any unused.
var fixes []*importFix
var fixes []*ImportFix
for _, imp := range p.existingImports {
// We deliberately ignore globals here, because we can't be sure
// they're in the same package. People do things like put multiple
@ -383,32 +387,35 @@ func (p *pass) fix() ([]*importFix, bool) {
// remove imports if they happen to have the same name as a var in
// a different package.
if _, ok := p.allRefs[p.importIdentifier(imp)]; !ok {
fixes = append(fixes, &importFix{
info: *imp,
fixType: deleteImport,
fixes = append(fixes, &ImportFix{
StmtInfo: *imp,
IdentName: p.importIdentifier(imp),
FixType: DeleteImport,
})
continue
}
// An existing import may need to update its import name to be correct.
if name := p.importSpecName(imp); name != imp.name {
fixes = append(fixes, &importFix{
info: importInfo{
name: name,
importPath: imp.importPath,
if name := p.importSpecName(imp); name != imp.Name {
fixes = append(fixes, &ImportFix{
StmtInfo: ImportInfo{
Name: name,
ImportPath: imp.ImportPath,
},
fixType: setImportName,
IdentName: p.importIdentifier(imp),
FixType: SetImportName,
})
}
}
for _, imp := range selected {
fixes = append(fixes, &importFix{
info: importInfo{
name: p.importSpecName(imp),
importPath: imp.importPath,
fixes = append(fixes, &ImportFix{
StmtInfo: ImportInfo{
Name: p.importSpecName(imp),
ImportPath: imp.ImportPath,
},
fixType: addImport,
IdentName: p.importIdentifier(imp),
FixType: AddImport,
})
}
@ -419,42 +426,41 @@ func (p *pass) fix() ([]*importFix, bool) {
//
// When the import identifier matches the assumed import name, the import name does
// not appear in the import spec.
func (p *pass) importSpecName(imp *importInfo) string {
func (p *pass) importSpecName(imp *ImportInfo) string {
// If we did not load the real package names, or the name is already set,
// we just return the existing name.
if !p.loadRealPackageNames || imp.name != "" {
return imp.name
if !p.loadRealPackageNames || imp.Name != "" {
return imp.Name
}
ident := p.importIdentifier(imp)
if ident == importPathToAssumedName(imp.importPath) {
if ident == importPathToAssumedName(imp.ImportPath) {
return "" // ident not needed since the assumed and real names are the same.
}
return ident
}
// apply will perform the fixes on f in order.
func apply(fset *token.FileSet, f *ast.File, fixes []*importFix) bool {
func apply(fset *token.FileSet, f *ast.File, fixes []*ImportFix) {
for _, fix := range fixes {
switch fix.fixType {
case deleteImport:
astutil.DeleteNamedImport(fset, f, fix.info.name, fix.info.importPath)
case addImport:
astutil.AddNamedImport(fset, f, fix.info.name, fix.info.importPath)
case setImportName:
switch fix.FixType {
case DeleteImport:
astutil.DeleteNamedImport(fset, f, fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
case AddImport:
astutil.AddNamedImport(fset, f, fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
case SetImportName:
// Find the matching import path and change the name.
for _, spec := range f.Imports {
path := strings.Trim(spec.Path.Value, `"`)
if path == fix.info.importPath {
if path == fix.StmtInfo.ImportPath {
spec.Name = &ast.Ident{
Name: fix.info.name,
Name: fix.StmtInfo.Name,
NamePos: spec.Pos(),
}
}
}
}
}
return true
}
// assumeSiblingImportsValid assumes that siblings' use of packages is valid,
@ -463,15 +469,15 @@ func (p *pass) assumeSiblingImportsValid() {
for _, f := range p.otherFiles {
refs := collectReferences(f)
imports := collectImports(f)
importsByName := map[string]*importInfo{}
importsByName := map[string]*ImportInfo{}
for _, imp := range imports {
importsByName[p.importIdentifier(imp)] = imp
}
for left, rights := range refs {
if imp, ok := importsByName[left]; ok {
if _, ok := stdlib[imp.importPath]; ok {
if _, ok := stdlib[imp.ImportPath]; ok {
// We have the stdlib in memory; no need to guess.
rights = stdlib[imp.importPath]
rights = stdlib[imp.ImportPath]
}
p.addCandidate(imp, &packageInfo{
// no name; we already know it.
@ -484,9 +490,9 @@ func (p *pass) assumeSiblingImportsValid() {
// addCandidate adds a candidate import to p, and merges in the information
// in pkg.
func (p *pass) addCandidate(imp *importInfo, pkg *packageInfo) {
func (p *pass) addCandidate(imp *ImportInfo, pkg *packageInfo) {
p.candidates = append(p.candidates, imp)
if existing, ok := p.knownPackages[imp.importPath]; ok {
if existing, ok := p.knownPackages[imp.ImportPath]; ok {
if existing.name == "" {
existing.name = pkg.name
}
@ -494,7 +500,7 @@ func (p *pass) addCandidate(imp *importInfo, pkg *packageInfo) {
existing.exports[export] = true
}
} else {
p.knownPackages[imp.importPath] = pkg
p.knownPackages[imp.ImportPath] = pkg
}
}
@ -516,7 +522,7 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *P
// getFixes gets the import fixes that need to be made to f in order to fix the imports.
// It does not modify the ast.
func getFixes(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) ([]*importFix, error) {
func getFixes(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) ([]*ImportFix, error) {
abs, err := filepath.Abs(filename)
if err != nil {
return nil, err
@ -682,7 +688,7 @@ func cmdDebugStr(cmd *exec.Cmd) string {
func addStdlibCandidates(pass *pass, refs references) {
add := func(pkg string) {
pass.addCandidate(
&importInfo{importPath: pkg},
&ImportInfo{ImportPath: pkg},
&packageInfo{name: path.Base(pkg), exports: stdlib[pkg]})
}
for left := range refs {
@ -768,7 +774,7 @@ func addExternalCandidates(pass *pass, refs references, filename string) error {
// Search for imports matching potential package references.
type result struct {
imp *importInfo
imp *ImportInfo
pkg *packageInfo
}
results := make(chan result, len(refs))
@ -802,8 +808,8 @@ func addExternalCandidates(pass *pass, refs references, filename string) error {
return // No matching package.
}
imp := &importInfo{
importPath: found.importPathShort,
imp := &ImportInfo{
ImportPath: found.importPathShort,
}
pkg := &packageInfo{

View File

@ -13,6 +13,7 @@ import (
"bytes"
"fmt"
"go/ast"
"go/build"
"go/format"
"go/parser"
"go/printer"
@ -42,18 +43,10 @@ type Options struct {
}
// Process implements golang.org/x/tools/imports.Process with explicit context in env.
func Process(filename string, src []byte, opt *Options) ([]byte, error) {
if src == nil {
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
src = b
}
// Set the logger if the user has not provided it.
if opt.Env.Logf == nil {
opt.Env.Logf = log.Printf
func Process(filename string, src []byte, opt *Options) (formatted []byte, err error) {
src, err = initialize(filename, src, opt)
if err != nil {
return nil, err
}
fileSet := token.NewFileSet()
@ -67,7 +60,85 @@ func Process(filename string, src []byte, opt *Options) ([]byte, error) {
return nil, err
}
}
return formatFile(fileSet, file, src, adjust, opt)
}
// FixImports returns a list of fixes to the imports that, when applied,
// will leave the imports in the same state as Process.
//
// Note that filename's directory influences which imports can be chosen,
// so it is important that filename be accurate.
func FixImports(filename string, src []byte, opt *Options) (fixes []*ImportFix, err error) {
src, err = initialize(filename, src, opt)
if err != nil {
return nil, err
}
fileSet := token.NewFileSet()
file, _, err := parse(fileSet, filename, src, opt)
if err != nil {
return nil, err
}
return getFixes(fileSet, file, filename, opt.Env)
}
// ApplyFix will apply all of the fixes to the file and format it.
func ApplyFixes(fixes []*ImportFix, filename string, src []byte, opt *Options) (formatted []byte, err error) {
src, err = initialize(filename, src, opt)
if err != nil {
return nil, err
}
fileSet := token.NewFileSet()
file, adjust, err := parse(fileSet, filename, src, opt)
if err != nil {
return nil, err
}
// Apply the fixes to the file.
apply(fileSet, file, fixes)
return formatFile(fileSet, file, src, adjust, opt)
}
// initialize sets the values for opt and src.
// If they are provided, they are not changed. Otherwise opt is set to the
// default values and src is read from the file system.
func initialize(filename string, src []byte, opt *Options) ([]byte, error) {
// Use defaults if opt is nil.
if opt == nil {
opt = &Options{
Env: &ProcessEnv{
GOPATH: build.Default.GOPATH,
GOROOT: build.Default.GOROOT,
},
AllErrors: opt.AllErrors,
Comments: opt.Comments,
FormatOnly: opt.FormatOnly,
Fragment: opt.Fragment,
TabIndent: opt.TabIndent,
TabWidth: opt.TabWidth,
}
}
// Set the logger if the user has not provided it.
if opt.Env.Logf == nil {
opt.Env.Logf = log.Printf
}
if src == nil {
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
src = b
}
return src, nil
}
func formatFile(fileSet *token.FileSet, file *ast.File, src []byte, adjust func(orig []byte, src []byte) []byte, opt *Options) ([]byte, error) {
sortImports(opt.Env, fileSet, file)
imps := astutil.Imports(fileSet, file)
var spacesBefore []string // import paths we need spaces before
@ -95,7 +166,7 @@ func Process(filename string, src []byte, opt *Options) ([]byte, error) {
printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}
var buf bytes.Buffer
err = printConfig.Fprint(&buf, fileSet, file)
err := printConfig.Fprint(&buf, fileSet, file)
if err != nil {
return nil, err
}

View File

@ -9,6 +9,7 @@ import (
"fmt"
"strings"
"golang.org/x/tools/internal/imports"
"golang.org/x/tools/internal/lsp/protocol"
"golang.org/x/tools/internal/lsp/source"
"golang.org/x/tools/internal/lsp/telemetry"
@ -47,7 +48,7 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara
var codeActions []protocol.CodeAction
edits, err := organizeImports(ctx, view, spn)
edits, editsPerFix, err := organizeImports(ctx, view, spn)
if err != nil {
return nil, err
}
@ -66,18 +67,23 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara
// If we also have diagnostics for missing imports, we can associate them with quick fixes.
if findImportErrors(params.Context.Diagnostics) {
// TODO(rstambler): Separate this into a set of codeActions per diagnostic,
// where each action is the addition or removal of one import.
// This can only be done when https://golang.org/issue/31493 is resolved.
codeActions = append(codeActions, protocol.CodeAction{
Title: "Organize All Imports", // clarify that all imports will change
Kind: protocol.QuickFix,
Edit: &protocol.WorkspaceEdit{
Changes: &map[string][]protocol.TextEdit{
string(uri): edits,
},
},
})
// Separate this into a set of codeActions per diagnostic, where
// each action is the addition, removal, or renaming of one import.
for _, importFix := range editsPerFix {
// Get the diagnostics this fix would affect.
if fixDiagnostics := importDiagnostics(importFix.fix, params.Context.Diagnostics); len(fixDiagnostics) > 0 {
codeActions = append(codeActions, protocol.CodeAction{
Title: importFixTitle(importFix.fix),
Kind: protocol.QuickFix,
Edit: &protocol.WorkspaceEdit{
Changes: &map[string][]protocol.TextEdit{
string(uri): importFix.edits,
},
},
Diagnostics: fixDiagnostics,
})
}
}
}
}
@ -97,16 +103,38 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara
return codeActions, nil
}
func organizeImports(ctx context.Context, view source.View, s span.Span) ([]protocol.TextEdit, error) {
type protocolImportFix struct {
fix *imports.ImportFix
edits []protocol.TextEdit
}
func organizeImports(ctx context.Context, view source.View, s span.Span) ([]protocol.TextEdit, []*protocolImportFix, error) {
f, m, rng, err := spanToRange(ctx, view, s)
if err != nil {
return nil, err
return nil, nil, err
}
edits, err := source.Imports(ctx, view, f, rng)
edits, editsPerFix, err := source.AllImportsFixes(ctx, view, f, rng)
if err != nil {
return nil, err
return nil, nil, err
}
return ToProtocolEdits(m, edits)
// Convert all source edits to protocol edits.
pEdits, err := ToProtocolEdits(m, edits)
if err != nil {
return nil, nil, err
}
pEditsPerFix := make([]*protocolImportFix, len(editsPerFix))
for i, fix := range editsPerFix {
pEdits, err := ToProtocolEdits(m, fix.Edits)
if err != nil {
return nil, nil, err
}
pEditsPerFix[i] = &protocolImportFix{
fix: fix.Fix,
edits: pEdits,
}
}
return pEdits, pEditsPerFix, nil
}
// findImports determines if a given diagnostic represents an error that could
@ -131,6 +159,49 @@ func findImportErrors(diagnostics []protocol.Diagnostic) bool {
return false
}
func importFixTitle(fix *imports.ImportFix) string {
var str string
switch fix.FixType {
case imports.AddImport:
str = fmt.Sprintf("Add import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
case imports.DeleteImport:
str = fmt.Sprintf("Delete import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
case imports.SetImportName:
str = fmt.Sprintf("Rename import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
}
return str
}
func importDiagnostics(fix *imports.ImportFix, diagnostics []protocol.Diagnostic) (results []protocol.Diagnostic) {
for _, diagnostic := range diagnostics {
switch {
// "undeclared name: X" may be an unresolved import.
case strings.HasPrefix(diagnostic.Message, "undeclared name: "):
ident := strings.TrimPrefix(diagnostic.Message, "undeclared name: ")
if ident == fix.IdentName {
results = append(results, diagnostic)
}
// "could not import: X" may be an invalid import.
case strings.HasPrefix(diagnostic.Message, "could not import: "):
ident := strings.TrimPrefix(diagnostic.Message, "could not import: ")
if ident == fix.IdentName {
results = append(results, diagnostic)
}
// "X imported but not used" is an unused import.
// "X imported but not used as Y" is an unused import.
case strings.Contains(diagnostic.Message, " imported but not used"):
idx := strings.Index(diagnostic.Message, " imported but not used")
importPath := diagnostic.Message[:idx]
if importPath == fmt.Sprintf("%q", fix.StmtInfo.ImportPath) {
results = append(results, diagnostic)
}
}
}
return results
}
func quickFixes(ctx context.Context, view source.View, gof source.GoFile) ([]protocol.CodeAction, error) {
var codeActions []protocol.CodeAction

View File

@ -108,6 +108,72 @@ func Imports(ctx context.Context, view View, f GoFile, rng span.Range) ([]TextEd
return computeTextEdits(ctx, f, string(formatted)), nil
}
type ImportFix struct {
Fix *imports.ImportFix
Edits []TextEdit
}
// AllImportsFixes formats f for each possible fix to the imports.
// In addition to returning the result of applying all edits,
// it returns a list of fixes that could be applied to the file, with the
// corresponding TextEdits that would be needed to apply that fix.
func AllImportsFixes(ctx context.Context, view View, f GoFile, rng span.Range) (edits []TextEdit, editsPerFix []*ImportFix, err error) {
ctx, done := trace.StartSpan(ctx, "source.AllImportsFixes")
defer done()
data, _, err := f.Handle(ctx).Read(ctx)
if err != nil {
return nil, nil, err
}
pkg := f.GetPackage(ctx)
if pkg == nil || pkg.IsIllTyped() {
return nil, nil, fmt.Errorf("no package for file %s", f.URI())
}
if hasListErrors(pkg.GetErrors()) {
return nil, nil, fmt.Errorf("%s has list errors, not running goimports", f.URI())
}
options := &imports.Options{
// Defaults.
AllErrors: true,
Comments: true,
Fragment: true,
FormatOnly: false,
TabIndent: true,
TabWidth: 8,
}
importFn := func(opts *imports.Options) error {
fixes, err := imports.FixImports(f.URI().Filename(), data, opts)
if err != nil {
return err
}
// Apply all of the import fixes to the file.
formatted, err := imports.ApplyFixes(fixes, f.URI().Filename(), data, options)
if err != nil {
return err
}
edits = computeTextEdits(ctx, f, string(formatted))
// Add the edits for each fix to the result.
editsPerFix = make([]*ImportFix, len(fixes))
for i, fix := range fixes {
formatted, err := imports.ApplyFixes([]*imports.ImportFix{fix}, f.URI().Filename(), data, options)
if err != nil {
return err
}
editsPerFix[i] = &ImportFix{
Fix: fix,
Edits: computeTextEdits(ctx, f, string(formatted)),
}
}
return err
}
err = view.RunProcessEnvFunc(ctx, importFn, options)
if err != nil {
return nil, nil, err
}
return edits, editsPerFix, nil
}
func hasParseErrors(errors []packages.Error) bool {
for _, err := range errors {
if err.Kind == packages.ParseError {