diff --git a/src/cmd/compile/internal/escape/escape.go b/src/cmd/compile/internal/escape/escape.go index 31d157b1651..d8f0111d2de 100644 --- a/src/cmd/compile/internal/escape/escape.go +++ b/src/cmd/compile/internal/escape/escape.go @@ -369,7 +369,6 @@ func (e *escape) stmt(n ir.Node) { var ks []hole for _, cas := range n.Cases { // cases - cas := cas.(*ir.CaseStmt) if typesw && n.Tag.(*ir.TypeSwitchGuard).Tag != nil { cv := cas.Var k := e.dcl(cv) // type switch variables have no ODCL. @@ -391,7 +390,6 @@ func (e *escape) stmt(n ir.Node) { case ir.OSELECT: n := n.(*ir.SelectStmt) for _, cas := range n.Cases { - cas := cas.(*ir.CaseStmt) e.stmt(cas.Comm) e.block(cas.Body) } diff --git a/src/cmd/compile/internal/ir/mknode.go b/src/cmd/compile/internal/ir/mknode.go index f5dacee622f..edf3ee501c1 100644 --- a/src/cmd/compile/internal/ir/mknode.go +++ b/src/cmd/compile/internal/ir/mknode.go @@ -37,6 +37,7 @@ func main() { nodeType := lookup("Node") ntypeType := lookup("Ntype") nodesType := lookup("Nodes") + slicePtrCaseStmtType := types.NewSlice(types.NewPointer(lookup("CaseStmt"))) ptrFieldType := types.NewPointer(lookup("Field")) slicePtrFieldType := types.NewSlice(ptrFieldType) ptrIdentType := types.NewPointer(lookup("Ident")) @@ -76,6 +77,8 @@ func main() { switch { case is(nodesType): fmt.Fprintf(&buf, "c.%s = c.%s.Copy()\n", name, name) + case is(slicePtrCaseStmtType): + fmt.Fprintf(&buf, "c.%s = copyCases(c.%s)\n", name, name) case is(ptrFieldType): fmt.Fprintf(&buf, "if c.%s != nil { c.%s = c.%s.copy() }\n", name, name, name) case is(slicePtrFieldType): @@ -94,6 +97,8 @@ func main() { fmt.Fprintf(&buf, "err = maybeDo(n.%s, err, do)\n", name) case is(nodesType): fmt.Fprintf(&buf, "err = maybeDoList(n.%s, err, do)\n", name) + case is(slicePtrCaseStmtType): + fmt.Fprintf(&buf, "err = maybeDoCases(n.%s, err, do)\n", name) case is(ptrFieldType): fmt.Fprintf(&buf, "err = maybeDoField(n.%s, err, do)\n", name) case is(slicePtrFieldType): @@ -113,6 +118,8 @@ func main() { fmt.Fprintf(&buf, "n.%s = toNtype(maybeEdit(n.%s, edit))\n", name, name) case is(nodesType): fmt.Fprintf(&buf, "editList(n.%s, edit)\n", name) + case is(slicePtrCaseStmtType): + fmt.Fprintf(&buf, "editCases(n.%s, edit)\n", name) case is(ptrFieldType): fmt.Fprintf(&buf, "editField(n.%s, edit)\n", name) case is(slicePtrFieldType): diff --git a/src/cmd/compile/internal/ir/node_gen.go b/src/cmd/compile/internal/ir/node_gen.go index ecb39563c46..041855bbe93 100644 --- a/src/cmd/compile/internal/ir/node_gen.go +++ b/src/cmd/compile/internal/ir/node_gen.go @@ -781,20 +781,20 @@ func (n *SelectStmt) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) } func (n *SelectStmt) copy() Node { c := *n c.init = c.init.Copy() - c.Cases = c.Cases.Copy() + c.Cases = copyCases(c.Cases) c.Compiled = c.Compiled.Copy() return &c } func (n *SelectStmt) doChildren(do func(Node) error) error { var err error err = maybeDoList(n.init, err, do) - err = maybeDoList(n.Cases, err, do) + err = maybeDoCases(n.Cases, err, do) err = maybeDoList(n.Compiled, err, do) return err } func (n *SelectStmt) editChildren(edit func(Node) Node) { editList(n.init, edit) - editList(n.Cases, edit) + editCases(n.Cases, edit) editList(n.Compiled, edit) } @@ -945,7 +945,7 @@ func (n *SwitchStmt) Format(s fmt.State, verb rune) { FmtNode(n, s, verb) } func (n *SwitchStmt) copy() Node { c := *n c.init = c.init.Copy() - c.Cases = c.Cases.Copy() + c.Cases = copyCases(c.Cases) c.Compiled = c.Compiled.Copy() return &c } @@ -953,14 +953,14 @@ func (n *SwitchStmt) doChildren(do func(Node) error) error { var err error err = maybeDoList(n.init, err, do) err = maybeDo(n.Tag, err, do) - err = maybeDoList(n.Cases, err, do) + err = maybeDoCases(n.Cases, err, do) err = maybeDoList(n.Compiled, err, do) return err } func (n *SwitchStmt) editChildren(edit func(Node) Node) { editList(n.init, edit) n.Tag = maybeEdit(n.Tag, edit) - editList(n.Cases, edit) + editCases(n.Cases, edit) editList(n.Compiled, edit) } diff --git a/src/cmd/compile/internal/ir/stmt.go b/src/cmd/compile/internal/ir/stmt.go index cfda6fd234d..ce775a8529e 100644 --- a/src/cmd/compile/internal/ir/stmt.go +++ b/src/cmd/compile/internal/ir/stmt.go @@ -191,6 +191,37 @@ func NewCaseStmt(pos src.XPos, list, body []Node) *CaseStmt { return n } +func copyCases(list []*CaseStmt) []*CaseStmt { + if list == nil { + return nil + } + c := make([]*CaseStmt, len(list)) + copy(c, list) + return c +} + +func maybeDoCases(list []*CaseStmt, err error, do func(Node) error) error { + if err != nil { + return err + } + for _, x := range list { + if x != nil { + if err := do(x); err != nil { + return err + } + } + } + return nil +} + +func editCases(list []*CaseStmt, edit func(Node) Node) { + for i, x := range list { + if x != nil { + list[i] = edit(x).(*CaseStmt) + } + } +} + // A ForStmt is a non-range for loop: for Init; Cond; Post { Body } // Op can be OFOR or OFORUNTIL (!Cond). type ForStmt struct { @@ -334,18 +365,18 @@ func (n *ReturnStmt) SetOrig(x Node) { n.orig = x } type SelectStmt struct { miniStmt Label *types.Sym - Cases Nodes + Cases []*CaseStmt HasBreak bool // TODO(rsc): Instead of recording here, replace with a block? Compiled Nodes // compiled form, after walkswitch } -func NewSelectStmt(pos src.XPos, cases []Node) *SelectStmt { +func NewSelectStmt(pos src.XPos, cases []*CaseStmt) *SelectStmt { n := &SelectStmt{} n.pos = pos n.op = OSELECT - n.Cases.Set(cases) + n.Cases = cases return n } @@ -367,7 +398,7 @@ func NewSendStmt(pos src.XPos, ch, value Node) *SendStmt { type SwitchStmt struct { miniStmt Tag Node - Cases Nodes // list of *CaseStmt + Cases []*CaseStmt Label *types.Sym HasBreak bool @@ -375,11 +406,11 @@ type SwitchStmt struct { Compiled Nodes // compiled form, after walkswitch } -func NewSwitchStmt(pos src.XPos, tag Node, cases []Node) *SwitchStmt { +func NewSwitchStmt(pos src.XPos, tag Node, cases []*CaseStmt) *SwitchStmt { n := &SwitchStmt{Tag: tag} n.pos = pos n.op = OSWITCH - n.Cases.Set(cases) + n.Cases = cases return n } diff --git a/src/cmd/compile/internal/ir/visit.go b/src/cmd/compile/internal/ir/visit.go index a1c345968f7..8839e1664d3 100644 --- a/src/cmd/compile/internal/ir/visit.go +++ b/src/cmd/compile/internal/ir/visit.go @@ -217,10 +217,9 @@ func EditChildren(n Node, edit func(Node) Node) { // Note that editList only calls edit on the nodes in the list, not their children. // If x's children should be processed, edit(x) must call EditChildren(x, edit) itself. func editList(list Nodes, edit func(Node) Node) { - s := list for i, x := range list { if x != nil { - s[i] = edit(x) + list[i] = edit(x) } } } diff --git a/src/cmd/compile/internal/noder/noder.go b/src/cmd/compile/internal/noder/noder.go index ad66b6c8509..b974448338f 100644 --- a/src/cmd/compile/internal/noder/noder.go +++ b/src/cmd/compile/internal/noder/noder.go @@ -1202,14 +1202,14 @@ func (p *noder) switchStmt(stmt *syntax.SwitchStmt) ir.Node { if l := n.Tag; l != nil && l.Op() == ir.OTYPESW { tswitch = l.(*ir.TypeSwitchGuard) } - n.Cases.Set(p.caseClauses(stmt.Body, tswitch, stmt.Rbrace)) + n.Cases = p.caseClauses(stmt.Body, tswitch, stmt.Rbrace) p.closeScope(stmt.Rbrace) return n } -func (p *noder) caseClauses(clauses []*syntax.CaseClause, tswitch *ir.TypeSwitchGuard, rbrace syntax.Pos) []ir.Node { - nodes := make([]ir.Node, 0, len(clauses)) +func (p *noder) caseClauses(clauses []*syntax.CaseClause, tswitch *ir.TypeSwitchGuard, rbrace syntax.Pos) []*ir.CaseStmt { + nodes := make([]*ir.CaseStmt, 0, len(clauses)) for i, clause := range clauses { p.setlineno(clause) if i > 0 { @@ -1266,8 +1266,8 @@ func (p *noder) simpleStmt(stmt syntax.SimpleStmt) []ir.Node { return []ir.Node{p.stmt(stmt)} } -func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []ir.Node { - nodes := make([]ir.Node, len(clauses)) +func (p *noder) commClauses(clauses []*syntax.CommClause, rbrace syntax.Pos) []*ir.CaseStmt { + nodes := make([]*ir.CaseStmt, len(clauses)) for i, clause := range clauses { p.setlineno(clause) if i > 0 { diff --git a/src/cmd/compile/internal/typecheck/iexport.go b/src/cmd/compile/internal/typecheck/iexport.go index 0c813a71ef4..19437a069e1 100644 --- a/src/cmd/compile/internal/typecheck/iexport.go +++ b/src/cmd/compile/internal/typecheck/iexport.go @@ -1181,10 +1181,9 @@ func isNamedTypeSwitch(x ir.Node) bool { return ok && guard.Tag != nil } -func (w *exportWriter) caseList(cases []ir.Node, namedTypeSwitch bool) { +func (w *exportWriter) caseList(cases []*ir.CaseStmt, namedTypeSwitch bool) { w.uint64(uint64(len(cases))) for _, cas := range cases { - cas := cas.(*ir.CaseStmt) w.pos(cas.Pos()) w.stmtList(cas.List) if namedTypeSwitch { diff --git a/src/cmd/compile/internal/typecheck/iimport.go b/src/cmd/compile/internal/typecheck/iimport.go index 8285c418e9f..fd8314b6621 100644 --- a/src/cmd/compile/internal/typecheck/iimport.go +++ b/src/cmd/compile/internal/typecheck/iimport.go @@ -767,10 +767,10 @@ func (r *importReader) stmtList() []ir.Node { return list } -func (r *importReader) caseList(switchExpr ir.Node) []ir.Node { +func (r *importReader) caseList(switchExpr ir.Node) []*ir.CaseStmt { namedTypeSwitch := isNamedTypeSwitch(switchExpr) - cases := make([]ir.Node, r.uint64()) + cases := make([]*ir.CaseStmt, r.uint64()) for i := range cases { cas := ir.NewCaseStmt(r.pos(), nil, nil) cas.List.Set(r.stmtList()) diff --git a/src/cmd/compile/internal/typecheck/stmt.go b/src/cmd/compile/internal/typecheck/stmt.go index 7e74b730bc1..03c3e399eb4 100644 --- a/src/cmd/compile/internal/typecheck/stmt.go +++ b/src/cmd/compile/internal/typecheck/stmt.go @@ -364,8 +364,6 @@ func tcSelect(sel *ir.SelectStmt) { lno := ir.SetPos(sel) Stmts(sel.Init()) for _, ncase := range sel.Cases { - ncase := ncase.(*ir.CaseStmt) - if len(ncase.List) == 0 { // default if def != nil { @@ -508,7 +506,6 @@ func tcSwitchExpr(n *ir.SwitchStmt) { var defCase ir.Node var cs constSet for _, ncase := range n.Cases { - ncase := ncase.(*ir.CaseStmt) ls := ncase.List if len(ls) == 0 { // default: if defCase != nil { @@ -577,7 +574,6 @@ func tcSwitchType(n *ir.SwitchStmt) { var defCase, nilCase ir.Node var ts typeSet for _, ncase := range n.Cases { - ncase := ncase.(*ir.CaseStmt) ls := ncase.List if len(ls) == 0 { // default: if defCase != nil { diff --git a/src/cmd/compile/internal/typecheck/typecheck.go b/src/cmd/compile/internal/typecheck/typecheck.go index b779f9ceb0f..dabfee3bf9c 100644 --- a/src/cmd/compile/internal/typecheck/typecheck.go +++ b/src/cmd/compile/internal/typecheck/typecheck.go @@ -2103,7 +2103,6 @@ func isTermNode(n ir.Node) bool { } def := false for _, cas := range n.Cases { - cas := cas.(*ir.CaseStmt) if !isTermNodes(cas.Body) { return false } @@ -2119,7 +2118,6 @@ func isTermNode(n ir.Node) bool { return false } for _, cas := range n.Cases { - cas := cas.(*ir.CaseStmt) if !isTermNodes(cas.Body) { return false } @@ -2218,9 +2216,6 @@ func deadcodeslice(nn *ir.Nodes) { case ir.OBLOCK: n := n.(*ir.BlockStmt) deadcodeslice(&n.List) - case ir.OCASE: - n := n.(*ir.CaseStmt) - deadcodeslice(&n.Body) case ir.OFOR: n := n.(*ir.ForStmt) deadcodeslice(&n.Body) @@ -2233,10 +2228,14 @@ func deadcodeslice(nn *ir.Nodes) { deadcodeslice(&n.Body) case ir.OSELECT: n := n.(*ir.SelectStmt) - deadcodeslice(&n.Cases) + for _, cas := range n.Cases { + deadcodeslice(&cas.Body) + } case ir.OSWITCH: n := n.(*ir.SwitchStmt) - deadcodeslice(&n.Cases) + for _, cas := range n.Cases { + deadcodeslice(&cas.Body) + } } if cut { diff --git a/src/cmd/compile/internal/walk/order.go b/src/cmd/compile/internal/walk/order.go index 1e41cfc6aaf..ebbd4675701 100644 --- a/src/cmd/compile/internal/walk/order.go +++ b/src/cmd/compile/internal/walk/order.go @@ -914,7 +914,6 @@ func (o *orderState) stmt(n ir.Node) { n := n.(*ir.SelectStmt) t := o.markTemp() for _, ncas := range n.Cases { - ncas := ncas.(*ir.CaseStmt) r := ncas.Comm ir.SetPos(ncas) @@ -996,7 +995,6 @@ func (o *orderState) stmt(n ir.Node) { // Also insert any ninit queued during the previous loop. // (The temporary cleaning must follow that ninit work.) for _, cas := range n.Cases { - cas := cas.(*ir.CaseStmt) orderBlock(&cas.Body, o.free) cas.Body.Prepend(o.cleanTempNoPop(t)...) @@ -1036,13 +1034,12 @@ func (o *orderState) stmt(n ir.Node) { n := n.(*ir.SwitchStmt) if base.Debug.Libfuzzer != 0 && !hasDefaultCase(n) { // Add empty "default:" case for instrumentation. - n.Cases.Append(ir.NewCaseStmt(base.Pos, nil, nil)) + n.Cases = append(n.Cases, ir.NewCaseStmt(base.Pos, nil, nil)) } t := o.markTemp() n.Tag = o.expr(n.Tag, nil) for _, ncas := range n.Cases { - ncas := ncas.(*ir.CaseStmt) o.exprListInPlace(ncas.List) orderBlock(&ncas.Body, o.free) } @@ -1056,7 +1053,6 @@ func (o *orderState) stmt(n ir.Node) { func hasDefaultCase(n *ir.SwitchStmt) bool { for _, ncas := range n.Cases { - ncas := ncas.(*ir.CaseStmt) if len(ncas.List) == 0 { return true } diff --git a/src/cmd/compile/internal/walk/select.go b/src/cmd/compile/internal/walk/select.go index 5e03732169f..0b7e7e99fbf 100644 --- a/src/cmd/compile/internal/walk/select.go +++ b/src/cmd/compile/internal/walk/select.go @@ -21,7 +21,7 @@ func walkSelect(sel *ir.SelectStmt) { sel.PtrInit().Set(nil) init = append(init, walkSelectCases(sel.Cases)...) - sel.Cases = ir.Nodes{} + sel.Cases = nil sel.Compiled.Set(init) walkStmtList(sel.Compiled) @@ -29,7 +29,7 @@ func walkSelect(sel *ir.SelectStmt) { base.Pos = lno } -func walkSelectCases(cases ir.Nodes) []ir.Node { +func walkSelectCases(cases []*ir.CaseStmt) []ir.Node { ncas := len(cases) sellineno := base.Pos @@ -40,7 +40,7 @@ func walkSelectCases(cases ir.Nodes) []ir.Node { // optimization: one-case select: single op. if ncas == 1 { - cas := cases[0].(*ir.CaseStmt) + cas := cases[0] ir.SetPos(cas) l := cas.Init() if cas.Comm != nil { // not default: @@ -75,7 +75,6 @@ func walkSelectCases(cases ir.Nodes) []ir.Node { // this rewrite is used by both the general code and the next optimization. var dflt *ir.CaseStmt for _, cas := range cases { - cas := cas.(*ir.CaseStmt) ir.SetPos(cas) n := cas.Comm if n == nil { @@ -99,9 +98,9 @@ func walkSelectCases(cases ir.Nodes) []ir.Node { // optimization: two-case select but one is default: single non-blocking op. if ncas == 2 && dflt != nil { - cas := cases[0].(*ir.CaseStmt) + cas := cases[0] if cas == dflt { - cas = cases[1].(*ir.CaseStmt) + cas = cases[1] } n := cas.Comm @@ -170,7 +169,6 @@ func walkSelectCases(cases ir.Nodes) []ir.Node { // register cases for _, cas := range cases { - cas := cas.(*ir.CaseStmt) ir.SetPos(cas) init = append(init, cas.Init()...) diff --git a/src/cmd/compile/internal/walk/switch.go b/src/cmd/compile/internal/walk/switch.go index 141d2e5e053..de0b471b34e 100644 --- a/src/cmd/compile/internal/walk/switch.go +++ b/src/cmd/compile/internal/walk/switch.go @@ -71,7 +71,6 @@ func walkSwitchExpr(sw *ir.SwitchStmt) { var defaultGoto ir.Node var body ir.Nodes for _, ncase := range sw.Cases { - ncase := ncase.(*ir.CaseStmt) label := typecheck.AutoLabel(".s") jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label) @@ -96,7 +95,7 @@ func walkSwitchExpr(sw *ir.SwitchStmt) { body.Append(br) } } - sw.Cases.Set(nil) + sw.Cases = nil if defaultGoto == nil { br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil) @@ -259,7 +258,6 @@ func allCaseExprsAreSideEffectFree(sw *ir.SwitchStmt) bool { // enough. for _, ncase := range sw.Cases { - ncase := ncase.(*ir.CaseStmt) for _, v := range ncase.List { if v.Op() != ir.OLITERAL { return false @@ -325,7 +323,6 @@ func walkSwitchType(sw *ir.SwitchStmt) { var defaultGoto, nilGoto ir.Node var body ir.Nodes for _, ncase := range sw.Cases { - ncase := ncase.(*ir.CaseStmt) caseVar := ncase.Var // For single-type cases with an interface type, @@ -384,7 +381,7 @@ func walkSwitchType(sw *ir.SwitchStmt) { body.Append(ncase.Body...) body.Append(br) } - sw.Cases.Set(nil) + sw.Cases = nil if defaultGoto == nil { defaultGoto = br