1
0
mirror of https://github.com/golang/go synced 2024-11-18 16:04:44 -07:00

internal/imports, internal/lsp: quick fix import errors

Get quick fixes for the diagnostics related to import errors. These
fixes add, remove, or rename exactly one import.

This change exposes the individual fixes found by the imports package,
and then applies each of them separately to the source.  Since applying each
fix requires a new ast anyway, we pass in the source to be parsed each time.

Change-Id: Ibcbfa703d21b6983d774d2010716da8c25525d4f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/188059
Run-TryBot: Suzy Mueller <suzmue@golang.org>
Reviewed-by: Heschi Kreinick <heschi@google.com>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
This commit is contained in:
Suzy Mueller 2019-07-30 14:00:02 -04:00
parent 1e85ed8060
commit 5f95ed5921
4 changed files with 317 additions and 103 deletions

View File

@ -67,23 +67,27 @@ func importGroup(env *ProcessEnv, importPath string) int {
return 0 return 0
} }
type importFixType int type ImportFixType int
const ( const (
addImport importFixType = iota AddImport ImportFixType = iota
deleteImport DeleteImport
setImportName SetImportName
) )
type importFix struct { type ImportFix struct {
info importInfo // StmtInfo represents the import statement this fix will add, remove, or change.
fixType importFixType StmtInfo ImportInfo
// IdentName is the identifier that this fix will add or remove.
IdentName string
// FixType is the type of fix this is (AddImport, DeleteImport, SetImportName).
FixType ImportFixType
} }
// An importInfo represents a single import statement. // An ImportInfo represents a single import statement.
type importInfo struct { type ImportInfo struct {
importPath string // import path, e.g. "crypto/rand". ImportPath string // import path, e.g. "crypto/rand".
name string // import name, e.g. "crand", or "" if none. Name string // import name, e.g. "crand", or "" if none.
} }
// A packageInfo represents what's known about a package. // A packageInfo represents what's known about a package.
@ -183,8 +187,8 @@ func collectReferences(f *ast.File) references {
// collectImports returns all the imports in f. // collectImports returns all the imports in f.
// Unnamed imports (., _) and "C" are ignored. // Unnamed imports (., _) and "C" are ignored.
func collectImports(f *ast.File) []*importInfo { func collectImports(f *ast.File) []*ImportInfo {
var imports []*importInfo var imports []*ImportInfo
for _, imp := range f.Imports { for _, imp := range f.Imports {
var name string var name string
if imp.Name != nil { if imp.Name != nil {
@ -194,9 +198,9 @@ func collectImports(f *ast.File) []*importInfo {
continue continue
} }
path := strings.Trim(imp.Path.Value, `"`) path := strings.Trim(imp.Path.Value, `"`)
imports = append(imports, &importInfo{ imports = append(imports, &ImportInfo{
name: name, Name: name,
importPath: path, ImportPath: path,
}) })
} }
return imports return imports
@ -204,9 +208,9 @@ func collectImports(f *ast.File) []*importInfo {
// findMissingImport searches pass's candidates for an import that provides // findMissingImport searches pass's candidates for an import that provides
// pkg, containing all of syms. // pkg, containing all of syms.
func (p *pass) findMissingImport(pkg string, syms map[string]bool) *importInfo { func (p *pass) findMissingImport(pkg string, syms map[string]bool) *ImportInfo {
for _, candidate := range p.candidates { for _, candidate := range p.candidates {
pkgInfo, ok := p.knownPackages[candidate.importPath] pkgInfo, ok := p.knownPackages[candidate.ImportPath]
if !ok { if !ok {
continue continue
} }
@ -246,18 +250,18 @@ type pass struct {
otherFiles []*ast.File // sibling files. otherFiles []*ast.File // sibling files.
// Intermediate state, generated by load. // Intermediate state, generated by load.
existingImports map[string]*importInfo existingImports map[string]*ImportInfo
allRefs references allRefs references
missingRefs references missingRefs references
// Inputs to fix. These can be augmented between successive fix calls. // Inputs to fix. These can be augmented between successive fix calls.
lastTry bool // indicates that this is the last call and fix should clean up as best it can. lastTry bool // indicates that this is the last call and fix should clean up as best it can.
candidates []*importInfo // candidate imports in priority order. candidates []*ImportInfo // candidate imports in priority order.
knownPackages map[string]*packageInfo // information about all known packages. knownPackages map[string]*packageInfo // information about all known packages.
} }
// loadPackageNames saves the package names for everything referenced by imports. // loadPackageNames saves the package names for everything referenced by imports.
func (p *pass) loadPackageNames(imports []*importInfo) error { func (p *pass) loadPackageNames(imports []*ImportInfo) error {
if p.env.Debug { if p.env.Debug {
p.env.Logf("loading package names for %v packages", len(imports)) p.env.Logf("loading package names for %v packages", len(imports))
defer func() { defer func() {
@ -266,10 +270,10 @@ func (p *pass) loadPackageNames(imports []*importInfo) error {
} }
var unknown []string var unknown []string
for _, imp := range imports { for _, imp := range imports {
if _, ok := p.knownPackages[imp.importPath]; ok { if _, ok := p.knownPackages[imp.ImportPath]; ok {
continue continue
} }
unknown = append(unknown, imp.importPath) unknown = append(unknown, imp.ImportPath)
} }
names, err := p.env.GetResolver().loadPackageNames(unknown, p.srcDir) names, err := p.env.GetResolver().loadPackageNames(unknown, p.srcDir)
@ -289,24 +293,24 @@ func (p *pass) loadPackageNames(imports []*importInfo) error {
// importIdentifier returns the identifier that imp will introduce. It will // importIdentifier returns the identifier that imp will introduce. It will
// guess if the package name has not been loaded, e.g. because the source // guess if the package name has not been loaded, e.g. because the source
// is not available. // is not available.
func (p *pass) importIdentifier(imp *importInfo) string { func (p *pass) importIdentifier(imp *ImportInfo) string {
if imp.name != "" { if imp.Name != "" {
return imp.name return imp.Name
} }
known := p.knownPackages[imp.importPath] known := p.knownPackages[imp.ImportPath]
if known != nil && known.name != "" { if known != nil && known.name != "" {
return known.name return known.name
} }
return importPathToAssumedName(imp.importPath) return importPathToAssumedName(imp.ImportPath)
} }
// load reads in everything necessary to run a pass, and reports whether the // load reads in everything necessary to run a pass, and reports whether the
// file already has all the imports it needs. It fills in p.missingRefs with the // file already has all the imports it needs. It fills in p.missingRefs with the
// file's missing symbols, if any, or removes unused imports if not. // file's missing symbols, if any, or removes unused imports if not.
func (p *pass) load() ([]*importFix, bool) { func (p *pass) load() ([]*ImportFix, bool) {
p.knownPackages = map[string]*packageInfo{} p.knownPackages = map[string]*packageInfo{}
p.missingRefs = references{} p.missingRefs = references{}
p.existingImports = map[string]*importInfo{} p.existingImports = map[string]*ImportInfo{}
// Load basic information about the file in question. // Load basic information about the file in question.
p.allRefs = collectReferences(p.f) p.allRefs = collectReferences(p.f)
@ -361,9 +365,9 @@ func (p *pass) load() ([]*importFix, bool) {
// fix attempts to satisfy missing imports using p.candidates. If it finds // fix attempts to satisfy missing imports using p.candidates. If it finds
// everything, or if p.lastTry is true, it updates fixes to add the imports it found, // everything, or if p.lastTry is true, it updates fixes to add the imports it found,
// delete anything unused, and update import names, and returns true. // delete anything unused, and update import names, and returns true.
func (p *pass) fix() ([]*importFix, bool) { func (p *pass) fix() ([]*ImportFix, bool) {
// Find missing imports. // Find missing imports.
var selected []*importInfo var selected []*ImportInfo
for left, rights := range p.missingRefs { for left, rights := range p.missingRefs {
if imp := p.findMissingImport(left, rights); imp != nil { if imp := p.findMissingImport(left, rights); imp != nil {
selected = append(selected, imp) selected = append(selected, imp)
@ -375,7 +379,7 @@ func (p *pass) fix() ([]*importFix, bool) {
} }
// Found everything, or giving up. Add the new imports and remove any unused. // Found everything, or giving up. Add the new imports and remove any unused.
var fixes []*importFix var fixes []*ImportFix
for _, imp := range p.existingImports { for _, imp := range p.existingImports {
// We deliberately ignore globals here, because we can't be sure // We deliberately ignore globals here, because we can't be sure
// they're in the same package. People do things like put multiple // they're in the same package. People do things like put multiple
@ -383,32 +387,35 @@ func (p *pass) fix() ([]*importFix, bool) {
// remove imports if they happen to have the same name as a var in // remove imports if they happen to have the same name as a var in
// a different package. // a different package.
if _, ok := p.allRefs[p.importIdentifier(imp)]; !ok { if _, ok := p.allRefs[p.importIdentifier(imp)]; !ok {
fixes = append(fixes, &importFix{ fixes = append(fixes, &ImportFix{
info: *imp, StmtInfo: *imp,
fixType: deleteImport, IdentName: p.importIdentifier(imp),
FixType: DeleteImport,
}) })
continue continue
} }
// An existing import may need to update its import name to be correct. // An existing import may need to update its import name to be correct.
if name := p.importSpecName(imp); name != imp.name { if name := p.importSpecName(imp); name != imp.Name {
fixes = append(fixes, &importFix{ fixes = append(fixes, &ImportFix{
info: importInfo{ StmtInfo: ImportInfo{
name: name, Name: name,
importPath: imp.importPath, ImportPath: imp.ImportPath,
}, },
fixType: setImportName, IdentName: p.importIdentifier(imp),
FixType: SetImportName,
}) })
} }
} }
for _, imp := range selected { for _, imp := range selected {
fixes = append(fixes, &importFix{ fixes = append(fixes, &ImportFix{
info: importInfo{ StmtInfo: ImportInfo{
name: p.importSpecName(imp), Name: p.importSpecName(imp),
importPath: imp.importPath, ImportPath: imp.ImportPath,
}, },
fixType: addImport, IdentName: p.importIdentifier(imp),
FixType: AddImport,
}) })
} }
@ -419,42 +426,41 @@ func (p *pass) fix() ([]*importFix, bool) {
// //
// When the import identifier matches the assumed import name, the import name does // When the import identifier matches the assumed import name, the import name does
// not appear in the import spec. // not appear in the import spec.
func (p *pass) importSpecName(imp *importInfo) string { func (p *pass) importSpecName(imp *ImportInfo) string {
// If we did not load the real package names, or the name is already set, // If we did not load the real package names, or the name is already set,
// we just return the existing name. // we just return the existing name.
if !p.loadRealPackageNames || imp.name != "" { if !p.loadRealPackageNames || imp.Name != "" {
return imp.name return imp.Name
} }
ident := p.importIdentifier(imp) ident := p.importIdentifier(imp)
if ident == importPathToAssumedName(imp.importPath) { if ident == importPathToAssumedName(imp.ImportPath) {
return "" // ident not needed since the assumed and real names are the same. return "" // ident not needed since the assumed and real names are the same.
} }
return ident return ident
} }
// apply will perform the fixes on f in order. // apply will perform the fixes on f in order.
func apply(fset *token.FileSet, f *ast.File, fixes []*importFix) bool { func apply(fset *token.FileSet, f *ast.File, fixes []*ImportFix) {
for _, fix := range fixes { for _, fix := range fixes {
switch fix.fixType { switch fix.FixType {
case deleteImport: case DeleteImport:
astutil.DeleteNamedImport(fset, f, fix.info.name, fix.info.importPath) astutil.DeleteNamedImport(fset, f, fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
case addImport: case AddImport:
astutil.AddNamedImport(fset, f, fix.info.name, fix.info.importPath) astutil.AddNamedImport(fset, f, fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
case setImportName: case SetImportName:
// Find the matching import path and change the name. // Find the matching import path and change the name.
for _, spec := range f.Imports { for _, spec := range f.Imports {
path := strings.Trim(spec.Path.Value, `"`) path := strings.Trim(spec.Path.Value, `"`)
if path == fix.info.importPath { if path == fix.StmtInfo.ImportPath {
spec.Name = &ast.Ident{ spec.Name = &ast.Ident{
Name: fix.info.name, Name: fix.StmtInfo.Name,
NamePos: spec.Pos(), NamePos: spec.Pos(),
} }
} }
} }
} }
} }
return true
} }
// assumeSiblingImportsValid assumes that siblings' use of packages is valid, // assumeSiblingImportsValid assumes that siblings' use of packages is valid,
@ -463,15 +469,15 @@ func (p *pass) assumeSiblingImportsValid() {
for _, f := range p.otherFiles { for _, f := range p.otherFiles {
refs := collectReferences(f) refs := collectReferences(f)
imports := collectImports(f) imports := collectImports(f)
importsByName := map[string]*importInfo{} importsByName := map[string]*ImportInfo{}
for _, imp := range imports { for _, imp := range imports {
importsByName[p.importIdentifier(imp)] = imp importsByName[p.importIdentifier(imp)] = imp
} }
for left, rights := range refs { for left, rights := range refs {
if imp, ok := importsByName[left]; ok { if imp, ok := importsByName[left]; ok {
if _, ok := stdlib[imp.importPath]; ok { if _, ok := stdlib[imp.ImportPath]; ok {
// We have the stdlib in memory; no need to guess. // We have the stdlib in memory; no need to guess.
rights = stdlib[imp.importPath] rights = stdlib[imp.ImportPath]
} }
p.addCandidate(imp, &packageInfo{ p.addCandidate(imp, &packageInfo{
// no name; we already know it. // no name; we already know it.
@ -484,9 +490,9 @@ func (p *pass) assumeSiblingImportsValid() {
// addCandidate adds a candidate import to p, and merges in the information // addCandidate adds a candidate import to p, and merges in the information
// in pkg. // in pkg.
func (p *pass) addCandidate(imp *importInfo, pkg *packageInfo) { func (p *pass) addCandidate(imp *ImportInfo, pkg *packageInfo) {
p.candidates = append(p.candidates, imp) p.candidates = append(p.candidates, imp)
if existing, ok := p.knownPackages[imp.importPath]; ok { if existing, ok := p.knownPackages[imp.ImportPath]; ok {
if existing.name == "" { if existing.name == "" {
existing.name = pkg.name existing.name = pkg.name
} }
@ -494,7 +500,7 @@ func (p *pass) addCandidate(imp *importInfo, pkg *packageInfo) {
existing.exports[export] = true existing.exports[export] = true
} }
} else { } else {
p.knownPackages[imp.importPath] = pkg p.knownPackages[imp.ImportPath] = pkg
} }
} }
@ -516,7 +522,7 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *P
// getFixes gets the import fixes that need to be made to f in order to fix the imports. // getFixes gets the import fixes that need to be made to f in order to fix the imports.
// It does not modify the ast. // It does not modify the ast.
func getFixes(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) ([]*importFix, error) { func getFixes(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) ([]*ImportFix, error) {
abs, err := filepath.Abs(filename) abs, err := filepath.Abs(filename)
if err != nil { if err != nil {
return nil, err return nil, err
@ -682,7 +688,7 @@ func cmdDebugStr(cmd *exec.Cmd) string {
func addStdlibCandidates(pass *pass, refs references) { func addStdlibCandidates(pass *pass, refs references) {
add := func(pkg string) { add := func(pkg string) {
pass.addCandidate( pass.addCandidate(
&importInfo{importPath: pkg}, &ImportInfo{ImportPath: pkg},
&packageInfo{name: path.Base(pkg), exports: stdlib[pkg]}) &packageInfo{name: path.Base(pkg), exports: stdlib[pkg]})
} }
for left := range refs { for left := range refs {
@ -768,7 +774,7 @@ func addExternalCandidates(pass *pass, refs references, filename string) error {
// Search for imports matching potential package references. // Search for imports matching potential package references.
type result struct { type result struct {
imp *importInfo imp *ImportInfo
pkg *packageInfo pkg *packageInfo
} }
results := make(chan result, len(refs)) results := make(chan result, len(refs))
@ -802,8 +808,8 @@ func addExternalCandidates(pass *pass, refs references, filename string) error {
return // No matching package. return // No matching package.
} }
imp := &importInfo{ imp := &ImportInfo{
importPath: found.importPathShort, ImportPath: found.importPathShort,
} }
pkg := &packageInfo{ pkg := &packageInfo{

View File

@ -13,6 +13,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"go/ast" "go/ast"
"go/build"
"go/format" "go/format"
"go/parser" "go/parser"
"go/printer" "go/printer"
@ -42,19 +43,11 @@ type Options struct {
} }
// Process implements golang.org/x/tools/imports.Process with explicit context in env. // Process implements golang.org/x/tools/imports.Process with explicit context in env.
func Process(filename string, src []byte, opt *Options) ([]byte, error) { func Process(filename string, src []byte, opt *Options) (formatted []byte, err error) {
if src == nil { src, err = initialize(filename, src, opt)
b, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
return nil, err return nil, err
} }
src = b
}
// Set the logger if the user has not provided it.
if opt.Env.Logf == nil {
opt.Env.Logf = log.Printf
}
fileSet := token.NewFileSet() fileSet := token.NewFileSet()
file, adjust, err := parse(fileSet, filename, src, opt) file, adjust, err := parse(fileSet, filename, src, opt)
@ -67,7 +60,85 @@ func Process(filename string, src []byte, opt *Options) ([]byte, error) {
return nil, err return nil, err
} }
} }
return formatFile(fileSet, file, src, adjust, opt)
}
// FixImports returns a list of fixes to the imports that, when applied,
// will leave the imports in the same state as Process.
//
// Note that filename's directory influences which imports can be chosen,
// so it is important that filename be accurate.
func FixImports(filename string, src []byte, opt *Options) (fixes []*ImportFix, err error) {
src, err = initialize(filename, src, opt)
if err != nil {
return nil, err
}
fileSet := token.NewFileSet()
file, _, err := parse(fileSet, filename, src, opt)
if err != nil {
return nil, err
}
return getFixes(fileSet, file, filename, opt.Env)
}
// ApplyFix will apply all of the fixes to the file and format it.
func ApplyFixes(fixes []*ImportFix, filename string, src []byte, opt *Options) (formatted []byte, err error) {
src, err = initialize(filename, src, opt)
if err != nil {
return nil, err
}
fileSet := token.NewFileSet()
file, adjust, err := parse(fileSet, filename, src, opt)
if err != nil {
return nil, err
}
// Apply the fixes to the file.
apply(fileSet, file, fixes)
return formatFile(fileSet, file, src, adjust, opt)
}
// initialize sets the values for opt and src.
// If they are provided, they are not changed. Otherwise opt is set to the
// default values and src is read from the file system.
func initialize(filename string, src []byte, opt *Options) ([]byte, error) {
// Use defaults if opt is nil.
if opt == nil {
opt = &Options{
Env: &ProcessEnv{
GOPATH: build.Default.GOPATH,
GOROOT: build.Default.GOROOT,
},
AllErrors: opt.AllErrors,
Comments: opt.Comments,
FormatOnly: opt.FormatOnly,
Fragment: opt.Fragment,
TabIndent: opt.TabIndent,
TabWidth: opt.TabWidth,
}
}
// Set the logger if the user has not provided it.
if opt.Env.Logf == nil {
opt.Env.Logf = log.Printf
}
if src == nil {
b, err := ioutil.ReadFile(filename)
if err != nil {
return nil, err
}
src = b
}
return src, nil
}
func formatFile(fileSet *token.FileSet, file *ast.File, src []byte, adjust func(orig []byte, src []byte) []byte, opt *Options) ([]byte, error) {
sortImports(opt.Env, fileSet, file) sortImports(opt.Env, fileSet, file)
imps := astutil.Imports(fileSet, file) imps := astutil.Imports(fileSet, file)
var spacesBefore []string // import paths we need spaces before var spacesBefore []string // import paths we need spaces before
@ -95,7 +166,7 @@ func Process(filename string, src []byte, opt *Options) ([]byte, error) {
printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth} printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth}
var buf bytes.Buffer var buf bytes.Buffer
err = printConfig.Fprint(&buf, fileSet, file) err := printConfig.Fprint(&buf, fileSet, file)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"golang.org/x/tools/internal/imports"
"golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/protocol"
"golang.org/x/tools/internal/lsp/source" "golang.org/x/tools/internal/lsp/source"
"golang.org/x/tools/internal/lsp/telemetry" "golang.org/x/tools/internal/lsp/telemetry"
@ -47,7 +48,7 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara
var codeActions []protocol.CodeAction var codeActions []protocol.CodeAction
edits, err := organizeImports(ctx, view, spn) edits, editsPerFix, err := organizeImports(ctx, view, spn)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -66,20 +67,25 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara
// If we also have diagnostics for missing imports, we can associate them with quick fixes. // If we also have diagnostics for missing imports, we can associate them with quick fixes.
if findImportErrors(params.Context.Diagnostics) { if findImportErrors(params.Context.Diagnostics) {
// TODO(rstambler): Separate this into a set of codeActions per diagnostic, // Separate this into a set of codeActions per diagnostic, where
// where each action is the addition or removal of one import. // each action is the addition, removal, or renaming of one import.
// This can only be done when https://golang.org/issue/31493 is resolved. for _, importFix := range editsPerFix {
// Get the diagnostics this fix would affect.
if fixDiagnostics := importDiagnostics(importFix.fix, params.Context.Diagnostics); len(fixDiagnostics) > 0 {
codeActions = append(codeActions, protocol.CodeAction{ codeActions = append(codeActions, protocol.CodeAction{
Title: "Organize All Imports", // clarify that all imports will change Title: importFixTitle(importFix.fix),
Kind: protocol.QuickFix, Kind: protocol.QuickFix,
Edit: &protocol.WorkspaceEdit{ Edit: &protocol.WorkspaceEdit{
Changes: &map[string][]protocol.TextEdit{ Changes: &map[string][]protocol.TextEdit{
string(uri): edits, string(uri): importFix.edits,
}, },
}, },
Diagnostics: fixDiagnostics,
}) })
} }
} }
}
}
// Add the results of import organization as source.OrganizeImports. // Add the results of import organization as source.OrganizeImports.
if wanted[protocol.SourceOrganizeImports] { if wanted[protocol.SourceOrganizeImports] {
@ -97,16 +103,38 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara
return codeActions, nil return codeActions, nil
} }
func organizeImports(ctx context.Context, view source.View, s span.Span) ([]protocol.TextEdit, error) { type protocolImportFix struct {
fix *imports.ImportFix
edits []protocol.TextEdit
}
func organizeImports(ctx context.Context, view source.View, s span.Span) ([]protocol.TextEdit, []*protocolImportFix, error) {
f, m, rng, err := spanToRange(ctx, view, s) f, m, rng, err := spanToRange(ctx, view, s)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
edits, err := source.Imports(ctx, view, f, rng) edits, editsPerFix, err := source.AllImportsFixes(ctx, view, f, rng)
if err != nil { if err != nil {
return nil, err return nil, nil, err
} }
return ToProtocolEdits(m, edits) // Convert all source edits to protocol edits.
pEdits, err := ToProtocolEdits(m, edits)
if err != nil {
return nil, nil, err
}
pEditsPerFix := make([]*protocolImportFix, len(editsPerFix))
for i, fix := range editsPerFix {
pEdits, err := ToProtocolEdits(m, fix.Edits)
if err != nil {
return nil, nil, err
}
pEditsPerFix[i] = &protocolImportFix{
fix: fix.Fix,
edits: pEdits,
}
}
return pEdits, pEditsPerFix, nil
} }
// findImports determines if a given diagnostic represents an error that could // findImports determines if a given diagnostic represents an error that could
@ -131,6 +159,49 @@ func findImportErrors(diagnostics []protocol.Diagnostic) bool {
return false return false
} }
func importFixTitle(fix *imports.ImportFix) string {
var str string
switch fix.FixType {
case imports.AddImport:
str = fmt.Sprintf("Add import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
case imports.DeleteImport:
str = fmt.Sprintf("Delete import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
case imports.SetImportName:
str = fmt.Sprintf("Rename import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath)
}
return str
}
func importDiagnostics(fix *imports.ImportFix, diagnostics []protocol.Diagnostic) (results []protocol.Diagnostic) {
for _, diagnostic := range diagnostics {
switch {
// "undeclared name: X" may be an unresolved import.
case strings.HasPrefix(diagnostic.Message, "undeclared name: "):
ident := strings.TrimPrefix(diagnostic.Message, "undeclared name: ")
if ident == fix.IdentName {
results = append(results, diagnostic)
}
// "could not import: X" may be an invalid import.
case strings.HasPrefix(diagnostic.Message, "could not import: "):
ident := strings.TrimPrefix(diagnostic.Message, "could not import: ")
if ident == fix.IdentName {
results = append(results, diagnostic)
}
// "X imported but not used" is an unused import.
// "X imported but not used as Y" is an unused import.
case strings.Contains(diagnostic.Message, " imported but not used"):
idx := strings.Index(diagnostic.Message, " imported but not used")
importPath := diagnostic.Message[:idx]
if importPath == fmt.Sprintf("%q", fix.StmtInfo.ImportPath) {
results = append(results, diagnostic)
}
}
}
return results
}
func quickFixes(ctx context.Context, view source.View, gof source.GoFile) ([]protocol.CodeAction, error) { func quickFixes(ctx context.Context, view source.View, gof source.GoFile) ([]protocol.CodeAction, error) {
var codeActions []protocol.CodeAction var codeActions []protocol.CodeAction

View File

@ -108,6 +108,72 @@ func Imports(ctx context.Context, view View, f GoFile, rng span.Range) ([]TextEd
return computeTextEdits(ctx, f, string(formatted)), nil return computeTextEdits(ctx, f, string(formatted)), nil
} }
type ImportFix struct {
Fix *imports.ImportFix
Edits []TextEdit
}
// AllImportsFixes formats f for each possible fix to the imports.
// In addition to returning the result of applying all edits,
// it returns a list of fixes that could be applied to the file, with the
// corresponding TextEdits that would be needed to apply that fix.
func AllImportsFixes(ctx context.Context, view View, f GoFile, rng span.Range) (edits []TextEdit, editsPerFix []*ImportFix, err error) {
ctx, done := trace.StartSpan(ctx, "source.AllImportsFixes")
defer done()
data, _, err := f.Handle(ctx).Read(ctx)
if err != nil {
return nil, nil, err
}
pkg := f.GetPackage(ctx)
if pkg == nil || pkg.IsIllTyped() {
return nil, nil, fmt.Errorf("no package for file %s", f.URI())
}
if hasListErrors(pkg.GetErrors()) {
return nil, nil, fmt.Errorf("%s has list errors, not running goimports", f.URI())
}
options := &imports.Options{
// Defaults.
AllErrors: true,
Comments: true,
Fragment: true,
FormatOnly: false,
TabIndent: true,
TabWidth: 8,
}
importFn := func(opts *imports.Options) error {
fixes, err := imports.FixImports(f.URI().Filename(), data, opts)
if err != nil {
return err
}
// Apply all of the import fixes to the file.
formatted, err := imports.ApplyFixes(fixes, f.URI().Filename(), data, options)
if err != nil {
return err
}
edits = computeTextEdits(ctx, f, string(formatted))
// Add the edits for each fix to the result.
editsPerFix = make([]*ImportFix, len(fixes))
for i, fix := range fixes {
formatted, err := imports.ApplyFixes([]*imports.ImportFix{fix}, f.URI().Filename(), data, options)
if err != nil {
return err
}
editsPerFix[i] = &ImportFix{
Fix: fix,
Edits: computeTextEdits(ctx, f, string(formatted)),
}
}
return err
}
err = view.RunProcessEnvFunc(ctx, importFn, options)
if err != nil {
return nil, nil, err
}
return edits, editsPerFix, nil
}
func hasParseErrors(errors []packages.Error) bool { func hasParseErrors(errors []packages.Error) bool {
for _, err := range errors { for _, err := range errors {
if err.Kind == packages.ParseError { if err.Kind == packages.ParseError {