diff --git a/internal/lsp/source/completion.go b/internal/lsp/source/completion.go index d06f9c6b08..e060c2b188 100644 --- a/internal/lsp/source/completion.go +++ b/internal/lsp/source/completion.go @@ -727,6 +727,9 @@ type typeInference struct { // assertableFrom is a type that must be assertable to our candidate type. assertableFrom types.Type + + // convertibleTo is a type our candidate type must be convertible to. + convertibleTo types.Type } // expectedType returns information about the expected type for an expression at @@ -741,8 +744,9 @@ func expectedType(c *completer) typeInference { } var ( - modifiers []typeModifier - typ types.Type + modifiers []typeModifier + typ types.Type + convertibleTo types.Type ) Nodes: @@ -774,6 +778,13 @@ Nodes: case *ast.CallExpr: // Only consider CallExpr args if position falls between parens. if node.Lparen <= c.pos && c.pos <= node.Rparen { + // For type conversions like "int64(foo)" we can only infer our + // desired type is convertible to int64. + if typ := typeConversion(node, c.info); typ != nil { + convertibleTo = typ + break Nodes + } + if tv, ok := c.info.Types[node.Fun]; ok { if sig, ok := tv.Type.(*types.Signature); ok { if sig.Params().Len() == 0 { @@ -860,8 +871,9 @@ Nodes: } return typeInference{ - objType: typ, - modifiers: modifiers, + objType: typ, + modifiers: modifiers, + convertibleTo: convertibleTo, } } @@ -1045,6 +1057,10 @@ func (c *completer) matchingType(cand *candidate) bool { } } + if c.expectedType.convertibleTo != nil { + return types.ConvertibleTo(objType, c.expectedType.convertibleTo) + } + return false } diff --git a/internal/lsp/source/util.go b/internal/lsp/source/util.go index ef7d78286e..8fe0498a21 100644 --- a/internal/lsp/source/util.go +++ b/internal/lsp/source/util.go @@ -139,6 +139,27 @@ func isFunc(obj types.Object) bool { return ok } +// typeConversion returns the type being converted to if call is a type +// conversion expression. +func typeConversion(call *ast.CallExpr, info *types.Info) types.Type { + var ident *ast.Ident + switch expr := call.Fun.(type) { + case *ast.Ident: + ident = expr + case *ast.SelectorExpr: + ident = expr.Sel + default: + return nil + } + + // Type conversion (e.g. "float64(foo)"). + if fun, _ := info.ObjectOf(ident).(*types.TypeName); fun != nil { + return fun.Type() + } + + return nil +} + func formatParams(tup *types.Tuple, variadic bool, qf types.Qualifier) []string { params := make([]string, 0, tup.Len()) for i := 0; i < tup.Len(); i++ { diff --git a/internal/lsp/testdata/rank/convert_rank.go.in b/internal/lsp/testdata/rank/convert_rank.go.in new file mode 100644 index 0000000000..dc0a3a5ee7 --- /dev/null +++ b/internal/lsp/testdata/rank/convert_rank.go.in @@ -0,0 +1,12 @@ +package rank + +func _() { + type strList []string + wantsStrList := func(strList) {} + + var ( + convA string //@item(convertA, "convA", "string", "var") + convB []string //@item(convertB, "convB", "[]string", "var") + ) + wantsStrList(strList(conv)) //@complete("))", convertB, convertA) +}