diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go index 81fbd991ad..fb8fb55dd7 100644 --- a/internal/lsp/source/completion.go +++ b/internal/lsp/source/completion.go @@ -211,10 +211,6 @@ func (c *completer) found(obj types.Object, score float64) { cand.score *= highScore } - if c.wantTypeName() && !isTypeName(obj) { - cand.score *= lowScore - } - c.items = append(c.items, c.item(cand)) } @@ -673,9 +669,9 @@ func (c *completer) expectedCompositeLiteralType() types.Type { type typeModifier int const ( - dereference typeModifier = iota // dereference ("*") operator - reference // reference ("&") operator - chanRead // channel read ("<-") operator + star typeModifier = iota // dereference operator for expressions, pointer indicator for types + reference // reference ("&") operator + chanRead // channel read ("<-") operator ) // typeInference holds information we have inferred about a type that can be @@ -690,6 +686,9 @@ type typeInference struct { // modifiers are prefixes such as "*", "&" or "<-" that influence how // a candidate type relates to the expected type. modifiers []typeModifier + + // assertableFrom is a type that must be assertable to our candidate type. + assertableFrom types.Type } // expectedType returns information about the expected type for an expression at @@ -807,7 +806,7 @@ Nodes: } return typeInference{} case *ast.StarExpr: - modifiers = append(modifiers, dereference) + modifiers = append(modifiers, star) case *ast.UnaryExpr: switch node.Op { case token.AND: @@ -832,7 +831,7 @@ Nodes: func (ti typeInference) applyTypeModifiers(typ types.Type) types.Type { for _, mod := range ti.modifiers { switch mod { - case dereference: + case star: // For every "*" deref operator, remove a pointer layer from candidate type. typ = deref(typ) case reference: @@ -848,6 +847,18 @@ func (ti typeInference) applyTypeModifiers(typ types.Type) types.Type { return typ } +// applyTypeNameModifiers applies the list of type modifiers to a type name. +func (ti typeInference) applyTypeNameModifiers(typ types.Type) types.Type { + for _, mod := range ti.modifiers { + switch mod { + case star: + // For every "*" indicator, add a pointer layer to type name. + typ = types.NewPointer(typ) + } + } + return typ +} + // findSwitchStmt returns an *ast.CaseClause's corresponding *ast.SwitchStmt or // *ast.TypeSwitchStmt. path should start from the case clause's first ancestor. func findSwitchStmt(path []ast.Node, pos token.Pos, c *ast.CaseClause) ast.Stmt { @@ -886,7 +897,11 @@ func breaksExpectedTypeInference(n ast.Node) bool { // expectTypeName returns information about the expected type name at position. func expectTypeName(c *completer) typeInference { - var wantTypeName bool + var ( + wantTypeName bool + modifiers []typeModifier + assertableFrom types.Type + ) Nodes: for i, p := range c.path { @@ -911,7 +926,15 @@ Nodes: return typeInference{} case *ast.CaseClause: // Expect type names in type switch case clauses. - if _, ok := findSwitchStmt(c.path[i+1:], c.pos, n).(*ast.TypeSwitchStmt); ok { + if swtch, ok := findSwitchStmt(c.path[i+1:], c.pos, n).(*ast.TypeSwitchStmt); ok { + // The case clause types must be assertable from the type switch parameter. + ast.Inspect(swtch.Assign, func(n ast.Node) bool { + if ta, ok := n.(*ast.TypeAssertExpr); ok { + assertableFrom = c.info.TypeOf(ta.X) + return false + } + return true + }) wantTypeName = true break Nodes } @@ -919,10 +942,14 @@ Nodes: case *ast.TypeAssertExpr: // Expect type names in type assert expressions. if n.Lparen < c.pos && c.pos <= n.Rparen { + // The type in parens must be assertable from the expression type. + assertableFrom = c.info.TypeOf(n.X) wantTypeName = true break Nodes } return typeInference{} + case *ast.StarExpr: + modifiers = append(modifiers, star) default: if breaksExpectedTypeInference(p) { return typeInference{} @@ -931,13 +958,19 @@ Nodes: } return typeInference{ - wantTypeName: wantTypeName, + wantTypeName: wantTypeName, + modifiers: modifiers, + assertableFrom: assertableFrom, } } // matchingType reports whether an object is a good completion candidate // in the context of the expected type. func (c *completer) matchingType(cand *candidate) bool { + if isTypeName(cand.obj) { + return c.matchingTypeName(cand) + } + objType := cand.obj.Type() // Default to invoking *types.Func candidates. This is so function @@ -976,3 +1009,29 @@ func (c *completer) matchingType(cand *candidate) bool { return false } + +func (c *completer) matchingTypeName(cand *candidate) bool { + if !c.wantTypeName() { + return false + } + + // Take into account any type name modifier prefixes. + actual := c.expectedType.applyTypeNameModifiers(cand.obj.Type()) + + if c.expectedType.assertableFrom != nil { + // Don't suggest the starting type in type assertions. For example, + // if "foo" is an io.Writer, don't suggest "foo.(io.Writer)". + if types.Identical(c.expectedType.assertableFrom, actual) { + return false + } + + if intf, ok := c.expectedType.assertableFrom.Underlying().(*types.Interface); ok { + if !types.AssertableTo(intf, actual) { + return false + } + } + } + + // Default to saying any type name is a match. + return true +} diff --git a/internal/lsp/testdata/good/good1.go b/internal/lsp/testdata/good/good1.go index f490b4d79e..b595950db7 100644 --- a/internal/lsp/testdata/good/good1.go +++ b/internal/lsp/testdata/good/good1.go @@ -12,7 +12,7 @@ func random() int { //@item(good_random, "random()", "int", "func") func random2(y int) int { //@item(good_random2, "random2(y int)", "int", "func"),item(good_y_param, "y", "int", "parameter") //@complete("", good_y_param, types_import, good_random, good_random2, good_stuff) var b types.Bob = &types.X{} - if _, ok := b.(*types.X); ok { //@complete("X", Bob_interface, X_struct, Y_struct) + if _, ok := b.(*types.X); ok { //@complete("X", X_struct, Y_struct, Bob_interface) } return y diff --git a/internal/lsp/testdata/typeassert/type_assert.go b/internal/lsp/testdata/typeassert/type_assert.go new file mode 100644 index 0000000000..8b55bf3d76 --- /dev/null +++ b/internal/lsp/testdata/typeassert/type_assert.go @@ -0,0 +1,24 @@ +package typeassert + +type abc interface { //@item(abcIntf, "abc", "interface{...}", "interface") + abc() +} + +type abcImpl struct{} //@item(abcImpl, "abcImpl", "struct{...}", "struct") +func (abcImpl) abc() + +type abcPtrImpl struct{} //@item(abcPtrImpl, "abcPtrImpl", "struct{...}", "struct") +func (*abcPtrImpl) abc() + +type abcNotImpl struct{} //@item(abcNotImpl, "abcNotImpl", "struct{...}", "struct") + +func _() { + var a abc + switch a.(type) { + case ab: //@complete(":", abcImpl, abcIntf, abcNotImpl, abcPtrImpl) + case *ab: //@complete(":", abcImpl, abcPtrImpl, abcIntf, abcNotImpl) + } + + a.(ab) //@complete(")", abcImpl, abcIntf, abcNotImpl, abcPtrImpl) + a.(*ab) //@complete(")", abcImpl, abcPtrImpl, abcIntf, abcNotImpl) +} diff --git a/internal/lsp/tests/tests.go b/internal/lsp/tests/tests.go index 5ecc71a240..24b026d452 100644 --- a/internal/lsp/tests/tests.go +++ b/internal/lsp/tests/tests.go @@ -25,7 +25,7 @@ import ( // We hardcode the expected number of test cases to ensure that all tests // are being executed. If a test is added, this number must be changed. const ( - ExpectedCompletionsCount = 128 + ExpectedCompletionsCount = 132 ExpectedCompletionSnippetCount = 14 ExpectedDiagnosticsCount = 17 ExpectedFormatCount = 5