diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 9de3887a2e9..418da42298f 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -5,8 +5,10 @@ package lsp import ( + "bytes" "context" "go/token" + "os/exec" "path/filepath" "reflect" "sort" @@ -47,6 +49,7 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { // collect results for certain tests expectedDiagnostics := make(map[string][]protocol.Diagnostic) expectedCompletions := make(map[token.Position]*protocol.CompletionItem) + expectedFormat := make(map[string]string) s := &server{ view: source.NewView(), @@ -78,54 +81,13 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { // Collect any data that needs to be used by subsequent tests. if err := exported.Expect(map[string]interface{}{ "diag": func(pos token.Position, msg string) { - line := float64(pos.Line - 1) - col := float64(pos.Column - 1) - want := protocol.Diagnostic{ - Range: protocol.Range{ - Start: protocol.Position{ - Line: line, - Character: col, - }, - End: protocol.Position{ - Line: line, - Character: col, - }, - }, - Severity: protocol.SeverityError, - Source: "LSP", - Message: msg, - } - if _, ok := expectedDiagnostics[pos.Filename]; ok { - expectedDiagnostics[pos.Filename] = append(expectedDiagnostics[pos.Filename], want) - } else { - t.Errorf("unexpected filename: %v", pos.Filename) - } + collectDiagnostics(t, expectedDiagnostics, pos, msg) }, "item": func(pos token.Position, label, detail, kind string) { - var k protocol.CompletionItemKind - switch kind { - case "struct": - k = protocol.StructCompletion - case "func": - k = protocol.FunctionCompletion - case "var": - k = protocol.VariableCompletion - case "type": - k = protocol.TypeParameterCompletion - case "field": - k = protocol.FieldCompletion - case "interface": - k = protocol.InterfaceCompletion - case "const": - k = protocol.ConstantCompletion - case "method": - k = protocol.MethodCompletion - } - expectedCompletions[pos] = &protocol.CompletionItem{ - Label: label, - Detail: detail, - Kind: float64(k), - } + collectCompletionItems(expectedCompletions, pos, label, detail, kind) + }, + "format": func(pos token.Position) { + collectFormat(expectedFormat, pos) }, }); err != nil { t.Fatal(err) @@ -145,26 +107,9 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { t.Fatal(err) } testDiagnostics(t, s.view, pkgs, expectedDiagnostics) -} -func testDiagnostics(t *testing.T, v *source.View, pkgs []*packages.Package, wants map[string][]protocol.Diagnostic) { - for _, pkg := range pkgs { - for _, filename := range pkg.GoFiles { - f := v.GetFile(source.ToURI(filename)) - diagnostics, err := source.Diagnostics(context.Background(), v, f) - if err != nil { - t.Fatal(err) - } - got := toProtocolDiagnostics(v, diagnostics[filename]) - sort.Slice(got, func(i int, j int) bool { - return got[i].Range.Start.Line < got[j].Range.Start.Line - }) - want := wants[filename] - if equal := reflect.DeepEqual(want, got); !equal { - t.Errorf("diagnostics failed for %s: (expected: %v), (got: %v)", filepath.Base(filename), want, got) - } - } - } + // test format + testFormat(t, s, expectedFormat) } func testCompletion(t *testing.T, exported *packagestest.Exported, s *server, wants map[token.Position]*protocol.CompletionItem) { @@ -197,3 +142,103 @@ func testCompletion(t *testing.T, exported *packagestest.Exported, s *server, wa t.Fatal(err) } } + +func collectCompletionItems(expectedCompletions map[token.Position]*protocol.CompletionItem, pos token.Position, label, detail, kind string) { + var k protocol.CompletionItemKind + switch kind { + case "struct": + k = protocol.StructCompletion + case "func": + k = protocol.FunctionCompletion + case "var": + k = protocol.VariableCompletion + case "type": + k = protocol.TypeParameterCompletion + case "field": + k = protocol.FieldCompletion + case "interface": + k = protocol.InterfaceCompletion + case "const": + k = protocol.ConstantCompletion + case "method": + k = protocol.MethodCompletion + } + expectedCompletions[pos] = &protocol.CompletionItem{ + Label: label, + Detail: detail, + Kind: float64(k), + } +} + +func testDiagnostics(t *testing.T, v *source.View, pkgs []*packages.Package, wants map[string][]protocol.Diagnostic) { + for _, pkg := range pkgs { + for _, filename := range pkg.GoFiles { + f := v.GetFile(source.ToURI(filename)) + diagnostics, err := source.Diagnostics(context.Background(), v, f) + if err != nil { + t.Fatal(err) + } + got := toProtocolDiagnostics(v, diagnostics[filename]) + sort.Slice(got, func(i int, j int) bool { + return got[i].Range.Start.Line < got[j].Range.Start.Line + }) + want := wants[filename] + if equal := reflect.DeepEqual(want, got); !equal { + t.Errorf("diagnostics failed for %s: (expected: %v), (got: %v)", filepath.Base(filename), want, got) + } + } + } +} + +func collectDiagnostics(t *testing.T, expectedDiagnostics map[string][]protocol.Diagnostic, pos token.Position, msg string) { + line := float64(pos.Line - 1) + col := float64(pos.Column - 1) + want := protocol.Diagnostic{ + Range: protocol.Range{ + Start: protocol.Position{ + Line: line, + Character: col, + }, + End: protocol.Position{ + Line: line, + Character: col, + }, + }, + Severity: protocol.SeverityError, + Source: "LSP", + Message: msg, + } + if _, ok := expectedDiagnostics[pos.Filename]; ok { + expectedDiagnostics[pos.Filename] = append(expectedDiagnostics[pos.Filename], want) + } else { + t.Errorf("unexpected filename: %v", pos.Filename) + } +} + +func testFormat(t *testing.T, s *server, expectedFormat map[string]string) { + for filename, gofmted := range expectedFormat { + edits, err := s.Formatting(context.Background(), &protocol.DocumentFormattingParams{ + TextDocument: protocol.TextDocumentIdentifier{ + URI: protocol.DocumentURI(source.ToURI(filename)), + }, + }) + if err != nil || len(edits) == 0 { + if gofmted != "" { + t.Error(err) + } + return + } + edit := edits[0] + if edit.NewText != gofmted { + t.Errorf("formatting failed: (got: %s), (expected: %s)", edit.NewText, gofmted) + } + } +} + +func collectFormat(expectedFormat map[string]string, pos token.Position) { + cmd := exec.Command("gofmt", pos.Filename) + stdout := bytes.NewBuffer(nil) + cmd.Stdout = stdout + cmd.Run() // ignore error, sometimes we have intentionally ungofmt-able files + expectedFormat[pos.Filename] = stdout.String() +} diff --git a/internal/lsp/server.go b/internal/lsp/server.go index 35d91915ea6..00c9075675d 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -43,11 +43,15 @@ func (s *server) Initialize(ctx context.Context, params *protocol.InitializePara s.initialized = true return &protocol.InitializeResult{ Capabilities: protocol.ServerCapabilities{ - CompletionProvider: protocol.CompletionOptions{}, + CompletionProvider: protocol.CompletionOptions{ + TriggerCharacters: []string{"."}, + }, DefinitionProvider: true, DocumentFormattingProvider: true, DocumentRangeFormattingProvider: true, - SignatureHelpProvider: protocol.SignatureHelpOptions{}, + SignatureHelpProvider: protocol.SignatureHelpOptions{ + TriggerCharacters: []string{"("}, + }, TextDocumentSync: protocol.TextDocumentSyncOptions{ Change: float64(protocol.Full), // full contents of file sent on each update OpenClose: true, diff --git a/internal/lsp/source/format.go b/internal/lsp/source/format.go index 5e7c1e7f96f..c7b46d236de 100644 --- a/internal/lsp/source/format.go +++ b/internal/lsp/source/format.go @@ -7,7 +7,11 @@ package source import ( "bytes" "context" + "fmt" + "go/ast" "go/format" + + "golang.org/x/tools/go/ast/astutil" ) // Format formats a document with a given range. @@ -16,17 +20,33 @@ func Format(ctx context.Context, f *File, rng Range) ([]TextEdit, error) { if err != nil { return nil, err } - - // TODO(rstambler): use astutil.PathEnclosingInterval to - // find the largest ast.Node n contained within start:end, and format the - // region n.Pos-n.End instead. - + path, exact := astutil.PathEnclosingInterval(fAST, rng.Start, rng.End) + if !exact || len(path) == 0 { + return nil, fmt.Errorf("no exact AST node matching the specified range") + } + node := path[0] + // format.Node can fail when the AST contains a bad expression or + // statement. For now, we preemptively check for one. + // TODO(rstambler): This should really return an error from format.Node. + var isBad bool + ast.Inspect(node, func(n ast.Node) bool { + switch n.(type) { + case *ast.BadDecl, *ast.BadExpr, *ast.BadStmt: + isBad = true + return false + default: + return true + } + }) + if isBad { + return nil, fmt.Errorf("unable to format file due to a badly formatted AST") + } // format.Node changes slightly from one release to another, so the version // of Go used to build the LSP server will determine how it formats code. // This should be acceptable for all users, who likely be prompted to rebuild // the LSP server on each Go release. buf := &bytes.Buffer{} - if err := format.Node(buf, f.view.Config.Fset, fAST); err != nil { + if err := format.Node(buf, f.view.Config.Fset, node); err != nil { return nil, err } // TODO(rstambler): Compute text edits instead of replacing whole file. diff --git a/internal/lsp/testdata/format/bad_format.go b/internal/lsp/testdata/format/bad_format.go new file mode 100644 index 00000000000..77f0861cada --- /dev/null +++ b/internal/lsp/testdata/format/bad_format.go @@ -0,0 +1,21 @@ +package format //@format("package") + +import ( + "fmt" + "runtime" + + "log" +) + +func hello() { + + var x int //@diag("x", "x declared but not used") +} + +func hi() { + + runtime.GOROOT() + fmt.Printf("") + + log.Printf("") +} diff --git a/internal/lsp/testdata/format/good_format.go b/internal/lsp/testdata/format/good_format.go new file mode 100644 index 00000000000..01cb1610ce8 --- /dev/null +++ b/internal/lsp/testdata/format/good_format.go @@ -0,0 +1,9 @@ +package format //@format("package") + +import ( + "log" +) + +func goodbye() { + log.Printf("byeeeee") +} diff --git a/internal/lsp/testdata/noparse_format/noparse_format.go.in b/internal/lsp/testdata/noparse_format/noparse_format.go.in new file mode 100644 index 00000000000..eb2ad2e6e4f --- /dev/null +++ b/internal/lsp/testdata/noparse_format/noparse_format.go.in @@ -0,0 +1,9 @@ +// +build go1.11 + +package noparse_format //@format("package") + +func what() { + var b int + if { hi() //@diag("{", "missing condition in if statement") + } +} \ No newline at end of file