diff --git a/src/cmd/gc/swt.c b/src/cmd/gc/swt.c index fb3e4b095bb..5f014e9a9f5 100644 --- a/src/cmd/gc/swt.c +++ b/src/cmd/gc/swt.c @@ -11,6 +11,7 @@ enum Sfalse, Stype, }; +Node* binarysw(Node *t, Iter *save, Node *name); /* * walktype @@ -263,6 +264,7 @@ prepsw(Node *sw, int arg) Iter save; Node *name, *bool, *cas; Node *t, *a; +//dump("prepsw before", sw->nbody->left); cas = N; name = N; @@ -279,7 +281,7 @@ prepsw(Node *sw, int arg) loop: if(t == N) { sw->nbody->left = rev(cas); -//dump("case", sw->nbody->left); +//dump("prepsw after", sw->nbody->left); return; } @@ -309,21 +311,29 @@ loop: goto loop; } - a = nod(OIF, N, N); - a->nbody = t->right; // then goto l switch(arg) { default: // not bool const + a = binarysw(t, &save, name); + if(a != N) + break; + + a = nod(OIF, N, N); a->ntest = nod(OEQ, name, t->left); // if name == val + a->nbody = t->right; // then goto l break; case Strue: + a = nod(OIF, N, N); a->ntest = t->left; // if val + a->nbody = t->right; // then goto l break; case Sfalse: + a = nod(OIF, N, N); a->ntest = nod(ONOT, t->left, N); // if !val + a->nbody = t->right; // then goto l break; } cas = list(cas, a); @@ -470,3 +480,202 @@ walkswitch(Node *sw) walkstate(sw->nbody); //print("normal done\n"); } + +/* + * binary search on cases + */ +enum +{ + Ncase = 4, // needed to binary search +}; + +typedef struct Case Case; +struct Case +{ + Node* node; // points at case statement + Case* link; // linked list to link +}; +#define C ((Case*)nil) + +int +iscaseconst(Node *t) +{ + if(t == N || t->left == N) + return 0; + switch(whatis(t->left)) { + case Wlitfloat: + case Wlitint: + case Wlitstr: + return 1; + } + return 0; +} + +int +countcase(Node *t, Iter save) +{ + int n; + + // note that the iter is by value, + // so cases are not really consumed + for(n=0;; n++) { + if(!iscaseconst(t)) + return n; + t = listnext(&save); + } +} + +Case* +csort(Case *l, int(*f)(Case*, Case*)) +{ + Case *l1, *l2, *le; + + if(l == C || l->link == C) + return l; + + l1 = l; + l2 = l; + for(;;) { + l2 = l2->link; + if(l2 == C) + break; + l2 = l2->link; + if(l2 == C) + break; + l1 = l1->link; + } + + l2 = l1->link; + l1->link = C; + l1 = csort(l, f); + l2 = csort(l2, f); + + /* set up lead element */ + if((*f)(l1, l2) < 0) { + l = l1; + l1 = l1->link; + } else { + l = l2; + l2 = l2->link; + } + le = l; + + for(;;) { + if(l1 == C) { + while(l2) { + le->link = l2; + le = l2; + l2 = l2->link; + } + le->link = C; + break; + } + if(l2 == C) { + while(l1) { + le->link = l1; + le = l1; + l1 = l1->link; + } + break; + } + if((*f)(l1, l2) < 0) { + le->link = l1; + le = l1; + l1 = l1->link; + } else { + le->link = l2; + le = l2; + l2 = l2->link; + } + } + le->link = C; + return l; +} + +int +casecmp(Case *c1, Case *c2) +{ + int w; + + w = whatis(c1->node->left); + if(w != whatis(c2->node->left)) + fatal("casecmp1"); + + switch(w) { + case Wlitfloat: + return mpcmpfltflt(c1->node->left->val.u.fval, c2->node->left->val.u.fval); + case Wlitint: + return mpcmpfixfix(c1->node->left->val.u.xval, c2->node->left->val.u.xval); + case Wlitstr: + return cmpslit(c1->node->left, c2->node->left); + } + + fatal("casecmp2"); + return 0; +} + +Node* +constsw(Case *c0, int ncase, Node *name) +{ + Node *cas, *a; + Case *c; + int i, n; + + // small number do sequentially + if(ncase < Ncase) { + cas = N; + for(i=0; intest = nod(OEQ, name, c0->node->left); + a->nbody = c0->node->right; + cas = list(cas, a); + c0 = c0->link; + } + return rev(cas); + } + + // find center and recur + c = c0; + n = ncase>>1; + for(i=0; ilink; + + a = nod(OIF, N, N); + a->ntest = nod(OLE, name, c->node->left); + a->nbody = constsw(c0, n+1, name); // include center + a->nelse = constsw(c->link, ncase-n-1, name); // exclude center + return a; +} + +Node* +binarysw(Node *t, Iter *save, Node *name) +{ + Case *c, *c1; + int i, ncase; + Node *a; + + ncase = countcase(t, *save); + if(ncase < Ncase) + return N; + + c = C; + for(i=1; ilink = c; + c1->node = t; + c = c1; + + t = listnext(save); + } + + // last one shouldnt consume the iter + c1 = mal(sizeof(*c1)); + c1->link = c; + c1->node = t; + c = c1; + + c = csort(c, casecmp); + a = constsw(c, ncase, name); +//dump("bin", a); + return a; +}