diff --git a/src/cmd/vet/Makefile b/src/cmd/vet/Makefile index 2a35d1ae37f..d90b5f9d54b 100644 --- a/src/cmd/vet/Makefile +++ b/src/cmd/vet/Makefile @@ -4,4 +4,4 @@ test testshort: go build - ../../../test/errchk ./vet -printfuncs='Warn:1,Warnf:1' print.go + ../../../test/errchk ./vet -printfuncs='Warn:1,Warnf:1' print.go rangeloop.go diff --git a/src/cmd/vet/main.go b/src/cmd/vet/main.go index d2a7c6e55b1..76a4896bfa0 100644 --- a/src/cmd/vet/main.go +++ b/src/cmd/vet/main.go @@ -30,6 +30,7 @@ var ( vetPrintf = flag.Bool("printf", false, "check printf-like invocations") vetStructTags = flag.Bool("structtags", false, "check that struct field tags have canonical format") vetUntaggedLiteral = flag.Bool("composites", false, "check that composite literals used type-tagged elements") + vetRangeLoops = flag.Bool("rangeloops", false, "check that range loop variables are used correctly") ) // setExit sets the value for os.Exit when it is called, later. It @@ -60,7 +61,7 @@ func main() { flag.Parse() // If a check is named explicitly, turn off the 'all' flag. - if *vetMethods || *vetPrintf || *vetStructTags || *vetUntaggedLiteral { + if *vetMethods || *vetPrintf || *vetStructTags || *vetUntaggedLiteral || *vetRangeLoops { *vetAll = false } @@ -197,6 +198,8 @@ func (f *File) Visit(node ast.Node) ast.Visitor { f.walkMethodDecl(n) case *ast.InterfaceType: f.walkInterfaceType(n) + case *ast.RangeStmt: + f.walkRangeStmt(n) } return f } @@ -206,6 +209,16 @@ func (f *File) walkCall(call *ast.CallExpr, name string) { f.checkFmtPrintfCall(call, name) } +// walkCallExpr walks a call expression. +func (f *File) walkCallExpr(call *ast.CallExpr) { + switch x := call.Fun.(type) { + case *ast.Ident: + f.walkCall(call, x.Name) + case *ast.SelectorExpr: + f.walkCall(call, x.Sel.Name) + } +} + // walkCompositeLit walks a composite literal. func (f *File) walkCompositeLit(c *ast.CompositeLit) { f.checkUntaggedLiteral(c) @@ -242,12 +255,7 @@ func (f *File) walkInterfaceType(t *ast.InterfaceType) { } } -// walkCallExpr walks a call expression. -func (f *File) walkCallExpr(call *ast.CallExpr) { - switch x := call.Fun.(type) { - case *ast.Ident: - f.walkCall(call, x.Name) - case *ast.SelectorExpr: - f.walkCall(call, x.Sel.Name) - } +// walkRangeStmt walks a range statment. +func (f *File) walkRangeStmt(n *ast.RangeStmt) { + checkRangeLoop(f, n) } diff --git a/src/cmd/vet/rangeloop.go b/src/cmd/vet/rangeloop.go new file mode 100644 index 00000000000..2fdb0b62160 --- /dev/null +++ b/src/cmd/vet/rangeloop.go @@ -0,0 +1,104 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +/* +This file contains the code to check range loop variables bound inside function +literals that are deferred or launched in new goroutines. We only check +instances where the defer or go statement is the last statement in the loop +body, as otherwise we would need whole program analysis. + +For example: + + for i, v := range s { + go func() { + println(i, v) // not what you might expect + }() + } + +See: http://golang.org/doc/go_faq.html#closures_and_goroutines +*/ + +package main + +import "go/ast" + +// checkRangeLoop walks the body of the provided range statement, checking if +// its index or value variables are used unsafely inside goroutines or deferred +// function literals. +func checkRangeLoop(f *File, n *ast.RangeStmt) { + if !*vetRangeLoops && !*vetAll { + return + } + key, _ := n.Key.(*ast.Ident) + val, _ := n.Value.(*ast.Ident) + if key == nil && val == nil { + return + } + sl := n.Body.List + if len(sl) == 0 { + return + } + var last *ast.CallExpr + switch s := sl[len(sl)-1].(type) { + case *ast.GoStmt: + last = s.Call + case *ast.DeferStmt: + last = s.Call + default: + return + } + lit, ok := last.Fun.(*ast.FuncLit) + if !ok { + return + } + ast.Inspect(lit.Body, func(n ast.Node) bool { + if n, ok := n.(*ast.Ident); ok && n.Obj != nil && (n.Obj == key.Obj || n.Obj == val.Obj) { + f.Warn(n.Pos(), "range variable", n.Name, "enclosed by function") + } + return true + }) +} + +func BadRangeLoopsUsedInTests() { + var s []int + for i, v := range s { + go func() { + println(i) // ERROR "range variable i enclosed by function" + println(v) // ERROR "range variable v enclosed by function" + }() + } + for i, v := range s { + defer func() { + println(i) // ERROR "range variable i enclosed by function" + println(v) // ERROR "range variable v enclosed by function" + }() + } + for i := range s { + go func() { + println(i) // ERROR "range variable i enclosed by function" + }() + } + for _, v := range s { + go func() { + println(v) // ERROR "range variable v enclosed by function" + }() + } + for i, v := range s { + go func() { + println(i, v) + }() + println("unfortunately, we don't catch the error above because of this statement") + } + for i, v := range s { + go func(i, v int) { + println(i, v) + }(i, v) + } + for i, v := range s { + i, v := i, v + go func() { + println(i, v) + }() + } +}