From 2538eef75904eff384a2551359968e40c207d9d2 Mon Sep 17 00:00:00 2001 From: Zac Bergquist Date: Sat, 30 Mar 2019 14:18:22 -0600 Subject: [PATCH] internal/lsp: enhance document symbols support Make methods children of their receiver's type symbol. Add struct fields as children of the struct's type symbol. Also identify numeric, boolean, and string types. Updates golang/go#30915 Fixes golang/go#31202 Change-Id: I33c4ea7b953e981ea1e858505b77c7a3ba6ee399 Reviewed-on: https://go-review.googlesource.com/c/tools/+/170198 Run-TryBot: Rebecca Stambler Reviewed-by: Rebecca Stambler --- internal/lsp/lsp_test.go | 83 +++++++++++++----- internal/lsp/source/symbols.go | 116 ++++++++++++++++++++++++-- internal/lsp/symbols.go | 8 ++ internal/lsp/testdata/symbols/main.go | 38 ++++++--- 4 files changed, 204 insertions(+), 41 deletions(-) diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index a49060db109..b42d1d8efe5 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -93,7 +93,10 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { expectedDefinitions := make(definitions) expectedTypeDefinitions := make(definitions) expectedHighlights := make(highlights) - expectedSymbols := make(symbols) + expectedSymbols := &symbols{ + m: make(map[span.URI][]protocol.DocumentSymbol), + children: make(map[string][]protocol.DocumentSymbol), + } // Collect any data that needs to be used by subsequent tests. if err := exported.Expect(map[string]interface{}{ @@ -180,8 +183,8 @@ func testLSP(t *testing.T, exporter packagestest.Exporter) { t.Run("Symbols", func(t *testing.T) { t.Helper() if goVersion111 { // TODO(rstambler): Remove this when we no longer support Go 1.10. - if len(expectedSymbols) != expectedSymbolsCount { - t.Errorf("got %v symbols expected %v", len(expectedSymbols), expectedSymbolsCount) + if len(expectedSymbols.m) != expectedSymbolsCount { + t.Errorf("got %v symbols expected %v", len(expectedSymbols.m), expectedSymbolsCount) } } expectedSymbols.test(t, s) @@ -194,7 +197,10 @@ type completions map[token.Position][]token.Pos type formats map[string]string type definitions map[protocol.Location]protocol.Location type highlights map[string][]protocol.Location -type symbols map[span.URI][]protocol.DocumentSymbol +type symbols struct { + m map[span.URI][]protocol.DocumentSymbol + children map[string][]protocol.DocumentSymbol +} func (d diagnostics) test(t *testing.T, v source.View) int { count := 0 @@ -522,7 +528,7 @@ func (h highlights) test(t *testing.T, s *Server) { } } -func (s symbols) collect(e *packagestest.Exported, fset *token.FileSet, name string, rng span.Range, kind int64) { +func (s symbols) collect(e *packagestest.Exported, fset *token.FileSet, name string, rng span.Range, kind int64, parentName string) { f := fset.File(rng.Start) if f == nil { return @@ -544,15 +550,20 @@ func (s symbols) collect(e *packagestest.Exported, fset *token.FileSet, name str return } - s[spn.URI()] = append(s[spn.URI()], protocol.DocumentSymbol{ + sym := protocol.DocumentSymbol{ Name: name, Kind: protocol.SymbolKind(kind), SelectionRange: prng, - }) + } + if parentName == "" { + s.m[spn.URI()] = append(s.m[spn.URI()], sym) + } else { + s.children[parentName] = append(s.children[parentName], sym) + } } func (s symbols) test(t *testing.T, server *Server) { - for uri, expectedSymbols := range s { + for uri, expectedSymbols := range s.m { params := &protocol.DocumentSymbolParams{ TextDocument: protocol.TextDocumentIdentifier{ URI: string(uri), @@ -564,28 +575,58 @@ func (s symbols) test(t *testing.T, server *Server) { } if len(symbols) != len(expectedSymbols) { - t.Errorf("want %d symbols in %v, got %d", len(expectedSymbols), uri, len(symbols)) + t.Errorf("want %d top-level symbols in %v, got %d", len(expectedSymbols), uri, len(symbols)) continue } sort.Slice(symbols, func(i, j int) bool { return symbols[i].Name < symbols[j].Name }) sort.Slice(expectedSymbols, func(i, j int) bool { return expectedSymbols[i].Name < expectedSymbols[j].Name }) - for i, w := range expectedSymbols { - g := symbols[i] - if w.Name != g.Name { - t.Errorf("%s: want symbol %q, got %q", uri, w.Name, g.Name) - continue - } - if w.Kind != g.Kind { - t.Errorf("%s: want kind %v for %s, got %v", uri, w.Kind, w.Name, g.Kind) - } - if w.SelectionRange != g.SelectionRange { - t.Errorf("%s: want selection range %v for %s, got %v", uri, w.SelectionRange, w.Name, g.SelectionRange) - } + for i := range expectedSymbols { + children := s.children[expectedSymbols[i].Name] + sort.Slice(children, func(i, j int) bool { return children[i].Name < children[j].Name }) + expectedSymbols[i].Children = children + } + if diff := diffSymbols(uri, expectedSymbols, symbols); diff != "" { + t.Error(diff) } } } +func diffSymbols(uri span.URI, want, got []protocol.DocumentSymbol) string { + if len(got) != len(want) { + goto Failed + } + for i, w := range want { + g := got[i] + if w.Name != g.Name { + goto Failed + } + if w.Kind != g.Kind { + goto Failed + } + if w.SelectionRange != g.SelectionRange { + goto Failed + } + sort.Slice(g.Children, func(i, j int) bool { return g.Children[i].Name < g.Children[j].Name }) + if msg := diffSymbols(uri, w.Children, g.Children); msg != "" { + return fmt.Sprintf("children of %s: %s", w.Name, msg) + } + } + return "" + +Failed: + msg := &bytes.Buffer{} + fmt.Fprintf(msg, "document symbols failed for %s:\nexpected:\n", uri) + for _, s := range want { + fmt.Fprintf(msg, " %v %v %v\n", s.Name, s.Kind, s.SelectionRange) + } + fmt.Fprintf(msg, "got:\n") + for _, s := range got { + fmt.Fprintf(msg, " %v %v %v\n", s.Name, s.Kind, s.SelectionRange) + } + return msg.String() +} + func testLocation(e *packagestest.Exported, fset *token.FileSet, rng packagestest.Range) (span.Span, *protocol.ColumnMapper) { spn, err := span.NewRange(fset, rng.Start, rng.End).Span() if err != nil { diff --git a/internal/lsp/source/symbols.go b/internal/lsp/source/symbols.go index f95d3f384e7..4b4d2b8b1d9 100644 --- a/internal/lsp/source/symbols.go +++ b/internal/lsp/source/symbols.go @@ -6,6 +6,7 @@ package source import ( "context" + "errors" "fmt" "go/ast" "go/token" @@ -24,6 +25,10 @@ const ( FunctionSymbol MethodSymbol InterfaceSymbol + NumberSymbol + StringSymbol + BooleanSymbol + FieldSymbol ) type Symbol struct { @@ -42,19 +47,30 @@ func DocumentSymbols(ctx context.Context, f File) []Symbol { info := pkg.GetTypesInfo() q := qualifier(file, pkg.GetTypes(), info) + methodsToReceiver := make(map[types.Type][]Symbol) + symbolsToReceiver := make(map[types.Type]int) var symbols []Symbol for _, decl := range file.Decls { switch decl := decl.(type) { case *ast.FuncDecl: if obj := info.ObjectOf(decl.Name); obj != nil { - symbols = append(symbols, funcSymbol(decl, obj, fset, q)) + if fs := funcSymbol(decl, obj, fset, q); fs.Kind == MethodSymbol { + // Store methods separately, as we want them to appear as children + // of the corresponding type (which we may not have seen yet). + rtype := obj.Type().(*types.Signature).Recv().Type() + methodsToReceiver[rtype] = append(methodsToReceiver[rtype], fs) + } else { + symbols = append(symbols, fs) + } } case *ast.GenDecl: for _, spec := range decl.Specs { switch spec := spec.(type) { case *ast.TypeSpec: if obj := info.ObjectOf(spec.Name); obj != nil { - symbols = append(symbols, typeSymbol(spec, obj, fset, q)) + ts := typeSymbol(spec, obj, fset, q) + symbols = append(symbols, ts) + symbolsToReceiver[obj.Type()] = len(symbols) - 1 } case *ast.ValueSpec: for _, name := range spec.Names { @@ -66,6 +82,21 @@ func DocumentSymbols(ctx context.Context, f File) []Symbol { } } } + + // Attempt to associate methods to the corresponding type symbol. + for typ, methods := range methodsToReceiver { + if ptr, ok := typ.(*types.Pointer); ok { + typ = ptr.Elem() + } + + if i, ok := symbolsToReceiver[typ]; ok { + symbols[i].Children = append(symbols[i].Children, methods...) + } else { + // The type definition for the receiver of these methods was not in the document. + symbols = append(symbols, methods...) + } + } + return symbols } @@ -102,24 +133,88 @@ func funcSymbol(decl *ast.FuncDecl, obj types.Object, fset *token.FileSet, q typ return s } -func typeSymbol(spec *ast.TypeSpec, obj types.Object, fset *token.FileSet, q types.Qualifier) Symbol { - s := Symbol{ - Name: obj.Name(), - Kind: StructSymbol, - } - if types.IsInterface(obj.Type()) { +func setKind(s *Symbol, typ types.Type, q types.Qualifier) { + switch typ := typ.Underlying().(type) { + case *types.Interface: s.Kind = InterfaceSymbol + case *types.Struct: + s.Kind = StructSymbol + case *types.Signature: + s.Kind = FunctionSymbol + if typ.Recv() != nil { + s.Kind = MethodSymbol + } + case *types.Named: + setKind(s, typ.Underlying(), q) + case *types.Basic: + i := typ.Info() + switch { + case i&types.IsNumeric != 0: + s.Kind = NumberSymbol + case i&types.IsBoolean != 0: + s.Kind = BooleanSymbol + case i&types.IsString != 0: + s.Kind = StringSymbol + } + default: + s.Kind = VariableSymbol } +} + +func typeSymbol(spec *ast.TypeSpec, obj types.Object, fset *token.FileSet, q types.Qualifier) Symbol { + s := Symbol{Name: obj.Name()} + s.Detail, _ = formatType(obj.Type(), q) + setKind(&s, obj.Type(), q) + if span, err := nodeSpan(spec, fset); err == nil { s.Span = span } if span, err := nodeSpan(spec.Name, fset); err == nil { s.SelectionSpan = span } - s.Detail, _ = formatType(obj.Type(), q) + + if t, ok := obj.Type().Underlying().(*types.Struct); ok { + st := spec.Type.(*ast.StructType) + for i := 0; i < t.NumFields(); i++ { + f := t.Field(i) + child := Symbol{Name: f.Name(), Kind: FieldSymbol} + child.Detail, _ = formatType(f.Type(), q) + + spanNode, selectionNode := nodesForStructField(i, st) + if span, err := nodeSpan(spanNode, fset); err == nil { + child.Span = span + } + if span, err := nodeSpan(selectionNode, fset); err == nil { + child.SelectionSpan = span + } + + s.Children = append(s.Children, child) + } + } + return s } +func nodesForStructField(i int, st *ast.StructType) (span, selection ast.Node) { + j := 0 + for _, field := range st.Fields.List { + if len(field.Names) == 0 { + if i == j { + return field, field.Type + } + j++ + continue + } + for _, name := range field.Names { + if i == j { + return field, name + } + j++ + } + } + return nil, nil +} + func varSymbol(decl ast.Node, name *ast.Ident, obj types.Object, fset *token.FileSet, q types.Qualifier) Symbol { s := Symbol{ Name: obj.Name(), @@ -139,6 +234,9 @@ func varSymbol(decl ast.Node, name *ast.Ident, obj types.Object, fset *token.Fil } func nodeSpan(n ast.Node, fset *token.FileSet) (span.Span, error) { + if n == nil { + return span.Span{}, errors.New("no span for nil node") + } r := span.NewRange(fset, n.Pos(), n.End()) return r.Span() } diff --git a/internal/lsp/symbols.go b/internal/lsp/symbols.go index ae15e08c650..6ac09e0cf84 100644 --- a/internal/lsp/symbols.go +++ b/internal/lsp/symbols.go @@ -45,6 +45,14 @@ func toProtocolSymbolKind(kind source.SymbolKind) protocol.SymbolKind { return protocol.Method case source.InterfaceSymbol: return protocol.Interface + case source.NumberSymbol: + return protocol.Number + case source.StringSymbol: + return protocol.String + case source.BooleanSymbol: + return protocol.Boolean + case source.FieldSymbol: + return protocol.Field default: return 0 } diff --git a/internal/lsp/testdata/symbols/main.go b/internal/lsp/testdata/symbols/main.go index df11cb36028..93ace00cdf8 100644 --- a/internal/lsp/testdata/symbols/main.go +++ b/internal/lsp/testdata/symbols/main.go @@ -1,27 +1,43 @@ package main -var x = 42 //@symbol("x", "x", 13) +import "io" -const y = 43 //@symbol("y", "y", 14) +var x = 42 //@symbol("x", "x", 13, "") -type Foo struct { //@symbol("Foo", "Foo", 23) - Quux - Bar int - baz string +const y = 43 //@symbol("y", "y", 14, "") + +type Number int //@symbol("Number", "Number", 16, "") + +type Alias = string //@symbol("Alias", "Alias", 15, "") + +type NumberAlias = Number //@symbol("NumberAlias", "NumberAlias", 16, "") + +type ( + Boolean bool //@symbol("Boolean", "Boolean", 17, "") + BoolAlias = bool //@symbol("BoolAlias", "BoolAlias", 17, "") +) + +type Foo struct { //@symbol("Foo", "Foo", 23, "") + Quux //@symbol("Quux", "Quux", 8, "Foo") + W io.Writer //@symbol("W" , "W", 8, "Foo") + Bar int //@symbol("Bar", "Bar", 8, "Foo") + baz string //@symbol("baz", "baz", 8, "Foo") } -type Quux struct { //@symbol("Quux", "Quux", 23) - X float64 +type Quux struct { //@symbol("Quux", "Quux", 23, "") + X, Y float64 //@symbol("X", "X", 8, "Quux"), symbol("Y", "Y", 8, "Quux") } -func (f Foo) Baz() string { //@symbol("Baz", "Baz", 6) +func (f Foo) Baz() string { //@symbol("Baz", "Baz", 6, "Foo") return f.baz } -func main() { //@symbol("main", "main", 12) +func (q *Quux) Do() {} //@symbol("Do", "Do", 6, "Quux") + +func main() { //@symbol("main", "main", 12, "") } -type Stringer interface { //@symbol("Stringer", "Stringer", 11) +type Stringer interface { //@symbol("Stringer", "Stringer", 11, "") String() string }