diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go index 990cfdd352..bccf80d6b5 100644 --- a/internal/lsp/source/completion.go +++ b/internal/lsp/source/completion.go @@ -130,17 +130,13 @@ type completer struct { // surrounding describes the identifier surrounding the position. surrounding *Selection - // expectedType is the type we expect the completion candidate to be. - // It may not be set. - expectedType types.Type + // expectedType conains information about the type we expect the completion + // candidate to be. It will be the zero value if no information is available. + expectedType typeInference // enclosingFunction is the function declaration enclosing the position. enclosingFunction *types.Signature - // preferTypeNames is true if we are completing at a position that expects a type, - // not a value. - preferTypeNames bool - // enclosingCompositeLiteral contains information about the composite literal // enclosing the position. enclosingCompositeLiteral *compLitInfo @@ -205,10 +201,10 @@ func (c *completer) found(obj types.Object, weight float64) { return } c.seen[obj] = true - if c.matchingType(obj.Type()) { + if c.matchingType(obj) { weight *= highScore } - if _, ok := obj.(*types.TypeName); !ok && c.preferTypeNames { + if c.wantTypeName() && !isTypeName(obj) { weight *= lowScore } c.items = append(c.items, c.item(obj, weight)) @@ -258,7 +254,6 @@ func Completion(ctx context.Context, f GoFile, pos token.Pos) ([]CompletionItem, pos: pos, seen: make(map[types.Object]bool), enclosingFunction: enclosingFunction(path, pos, pkg.GetTypesInfo()), - preferTypeNames: preferTypeNames(path, pos), enclosingCompositeLiteral: clInfo, } @@ -344,6 +339,10 @@ func (c *completer) wantStructFieldCompletions() bool { return clInfo.isStruct() && (clInfo.inKey || clInfo.maybeInFieldName) } +func (c *completer) wantTypeName() bool { + return c.expectedType.wantTypeName +} + // selector finds completions for the specified selector expression. func (c *completer) selector(sel *ast.SelectorExpr) error { // Is sel a qualified identifier? @@ -657,10 +656,25 @@ const ( chanRead // channel read ("<-") operator ) -// expectedType returns the expected type for an expression at the query position. -func expectedType(c *completer) types.Type { +// typeInference holds information we have inferred about a type that can be +// used at the current position. +type typeInference struct { + // objType is the desired type of an object used at the query position. + objType types.Type + + // wantTypeName is true if we expect the name of a type. + wantTypeName bool +} + +// expectedType returns information about the expected type for an expression at +// the query position. +func expectedType(c *completer) typeInference { + if ti := expectTypeName(c); ti.wantTypeName { + return ti + } + if c.enclosingCompositeLiteral != nil { - return c.expectedCompositeLiteralType() + return typeInference{objType: c.expectedCompositeLiteralType()} } var ( @@ -693,14 +707,14 @@ Nodes: break Nodes } } - return nil + return typeInference{} case *ast.CallExpr: // Only consider CallExpr args if position falls between parens. if node.Lparen <= c.pos && c.pos <= node.Rparen { if tv, ok := c.info.Types[node.Fun]; ok { if sig, ok := tv.Type.(*types.Signature); ok { if sig.Params().Len() == 0 { - return nil + return typeInference{} } i := indexExprAtPos(c.pos, node.Args) // Make sure not to run past the end of expected parameters. @@ -712,7 +726,7 @@ Nodes: } } } - return nil + return typeInference{} case *ast.ReturnStmt: if sig := c.enclosingFunction; sig != nil { // Find signature result that corresponds to our return statement. @@ -723,7 +737,7 @@ Nodes: } } } - return nil + return typeInference{} case *ast.CaseClause: if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, node).(*ast.SwitchStmt); ok { if tv, ok := c.info.Types[swtch.Tag]; ok { @@ -731,14 +745,14 @@ Nodes: break Nodes } } - return nil + return typeInference{} case *ast.SliceExpr: // Make sure position falls within the brackets (e.g. "foo[a:<>]"). if node.Lbrack < c.pos && c.pos <= node.Rbrack { typ = types.Typ[types.Int] break Nodes } - return nil + return typeInference{} case *ast.IndexExpr: // Make sure position falls within the brackets (e.g. "foo[<>]"). if node.Lbrack < c.pos && c.pos <= node.Rbrack { @@ -749,12 +763,12 @@ Nodes: case *types.Slice, *types.Array: typ = types.Typ[types.Int] default: - return nil + return typeInference{} } break Nodes } } - return nil + return typeInference{} case *ast.SendStmt: // Make sure we are on right side of arrow (e.g. "foo <- <>"). if c.pos > node.Arrow+1 { @@ -765,7 +779,7 @@ Nodes: } } } - return nil + return typeInference{} case *ast.StarExpr: modifiers = append(modifiers, dereference) case *ast.UnaryExpr: @@ -777,7 +791,7 @@ Nodes: } default: if breaksExpectedTypeInference(node) { - return nil + return typeInference{} } } } @@ -798,7 +812,9 @@ Nodes: } } - return typ + return typeInference{ + objType: typ, + } } // findSwitchStmt returns an *ast.CaseClause's corresponding *ast.SwitchStmt or @@ -837,50 +853,74 @@ func breaksExpectedTypeInference(n ast.Node) bool { } } -// preferTypeNames checks if given token position is inside func receiver, -// type params, or type results. For example: -// -// func (<>) foo(<>) (<>) {} -// -func preferTypeNames(path []ast.Node, pos token.Pos) bool { - for i, p := range path { +// expectTypeName returns information about the expected type name at position. +func expectTypeName(c *completer) typeInference { + var wantTypeName bool + +Nodes: + for i, p := range c.path { switch n := p.(type) { case *ast.FuncDecl: - if r := n.Recv; r != nil && r.Pos() <= pos && pos <= r.End() { - return true + // Expect type names in a function declaration receiver, params and results. + + if r := n.Recv; r != nil && r.Pos() <= c.pos && c.pos <= r.End() { + wantTypeName = true + break Nodes } if t := n.Type; t != nil { - if p := t.Params; p != nil && p.Pos() <= pos && pos <= p.End() { - return true + if p := t.Params; p != nil && p.Pos() <= c.pos && c.pos <= p.End() { + wantTypeName = true + break Nodes } - if r := t.Results; r != nil && r.Pos() <= pos && pos <= r.End() { - return true + if r := t.Results; r != nil && r.Pos() <= c.pos && c.pos <= r.End() { + wantTypeName = true + break Nodes } } - return false + return typeInference{} case *ast.CaseClause: - _, isTypeSwitch := findSwitchStmt(path[i+1:], pos, n).(*ast.TypeSwitchStmt) - return isTypeSwitch + // Expect type names in type switch case clauses. + if _, ok := findSwitchStmt(c.path[i+1:], c.pos, n).(*ast.TypeSwitchStmt); ok { + wantTypeName = true + break Nodes + } + return typeInference{} case *ast.TypeAssertExpr: - if n.Lparen < pos && pos <= n.Rparen { - return true + // Expect type names in type assert expressions. + if n.Lparen < c.pos && c.pos <= n.Rparen { + wantTypeName = true + break Nodes + } + return typeInference{} + default: + if breaksExpectedTypeInference(p) { + return typeInference{} } } } - return false + + return typeInference{ + wantTypeName: wantTypeName, + } } -// matchingTypes reports whether actual is a good candidate type -// for a completion in a context of the expected type. -func (c *completer) matchingType(actual types.Type) bool { - if c.expectedType == nil { - return false - } +// matchingType reports whether an object is a good completion candidate +// in the context of the expected type. +func (c *completer) matchingType(obj types.Object) bool { + actual := obj.Type() + // Use a function's return type as its type. if sig, ok := actual.(*types.Signature); ok { if sig.Results().Len() == 1 { actual = sig.Results().At(0).Type() } } - return types.Identical(types.Default(c.expectedType), types.Default(actual)) + + if c.expectedType.objType != nil { + // AssignableTo covers the case where the types are equal, but also handles + // cases like assigning a concrete type to an interface type. + return types.AssignableTo(types.Default(actual), types.Default(c.expectedType.objType)) + } + + return false } diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go index 81299b96d5..403d1c8276 100644 --- a/internal/lsp/source/util.go +++ b/internal/lsp/source/util.go @@ -74,7 +74,7 @@ func resolveInvalid(obj types.Object, node ast.Node, info *types.Info) types.Obj default: return nil } - typ := types.NewNamed(types.NewTypeName(token.NoPos, obj.Pkg(), typename, nil), nil, nil) + typ := types.NewNamed(types.NewTypeName(token.NoPos, obj.Pkg(), typename, nil), types.Typ[types.Invalid], nil) return types.NewVar(obj.Pos(), obj.Pkg(), obj.Name(), typ) } var resultExpr ast.Expr @@ -127,6 +127,11 @@ func deref(typ types.Type) types.Type { return typ } +func isTypeName(obj types.Object) bool { + _, ok := obj.(*types.TypeName) + return ok +} + func formatParams(tup *types.Tuple, variadic bool, qf types.Qualifier) []string { params := make([]string, 0, tup.Len()) for i := 0; i < tup.Len(); i++ { diff --git a/internal/lsp/testdata/interfacerank/interface_rank.go b/internal/lsp/testdata/interfacerank/interface_rank.go new file mode 100644 index 0000000000..968c1a6a0d --- /dev/null +++ b/internal/lsp/testdata/interfacerank/interface_rank.go @@ -0,0 +1,20 @@ +package interfacerank + +type foo interface { + foo() +} + +type fooImpl int + +func (*fooImpl) foo() {} + +func wantsFoo(foo) {} + +func _() { + var ( + aa string //@item(irAA, "aa", "string", "var") + ab *fooImpl //@item(irAB, "ab", "*fooImpl", "var") + ) + + wantsFoo(a) //@complete(")", irAB, irAA) +} diff --git a/internal/lsp/testdata/unresolved/unresolved.go.in b/internal/lsp/testdata/unresolved/unresolved.go.in new file mode 100644 index 0000000000..731d582541 --- /dev/null +++ b/internal/lsp/testdata/unresolved/unresolved.go.in @@ -0,0 +1,6 @@ +package unresolved + +func foo(interface{}) { //@item(unresolvedFoo, "foo(interface{})", "", "func") + // don't crash on fake "resolved" type + foo(func(i, j f //@complete(" //", unresolvedFoo) +} diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index 9454bf37b7..373d6755c1 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -25,7 +25,7 @@ import ( // We hardcode the expected number of test cases to ensure that all tests // are being executed. If a test is added, this number must be changed. const ( - ExpectedCompletionsCount = 122 + ExpectedCompletionsCount = 124 ExpectedCompletionSnippetCount = 14 ExpectedDiagnosticsCount = 17 ExpectedFormatCount = 5