1
0
mirror of https://github.com/golang/go synced 2024-11-18 08:54:45 -07:00

internal/lsp: extract highlighted selection to variable

I add a code action that triggers upon request of the user. A variable
name is generated manually for the extracted code because the LSP does
not support a user's ability to provide a name.

Change-Id: Id1ec19b49562b7cfbc2cd416378bec9bd021d82f
Reviewed-on: https://go-review.googlesource.com/c/tools/+/240182
Run-TryBot: Josh Baum <joshbaum@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rebecca Stambler <rstambler@golang.org>
This commit is contained in:
Josh Baum 2020-06-24 09:52:23 -04:00
parent 416e8f4faf
commit 9c9572d6f9
13 changed files with 352 additions and 92 deletions

View File

@ -124,3 +124,77 @@ const (
NoResultValues TypeErrorPass = "noresultvalues"
UndeclaredName TypeErrorPass = "undeclaredname"
)
// StmtToInsertVarBefore returns the ast.Stmt before which we can safely insert a new variable.
// Some examples:
//
// Basic Example:
// z := 1
// y := z + x
// If x is undeclared, then this function would return `y := z + x`, so that we
// can insert `x := ` on the line before `y := z + x`.
//
// If stmt example:
// if z == 1 {
// } else if z == y {}
// If y is undeclared, then this function would return `if z == 1 {`, because we cannot
// insert a statement between an if and an else if statement. As a result, we need to find
// the top of the if chain to insert `y := ` before.
func StmtToInsertVarBefore(path []ast.Node) ast.Stmt {
enclosingIndex := -1
for i, p := range path {
if _, ok := p.(ast.Stmt); ok {
enclosingIndex = i
break
}
}
if enclosingIndex == -1 {
return nil
}
enclosingStmt := path[enclosingIndex]
switch enclosingStmt.(type) {
case *ast.IfStmt:
// The enclosingStmt is inside of the if declaration,
// We need to check if we are in an else-if stmt and
// get the base if statement.
return baseIfStmt(path, enclosingIndex)
case *ast.CaseClause:
// Get the enclosing switch stmt if the enclosingStmt is
// inside of the case statement.
for i := enclosingIndex + 1; i < len(path); i++ {
if node, ok := path[i].(*ast.SwitchStmt); ok {
return node
} else if node, ok := path[i].(*ast.TypeSwitchStmt); ok {
return node
}
}
}
if len(path) <= enclosingIndex+1 {
return enclosingStmt.(ast.Stmt)
}
// Check if the enclosing statement is inside another node.
switch expr := path[enclosingIndex+1].(type) {
case *ast.IfStmt:
// Get the base if statement.
return baseIfStmt(path, enclosingIndex+1)
case *ast.ForStmt:
if expr.Init == enclosingStmt || expr.Post == enclosingStmt {
return expr
}
}
return enclosingStmt.(ast.Stmt)
}
// baseIfStmt walks up the if/else-if chain until we get to
// the top of the current if chain.
func baseIfStmt(path []ast.Node, index int) ast.Stmt {
stmt := path[index]
for i := index + 1; i < len(path); i++ {
if node, ok := path[i].(*ast.IfStmt); ok && node.Else == stmt {
stmt = node
continue
}
break
}
return stmt.(ast.Stmt)
}

View File

@ -70,20 +70,9 @@ func run(pass *analysis.Pass) (interface{}, error) {
if _, ok := path[1].(*ast.CallExpr); ok {
continue
}
// Get the enclosing statement.
enclosingIndex := -1
for i, p := range path {
if _, ok := p.(ast.Stmt); ok && enclosingIndex == -1 {
enclosingIndex = i
break
}
}
if enclosingIndex == -1 {
continue
}
// Get the place to insert the new statement.
insertBeforeStmt := stmtToInsertVarBefore(path, enclosingIndex)
insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path)
if insertBeforeStmt == nil {
continue
}
@ -121,70 +110,6 @@ func run(pass *analysis.Pass) (interface{}, error) {
return nil, nil
}
// stmtToInsertVarBefore returns the ast.Stmt before which we can safely insert a new variable.
// Some examples:
//
// Basic Example:
// z := 1
// y := z + x
// If x is undeclared, then this function would return `y := z + x`, so that we
// can insert `x := ` on the line before `y := z + x`.
//
// If stmt example:
// if z == 1 {
// } else if z == y {}
// If y is undeclared, then this function would return `if z == 1 {`, because we cannot
// insert a statement between an if and an else if statement. As a result, we need to find
// the top of the if chain to insert `y := ` before.
func stmtToInsertVarBefore(path []ast.Node, enclosingIndex int) ast.Stmt {
enclosingStmt := path[enclosingIndex]
switch enclosingStmt.(type) {
case *ast.IfStmt:
// The enclosingStmt is inside of the if declaration,
// We need to check if we are in an else-if stmt and
// get the base if statement.
return baseIfStmt(path, enclosingIndex)
case *ast.CaseClause:
// Get the enclosing switch stmt if the enclosingStmt is
// inside of the case statement.
for i := enclosingIndex + 1; i < len(path); i++ {
if node, ok := path[i].(*ast.SwitchStmt); ok {
return node
} else if node, ok := path[i].(*ast.TypeSwitchStmt); ok {
return node
}
}
}
if len(path) <= enclosingIndex+1 {
return enclosingStmt.(ast.Stmt)
}
// Check if the enclosing statement is inside another node.
switch expr := path[enclosingIndex+1].(type) {
case *ast.IfStmt:
// Get the base if statement.
return baseIfStmt(path, enclosingIndex+1)
case *ast.ForStmt:
if expr.Init == enclosingStmt || expr.Post == enclosingStmt {
return expr
}
}
return enclosingStmt.(ast.Stmt)
}
// baseIfStmt walks up the if/else-if chain until we get to
// the top of the current if chain.
func baseIfStmt(path []ast.Node, index int) ast.Stmt {
stmt := path[index]
for i := index + 1; i < len(path); i++ {
if node, ok := path[i].(*ast.IfStmt); ok && node.Else == stmt {
stmt = node
continue
}
break
}
return stmt.(ast.Stmt)
}
func FixesError(msg string) bool {
return strings.HasPrefix(msg, undeclaredNamePrefix)
}

View File

@ -76,6 +76,10 @@ func (s *suggestedfix) Run(ctx context.Context, args ...string) error {
}
}
rng, err := file.mapper.Range(from)
if err != nil {
return err
}
p := protocol.CodeActionParams{
TextDocument: protocol.TextDocumentIdentifier{
URI: protocol.URIFromSpanURI(uri),
@ -84,6 +88,7 @@ func (s *suggestedfix) Run(ctx context.Context, args ...string) error {
Only: codeActionKinds,
Diagnostics: file.diagnostics,
},
Range: rng,
}
actions, err := conn.CodeAction(ctx, &p)
if err != nil {
@ -118,6 +123,15 @@ func (s *suggestedfix) Run(ctx context.Context, args ...string) error {
break
}
}
// If suggested fix is not a diagnostic, still must collect edits.
if len(a.Diagnostics) == 0 {
for _, c := range a.Edit.DocumentChanges {
if fileURI(c.TextDocument.URI) == uri {
edits = append(edits, c.Edits...)
}
}
}
}
sedits, err := source.FromProtocolEdits(file.mapper, edits)

View File

@ -162,6 +162,13 @@ func (s *Server) codeAction(ctx context.Context, params *protocol.CodeActionPara
}
codeActions = append(codeActions, fixes...)
}
if wanted[protocol.RefactorExtract] {
fixes, err := extractionFixes(ctx, snapshot, ph, uri, params.Range)
if err != nil {
return nil, err
}
codeActions = append(codeActions, fixes...)
}
default:
// Unsupported file kind for a code action.
return nil, nil
@ -385,6 +392,29 @@ func convenienceFixes(ctx context.Context, snapshot source.Snapshot, ph source.P
return codeActions, nil
}
func extractionFixes(ctx context.Context, snapshot source.Snapshot, ph source.PackageHandle, uri span.URI, rng protocol.Range) ([]protocol.CodeAction, error) {
fh, err := snapshot.GetFile(ctx, uri)
if err != nil {
return nil, nil
}
edits, err := source.ExtractVariable(ctx, snapshot, fh, rng)
if err != nil {
return nil, err
}
if len(edits) == 0 {
return nil, nil
}
return []protocol.CodeAction{
{
Title: "Extract to variable",
Kind: protocol.RefactorExtract,
Edit: protocol.WorkspaceEdit{
DocumentChanges: documentChanges(fh, edits),
},
},
}, nil
}
func documentChanges(fh source.FileHandle, edits []protocol.TextEdit) []protocol.TextDocumentEdit {
return []protocol.TextDocumentEdit{
{

View File

@ -955,21 +955,7 @@ func (c *completer) methodsAndFields(ctx context.Context, typ types.Type, addres
// lexical finds completions in the lexical environment.
func (c *completer) lexical(ctx context.Context) error {
var scopes []*types.Scope // scopes[i], where i<len(path), is the possibly nil Scope of path[i].
for _, n := range c.path {
// Include *FuncType scope if pos is inside the function body.
switch node := n.(type) {
case *ast.FuncDecl:
if node.Body != nil && nodeContains(node.Body, c.pos) {
n = node.Type
}
case *ast.FuncLit:
if node.Body != nil && nodeContains(node.Body, c.pos) {
n = node.Type
}
}
scopes = append(scopes, c.pkg.GetTypesInfo().Scopes[n])
}
scopes := collectScopes(c.pkg, c.path, c.pos)
scopes = append(scopes, c.pkg.GetTypes().Scope(), types.Universe)
var (
@ -1106,6 +1092,26 @@ func (c *completer) lexical(ctx context.Context) error {
return nil
}
func collectScopes(pkg Package, path []ast.Node, pos token.Pos) []*types.Scope {
// scopes[i], where i<len(path), is the possibly nil Scope of path[i].
var scopes []*types.Scope
for _, n := range path {
// Include *FuncType scope if pos is inside the function body.
switch node := n.(type) {
case *ast.FuncDecl:
if node.Body != nil && nodeContains(node.Body, pos) {
n = node.Type
}
case *ast.FuncLit:
if node.Body != nil && nodeContains(node.Body, pos) {
n = node.Type
}
}
scopes = append(scopes, pkg.GetTypesInfo().Scopes[n])
}
return scopes
}
func (c *completer) unimportedPackages(ctx context.Context, seen map[string]struct{}) error {
var prefix string
if c.surrounding != nil {

View File

@ -0,0 +1,140 @@
// Copyright 2020 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package source
import (
"bytes"
"context"
"fmt"
"go/ast"
"go/format"
"go/token"
"go/types"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/internal/analysisinternal"
"golang.org/x/tools/internal/lsp/protocol"
"golang.org/x/tools/internal/span"
)
func ExtractVariable(ctx context.Context, snapshot Snapshot, fh FileHandle, protoRng protocol.Range) ([]protocol.TextEdit, error) {
pkg, pgh, err := getParsedFile(ctx, snapshot, fh, NarrowestPackageHandle)
if err != nil {
return nil, fmt.Errorf("ExtractVariable: %v", err)
}
file, _, m, _, err := pgh.Cached()
if err != nil {
return nil, err
}
spn, err := m.RangeSpan(protoRng)
if err != nil {
return nil, err
}
rng, err := spn.Range(m.Converter)
if err != nil {
return nil, err
}
path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End)
if len(path) == 0 {
return nil, nil
}
fset := snapshot.View().Session().Cache().FileSet()
node := path[0]
tok := fset.File(node.Pos())
if tok == nil {
return nil, fmt.Errorf("ExtractVariable: no token.File for %s", fh.URI())
}
var content []byte
if content, err = fh.Read(); err != nil {
return nil, err
}
if rng.Start != node.Pos() || rng.End != node.End() {
return nil, nil
}
// Adjust new variable name until no collisons in scope.
scopes := collectScopes(pkg, path, node.Pos())
name := "x0"
idx := 0
for !isValidName(name, scopes) {
idx++
name = fmt.Sprintf("x%d", idx)
}
var assignment string
expr, ok := node.(ast.Expr)
if !ok {
return nil, nil
}
// Create new AST node for extracted code
switch expr.(type) {
case *ast.BasicLit, *ast.CompositeLit, *ast.IndexExpr,
*ast.SliceExpr, *ast.UnaryExpr, *ast.BinaryExpr, *ast.SelectorExpr: // TODO: stricter rules for selectorExpr
assignStmt := &ast.AssignStmt{
Lhs: []ast.Expr{ast.NewIdent(name)},
Tok: token.DEFINE,
Rhs: []ast.Expr{expr},
}
var buf bytes.Buffer
if err = format.Node(&buf, fset, assignStmt); err != nil {
return nil, err
}
assignment = buf.String()
case *ast.CallExpr: // TODO: find number of return values and do according actions.
return nil, nil
default:
return nil, nil
}
insertBeforeStmt := analysisinternal.StmtToInsertVarBefore(path)
if insertBeforeStmt == nil {
return nil, nil
}
// Convert token.Pos to protcol.Position
rng = span.NewRange(fset, insertBeforeStmt.Pos(), insertBeforeStmt.End())
spn, err = rng.Span()
if err != nil {
return nil, nil
}
beforeStmtStart, err := m.Position(spn.Start())
if err != nil {
return nil, nil
}
stmtBeforeRng := protocol.Range{
Start: beforeStmtStart,
End: beforeStmtStart,
}
// Calculate indentation for insertion
line := tok.Line(insertBeforeStmt.Pos())
lineOffset := tok.Offset(tok.LineStart(line))
stmtOffset := tok.Offset(insertBeforeStmt.Pos())
indent := content[lineOffset:stmtOffset] // space between these is indentation.
return []protocol.TextEdit{
{
Range: stmtBeforeRng,
NewText: assignment + "\n" + string(indent),
},
{
Range: protoRng,
NewText: name,
},
}, nil
}
// Check for variable collision in scope.
func isValidName(name string, scopes []*types.Scope) bool {
for _, scope := range scopes {
if scope == nil {
continue
}
if scope.Lookup(name) != nil {
return false
}
}
return true
}

View File

@ -94,6 +94,7 @@ func DefaultOptions() Options {
protocol.SourceOrganizeImports: true,
protocol.QuickFix: true,
protocol.RefactorRewrite: true,
protocol.RefactorExtract: true,
},
Mod: {
protocol.SourceOrganizeImports: true,

View File

@ -0,0 +1,6 @@
package extract
func _() {
var _ = 1 + 2 //@suggestedfix("1", "refactor.extract")
var _ = 3 + 4 //@suggestedfix("3 + 4", "refactor.extract")
}

View File

@ -0,0 +1,18 @@
-- suggestedfix_extract_basic_lit_4_10 --
package extract
func _() {
x0 := 1
var _ = x0 + 2 //@suggestedfix("1", "refactor.extract")
var _ = 3 + 4 //@suggestedfix("3 + 4", "refactor.extract")
}
-- suggestedfix_extract_basic_lit_5_10 --
package extract
func _() {
var _ = 1 + 2 //@suggestedfix("1", "refactor.extract")
x0 := 3 + 4
var _ = x0 //@suggestedfix("3 + 4", "refactor.extract")
}

View File

@ -0,0 +1,13 @@
package extract
import "go/ast"
func _() {
x0 := 0
if true {
y := ast.CompositeLit{} //@suggestedfix("ast.CompositeLit{}", "refactor.extract")
}
if true {
x1 := !false //@suggestedfix("!false", "refactor.extract")
}
}

View File

@ -0,0 +1,32 @@
-- suggestedfix_extract_scope_11_9 --
package extract
import "go/ast"
func _() {
x0 := 0
if true {
y := ast.CompositeLit{} //@suggestedfix("ast.CompositeLit{}", "refactor.extract")
}
if true {
x2 := !false
x1 := x2 //@suggestedfix("!false", "refactor.extract")
}
}
-- suggestedfix_extract_scope_8_8 --
package extract
import "go/ast"
func _() {
x0 := 0
if true {
x1 := ast.CompositeLit{}
y := x1 //@suggestedfix("ast.CompositeLit{}", "refactor.extract")
}
if true {
x1 := !false //@suggestedfix("!false", "refactor.extract")
}
}

View File

@ -11,7 +11,7 @@ DiagnosticsCount = 44
FoldingRangesCount = 2
FormatCount = 6
ImportCount = 8
SuggestedFixCount = 14
SuggestedFixCount = 18
DefinitionsCount = 53
TypeDefinitionsCount = 2
HighlightsCount = 69

View File

@ -217,6 +217,7 @@ func DefaultOptions() source.Options {
protocol.SourceOrganizeImports: true,
protocol.QuickFix: true,
protocol.RefactorRewrite: true,
protocol.RefactorExtract: true,
protocol.SourceFixAll: true,
},
source.Mod: {