diff --git a/src/cmd/internal/gc/swt.go b/src/cmd/internal/gc/swt.go index 81eb56c3a6e..64ab47e6275 100644 --- a/src/cmd/internal/gc/swt.go +++ b/src/cmd/internal/gc/swt.go @@ -7,257 +7,323 @@ package gc import ( "cmd/internal/obj" "fmt" + "sort" + "strconv" ) const ( - Snorm = 0 + iota - Strue - Sfalse - Stype - Tdefault - Texprconst - Texprvar - Ttypenil - Ttypeconst - Ttypevar - Ncase = 4 + // expression switch + switchKindExpr = iota // switch a {...} or switch 5 {...} + switchKindTrue // switch true {...} or switch {...} + switchKindFalse // switch false {...} + + // type switch + switchKindType // switch a.(type) {...} ) -type Case struct { - node *Node - hash uint32 - type_ uint8 - diag uint8 - ordinal uint16 - link *Case +const ( + caseKindDefault = iota // default: + + // expression switch + caseKindExprConst // case 5: + caseKindExprVar // case x: + + // type switch + caseKindTypeNil // case nil: + caseKindTypeConst // case time.Time: (concrete type, has type hash) + caseKindTypeVar // case io.Reader: (interface type) +) + +const binarySearchMin = 4 // minimum number of cases for binary search + +// An exprSwitch walks an expression switch. +type exprSwitch struct { + exprname *Node // node for the expression being switched on + kind int // kind of switch statement (switchKind*) } -var C *Case +// A typeSwitch walks a type switch. +type typeSwitch struct { + hashname *Node // node for the hash of the type of the variable being switched on + facename *Node // node for the concrete type of the variable being switched on + okname *Node // boolean node used for comma-ok type assertions +} -func dumpcase(c0 *Case) { - for c := c0; c != nil; c = c.link { - switch c.type_ { - case Tdefault: - fmt.Printf("case-default\n") - fmt.Printf("\tord=%d\n", c.ordinal) +// A caseClause is a single case clause in a switch statement. +type caseClause struct { + node *Node // points at case statement + ordinal int // position in switch + hash uint32 // hash of a type switch + typ uint8 // type of case +} - case Texprconst: - fmt.Printf("case-exprconst\n") - fmt.Printf("\tord=%d\n", c.ordinal) +// typecheckswitch typechecks a switch statement. +func typecheckswitch(n *Node) { + lno := int(lineno) + typechecklist(n.Ninit, Etop) - case Texprvar: - fmt.Printf("case-exprvar\n") - fmt.Printf("\tord=%d\n", c.ordinal) - fmt.Printf("\top=%v\n", Oconv(int(c.node.Left.Op), 0)) + var nilonly string + var top int + var t *Type - case Ttypenil: - fmt.Printf("case-typenil\n") - fmt.Printf("\tord=%d\n", c.ordinal) - - case Ttypeconst: - fmt.Printf("case-typeconst\n") - fmt.Printf("\tord=%d\n", c.ordinal) - fmt.Printf("\thash=%x\n", c.hash) - - case Ttypevar: - fmt.Printf("case-typevar\n") - fmt.Printf("\tord=%d\n", c.ordinal) - - default: - fmt.Printf("case-???\n") - fmt.Printf("\tord=%d\n", c.ordinal) - fmt.Printf("\top=%v\n", Oconv(int(c.node.Left.Op), 0)) - fmt.Printf("\thash=%x\n", c.hash) + if n.Ntest != nil && n.Ntest.Op == OTYPESW { + // type switch + top = Etype + typecheck(&n.Ntest.Right, Erv) + t = n.Ntest.Right.Type + if t != nil && t.Etype != TINTER { + Yyerror("cannot type switch on non-interface value %v", Nconv(n.Ntest.Right, obj.FmtLong)) } - } - - fmt.Printf("\n") -} - -func ordlcmp(c1 *Case, c2 *Case) int { - // sort default first - if c1.type_ == Tdefault { - return -1 - } - if c2.type_ == Tdefault { - return +1 - } - - // sort nil second - if c1.type_ == Ttypenil { - return -1 - } - if c2.type_ == Ttypenil { - return +1 - } - - // sort by ordinal - if c1.ordinal > c2.ordinal { - return +1 - } - if c1.ordinal < c2.ordinal { - return -1 - } - return 0 -} - -func exprcmp(c1 *Case, c2 *Case) int { - // sort non-constants last - if c1.type_ != Texprconst { - return +1 - } - if c2.type_ != Texprconst { - return -1 - } - - n1 := c1.node.Left - n2 := c2.node.Left - - // sort by type (for switches on interface) - ct := int(n1.Val.Ctype) - - if ct != int(n2.Val.Ctype) { - return ct - int(n2.Val.Ctype) - } - if !Eqtype(n1.Type, n2.Type) { - if n1.Type.Vargen > n2.Type.Vargen { - return +1 - } else { - return -1 - } - } - - // sort by constant value - n := 0 - - switch ct { - case CTFLT: - n = mpcmpfltflt(n1.Val.U.Fval, n2.Val.U.Fval) - - case CTINT, - CTRUNE: - n = Mpcmpfixfix(n1.Val.U.Xval, n2.Val.U.Xval) - - case CTSTR: - n = cmpslit(n1, n2) - } - - return n -} - -func typecmp(c1 *Case, c2 *Case) int { - // sort non-constants last - if c1.type_ != Ttypeconst { - return +1 - } - if c2.type_ != Ttypeconst { - return -1 - } - - // sort by hash code - if c1.hash > c2.hash { - return +1 - } - if c1.hash < c2.hash { - return -1 - } - - // sort by ordinal so duplicate error - // happens on later case. - if c1.ordinal > c2.ordinal { - return +1 - } - if c1.ordinal < c2.ordinal { - return -1 - } - return 0 -} - -func csort(l *Case, f func(*Case, *Case) int) *Case { - if l == nil || l.link == nil { - return l - } - - l1 := l - l2 := l - for { - l2 = l2.link - if l2 == nil { - break - } - l2 = l2.link - if l2 == nil { - break - } - l1 = l1.link - } - - l2 = l1.link - l1.link = nil - l1 = csort(l, f) - l2 = csort(l2, f) - - /* set up lead element */ - if f(l1, l2) < 0 { - l = l1 - l1 = l1.link } else { - l = l2 - l2 = l2.link - } - - le := l - - for { - if l1 == nil { - for l2 != nil { - le.link = l2 - le = l2 - l2 = l2.link - } - - le.link = nil - break - } - - if l2 == nil { - for l1 != nil { - le.link = l1 - le = l1 - l1 = l1.link - } - - break - } - - if f(l1, l2) < 0 { - le.link = l1 - le = l1 - l1 = l1.link + // expression switch + top = Erv + if n.Ntest != nil { + typecheck(&n.Ntest, Erv) + defaultlit(&n.Ntest, nil) + t = n.Ntest.Type } else { - le.link = l2 - le = l2 - l2 = l2.link + t = Types[TBOOL] + } + if t != nil { + var badtype *Type + switch { + case okforeq[t.Etype] == 0: + Yyerror("cannot switch on %v", Nconv(n.Ntest, obj.FmtLong)) + case t.Etype == TARRAY && !Isfixedarray(t): + nilonly = "slice" + case t.Etype == TARRAY && Isfixedarray(t) && algtype1(t, nil) == ANOEQ: + Yyerror("cannot switch on %v", Nconv(n.Ntest, obj.FmtLong)) + case t.Etype == TSTRUCT && algtype1(t, &badtype) == ANOEQ: + Yyerror("cannot switch on %v (struct containing %v cannot be compared)", Nconv(n.Ntest, obj.FmtLong), Tconv(badtype, 0)) + case t.Etype == TFUNC: + nilonly = "func" + case t.Etype == TMAP: + nilonly = "map" + } } } - le.link = nil - return l + n.Type = t + + var def *Node + var ll *NodeList + for l := n.List; l != nil; l = l.Next { + ncase := l.N + setlineno(n) + if ncase.List == nil { + // default + if def != nil { + Yyerror("multiple defaults in switch (first at %v)", def.Line()) + } else { + def = ncase + } + } else { + for ll = ncase.List; ll != nil; ll = ll.Next { + setlineno(ll.N) + typecheck(&ll.N, Erv|Etype) + if ll.N.Type == nil || t == nil { + continue + } + setlineno(ncase) + switch top { + // expression switch + case Erv: + defaultlit(&ll.N, t) + switch { + case ll.N.Op == OTYPE: + Yyerror("type %v is not an expression", Tconv(ll.N.Type, 0)) + case ll.N.Type != nil && assignop(ll.N.Type, t, nil) == 0 && assignop(t, ll.N.Type, nil) == 0: + if n.Ntest != nil { + Yyerror("invalid case %v in switch on %v (mismatched types %v and %v)", Nconv(ll.N, 0), Nconv(n.Ntest, 0), Tconv(ll.N.Type, 0), Tconv(t, 0)) + } else { + Yyerror("invalid case %v in switch (mismatched types %v and bool)", Nconv(ll.N, 0), Tconv(ll.N.Type, 0)) + } + case nilonly != "" && !Isconst(ll.N, CTNIL): + Yyerror("invalid case %v in switch (can only compare %s %v to nil)", Nconv(ll.N, 0), nilonly, Nconv(n.Ntest, 0)) + } + + // type switch + case Etype: + var missing, have *Type + var ptr int + switch { + case ll.N.Op == OLITERAL && Istype(ll.N.Type, TNIL): + case ll.N.Op != OTYPE && ll.N.Type != nil: // should this be ||? + Yyerror("%v is not a type", Nconv(ll.N, obj.FmtLong)) + // reset to original type + ll.N = n.Ntest.Right + case ll.N.Type.Etype != TINTER && t.Etype == TINTER && !implements(ll.N.Type, t, &missing, &have, &ptr): + if have != nil && missing.Broke == 0 && have.Broke == 0 { + Yyerror("impossible type switch case: %v cannot have dynamic type %v"+" (wrong type for %v method)\n\thave %v%v\n\twant %v%v", Nconv(n.Ntest.Right, obj.FmtLong), Tconv(ll.N.Type, 0), Sconv(missing.Sym, 0), Sconv(have.Sym, 0), Tconv(have.Type, obj.FmtShort), Sconv(missing.Sym, 0), Tconv(missing.Type, obj.FmtShort)) + } else if missing.Broke == 0 { + Yyerror("impossible type switch case: %v cannot have dynamic type %v"+" (missing %v method)", Nconv(n.Ntest.Right, obj.FmtLong), Tconv(ll.N.Type, 0), Sconv(missing.Sym, 0)) + } + } + } + } + } + + if top == Etype && n.Type != nil { + ll = ncase.List + nvar := ncase.Nname + if nvar != nil { + if ll != nil && ll.Next == nil && ll.N.Type != nil && !Istype(ll.N.Type, TNIL) { + // single entry type switch + nvar.Ntype = typenod(ll.N.Type) + } else { + // multiple entry type switch or default + nvar.Ntype = typenod(n.Type) + } + + typecheck(&nvar, Erv|Easgn) + ncase.Nname = nvar + } + } + + typechecklist(ncase.Nbody, Etop) + } + + lineno = int32(lno) } -var newlabel_swt_label int +// walkswitch walks a switch statement. +func walkswitch(sw *Node) { + // convert switch {...} to switch true {...} + if sw.Ntest == nil { + sw.Ntest = Nodbool(true) + typecheck(&sw.Ntest, Erv) + } -func newlabel_swt() *Node { - newlabel_swt_label++ - namebuf = fmt.Sprintf("%.6d", newlabel_swt_label) - return newname(Lookup(namebuf)) + if sw.Ntest.Op == OTYPESW { + var s typeSwitch + s.walk(sw) + } else { + var s exprSwitch + s.walk(sw) + } + + // Discard old AST elements. They can confuse racewalk. + sw.Ntest = nil + sw.List = nil } -/* - * build separate list of statements and cases - * make labels between cases and statements - * deal with fallthrough, break, unreachable statements - */ +// walk generates an AST implementing sw. +// sw is an expression switch. +// The AST is generally of the form of a linear +// search using if..goto, although binary search +// is used with long runs of constants. +func (s *exprSwitch) walk(sw *Node) { + casebody(sw, nil) + + s.kind = switchKindExpr + if Isconst(sw.Ntest, CTBOOL) { + s.kind = switchKindTrue + if sw.Ntest.Val.U.Bval == 0 { + s.kind = switchKindFalse + } + } + + walkexpr(&sw.Ntest, &sw.Ninit) + t := sw.Type + if t == nil { + return + } + + // convert the switch into OIF statements + var cas *NodeList + if s.kind == switchKindTrue || s.kind == switchKindFalse { + s.exprname = Nodbool(s.kind == switchKindTrue) + } else if consttype(sw.Ntest) >= 0 { + // leave constants to enable dead code elimination (issue 9608) + s.exprname = sw.Ntest + } else { + s.exprname = temp(sw.Ntest.Type) + cas = list1(Nod(OAS, s.exprname, sw.Ntest)) + typechecklist(cas, Etop) + } + + // enumerate the cases, and lop off the default case + cc := caseClauses(sw, s.kind) + var def *Node + if len(cc) > 0 && cc[0].typ == caseKindDefault { + def = cc[0].node.Right + cc = cc[1:] + } else { + def = Nod(OBREAK, nil, nil) + } + + // handle the cases in order + for len(cc) > 0 { + // deal with expressions one at a time + if okforcmp[t.Etype] == 0 || cc[0].typ != caseKindExprConst { + a := s.walkCases(cc[:1]) + cas = list(cas, a) + cc = cc[1:] + continue + } + + // do binary search on runs of constants + var run int + for run = 1; run < len(cc) && cc[run].typ == caseKindExprConst; run++ { + } + + // sort and compile constants + sort.Sort(caseClauseByExpr(cc[:run])) + a := s.walkCases(cc[:run]) + cas = list(cas, a) + cc = cc[run:] + } + + // handle default case + if nerrors == 0 { + cas = list(cas, def) + sw.Nbody = concat(cas, sw.Nbody) + sw.List = nil + walkstmtlist(sw.Nbody) + } +} + +// walkCases generates an AST implementing the cases in cc. +func (s *exprSwitch) walkCases(cc []*caseClause) *Node { + if len(cc) < binarySearchMin { + // linear search + var cas *NodeList + for _, c := range cc { + n := c.node + lno := int(setlineno(n)) + + a := Nod(OIF, nil, nil) + if (s.kind != switchKindTrue && s.kind != switchKindFalse) || assignop(n.Left.Type, s.exprname.Type, nil) == OCONVIFACE || assignop(s.exprname.Type, n.Left.Type, nil) == OCONVIFACE { + a.Ntest = Nod(OEQ, s.exprname, n.Left) // if name == val + typecheck(&a.Ntest, Erv) + } else if s.kind == switchKindTrue { + a.Ntest = n.Left // if val + } else { + // s.kind == switchKindFalse + a.Ntest = Nod(ONOT, n.Left, nil) // if !val + typecheck(&a.Ntest, Erv) + } + a.Nbody = list1(n.Right) // goto l + + cas = list(cas, a) + lineno = int32(lno) + } + return liststmt(cas) + } + + // find the middle and recur + half := len(cc) / 2 + a := Nod(OIF, nil, nil) + a.Ntest = Nod(OLE, s.exprname, cc[half-1].node.Left) + typecheck(&a.Ntest, Erv) + a.Nbody = list1(s.walkCases(cc[:half])) + a.Nelse = list1(s.walkCases(cc[half:])) + return a +} + +// casebody builds separate lists of statements and cases. +// It makes labels between cases and statements +// and deals with fallthrough, break, and unreachable statements. func casebody(sw *Node, typeswvar *Node) { if sw.List == nil { return @@ -265,67 +331,54 @@ func casebody(sw *Node, typeswvar *Node) { lno := setlineno(sw) - cas := (*NodeList)(nil) // cases - stat := (*NodeList)(nil) // statements - def := (*Node)(nil) // defaults + var cas *NodeList // cases + var stat *NodeList // statements + var def *Node // defaults br := Nod(OBREAK, nil, nil) - var c *Node - var go_ *Node - var needvar bool - var lc *NodeList - var last *Node - var n *Node for l := sw.List; l != nil; l = l.Next { - n = l.N + n := l.N setlineno(n) if n.Op != OXCASE { Fatal("casebody %v", Oconv(int(n.Op), 0)) } n.Op = OCASE - needvar = count(n.List) != 1 || n.List.N.Op == OLITERAL + needvar := count(n.List) != 1 || n.List.N.Op == OLITERAL - go_ = Nod(OGOTO, newlabel_swt(), nil) + jmp := Nod(OGOTO, newCaseLabel(), nil) if n.List == nil { if def != nil { Yyerror("more than one default case") } - // reuse original default case - n.Right = go_ - + n.Right = jmp def = n } if n.List != nil && n.List.Next == nil { - // one case - reuse OCASE node. - c = n.List.N - - n.Left = c - n.Right = go_ + // one case -- reuse OCASE node + n.Left = n.List.N + n.Right = jmp n.List = nil cas = list(cas, n) } else { // expand multi-valued cases - for lc = n.List; lc != nil; lc = lc.Next { - c = lc.N - cas = list(cas, Nod(OCASE, c, go_)) + for lc := n.List; lc != nil; lc = lc.Next { + cas = list(cas, Nod(OCASE, lc.N, jmp)) } } - stat = list(stat, Nod(OLABEL, go_.Left, nil)) + stat = list(stat, Nod(OLABEL, jmp.Left, nil)) if typeswvar != nil && needvar && n.Nname != nil { l := list1(Nod(ODCL, n.Nname, nil)) l = list(l, Nod(OAS, n.Nname, typeswvar)) typechecklist(l, Etop) stat = concat(stat, l) } - stat = concat(stat, n.Nbody) // botch - shouldn't fall thru declaration - last = stat.End.N - + last := stat.End.N if last.Xoffset == n.Xoffset && last.Op == OXFALL { if typeswvar != nil { setlineno(last) @@ -353,326 +406,103 @@ func casebody(sw *Node, typeswvar *Node) { lineno = lno } -func mkcaselist(sw *Node, arg int) *Case { - var n *Node - var c1 *Case +// nSwitchLabel is the number of switch labels generated. +// This should be per-function, but it is a global counter for now. +var nSwitchLabel int - c := (*Case)(nil) - ord := 0 +func newCaseLabel() *Node { + label := strconv.Itoa(nSwitchLabel) + nSwitchLabel++ + return newname(Lookup(label)) +} +// caseClauses generates a slice of caseClauses +// corresponding to the clauses in the switch statement sw. +// Kind is the kind of switch statement. +func caseClauses(sw *Node, kind int) []*caseClause { + var cc []*caseClause for l := sw.List; l != nil; l = l.Next { - n = l.N - c1 = new(Case) - c1.link = c - c = c1 - - ord++ - if int(uint16(ord)) != ord { - Fatal("too many cases in switch") - } - c.ordinal = uint16(ord) + n := l.N + c := new(caseClause) + cc = append(cc, c) + c.ordinal = len(cc) c.node = n if n.Left == nil { - c.type_ = Tdefault + c.typ = caseKindDefault continue } - switch arg { - case Stype: - c.hash = 0 - if n.Left.Op == OLITERAL { - c.type_ = Ttypenil - continue + if kind == switchKindType { + // type switch + switch { + case n.Left.Op == OLITERAL: + c.typ = caseKindTypeNil + case Istype(n.Left.Type, TINTER): + c.typ = caseKindTypeVar + default: + c.typ = caseKindTypeConst + c.hash = typehash(n.Left.Type) } - - if Istype(n.Left.Type, TINTER) { - c.type_ = Ttypevar - continue - } - - c.hash = typehash(n.Left.Type) - c.type_ = Ttypeconst - continue - - case Snorm, - Strue, - Sfalse: - c.type_ = Texprvar - c.hash = typehash(n.Left.Type) + } else { + // expression switch switch consttype(n.Left) { - case CTFLT, - CTINT, - CTRUNE, - CTSTR: - c.type_ = Texprconst + case CTFLT, CTINT, CTRUNE, CTSTR: + c.typ = caseKindExprConst + default: + c.typ = caseKindExprVar } - - continue } } - if c == nil { + if cc == nil { return nil } // sort by value and diagnose duplicate cases - switch arg { - case Stype: - c = csort(c, typecmp) - var c2 *Case - for c1 := c; c1 != nil; c1 = c1.link { - for c2 = c1.link; c2 != nil && c2.hash == c1.hash; c2 = c2.link { - if c1.type_ == Ttypenil || c1.type_ == Tdefault { + if kind == switchKindType { + // type switch + sort.Sort(caseClauseByType(cc)) + for i, c1 := range cc { + if c1.typ == caseKindTypeNil || c1.typ == caseKindDefault { + break + } + for _, c2 := range cc[i+1:] { + if c2.typ == caseKindTypeNil || c2.typ == caseKindDefault || c1.hash != c2.hash { break } - if c2.type_ == Ttypenil || c2.type_ == Tdefault { - break + if Eqtype(c1.node.Left.Type, c2.node.Left.Type) { + yyerrorl(int(c2.node.Lineno), "duplicate case %v in type switch\n\tprevious case at %v", Tconv(c2.node.Left.Type, 0), c1.node.Line()) } - if !Eqtype(c1.node.Left.Type, c2.node.Left.Type) { - continue - } - yyerrorl(int(c2.node.Lineno), "duplicate case %v in type switch\n\tprevious case at %v", Tconv(c2.node.Left.Type, 0), c1.node.Line()) } } - - case Snorm, - Strue, - Sfalse: - c = csort(c, exprcmp) - for c1 := c; c1.link != nil; c1 = c1.link { - if exprcmp(c1, c1.link) != 0 { + } else { + // expression switch + sort.Sort(caseClauseByExpr(cc)) + for i, c1 := range cc { + if i+1 == len(cc) { + break + } + c2 := cc[i+1] + if exprcmp(c1, c2) != 0 { continue } - setlineno(c1.link.node) + setlineno(c2.node) Yyerror("duplicate case %v in switch\n\tprevious case at %v", Nconv(c1.node.Left, 0), c1.node.Line()) } } // put list back in processing order - c = csort(c, ordlcmp) - - return c + sort.Sort(caseClauseByOrd(cc)) + return cc } -var exprname *Node - -func exprbsw(c0 *Case, ncase int, arg int) *Node { - cas := (*NodeList)(nil) - if ncase < Ncase { - var a *Node - var n *Node - var lno int - for i := 0; i < ncase; i++ { - n = c0.node - lno = int(setlineno(n)) - - if (arg != Strue && arg != Sfalse) || assignop(n.Left.Type, exprname.Type, nil) == OCONVIFACE || assignop(exprname.Type, n.Left.Type, nil) == OCONVIFACE { - a = Nod(OIF, nil, nil) - a.Ntest = Nod(OEQ, exprname, n.Left) // if name == val - typecheck(&a.Ntest, Erv) - a.Nbody = list1(n.Right) // then goto l - } else if arg == Strue { - a = Nod(OIF, nil, nil) - a.Ntest = n.Left // if val - a.Nbody = list1(n.Right) // then goto l // arg == Sfalse - } else { - a = Nod(OIF, nil, nil) - a.Ntest = Nod(ONOT, n.Left, nil) // if !val - typecheck(&a.Ntest, Erv) - a.Nbody = list1(n.Right) // then goto l - } - - cas = list(cas, a) - c0 = c0.link - lineno = int32(lno) - } - - return liststmt(cas) - } - - // find the middle and recur - c := c0 - - half := ncase >> 1 - for i := 1; i < half; i++ { - c = c.link - } - a := Nod(OIF, nil, nil) - a.Ntest = Nod(OLE, exprname, c.node.Left) - typecheck(&a.Ntest, Erv) - a.Nbody = list1(exprbsw(c0, half, arg)) - a.Nelse = list1(exprbsw(c.link, ncase-half, arg)) - return a -} - -/* - * normal (expression) switch. - * rebuild case statements into if .. goto - */ -func exprswitch(sw *Node) { - casebody(sw, nil) - - arg := Snorm - if Isconst(sw.Ntest, CTBOOL) { - arg = Strue - if sw.Ntest.Val.U.Bval == 0 { - arg = Sfalse - } - } - - walkexpr(&sw.Ntest, &sw.Ninit) - t := sw.Type - if t == nil { - return - } - - /* - * convert the switch into OIF statements - */ - exprname = nil - - cas := (*NodeList)(nil) - if arg == Strue || arg == Sfalse { - exprname = Nodbool(arg == Strue) - } else if consttype(sw.Ntest) >= 0 { - // leave constants to enable dead code elimination (issue 9608) - exprname = sw.Ntest - } else { - exprname = temp(sw.Ntest.Type) - cas = list1(Nod(OAS, exprname, sw.Ntest)) - typechecklist(cas, Etop) - } - - c0 := mkcaselist(sw, arg) - var def *Node - if c0 != nil && c0.type_ == Tdefault { - def = c0.node.Right - c0 = c0.link - } else { - def = Nod(OBREAK, nil, nil) - } - - var c *Case - var a *Node - var ncase int - var c1 *Case -loop: - if c0 == nil { - cas = list(cas, def) - sw.Nbody = concat(cas, sw.Nbody) - sw.List = nil - walkstmtlist(sw.Nbody) - return - } - - // deal with the variables one-at-a-time - if okforcmp[t.Etype] == 0 || c0.type_ != Texprconst { - a = exprbsw(c0, 1, arg) - cas = list(cas, a) - c0 = c0.link - goto loop - } - - // do binary search on run of constants - ncase = 1 - - for c = c0; c.link != nil; c = c.link { - if c.link.type_ != Texprconst { - break - } - ncase++ - } - - // break the chain at the count - c1 = c.link - - c.link = nil - - // sort and compile constants - c0 = csort(c0, exprcmp) - - a = exprbsw(c0, ncase, arg) - cas = list(cas, a) - - c0 = c1 - goto loop -} - -var hashname *Node - -var facename *Node - -var boolname *Node - -func typeone(t *Node) *Node { - var_ := t.Nname - init := (*NodeList)(nil) - if var_ == nil { - typecheck(&nblank, Erv|Easgn) - var_ = nblank - } else { - init = list1(Nod(ODCL, var_, nil)) - } - - a := Nod(OAS2, nil, nil) - a.List = list(list1(var_), boolname) // var,bool = - b := Nod(ODOTTYPE, facename, nil) - b.Type = t.Left.Type // interface.(type) - a.Rlist = list1(b) - typecheck(&a, Etop) - init = list(init, a) - - b = Nod(OIF, nil, nil) - b.Ntest = boolname - b.Nbody = list1(t.Right) // if bool { goto l } - a = liststmt(list(init, b)) - return a -} - -func typebsw(c0 *Case, ncase int) *Node { - cas := (*NodeList)(nil) - - if ncase < Ncase { - var n *Node - var a *Node - for i := 0; i < ncase; i++ { - n = c0.node - if c0.type_ != Ttypeconst { - Fatal("typebsw") - } - a = Nod(OIF, nil, nil) - a.Ntest = Nod(OEQ, hashname, Nodintconst(int64(c0.hash))) - typecheck(&a.Ntest, Erv) - a.Nbody = list1(n.Right) - cas = list(cas, a) - c0 = c0.link - } - - return liststmt(cas) - } - - // find the middle and recur - c := c0 - - half := ncase >> 1 - for i := 1; i < half; i++ { - c = c.link - } - a := Nod(OIF, nil, nil) - a.Ntest = Nod(OLE, hashname, Nodintconst(int64(c.hash))) - typecheck(&a.Ntest, Erv) - a.Nbody = list1(typebsw(c0, half)) - a.Nelse = list1(typebsw(c.link, ncase-half)) - return a -} - -/* - * convert switch of the form - * switch v := i.(type) { case t1: ..; case t2: ..; } - * into if statements - */ -func typeswitch(sw *Node) { +// walk generates an AST that implements sw, +// where sw is a type switch. +// The AST is generally of the form of a linear +// search using if..goto, although binary search +// is used with long runs of concrete types. +func (s *typeSwitch) walk(sw *Node) { if sw.Ntest == nil { return } @@ -688,26 +518,25 @@ func typeswitch(sw *Node) { return } - cas := (*NodeList)(nil) + var cas *NodeList - /* - * predeclare temporary variables - * and the boolean var - */ - facename = temp(sw.Ntest.Right.Type) + // predeclare temporary variables and the boolean var + s.facename = temp(sw.Ntest.Right.Type) - a := Nod(OAS, facename, sw.Ntest.Right) + a := Nod(OAS, s.facename, sw.Ntest.Right) typecheck(&a, Etop) cas = list(cas, a) - casebody(sw, facename) + s.okname = temp(Types[TBOOL]) + typecheck(&s.okname, Erv) - boolname = temp(Types[TBOOL]) - typecheck(&boolname, Erv) + s.hashname = temp(Types[TUINT32]) + typecheck(&s.hashname, Erv) - hashname = temp(Types[TUINT32]) - typecheck(&hashname, Erv) + // set up labels and jumps + casebody(sw, s.facename) + // calculate type hash t := sw.Ntest.Right.Type if isnilinter(t) { a = syslook("efacethash", 1) @@ -716,98 +545,81 @@ func typeswitch(sw *Node) { } argtype(a, t) a = Nod(OCALL, a, nil) - a.List = list1(facename) - a = Nod(OAS, hashname, a) + a.List = list1(s.facename) + a = Nod(OAS, s.hashname, a) typecheck(&a, Etop) cas = list(cas, a) - c0 := mkcaselist(sw, Stype) + cc := caseClauses(sw, switchKindType) var def *Node - if c0 != nil && c0.type_ == Tdefault { - def = c0.node.Right - c0 = c0.link + if len(cc) > 0 && cc[0].typ == caseKindDefault { + def = cc[0].node.Right + cc = cc[1:] } else { def = Nod(OBREAK, nil, nil) } - /* - * insert if statement into each case block - */ - var v Val - var n *Node - for c := c0; c != nil; c = c.link { - n = c.node - switch c.type_ { - case Ttypenil: + // insert type equality check into each case block + for _, c := range cc { + n := c.node + switch c.typ { + case caseKindTypeNil: + var v Val v.Ctype = CTNIL a = Nod(OIF, nil, nil) - a.Ntest = Nod(OEQ, facename, nodlit(v)) + a.Ntest = Nod(OEQ, s.facename, nodlit(v)) typecheck(&a.Ntest, Erv) a.Nbody = list1(n.Right) // if i==nil { goto l } n.Right = a - case Ttypevar, - Ttypeconst: - n.Right = typeone(n) + case caseKindTypeVar, caseKindTypeConst: + n.Right = s.typeone(n) } } - /* - * generate list of if statements, binary search for constant sequences - */ - var ncase int - var c1 *Case - var hash *NodeList - var c *Case - for c0 != nil { - if c0.type_ != Ttypeconst { - n = c0.node + // generate list of if statements, binary search for constant sequences + for len(cc) > 0 { + if cc[0].typ != caseKindTypeConst { + n := cc[0].node cas = list(cas, n.Right) - c0 = c0.link + cc = cc[1:] continue } // identify run of constants - c = c0 - c1 = c - - for c.link != nil && c.link.type_ == Ttypeconst { - c = c.link + var run int + for run = 1; run < len(cc) && cc[run].typ == caseKindTypeConst; run++ { } - c0 = c.link - c.link = nil // sort by hash - c1 = csort(c1, typecmp) + sort.Sort(caseClauseByType(cc[:run])) // for debugging: linear search if false { - for c = c1; c != nil; c = c.link { - n = c.node + for i := 0; i < run; i++ { + n := cc[i].node cas = list(cas, n.Right) } - continue } // combine adjacent cases with the same hash - ncase = 0 - - for c = c1; c != nil; c = c.link { + ncase := 0 + for i := 0; i < run; i++ { ncase++ - hash = list1(c.node.Right) - for c.link != nil && c.link.hash == c.hash { - hash = list(hash, c.link.node.Right) - c.link = c.link.link + hash := list1(cc[i].node.Right) + for j := i + 1; j < run && cc[i].hash == cc[j].hash; j++ { + hash = list(hash, cc[j].node.Right) } - - c.node.Right = liststmt(hash) + cc[i].node.Right = liststmt(hash) } // binary search among cases to narrow by hash - cas = list(cas, typebsw(c1, ncase)) + cas = list(cas, s.walkCases(cc[:ncase])) + cc = cc[ncase:] } + // handle default case if nerrors == 0 { cas = list(cas, def) sw.Nbody = concat(cas, sw.Nbody) @@ -816,161 +628,192 @@ func typeswitch(sw *Node) { } } -func walkswitch(sw *Node) { - /* - * reorder the body into (OLIST, cases, statements) - * cases have OGOTO into statements. - * both have inserted OBREAK statements - */ - if sw.Ntest == nil { - sw.Ntest = Nodbool(true) - typecheck(&sw.Ntest, Erv) - } - - if sw.Ntest.Op == OTYPESW { - typeswitch(sw) - - //dump("sw", sw); - return - } - - exprswitch(sw) - - // Discard old AST elements after a walk. They can confuse racewealk. - sw.Ntest = nil - - sw.List = nil -} - -/* - * type check switch statement - */ -func typecheckswitch(n *Node) { - var top int - var t *Type - - lno := int(lineno) - typechecklist(n.Ninit, Etop) - nilonly := "" - - if n.Ntest != nil && n.Ntest.Op == OTYPESW { - // type switch - top = Etype - - typecheck(&n.Ntest.Right, Erv) - t = n.Ntest.Right.Type - if t != nil && t.Etype != TINTER { - Yyerror("cannot type switch on non-interface value %v", Nconv(n.Ntest.Right, obj.FmtLong)) - } +// typeone generates an AST that jumps to the +// case body if the variable is of type t. +func (s *typeSwitch) typeone(t *Node) *Node { + name := t.Nname + var init *NodeList + if name == nil { + typecheck(&nblank, Erv|Easgn) + name = nblank } else { - // value switch - top = Erv - - if n.Ntest != nil { - typecheck(&n.Ntest, Erv) - defaultlit(&n.Ntest, nil) - t = n.Ntest.Type - } else { - t = Types[TBOOL] - } - if t != nil { - var badtype *Type - if okforeq[t.Etype] == 0 { - Yyerror("cannot switch on %v", Nconv(n.Ntest, obj.FmtLong)) - } else if t.Etype == TARRAY && !Isfixedarray(t) { - nilonly = "slice" - } else if t.Etype == TARRAY && Isfixedarray(t) && algtype1(t, nil) == ANOEQ { - Yyerror("cannot switch on %v", Nconv(n.Ntest, obj.FmtLong)) - } else if t.Etype == TSTRUCT && algtype1(t, &badtype) == ANOEQ { - Yyerror("cannot switch on %v (struct containing %v cannot be compared)", Nconv(n.Ntest, obj.FmtLong), Tconv(badtype, 0)) - } else if t.Etype == TFUNC { - nilonly = "func" - } else if t.Etype == TMAP { - nilonly = "map" - } - } + init = list1(Nod(ODCL, name, nil)) } - n.Type = t + a := Nod(OAS2, nil, nil) + a.List = list(list1(name), s.okname) // name, ok = + b := Nod(ODOTTYPE, s.facename, nil) + b.Type = t.Left.Type // interface.(type) + a.Rlist = list1(b) + typecheck(&a, Etop) + init = list(init, a) - def := (*Node)(nil) - var ptr int - var have *Type - var nvar *Node - var ll *NodeList - var missing *Type - var ncase *Node - for l := n.List; l != nil; l = l.Next { - ncase = l.N - setlineno(n) - if ncase.List == nil { - // default - if def != nil { - Yyerror("multiple defaults in switch (first at %v)", def.Line()) - } else { - def = ncase - } - } else { - for ll = ncase.List; ll != nil; ll = ll.Next { - setlineno(ll.N) - typecheck(&ll.N, Erv|Etype) - if ll.N.Type == nil || t == nil { - continue - } - setlineno(ncase) - switch top { - case Erv: // expression switch - defaultlit(&ll.N, t) + c := Nod(OIF, nil, nil) + c.Ntest = s.okname + c.Nbody = list1(t.Right) // if ok { goto l } - if ll.N.Op == OTYPE { - Yyerror("type %v is not an expression", Tconv(ll.N.Type, 0)) - } else if ll.N.Type != nil && assignop(ll.N.Type, t, nil) == 0 && assignop(t, ll.N.Type, nil) == 0 { - if n.Ntest != nil { - Yyerror("invalid case %v in switch on %v (mismatched types %v and %v)", Nconv(ll.N, 0), Nconv(n.Ntest, 0), Tconv(ll.N.Type, 0), Tconv(t, 0)) - } else { - Yyerror("invalid case %v in switch (mismatched types %v and bool)", Nconv(ll.N, 0), Tconv(ll.N.Type, 0)) - } - } else if nilonly != "" && !Isconst(ll.N, CTNIL) { - Yyerror("invalid case %v in switch (can only compare %s %v to nil)", Nconv(ll.N, 0), nilonly, Nconv(n.Ntest, 0)) - } - - case Etype: // type switch - if ll.N.Op == OLITERAL && Istype(ll.N.Type, TNIL) { - } else if ll.N.Op != OTYPE && ll.N.Type != nil { // should this be ||? - Yyerror("%v is not a type", Nconv(ll.N, obj.FmtLong)) - - // reset to original type - ll.N = n.Ntest.Right - } else if ll.N.Type.Etype != TINTER && t.Etype == TINTER && !implements(ll.N.Type, t, &missing, &have, &ptr) { - if have != nil && missing.Broke == 0 && have.Broke == 0 { - Yyerror("impossible type switch case: %v cannot have dynamic type %v"+" (wrong type for %v method)\n\thave %v%v\n\twant %v%v", Nconv(n.Ntest.Right, obj.FmtLong), Tconv(ll.N.Type, 0), Sconv(missing.Sym, 0), Sconv(have.Sym, 0), Tconv(have.Type, obj.FmtShort), Sconv(missing.Sym, 0), Tconv(missing.Type, obj.FmtShort)) - } else if missing.Broke == 0 { - Yyerror("impossible type switch case: %v cannot have dynamic type %v"+" (missing %v method)", Nconv(n.Ntest.Right, obj.FmtLong), Tconv(ll.N.Type, 0), Sconv(missing.Sym, 0)) - } - } - } - } - } - - if top == Etype && n.Type != nil { - ll = ncase.List - nvar = ncase.Nname - if nvar != nil { - if ll != nil && ll.Next == nil && ll.N.Type != nil && !Istype(ll.N.Type, TNIL) { - // single entry type switch - nvar.Ntype = typenod(ll.N.Type) - } else { - // multiple entry type switch or default - nvar.Ntype = typenod(n.Type) - } - - typecheck(&nvar, Erv|Easgn) - ncase.Nname = nvar - } - } - - typechecklist(ncase.Nbody, Etop) - } - - lineno = int32(lno) + return liststmt(list(init, c)) +} + +// walkCases generates an AST implementing the cases in cc. +func (s *typeSwitch) walkCases(cc []*caseClause) *Node { + if len(cc) < binarySearchMin { + var cas *NodeList + for _, c := range cc { + n := c.node + if c.typ != caseKindTypeConst { + Fatal("typeSwitch walkCases") + } + a := Nod(OIF, nil, nil) + a.Ntest = Nod(OEQ, s.hashname, Nodintconst(int64(c.hash))) + typecheck(&a.Ntest, Erv) + a.Nbody = list1(n.Right) + cas = list(cas, a) + } + return liststmt(cas) + } + + // find the middle and recur + half := len(cc) / 2 + a := Nod(OIF, nil, nil) + a.Ntest = Nod(OLE, s.hashname, Nodintconst(int64(cc[half-1].hash))) + typecheck(&a.Ntest, Erv) + a.Nbody = list1(s.walkCases(cc[:half])) + a.Nelse = list1(s.walkCases(cc[half:])) + return a +} + +type caseClauseByOrd []*caseClause + +func (x caseClauseByOrd) Len() int { return len(x) } +func (x caseClauseByOrd) Swap(i, j int) { x[i], x[j] = x[j], x[i] } +func (x caseClauseByOrd) Less(i, j int) bool { + c1, c2 := x[i], x[j] + switch { + // sort default first + case c1.typ == caseKindDefault: + return true + case c2.typ == caseKindDefault: + return false + + // sort nil second + case c1.typ == caseKindTypeNil: + return true + case c2.typ == caseKindTypeNil: + return false + } + + // sort by ordinal + return c1.ordinal < c2.ordinal +} + +type caseClauseByExpr []*caseClause + +func (x caseClauseByExpr) Len() int { return len(x) } +func (x caseClauseByExpr) Swap(i, j int) { x[i], x[j] = x[j], x[i] } +func (x caseClauseByExpr) Less(i, j int) bool { + return exprcmp(x[i], x[j]) < 0 +} + +func exprcmp(c1, c2 *caseClause) int { + // sort non-constants last + if c1.typ != caseKindExprConst { + return +1 + } + if c2.typ != caseKindExprConst { + return -1 + } + + n1 := c1.node.Left + n2 := c2.node.Left + + // sort by type (for switches on interface) + ct := int(n1.Val.Ctype) + if ct > int(n2.Val.Ctype) { + return +1 + } + if ct < int(n2.Val.Ctype) { + return -1 + } + if !Eqtype(n1.Type, n2.Type) { + if n1.Type.Vargen > n2.Type.Vargen { + return +1 + } else { + return -1 + } + } + + // sort by constant value to enable binary search + switch ct { + case CTFLT: + return mpcmpfltflt(n1.Val.U.Fval, n2.Val.U.Fval) + case CTINT, CTRUNE: + return Mpcmpfixfix(n1.Val.U.Xval, n2.Val.U.Xval) + case CTSTR: + return cmpslit(n1, n2) + } + + return 0 +} + +type caseClauseByType []*caseClause + +func (x caseClauseByType) Len() int { return len(x) } +func (x caseClauseByType) Swap(i, j int) { x[i], x[j] = x[j], x[i] } +func (x caseClauseByType) Less(i, j int) bool { + c1, c2 := x[i], x[j] + switch { + // sort non-constants last + case c1.typ != caseKindTypeConst: + return false + case c2.typ != caseKindTypeConst: + return true + + // sort by hash code + case c1.hash != c2.hash: + return c1.hash < c2.hash + } + + // sort by ordinal + return c1.ordinal < c2.ordinal +} + +func dumpcase(cc []*caseClause) { + for _, c := range cc { + switch c.typ { + case caseKindDefault: + fmt.Printf("case-default\n") + fmt.Printf("\tord=%d\n", c.ordinal) + + case caseKindExprConst: + fmt.Printf("case-exprconst\n") + fmt.Printf("\tord=%d\n", c.ordinal) + + case caseKindExprVar: + fmt.Printf("case-exprvar\n") + fmt.Printf("\tord=%d\n", c.ordinal) + fmt.Printf("\top=%v\n", Oconv(int(c.node.Left.Op), 0)) + + case caseKindTypeNil: + fmt.Printf("case-typenil\n") + fmt.Printf("\tord=%d\n", c.ordinal) + + case caseKindTypeConst: + fmt.Printf("case-typeconst\n") + fmt.Printf("\tord=%d\n", c.ordinal) + fmt.Printf("\thash=%x\n", c.hash) + + case caseKindTypeVar: + fmt.Printf("case-typevar\n") + fmt.Printf("\tord=%d\n", c.ordinal) + + default: + fmt.Printf("case-???\n") + fmt.Printf("\tord=%d\n", c.ordinal) + fmt.Printf("\top=%v\n", Oconv(int(c.node.Left.Op), 0)) + fmt.Printf("\thash=%x\n", c.hash) + } + } + + fmt.Printf("\n") }