1
0
mirror of https://github.com/golang/go synced 2024-09-30 16:08:36 -06:00

go/ast/astutil: add Apply, for rewriting Go ASTs

This change implements an AST walker, Apply, with the capability
to perform various actions (including rewrites) during the AST
walk, in pre- and post-iteration order.

Ideally we would like to have this functionality in the std lib
(go/ast), but as this change contains significant new API, it
is better to first gain some experience with it in x/tools.

Credits: This is joint work of Josh Bleecher Snyder, Roger
Peppe, and myself (author of this CL). The code is based on
proposal golang/go#17108 which I had started, together with an
initial implementation. It was then reworked significantly by
Josh in golang.org/cl/55790, further refined by suggestions
from Roger (allocation reduction), and finally cleaned up a bit
by me (AST simplifications).

Fixes golang/go#17108.

Change-Id: I56d56b7f2bc5be2acfeb927f76aea7f264bb7b94
Reviewed-on: https://go-review.googlesource.com/77811
Run-TryBot: Robert Griesemer <gri@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Alan Donovan <adonovan@google.com>
This commit is contained in:
Robert Griesemer 2017-11-14 21:27:03 -08:00
parent 4de4a8d206
commit 6d70fb2e85
2 changed files with 725 additions and 0 deletions

477
go/ast/astutil/rewrite.go Normal file
View File

@ -0,0 +1,477 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package astutil
import (
"fmt"
"go/ast"
"reflect"
"sort"
)
// An ApplyFunc is invoked by Apply for each node n, even if n is nil,
// before and/or after the node's children, using a Cursor describing
// the current node and providing operations on it.
//
// The return value of ApplyFunc controls the syntax tree traversal.
// See Apply for details.
type ApplyFunc func(*Cursor) bool
// Apply traverses a syntax tree recursively, starting with root,
// and calling pre and post for each node as described below.
// Apply returns the syntax tree, possibly modified.
//
// If pre is not nil, it is called for each node before the node's
// children are traversed (pre-order). If pre returns false, no
// children are traversed, and post is not called for that node.
//
// If post is not nil, and a prior call of pre didn't return false,
// post is called for each node after its children are traversed
// (post-order). If post returns false, traversal is terminated and
// Apply returns immediately.
//
// Only fields that refer to AST nodes are considered children;
// i.e., token.Pos, Scopes, Objects, and fields of basic types
// (strings, etc.) are ignored.
//
// Children are traversed in the order in which they appear in the
// respective node's struct definition. A package's files are
// traversed in the filenames' alphabetical order.
//
func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
parent := &struct{ ast.Node }{root}
defer func() {
if r := recover(); r != nil && r != abort {
panic(r)
}
result = parent.Node
}()
a := &application{pre: pre, post: post}
a.apply(parent, "Node", nil, root)
return
}
var abort = new(int) // singleton, to signal termination of Apply
// A Cursor describes a node encountered during Apply.
// Information about the node and its parent is available
// from the Node, Parent, Name, and Index methods.
//
// If p is a variable of type and value of the current parent node
// c.Parent(), and f is the field identifier with name c.Name(),
// the following invariants hold:
//
// p.f == c.Node() if c.Index() < 0
// p.f[c.Index()] == c.Node() if c.Index() >= 0
//
// The methods Replace, Delete, InsertBefore, and InsertAfter
// can be used to change the AST without disrupting Apply.
type Cursor struct {
parent ast.Node
name string
iter *iterator // valid if non-nil
node ast.Node
}
// Node returns the current Node.
func (c *Cursor) Node() ast.Node { return c.node }
// Parent returns the parent of the current Node.
func (c *Cursor) Parent() ast.Node { return c.parent }
// Name returns the name of the parent Node field that contains the current Node.
// If the parent is a *ast.Package and the current Node is a *ast.File, Name returns
// the filename for the current Node.
func (c *Cursor) Name() string { return c.name }
// Index reports the index >= 0 of the current Node in the slice of Nodes that
// contains it, or a value < 0 if the current Node is not part of a slice.
// The index of the current node changes if InsertBefore is called while
// processing the current node.
func (c *Cursor) Index() int {
if c.iter != nil {
return c.iter.index
}
return -1
}
// field returns the current node's parent field value.
func (c *Cursor) field() reflect.Value {
return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
}
// Replace replaces the current Node with n.
// The replacement node is not walked by Apply.
func (c *Cursor) Replace(n ast.Node) {
if _, ok := c.node.(*ast.File); ok {
file, ok := n.(*ast.File)
if !ok {
panic("attempt to replace *ast.File with non-*ast.File")
}
c.parent.(*ast.Package).Files[c.name] = file
return
}
v := c.field()
if i := c.Index(); i >= 0 {
v = v.Index(i)
}
v.Set(reflect.ValueOf(n))
}
// Delete deletes the current Node from its containing slice.
// If the current Node is not part of a slice, Delete panics.
// As a special case, if the current node is a package file,
// Delete removes it from the package's Files map.
func (c *Cursor) Delete() {
if _, ok := c.node.(*ast.File); ok {
delete(c.parent.(*ast.Package).Files, c.name)
return
}
i := c.Index()
if i < 0 {
panic("Delete node not contained in slice")
}
v := c.field()
l := v.Len()
reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
v.SetLen(l - 1)
c.iter.step--
}
// InsertAfter inserts n after the current Node in its containing slice.
// If the current Node is not part of a slice, InsertAfter panics.
// Apply does not walk n.
func (c *Cursor) InsertAfter(n ast.Node) {
i := c.Index()
if i < 0 {
panic("InsertAfter node not contained in slice")
}
v := c.field()
v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
l := v.Len()
reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
v.Index(i + 1).Set(reflect.ValueOf(n))
c.iter.step++
}
// InsertBefore inserts n before the current Node in its containing slice.
// If the current Node is not part of a slice, InsertBefore panics.
// Apply will not walk n.
func (c *Cursor) InsertBefore(n ast.Node) {
i := c.Index()
if i < 0 {
panic("InsertBefore node not contained in slice")
}
v := c.field()
v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
l := v.Len()
reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
v.Index(i).Set(reflect.ValueOf(n))
c.iter.index++
}
// application carries all the shared data so we can pass it around cheaply.
type application struct {
pre, post ApplyFunc
cursor Cursor
iter iterator
}
func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
// convert typed nil into untyped nil
if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() {
n = nil
}
// avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead
saved := a.cursor
a.cursor.parent = parent
a.cursor.name = name
a.cursor.iter = iter
a.cursor.node = n
if a.pre != nil && !a.pre(&a.cursor) {
a.cursor = saved
return
}
// walk children
// (the order of the cases matches the order of the corresponding node types in go/ast)
switch n := n.(type) {
case nil:
// nothing to do
// Comments and fields
case *ast.Comment:
// nothing to do
case *ast.CommentGroup:
if n != nil {
a.applyList(n, "List")
}
case *ast.Field:
a.apply(n, "Doc", nil, n.Doc)
a.applyList(n, "Names")
a.apply(n, "Type", nil, n.Type)
a.apply(n, "Tag", nil, n.Tag)
a.apply(n, "Comment", nil, n.Comment)
case *ast.FieldList:
a.applyList(n, "List")
// Expressions
case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
// nothing to do
case *ast.Ellipsis:
a.apply(n, "Elt", nil, n.Elt)
case *ast.FuncLit:
a.apply(n, "Type", nil, n.Type)
a.apply(n, "Body", nil, n.Body)
case *ast.CompositeLit:
a.apply(n, "Type", nil, n.Type)
a.applyList(n, "Elts")
case *ast.ParenExpr:
a.apply(n, "X", nil, n.X)
case *ast.SelectorExpr:
a.apply(n, "X", nil, n.X)
a.apply(n, "Sel", nil, n.Sel)
case *ast.IndexExpr:
a.apply(n, "X", nil, n.X)
a.apply(n, "Index", nil, n.Index)
case *ast.SliceExpr:
a.apply(n, "X", nil, n.X)
a.apply(n, "Low", nil, n.Low)
a.apply(n, "High", nil, n.High)
a.apply(n, "Max", nil, n.Max)
case *ast.TypeAssertExpr:
a.apply(n, "X", nil, n.X)
a.apply(n, "Type", nil, n.Type)
case *ast.CallExpr:
a.apply(n, "Fun", nil, n.Fun)
a.applyList(n, "Args")
case *ast.StarExpr:
a.apply(n, "X", nil, n.X)
case *ast.UnaryExpr:
a.apply(n, "X", nil, n.X)
case *ast.BinaryExpr:
a.apply(n, "X", nil, n.X)
a.apply(n, "Y", nil, n.Y)
case *ast.KeyValueExpr:
a.apply(n, "Key", nil, n.Key)
a.apply(n, "Value", nil, n.Value)
// Types
case *ast.ArrayType:
a.apply(n, "Len", nil, n.Len)
a.apply(n, "Elt", nil, n.Elt)
case *ast.StructType:
a.apply(n, "Fields", nil, n.Fields)
case *ast.FuncType:
a.apply(n, "Params", nil, n.Params)
a.apply(n, "Results", nil, n.Results)
case *ast.InterfaceType:
a.apply(n, "Methods", nil, n.Methods)
case *ast.MapType:
a.apply(n, "Key", nil, n.Key)
a.apply(n, "Value", nil, n.Value)
case *ast.ChanType:
a.apply(n, "Value", nil, n.Value)
// Statements
case *ast.BadStmt:
// nothing to do
case *ast.DeclStmt:
a.apply(n, "Decl", nil, n.Decl)
case *ast.EmptyStmt:
// nothing to do
case *ast.LabeledStmt:
a.apply(n, "Label", nil, n.Label)
a.apply(n, "Stmt", nil, n.Stmt)
case *ast.ExprStmt:
a.apply(n, "X", nil, n.X)
case *ast.SendStmt:
a.apply(n, "Chan", nil, n.Chan)
a.apply(n, "Value", nil, n.Value)
case *ast.IncDecStmt:
a.apply(n, "X", nil, n.X)
case *ast.AssignStmt:
a.applyList(n, "Lhs")
a.applyList(n, "Rhs")
case *ast.GoStmt:
a.apply(n, "Call", nil, n.Call)
case *ast.DeferStmt:
a.apply(n, "Call", nil, n.Call)
case *ast.ReturnStmt:
a.applyList(n, "Results")
case *ast.BranchStmt:
a.apply(n, "Label", nil, n.Label)
case *ast.BlockStmt:
a.applyList(n, "List")
case *ast.IfStmt:
a.apply(n, "Init", nil, n.Init)
a.apply(n, "Cond", nil, n.Cond)
a.apply(n, "Body", nil, n.Body)
a.apply(n, "Else", nil, n.Else)
case *ast.CaseClause:
a.applyList(n, "List")
a.applyList(n, "Body")
case *ast.SwitchStmt:
a.apply(n, "Init", nil, n.Init)
a.apply(n, "Tag", nil, n.Tag)
a.apply(n, "Body", nil, n.Body)
case *ast.TypeSwitchStmt:
a.apply(n, "Init", nil, n.Init)
a.apply(n, "Assign", nil, n.Assign)
a.apply(n, "Body", nil, n.Body)
case *ast.CommClause:
a.apply(n, "Comm", nil, n.Comm)
a.applyList(n, "Body")
case *ast.SelectStmt:
a.apply(n, "Body", nil, n.Body)
case *ast.ForStmt:
a.apply(n, "Init", nil, n.Init)
a.apply(n, "Cond", nil, n.Cond)
a.apply(n, "Post", nil, n.Post)
a.apply(n, "Body", nil, n.Body)
case *ast.RangeStmt:
a.apply(n, "Key", nil, n.Key)
a.apply(n, "Value", nil, n.Value)
a.apply(n, "X", nil, n.X)
a.apply(n, "Body", nil, n.Body)
// Declarations
case *ast.ImportSpec:
a.apply(n, "Doc", nil, n.Doc)
a.apply(n, "Name", nil, n.Name)
a.apply(n, "Path", nil, n.Path)
a.apply(n, "Comment", nil, n.Comment)
case *ast.ValueSpec:
a.apply(n, "Doc", nil, n.Doc)
a.applyList(n, "Names")
a.apply(n, "Type", nil, n.Type)
a.applyList(n, "Values")
a.apply(n, "Comment", nil, n.Comment)
case *ast.TypeSpec:
a.apply(n, "Doc", nil, n.Doc)
a.apply(n, "Name", nil, n.Name)
a.apply(n, "Type", nil, n.Type)
a.apply(n, "Comment", nil, n.Comment)
case *ast.BadDecl:
// nothing to do
case *ast.GenDecl:
a.apply(n, "Doc", nil, n.Doc)
a.applyList(n, "Specs")
case *ast.FuncDecl:
a.apply(n, "Doc", nil, n.Doc)
a.apply(n, "Recv", nil, n.Recv)
a.apply(n, "Name", nil, n.Name)
a.apply(n, "Type", nil, n.Type)
a.apply(n, "Body", nil, n.Body)
// Files and packages
case *ast.File:
a.apply(n, "Doc", nil, n.Doc)
a.apply(n, "Name", nil, n.Name)
a.applyList(n, "Decls")
// Don't walk n.Comments; they have either been walked already if
// they are Doc comments, or they can be easily walked explicitly.
case *ast.Package:
// collect and sort names for reproducible behavior
var names []string
for name := range n.Files {
names = append(names, name)
}
sort.Strings(names)
for _, name := range names {
a.apply(n, name, nil, n.Files[name])
}
default:
panic(fmt.Sprintf("Apply: unexpected node type %T", n))
}
if a.post != nil && !a.post(&a.cursor) {
panic(abort)
}
a.cursor = saved
}
// An iterator controls iteration over a slice of nodes.
type iterator struct {
index, step int
}
func (a *application) applyList(parent ast.Node, name string) {
// avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead
saved := a.iter
a.iter.index = 0
for {
// must reload parent.name each time, since cursor modifications might change it
v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
if a.iter.index >= v.Len() {
break
}
// element x may be nil in a bad AST - be cautious
var x ast.Node
if e := v.Index(a.iter.index); e.IsValid() {
x = e.Interface().(ast.Node)
}
a.iter.step = 1
a.apply(parent, name, &a.iter, x)
a.iter.index += a.iter.step
}
a.iter = saved
}

View File

@ -0,0 +1,248 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package astutil_test
import (
"bytes"
"go/ast"
"go/format"
"go/parser"
"go/token"
"testing"
"golang.org/x/tools/go/ast/astutil"
)
var rewriteTests = [...]struct {
name string
orig, want string
pre, post astutil.ApplyFunc
}{
{name: "nop", orig: "package p\n", want: "package p\n"},
{name: "replace",
orig: `package p
var x int
`,
want: `package p
var t T
`,
post: func(c *astutil.Cursor) bool {
if _, ok := c.Node().(*ast.ValueSpec); ok {
c.Replace(valspec("t", "T"))
return false
}
return true
},
},
{name: "set doc strings",
orig: `package p
const z = 0
type T struct{}
var x int
`,
want: `package p
// a foo is a foo
const z = 0
// a foo is a foo
type T struct{}
// a foo is a foo
var x int
`,
post: func(c *astutil.Cursor) bool {
if _, ok := c.Parent().(*ast.GenDecl); ok && c.Name() == "Doc" && c.Node() == nil {
c.Replace(&ast.CommentGroup{List: []*ast.Comment{{Text: "// a foo is a foo"}}})
}
return true
},
},
{name: "insert names",
orig: `package p
const a = 1
`,
want: `package p
const a, b, c = 1, 2, 3
`,
pre: func(c *astutil.Cursor) bool {
if _, ok := c.Parent().(*ast.ValueSpec); ok {
switch c.Name() {
case "Names":
c.InsertAfter(ast.NewIdent("c"))
c.InsertAfter(ast.NewIdent("b"))
case "Values":
c.InsertAfter(&ast.BasicLit{Kind: token.INT, Value: "3"})
c.InsertAfter(&ast.BasicLit{Kind: token.INT, Value: "2"})
}
}
return true
},
},
{name: "insert",
orig: `package p
var (
x int
y int
)
`,
want: `package p
var before1 int
var before2 int
var (
x int
y int
)
var after2 int
var after1 int
`,
pre: func(c *astutil.Cursor) bool {
if _, ok := c.Node().(*ast.GenDecl); ok {
c.InsertBefore(vardecl("before1", "int"))
c.InsertAfter(vardecl("after1", "int"))
c.InsertAfter(vardecl("after2", "int"))
c.InsertBefore(vardecl("before2", "int"))
}
return true
},
},
{name: "delete",
orig: `package p
var x int
var y int
var z int
`,
want: `package p
var y int
var z int
`,
pre: func(c *astutil.Cursor) bool {
n := c.Node()
if d, ok := n.(*ast.GenDecl); ok && d.Specs[0].(*ast.ValueSpec).Names[0].Name == "x" {
c.Delete()
}
return true
},
},
{name: "insertafter-delete",
orig: `package p
var x int
var y int
var z int
`,
want: `package p
var x1 int
var y int
var z int
`,
pre: func(c *astutil.Cursor) bool {
n := c.Node()
if d, ok := n.(*ast.GenDecl); ok && d.Specs[0].(*ast.ValueSpec).Names[0].Name == "x" {
c.InsertAfter(vardecl("x1", "int"))
c.Delete()
}
return true
},
},
{name: "delete-insertafter",
orig: `package p
var x int
var y int
var z int
`,
want: `package p
var y int
var x1 int
var z int
`,
pre: func(c *astutil.Cursor) bool {
n := c.Node()
if d, ok := n.(*ast.GenDecl); ok && d.Specs[0].(*ast.ValueSpec).Names[0].Name == "x" {
c.Delete()
// The cursor is now effectively atop the 'var y int' node.
c.InsertAfter(vardecl("x1", "int"))
}
return true
},
},
}
func valspec(name, typ string) *ast.ValueSpec {
return &ast.ValueSpec{Names: []*ast.Ident{ast.NewIdent(name)},
Type: ast.NewIdent(typ),
}
}
func vardecl(name, typ string) *ast.GenDecl {
return &ast.GenDecl{
Tok: token.VAR,
Specs: []ast.Spec{valspec(name, typ)},
}
}
func TestRewrite(t *testing.T) {
t.Run("*", func(t *testing.T) {
for _, test := range rewriteTests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, test.name, test.orig, parser.ParseComments)
if err != nil {
t.Fatal(err)
}
n := astutil.Apply(f, test.pre, test.post)
var buf bytes.Buffer
if err := format.Node(&buf, fset, n); err != nil {
t.Fatal(err)
}
got := buf.String()
if got != test.want {
t.Errorf("got:\n\n%s\nwant:\n\n%s\n", got, test.want)
}
})
}
})
}
var sink ast.Node
func BenchmarkRewrite(b *testing.B) {
for _, test := range rewriteTests {
b.Run(test.name, func(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, test.name, test.orig, parser.ParseComments)
if err != nil {
b.Fatal(err)
}
b.StartTimer()
sink = astutil.Apply(f, test.pre, test.post)
}
})
}
}