diff --git a/src/cmd/compile/internal/syntax/branches.go b/src/cmd/compile/internal/syntax/branches.go index 60790974265..3d7ffed3742 100644 --- a/src/cmd/compile/internal/syntax/branches.go +++ b/src/cmd/compile/internal/syntax/branches.go @@ -6,12 +6,10 @@ package syntax import "fmt" -// TODO(gri) consider making this part of the parser code - // checkBranches checks correct use of labels and branch -// statements (break, continue, goto) in a function body. +// statements (break, continue, fallthrough, goto) in a function body. // It catches: -// - misplaced breaks and continues +// - misplaced breaks, continues, and fallthroughs // - bad labeled breaks and continues // - invalid, unused, duplicate, and missing labels // - gotos jumping over variable declarations and into blocks @@ -123,6 +121,7 @@ func (ls *labelScope) enclosingTarget(b *block, name string) *LabeledStmt { type targets struct { breaks Stmt // *ForStmt, *SwitchStmt, *SelectStmt, or nil continues *ForStmt // or nil + caseIndex int // case index of immediately enclosing switch statement, or < 0 } // blockBranches processes a block's body starting at start and returns the @@ -163,7 +162,10 @@ func (ls *labelScope) blockBranches(parent *block, ctxt targets, lstmt *LabeledS fwdGotos = append(fwdGotos, ls.blockBranches(b, ctxt, lstmt, start, body)...) } - for _, stmt := range body { + // A fallthrough statement counts as last statement in a statement + // list even if there are trailing empty statements; remove them. + stmtList := trimTrailingEmptyStmts(body) + for stmtIndex, stmt := range stmtList { lstmt = nil L: switch s := stmt.(type) { @@ -222,7 +224,20 @@ func (ls *labelScope) blockBranches(parent *block, ctxt targets, lstmt *LabeledS ls.err(s.Pos(), "continue is not in a loop") } case _Fallthrough: - // nothing to do + msg := "fallthrough statement out of place" + if t, _ := ctxt.breaks.(*SwitchStmt); t != nil { + if _, ok := t.Tag.(*TypeSwitchGuard); ok { + msg = "cannot fallthrough in type switch" + } else if ctxt.caseIndex < 0 || stmtIndex+1 < len(stmtList) { + // fallthrough nested in a block or not the last statement + // use msg as is + } else if ctxt.caseIndex+1 == len(t.Body) { + msg = "cannot fallthrough final case in switch" + } else { + break // fallthrough ok + } + } + ls.err(s.Pos(), msg) case _Goto: fallthrough // should always have a label default: @@ -282,25 +297,29 @@ func (ls *labelScope) blockBranches(parent *block, ctxt targets, lstmt *LabeledS } case *BlockStmt: - innerBlock(ctxt, s.Pos(), s.List) + inner := targets{ctxt.breaks, ctxt.continues, -1} + innerBlock(inner, s.Pos(), s.List) case *IfStmt: - innerBlock(ctxt, s.Then.Pos(), s.Then.List) + inner := targets{ctxt.breaks, ctxt.continues, -1} + innerBlock(inner, s.Then.Pos(), s.Then.List) if s.Else != nil { - innerBlock(ctxt, s.Else.Pos(), []Stmt{s.Else}) + innerBlock(inner, s.Else.Pos(), []Stmt{s.Else}) } case *ForStmt: - innerBlock(targets{s, s}, s.Body.Pos(), s.Body.List) + inner := targets{s, s, -1} + innerBlock(inner, s.Body.Pos(), s.Body.List) case *SwitchStmt: - inner := targets{s, ctxt.continues} - for _, cc := range s.Body { + inner := targets{s, ctxt.continues, -1} + for i, cc := range s.Body { + inner.caseIndex = i innerBlock(inner, cc.Pos(), cc.Body) } case *SelectStmt: - inner := targets{s, ctxt.continues} + inner := targets{s, ctxt.continues, -1} for _, cc := range s.Body { innerBlock(inner, cc.Pos(), cc.Body) } @@ -309,3 +328,12 @@ func (ls *labelScope) blockBranches(parent *block, ctxt targets, lstmt *LabeledS return fwdGotos } + +func trimTrailingEmptyStmts(list []Stmt) []Stmt { + for i := len(list); i > 0; i-- { + if _, ok := list[i-1].(*EmptyStmt); !ok { + return list[:i] + } + } + return nil +} diff --git a/src/cmd/compile/internal/syntax/error_test.go b/src/cmd/compile/internal/syntax/error_test.go index 724ca0eb986..2f70b5278e7 100644 --- a/src/cmd/compile/internal/syntax/error_test.go +++ b/src/cmd/compile/internal/syntax/error_test.go @@ -162,7 +162,7 @@ func testSyntaxErrors(t *testing.T, filename string) { } else { t.Errorf("%s:%s: unexpected error: %s", filename, orig, e.Msg) } - }, nil, 0) + }, nil, CheckBranches) if *print { fmt.Println() diff --git a/src/cmd/compile/internal/syntax/testdata/fallthrough.go b/src/cmd/compile/internal/syntax/testdata/fallthrough.go new file mode 100644 index 00000000000..851da81ea04 --- /dev/null +++ b/src/cmd/compile/internal/syntax/testdata/fallthrough.go @@ -0,0 +1,55 @@ +// Copyright 2022 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fallthroughs + +func _() { + var x int + switch x { + case 0: + fallthrough + + case 1: + fallthrough // ERROR fallthrough statement out of place + { + } + + case 2: + { + fallthrough // ERROR fallthrough statement out of place + } + + case 3: + for { + fallthrough // ERROR fallthrough statement out of place + } + + case 4: + fallthrough // trailing empty statements are ok + ; + ; + + case 5: + fallthrough + + default: + fallthrough // ERROR cannot fallthrough final case in switch + } + + fallthrough // ERROR fallthrough statement out of place + + if true { + fallthrough // ERROR fallthrough statement out of place + } + + for { + fallthrough // ERROR fallthrough statement out of place + } + + var t any + switch t.(type) { + case int: + fallthrough // ERROR cannot fallthrough in type switch + } +}