1
0
mirror of https://github.com/golang/go synced 2024-09-30 14:38:33 -06:00

go.tools/refactor/eg: an example-based refactoring tool.

See refactor/eg/eg.go for details.

LGTM=crawshaw
R=crawshaw, gri, kamil.kisiel, josharian
CC=golang-codereviews
https://golang.org/cl/81010043
This commit is contained in:
Alan Donovan 2014-04-02 12:24:55 -04:00
parent d7048bec64
commit bfcffc697d
28 changed files with 1512 additions and 0 deletions

121
cmd/eg/eg.go Normal file
View File

@ -0,0 +1,121 @@
// The eg command performs example-based refactoring.
package main
import (
"flag"
"fmt"
"go/parser"
"go/printer"
"go/token"
"os"
"path/filepath"
"code.google.com/p/go.tools/go/loader"
"code.google.com/p/go.tools/refactor/eg"
)
var (
helpFlag = flag.Bool("help", false, "show detailed help message")
templateFlag = flag.String("t", "", "template.go file specifying the refactoring")
transitiveFlag = flag.Bool("transitive", false, "apply refactoring to all dependencies too")
writeFlag = flag.Bool("w", false, "rewrite input files in place (by default, the results are printed to standard output)")
verboseFlag = flag.Bool("v", false, "show verbose matcher diagnostics")
)
const usage = `eg: an example-based refactoring tool.
Usage: eg -t template.go [-w] [-transitive] <args>...
-t template.go specifies the template file (use -help to see explanation)
-w causes files to be re-written in place.
-transitive causes all dependencies to be refactored too.
` + loader.FromArgsUsage
func main() {
if err := doMain(); err != nil {
fmt.Fprintf(os.Stderr, "%s: %s.\n", filepath.Base(os.Args[0]), err)
os.Exit(1)
}
}
func doMain() error {
flag.Parse()
args := flag.Args()
if *helpFlag {
fmt.Fprintf(os.Stderr, eg.Help)
os.Exit(2)
}
if *templateFlag == "" {
return fmt.Errorf("no -t template.go file specified")
}
conf := loader.Config{
Fset: token.NewFileSet(),
ParserMode: parser.ParseComments,
SourceImports: true,
}
// The first Created package is the template.
if err := conf.CreateFromFilenames("template", *templateFlag); err != nil {
return err // e.g. "foo.go:1: syntax error"
}
if len(args) == 0 {
fmt.Fprint(os.Stderr, usage)
os.Exit(1)
}
if _, err := conf.FromArgs(args, true); err != nil {
return err
}
// Load, parse and type-check the whole program.
iprog, err := conf.Load()
if err != nil {
return err
}
// Analyze the template.
template := iprog.Created[0]
xform, err := eg.NewTransformer(iprog.Fset, template, *verboseFlag)
if err != nil {
return err
}
// Apply it to the input packages.
var pkgs []*loader.PackageInfo
if *transitiveFlag {
for _, info := range iprog.AllPackages {
pkgs = append(pkgs, info)
}
} else {
pkgs = iprog.InitialPackages()
}
var hadErrors bool
for _, pkg := range pkgs {
if pkg == template {
continue
}
for _, file := range pkg.Files {
n := xform.Transform(&pkg.Info, pkg.Pkg, file)
if n == 0 {
continue
}
filename := iprog.Fset.File(file.Pos()).Name()
fmt.Fprintf(os.Stderr, "=== %s (%d matches):\n", filename, n)
if *writeFlag {
if err := eg.WriteAST(iprog.Fset, filename, file); err != nil {
fmt.Fprintf(os.Stderr, "Error: %s\n", err)
hadErrors = true
}
} else {
printer.Fprint(os.Stdout, iprog.Fset, file)
}
}
}
if hadErrors {
os.Exit(1)
}
return nil
}

1
refactor/README Normal file
View File

@ -0,0 +1 @@
code.google.com/p/go.tools/refactor: libraries for refactoring tools.

326
refactor/eg/eg.go Normal file
View File

@ -0,0 +1,326 @@
// Package eg implements the example-based refactoring tool whose
// command-line is defined in code.google.com/p/go.tools/cmd/eg.
package eg
import (
"bytes"
"fmt"
"go/ast"
"go/printer"
"go/token"
"os"
"code.google.com/p/go.tools/go/loader"
"code.google.com/p/go.tools/go/types"
)
const Help = `
This tool implements example-based refactoring of expressions.
The transformation is specified as a Go file defining two functions,
'before' and 'after', of identical types. Each function body consists
of a single statement: either a return statement with a single
(possibly multi-valued) expression, or an expression statement. The
'before' expression specifies a pattern and the 'after' expression its
replacement.
package P
import ( "errors"; "fmt" )
func before(s string) error { return fmt.Errorf("%s", s) }
func after(s string) error { return errors.New(s) }
The expression statement form is useful when the the expression has no
result, for example:
func before(msg string) { log.Fatalf("%s", msg) }
func after(msg string) { log.Fatal(msg) }
The parameters of both functions are wildcards that may match any
expression assignable to that type. If the pattern contains multiple
occurrences of the same parameter, each must match the same expression
in the input for the pattern to match. If the replacement contains
multiple occurrences of the same parameter, the expression will be
duplicated, possibly changing the side-effects.
The tool analyses all Go code in the packages specified by the
arguments, replacing all occurrences of the pattern with the
substitution.
So, the transform above would change this input:
err := fmt.Errorf("%s", "error: " + msg)
to this output:
err := errors.New("error: " + msg)
Identifiers, including qualified identifiers (p.X) are considered to
match only if they denote the same object. This allows correct
matching even in the presence of dot imports, named imports and
locally shadowed package names in the input program.
Matching of type syntax is semantic, not syntactic: type syntax in the
pattern matches type syntax in the input if the types are identical.
Thus, func(x int) matches func(y int).
This tool was inspired by other example-based refactoring tools,
'gofmt -r' for Go and Refaster for Java.
LIMITATIONS
===========
EXPRESSIVENESS
Only refactorings that replace one expression with another, regardless
of the expression's context, may be expressed. Refactoring arbitrary
statements (or sequences of statements) is a less well-defined problem
and is less amenable to this approach.
A pattern that contains a function literal (and hence statements)
never matches.
There is no way to generalize over related types, e.g. to express that
a wildcard may have any integer type, for example.
SAFETY
Verifying that a transformation does not introduce type errors is very
complex in the general case. An innocuous-looking replacement of one
constant by another (e.g. 1 to 2) may cause type errors relating to
array types and indices, for example. The tool performs only very
superficial checks of type preservation.
It is not possible to replace an expression by one of a different
type, even in contexts where this is legal, such as x in fmt.Print(x).
IMPORTS
Although the matching algorithm is fully aware of scoping rules, the
replacement algorithm is not, so the replacement code may contain
incorrect identifier syntax for imported objects if there are dot
imports, named imports or locally shadowed package names in the input
program.
Imports are added as needed, but they are not removed as needed.
Run 'goimports' on the modified file for now.
Dot imports are forbidden in the template.
`
// TODO(adonovan): allow the tool to be invoked using relative package
// directory names (./foo). Requires changes to go/loader.
// TODO(adonovan): expand upon the above documentation as an HTML page.
// TODO(adonovan): eliminate dependency on loader.PackageInfo.
// Move its ObjectOf/IsType/TypeOf methods into go/types.
// A Transformer represents a single example-based transformation.
type Transformer struct {
fset *token.FileSet
verbose bool
info loader.PackageInfo // combined type info for template/input/output ASTs
seenInfos map[*types.Info]bool
wildcards map[*types.Var]bool // set of parameters in func before()
env map[string]ast.Expr // maps parameter name to wildcard binding
importedObjs map[types.Object]*ast.SelectorExpr // objects imported by after().
before, after ast.Expr
allowWildcards bool
// Working state of Transform():
nsubsts int // number of substitutions made
currentPkg *types.Package // package of current call
}
// NewTransformer returns a transformer based on the specified template,
// a package containing "before" and "after" functions as described
// in the package documentation.
//
func NewTransformer(fset *token.FileSet, template *loader.PackageInfo, verbose bool) (*Transformer, error) {
// Check the template.
beforeSig := funcSig(template.Pkg, "before")
if beforeSig == nil {
return nil, fmt.Errorf("no 'before' func found in template")
}
afterSig := funcSig(template.Pkg, "after")
if afterSig == nil {
return nil, fmt.Errorf("no 'after' func found in template")
}
// TODO(adonovan): should we also check the names of the params match?
if !types.Identical(afterSig, beforeSig) {
return nil, fmt.Errorf("before %s and after %s functions have different signatures",
beforeSig, afterSig)
}
templateFile := template.Files[0]
for _, imp := range templateFile.Imports {
if imp.Name != nil && imp.Name.Name == "." {
// Dot imports are currently forbidden. We
// make the simplifying assumption that all
// imports are regular, without local renames.
//TODO document
return nil, fmt.Errorf("dot-import (of %s) in template", imp.Path.Value)
}
}
var beforeDecl, afterDecl *ast.FuncDecl
for _, decl := range templateFile.Decls {
if decl, ok := decl.(*ast.FuncDecl); ok {
switch decl.Name.Name {
case "before":
beforeDecl = decl
case "after":
afterDecl = decl
}
}
}
before, err := soleExpr(beforeDecl)
if err != nil {
return nil, fmt.Errorf("before: %s", err)
}
after, err := soleExpr(afterDecl)
if err != nil {
return nil, fmt.Errorf("after: %s", err)
}
wildcards := make(map[*types.Var]bool)
for i := 0; i < beforeSig.Params().Len(); i++ {
wildcards[beforeSig.Params().At(i)] = true
}
// checkExprTypes returns an error if Tb (type of before()) is not
// safe to replace with Ta (type of after()).
//
// Only superficial checks are performed, and they may result in both
// false positives and negatives.
//
// Ideally, we would only require that the replacement be assignable
// to the context of a specific pattern occurrence, but the type
// checker doesn't record that information and it's complex to deduce.
// A Go type cannot capture all the constraints of a given expression
// context, which may include the size, constness, signedness,
// namedness or constructor of its type, and even the specific value
// of the replacement. (Consider the rule that array literal keys
// must be unique.) So we cannot hope to prove the safety of a
// transformation in general.
Tb := template.TypeOf(before)
Ta := template.TypeOf(after)
if types.AssignableTo(Tb, Ta) {
// safe: replacement is assignable to pattern.
} else if tuple, ok := Tb.(*types.Tuple); ok && tuple.Len() == 0 {
// safe: pattern has void type (must appear in an ExprStmt).
} else {
return nil, fmt.Errorf("%s is not a safe replacement for %s", Ta, Tb)
}
tr := &Transformer{
fset: fset,
verbose: verbose,
wildcards: wildcards,
allowWildcards: true,
seenInfos: make(map[*types.Info]bool),
importedObjs: make(map[types.Object]*ast.SelectorExpr),
before: before,
after: after,
}
// Combine type info from the template and input packages, and
// type info for the synthesized ASTs too. This saves us
// having to book-keep where each ast.Node originated as we
// construct the resulting hybrid AST.
//
// TODO(adonovan): move type utility methods of PackageInfo to
// types.Info, or at least into go/types.typeutil.
tr.info.Info = types.Info{
Types: make(map[ast.Expr]types.TypeAndValue),
Defs: make(map[*ast.Ident]types.Object),
Uses: make(map[*ast.Ident]types.Object),
Selections: make(map[*ast.SelectorExpr]*types.Selection),
}
mergeTypeInfo(&tr.info.Info, &template.Info)
// Compute set of imported objects required by after().
// TODO reject dot-imports in pattern
ast.Inspect(after, func(n ast.Node) bool {
if n, ok := n.(*ast.SelectorExpr); ok {
sel := tr.info.Selections[n]
if sel.Kind() == types.PackageObj {
tr.importedObjs[sel.Obj()] = n
return false // prune
}
}
return true // recur
})
return tr, nil
}
// WriteAST is a convenience function that writes AST f to the specified file.
func WriteAST(fset *token.FileSet, filename string, f *ast.File) (err error) {
fh, err := os.Create(filename)
if err != nil {
return err
}
defer func() {
if err2 := fh.Close(); err != nil {
err = err2 // prefer earlier error
}
}()
return printer.Fprint(fh, fset, f)
}
// -- utilities --------------------------------------------------------
// funcSig returns the signature of the specified package-level function.
func funcSig(pkg *types.Package, name string) *types.Signature {
if f, ok := pkg.Scope().Lookup(name).(*types.Func); ok {
return f.Type().(*types.Signature)
}
return nil
}
// soleExpr returns the sole expression in the before/after template function.
func soleExpr(fn *ast.FuncDecl) (ast.Expr, error) {
if fn.Body == nil {
return nil, fmt.Errorf("no body")
}
if len(fn.Body.List) != 1 {
return nil, fmt.Errorf("must contain a single statement")
}
switch stmt := fn.Body.List[0].(type) {
case *ast.ReturnStmt:
if len(stmt.Results) != 1 {
return nil, fmt.Errorf("return statement must have a single operand")
}
return stmt.Results[0], nil
case *ast.ExprStmt:
return stmt.X, nil
}
return nil, fmt.Errorf("must contain a single return or expression statement")
}
// mergeTypeInfo adds type info from src to dst.
func mergeTypeInfo(dst, src *types.Info) {
for k, v := range src.Types {
dst.Types[k] = v
}
for k, v := range src.Defs {
dst.Defs[k] = v
}
for k, v := range src.Uses {
dst.Uses[k] = v
}
for k, v := range src.Selections {
dst.Selections[k] = v
}
}
// (debugging only)
func astString(fset *token.FileSet, n ast.Node) string {
var buf bytes.Buffer
printer.Fprint(&buf, fset, n)
return buf.String()
}

145
refactor/eg/eg_test.go Normal file
View File

@ -0,0 +1,145 @@
package eg_test
import (
"bytes"
"flag"
"go/parser"
"go/token"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"testing"
"code.google.com/p/go.tools/go/exact"
"code.google.com/p/go.tools/go/loader"
"code.google.com/p/go.tools/go/types"
"code.google.com/p/go.tools/refactor/eg"
)
// TODO(adonovan): more tests:
// - of command-line tool
// - of all parts of syntax
// - of applying a template to a package it imports:
// the replacement syntax should use unqualified names for its objects.
var (
updateFlag = flag.Bool("update", false, "update the golden files")
verboseFlag = flag.Bool("verbose", false, "show matcher information")
)
func Test(t *testing.T) {
switch runtime.GOOS {
case "windows":
t.Skipf("skipping test on %q (no /usr/bin/diff)", runtime.GOOS)
}
conf := loader.Config{
Fset: token.NewFileSet(),
ParserMode: parser.ParseComments,
SourceImports: true,
}
// Each entry is a single-file package.
// (Multi-file packages aren't interesting for this test.)
// Order matters: each non-template package is processed using
// the preceding template package.
for _, filename := range []string{
"testdata/A.template",
"testdata/A1.go",
"testdata/A2.go",
"testdata/B.template",
"testdata/B1.go",
"testdata/C.template",
"testdata/C1.go",
"testdata/D.template",
"testdata/D1.go",
"testdata/E.template",
"testdata/E1.go",
"testdata/bad_type.template",
"testdata/no_before.template",
"testdata/no_after_return.template",
"testdata/type_mismatch.template",
"testdata/expr_type_mismatch.template",
} {
pkgname := strings.TrimSuffix(filepath.Base(filename), ".go")
if err := conf.CreateFromFilenames(pkgname, filename); err != nil {
t.Fatal(err)
}
}
iprog, err := conf.Load()
if err != nil {
t.Fatal(err)
}
var xform *eg.Transformer
for _, info := range iprog.Created {
file := info.Files[0]
filename := iprog.Fset.File(file.Pos()).Name() // foo.go
if strings.HasSuffix(filename, "template") {
// a new template
shouldFail, _ := info.Pkg.Scope().Lookup("shouldFail").(*types.Const)
xform, err = eg.NewTransformer(iprog.Fset, info, *verboseFlag)
if err != nil {
if shouldFail == nil {
t.Errorf("NewTransformer(%s): %s", filename, err)
} else if want := exact.StringVal(shouldFail.Val()); !strings.Contains(err.Error(), want) {
t.Errorf("NewTransformer(%s): got error %q, want error %q", filename, err, want)
}
} else if shouldFail != nil {
t.Errorf("NewTransformer(%s) succeeded unexpectedly; want error %q",
filename, shouldFail.Val())
}
continue
}
if xform == nil {
t.Errorf("%s: no previous template", filename)
continue
}
// apply previous template to this package
n := xform.Transform(&info.Info, info.Pkg, file)
if n == 0 {
t.Errorf("%s: no matches", filename)
continue
}
got := filename + "t" // foo.got
golden := filename + "lden" // foo.golden
// Write actual output to foo.got.
if err := eg.WriteAST(iprog.Fset, got, file); err != nil {
t.Error(err)
}
// Compare foo.got with foo.golden.
var cmd *exec.Cmd
switch runtime.GOOS {
case "plan9":
cmd = exec.Command("/bin/diff", "-c", golden, got)
default:
cmd = exec.Command("/usr/bin/diff", "-u", "-N", golden, got)
}
buf := new(bytes.Buffer)
cmd.Stdout = buf
cmd.Stderr = os.Stderr
if err := cmd.Run(); err != nil {
t.Errorf("eg tests for %s failed: %s.\n%s\n", filename, err, buf)
if *updateFlag {
t.Logf("Updating %s...", golden)
if err := exec.Command("/bin/cp", got, golden).Run(); err != nil {
t.Errorf("Update failed: %s", err)
}
}
}
}
}

226
refactor/eg/match.go Normal file
View File

@ -0,0 +1,226 @@
package eg
import (
"fmt"
"go/ast"
"go/token"
"log"
"os"
"reflect"
"code.google.com/p/go.tools/go/exact"
"code.google.com/p/go.tools/go/loader"
"code.google.com/p/go.tools/go/types"
)
// matchExpr reports whether pattern x matches y.
//
// If tr.allowWildcards, Idents in x that refer to parameters are
// treated as wildcards, and match any y that is assignable to the
// parameter type; matchExpr records this correspondence in tr.env.
// Otherwise, matchExpr simply reports whether the two trees are
// equivalent.
//
// A wildcard appearing more than once in the pattern must
// consistently match the same tree.
//
func (tr *Transformer) matchExpr(x, y ast.Expr) bool {
if x == nil && y == nil {
return true
}
if x == nil || y == nil {
return false
}
x = unparen(x)
y = unparen(y)
// Is x a wildcard? (a reference to a 'before' parameter)
if x, ok := x.(*ast.Ident); ok && x != nil && tr.allowWildcards {
if xobj, ok := tr.info.Uses[x].(*types.Var); ok && tr.wildcards[xobj] {
return tr.matchWildcard(xobj, y)
}
}
// Object identifiers (including pkg-qualified ones)
// are handled semantically, not syntactically.
xobj := isRef(x, &tr.info)
yobj := isRef(y, &tr.info)
if xobj != nil {
return xobj == yobj
}
if yobj != nil {
return false
}
// TODO(adonovan): audit: we cannot assume these ast.Exprs
// contain non-nil pointers. e.g. ImportSpec.Name may be a
// nil *ast.Ident.
if reflect.TypeOf(x) != reflect.TypeOf(y) {
return false
}
switch x := x.(type) {
case *ast.Ident:
log.Fatalf("unexpected Ident: %s", astString(tr.fset, x))
case *ast.BasicLit:
y := y.(*ast.BasicLit)
xval := exact.MakeFromLiteral(x.Value, x.Kind)
yval := exact.MakeFromLiteral(y.Value, y.Kind)
return exact.Compare(xval, token.EQL, yval)
case *ast.FuncLit:
// func literals (and thus statement syntax) never match.
return false
case *ast.CompositeLit:
y := y.(*ast.CompositeLit)
return (x.Type == nil) == (y.Type == nil) &&
(x.Type == nil || tr.matchType(x.Type, y.Type)) &&
tr.matchExprs(x.Elts, y.Elts)
case *ast.SelectorExpr:
y := y.(*ast.SelectorExpr)
return tr.matchExpr(x.X, y.X) &&
tr.info.Selections[x].Obj() == tr.info.Selections[y].Obj()
case *ast.IndexExpr:
y := y.(*ast.IndexExpr)
return tr.matchExpr(x.X, y.X) &&
tr.matchExpr(x.Index, y.Index)
case *ast.SliceExpr:
y := y.(*ast.SliceExpr)
return tr.matchExpr(x.X, y.X) &&
tr.matchExpr(x.Low, y.Low) &&
tr.matchExpr(x.High, y.High) &&
tr.matchExpr(x.Max, y.Max) &&
x.Slice3 == y.Slice3
case *ast.TypeAssertExpr:
y := y.(*ast.TypeAssertExpr)
return tr.matchExpr(x.X, y.X) &&
tr.matchType(x.Type, y.Type)
case *ast.CallExpr:
y := y.(*ast.CallExpr)
match := tr.matchExpr // function call
if tr.info.IsType(x.Fun) {
match = tr.matchType // type conversion
}
return x.Ellipsis.IsValid() == y.Ellipsis.IsValid() &&
match(x.Fun, y.Fun) &&
tr.matchExprs(x.Args, y.Args)
case *ast.StarExpr:
y := y.(*ast.StarExpr)
return tr.matchExpr(x.X, y.X)
case *ast.UnaryExpr:
y := y.(*ast.UnaryExpr)
return x.Op == y.Op &&
tr.matchExpr(x.X, y.X)
case *ast.BinaryExpr:
y := y.(*ast.BinaryExpr)
return x.Op == y.Op &&
tr.matchExpr(x.X, y.X) &&
tr.matchExpr(x.Y, y.Y)
case *ast.KeyValueExpr:
y := y.(*ast.KeyValueExpr)
return tr.matchExpr(x.Key, y.Key) &&
tr.matchExpr(x.Value, y.Value)
}
panic(fmt.Sprintf("unhandled AST node type: %T", x))
}
func (tr *Transformer) matchExprs(xx, yy []ast.Expr) bool {
if len(xx) != len(yy) {
return false
}
for i := range xx {
if !tr.matchExpr(xx[i], yy[i]) {
return false
}
}
return true
}
// matchType reports whether the two type ASTs denote identical types.
func (tr *Transformer) matchType(x, y ast.Expr) bool {
tx := tr.info.Types[x].Type
ty := tr.info.Types[y].Type
return types.Identical(tx, ty)
}
func (tr *Transformer) matchWildcard(xobj *types.Var, y ast.Expr) bool {
name := xobj.Name()
if tr.verbose {
fmt.Fprintf(os.Stderr, "%s: wildcard %s -> %s?: ",
tr.fset.Position(y.Pos()), name, astString(tr.fset, y))
}
// Check that y is assignable to the declared type of the param.
if yt := tr.info.TypeOf(y); !types.AssignableTo(yt, xobj.Type()) {
if tr.verbose {
fmt.Fprintf(os.Stderr, "%s not assignable to %s\n", yt, xobj.Type())
}
return false
}
// A wildcard matches any expression.
// If it appears multiple times in the pattern, it must match
// the same expression each time.
if old, ok := tr.env[name]; ok {
// found existing binding
tr.allowWildcards = false
r := tr.matchExpr(old, y)
if tr.verbose {
fmt.Fprintf(os.Stderr, "%t secondary match, primary was %s\n",
r, astString(tr.fset, old))
}
tr.allowWildcards = true
return r
}
if tr.verbose {
fmt.Fprintf(os.Stderr, "primary match\n")
}
tr.env[name] = y // record binding
return true
}
// -- utilities --------------------------------------------------------
// unparen returns e with any enclosing parentheses stripped.
// TODO(adonovan): move to astutil package.
func unparen(e ast.Expr) ast.Expr {
for {
p, ok := e.(*ast.ParenExpr)
if !ok {
break
}
e = p.X
}
return e
}
// isRef returns the object referred to by this (possibly qualified)
// identifier, or nil if the node is not a referring identifier.
func isRef(n ast.Node, info *loader.PackageInfo) types.Object {
switch n := n.(type) {
case *ast.Ident:
return info.Uses[n]
case *ast.SelectorExpr:
sel := info.Selections[n]
if sel.Kind() == types.PackageObj {
return sel.Obj()
}
}
return nil
}

347
refactor/eg/rewrite.go Normal file
View File

@ -0,0 +1,347 @@
package eg
// This file defines the AST rewriting pass.
// Most of it was plundered directly from
// $GOROOT/src/cmd/gofmt/rewrite.go (after convergent evolution).
import (
"fmt"
"go/ast"
"go/token"
"os"
"reflect"
"sort"
"strconv"
"strings"
"code.google.com/p/go.tools/astutil"
"code.google.com/p/go.tools/go/types"
)
// Transform applies the transformation to the specified parsed file,
// whose type information is supplied in info, and returns the number
// of replacements that were made.
//
// It mutates the AST in place (the identity of the root node is
// unchanged), and may add nodes for which no type information is
// available in info.
//
// Derived from rewriteFile in $GOROOT/src/cmd/gofmt/rewrite.go.
//
func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast.File) int {
if !tr.seenInfos[info] {
tr.seenInfos[info] = true
mergeTypeInfo(&tr.info.Info, info)
}
tr.currentPkg = pkg
tr.nsubsts = 0
if tr.verbose {
fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
}
var f func(rv reflect.Value) reflect.Value
f = func(rv reflect.Value) reflect.Value {
// don't bother if val is invalid to start with
if !rv.IsValid() {
return reflect.Value{}
}
rv = apply(f, rv)
e := rvToExpr(rv)
if e != nil {
savedEnv := tr.env
tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs
if tr.matchExpr(tr.before, e) {
if tr.verbose {
fmt.Fprintf(os.Stderr, "%s matches %s",
astString(tr.fset, tr.before), astString(tr.fset, e))
if len(tr.env) > 0 {
fmt.Fprintf(os.Stderr, " with:")
for name, ast := range tr.env {
fmt.Fprintf(os.Stderr, " %s->%s",
name, astString(tr.fset, ast))
}
}
fmt.Fprintf(os.Stderr, "\n")
}
tr.nsubsts++
// Clone the replacement tree, performing parameter substitution.
// We update all positions to n.Pos() to aid comment placement.
rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
reflect.ValueOf(e.Pos()))
}
tr.env = savedEnv
}
return rv
}
file2 := apply(f, reflect.ValueOf(file)).Interface().(*ast.File)
// By construction, the root node is unchanged.
if file != file2 {
panic("BUG")
}
// Add any necessary imports.
// TODO(adonovan): remove no-longer needed imports too.
if tr.nsubsts > 0 {
pkgs := make(map[string]*types.Package)
for obj := range tr.importedObjs {
pkgs[obj.Pkg().Path()] = obj.Pkg()
}
for _, imp := range file.Imports {
path, _ := strconv.Unquote(imp.Path.Value)
delete(pkgs, path)
}
delete(pkgs, pkg.Path()) // don't import self
// NB: AddImport may completely replace the AST!
// It thus renders info and tr.info no longer relevant to file.
var paths []string
for path := range pkgs {
paths = append(paths, path)
}
sort.Strings(paths)
for _, path := range paths {
astutil.AddImport(tr.fset, file, path)
}
}
tr.currentPkg = nil
return tr.nsubsts
}
// setValue is a wrapper for x.SetValue(y); it protects
// the caller from panics if x cannot be changed to y.
func setValue(x, y reflect.Value) {
// don't bother if y is invalid to start with
if !y.IsValid() {
return
}
defer func() {
if x := recover(); x != nil {
if s, ok := x.(string); ok &&
(strings.Contains(s, "type mismatch") || strings.Contains(s, "not assignable")) {
// x cannot be set to y - ignore this rewrite
return
}
panic(x)
}
}()
x.Set(y)
}
// Values/types for special cases.
var (
objectPtrNil = reflect.ValueOf((*ast.Object)(nil))
scopePtrNil = reflect.ValueOf((*ast.Scope)(nil))
identType = reflect.TypeOf((*ast.Ident)(nil))
selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil))
objectPtrType = reflect.TypeOf((*ast.Object)(nil))
positionType = reflect.TypeOf(token.NoPos)
callExprType = reflect.TypeOf((*ast.CallExpr)(nil))
scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
)
// apply replaces each AST field x in val with f(x), returning val.
// To avoid extra conversions, f operates on the reflect.Value form.
func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
if !val.IsValid() {
return reflect.Value{}
}
// *ast.Objects introduce cycles and are likely incorrect after
// rewrite; don't follow them but replace with nil instead
if val.Type() == objectPtrType {
return objectPtrNil
}
// similarly for scopes: they are likely incorrect after a rewrite;
// replace them with nil
if val.Type() == scopePtrType {
return scopePtrNil
}
switch v := reflect.Indirect(val); v.Kind() {
case reflect.Slice:
for i := 0; i < v.Len(); i++ {
e := v.Index(i)
setValue(e, f(e))
}
case reflect.Struct:
for i := 0; i < v.NumField(); i++ {
e := v.Field(i)
setValue(e, f(e))
}
case reflect.Interface:
e := v.Elem()
setValue(v, f(e))
}
return val
}
// subst returns a copy of (replacement) pattern with values from env
// substituted in place of wildcards and pos used as the position of
// tokens from the pattern. if env == nil, subst returns a copy of
// pattern and doesn't change the line number information.
func (tr *Transformer) subst(env map[string]ast.Expr, pattern, pos reflect.Value) reflect.Value {
if !pattern.IsValid() {
return reflect.Value{}
}
// *ast.Objects introduce cycles and are likely incorrect after
// rewrite; don't follow them but replace with nil instead
if pattern.Type() == objectPtrType {
return objectPtrNil
}
// similarly for scopes: they are likely incorrect after a rewrite;
// replace them with nil
if pattern.Type() == scopePtrType {
return scopePtrNil
}
// Wildcard gets replaced with map value.
if env != nil && pattern.Type() == identType {
id := pattern.Interface().(*ast.Ident)
if old, ok := env[id.Name]; ok {
return tr.subst(nil, reflect.ValueOf(old), reflect.Value{})
}
}
// Emit qualified identifiers in the pattern by appropriate
// (possibly qualified) identifier in the input.
//
// The template cannot contain dot imports, so all identifiers
// for imported objects are explicitly qualified.
//
// We assume (unsoundly) that there are no dot or named
// imports in the input code, nor are any imported package
// names shadowed, so the usual normal qualified identifier
// syntax may be used.
// TODO(adonovan): fix: avoid this assumption.
//
// A refactoring may be applied to a package referenced by the
// template. Objects belonging to the current package are
// denoted by unqualified identifiers.
//
if tr.importedObjs != nil && pattern.Type() == selectorExprType {
obj := isRef(pattern.Interface().(*ast.SelectorExpr), &tr.info)
if obj != nil {
if sel, ok := tr.importedObjs[obj]; ok {
var id ast.Expr
if obj.Pkg() == tr.currentPkg {
id = sel.Sel // unqualified
} else {
id = sel // pkg-qualified
}
// Return a clone of id.
saved := tr.importedObjs
tr.importedObjs = nil // break cycle
r := tr.subst(nil, reflect.ValueOf(id), pos)
tr.importedObjs = saved
return r
}
}
}
if pos.IsValid() && pattern.Type() == positionType {
// use new position only if old position was valid in the first place
if old := pattern.Interface().(token.Pos); !old.IsValid() {
return pattern
}
return pos
}
// Otherwise copy.
switch p := pattern; p.Kind() {
case reflect.Slice:
v := reflect.MakeSlice(p.Type(), p.Len(), p.Len())
for i := 0; i < p.Len(); i++ {
v.Index(i).Set(tr.subst(env, p.Index(i), pos))
}
return v
case reflect.Struct:
v := reflect.New(p.Type()).Elem()
for i := 0; i < p.NumField(); i++ {
v.Field(i).Set(tr.subst(env, p.Field(i), pos))
}
return v
case reflect.Ptr:
v := reflect.New(p.Type()).Elem()
if elem := p.Elem(); elem.IsValid() {
v.Set(tr.subst(env, elem, pos).Addr())
}
// Duplicate type information for duplicated ast.Expr.
// All ast.Node implementations are *structs,
// so this case catches them all.
if e := rvToExpr(v); e != nil {
updateTypeInfo(&tr.info.Info, e, p.Interface().(ast.Expr))
}
return v
case reflect.Interface:
v := reflect.New(p.Type()).Elem()
if elem := p.Elem(); elem.IsValid() {
v.Set(tr.subst(env, elem, pos))
}
return v
}
return pattern
}
// -- utilitiies -------------------------------------------------------
func rvToExpr(rv reflect.Value) ast.Expr {
if rv.CanInterface() {
if e, ok := rv.Interface().(ast.Expr); ok {
return e
}
}
return nil
}
// updateTypeInfo duplicates type information for the existing AST old
// so that it also applies to duplicated AST new.
func updateTypeInfo(info *types.Info, new, old ast.Expr) {
switch new := new.(type) {
case *ast.Ident:
orig := old.(*ast.Ident)
if obj, ok := info.Defs[orig]; ok {
info.Defs[new] = obj
}
if obj, ok := info.Uses[orig]; ok {
info.Uses[new] = obj
}
case *ast.SelectorExpr:
orig := old.(*ast.SelectorExpr)
if sel, ok := info.Selections[orig]; ok {
info.Selections[new] = sel
}
}
if tv, ok := info.Types[old]; ok {
info.Types[new] = tv
}
}
func F() {}
func G() {}
func init() {
F()
}

13
refactor/eg/testdata/A.template vendored Normal file
View File

@ -0,0 +1,13 @@
// +build ignore
package template
// Basic test of type-aware expression refactoring.
import (
"errors"
"fmt"
)
func before(s string) error { return fmt.Errorf("%s", s) }
func after(s string) error { return errors.New(s) }

51
refactor/eg/testdata/A1.go vendored Normal file
View File

@ -0,0 +1,51 @@
// +build ignore
package A1
import (
. "fmt"
myfmt "fmt"
"os"
"strings"
)
func example(n int) {
x := "foo" + strings.Repeat("\t", n)
// Match, despite named import.
myfmt.Errorf("%s", x)
// Match, despite dot import.
Errorf("%s", x)
// Match: multiple matches in same function are possible.
myfmt.Errorf("%s", x)
// No match: wildcarded operand has the wrong type.
myfmt.Errorf("%s", 3)
// No match: function operand doesn't match.
myfmt.Printf("%s", x)
// No match again, dot import.
Printf("%s", x)
// Match.
myfmt.Fprint(os.Stderr, myfmt.Errorf("%s", x+"foo"))
// No match: though this literally matches the template,
// fmt doesn't resolve to a package here.
var fmt struct{ Errorf func(string, string) }
fmt.Errorf("%s", x)
// Recursive matching:
// Match: both matches are well-typed, so both succeed.
myfmt.Errorf("%s", myfmt.Errorf("%s", x+"foo").Error())
// Outer match succeeds, inner doesn't: 3 has wrong type.
myfmt.Errorf("%s", myfmt.Errorf("%s", 3).Error())
// Inner match succeeds, outer doesn't: the inner replacement
// has the wrong type (error not string).
myfmt.Errorf("%s", myfmt.Errorf("%s", x+"foo"))
}

52
refactor/eg/testdata/A1.golden vendored Normal file
View File

@ -0,0 +1,52 @@
// +build ignore
package A1
import (
. "fmt"
"errors"
myfmt "fmt"
"os"
"strings"
)
func example(n int) {
x := "foo" + strings.Repeat("\t", n)
// Match, despite named import.
errors.New(x)
// Match, despite dot import.
errors.New(x)
// Match: multiple matches in same function are possible.
errors.New(x)
// No match: wildcarded operand has the wrong type.
myfmt.Errorf("%s", 3)
// No match: function operand doesn't match.
myfmt.Printf("%s", x)
// No match again, dot import.
Printf("%s", x)
// Match.
myfmt.Fprint(os.Stderr, errors.New(x+"foo"))
// No match: though this literally matches the template,
// fmt doesn't resolve to a package here.
var fmt struct{ Errorf func(string, string) }
fmt.Errorf("%s", x)
// Recursive matching:
// Match: both matches are well-typed, so both succeed.
errors.New(errors.New(x + "foo").Error())
// Outer match succeeds, inner doesn't: 3 has wrong type.
errors.New(myfmt.Errorf("%s", 3).Error())
// Inner match succeeds, outer doesn't: the inner replacement
// has the wrong type (error not string).
myfmt.Errorf("%s", errors.New(x+"foo"))
}

12
refactor/eg/testdata/A2.go vendored Normal file
View File

@ -0,0 +1,12 @@
// +build ignore
package A2
// This refactoring causes addition of "errors" import.
// TODO(adonovan): fix: it should also remove "fmt".
import myfmt "fmt"
func example(n int) {
myfmt.Errorf("%s", "")
}

15
refactor/eg/testdata/A2.golden vendored Normal file
View File

@ -0,0 +1,15 @@
// +build ignore
package A2
// This refactoring causes addition of "errors" import.
// TODO(adonovan): fix: it should also remove "fmt".
import (
myfmt "fmt"
"errors"
)
func example(n int) {
errors.New("")
}

9
refactor/eg/testdata/B.template vendored Normal file
View File

@ -0,0 +1,9 @@
package template
// Basic test of expression refactoring.
// (Types are not important in this case; it could be done with gofmt -r.)
import "time"
func before(t time.Time) time.Duration { return time.Now().Sub(t) }
func after(t time.Time) time.Duration { return time.Since(t) }

17
refactor/eg/testdata/B1.go vendored Normal file
View File

@ -0,0 +1,17 @@
// +build ignore
package B1
import "time"
var startup = time.Now()
func example() time.Duration {
before := time.Now()
time.Sleep(1)
return time.Now().Sub(before)
}
func msSinceStartup() int64 {
return int64(time.Now().Sub(startup) / time.Millisecond)
}

17
refactor/eg/testdata/B1.golden vendored Normal file
View File

@ -0,0 +1,17 @@
// +build ignore
package B1
import "time"
var startup = time.Now()
func example() time.Duration {
before := time.Now()
time.Sleep(1)
return time.Since(before)
}
func msSinceStartup() int64 {
return int64(time.Since(startup) / time.Millisecond)
}

10
refactor/eg/testdata/C.template vendored Normal file
View File

@ -0,0 +1,10 @@
package template
// Test of repeated use of wildcard in pattern.
// NB: multiple patterns would be required to handle variants such as
// s[:len(s)], s[x:len(s)], etc, since a wildcard can't match nothing at all.
// TODO(adonovan): support multiple templates in a single pass.
func before(s string) string { return s[:len(s)] }
func after(s string) string { return s }

22
refactor/eg/testdata/C1.go vendored Normal file
View File

@ -0,0 +1,22 @@
// +build ignore
package C1
import "strings"
func example() {
x := "foo"
println(x[:len(x)])
// Match, but the transformation is not sound w.r.t. possible side effects.
println(strings.Repeat("*", 3)[:len(strings.Repeat("*", 3))])
// No match, since second use of wildcard doesn't match first.
println(strings.Repeat("*", 3)[:len(strings.Repeat("*", 2))])
// Recursive match demonstrating bottom-up rewrite:
// only after the inner replacement occurs does the outer syntax match.
println((x[:len(x)])[:len(x[:len(x)])])
// -> (x[:len(x)])
// -> x
}

22
refactor/eg/testdata/C1.golden vendored Normal file
View File

@ -0,0 +1,22 @@
// +build ignore
package C1
import "strings"
func example() {
x := "foo"
println(x)
// Match, but the transformation is not sound w.r.t. possible side effects.
println(strings.Repeat("*", 3))
// No match, since second use of wildcard doesn't match first.
println(strings.Repeat("*", 3)[:len(strings.Repeat("*", 2))])
// Recursive match demonstrating bottom-up rewrite:
// only after the inner replacement occurs does the outer syntax match.
println(x)
// -> (x[:len(x)])
// -> x
}

8
refactor/eg/testdata/D.template vendored Normal file
View File

@ -0,0 +1,8 @@
package template
import "fmt"
// Test of semantic (not syntactic) matching of basic literals.
func before() (int, error) { return fmt.Println(123, "a") }
func after() (int, error) { return fmt.Println(456, "!") }

12
refactor/eg/testdata/D1.go vendored Normal file
View File

@ -0,0 +1,12 @@
// +build ignore
package D1
import "fmt"
func example() {
fmt.Println(123, "a") // match
fmt.Println(0x7b, `a`) // match
fmt.Println(0173, "\x61") // match
fmt.Println(100+20+3, "a"+"") // no match: constant expressions, but not basic literals
}

12
refactor/eg/testdata/D1.golden vendored Normal file
View File

@ -0,0 +1,12 @@
// +build ignore
package D1
import "fmt"
func example() {
fmt.Println(456, "!") // match
fmt.Println(456, "!") // match
fmt.Println(456, "!") // match
fmt.Println(100+20+3, "a"+"") // no match: constant expressions, but not basic literals
}

12
refactor/eg/testdata/E.template vendored Normal file
View File

@ -0,0 +1,12 @@
package template
import (
"fmt"
"log"
"os"
)
// Replace call to void function by call to non-void function.
func before(x interface{}) { log.Fatal(x) }
func after(x interface{}) { fmt.Fprintf(os.Stderr, "warning: %v", x) }

9
refactor/eg/testdata/E1.go vendored Normal file
View File

@ -0,0 +1,9 @@
// +build ignore
package E1
import "log"
func example() {
log.Fatal("oops") // match
}

13
refactor/eg/testdata/E1.golden vendored Normal file
View File

@ -0,0 +1,13 @@
// +build ignore
package E1
import (
"log"
"os"
"fmt"
)
func example() {
fmt.Fprintf(os.Stderr, "warning: %v", "oops") // match
}

View File

@ -0,0 +1,8 @@
package template
// Test in which replacement has a different type.
const shouldFail = "int is not a safe replacement for string"
func before() interface{} { return "three" }
func after() interface{} { return 3 }

View File

@ -0,0 +1,15 @@
package template
import (
"crypto/x509"
"fmt"
)
// This test demonstrates a false negative: according to the language
// rules this replacement should be ok, but types.Assignable doesn't work
// in the expected way (elementwise assignability) for tuples.
// Perhaps that's even a type-checker bug?
const shouldFail = "(n int, err error) is not a safe replacement for (key interface{}, err error)"
func before() (interface{}, error) { return x509.ParsePKCS8PrivateKey(nil) }
func after() (interface{}, error) { return fmt.Print() }

View File

@ -0,0 +1,6 @@
package template
const shouldFail = "after: must contain a single statement"
func before() int { return 0 }
func after() int { println(); return 0 }

View File

@ -0,0 +1,5 @@
package template
const shouldFail = "no 'before' func found in template"
func Before() {}

View File

@ -0,0 +1,6 @@
package template
const shouldFail = "different signatures"
func before() int { return 0 }
func after() string { return "" }