diff --git a/go/types/stmt.go b/go/types/stmt.go index 05c8d27197..5f8e7e1cba 100644 --- a/go/types/stmt.go +++ b/go/types/stmt.go @@ -311,11 +311,14 @@ func (check *checker) stmt(ctxt stmtContext, s ast.Stmt) { case *ast.BlockStmt: check.openScope(s) + defer check.closeScope() + check.stmtList(inner, s.List) - check.closeScope() case *ast.IfStmt: check.openScope(s) + defer check.closeScope() + check.initStmt(s.Init) var x operand check.expr(&x, s.Cond) @@ -326,11 +329,12 @@ func (check *checker) stmt(ctxt stmtContext, s ast.Stmt) { if s.Else != nil { check.stmt(inner, s.Else) } - check.closeScope() case *ast.SwitchStmt: inner |= inBreakable check.openScope(s) + defer check.closeScope() + check.initStmt(s.Init) var x operand tag := s.Tag @@ -361,12 +365,12 @@ func (check *checker) stmt(ctxt stmtContext, s ast.Stmt) { check.stmtList(inner, clause.Body) check.closeScope() } - check.closeScope() case *ast.TypeSwitchStmt: inner |= inBreakable check.openScope(s) defer check.closeScope() + check.initStmt(s.Init) // A type switch guard must be of the form: @@ -473,17 +477,46 @@ func (check *checker) stmt(ctxt stmtContext, s ast.Stmt) { if clause == nil { continue // error reported before } - check.openScope(clause) - if s := clause.Comm; s != nil { - check.stmt(inner, s) // TODO(gri) check correctness of c.Comm (must be Send/RecvStmt) + + // clause.Comm must be a SendStmt, RecvStmt, or default case + valid := false + var rhs ast.Expr // rhs of RecvStmt, or nil + switch s := clause.Comm.(type) { + case nil, *ast.SendStmt: + valid = true + case *ast.AssignStmt: + if len(s.Rhs) == 1 { + rhs = s.Rhs[0] + } + case *ast.ExprStmt: + rhs = s.X + } + + // if present, rhs must be a receive operation + if rhs != nil { + if x, _ := unparen(rhs).(*ast.UnaryExpr); x != nil && x.Op == token.ARROW { + valid = true + } + } + + if !valid { + check.errorf(clause.Comm.Pos(), "select case must be send or receive (possibly with assignment)") + continue + } + + check.openScope(s) + defer check.closeScope() + if clause.Comm != nil { + check.stmt(inner, clause.Comm) } check.stmtList(inner, clause.Body) - check.closeScope() } case *ast.ForStmt: inner |= inBreakable | inContinuable check.openScope(s) + defer check.closeScope() + check.initStmt(s.Init) if s.Cond != nil { var x operand @@ -494,7 +527,6 @@ func (check *checker) stmt(ctxt stmtContext, s ast.Stmt) { } check.initStmt(s.Post) check.stmt(inner, s.Body) - check.closeScope() case *ast.RangeStmt: inner |= inBreakable | inContinuable diff --git a/go/types/testdata/stmt0.src b/go/types/testdata/stmt0.src index deae9d9461..0921d9f855 100644 --- a/go/types/testdata/stmt0.src +++ b/go/types/testdata/stmt0.src @@ -186,20 +186,30 @@ func selects() { var ( ch chan int sc chan <- bool - x int ) select { case <-ch: - ch <- x + case (<-ch): + case t := <-ch: + _ = t + case t := (<-ch): + _ = t case t, ok := <-ch: - x = t - _ = ok + _, _ = t, ok + case t, ok := (<-ch): + _, _ = t, ok case <-sc /* ERROR "cannot receive from send-only channel" */ : } select { default: default /* ERROR "multiple defaults" */ : } + select { + case a, b := <-ch: + _, b = a, b + case x /* ERROR send or receive */ : + case a /* ERROR send or receive */ := ch: + } } func gos() {