diff --git a/go/ssa/testmain.go b/go/ssa/testmain.go index 242207b003..223013d56e 100644 --- a/go/ssa/testmain.go +++ b/go/ssa/testmain.go @@ -17,6 +17,67 @@ import ( "code.google.com/p/go.tools/go/types" ) +// FindTests returns the list of packages that define at least one Test, +// Example or Benchmark function (as defined by "go test"), and the +// lists of all such functions. +// +func FindTests(pkgs []*Package) (testpkgs []*Package, tests, benchmarks, examples []*Function) { + if len(pkgs) == 0 { + return + } + prog := pkgs[0].Prog + + // The first two of these may be nil: if the program doesn't import "testing", + // it can't contain any tests, but it may yet contain Examples. + var testSig *types.Signature // func(*testing.T) + var benchmarkSig *types.Signature // func(*testing.B) + var exampleSig = types.NewSignature(nil, nil, nil, nil, false) // func() + + // Obtain the types from the parameters of testing.Main(). + if testingPkg := prog.ImportedPackage("testing"); testingPkg != nil { + params := testingPkg.Func("Main").Signature.Params() + testSig = funcField(params.At(1).Type()) + benchmarkSig = funcField(params.At(2).Type()) + } + + seen := make(map[*Package]bool) + for _, pkg := range pkgs { + if pkg.Prog != prog { + panic("wrong Program") + } + + // TODO(adonovan): use a stable order, e.g. lexical. + for _, mem := range pkg.Members { + if f, ok := mem.(*Function); ok && + ast.IsExported(f.Name()) && + strings.HasSuffix(prog.Fset.Position(f.Pos()).Filename, "_test.go") { + + switch { + case testSig != nil && isTestSig(f, "Test", testSig): + tests = append(tests, f) + case benchmarkSig != nil && isTestSig(f, "Benchmark", benchmarkSig): + benchmarks = append(benchmarks, f) + case isTestSig(f, "Example", exampleSig): + examples = append(examples, f) + default: + continue + } + + if !seen[pkg] { + seen[pkg] = true + testpkgs = append(testpkgs, pkg) + } + } + } + } + return +} + +// Like isTest, but checks the signature too. +func isTestSig(f *Function, prefix string, sig *types.Signature) bool { + return isTest(f.Name(), prefix) && types.Identical(f.Signature, sig) +} + // If non-nil, testMainStartBodyHook is called immediately after // startBody for main.init and main.main, making it easy for users to // add custom imports and initialization steps for proprietary build @@ -30,9 +91,11 @@ var testMainStartBodyHook func(*Function) // It returns nil if the program contains no tests. // func (prog *Program) CreateTestMainPackage(pkgs ...*Package) *Package { + pkgs, tests, benchmarks, examples := FindTests(pkgs) if len(pkgs) == 0 { return nil } + testmain := &Package{ Prog: prog, Members: make(map[string]Member), @@ -54,26 +117,12 @@ func (prog *Program) CreateTestMainPackage(pkgs ...*Package) *Package { testMainStartBodyHook(init) } - // TODO(adonovan): use lexical order. - var expfuncs []*Function // all exported functions of *_test.go in pkgs, unordered + // Initialize packages to test. for _, pkg := range pkgs { - if pkg.Prog != prog { - panic("wrong Program") - } - // Initialize package to test. var v Call v.Call.Value = pkg.init v.setType(types.NewTuple()) init.emit(&v) - - // Enumerate its possible tests/benchmarks. - for _, mem := range pkg.Members { - if f, ok := mem.(*Function); ok && - ast.IsExported(f.Name()) && - strings.HasSuffix(prog.Fset.Position(f.Pos()).Filename, "_test.go") { - expfuncs = append(expfuncs, f) - } - } } init.emit(new(Return)) init.finishBody() @@ -81,27 +130,6 @@ func (prog *Program) CreateTestMainPackage(pkgs ...*Package) *Package { testmain.Object.MarkComplete() testmain.Members[init.name] = init - testingPkg := prog.ImportedPackage("testing") - if testingPkg == nil { - // If the program doesn't import "testing", it can't - // contain any tests. - // TODO(adonovan): but it might contain Examples. - // Support them (by just calling them directly). - return nil - } - testingMain := testingPkg.Func("Main") - testingMainParams := testingMain.Signature.Params() - - // The generated code is as if compiled from this: - // - // func main() { - // match := func(_, _ string) (bool, error) { return true, nil } - // tests := []testing.InternalTest{{"TestFoo", TestFoo}, ...} - // benchmarks := []testing.InternalBenchmark{...} - // examples := []testing.InternalExample{...} - // testing.Main(match, tests, benchmarks, examples) - // } - main := &Function{ name: "main", Signature: new(types.Signature), @@ -110,41 +138,66 @@ func (prog *Program) CreateTestMainPackage(pkgs ...*Package) *Package { Pkg: testmain, } - matcher := &Function{ - name: "matcher", - Signature: testingMainParams.At(0).Type().(*types.Signature), - Synthetic: "test matcher predicate", - parent: main, - Pkg: testmain, - Prog: prog, - } - main.AnonFuncs = append(main.AnonFuncs, matcher) - matcher.startBody() - matcher.emit(&Return{Results: []Value{vTrue, nilConst(types.Universe.Lookup("error").Type())}}) - matcher.finishBody() - main.startBody() if testMainStartBodyHook != nil { testMainStartBodyHook(main) } - var c Call - c.Call.Value = testingMain + if testingPkg := prog.ImportedPackage("testing"); testingPkg != nil { + testingMain := testingPkg.Func("Main") + testingMainParams := testingMain.Signature.Params() - tests := testMainSlice(main, expfuncs, "Test", testingMainParams.At(1).Type()) - benchmarks := testMainSlice(main, expfuncs, "Benchmark", testingMainParams.At(2).Type()) - examples := testMainSlice(main, expfuncs, "Example", testingMainParams.At(3).Type()) - _, noTests := tests.(*Const) // i.e. nil slice - _, noBenchmarks := benchmarks.(*Const) - _, noExamples := examples.(*Const) - if noTests && noBenchmarks && noExamples { - return nil + // The generated code is as if compiled from this: + // + // func main() { + // match := func(_, _ string) (bool, error) { return true, nil } + // tests := []testing.InternalTest{{"TestFoo", TestFoo}, ...} + // benchmarks := []testing.InternalBenchmark{...} + // examples := []testing.InternalExample{...} + // testing.Main(match, tests, benchmarks, examples) + // } + + matcher := &Function{ + name: "matcher", + Signature: testingMainParams.At(0).Type().(*types.Signature), + Synthetic: "test matcher predicate", + parent: main, + Pkg: testmain, + Prog: prog, + } + main.AnonFuncs = append(main.AnonFuncs, matcher) + matcher.startBody() + matcher.emit(&Return{Results: []Value{vTrue, nilConst(types.Universe.Lookup("error").Type())}}) + matcher.finishBody() + + // Emit call: testing.Main(matcher, tests, benchmarks, examples). + var c Call + c.Call.Value = testingMain + c.Call.Args = []Value{ + matcher, + testMainSlice(main, tests, testingMainParams.At(1).Type()), + testMainSlice(main, benchmarks, testingMainParams.At(2).Type()), + testMainSlice(main, examples, testingMainParams.At(3).Type()), + } + emitTailCall(main, &c) + } else { + // The program does not import "testing", but FindTests + // returned non-nil, which must mean there were Examples + // but no Tests or Benchmarks. + // We'll simply call them from testmain.main; this will + // ensure they don't panic, but will not check any + // "Output:" comments. + for _, eg := range examples { + var c Call + c.Call.Value = eg + c.setType(types.NewTuple()) + main.emit(&c) + } + main.emit(&Return{}) + main.currentBlock = nil } - c.Call.Args = []Value{matcher, tests, benchmarks, examples} - // Emit: testing.Main(nil, tests, benchmarks, examples) - emitTailCall(main, &c) main.finishBody() testmain.Members["main"] = main @@ -166,27 +219,17 @@ func (prog *Program) CreateTestMainPackage(pkgs ...*Package) *Package { // testMainSlice emits to fn code to construct a slice of type slice // (one of []testing.Internal{Test,Benchmark,Example}) for all -// functions in expfuncs whose name starts with prefix (one of -// "Test", "Benchmark" or "Example") and whose type is appropriate. -// It returns the slice value. +// functions in testfuncs. It returns the slice value. // -func testMainSlice(fn *Function, expfuncs []*Function, prefix string, slice types.Type) Value { - tElem := slice.(*types.Slice).Elem() - tFunc := tElem.Underlying().(*types.Struct).Field(1).Type() - - var testfuncs []*Function - for _, f := range expfuncs { - if isTest(f.Name(), prefix) && types.Identical(f.Signature, tFunc) { - testfuncs = append(testfuncs, f) - } - } +func testMainSlice(fn *Function, testfuncs []*Function, slice types.Type) Value { if testfuncs == nil { return nilConst(slice) } + tElem := slice.(*types.Slice).Elem() tPtrString := types.NewPointer(tString) tPtrElem := types.NewPointer(tElem) - tPtrFunc := types.NewPointer(tFunc) + tPtrFunc := types.NewPointer(funcField(slice)) // Emit: array = new [n]testing.InternalTest tArray := types.NewArray(tElem, int64(len(testfuncs))) @@ -221,6 +264,12 @@ func testMainSlice(fn *Function, expfuncs []*Function, prefix string, slice type return fn.emit(sl) } +// Given the type of one of the three slice parameters of testing.Main, +// returns the function type. +func funcField(slice types.Type) *types.Signature { + return slice.(*types.Slice).Elem().Underlying().(*types.Struct).Field(1).Type().(*types.Signature) +} + // Plundered from $GOROOT/src/cmd/go/test.go // isTest tells whether name looks like a test (or benchmark, according to prefix). diff --git a/go/ssa/testmain_test.go b/go/ssa/testmain_test.go new file mode 100644 index 0000000000..3a9eacc760 --- /dev/null +++ b/go/ssa/testmain_test.go @@ -0,0 +1,122 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package ssa_test + +// Tests of FindTests. CreateTestMainPackage is tested via the interpreter. +// TODO(adonovan): test the 'pkgs' result from FindTests. + +import ( + "fmt" + "sort" + "testing" + + "code.google.com/p/go.tools/go/loader" + "code.google.com/p/go.tools/go/ssa" +) + +func create(t *testing.T, content string) []*ssa.Package { + var conf loader.Config + f, err := conf.ParseFile("foo_test.go", content) + if err != nil { + t.Fatal(err) + } + conf.CreateFromFiles("foo", f) + + iprog, err := conf.Load() + if err != nil { + t.Fatal(err) + } + + // We needn't call Build. + return ssa.Create(iprog, ssa.SanityCheckFunctions).AllPackages() +} + +func TestFindTests(t *testing.T) { + test := ` +package foo + +import "testing" + +type T int + +// Tests: +func Test(t *testing.T) {} +func TestA(t *testing.T) {} +func TestB(t *testing.T) {} + +// Not tests: +func testC(t *testing.T) {} +func TestD() {} +func testE(t *testing.T) int { return 0 } +func (T) Test(t *testing.T) {} + +// Benchmarks: +func Benchmark(*testing.B) {} +func BenchmarkA(b *testing.B) {} +func BenchmarkB(*testing.B) {} + +// Not benchmarks: +func benchmarkC(t *testing.T) {} +func BenchmarkD() {} +func benchmarkE(t *testing.T) int { return 0 } +func (T) Benchmark(t *testing.T) {} + +// Examples: +func Example() {} +func ExampleA() {} + +// Not examples: +func exampleC() {} +func ExampleD(t *testing.T) {} +func exampleE() int { return 0 } +func (T) Example() {} +` + pkgs := create(t, test) + _, tests, benchmarks, examples := ssa.FindTests(pkgs) + + sort.Sort(funcsByPos(tests)) + if got, want := fmt.Sprint(tests), "[foo.Test foo.TestA foo.TestB]"; got != want { + t.Errorf("FindTests.tests = %s, want %s", got, want) + } + + sort.Sort(funcsByPos(benchmarks)) + if got, want := fmt.Sprint(benchmarks), "[foo.Benchmark foo.BenchmarkA foo.BenchmarkB]"; got != want { + t.Errorf("FindTests.benchmarks = %s, want %s", got, want) + } + + sort.Sort(funcsByPos(examples)) + if got, want := fmt.Sprint(examples), "[foo.Example foo.ExampleA]"; got != want { + t.Errorf("FindTests examples = %s, want %s", got, want) + } +} + +func TestFindTestsTesting(t *testing.T) { + test := ` +package foo + +// foo does not import "testing", but defines Examples. + +func Example() {} +func ExampleA() {} +` + pkgs := create(t, test) + _, tests, benchmarks, examples := ssa.FindTests(pkgs) + if len(tests) > 0 { + t.Errorf("FindTests.tests = %s, want none", tests) + } + if len(benchmarks) > 0 { + t.Errorf("FindTests.benchmarks = %s, want none", benchmarks) + } + sort.Sort(funcsByPos(examples)) + if got, want := fmt.Sprint(examples), "[foo.Example foo.ExampleA]"; got != want { + t.Errorf("FindTests examples = %s, want %s", got, want) + } +} + +type funcsByPos []*ssa.Function + +func (p funcsByPos) Len() int { return len(p) } +func (p funcsByPos) Less(i, j int) bool { return p[i].Pos() < p[j].Pos() } +func (p funcsByPos) Swap(i, j int) { p[i], p[j] = p[j], p[i] }