1
0
mirror of https://github.com/golang/go synced 2024-11-18 07:04:52 -07:00

refactor/eg: Add support for multi line after statements to eg.

The semantics of this change are that the last line will be subsituted
in place of the expression, where as the lines before that will undergo
variable substitution and be prepended before the lowest (in the AST
tree sense) statement which included the expression.

Change-Id: Ie2571934dcc1b0a30b5cec157e690924a4ac2c5a
Reviewed-on: https://go-review.googlesource.com/77730
Reviewed-by: Alan Donovan <adonovan@google.com>
This commit is contained in:
Colin 2017-11-14 18:37:00 -05:00 committed by Michael Matloob
parent 96caea4103
commit 2226533658
10 changed files with 211 additions and 50 deletions

View File

@ -145,6 +145,7 @@ type Transformer struct {
env map[string]ast.Expr // maps parameter name to wildcard binding
importedObjs map[types.Object]*ast.SelectorExpr // objects imported by after().
before, after ast.Expr
afterStmts []ast.Stmt
allowWildcards bool
// Working state of Transform():
@ -198,7 +199,7 @@ func NewTransformer(fset *token.FileSet, tmplPkg *types.Package, tmplFile *ast.F
if err != nil {
return nil, fmt.Errorf("before: %s", err)
}
after, err := soleExpr(afterDecl)
afterStmts, after, err := stmtAndExpr(afterDecl)
if err != nil {
return nil, fmt.Errorf("after: %s", err)
}
@ -242,6 +243,7 @@ func NewTransformer(fset *token.FileSet, tmplPkg *types.Package, tmplFile *ast.F
importedObjs: make(map[types.Object]*ast.SelectorExpr),
before: before,
after: after,
afterStmts: afterStmts,
}
// Combine type info from the template and input packages, and
@ -279,6 +281,7 @@ func WriteAST(fset *token.FileSet, filename string, f *ast.File) (err error) {
if err != nil {
return err
}
defer func() {
if err2 := fh.Close(); err != nil {
err = err2 // prefer earlier error
@ -319,6 +322,33 @@ func soleExpr(fn *ast.FuncDecl) (ast.Expr, error) {
return nil, fmt.Errorf("must contain a single return or expression statement")
}
// stmtAndExpr returns the expression in the last return statement as well as the preceeding lines.
func stmtAndExpr(fn *ast.FuncDecl) ([]ast.Stmt, ast.Expr, error) {
if fn.Body == nil {
return nil, nil, fmt.Errorf("no body")
}
n := len(fn.Body.List)
if n == 0 {
return nil, nil, fmt.Errorf("must contain at least one statement")
}
stmts, last := fn.Body.List[:n-1], fn.Body.List[n-1]
switch last := last.(type) {
case *ast.ReturnStmt:
if len(last.Results) != 1 {
return nil, nil, fmt.Errorf("return statement must have a single operand")
}
return stmts, last.Results[0], nil
case *ast.ExprStmt:
return stmts, last.X, nil
}
return nil, nil, fmt.Errorf("must end with a single return or expression statement")
}
// mergeTypeInfo adds type info from src to dst.
func mergeTypeInfo(dst, src *types.Info) {
for k, v := range src.Types {

View File

@ -78,6 +78,12 @@ func Test(t *testing.T) {
"testdata/H.template",
"testdata/H1.go",
"testdata/I.template",
"testdata/I1.go",
"testdata/J.template",
"testdata/J1.go",
"testdata/bad_type.template",
"testdata/no_before.template",
"testdata/no_after_return.template",

View File

@ -22,6 +22,52 @@ import (
"golang.org/x/tools/go/ast/astutil"
)
// transformItem takes a reflect.Value representing a variable of type ast.Node
// transforms its child elements recursively with apply, and then transforms the
// actual element if it contains an expression.
func (tr *Transformer) transformItem(rv reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
// don't bother if val is invalid to start with
if !rv.IsValid() {
return reflect.Value{}, false, nil
}
rv, changed, newEnv := tr.apply(tr.transformItem, rv)
e := rvToExpr(rv)
if e == nil {
return rv, changed, newEnv
}
savedEnv := tr.env
tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs
if tr.matchExpr(tr.before, e) {
if tr.verbose {
fmt.Fprintf(os.Stderr, "%s matches %s",
astString(tr.fset, tr.before), astString(tr.fset, e))
if len(tr.env) > 0 {
fmt.Fprintf(os.Stderr, " with:")
for name, ast := range tr.env {
fmt.Fprintf(os.Stderr, " %s->%s",
name, astString(tr.fset, ast))
}
}
fmt.Fprintf(os.Stderr, "\n")
}
tr.nsubsts++
// Clone the replacement tree, performing parameter substitution.
// We update all positions to n.Pos() to aid comment placement.
rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
reflect.ValueOf(e.Pos()))
changed = true
newEnv = tr.env
}
tr.env = savedEnv
return rv, changed, newEnv
}
// Transform applies the transformation to the specified parsed file,
// whose type information is supplied in info, and returns the number
// of replacements that were made.
@ -43,48 +89,14 @@ func (tr *Transformer) Transform(info *types.Info, pkg *types.Package, file *ast
if tr.verbose {
fmt.Fprintf(os.Stderr, "before: %s\n", astString(tr.fset, tr.before))
fmt.Fprintf(os.Stderr, "after: %s\n", astString(tr.fset, tr.after))
fmt.Fprintf(os.Stderr, "afterStmts: %s\n", tr.afterStmts)
}
var f func(rv reflect.Value) reflect.Value
f = func(rv reflect.Value) reflect.Value {
// don't bother if val is invalid to start with
if !rv.IsValid() {
return reflect.Value{}
}
rv = apply(f, rv)
e := rvToExpr(rv)
if e != nil {
savedEnv := tr.env
tr.env = make(map[string]ast.Expr) // inefficient! Use a slice of k/v pairs
if tr.matchExpr(tr.before, e) {
if tr.verbose {
fmt.Fprintf(os.Stderr, "%s matches %s",
astString(tr.fset, tr.before), astString(tr.fset, e))
if len(tr.env) > 0 {
fmt.Fprintf(os.Stderr, " with:")
for name, ast := range tr.env {
fmt.Fprintf(os.Stderr, " %s->%s",
name, astString(tr.fset, ast))
}
}
fmt.Fprintf(os.Stderr, "\n")
}
tr.nsubsts++
// Clone the replacement tree, performing parameter substitution.
// We update all positions to n.Pos() to aid comment placement.
rv = tr.subst(tr.env, reflect.ValueOf(tr.after),
reflect.ValueOf(e.Pos()))
}
tr.env = savedEnv
}
return rv
o, changed, _ := tr.apply(tr.transformItem, reflect.ValueOf(file))
if changed {
panic("BUG")
}
file2 := apply(f, reflect.ValueOf(file)).Interface().(*ast.File)
file2 := o.Interface().(*ast.File)
// By construction, the root node is unchanged.
if file != file2 {
@ -150,45 +162,91 @@ var (
identType = reflect.TypeOf((*ast.Ident)(nil))
selectorExprType = reflect.TypeOf((*ast.SelectorExpr)(nil))
objectPtrType = reflect.TypeOf((*ast.Object)(nil))
statementType = reflect.TypeOf((*ast.Stmt)(nil)).Elem()
positionType = reflect.TypeOf(token.NoPos)
scopePtrType = reflect.TypeOf((*ast.Scope)(nil))
)
// apply replaces each AST field x in val with f(x), returning val.
// To avoid extra conversions, f operates on the reflect.Value form.
func apply(f func(reflect.Value) reflect.Value, val reflect.Value) reflect.Value {
// f takes a reflect.Value representing the variable to modify of type ast.Node.
// It returns a reflect.Value containing the transformed value of type ast.Node,
// whether any change was made, and a map of identifiers to ast.Expr (so we can
// do contextually correct substitutions in the parent statements).
func (tr *Transformer) apply(f func(reflect.Value) (reflect.Value, bool, map[string]ast.Expr), val reflect.Value) (reflect.Value, bool, map[string]ast.Expr) {
if !val.IsValid() {
return reflect.Value{}
return reflect.Value{}, false, nil
}
// *ast.Objects introduce cycles and are likely incorrect after
// rewrite; don't follow them but replace with nil instead
if val.Type() == objectPtrType {
return objectPtrNil
return objectPtrNil, false, nil
}
// similarly for scopes: they are likely incorrect after a rewrite;
// replace them with nil
if val.Type() == scopePtrType {
return scopePtrNil
return scopePtrNil, false, nil
}
switch v := reflect.Indirect(val); v.Kind() {
case reflect.Slice:
// no possible rewriting of statements.
if v.Type().Elem() != statementType {
changed := false
var envp map[string]ast.Expr
for i := 0; i < v.Len(); i++ {
e := v.Index(i)
o, localchanged, env := f(e)
if localchanged {
changed = true
// we clobber envp here,
// which means if we have two sucessive
// replacements inside the same statement
// we will only generate the setup for one of them.
envp = env
}
setValue(e, o)
}
return val, changed, envp
}
// statements are rewritten.
var out []ast.Stmt
for i := 0; i < v.Len(); i++ {
e := v.Index(i)
setValue(e, f(e))
o, changed, env := f(e)
if changed {
for _, s := range tr.afterStmts {
t := tr.subst(env, reflect.ValueOf(s), reflect.Value{}).Interface()
out = append(out, t.(ast.Stmt))
}
}
setValue(e, o)
out = append(out, e.Interface().(ast.Stmt))
}
return reflect.ValueOf(out), false, nil
case reflect.Struct:
changed := false
var envp map[string]ast.Expr
for i := 0; i < v.NumField(); i++ {
e := v.Field(i)
setValue(e, f(e))
o, localchanged, env := f(e)
if localchanged {
changed = true
envp = env
}
setValue(e, o)
}
return val, changed, envp
case reflect.Interface:
e := v.Elem()
setValue(v, f(e))
o, changed, env := f(e)
setValue(v, o)
return val, changed, env
}
return val
return val, false, nil
}
// subst returns a copy of (replacement) pattern with values from env

14
refactor/eg/testdata/I.template vendored Normal file
View File

@ -0,0 +1,14 @@
// +build ignore
package templates
import (
"errors"
"fmt"
)
func before(s string) error { return fmt.Errorf("%s", s) }
func after(s string) error {
n := fmt.Sprintf("error - %s", s)
return errors.New(n)
}

9
refactor/eg/testdata/I1.go vendored Normal file
View File

@ -0,0 +1,9 @@
// +build ignore
package I1
import "fmt"
func example() {
_ = fmt.Errorf("%s", "foo")
}

14
refactor/eg/testdata/I1.golden vendored Normal file
View File

@ -0,0 +1,14 @@
// +build ignore
package I1
import (
"errors"
"fmt"
)
func example() {
n := fmt.Sprintf("error - %s", "foo")
_ = errors.New(n)
}

11
refactor/eg/testdata/J.template vendored Normal file
View File

@ -0,0 +1,11 @@
// +build ignore
package templates
import ()
func before(x int) int { return x + x + x }
func after(x int) int {
temp := x + x
return temp + x
}

10
refactor/eg/testdata/J1.go vendored Normal file
View File

@ -0,0 +1,10 @@
// +build ignore
package I1
import "fmt"
func example() {
temp := 5
fmt.Print(temp + temp + temp)
}

11
refactor/eg/testdata/J1.golden vendored Normal file
View File

@ -0,0 +1,11 @@
// +build ignore
package I1
import "fmt"
func example() {
temp := 5
temp := temp + temp
fmt.Print(temp + temp)
}

View File

@ -1,6 +1,4 @@
package template
const shouldFail = "after: must contain a single statement"
func before() int { return 0 }
func after() int { println(); return 0 }