diff --git a/internal/lsp/source/command.go b/internal/lsp/source/command.go index 0103ab49b8..211f562329 100644 --- a/internal/lsp/source/command.go +++ b/internal/lsp/source/command.go @@ -105,7 +105,10 @@ var ( Name: "extract_variable", Title: "Extract to variable", suggestedFixFn: extractVariable, - appliesFn: canExtractVariable, + appliesFn: func(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) bool { + _, _, ok, _ := canExtractVariable(fset, rng, src, file, pkg, info) + return ok + }, } // CommandExtractFunction extracts statements to a function. diff --git a/internal/lsp/source/extract.go b/internal/lsp/source/extract.go index 753a713b1f..f0acf8a37e 100644 --- a/internal/lsp/source/extract.go +++ b/internal/lsp/source/extract.go @@ -22,22 +22,12 @@ import ( ) func extractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (*analysis.SuggestedFix, error) { - if rng.Start == rng.End { - return nil, fmt.Errorf("extractVariable: start and end are equal (%v)", fset.Position(rng.Start)) - } - path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) - if len(path) == 0 { - return nil, fmt.Errorf("extractVariable: no path enclosing interval") - } - node := path[0] - if rng.Start != node.Pos() || rng.End != node.End() { - return nil, fmt.Errorf("extractVariable: node doesn't perfectly enclose range") - } - expr, ok := node.(ast.Expr) + expr, path, ok, err := canExtractVariable(fset, rng, src, file, pkg, info) if !ok { - return nil, fmt.Errorf("extractVariable: node is not an expression") + return nil, fmt.Errorf("extractVariable: cannot extract %s: %v", fset.Position(rng.Start), err) } - name := generateAvailableIdentifier(node.Pos(), file, path, info, "x", 0) + + name := generateAvailableIdentifier(expr.Pos(), file, path, info, "x", 0) // Create new AST node for extracted code. var assignment string @@ -65,7 +55,7 @@ func extractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast. return nil, nil } - tok := fset.File(node.Pos()) + tok := fset.File(expr.Pos()) if tok == nil { return nil, nil } @@ -88,22 +78,23 @@ func extractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast. // canExtractVariable reports whether the code in the given range can be // extracted to a variable. -// TODO(rstambler): De-duplicate the logic between extractVariable and -// canExtractVariable. -func canExtractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) bool { +func canExtractVariable(fset *token.FileSet, rng span.Range, src []byte, file *ast.File, pkg *types.Package, info *types.Info) (ast.Expr, []ast.Node, bool, error) { if rng.Start == rng.End { - return false + return nil, nil, false, fmt.Errorf("start and end are equal") } path, _ := astutil.PathEnclosingInterval(file, rng.Start, rng.End) if len(path) == 0 { - return false + return nil, nil, false, fmt.Errorf("no path enclosing interval") } node := path[0] if rng.Start != node.Pos() || rng.End != node.End() { - return false + return nil, nil, false, fmt.Errorf("range does not map to an AST node") } - _, ok := node.(ast.Expr) - return ok + expr, ok := node.(ast.Expr) + if !ok { + return nil, nil, false, fmt.Errorf("node is not an expression") + } + return expr, path, true, nil } // Calculate indentation for insertion. @@ -415,7 +406,8 @@ func extractFunction(fset *token.FileSet, rng span.Range, src []byte, file *ast. hasReturnValues := len(returns)+len(retVars) > 0 if hasReturnValues { extractedBlock.List = append(extractedBlock.List, &ast.ReturnStmt{ - Results: append(returns, getZeroVals(retVars)...)}) + Results: append(returns, getZeroVals(retVars)...), + }) } // Construct the appropriate call to the extracted function.