diff --git a/src/cmd/compile/internal/gc/swt.go b/src/cmd/compile/internal/gc/swt.go index 07d324c593..848eca8915 100644 --- a/src/cmd/compile/internal/gc/swt.go +++ b/src/cmd/compile/internal/gc/swt.go @@ -17,14 +17,11 @@ const ( ) const ( - caseKindDefault = iota // default: - // expression switch - caseKindExprConst // case 5: - caseKindExprVar // case x: + caseKindExprConst = iota // case 5: + caseKindExprVar // case x: // type switch - caseKindTypeNil // case nil: caseKindTypeConst // case time.Time: (concrete type, has type hash) caseKindTypeVar // case io.Reader: (interface type) ) @@ -52,6 +49,13 @@ type caseClause struct { typ uint8 // type of case } +// caseClauses are all the case clauses in a switch statement. +type caseClauses struct { + list []caseClause // general cases + defjmp *Node // OGOTO for default case or OBREAK if no default case present + niljmp *Node // OGOTO for nil type case in a type switch +} + // typecheckswitch typechecks a switch statement. func typecheckswitch(n *Node) { lno := lineno @@ -248,16 +252,10 @@ func (s *exprSwitch) walk(sw *Node) { typecheckslice(cas, Etop) } - // enumerate the cases, and lop off the default case - cc := caseClauses(sw, s.kind) + // Enumerate the cases and prepare the default case. + clauses := genCaseClauses(sw, s.kind) sw.List.Set(nil) - var def *Node - if len(cc) > 0 && cc[0].typ == caseKindDefault { - def = cc[0].node.Right - cc = cc[1:] - } else { - def = Nod(OBREAK, nil, nil) - } + cc := clauses.list // handle the cases in order for len(cc) > 0 { @@ -283,14 +281,14 @@ func (s *exprSwitch) walk(sw *Node) { // handle default case if nerrors == 0 { - cas = append(cas, def) + cas = append(cas, clauses.defjmp) sw.Nbody.Set(append(cas, sw.Nbody.Slice()...)) walkstmtlist(sw.Nbody.Slice()) } } // walkCases generates an AST implementing the cases in cc. -func (s *exprSwitch) walkCases(cc []*caseClause) *Node { +func (s *exprSwitch) walkCases(cc []caseClause) *Node { if len(cc) < binarySearchMin { // linear search var cas []*Node @@ -422,27 +420,35 @@ func casebody(sw *Node, typeswvar *Node) { lineno = lno } -// caseClauses generates a slice of caseClauses +// genCaseClauses generates the caseClauses value // corresponding to the clauses in the switch statement sw. // Kind is the kind of switch statement. -func caseClauses(sw *Node, kind int) []*caseClause { - var cc []*caseClause +func genCaseClauses(sw *Node, kind int) caseClauses { + var cc caseClauses for _, n := range sw.List.Slice() { - c := new(caseClause) - cc = append(cc, c) - c.ordinal = len(cc) - c.node = n - if n.Left == nil { - c.typ = caseKindDefault + // default case + if cc.defjmp != nil { + Fatalf("duplicate default case not detected during typechecking") + } + cc.defjmp = n.Right continue } + if kind == switchKindType && n.Left.Op == OLITERAL { + // nil case in type switch + if cc.niljmp != nil { + Fatalf("duplicate nil case not detected during typechecking") + } + cc.niljmp = n.Right + continue + } + + // general case + c := caseClause{node: n, ordinal: len(cc.list)} if kind == switchKindType { // type switch switch { - case n.Left.Op == OLITERAL: - c.typ = caseKindTypeNil case n.Left.Type.IsInterface(): c.typ = caseKindTypeVar default: @@ -458,22 +464,24 @@ func caseClauses(sw *Node, kind int) []*caseClause { c.typ = caseKindExprVar } } + cc.list = append(cc.list, c) } - if cc == nil { - return nil + if cc.defjmp == nil { + cc.defjmp = Nod(OBREAK, nil, nil) + } + + if cc.list == nil { + return cc } // sort by value and diagnose duplicate cases if kind == switchKindType { // type switch - sort.Sort(caseClauseByType(cc)) - for i, c1 := range cc { - if c1.typ == caseKindTypeNil || c1.typ == caseKindDefault { - break - } - for _, c2 := range cc[i+1:] { - if c2.typ == caseKindTypeNil || c2.typ == caseKindDefault || c1.hash != c2.hash { + sort.Sort(caseClauseByType(cc.list)) + for i, c1 := range cc.list { + for _, c2 := range cc.list[i+1:] { + if c1.hash != c2.hash { break } if Eqtype(c1.node.Left.Type, c2.node.Left.Type) { @@ -483,12 +491,12 @@ func caseClauses(sw *Node, kind int) []*caseClause { } } else { // expression switch - sort.Sort(caseClauseByExpr(cc)) - for i, c1 := range cc { - if i+1 == len(cc) { + sort.Sort(caseClauseByExpr(cc.list)) + for i, c1 := range cc.list { + if i+1 == len(cc.list) { break } - c2 := cc[i+1] + c2 := cc.list[i+1] if exprcmp(c1, c2) != 0 { continue } @@ -498,7 +506,7 @@ func caseClauses(sw *Node, kind int) []*caseClause { } // put list back in processing order - sort.Sort(caseClauseByOrd(cc)) + sort.Sort(caseClauseByOrd(cc.list)) return cc } @@ -545,20 +553,9 @@ func (s *typeSwitch) walk(sw *Node) { // set up labels and jumps casebody(sw, s.facename) - cc := caseClauses(sw, switchKindType) + clauses := genCaseClauses(sw, switchKindType) sw.List.Set(nil) - var def *Node - if len(cc) > 0 && cc[0].typ == caseKindDefault { - def = cc[0].node.Right - cc = cc[1:] - } else { - def = Nod(OBREAK, nil, nil) - } - var typenil *Node - if len(cc) > 0 && cc[0].typ == caseKindTypeNil { - typenil = cc[0].node.Right - cc = cc[1:] - } + def := clauses.defjmp // For empty interfaces, do: // if e._type == nil { @@ -573,9 +570,9 @@ func (s *typeSwitch) walk(sw *Node) { // Check for nil first. i := Nod(OIF, nil, nil) i.Left = Nod(OEQ, typ, nodnil()) - if typenil != nil { + if clauses.niljmp != nil { // Do explicit nil case right here. - i.Nbody.Set1(typenil) + i.Nbody.Set1(clauses.niljmp) } else { // Jump to default case. lbl := autolabel(".s") @@ -602,6 +599,8 @@ func (s *typeSwitch) walk(sw *Node) { a = typecheck(a, Etop) cas = append(cas, a) + cc := clauses.list + // insert type equality check into each case block for _, c := range cc { n := c.node @@ -696,7 +695,7 @@ func (s *typeSwitch) typeone(t *Node) *Node { } // walkCases generates an AST implementing the cases in cc. -func (s *typeSwitch) walkCases(cc []*caseClause) *Node { +func (s *typeSwitch) walkCases(cc []caseClause) *Node { if len(cc) < binarySearchMin { var cas []*Node for _, c := range cc { @@ -723,31 +722,13 @@ func (s *typeSwitch) walkCases(cc []*caseClause) *Node { return a } -type caseClauseByOrd []*caseClause +type caseClauseByOrd []caseClause -func (x caseClauseByOrd) Len() int { return len(x) } -func (x caseClauseByOrd) Swap(i, j int) { x[i], x[j] = x[j], x[i] } -func (x caseClauseByOrd) Less(i, j int) bool { - c1, c2 := x[i], x[j] - switch { - // sort default first - case c1.typ == caseKindDefault: - return true - case c2.typ == caseKindDefault: - return false +func (x caseClauseByOrd) Len() int { return len(x) } +func (x caseClauseByOrd) Swap(i, j int) { x[i], x[j] = x[j], x[i] } +func (x caseClauseByOrd) Less(i, j int) bool { return x[i].ordinal < x[j].ordinal } - // sort nil second - case c1.typ == caseKindTypeNil: - return true - case c2.typ == caseKindTypeNil: - return false - } - - // sort by ordinal - return c1.ordinal < c2.ordinal -} - -type caseClauseByExpr []*caseClause +type caseClauseByExpr []caseClause func (x caseClauseByExpr) Len() int { return len(x) } func (x caseClauseByExpr) Swap(i, j int) { x[i], x[j] = x[j], x[i] } @@ -755,7 +736,7 @@ func (x caseClauseByExpr) Less(i, j int) bool { return exprcmp(x[i], x[j]) < 0 } -func exprcmp(c1, c2 *caseClause) int { +func exprcmp(c1, c2 caseClause) int { // sort non-constants last if c1.typ != caseKindExprConst { return +1 @@ -814,7 +795,7 @@ func exprcmp(c1, c2 *caseClause) int { return 0 } -type caseClauseByType []*caseClause +type caseClauseByType []caseClause func (x caseClauseByType) Len() int { return len(x) } func (x caseClauseByType) Swap(i, j int) { x[i], x[j] = x[j], x[i] } diff --git a/src/cmd/compile/internal/gc/swt_test.go b/src/cmd/compile/internal/gc/swt_test.go index c1ee8955cf..4e5dbcae50 100644 --- a/src/cmd/compile/internal/gc/swt_test.go +++ b/src/cmd/compile/internal/gc/swt_test.go @@ -134,7 +134,7 @@ func TestExprcmp(t *testing.T) { }, } for i, d := range testdata { - got := exprcmp(&d.a, &d.b) + got := exprcmp(d.a, d.b) if d.want != got { t.Errorf("%d: exprcmp(a, b) = %d; want %d", i, got, d.want) t.Logf("\ta = caseClause{node: %#v, typ: %#v}", d.a.node, d.a.typ)