diff --git a/go/analysis/passes/loopclosure/loopclosure.go b/go/analysis/passes/loopclosure/loopclosure.go new file mode 100644 index 00000000000..e71630bb9e9 --- /dev/null +++ b/go/analysis/passes/loopclosure/loopclosure.go @@ -0,0 +1,126 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package loopclosure + +import ( + "go/ast" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" + "golang.org/x/tools/go/ast/inspector" +) + +// TODO(adonovan): also report an error for the following structure, +// which is often used to ensure that deferred calls do not accumulate +// in a loop: +// +// for i, x := range c { +// func() { +// ...reference to i or x... +// }() +// } + +var Analyzer = &analysis.Analyzer{ + Name: "loopclosure", + Doc: `check references to loop variables from within nested functions + +This analyzer checks for references to loop variables from within a +function literal inside the loop body. It checks only instances where +the function literal is called in a defer or go statement that is the +last statement in the loop body, as otherwise we would need whole +program analysis. + +For example: + + for i, v := range s { + go func() { + println(i, v) // not what you might expect + }() + } + +See: https://golang.org/doc/go_faq.html#closures_and_goroutines`, + Requires: []*analysis.Analyzer{inspect.Analyzer}, + Run: run, +} + +func run(pass *analysis.Pass) (interface{}, error) { + inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector) + + nodeFilter := []ast.Node{ + (*ast.RangeStmt)(nil), + (*ast.ForStmt)(nil), + } + inspect.Preorder(nodeFilter, func(n ast.Node) { + // Find the variables updated by the loop statement. + var vars []*ast.Ident + addVar := func(expr ast.Expr) { + if id, ok := expr.(*ast.Ident); ok { + vars = append(vars, id) + } + } + var body *ast.BlockStmt + switch n := n.(type) { + case *ast.RangeStmt: + body = n.Body + addVar(n.Key) + addVar(n.Value) + case *ast.ForStmt: + body = n.Body + switch post := n.Post.(type) { + case *ast.AssignStmt: + // e.g. for p = head; p != nil; p = p.next + for _, lhs := range post.Lhs { + addVar(lhs) + } + case *ast.IncDecStmt: + // e.g. for i := 0; i < n; i++ + addVar(post.X) + } + } + if vars == nil { + return + } + + // Inspect a go or defer statement + // if it's the last one in the loop body. + // (We give up if there are following statements, + // because it's hard to prove go isn't followed by wait, + // or defer by return.) + if len(body.List) == 0 { + return + } + var last *ast.CallExpr + switch s := body.List[len(body.List)-1].(type) { + case *ast.GoStmt: + last = s.Call + case *ast.DeferStmt: + last = s.Call + default: + return + } + lit, ok := last.Fun.(*ast.FuncLit) + if !ok { + return + } + ast.Inspect(lit.Body, func(n ast.Node) bool { + id, ok := n.(*ast.Ident) + if !ok || id.Obj == nil { + return true + } + if pass.TypesInfo.Types[id].Type == nil { + // Not referring to a variable (e.g. struct field name) + return true + } + for _, v := range vars { + if v.Obj == id.Obj { + pass.Reportf(id.Pos(), "loop variable %s captured by func literal", + id.Name) + } + } + return true + }) + }) + return nil, nil +} diff --git a/go/analysis/passes/loopclosure/loopclosure_test.go b/go/analysis/passes/loopclosure/loopclosure_test.go new file mode 100644 index 00000000000..8253ab72231 --- /dev/null +++ b/go/analysis/passes/loopclosure/loopclosure_test.go @@ -0,0 +1,13 @@ +package loopclosure_test + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" + "golang.org/x/tools/go/analysis/passes/loopclosure" +) + +func Test(t *testing.T) { + testdata := analysistest.TestData() + analysistest.Run(t, testdata, loopclosure.Analyzer, "a") +} diff --git a/go/analysis/passes/vet/testdata/rangeloop.go b/go/analysis/passes/loopclosure/testdata/src/a/a.go similarity index 58% rename from go/analysis/passes/vet/testdata/rangeloop.go rename to go/analysis/passes/loopclosure/testdata/src/a/a.go index cd3b4cbc452..e1f7bad5817 100644 --- a/go/analysis/passes/vet/testdata/rangeloop.go +++ b/go/analysis/passes/loopclosure/testdata/src/a/a.go @@ -2,32 +2,32 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// This file contains tests for the rangeloop checker. +// This file contains tests for the loopclosure checker. package testdata -func RangeLoopTests() { +func _() { var s []int for i, v := range s { go func() { - println(i) // ERROR "loop variable i captured by func literal" - println(v) // ERROR "loop variable v captured by func literal" + println(i) // want "loop variable i captured by func literal" + println(v) // want "loop variable v captured by func literal" }() } for i, v := range s { defer func() { - println(i) // ERROR "loop variable i captured by func literal" - println(v) // ERROR "loop variable v captured by func literal" + println(i) // want "loop variable i captured by func literal" + println(v) // want "loop variable v captured by func literal" }() } for i := range s { go func() { - println(i) // ERROR "loop variable i captured by func literal" + println(i) // want "loop variable i captured by func literal" }() } for _, v := range s { go func() { - println(v) // ERROR "loop variable v captured by func literal" + println(v) // want "loop variable v captured by func literal" }() } for i, v := range s { @@ -53,7 +53,7 @@ func RangeLoopTests() { var f int for x[0], f = range s { go func() { - _ = f // ERROR "loop variable f captured by func literal" + _ = f // want "loop variable f captured by func literal" }() } type T struct { @@ -62,19 +62,19 @@ func RangeLoopTests() { for _, v := range s { go func() { _ = T{v: 1} - _ = []int{v: 1} // ERROR "loop variable v captured by func literal" + _ = map[int]int{v: 1} // want "loop variable v captured by func literal" }() } // ordinary for-loops for i := 0; i < 10; i++ { go func() { - print(i) // ERROR "loop variable i captured by func literal" + print(i) // want "loop variable i captured by func literal" }() } for i, j := 0, 1; i < 100; i, j = j, i+j { go func() { - print(j) // ERROR "loop variable j captured by func literal" + print(j) // want "loop variable j captured by func literal" }() } type cons struct { @@ -82,9 +82,9 @@ func RangeLoopTests() { cdr *cons } var head *cons - for p := head; p != nil; p = p.next { + for p := head; p != nil; p = p.cdr { go func() { - print(p.car) // ERROR "loop variable p captured by func literal" + print(p.car) // want "loop variable p captured by func literal" }() } } diff --git a/go/analysis/passes/vet/rangeloop.go b/go/analysis/passes/vet/rangeloop.go deleted file mode 100644 index 19642df2744..00000000000 --- a/go/analysis/passes/vet/rangeloop.go +++ /dev/null @@ -1,107 +0,0 @@ -// +build ignore - -// Copyright 2012 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -/* -This file contains the code to check range loop variables bound inside function -literals that are deferred or launched in new goroutines. We only check -instances where the defer or go statement is the last statement in the loop -body, as otherwise we would need whole program analysis. - -For example: - - for i, v := range s { - go func() { - println(i, v) // not what you might expect - }() - } - -See: https://golang.org/doc/go_faq.html#closures_and_goroutines -*/ - -package main - -import "go/ast" - -func init() { - register("rangeloops", - "check that loop variables are used correctly", - checkLoop, - rangeStmt, forStmt) -} - -// checkLoop walks the body of the provided loop statement, checking whether -// its index or value variables are used unsafely inside goroutines or deferred -// function literals. -func checkLoop(f *File, node ast.Node) { - // Find the variables updated by the loop statement. - var vars []*ast.Ident - addVar := func(expr ast.Expr) { - if id, ok := expr.(*ast.Ident); ok { - vars = append(vars, id) - } - } - var body *ast.BlockStmt - switch n := node.(type) { - case *ast.RangeStmt: - body = n.Body - addVar(n.Key) - addVar(n.Value) - case *ast.ForStmt: - body = n.Body - switch post := n.Post.(type) { - case *ast.AssignStmt: - // e.g. for p = head; p != nil; p = p.next - for _, lhs := range post.Lhs { - addVar(lhs) - } - case *ast.IncDecStmt: - // e.g. for i := 0; i < n; i++ - addVar(post.X) - } - } - if vars == nil { - return - } - - // Inspect a go or defer statement - // if it's the last one in the loop body. - // (We give up if there are following statements, - // because it's hard to prove go isn't followed by wait, - // or defer by return.) - if len(body.List) == 0 { - return - } - var last *ast.CallExpr - switch s := body.List[len(body.List)-1].(type) { - case *ast.GoStmt: - last = s.Call - case *ast.DeferStmt: - last = s.Call - default: - return - } - lit, ok := last.Fun.(*ast.FuncLit) - if !ok { - return - } - ast.Inspect(lit.Body, func(n ast.Node) bool { - id, ok := n.(*ast.Ident) - if !ok || id.Obj == nil { - return true - } - if f.pkg.types[id].Type == nil { - // Not referring to a variable (e.g. struct field name) - return true - } - for _, v := range vars { - if v.Obj == id.Obj { - f.Badf(id.Pos(), "loop variable %s captured by func literal", - id.Name) - } - } - return true - }) -}