diff --git a/go/analysis/passes/assign/assign.go b/go/analysis/passes/assign/assign.go index 4dff2908c3..a9aefe52fe 100644 --- a/go/analysis/passes/assign/assign.go +++ b/go/analysis/passes/assign/assign.go @@ -9,6 +9,7 @@ package assign // methods that are on T instead of *T. import ( + "fmt" "go/ast" "go/token" "reflect" @@ -59,7 +60,12 @@ func run(pass *analysis.Pass) (interface{}, error) { le := analysisutil.Format(pass.Fset, lhs) re := analysisutil.Format(pass.Fset, rhs) if le == re { - pass.Reportf(stmt.Pos(), "self-assignment of %s to %s", re, le) + pass.Report(analysis.Diagnostic{ + Pos: stmt.Pos(), Message: fmt.Sprintf("self-assignment of %s to %s", re, le), + SuggestedFixes: []analysis.SuggestedFix{ + {Message: "Remove", TextEdits: []analysis.TextEdit{{stmt.Pos(), stmt.End(), []byte{}}}}, + }, + }) } } }) diff --git a/internal/lsp/cache/pkg.go b/internal/lsp/cache/pkg.go index c50421d11a..64264afb07 100644 --- a/internal/lsp/cache/pkg.go +++ b/internal/lsp/cache/pkg.go @@ -39,7 +39,7 @@ type pkg struct { analyses map[*analysis.Analyzer]*analysisEntry diagMu sync.Mutex - diagnostics []source.Diagnostic + diagnostics map[*analysis.Analyzer][]source.Diagnostic } // packageID is a type that abstracts a package ID. @@ -180,14 +180,22 @@ func (pkg *pkg) IsIllTyped() bool { return pkg.types == nil || pkg.typesInfo == nil || pkg.typesSizes == nil } -func (pkg *pkg) SetDiagnostics(diags []source.Diagnostic) { +func (pkg *pkg) SetDiagnostics(a *analysis.Analyzer, diags []source.Diagnostic) { pkg.diagMu.Lock() defer pkg.diagMu.Unlock() - pkg.diagnostics = diags + if pkg.diagnostics == nil { + pkg.diagnostics = make(map[*analysis.Analyzer][]source.Diagnostic) + } + pkg.diagnostics[a] = diags } func (pkg *pkg) GetDiagnostics() []source.Diagnostic { pkg.diagMu.Lock() defer pkg.diagMu.Unlock() - return pkg.diagnostics + + var diags []source.Diagnostic + for _, d := range pkg.diagnostics { + diags = append(diags, d...) + } + return diags } diff --git a/internal/lsp/cmd/cmd_test.go b/internal/lsp/cmd/cmd_test.go index 1a79b37eef..b9e2fa6bfb 100644 --- a/internal/lsp/cmd/cmd_test.go +++ b/internal/lsp/cmd/cmd_test.go @@ -86,6 +86,10 @@ func (r *runner) Import(t *testing.T, data tests.Imports) { //TODO: add command line imports tests when it works } +func (r *runner) SuggestedFix(t *testing.T, data tests.SuggestedFixes) { + //TODO: add suggested fix tests when it works +} + func captureStdOut(t testing.TB, f func()) string { r, out, err := os.Pipe() if err != nil { diff --git a/internal/lsp/code_action.go b/internal/lsp/code_action.go index 5e00666be0..39eec9d637 100644 --- a/internal/lsp/code_action.go +++ b/internal/lsp/code_action.go @@ -77,18 +77,15 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara // If the user wants to see quickfixes. if wanted[protocol.QuickFix] { // First, add the quick fixes reported by go/analysis. - // TODO: Enable this when this actually works. For now, it's needless work. - if s.session.Options().SuggestedFixes { - gof, ok := f.(source.GoFile) - if !ok { - return nil, fmt.Errorf("%s is not a Go file", f.URI()) - } - qf, err := quickFixes(ctx, view, gof) - if err != nil { - log.Error(ctx, "quick fixes failed", err, telemetry.File.Of(uri)) - } - codeActions = append(codeActions, qf...) + gof, ok := f.(source.GoFile) + if !ok { + return nil, fmt.Errorf("%s is not a Go file", f.URI()) } + qf, err := quickFixes(ctx, view, gof) + if err != nil { + log.Error(ctx, "quick fixes failed", err, telemetry.File.Of(uri)) + } + codeActions = append(codeActions, qf...) // If we also have diagnostics for missing imports, we can associate them with quick fixes. if findImportErrors(params.Context.Diagnostics) { @@ -204,7 +201,7 @@ func quickFixes(ctx context.Context, view source.View, gof source.GoFile) ([]pro // TODO: This is technically racy because the diagnostics provided by the code action // may not be the same as the ones that gopls is aware of. // We need to figure out some way to solve this problem. - pkg, err := gof.GetPackage(ctx) + pkg, err := gof.GetCachedPackage(ctx) if err != nil { return nil, err } diff --git a/internal/lsp/general.go b/internal/lsp/general.go index be68b28dd9..0f49faa1b7 100644 --- a/internal/lsp/general.go +++ b/internal/lsp/general.go @@ -302,9 +302,6 @@ func (s *Server) processConfig(ctx context.Context, view source.View, options *s } } - // Check if the user wants to see suggested fixes from go/analysis. - setBool(&options.SuggestedFixes, c, "wantSuggestedFixes") - // Check if the user has explicitly disabled any analyses. if disabledAnalyses, ok := c["experimentalDisabledAnalyses"].([]interface{}); ok { options.DisabledAnalyses = make(map[string]struct{}) diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 2079e23b86..f1803329a0 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -491,6 +491,58 @@ func (r *runner) Import(t *testing.T, data tests.Imports) { } } +func (r *runner) SuggestedFix(t *testing.T, data tests.SuggestedFixes) { + for _, spn := range data { + uri := spn.URI() + filename := uri.Filename() + v := r.server.session.ViewOf(uri) + fixed := string(r.data.Golden("suggestedfix", filename, func() ([]byte, error) { + cmd := exec.Command("suggestedfix", filename) // TODO(matloob): what do we do here? + out, _ := cmd.Output() // ignore error, sometimes we have intentionally ungofmt-able files + return out, nil + })) + f, err := getGoFile(r.ctx, v, uri) + if err != nil { + t.Fatal(err) + } + results, err := source.Diagnostics(r.ctx, v, f, nil) + if err != nil { + t.Fatal(err) + } + _ = results + actions, err := r.server.CodeAction(r.ctx, &protocol.CodeActionParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: protocol.NewURI(uri), + }, + Context: protocol.CodeActionContext{Only: []protocol.CodeActionKind{protocol.QuickFix}}, + }) + if err != nil { + if fixed != "" { + t.Error(err) + } + continue + } + m, err := r.mapper(f.URI()) + if err != nil { + t.Fatal(err) + } + var edits []protocol.TextEdit + for _, a := range actions { + if a.Title == "Remove" { + edits = (*a.Edit.Changes)[string(uri)] + } + } + sedits, err := source.FromProtocolEdits(m, edits) + if err != nil { + t.Error(err) + } + got := diff.ApplyEdits(string(m.Content), sedits) + if fixed != got { + t.Errorf("suggested fixes failed for %s, expected:\n%v\ngot:\n%v", filename, fixed, got) + } + } +} + func (r *runner) Definition(t *testing.T, data tests.Definitions) { for _, d := range data { sm, err := r.mapper(d.Src.URI()) diff --git a/internal/lsp/source/diagnostics.go b/internal/lsp/source/diagnostics.go index 4d0a8d56ac..70108f3901 100644 --- a/internal/lsp/source/diagnostics.go +++ b/internal/lsp/source/diagnostics.go @@ -386,7 +386,7 @@ func runAnalyses(ctx context.Context, view View, cph CheckPackageHandle, disable if err != nil { return err } - pkg.SetDiagnostics(sdiags) + pkg.SetDiagnostics(r.Analyzer, sdiags) } return nil } diff --git a/internal/lsp/source/options.go b/internal/lsp/source/options.go index ae9d16db84..6770902031 100644 --- a/internal/lsp/source/options.go +++ b/internal/lsp/source/options.go @@ -33,7 +33,6 @@ type SessionOptions struct { BuildFlags []string UsePlaceholders bool HoverKind HoverKind - SuggestedFixes bool DisabledAnalyses map[string]struct{} WatchFileChanges bool diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index 14cc98b994..6271706892 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -494,6 +494,9 @@ func (r *runner) Import(t *testing.T, data tests.Imports) { } } +func (r *runner) SuggestedFix(t *testing.T, data tests.SuggestedFixes) { +} + func (r *runner) Definition(t *testing.T, data tests.Definitions) { ctx := r.ctx for _, d := range data { diff --git a/internal/lsp/source/view.go b/internal/lsp/source/view.go index 0f1a7b494c..d7b8b8295f 100644 --- a/internal/lsp/source/view.go +++ b/internal/lsp/source/view.go @@ -311,7 +311,7 @@ type Package interface { GetTypesSizes() types.Sizes IsIllTyped() bool GetDiagnostics() []Diagnostic - SetDiagnostics(diags []Diagnostic) + SetDiagnostics(a *analysis.Analyzer, diag []Diagnostic) // GetImport returns the CheckPackageHandle for a package imported by this package. GetImport(ctx context.Context, pkgPath string) (Package, error) diff --git a/internal/lsp/testdata/suggestedfix/has_suggested_fix.go b/internal/lsp/testdata/suggestedfix/has_suggested_fix.go new file mode 100644 index 0000000000..9ade674108 --- /dev/null +++ b/internal/lsp/testdata/suggestedfix/has_suggested_fix.go @@ -0,0 +1,11 @@ +package suggestedfix + +import ( + "log" +) + +func goodbye() { + s := "hiiiiiii" + s = s //@suggestedfix("s = s") + log.Printf(s) +} diff --git a/internal/lsp/testdata/suggestedfix/has_suggested_fix.go.golden b/internal/lsp/testdata/suggestedfix/has_suggested_fix.go.golden new file mode 100644 index 0000000000..10ec450d33 --- /dev/null +++ b/internal/lsp/testdata/suggestedfix/has_suggested_fix.go.golden @@ -0,0 +1,13 @@ +-- suggestedfix -- +package suggestedfix + +import ( + "log" +) + +func goodbye() { + s := "hiiiiiii" + //@suggestedfix("s = s") + log.Printf(s) +} + diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index ca33dcaadb..d81f42e94c 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -34,6 +34,7 @@ const ( ExpectedDiagnosticsCount = 21 ExpectedFormatCount = 6 ExpectedImportCount = 2 + ExpectedSuggestedFixCount = 1 ExpectedDefinitionsCount = 39 ExpectedTypeDefinitionsCount = 2 ExpectedFoldingRangesCount = 2 @@ -62,6 +63,7 @@ type CompletionSnippets map[span.Span]CompletionSnippet type FoldingRanges []span.Span type Formats []span.Span type Imports []span.Span +type SuggestedFixes []span.Span type Definitions map[span.Span]Definition type Highlights map[string][]span.Span type References map[span.Span][]span.Span @@ -82,6 +84,7 @@ type Data struct { FoldingRanges FoldingRanges Formats Formats Imports Imports + SuggestedFixes SuggestedFixes Definitions Definitions Highlights Highlights References References @@ -104,6 +107,7 @@ type Tests interface { FoldingRange(*testing.T, FoldingRanges) Format(*testing.T, Formats) Import(*testing.T, Imports) + SuggestedFix(*testing.T, SuggestedFixes) Definition(*testing.T, Definitions) Highlight(*testing.T, Highlights) Reference(*testing.T, References) @@ -228,23 +232,24 @@ func Load(t testing.TB, exporter packagestest.Exporter, dir string) *Data { // Collect any data that needs to be used by subsequent tests. if err := data.Exported.Expect(map[string]interface{}{ - "diag": data.collectDiagnostics, - "item": data.collectCompletionItems, - "complete": data.collectCompletions, - "fold": data.collectFoldingRanges, - "format": data.collectFormats, - "import": data.collectImports, - "godef": data.collectDefinitions, - "typdef": data.collectTypeDefinitions, - "hover": data.collectHoverDefinitions, - "highlight": data.collectHighlights, - "refs": data.collectReferences, - "rename": data.collectRenames, - "prepare": data.collectPrepareRenames, - "symbol": data.collectSymbols, - "signature": data.collectSignatures, - "snippet": data.collectCompletionSnippets, - "link": data.collectLinks, + "diag": data.collectDiagnostics, + "item": data.collectCompletionItems, + "complete": data.collectCompletions, + "fold": data.collectFoldingRanges, + "format": data.collectFormats, + "import": data.collectImports, + "godef": data.collectDefinitions, + "typdef": data.collectTypeDefinitions, + "hover": data.collectHoverDefinitions, + "highlight": data.collectHighlights, + "refs": data.collectReferences, + "rename": data.collectRenames, + "prepare": data.collectPrepareRenames, + "symbol": data.collectSymbols, + "signature": data.collectSignatures, + "snippet": data.collectCompletionSnippets, + "link": data.collectLinks, + "suggestedfix": data.collectSuggestedFixes, }); err != nil { t.Fatal(err) } @@ -313,6 +318,14 @@ func Run(t *testing.T, tests Tests, data *Data) { tests.Import(t, data.Imports) }) + t.Run("SuggestedFix", func(t *testing.T) { + t.Helper() + if len(data.SuggestedFixes) != ExpectedSuggestedFixCount { + t.Errorf("got %v suggested fixes expected %v", len(data.SuggestedFixes), ExpectedSuggestedFixCount) + } + tests.SuggestedFix(t, data.SuggestedFixes) + }) + t.Run("Definition", func(t *testing.T) { t.Helper() if len(data.Definitions) != ExpectedDefinitionsCount { @@ -587,6 +600,10 @@ func (data *Data) collectImports(spn span.Span) { data.Imports = append(data.Imports, spn) } +func (data *Data) collectSuggestedFixes(spn span.Span) { + data.SuggestedFixes = append(data.SuggestedFixes, spn) +} + func (data *Data) collectDefinitions(src, target span.Span) { data.Definitions[src] = Definition{ Src: src,