diff --git a/internal/lsp/cmd/test/cmdtest.go b/internal/lsp/cmd/test/cmdtest.go index db4e502c6b..5b516884f0 100644 --- a/internal/lsp/cmd/test/cmdtest.go +++ b/internal/lsp/cmd/test/cmdtest.go @@ -90,6 +90,10 @@ func (r *runner) Link(t *testing.T, uri span.URI, wantLinks []tests.Link) { //TODO: add command line link tests when it works } +func (r *runner) Implementation(t *testing.T, spn span.Span, imp tests.Implementations) { + //TODO: add implements tests when it works +} + func CaptureStdOut(t testing.TB, f func()) string { r, out, err := os.Pipe() if err != nil { diff --git a/internal/lsp/implementation.go b/internal/lsp/implementation.go new file mode 100644 index 0000000000..dffc449289 --- /dev/null +++ b/internal/lsp/implementation.go @@ -0,0 +1,24 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package lsp + +import ( + "context" + + "golang.org/x/tools/internal/lsp/protocol" + "golang.org/x/tools/internal/lsp/source" + "golang.org/x/tools/internal/span" +) + +func (s *Server) implementation(ctx context.Context, params *protocol.ImplementationParams) ([]protocol.Location, error) { + uri := span.NewURI(params.TextDocument.URI) + view := s.session.ViewOf(uri) + f, err := view.GetFile(ctx, uri) + if err != nil { + return nil, err + } + + return source.Implementation(ctx, view, f, params.Position) +} diff --git a/internal/lsp/lsp_test.go b/internal/lsp/lsp_test.go index 568883b1c8..0a8fc0fbaa 100644 --- a/internal/lsp/lsp_test.go +++ b/internal/lsp/lsp_test.go @@ -426,6 +426,44 @@ func (r *runner) Definition(t *testing.T, spn span.Span, d tests.Definition) { } } +func (r *runner) Implementation(t *testing.T, spn span.Span, m tests.Implementations) { + sm, err := r.data.Mapper(m.Src.URI()) + if err != nil { + t.Fatal(err) + } + loc, err := sm.Location(m.Src) + if err != nil { + t.Fatalf("failed for %v: %v", m.Src, err) + } + tdpp := protocol.TextDocumentPositionParams{ + TextDocument: protocol.TextDocumentIdentifier{URI: loc.URI}, + Position: loc.Range.Start, + } + var locs []protocol.Location + params := &protocol.ImplementationParams{ + TextDocumentPositionParams: tdpp, + } + locs, err = r.server.Implementation(r.ctx, params) + if err != nil { + t.Fatalf("failed for %v: %v", m.Src, err) + } + if len(locs) != len(m.Implementations) { + t.Fatalf("got %d locations for implementation, expected %d", len(locs), len(m.Implementations)) + } + for i := range locs { + locURI := span.NewURI(locs[i].URI) + lm, err := r.data.Mapper(locURI) + if err != nil { + t.Fatal(err) + } + if imp, err := lm.Span(locs[i]); err != nil { + t.Fatalf("failed for %v: %v", locs[i], err) + } else if imp != m.Implementations[i] { + t.Errorf("for %dth implementation of %v got %v want %v", i, m.Src, imp, m.Implementations[i]) + } + } +} + func (r *runner) Highlight(t *testing.T, name string, locations []span.Span) { m, err := r.data.Mapper(locations[0].URI()) if err != nil { diff --git a/internal/lsp/server.go b/internal/lsp/server.go index fe096522f6..2a4345fb05 100644 --- a/internal/lsp/server.go +++ b/internal/lsp/server.go @@ -181,8 +181,8 @@ func (s *Server) TypeDefinition(ctx context.Context, params *protocol.TypeDefini return s.typeDefinition(ctx, params) } -func (s *Server) Implementation(context.Context, *protocol.ImplementationParams) ([]protocol.Location, error) { - return nil, notImplemented("Implementation") +func (s *Server) Implementation(ctx context.Context, params *protocol.ImplementationParams) ([]protocol.Location, error) { + return s.implementation(ctx, params) } func (s *Server) References(ctx context.Context, params *protocol.ReferenceParams) ([]protocol.Location, error) { diff --git a/internal/lsp/source/implementation.go b/internal/lsp/source/implementation.go new file mode 100644 index 0000000000..a28b38f682 --- /dev/null +++ b/internal/lsp/source/implementation.go @@ -0,0 +1,154 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// The code in this file is based largely on the code in +// cmd/guru/implements.go. The guru implementation supports +// looking up "implementers" of methods also, but that +// code has been cut out here for now for simplicity. + +package source + +import ( + "context" + "go/types" + "sort" + + "golang.org/x/tools/go/types/typeutil" + "golang.org/x/tools/internal/lsp/protocol" +) + +func Implementation(ctx context.Context, view View, f File, position protocol.Position) ([]protocol.Location, error) { + // Find all references to the identifier at the position. + ident, err := Identifier(ctx, view, f, position) + if err != nil { + return nil, err + } + + res, err := ident.implementations(ctx) + if err != nil { + return nil, err + } + + var locations []protocol.Location + for _, t := range res.to { + // We'll provide implementations that are named types and pointers to named types. + if p, ok := t.(*types.Pointer); ok { + t = p.Elem() + } + if n, ok := t.(*types.Named); ok { + ph, pkg, err := view.FindFileInPackage(ctx, f.URI(), ident.pkg) + if err != nil { + return nil, err + } + f, _, _, err := ph.Cached() + if err != nil { + return nil, err + } + ident, err := findIdentifier(ctx, view.Snapshot(), pkg, f, n.Obj().Pos()) + if err != nil { + return nil, err + } + decRange, err := ident.Declaration.Range() + if err != nil { + return nil, err + } + locations = append(locations, protocol.Location{ + URI: protocol.NewURI(ident.Declaration.URI()), + Range: decRange, + }) + } + } + + return locations, nil +} + +func (i *IdentifierInfo) implementations(ctx context.Context) (implementsResult, error) { + T := i.Type.Object.Type() + + // Find all named types, even local types (which can have + // methods due to promotion) and the built-in "error". + // We ignore aliases 'type M = N' to avoid duplicate + // reporting of the Named type N. + var allNamed []*types.Named + info := i.pkg.GetTypesInfo() + for _, obj := range info.Defs { + if obj, ok := obj.(*types.TypeName); ok && !obj.IsAlias() { + if named, ok := obj.Type().(*types.Named); ok { + allNamed = append(allNamed, named) + } + } + } + + allNamed = append(allNamed, types.Universe.Lookup("error").Type().(*types.Named)) + + var msets typeutil.MethodSetCache + + // TODO(matloob): We only use the to result for now. Figure out if we want to + // surface the from and fromPtr results to users. + // Test each named type. + var to, from, fromPtr []types.Type + for _, U := range allNamed { + if isInterface(T) { + if msets.MethodSet(T).Len() == 0 { + continue // empty interface + } + if isInterface(U) { + if msets.MethodSet(U).Len() == 0 { + continue // empty interface + } + + // T interface, U interface + if !types.Identical(T, U) { + if types.AssignableTo(U, T) { + to = append(to, U) + } + if types.AssignableTo(T, U) { + from = append(from, U) + } + } + } else { + // T interface, U concrete + if types.AssignableTo(U, T) { + to = append(to, U) + } else if pU := types.NewPointer(U); types.AssignableTo(pU, T) { + to = append(to, pU) + } + } + } else if isInterface(U) { + if msets.MethodSet(U).Len() == 0 { + continue // empty interface + } + + // T concrete, U interface + if types.AssignableTo(T, U) { + from = append(from, U) + } else if pT := types.NewPointer(T); types.AssignableTo(pT, U) { + fromPtr = append(fromPtr, U) + } + } + } + + // Sort types (arbitrarily) to ensure test determinism. + sort.Sort(typesByString(to)) + sort.Sort(typesByString(from)) + sort.Sort(typesByString(fromPtr)) + + // TODO(matloob): Perhaps support calling implements on methods instead of just interface types, + // as guru does. + + return implementsResult{to, from, fromPtr}, nil +} + +// implementsResult contains the results of an implements query. +type implementsResult struct { + to []types.Type // named or ptr-to-named types assignable to interface T + from []types.Type // named interfaces assignable from T + fromPtr []types.Type // named interfaces assignable only from *T +} + +type typesByString []types.Type + +func (p typesByString) Len() int { return len(p) } +func (p typesByString) Less(i, j int) bool { return p[i].String() < p[j].String() } +func (p typesByString) Swap(i, j int) { p[i], p[j] = p[j], p[i] } diff --git a/internal/lsp/source/source_test.go b/internal/lsp/source/source_test.go index f8ec138922..c4ab1bb0c7 100644 --- a/internal/lsp/source/source_test.go +++ b/internal/lsp/source/source_test.go @@ -541,6 +541,42 @@ func (r *runner) Definition(t *testing.T, spn span.Span, d tests.Definition) { } } +func (r *runner) Implementation(t *testing.T, spn span.Span, m tests.Implementations) { + ctx := r.ctx + f, err := r.view.GetFile(ctx, m.Src.URI()) + if err != nil { + t.Fatalf("failed for %v: %v", m.Src, err) + } + sm, err := r.data.Mapper(m.Src.URI()) + if err != nil { + t.Fatal(err) + } + loc, err := sm.Location(m.Src) + if err != nil { + t.Fatalf("failed for %v: %v", m.Src, err) + } + var locs []protocol.Location + locs, err = source.Implementation(r.ctx, r.view, f, loc.Range.Start) + if err != nil { + t.Fatalf("failed for %v: %v", m.Src, err) + } + if len(locs) != len(m.Implementations) { + t.Fatalf("got %d locations for implementation, expected %d", len(locs), len(m.Implementations)) + } + for i := range locs { + locURI := span.NewURI(locs[i].URI) + lm, err := r.data.Mapper(locURI) + if err != nil { + t.Fatal(err) + } + if imp, err := lm.Span(locs[i]); err != nil { + t.Fatalf("failed for %v: %v", locs[i], err) + } else if imp != m.Implementations[i] { + t.Errorf("for %dth implementation of %v got %v want %v", i, m.Src, imp, m.Implementations[i]) + } + } +} + func (r *runner) Highlight(t *testing.T, name string, locations []span.Span) { ctx := r.ctx src := locations[0] diff --git a/internal/lsp/testdata/implementation/implementation.go b/internal/lsp/testdata/implementation/implementation.go new file mode 100644 index 0000000000..9656ae483c --- /dev/null +++ b/internal/lsp/testdata/implementation/implementation.go @@ -0,0 +1,21 @@ +package implementation + +type ImpP struct{} //@ImpP + +func (*ImpP) Laugh() { + +} + +type ImpS struct{} //@ImpS + +func (ImpS) Laugh() { + +} + +type ImpI interface { //@ImpI + Laugh() +} + +type Laugher interface { //@implementations("augher", ImpP),implementations("augher", ImpI),implementations("augher", ImpS), + Laugh() +} diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index 34c6273647..64225f820b 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -51,6 +51,7 @@ type Formats []span.Span type Imports []span.Span type SuggestedFixes []span.Span type Definitions map[span.Span]Definition +type Implementationses map[span.Span]Implementations type Highlights map[string][]span.Span type References map[span.Span][]span.Span type Renames map[span.Span]string @@ -77,6 +78,7 @@ type Data struct { Imports Imports SuggestedFixes SuggestedFixes Definitions Definitions + Implementationses Implementationses Highlights Highlights References References Renames Renames @@ -109,6 +111,7 @@ type Tests interface { Import(*testing.T, span.Span) SuggestedFix(*testing.T, span.Span) Definition(*testing.T, span.Span, Definition) + Implementation(*testing.T, span.Span, Implementations) Highlight(*testing.T, string, []span.Span) References(*testing.T, span.Span, []span.Span) Rename(*testing.T, span.Span, string) @@ -125,6 +128,11 @@ type Definition struct { Src, Def span.Span } +type Implementations struct { + Src span.Span + Implementations []span.Span +} + type CompletionTestType int const ( @@ -202,6 +210,7 @@ func Load(t testing.TB, exporter packagestest.Exporter, dir string) *Data { RankCompletions: make(RankCompletions), CaseSensitiveCompletions: make(CaseSensitiveCompletions), Definitions: make(Definitions), + Implementationses: make(Implementationses), Highlights: make(Highlights), References: make(References), Renames: make(Renames), @@ -287,29 +296,30 @@ func Load(t testing.TB, exporter packagestest.Exporter, dir string) *Data { // Collect any data that needs to be used by subsequent tests. if err := data.Exported.Expect(map[string]interface{}{ - "diag": data.collectDiagnostics, - "item": data.collectCompletionItems, - "complete": data.collectCompletions(CompletionDefault), - "unimported": data.collectCompletions(CompletionUnimported), - "deep": data.collectCompletions(CompletionDeep), - "fuzzy": data.collectCompletions(CompletionFuzzy), - "casesensitive": data.collectCompletions(CompletionCaseSensitve), - "rank": data.collectCompletions(CompletionRank), - "snippet": data.collectCompletionSnippets, - "fold": data.collectFoldingRanges, - "format": data.collectFormats, - "import": data.collectImports, - "godef": data.collectDefinitions, - "typdef": data.collectTypeDefinitions, - "hover": data.collectHoverDefinitions, - "highlight": data.collectHighlights, - "refs": data.collectReferences, - "rename": data.collectRenames, - "prepare": data.collectPrepareRenames, - "symbol": data.collectSymbols, - "signature": data.collectSignatures, - "link": data.collectLinks, - "suggestedfix": data.collectSuggestedFixes, + "diag": data.collectDiagnostics, + "item": data.collectCompletionItems, + "complete": data.collectCompletions(CompletionDefault), + "unimported": data.collectCompletions(CompletionUnimported), + "deep": data.collectCompletions(CompletionDeep), + "fuzzy": data.collectCompletions(CompletionFuzzy), + "casesensitive": data.collectCompletions(CompletionCaseSensitve), + "rank": data.collectCompletions(CompletionRank), + "snippet": data.collectCompletionSnippets, + "fold": data.collectFoldingRanges, + "format": data.collectFormats, + "import": data.collectImports, + "godef": data.collectDefinitions, + "implementations": data.collectImplementations, + "typdef": data.collectTypeDefinitions, + "hover": data.collectHoverDefinitions, + "highlight": data.collectHighlights, + "refs": data.collectReferences, + "rename": data.collectRenames, + "prepare": data.collectPrepareRenames, + "symbol": data.collectSymbols, + "signature": data.collectSignatures, + "link": data.collectLinks, + "suggestedfix": data.collectSuggestedFixes, }); err != nil { t.Fatal(err) } @@ -469,6 +479,16 @@ func Run(t *testing.T, tests Tests, data *Data) { } }) + t.Run("Implementation", func(t *testing.T) { + t.Helper() + for spn, m := range data.Implementationses { + t.Run(spanName(spn), func(t *testing.T) { + t.Helper() + tests.Implementation(t, spn, m) + }) + } + }) + t.Run("Highlight", func(t *testing.T) { t.Helper() for name, locations := range data.Highlights { @@ -776,6 +796,14 @@ func (data *Data) collectDefinitions(src, target span.Span) { } } +func (data *Data) collectImplementations(src, target span.Span) { + // Add target to the list of expected implementations for src + imps := data.Implementationses[src] + imps.Src = src // Src is already set if imps already exists, but then we're setting it to the same thing. + imps.Implementations = append(imps.Implementations, target) + data.Implementationses[src] = imps +} + func (data *Data) collectHoverDefinitions(src, target span.Span) { data.Definitions[src] = Definition{ Src: src,