diff --git a/src/pkg/go/ast/filter.go b/src/pkg/go/ast/filter.go index 4c96e71c037..d4b580e003a 100644 --- a/src/pkg/go/ast/filter.go +++ b/src/pkg/go/ast/filter.go @@ -20,24 +20,6 @@ func identListExports(list []*Ident) []*Ident { return list[0:j] } -// fieldName assumes that x is the type of an anonymous field and -// returns the corresponding field name. If x is not an acceptable -// anonymous field, the result is nil. -// -func fieldName(x Expr) *Ident { - switch t := x.(type) { - case *Ident: - return t - case *SelectorExpr: - if _, ok := t.X.(*Ident); ok { - return t.Sel - } - case *StarExpr: - return fieldName(t.X) - } - return nil -} - func fieldListExports(fields *FieldList) (removedFields bool) { if fields == nil { return @@ -203,6 +185,24 @@ func filterIdentList(list []*Ident, f Filter) []*Ident { return list[0:j] } +// fieldName assumes that x is the type of an anonymous field and +// returns the corresponding field name. If x is not an acceptable +// anonymous field, the result is nil. +// +func fieldName(x Expr) *Ident { + switch t := x.(type) { + case *Ident: + return t + case *SelectorExpr: + if _, ok := t.X.(*Ident); ok { + return t.Sel + } + case *StarExpr: + return fieldName(t.X) + } + return nil +} + func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) { if fields == nil { return false @@ -224,6 +224,7 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) { keepField = len(f.Names) > 0 } if keepField { + filterType(f.Type, filter) list[j] = f j++ } @@ -235,26 +236,71 @@ func filterFieldList(fields *FieldList, filter Filter) (removedFields bool) { return } +func filterParamList(fields *FieldList, filter Filter) bool { + if fields == nil { + return false + } + var b bool + for _, f := range fields.List { + if filterType(f.Type, filter) { + b = true + } + } + return b +} + +func filterType(typ Expr, f Filter) bool { + switch t := typ.(type) { + case *Ident: + return f(t.Name) + case *ParenExpr: + return filterType(t.X, f) + case *ArrayType: + return filterType(t.Elt, f) + case *StructType: + if filterFieldList(t.Fields, f) { + t.Incomplete = true + } + return len(t.Fields.List) > 0 + case *FuncType: + b1 := filterParamList(t.Params, f) + b2 := filterParamList(t.Results, f) + return b1 || b2 + case *InterfaceType: + if filterFieldList(t.Methods, f) { + t.Incomplete = true + } + return len(t.Methods.List) > 0 + case *MapType: + b1 := filterType(t.Key, f) + b2 := filterType(t.Value, f) + return b1 || b2 + case *ChanType: + return filterType(t.Value, f) + } + return false +} + func filterSpec(spec Spec, f Filter) bool { switch s := spec.(type) { case *ValueSpec: s.Names = filterIdentList(s.Names, f) - return len(s.Names) > 0 - case *TypeSpec: - if f(s.Name.Name) { + if len(s.Names) > 0 { + filterType(s.Type, f) return true } - switch t := s.Type.(type) { - case *StructType: - if filterFieldList(t.Fields, f) { - t.Incomplete = true - } - return len(t.Fields.List) > 0 - case *InterfaceType: - if filterFieldList(t.Methods, f) { - t.Incomplete = true - } - return len(t.Methods.List) > 0 + case *TypeSpec: + if f(s.Name.Name) { + filterType(s.Type, f) + return true + } + if f != IsExported { + // For general filtering (not just exports), + // filter type even if name is not filtered + // out. + // If the type contains filtered elements, + // keep the declaration. + return filterType(s.Type, f) } } return false @@ -284,6 +330,7 @@ func FilterDecl(decl Decl, f Filter) bool { d.Specs = filterSpecList(d.Specs, f) return len(d.Specs) > 0 case *FuncDecl: + d.Body = nil // strip body return f(d.Name.Name) } return false @@ -292,13 +339,16 @@ func FilterDecl(decl Decl, f Filter) bool { // FilterFile trims the AST for a Go file in place by removing all // names from top-level declarations (including struct field and // interface method names, but not from parameter lists) that don't -// pass through the filter f. If the declaration is empty afterwards, -// the declaration is removed from the AST. -// The File.comments list is not changed. +// pass through the filter f. Function bodies are set to nil. +// If the declaration is empty afterwards, the declaration is +// removed from the AST. The File.comments list is not changed. // // FilterFile returns true if there are any top-level declarations // left after filtering; it returns false otherwise. // +// To trim an AST such that only exported nodes remain, call +// FilterFile with IsExported as filter function. +// func FilterFile(src *File, f Filter) bool { j := 0 for _, d := range src.Decls { @@ -314,14 +364,17 @@ func FilterFile(src *File, f Filter) bool { // FilterPackage trims the AST for a Go package in place by removing all // names from top-level declarations (including struct field and // interface method names, but not from parameter lists) that don't -// pass through the filter f. If the declaration is empty afterwards, -// the declaration is removed from the AST. -// The pkg.Files list is not changed, so that file names and top-level -// package comments don't get lost. +// pass through the filter f. Function bodies are set to nil. +// If the declaration is empty afterwards, the declaration is +// removed from the AST. The pkg.Files list is not changed, so +// that file names and top-level package comments don't get lost. // // FilterPackage returns true if there are any top-level declarations // left after filtering; it returns false otherwise. // +// To trim an AST such that only exported nodes remain, call +// FilterPackage with IsExported as filter function. +// func FilterPackage(pkg *Package, f Filter) bool { hasDecls := false for _, src := range pkg.Files { @@ -467,17 +520,16 @@ func MergePackageFiles(pkg *Package, mode MergeMode) *File { seen := make(map[string]bool) for _, f := range pkg.Files { for _, imp := range f.Imports { - path := imp.Path.Value - if !seen[path] { - //TODO: consider handling cases where: + if path := imp.Path.Value; !seen[path] { + // TODO: consider handling cases where: // - 2 imports exist with the same import path but // have different local names (one should probably // keep both of them) // - 2 imports exist but only one has a comment // - 2 imports exist and they both have (possibly // different) comments - seen[path] = true imports = append(imports, imp) + seen[path] = true } } } diff --git a/src/pkg/go/parser/filter_test.go b/src/pkg/go/parser/filter_test.go new file mode 100644 index 00000000000..f1672a9dd8d --- /dev/null +++ b/src/pkg/go/parser/filter_test.go @@ -0,0 +1,91 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// A test that ensures that ast.FileExports(file) and +// ast.FilterFile(file, ast.IsExported) produce the +// same trimmed AST given the same input file for all +// files under runtime.GOROOT(). +// +// The test is here because it requires parsing, and the +// parser imports AST already (avoid import cycle). + +package parser + +import ( + "bytes" + "go/ast" + "go/printer" + "go/token" + "os" + "path/filepath" + "runtime" + "strings" + "testing" +) + +// For a short test run, limit the number of files to a few. +// Set to a large value to test all files under GOROOT. +const maxFiles = 10 + +type visitor struct { + t *testing.T + n int +} + +func (v *visitor) VisitDir(path string, f *os.FileInfo) bool { + return true +} + +func str(f *ast.File, fset *token.FileSet) string { + var buf bytes.Buffer + printer.Fprint(&buf, fset, f) + return buf.String() +} + +func (v *visitor) VisitFile(path string, f *os.FileInfo) { + // exclude files that clearly don't make it + if !f.IsRegular() || len(f.Name) > 0 && f.Name[0] == '.' || !strings.HasSuffix(f.Name, ".go") { + return + } + + // should we stop early for quick test runs? + if v.n <= 0 { + return + } + v.n-- + + fset := token.NewFileSet() + + // get two ASTs f1, f2 of identical structure; + // parsing twice is easiest + f1, err := ParseFile(fset, path, nil, ParseComments) + if err != nil { + v.t.Logf("parse error (1): %s", err) + return + } + + f2, err := ParseFile(fset, path, nil, ParseComments) + if err != nil { + v.t.Logf("parse error (2): %s", err) + return + } + + b1 := ast.FileExports(f1) + b2 := ast.FilterFile(f2, ast.IsExported) + if b1 != b2 { + v.t.Errorf("filtering failed (a): %s", path) + return + } + + s1 := str(f1, fset) + s2 := str(f2, fset) + if s1 != s2 { + v.t.Errorf("filtering failed (b): %s", path) + return + } +} + +func TestFilter(t *testing.T) { + filepath.Walk(runtime.GOROOT(), &visitor{t, maxFiles}, nil) +}