1
0
mirror of https://github.com/golang/go synced 2024-11-21 23:54:40 -07:00

go/ast: fix ast.Walk

- change Walk signature to use an ast.Node instead of interface{}
- add Pos functions to a couple of ast types to make them proper nodes
- explicit nil checks where a node can be nil; incorrect ASTs cause Walk to crash

For now ast.Walk is exercised extensively as part of godoc's indexer;
so we have some confidence in its correctness. But this needs a test,
eventually.

Fixes #1326.

R=rsc, r
CC=golang-dev
https://golang.org/cl/3481043
This commit is contained in:
Robert Griesemer 2010-12-09 10:22:01 -08:00
parent e2da3b6498
commit e1d6b3c98d
6 changed files with 208 additions and 145 deletions

View File

@ -509,7 +509,7 @@ func (x *Indexer) visitSpec(spec ast.Spec, isVarDecl bool) {
} }
func (x *Indexer) Visit(node interface{}) ast.Visitor { func (x *Indexer) Visit(node ast.Node) ast.Visitor {
// TODO(gri): methods in interface types are categorized as VarDecl // TODO(gri): methods in interface types are categorized as VarDecl
switch n := node.(type) { switch n := node.(type) {
case nil: case nil:

View File

@ -12,7 +12,7 @@ import (
type simplifier struct{} type simplifier struct{}
func (s *simplifier) Visit(node interface{}) ast.Visitor { func (s *simplifier) Visit(node ast.Node) ast.Visitor {
switch n := node.(type) { switch n := node.(type) {
case *ast.CompositeLit: case *ast.CompositeLit:
// array, slice, and map composite literals may be simplified // array, slice, and map composite literals may be simplified
@ -61,7 +61,7 @@ func (s *simplifier) Visit(node interface{}) ast.Visitor {
} }
func simplify(node interface{}) { func simplify(node ast.Node) {
var s simplifier var s simplifier
ast.Walk(&s, node) ast.Walk(&s, node)
} }

View File

@ -139,7 +139,7 @@ func (f *File) checkFile(name string, file *ast.File) {
} }
// Visit implements the ast.Visitor interface. // Visit implements the ast.Visitor interface.
func (f *File) Visit(node interface{}) ast.Visitor { func (f *File) Visit(node ast.Node) ast.Visitor {
// TODO: could return nil for nodes that cannot contain a CallExpr - // TODO: could return nil for nodes that cannot contain a CallExpr -
// will shortcut traversal. Worthwhile? // will shortcut traversal. Worthwhile?
switch n := node.(type) { switch n := node.(type) {

View File

@ -70,19 +70,20 @@ type Comment struct {
} }
func (c *Comment) Pos() token.Pos { func (c *Comment) Pos() token.Pos { return c.Slash }
return c.Slash
}
// A CommentGroup represents a sequence of comments // A CommentGroup represents a sequence of comments
// with no other tokens and no empty lines between. // with no other tokens and no empty lines between.
// //
type CommentGroup struct { type CommentGroup struct {
List []*Comment List []*Comment // len(List) > 0
} }
func (g *CommentGroup) Pos() token.Pos { return g.List[0].Pos() }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
// Expressions and types // Expressions and types
@ -115,6 +116,9 @@ type FieldList struct {
} }
func (list *FieldList) Pos() token.Pos { return list.Opening }
// NumFields returns the number of (named and anonymous fields) in a FieldList. // NumFields returns the number of (named and anonymous fields) in a FieldList.
func (f *FieldList) NumFields() int { func (f *FieldList) NumFields() int {
n := 0 n := 0
@ -294,7 +298,7 @@ type (
FuncType struct { FuncType struct {
Func token.Pos // position of "func" keyword Func token.Pos // position of "func" keyword
Params *FieldList // (incoming) parameters Params *FieldList // (incoming) parameters
Results *FieldList // (outgoing) results Results *FieldList // (outgoing) results; or nil
} }
// An InterfaceType node represents an interface type. // An InterfaceType node represents an interface type.
@ -494,7 +498,7 @@ type (
BranchStmt struct { BranchStmt struct {
TokPos token.Pos // position of Tok TokPos token.Pos // position of Tok
Tok token.Token // keyword token (BREAK, CONTINUE, GOTO, FALLTHROUGH) Tok token.Token // keyword token (BREAK, CONTINUE, GOTO, FALLTHROUGH)
Label *Ident Label *Ident // label name; or nil
} }
// A BlockStmt node represents a braced statement list. // A BlockStmt node represents a braced statement list.
@ -507,10 +511,10 @@ type (
// An IfStmt node represents an if statement. // An IfStmt node represents an if statement.
IfStmt struct { IfStmt struct {
If token.Pos // position of "if" keyword If token.Pos // position of "if" keyword
Init Stmt Init Stmt // initalization statement; or nil
Cond Expr Cond Expr // condition; or nil
Body *BlockStmt Body *BlockStmt
Else Stmt Else Stmt // else branch; or nil
} }
// A CaseClause represents a case of an expression switch statement. // A CaseClause represents a case of an expression switch statement.
@ -523,9 +527,9 @@ type (
// A SwitchStmt node represents an expression switch statement. // A SwitchStmt node represents an expression switch statement.
SwitchStmt struct { SwitchStmt struct {
Switch token.Pos // position of "switch" keyword Switch token.Pos // position of "switch" keyword
Init Stmt Init Stmt // initalization statement; or nil
Tag Expr Tag Expr // tag expression; or nil
Body *BlockStmt // CaseClauses only Body *BlockStmt // CaseClauses only
} }
@ -539,8 +543,8 @@ type (
// An TypeSwitchStmt node represents a type switch statement. // An TypeSwitchStmt node represents a type switch statement.
TypeSwitchStmt struct { TypeSwitchStmt struct {
Switch token.Pos // position of "switch" keyword Switch token.Pos // position of "switch" keyword
Init Stmt Init Stmt // initalization statement; or nil
Assign Stmt // x := y.(type) Assign Stmt // x := y.(type)
Body *BlockStmt // TypeCaseClauses only Body *BlockStmt // TypeCaseClauses only
} }
@ -563,9 +567,9 @@ type (
// A ForStmt represents a for statement. // A ForStmt represents a for statement.
ForStmt struct { ForStmt struct {
For token.Pos // position of "for" keyword For token.Pos // position of "for" keyword
Init Stmt Init Stmt // initalization statement; or nil
Cond Expr Cond Expr // condition; or nil
Post Stmt Post Stmt // post iteration statement; or nil
Body *BlockStmt Body *BlockStmt
} }
@ -780,3 +784,13 @@ type Package struct {
Scope *Scope // package scope; or nil Scope *Scope // package scope; or nil
Files map[string]*File // Go source files by filename Files map[string]*File // Go source files by filename
} }
func (p *Package) Pos() (pos token.Pos) {
// get the position of the package clause of the first file, if any
for _, f := range p.Files {
pos = f.Pos()
break
}
return
}

View File

@ -10,51 +10,57 @@ import "fmt"
// If the result visitor w is not nil, Walk visits each of the children // If the result visitor w is not nil, Walk visits each of the children
// of node with the visitor w, followed by a call of w.Visit(nil). // of node with the visitor w, followed by a call of w.Visit(nil).
type Visitor interface { type Visitor interface {
Visit(node interface{}) (w Visitor) Visit(node Node) (w Visitor)
} }
func walkIdent(v Visitor, x *Ident) { // Helper functions for common node lists. They may be empty.
if x != nil {
func walkIdentList(v Visitor, list []*Ident) {
for _, x := range list {
Walk(v, x) Walk(v, x)
} }
} }
func walkCommentGroup(v Visitor, g *CommentGroup) { func walkExprList(v Visitor, list []Expr) {
if g != nil { for _, x := range list {
Walk(v, g) Walk(v, x)
} }
} }
func walkBlockStmt(v Visitor, b *BlockStmt) { func walkStmtList(v Visitor, list []Stmt) {
if b != nil { for _, x := range list {
Walk(v, b) Walk(v, x)
} }
} }
// Walk traverses an AST in depth-first order: If node != nil, it func walkDeclList(v Visitor, list []Decl) {
// invokes v.Visit(node). If the visitor w returned by v.Visit(node) is for _, x := range list {
// not nil, Walk visits each of the children of node with the visitor w, Walk(v, x)
// followed by a call of w.Visit(nil).
//
// Walk may be called with any of the named ast node types. It also
// accepts arguments of type []*Field, []*Ident, []Expr, []Stmt and []Decl;
// the respective children are the slice elements.
//
func Walk(v Visitor, node interface{}) {
if node == nil {
return
} }
}
// TODO(gri): Investigate if providing a closure to Walk leads to
// simpler use (and may help eliminate Inspect in turn).
// Walk traverses an AST in depth-first order: It starts by calling
// v.Visit(node); node must not be nil. If the visitor w returned by
// v.Visit(node) is not nil, Walk is invoked recursively with visitor
// w for each of the non-nil children of node, followed by a call of
// w.Visit(nil).
//
func Walk(v Visitor, node Node) {
if v = v.Visit(node); v == nil { if v = v.Visit(node); v == nil {
return return
} }
// walk children // walk children
// (the order of the cases matches the order // (the order of the cases matches the order
// of the corresponding declaration in ast.go) // of the corresponding node types in ast.go)
switch n := node.(type) { switch n := node.(type) {
// Comments and fields // Comments and fields
case *Comment: case *Comment:
@ -66,11 +72,17 @@ func Walk(v Visitor, node interface{}) {
} }
case *Field: case *Field:
walkCommentGroup(v, n.Doc) if n.Doc != nil {
Walk(v, n.Names) Walk(v, n.Doc)
}
walkIdentList(v, n.Names)
Walk(v, n.Type) Walk(v, n.Type)
Walk(v, n.Tag) if n.Tag != nil {
walkCommentGroup(v, n.Comment) Walk(v, n.Tag)
}
if n.Comment != nil {
Walk(v, n.Comment)
}
case *FieldList: case *FieldList:
for _, f := range n.List { for _, f := range n.List {
@ -82,21 +94,21 @@ func Walk(v Visitor, node interface{}) {
// nothing to do // nothing to do
case *FuncLit: case *FuncLit:
if n != nil { Walk(v, n.Type)
Walk(v, n.Type) Walk(v, n.Body)
}
walkBlockStmt(v, n.Body)
case *CompositeLit: case *CompositeLit:
Walk(v, n.Type) if n.Type != nil {
Walk(v, n.Elts) Walk(v, n.Type)
}
walkExprList(v, n.Elts)
case *ParenExpr: case *ParenExpr:
Walk(v, n.X) Walk(v, n.X)
case *SelectorExpr: case *SelectorExpr:
Walk(v, n.X) Walk(v, n.X)
walkIdent(v, n.Sel) Walk(v, n.Sel)
case *IndexExpr: case *IndexExpr:
Walk(v, n.X) Walk(v, n.X)
@ -104,16 +116,22 @@ func Walk(v Visitor, node interface{}) {
case *SliceExpr: case *SliceExpr:
Walk(v, n.X) Walk(v, n.X)
Walk(v, n.Index) if n.Index != nil {
Walk(v, n.End) Walk(v, n.Index)
}
if n.End != nil {
Walk(v, n.End)
}
case *TypeAssertExpr: case *TypeAssertExpr:
Walk(v, n.X) Walk(v, n.X)
Walk(v, n.Type) if n.Type != nil {
Walk(v, n.Type)
}
case *CallExpr: case *CallExpr:
Walk(v, n.Fun) Walk(v, n.Fun)
Walk(v, n.Args) walkExprList(v, n.Args)
case *StarExpr: case *StarExpr:
Walk(v, n.X) Walk(v, n.X)
@ -131,7 +149,9 @@ func Walk(v Visitor, node interface{}) {
// Types // Types
case *ArrayType: case *ArrayType:
Walk(v, n.Len) if n.Len != nil {
Walk(v, n.Len)
}
Walk(v, n.Elt) Walk(v, n.Elt)
case *StructType: case *StructType:
@ -164,7 +184,7 @@ func Walk(v Visitor, node interface{}) {
// nothing to do // nothing to do
case *LabeledStmt: case *LabeledStmt:
walkIdent(v, n.Label) Walk(v, n.Label)
Walk(v, n.Stmt) Walk(v, n.Stmt)
case *ExprStmt: case *ExprStmt:
@ -174,148 +194,177 @@ func Walk(v Visitor, node interface{}) {
Walk(v, n.X) Walk(v, n.X)
case *AssignStmt: case *AssignStmt:
Walk(v, n.Lhs) walkExprList(v, n.Lhs)
Walk(v, n.Rhs) walkExprList(v, n.Rhs)
case *GoStmt: case *GoStmt:
if n.Call != nil { Walk(v, n.Call)
Walk(v, n.Call)
}
case *DeferStmt: case *DeferStmt:
if n.Call != nil { Walk(v, n.Call)
Walk(v, n.Call)
}
case *ReturnStmt: case *ReturnStmt:
Walk(v, n.Results) walkExprList(v, n.Results)
case *BranchStmt: case *BranchStmt:
walkIdent(v, n.Label) if n.Label != nil {
Walk(v, n.Label)
}
case *BlockStmt: case *BlockStmt:
Walk(v, n.List) walkStmtList(v, n.List)
case *IfStmt: case *IfStmt:
Walk(v, n.Init) if n.Init != nil {
Walk(v, n.Cond) Walk(v, n.Init)
walkBlockStmt(v, n.Body) }
Walk(v, n.Else) if n.Cond != nil {
Walk(v, n.Cond)
}
Walk(v, n.Body)
if n.Else != nil {
Walk(v, n.Else)
}
case *CaseClause: case *CaseClause:
Walk(v, n.Values) walkExprList(v, n.Values)
Walk(v, n.Body) walkStmtList(v, n.Body)
case *SwitchStmt: case *SwitchStmt:
Walk(v, n.Init) if n.Init != nil {
Walk(v, n.Tag) Walk(v, n.Init)
walkBlockStmt(v, n.Body) }
if n.Tag != nil {
Walk(v, n.Tag)
}
Walk(v, n.Body)
case *TypeCaseClause: case *TypeCaseClause:
Walk(v, n.Types) for _, x := range n.Types {
Walk(v, n.Body) Walk(v, x)
}
walkStmtList(v, n.Body)
case *TypeSwitchStmt: case *TypeSwitchStmt:
Walk(v, n.Init) if n.Init != nil {
Walk(v, n.Init)
}
Walk(v, n.Assign) Walk(v, n.Assign)
walkBlockStmt(v, n.Body)
case *CommClause:
Walk(v, n.Lhs)
Walk(v, n.Rhs)
Walk(v, n.Body) Walk(v, n.Body)
case *CommClause:
if n.Lhs != nil {
Walk(v, n.Lhs)
}
if n.Rhs != nil {
Walk(v, n.Rhs)
}
walkStmtList(v, n.Body)
case *SelectStmt: case *SelectStmt:
walkBlockStmt(v, n.Body) Walk(v, n.Body)
case *ForStmt: case *ForStmt:
Walk(v, n.Init) if n.Init != nil {
Walk(v, n.Cond) Walk(v, n.Init)
Walk(v, n.Post) }
walkBlockStmt(v, n.Body) if n.Cond != nil {
Walk(v, n.Cond)
}
if n.Post != nil {
Walk(v, n.Post)
}
Walk(v, n.Body)
case *RangeStmt: case *RangeStmt:
Walk(v, n.Key) Walk(v, n.Key)
Walk(v, n.Value) if n.Value != nil {
Walk(v, n.Value)
}
Walk(v, n.X) Walk(v, n.X)
walkBlockStmt(v, n.Body) Walk(v, n.Body)
// Declarations // Declarations
case *ImportSpec: case *ImportSpec:
walkCommentGroup(v, n.Doc) if n.Doc != nil {
walkIdent(v, n.Name) Walk(v, n.Doc)
}
if n.Name != nil {
Walk(v, n.Name)
}
Walk(v, n.Path) Walk(v, n.Path)
walkCommentGroup(v, n.Comment) if n.Comment != nil {
Walk(v, n.Comment)
}
case *ValueSpec: case *ValueSpec:
walkCommentGroup(v, n.Doc) if n.Doc != nil {
Walk(v, n.Names) Walk(v, n.Doc)
Walk(v, n.Type) }
Walk(v, n.Values) walkIdentList(v, n.Names)
walkCommentGroup(v, n.Comment) if n.Type != nil {
Walk(v, n.Type)
}
walkExprList(v, n.Values)
if n.Comment != nil {
Walk(v, n.Comment)
}
case *TypeSpec: case *TypeSpec:
walkCommentGroup(v, n.Doc) if n.Doc != nil {
walkIdent(v, n.Name) Walk(v, n.Doc)
}
Walk(v, n.Name)
Walk(v, n.Type) Walk(v, n.Type)
walkCommentGroup(v, n.Comment) if n.Comment != nil {
Walk(v, n.Comment)
}
case *BadDecl: case *BadDecl:
// nothing to do // nothing to do
case *GenDecl: case *GenDecl:
walkCommentGroup(v, n.Doc) if n.Doc != nil {
Walk(v, n.Doc)
}
for _, s := range n.Specs { for _, s := range n.Specs {
Walk(v, s) Walk(v, s)
} }
case *FuncDecl: case *FuncDecl:
walkCommentGroup(v, n.Doc) if n.Doc != nil {
Walk(v, n.Doc)
}
if n.Recv != nil { if n.Recv != nil {
Walk(v, n.Recv) Walk(v, n.Recv)
} }
walkIdent(v, n.Name) Walk(v, n.Name)
if n.Type != nil { Walk(v, n.Type)
Walk(v, n.Type) if n.Body != nil {
Walk(v, n.Body)
} }
walkBlockStmt(v, n.Body)
// Files and packages // Files and packages
case *File: case *File:
walkCommentGroup(v, n.Doc) if n.Doc != nil {
walkIdent(v, n.Name) Walk(v, n.Doc)
Walk(v, n.Decls) }
Walk(v, n.Name)
walkDeclList(v, n.Decls)
for _, g := range n.Comments { for _, g := range n.Comments {
Walk(v, g) Walk(v, g)
} }
// don't walk n.Comments - they have been
// visited already through the individual
// nodes
case *Package: case *Package:
for _, f := range n.Files { for _, f := range n.Files {
Walk(v, f) Walk(v, f)
} }
case []*Ident:
for _, x := range n {
Walk(v, x)
}
case []Expr:
for _, x := range n {
Walk(v, x)
}
case []Stmt:
for _, x := range n {
Walk(v, x)
}
case []Decl:
for _, x := range n {
Walk(v, x)
}
default: default:
fmt.Printf("ast.Walk: unexpected type %T", n) fmt.Printf("ast.Walk: unexpected node type %T", n)
panic("ast.Walk") panic("ast.Walk")
} }
@ -323,20 +372,20 @@ func Walk(v Visitor, node interface{}) {
} }
type inspector func(node interface{}) bool type inspector func(Node) bool
func (f inspector) Visit(node interface{}) Visitor { func (f inspector) Visit(node Node) Visitor {
if node != nil && f(node) { if f(node) {
return f return f
} }
return nil return nil
} }
// Inspect traverses an AST in depth-first order: If node != nil, it // Inspect traverses an AST in depth-first order: It starts by calling
// invokes f(node). If f returns true, inspect invokes f for all the // f(node); node must not be nil. If f returns true, Inspect invokes f
// non-nil children of node, recursively. // for all the non-nil children of node, recursively.
// //
func Inspect(ast interface{}, f func(node interface{}) bool) { func Inspect(node Node, f func(Node) bool) {
Walk(inspector(f), ast) Walk(inspector(f), node)
} }

View File

@ -994,7 +994,7 @@ func stripParens(x ast.Expr) ast.Expr {
// parentheses must not be stripped if there are any // parentheses must not be stripped if there are any
// unparenthesized composite literals starting with // unparenthesized composite literals starting with
// a type name // a type name
ast.Inspect(px.X, func(node interface{}) bool { ast.Inspect(px.X, func(node ast.Node) bool {
switch x := node.(type) { switch x := node.(type) {
case *ast.ParenExpr: case *ast.ParenExpr:
// parentheses protect enclosed composite literals // parentheses protect enclosed composite literals