1
0
mirror of https://github.com/golang/go synced 2024-11-18 11:04:42 -07:00

internal/lsp/source: fix composite literal type name completion

Fix completion in the following cases:

    type foo struct{}

    // now we offer "&foo" instead of "foo"
    var _ *foo = fo<>{}

    struct { f *foo }{
      // now we offer "&foo" instead of "*foo"
      f: fo<>{},
    }

Composite literal type names are a bit special because they are part
of an arbitrary value expression rather than just a standalone type
name expression. In particular, they can be preceded by "&", which
affects how they relate to the surrounding context. The "&" doesn't
technically apply to the type name, but we must take it into account.

I made three changes to fix the behavior:
1. When we want to make a composite literal type name into a pointer,
   we use "&" instead of "*".
2. Record if a composite literal type is already has a "&" so we don't
   add it again.
3. Fix "var _ *foo = fo<>{}" to properly infer expected type of "*foo"
   by not stopping at *ast.CompositeLit searching up AST path when the
   position is in the type name (as opposed to within the curlies).

Change-Id: Iee828f259eb939646b68f5066614ea3a262585c2
Reviewed-on: https://go-review.googlesource.com/c/tools/+/247525
Run-TryBot: Muir Manders <muir@mnd.rs>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
This commit is contained in:
Muir Manders 2020-08-09 21:42:16 -07:00 committed by Robert Findley
parent c886c0b611
commit 74543c4034
4 changed files with 86 additions and 49 deletions

View File

@ -1445,7 +1445,7 @@ func enclosingCompositeLiteral(path []ast.Node, pos token.Pos, info *types.Info)
return &clInfo return &clInfo
default: default:
if breaksExpectedTypeInference(n) { if breaksExpectedTypeInference(n, pos) {
return nil return nil
} }
} }
@ -1535,11 +1535,11 @@ type typeModifier struct {
type typeMod int type typeMod int
const ( const (
star typeMod = iota // pointer indirection for expressions, pointer indicator for types dereference typeMod = iota // pointer indirection: "*"
address // address operator ("&") reference // adds level of pointer: "&" for values, "*" for type names
chanRead // channel read operator ("<-") chanRead // channel read operator ("<-")
slice // make a slice type ("[]" in "[]int") slice // make a slice type ("[]" in "[]int")
array // make an array type ("[2]" in "[2]int") array // make an array type ("[2]" in "[2]int")
) )
type objKind int type objKind int
@ -1651,6 +1651,10 @@ type typeNameInference struct {
// seenTypeSwitchCases tracks types that have already been used by // seenTypeSwitchCases tracks types that have already been used by
// the containing type switch. // the containing type switch.
seenTypeSwitchCases []types.Type seenTypeSwitchCases []types.Type
// compLitType is true if we are completing a composite literal type
// name, e.g "foo<>{}".
compLitType bool
} }
// expectedCandidate returns information about the expected candidate // expectedCandidate returns information about the expected candidate
@ -1862,11 +1866,11 @@ Nodes:
} }
return inf return inf
case *ast.StarExpr: case *ast.StarExpr:
inf.modifiers = append(inf.modifiers, typeModifier{mod: star}) inf.modifiers = append(inf.modifiers, typeModifier{mod: dereference})
case *ast.UnaryExpr: case *ast.UnaryExpr:
switch node.Op { switch node.Op {
case token.AND: case token.AND:
inf.modifiers = append(inf.modifiers, typeModifier{mod: address}) inf.modifiers = append(inf.modifiers, typeModifier{mod: reference})
case token.ARROW: case token.ARROW:
inf.modifiers = append(inf.modifiers, typeModifier{mod: chanRead}) inf.modifiers = append(inf.modifiers, typeModifier{mod: chanRead})
} }
@ -1874,7 +1878,7 @@ Nodes:
inf.objKind |= kindFunc inf.objKind |= kindFunc
return inf return inf
default: default:
if breaksExpectedTypeInference(node) { if breaksExpectedTypeInference(node, c.pos) {
return inf return inf
} }
} }
@ -1928,7 +1932,7 @@ func objChain(info *types.Info, e ast.Expr) []types.Object {
func (ci candidateInference) applyTypeModifiers(typ types.Type, addressable bool) types.Type { func (ci candidateInference) applyTypeModifiers(typ types.Type, addressable bool) types.Type {
for _, mod := range ci.modifiers { for _, mod := range ci.modifiers {
switch mod.mod { switch mod.mod {
case star: case dereference:
// For every "*" indirection operator, remove a pointer layer // For every "*" indirection operator, remove a pointer layer
// from candidate type. // from candidate type.
if ptr, ok := typ.Underlying().(*types.Pointer); ok { if ptr, ok := typ.Underlying().(*types.Pointer); ok {
@ -1936,7 +1940,7 @@ func (ci candidateInference) applyTypeModifiers(typ types.Type, addressable bool
} else { } else {
return nil return nil
} }
case address: case reference:
// For every "&" address operator, add another pointer layer to // For every "&" address operator, add another pointer layer to
// candidate type, if the candidate is addressable. // candidate type, if the candidate is addressable.
if addressable { if addressable {
@ -1961,8 +1965,7 @@ func (ci candidateInference) applyTypeModifiers(typ types.Type, addressable bool
func (ci candidateInference) applyTypeNameModifiers(typ types.Type) types.Type { func (ci candidateInference) applyTypeNameModifiers(typ types.Type) types.Type {
for _, mod := range ci.typeName.modifiers { for _, mod := range ci.typeName.modifiers {
switch mod.mod { switch mod.mod {
case star: case reference:
// For every "*" indicator, add a pointer layer to type name.
typ = types.NewPointer(typ) typ = types.NewPointer(typ)
case array: case array:
typ = types.NewArray(typ, mod.arrayLen) typ = types.NewArray(typ, mod.arrayLen)
@ -2006,9 +2009,17 @@ func findSwitchStmt(path []ast.Node, pos token.Pos, c *ast.CaseClause) ast.Stmt
// breaksExpectedTypeInference reports if an expression node's type is unrelated // breaksExpectedTypeInference reports if an expression node's type is unrelated
// to its child expression node types. For example, "Foo{Bar: x.Baz(<>)}" should // to its child expression node types. For example, "Foo{Bar: x.Baz(<>)}" should
// expect a function argument, not a composite literal value. // expect a function argument, not a composite literal value.
func breaksExpectedTypeInference(n ast.Node) bool { func breaksExpectedTypeInference(n ast.Node, pos token.Pos) bool {
switch n.(type) { switch n := n.(type) {
case *ast.FuncLit, *ast.CallExpr, *ast.IndexExpr, *ast.SliceExpr, *ast.CompositeLit: case *ast.CompositeLit:
// Doesn't break inference if pos is in type name.
// For example: "Foo<>{Bar: 123}"
return !nodeContains(n.Type, pos)
case *ast.CallExpr:
// Doesn't break inference if pos is in func name.
// For example: "Foo<>(123)"
return !nodeContains(n.Fun, pos)
case *ast.FuncLit, *ast.IndexExpr, *ast.SliceExpr:
return true return true
default: default:
return false return false
@ -2017,13 +2028,7 @@ func breaksExpectedTypeInference(n ast.Node) bool {
// expectTypeName returns information about the expected type name at position. // expectTypeName returns information about the expected type name at position.
func expectTypeName(c *completer) typeNameInference { func expectTypeName(c *completer) typeNameInference {
var ( var inf typeNameInference
wantTypeName bool
wantComparable bool
modifiers []typeModifier
assertableFrom types.Type
seenTypeSwitchCases []types.Type
)
Nodes: Nodes:
for i, p := range c.path { for i, p := range c.path {
@ -2034,7 +2039,7 @@ Nodes:
// InterfaceType. We don't need to worry about the field name // InterfaceType. We don't need to worry about the field name
// because completion bails out early if pos is in an *ast.Ident // because completion bails out early if pos is in an *ast.Ident
// that defines an object. // that defines an object.
wantTypeName = true inf.wantTypeName = true
break Nodes break Nodes
case *ast.CaseClause: case *ast.CaseClause:
// Expect type names in type switch case clauses. // Expect type names in type switch case clauses.
@ -2042,12 +2047,12 @@ Nodes:
// The case clause types must be assertable from the type switch parameter. // The case clause types must be assertable from the type switch parameter.
ast.Inspect(swtch.Assign, func(n ast.Node) bool { ast.Inspect(swtch.Assign, func(n ast.Node) bool {
if ta, ok := n.(*ast.TypeAssertExpr); ok { if ta, ok := n.(*ast.TypeAssertExpr); ok {
assertableFrom = c.pkg.GetTypesInfo().TypeOf(ta.X) inf.assertableFrom = c.pkg.GetTypesInfo().TypeOf(ta.X)
return false return false
} }
return true return true
}) })
wantTypeName = true inf.wantTypeName = true
// Track the types that have already been used in this // Track the types that have already been used in this
// switch's case statements so we don't recommend them. // switch's case statements so we don't recommend them.
@ -2060,7 +2065,7 @@ Nodes:
} }
if t := c.pkg.GetTypesInfo().TypeOf(typeExpr); t != nil { if t := c.pkg.GetTypesInfo().TypeOf(typeExpr); t != nil {
seenTypeSwitchCases = append(seenTypeSwitchCases, t) inf.seenTypeSwitchCases = append(inf.seenTypeSwitchCases, t)
} }
} }
} }
@ -2072,33 +2077,43 @@ Nodes:
// Expect type names in type assert expressions. // Expect type names in type assert expressions.
if n.Lparen < c.pos && c.pos <= n.Rparen { if n.Lparen < c.pos && c.pos <= n.Rparen {
// The type in parens must be assertable from the expression type. // The type in parens must be assertable from the expression type.
assertableFrom = c.pkg.GetTypesInfo().TypeOf(n.X) inf.assertableFrom = c.pkg.GetTypesInfo().TypeOf(n.X)
wantTypeName = true inf.wantTypeName = true
break Nodes break Nodes
} }
return typeNameInference{} return typeNameInference{}
case *ast.StarExpr: case *ast.StarExpr:
modifiers = append(modifiers, typeModifier{mod: star}) inf.modifiers = append(inf.modifiers, typeModifier{mod: reference})
case *ast.CompositeLit: case *ast.CompositeLit:
// We want a type name if position is in the "Type" part of a // We want a type name if position is in the "Type" part of a
// composite literal (e.g. "Foo<>{}"). // composite literal (e.g. "Foo<>{}").
if n.Type != nil && n.Type.Pos() <= c.pos && c.pos <= n.Type.End() { if n.Type != nil && n.Type.Pos() <= c.pos && c.pos <= n.Type.End() {
wantTypeName = true inf.wantTypeName = true
inf.compLitType = true
if i < len(c.path)-1 {
// Track preceding "&" operator. Technically it applies to
// the composite literal and not the type name, but if
// affects our type completion nonetheless.
if u, ok := c.path[i+1].(*ast.UnaryExpr); ok && u.Op == token.AND {
inf.modifiers = append(inf.modifiers, typeModifier{mod: reference})
}
}
} }
break Nodes break Nodes
case *ast.ArrayType: case *ast.ArrayType:
// If we are inside the "Elt" part of an array type, we want a type name. // If we are inside the "Elt" part of an array type, we want a type name.
if n.Elt.Pos() <= c.pos && c.pos <= n.Elt.End() { if n.Elt.Pos() <= c.pos && c.pos <= n.Elt.End() {
wantTypeName = true inf.wantTypeName = true
if n.Len == nil { if n.Len == nil {
// No "Len" expression means a slice type. // No "Len" expression means a slice type.
modifiers = append(modifiers, typeModifier{mod: slice}) inf.modifiers = append(inf.modifiers, typeModifier{mod: slice})
} else { } else {
// Try to get the array type using the constant value of "Len". // Try to get the array type using the constant value of "Len".
tv, ok := c.pkg.GetTypesInfo().Types[n.Len] tv, ok := c.pkg.GetTypesInfo().Types[n.Len]
if ok && tv.Value != nil && tv.Value.Kind() == constant.Int { if ok && tv.Value != nil && tv.Value.Kind() == constant.Int {
if arrayLen, ok := constant.Int64Val(tv.Value); ok { if arrayLen, ok := constant.Int64Val(tv.Value); ok {
modifiers = append(modifiers, typeModifier{mod: array, arrayLen: arrayLen}) inf.modifiers = append(inf.modifiers, typeModifier{mod: array, arrayLen: arrayLen})
} }
} }
} }
@ -2114,34 +2129,28 @@ Nodes:
break Nodes break Nodes
} }
case *ast.MapType: case *ast.MapType:
wantTypeName = true inf.wantTypeName = true
if n.Key != nil { if n.Key != nil {
wantComparable = nodeContains(n.Key, c.pos) inf.wantComparable = nodeContains(n.Key, c.pos)
} else { } else {
// If the key is empty, assume we are completing the key if // If the key is empty, assume we are completing the key if
// pos is directly after the "map[". // pos is directly after the "map[".
wantComparable = c.pos == n.Pos()+token.Pos(len("map[")) inf.wantComparable = c.pos == n.Pos()+token.Pos(len("map["))
} }
break Nodes break Nodes
case *ast.ValueSpec: case *ast.ValueSpec:
wantTypeName = nodeContains(n.Type, c.pos) inf.wantTypeName = nodeContains(n.Type, c.pos)
break Nodes break Nodes
case *ast.TypeSpec: case *ast.TypeSpec:
wantTypeName = nodeContains(n.Type, c.pos) inf.wantTypeName = nodeContains(n.Type, c.pos)
default: default:
if breaksExpectedTypeInference(p) { if breaksExpectedTypeInference(p, c.pos) {
return typeNameInference{} return typeNameInference{}
} }
} }
} }
return typeNameInference{ return inf
wantTypeName: wantTypeName,
wantComparable: wantComparable,
modifiers: modifiers,
assertableFrom: assertableFrom,
seenTypeSwitchCases: seenTypeSwitchCases,
}
} }
func (c *completer) fakeObj(T types.Type) *types.Var { func (c *completer) fakeObj(T types.Type) *types.Var {
@ -2519,7 +2528,15 @@ func (c *completer) matchingTypeName(cand *candidate) bool {
} }
if !isInterface(t) && typeMatches(types.NewPointer(t)) { if !isInterface(t) && typeMatches(types.NewPointer(t)) {
cand.makePointer = true if c.inference.typeName.compLitType {
// If we are completing a composite literal type as in
// "foo<>{}", to make a pointer we must prepend "&".
cand.takeAddress = true
} else {
// If we are completing a normal type name such as "foo<>", to
// make a pointer we must prepend "*".
cand.makePointer = true
}
return true return true
} }

View File

@ -94,6 +94,20 @@ func _() {
_ = position{X} //@complete("}", fieldX, varX) _ = position{X} //@complete("}", fieldX, varX)
} }
func _() {
type foo struct{} //@item(complitFoo, "foo", "struct{...}", "struct")
"&foo" //@item(complitAndFoo, "&foo", "struct{...}", "struct")
var _ *foo = &fo{} //@rank("{", complitFoo)
var _ *foo = fo{} //@rank("{", complitAndFoo)
struct { a, b *foo }{
a: &fo{}, //@rank("{", complitFoo)
b: fo{}, //@rank("{", complitAndFoo)
}
}
func _() { func _() {
_ := position{ _ := position{
X: 1, //@complete("X", fieldX),complete(" 1", exportedFunc, multilineWithPrefix, structPosition, cVar, exportedConst, exportedType) X: 1, //@complete("X", fieldX),complete(" 1", exportedFunc, multilineWithPrefix, structPosition, cVar, exportedConst, exportedType)

View File

@ -199,6 +199,12 @@ func _() {
ptrStruct{ ptrStruct{
p: &ptrSt, //@rank(",", litPtrStruct) p: &ptrSt, //@rank(",", litPtrStruct)
} }
&ptrStruct{} //@item(litPtrStructPtr, "&ptrStruct{}", "", "var")
&ptrStruct{
p: ptrSt, //@rank(",", litPtrStructPtr)
}
} }
func _() { func _() {

View File

@ -6,7 +6,7 @@ CompletionSnippetCount = 85
UnimportedCompletionsCount = 6 UnimportedCompletionsCount = 6
DeepCompletionsCount = 5 DeepCompletionsCount = 5
FuzzyCompletionsCount = 8 FuzzyCompletionsCount = 8
RankedCompletionsCount = 152 RankedCompletionsCount = 157
CaseSensitiveCompletionsCount = 4 CaseSensitiveCompletionsCount = 4
DiagnosticsCount = 44 DiagnosticsCount = 44
FoldingRangesCount = 2 FoldingRangesCount = 2