diff --git a/internal/imports/fix.go b/internal/imports/fix.go index 7cb63369b7..4066565192 100644 --- a/internal/imports/fix.go +++ b/internal/imports/fix.go @@ -67,23 +67,27 @@ func importGroup(env *ProcessEnv, importPath string) int { return 0 } -type importFixType int +type ImportFixType int const ( - addImport importFixType = iota - deleteImport - setImportName + AddImport ImportFixType = iota + DeleteImport + SetImportName ) -type importFix struct { - info importInfo - fixType importFixType +type ImportFix struct { + // StmtInfo represents the import statement this fix will add, remove, or change. + StmtInfo ImportInfo + // IdentName is the identifier that this fix will add or remove. + IdentName string + // FixType is the type of fix this is (AddImport, DeleteImport, SetImportName). + FixType ImportFixType } -// An importInfo represents a single import statement. -type importInfo struct { - importPath string // import path, e.g. "crypto/rand". - name string // import name, e.g. "crand", or "" if none. +// An ImportInfo represents a single import statement. +type ImportInfo struct { + ImportPath string // import path, e.g. "crypto/rand". + Name string // import name, e.g. "crand", or "" if none. } // A packageInfo represents what's known about a package. @@ -183,8 +187,8 @@ func collectReferences(f *ast.File) references { // collectImports returns all the imports in f. // Unnamed imports (., _) and "C" are ignored. -func collectImports(f *ast.File) []*importInfo { - var imports []*importInfo +func collectImports(f *ast.File) []*ImportInfo { + var imports []*ImportInfo for _, imp := range f.Imports { var name string if imp.Name != nil { @@ -194,9 +198,9 @@ func collectImports(f *ast.File) []*importInfo { continue } path := strings.Trim(imp.Path.Value, `"`) - imports = append(imports, &importInfo{ - name: name, - importPath: path, + imports = append(imports, &ImportInfo{ + Name: name, + ImportPath: path, }) } return imports @@ -204,9 +208,9 @@ func collectImports(f *ast.File) []*importInfo { // findMissingImport searches pass's candidates for an import that provides // pkg, containing all of syms. -func (p *pass) findMissingImport(pkg string, syms map[string]bool) *importInfo { +func (p *pass) findMissingImport(pkg string, syms map[string]bool) *ImportInfo { for _, candidate := range p.candidates { - pkgInfo, ok := p.knownPackages[candidate.importPath] + pkgInfo, ok := p.knownPackages[candidate.ImportPath] if !ok { continue } @@ -246,18 +250,18 @@ type pass struct { otherFiles []*ast.File // sibling files. // Intermediate state, generated by load. - existingImports map[string]*importInfo + existingImports map[string]*ImportInfo allRefs references missingRefs references // Inputs to fix. These can be augmented between successive fix calls. lastTry bool // indicates that this is the last call and fix should clean up as best it can. - candidates []*importInfo // candidate imports in priority order. + candidates []*ImportInfo // candidate imports in priority order. knownPackages map[string]*packageInfo // information about all known packages. } // loadPackageNames saves the package names for everything referenced by imports. -func (p *pass) loadPackageNames(imports []*importInfo) error { +func (p *pass) loadPackageNames(imports []*ImportInfo) error { if p.env.Debug { p.env.Logf("loading package names for %v packages", len(imports)) defer func() { @@ -266,10 +270,10 @@ func (p *pass) loadPackageNames(imports []*importInfo) error { } var unknown []string for _, imp := range imports { - if _, ok := p.knownPackages[imp.importPath]; ok { + if _, ok := p.knownPackages[imp.ImportPath]; ok { continue } - unknown = append(unknown, imp.importPath) + unknown = append(unknown, imp.ImportPath) } names, err := p.env.GetResolver().loadPackageNames(unknown, p.srcDir) @@ -289,24 +293,24 @@ func (p *pass) loadPackageNames(imports []*importInfo) error { // importIdentifier returns the identifier that imp will introduce. It will // guess if the package name has not been loaded, e.g. because the source // is not available. -func (p *pass) importIdentifier(imp *importInfo) string { - if imp.name != "" { - return imp.name +func (p *pass) importIdentifier(imp *ImportInfo) string { + if imp.Name != "" { + return imp.Name } - known := p.knownPackages[imp.importPath] + known := p.knownPackages[imp.ImportPath] if known != nil && known.name != "" { return known.name } - return importPathToAssumedName(imp.importPath) + return importPathToAssumedName(imp.ImportPath) } // load reads in everything necessary to run a pass, and reports whether the // file already has all the imports it needs. It fills in p.missingRefs with the // file's missing symbols, if any, or removes unused imports if not. -func (p *pass) load() ([]*importFix, bool) { +func (p *pass) load() ([]*ImportFix, bool) { p.knownPackages = map[string]*packageInfo{} p.missingRefs = references{} - p.existingImports = map[string]*importInfo{} + p.existingImports = map[string]*ImportInfo{} // Load basic information about the file in question. p.allRefs = collectReferences(p.f) @@ -361,9 +365,9 @@ func (p *pass) load() ([]*importFix, bool) { // fix attempts to satisfy missing imports using p.candidates. If it finds // everything, or if p.lastTry is true, it updates fixes to add the imports it found, // delete anything unused, and update import names, and returns true. -func (p *pass) fix() ([]*importFix, bool) { +func (p *pass) fix() ([]*ImportFix, bool) { // Find missing imports. - var selected []*importInfo + var selected []*ImportInfo for left, rights := range p.missingRefs { if imp := p.findMissingImport(left, rights); imp != nil { selected = append(selected, imp) @@ -375,7 +379,7 @@ func (p *pass) fix() ([]*importFix, bool) { } // Found everything, or giving up. Add the new imports and remove any unused. - var fixes []*importFix + var fixes []*ImportFix for _, imp := range p.existingImports { // We deliberately ignore globals here, because we can't be sure // they're in the same package. People do things like put multiple @@ -383,32 +387,35 @@ func (p *pass) fix() ([]*importFix, bool) { // remove imports if they happen to have the same name as a var in // a different package. if _, ok := p.allRefs[p.importIdentifier(imp)]; !ok { - fixes = append(fixes, &importFix{ - info: *imp, - fixType: deleteImport, + fixes = append(fixes, &ImportFix{ + StmtInfo: *imp, + IdentName: p.importIdentifier(imp), + FixType: DeleteImport, }) continue } // An existing import may need to update its import name to be correct. - if name := p.importSpecName(imp); name != imp.name { - fixes = append(fixes, &importFix{ - info: importInfo{ - name: name, - importPath: imp.importPath, + if name := p.importSpecName(imp); name != imp.Name { + fixes = append(fixes, &ImportFix{ + StmtInfo: ImportInfo{ + Name: name, + ImportPath: imp.ImportPath, }, - fixType: setImportName, + IdentName: p.importIdentifier(imp), + FixType: SetImportName, }) } } for _, imp := range selected { - fixes = append(fixes, &importFix{ - info: importInfo{ - name: p.importSpecName(imp), - importPath: imp.importPath, + fixes = append(fixes, &ImportFix{ + StmtInfo: ImportInfo{ + Name: p.importSpecName(imp), + ImportPath: imp.ImportPath, }, - fixType: addImport, + IdentName: p.importIdentifier(imp), + FixType: AddImport, }) } @@ -419,42 +426,41 @@ func (p *pass) fix() ([]*importFix, bool) { // // When the import identifier matches the assumed import name, the import name does // not appear in the import spec. -func (p *pass) importSpecName(imp *importInfo) string { +func (p *pass) importSpecName(imp *ImportInfo) string { // If we did not load the real package names, or the name is already set, // we just return the existing name. - if !p.loadRealPackageNames || imp.name != "" { - return imp.name + if !p.loadRealPackageNames || imp.Name != "" { + return imp.Name } ident := p.importIdentifier(imp) - if ident == importPathToAssumedName(imp.importPath) { + if ident == importPathToAssumedName(imp.ImportPath) { return "" // ident not needed since the assumed and real names are the same. } return ident } // apply will perform the fixes on f in order. -func apply(fset *token.FileSet, f *ast.File, fixes []*importFix) bool { +func apply(fset *token.FileSet, f *ast.File, fixes []*ImportFix) { for _, fix := range fixes { - switch fix.fixType { - case deleteImport: - astutil.DeleteNamedImport(fset, f, fix.info.name, fix.info.importPath) - case addImport: - astutil.AddNamedImport(fset, f, fix.info.name, fix.info.importPath) - case setImportName: + switch fix.FixType { + case DeleteImport: + astutil.DeleteNamedImport(fset, f, fix.StmtInfo.Name, fix.StmtInfo.ImportPath) + case AddImport: + astutil.AddNamedImport(fset, f, fix.StmtInfo.Name, fix.StmtInfo.ImportPath) + case SetImportName: // Find the matching import path and change the name. for _, spec := range f.Imports { path := strings.Trim(spec.Path.Value, `"`) - if path == fix.info.importPath { + if path == fix.StmtInfo.ImportPath { spec.Name = &ast.Ident{ - Name: fix.info.name, + Name: fix.StmtInfo.Name, NamePos: spec.Pos(), } } } } } - return true } // assumeSiblingImportsValid assumes that siblings' use of packages is valid, @@ -463,15 +469,15 @@ func (p *pass) assumeSiblingImportsValid() { for _, f := range p.otherFiles { refs := collectReferences(f) imports := collectImports(f) - importsByName := map[string]*importInfo{} + importsByName := map[string]*ImportInfo{} for _, imp := range imports { importsByName[p.importIdentifier(imp)] = imp } for left, rights := range refs { if imp, ok := importsByName[left]; ok { - if _, ok := stdlib[imp.importPath]; ok { + if _, ok := stdlib[imp.ImportPath]; ok { // We have the stdlib in memory; no need to guess. - rights = stdlib[imp.importPath] + rights = stdlib[imp.ImportPath] } p.addCandidate(imp, &packageInfo{ // no name; we already know it. @@ -484,9 +490,9 @@ func (p *pass) assumeSiblingImportsValid() { // addCandidate adds a candidate import to p, and merges in the information // in pkg. -func (p *pass) addCandidate(imp *importInfo, pkg *packageInfo) { +func (p *pass) addCandidate(imp *ImportInfo, pkg *packageInfo) { p.candidates = append(p.candidates, imp) - if existing, ok := p.knownPackages[imp.importPath]; ok { + if existing, ok := p.knownPackages[imp.ImportPath]; ok { if existing.name == "" { existing.name = pkg.name } @@ -494,7 +500,7 @@ func (p *pass) addCandidate(imp *importInfo, pkg *packageInfo) { existing.exports[export] = true } } else { - p.knownPackages[imp.importPath] = pkg + p.knownPackages[imp.ImportPath] = pkg } } @@ -516,7 +522,7 @@ func fixImportsDefault(fset *token.FileSet, f *ast.File, filename string, env *P // getFixes gets the import fixes that need to be made to f in order to fix the imports. // It does not modify the ast. -func getFixes(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) ([]*importFix, error) { +func getFixes(fset *token.FileSet, f *ast.File, filename string, env *ProcessEnv) ([]*ImportFix, error) { abs, err := filepath.Abs(filename) if err != nil { return nil, err @@ -682,7 +688,7 @@ func cmdDebugStr(cmd *exec.Cmd) string { func addStdlibCandidates(pass *pass, refs references) { add := func(pkg string) { pass.addCandidate( - &importInfo{importPath: pkg}, + &ImportInfo{ImportPath: pkg}, &packageInfo{name: path.Base(pkg), exports: stdlib[pkg]}) } for left := range refs { @@ -768,7 +774,7 @@ func addExternalCandidates(pass *pass, refs references, filename string) error { // Search for imports matching potential package references. type result struct { - imp *importInfo + imp *ImportInfo pkg *packageInfo } results := make(chan result, len(refs)) @@ -802,8 +808,8 @@ func addExternalCandidates(pass *pass, refs references, filename string) error { return // No matching package. } - imp := &importInfo{ - importPath: found.importPathShort, + imp := &ImportInfo{ + ImportPath: found.importPathShort, } pkg := &packageInfo{ diff --git a/internal/imports/imports.go b/internal/imports/imports.go index a47a815f58..acf1461b03 100644 --- a/internal/imports/imports.go +++ b/internal/imports/imports.go @@ -13,6 +13,7 @@ import ( "bytes" "fmt" "go/ast" + "go/build" "go/format" "go/parser" "go/printer" @@ -42,18 +43,10 @@ type Options struct { } // Process implements golang.org/x/tools/imports.Process with explicit context in env. -func Process(filename string, src []byte, opt *Options) ([]byte, error) { - if src == nil { - b, err := ioutil.ReadFile(filename) - if err != nil { - return nil, err - } - src = b - } - - // Set the logger if the user has not provided it. - if opt.Env.Logf == nil { - opt.Env.Logf = log.Printf +func Process(filename string, src []byte, opt *Options) (formatted []byte, err error) { + src, err = initialize(filename, src, opt) + if err != nil { + return nil, err } fileSet := token.NewFileSet() @@ -67,7 +60,85 @@ func Process(filename string, src []byte, opt *Options) ([]byte, error) { return nil, err } } + return formatFile(fileSet, file, src, adjust, opt) +} +// FixImports returns a list of fixes to the imports that, when applied, +// will leave the imports in the same state as Process. +// +// Note that filename's directory influences which imports can be chosen, +// so it is important that filename be accurate. +func FixImports(filename string, src []byte, opt *Options) (fixes []*ImportFix, err error) { + src, err = initialize(filename, src, opt) + if err != nil { + return nil, err + } + + fileSet := token.NewFileSet() + file, _, err := parse(fileSet, filename, src, opt) + if err != nil { + return nil, err + } + + return getFixes(fileSet, file, filename, opt.Env) +} + +// ApplyFix will apply all of the fixes to the file and format it. +func ApplyFixes(fixes []*ImportFix, filename string, src []byte, opt *Options) (formatted []byte, err error) { + src, err = initialize(filename, src, opt) + if err != nil { + return nil, err + } + + fileSet := token.NewFileSet() + file, adjust, err := parse(fileSet, filename, src, opt) + if err != nil { + return nil, err + } + + // Apply the fixes to the file. + apply(fileSet, file, fixes) + + return formatFile(fileSet, file, src, adjust, opt) +} + +// initialize sets the values for opt and src. +// If they are provided, they are not changed. Otherwise opt is set to the +// default values and src is read from the file system. +func initialize(filename string, src []byte, opt *Options) ([]byte, error) { + // Use defaults if opt is nil. + if opt == nil { + opt = &Options{ + Env: &ProcessEnv{ + GOPATH: build.Default.GOPATH, + GOROOT: build.Default.GOROOT, + }, + AllErrors: opt.AllErrors, + Comments: opt.Comments, + FormatOnly: opt.FormatOnly, + Fragment: opt.Fragment, + TabIndent: opt.TabIndent, + TabWidth: opt.TabWidth, + } + } + + // Set the logger if the user has not provided it. + if opt.Env.Logf == nil { + opt.Env.Logf = log.Printf + } + + if src == nil { + b, err := ioutil.ReadFile(filename) + if err != nil { + return nil, err + } + src = b + } + + return src, nil +} + +func formatFile(fileSet *token.FileSet, file *ast.File, src []byte, adjust func(orig []byte, src []byte) []byte, opt *Options) ([]byte, error) { sortImports(opt.Env, fileSet, file) imps := astutil.Imports(fileSet, file) var spacesBefore []string // import paths we need spaces before @@ -95,7 +166,7 @@ func Process(filename string, src []byte, opt *Options) ([]byte, error) { printConfig := &printer.Config{Mode: printerMode, Tabwidth: opt.TabWidth} var buf bytes.Buffer - err = printConfig.Fprint(&buf, fileSet, file) + err := printConfig.Fprint(&buf, fileSet, file) if err != nil { return nil, err } diff --git a/internal/lsp/code_action.go b/internal/lsp/code_action.go index 64106369e7..4447532483 100644 --- a/internal/lsp/code_action.go +++ b/internal/lsp/code_action.go @@ -9,6 +9,7 @@ import ( "fmt" "strings" + "golang.org/x/tools/internal/imports" "golang.org/x/tools/internal/lsp/protocol" "golang.org/x/tools/internal/lsp/source" "golang.org/x/tools/internal/lsp/telemetry" @@ -47,7 +48,7 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara var codeActions []protocol.CodeAction - edits, err := organizeImports(ctx, view, spn) + edits, editsPerFix, err := organizeImports(ctx, view, spn) if err != nil { return nil, err } @@ -66,18 +67,23 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara // If we also have diagnostics for missing imports, we can associate them with quick fixes. if findImportErrors(params.Context.Diagnostics) { - // TODO(rstambler): Separate this into a set of codeActions per diagnostic, - // where each action is the addition or removal of one import. - // This can only be done when https://golang.org/issue/31493 is resolved. - codeActions = append(codeActions, protocol.CodeAction{ - Title: "Organize All Imports", // clarify that all imports will change - Kind: protocol.QuickFix, - Edit: &protocol.WorkspaceEdit{ - Changes: &map[string][]protocol.TextEdit{ - string(uri): edits, - }, - }, - }) + // Separate this into a set of codeActions per diagnostic, where + // each action is the addition, removal, or renaming of one import. + for _, importFix := range editsPerFix { + // Get the diagnostics this fix would affect. + if fixDiagnostics := importDiagnostics(importFix.fix, params.Context.Diagnostics); len(fixDiagnostics) > 0 { + codeActions = append(codeActions, protocol.CodeAction{ + Title: importFixTitle(importFix.fix), + Kind: protocol.QuickFix, + Edit: &protocol.WorkspaceEdit{ + Changes: &map[string][]protocol.TextEdit{ + string(uri): importFix.edits, + }, + }, + Diagnostics: fixDiagnostics, + }) + } + } } } @@ -97,16 +103,38 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara return codeActions, nil } -func organizeImports(ctx context.Context, view source.View, s span.Span) ([]protocol.TextEdit, error) { +type protocolImportFix struct { + fix *imports.ImportFix + edits []protocol.TextEdit +} + +func organizeImports(ctx context.Context, view source.View, s span.Span) ([]protocol.TextEdit, []*protocolImportFix, error) { f, m, rng, err := spanToRange(ctx, view, s) if err != nil { - return nil, err + return nil, nil, err } - edits, err := source.Imports(ctx, view, f, rng) + edits, editsPerFix, err := source.AllImportsFixes(ctx, view, f, rng) if err != nil { - return nil, err + return nil, nil, err } - return ToProtocolEdits(m, edits) + // Convert all source edits to protocol edits. + pEdits, err := ToProtocolEdits(m, edits) + if err != nil { + return nil, nil, err + } + + pEditsPerFix := make([]*protocolImportFix, len(editsPerFix)) + for i, fix := range editsPerFix { + pEdits, err := ToProtocolEdits(m, fix.Edits) + if err != nil { + return nil, nil, err + } + pEditsPerFix[i] = &protocolImportFix{ + fix: fix.Fix, + edits: pEdits, + } + } + return pEdits, pEditsPerFix, nil } // findImports determines if a given diagnostic represents an error that could @@ -131,6 +159,49 @@ func findImportErrors(diagnostics []protocol.Diagnostic) bool { return false } +func importFixTitle(fix *imports.ImportFix) string { + var str string + switch fix.FixType { + case imports.AddImport: + str = fmt.Sprintf("Add import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath) + case imports.DeleteImport: + str = fmt.Sprintf("Delete import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath) + case imports.SetImportName: + str = fmt.Sprintf("Rename import: %s %q", fix.StmtInfo.Name, fix.StmtInfo.ImportPath) + } + return str +} + +func importDiagnostics(fix *imports.ImportFix, diagnostics []protocol.Diagnostic) (results []protocol.Diagnostic) { + for _, diagnostic := range diagnostics { + switch { + // "undeclared name: X" may be an unresolved import. + case strings.HasPrefix(diagnostic.Message, "undeclared name: "): + ident := strings.TrimPrefix(diagnostic.Message, "undeclared name: ") + if ident == fix.IdentName { + results = append(results, diagnostic) + } + // "could not import: X" may be an invalid import. + case strings.HasPrefix(diagnostic.Message, "could not import: "): + ident := strings.TrimPrefix(diagnostic.Message, "could not import: ") + if ident == fix.IdentName { + results = append(results, diagnostic) + } + // "X imported but not used" is an unused import. + // "X imported but not used as Y" is an unused import. + case strings.Contains(diagnostic.Message, " imported but not used"): + idx := strings.Index(diagnostic.Message, " imported but not used") + importPath := diagnostic.Message[:idx] + if importPath == fmt.Sprintf("%q", fix.StmtInfo.ImportPath) { + results = append(results, diagnostic) + } + } + + } + + return results +} + func quickFixes(ctx context.Context, view source.View, gof source.GoFile) ([]protocol.CodeAction, error) { var codeActions []protocol.CodeAction diff --git a/internal/lsp/source/format.go b/internal/lsp/source/format.go index b4e24b4b34..1f100c5c8a 100644 --- a/internal/lsp/source/format.go +++ b/internal/lsp/source/format.go @@ -108,6 +108,72 @@ func Imports(ctx context.Context, view View, f GoFile, rng span.Range) ([]TextEd return computeTextEdits(ctx, f, string(formatted)), nil } +type ImportFix struct { + Fix *imports.ImportFix + Edits []TextEdit +} + +// AllImportsFixes formats f for each possible fix to the imports. +// In addition to returning the result of applying all edits, +// it returns a list of fixes that could be applied to the file, with the +// corresponding TextEdits that would be needed to apply that fix. +func AllImportsFixes(ctx context.Context, view View, f GoFile, rng span.Range) (edits []TextEdit, editsPerFix []*ImportFix, err error) { + ctx, done := trace.StartSpan(ctx, "source.AllImportsFixes") + defer done() + data, _, err := f.Handle(ctx).Read(ctx) + if err != nil { + return nil, nil, err + } + pkg := f.GetPackage(ctx) + if pkg == nil || pkg.IsIllTyped() { + return nil, nil, fmt.Errorf("no package for file %s", f.URI()) + } + if hasListErrors(pkg.GetErrors()) { + return nil, nil, fmt.Errorf("%s has list errors, not running goimports", f.URI()) + } + + options := &imports.Options{ + // Defaults. + AllErrors: true, + Comments: true, + Fragment: true, + FormatOnly: false, + TabIndent: true, + TabWidth: 8, + } + importFn := func(opts *imports.Options) error { + fixes, err := imports.FixImports(f.URI().Filename(), data, opts) + if err != nil { + return err + } + // Apply all of the import fixes to the file. + formatted, err := imports.ApplyFixes(fixes, f.URI().Filename(), data, options) + if err != nil { + return err + } + edits = computeTextEdits(ctx, f, string(formatted)) + // Add the edits for each fix to the result. + editsPerFix = make([]*ImportFix, len(fixes)) + for i, fix := range fixes { + formatted, err := imports.ApplyFixes([]*imports.ImportFix{fix}, f.URI().Filename(), data, options) + if err != nil { + return err + } + editsPerFix[i] = &ImportFix{ + Fix: fix, + Edits: computeTextEdits(ctx, f, string(formatted)), + } + } + return err + } + err = view.RunProcessEnvFunc(ctx, importFn, options) + if err != nil { + return nil, nil, err + } + + return edits, editsPerFix, nil +} + func hasParseErrors(errors []packages.Error) bool { for _, err := range errors { if err.Kind == packages.ParseError {