1
0
mirror of https://github.com/golang/go synced 2024-10-04 13:21:22 -06:00
go/src/cmd/gc/swt.c
Ken Thompson 820f42d977 binary search for constant case statements.
R=r
OCL=25890
CL=25890
2009-03-07 17:33:42 -08:00

682 lines
11 KiB
C

// Copyright 2009 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
#include "go.h"
enum
{
Snorm = 0,
Strue,
Sfalse,
Stype,
};
Node* binarysw(Node *t, Iter *save, Node *name);
/*
* walktype
*/
Type*
sw0(Node *c, Type *place, int arg)
{
Node *r;
if(c == N)
return T;
switch(c->op) {
default:
if(arg == Stype) {
yyerror("inappropriate case for a type switch");
return T;
}
walktype(c, Erv);
return T;
case OTYPESW:
if(arg != Stype)
yyerror("inappropriate type case");
return T;
case OAS:
break;
}
walktype(c->left, Elv);
r = c->right;
if(c == N)
return T;
switch(r->op) {
default:
goto bad;
case ORECV:
// <-chan
walktype(r->left, Erv);
if(!istype(r->left->type, TCHAN))
goto bad;
break;
case OINDEX:
// map[e]
walktype(r->left, Erv);
if(!istype(r->left->type, TMAP))
goto bad;
break;
case ODOTTYPE:
// interface.(type)
walktype(r->left, Erv);
if(!istype(r->left->type, TINTER))
goto bad;
break;
}
c->type = types[TBOOL];
if(arg != Strue)
goto bad;
return T;
bad:
yyerror("inappropriate assignment in a case statement");
return T;
}
/*
* return the first type
*/
Type*
sw1(Node *c, Type *place, int arg)
{
if(place == T)
return c->type;
return place;
}
/*
* return a suitable type
*/
Type*
sw2(Node *c, Type *place, int arg)
{
return types[TINT]; // botch
}
/*
* check that switch type
* is compat with all the cases
*/
Type*
sw3(Node *c, Type *place, int arg)
{
if(place == T)
return c->type;
if(c->type == T)
c->type = place;
convlit(c, place);
if(!ascompat(place, c->type))
badtype(OSWITCH, place, c->type);
return place;
}
/*
* over all cases, call paramenter function.
* four passes of these are used to allocate
* types to cases and switch
*/
Type*
walkcases(Node *sw, Type*(*call)(Node*, Type*, int arg), int arg)
{
Iter save;
Node *n;
Type *place;
int32 lno;
lno = setlineno(sw);
place = call(sw->ntest, T, arg);
n = listfirst(&save, &sw->nbody->left);
if(n == N || n->op == OEMPTY)
return T;
loop:
if(n == N) {
lineno = lno;
return place;
}
if(n->op != OCASE)
fatal("walkcases: not case %O\n", n->op);
if(n->left != N) {
setlineno(n->left);
place = call(n->left, place, arg);
}
n = listnext(&save);
goto loop;
}
Node*
newlabel()
{
static int label;
label++;
snprint(namebuf, sizeof(namebuf), "%.6d", label);
return newname(lookup(namebuf));
}
/*
* build separate list of statements and cases
* make labels between cases and statements
* deal with fallthrough, break, unreachable statements
*/
void
casebody(Node *sw)
{
Iter save;
Node *os, *oc, *t, *c;
Node *cas, *stat, *def;
Node *go, *br;
int32 lno;
lno = setlineno(sw);
t = listfirst(&save, &sw->nbody);
if(t == N || t->op == OEMPTY) {
sw->nbody = nod(OLIST, N, N);
return;
}
cas = N; // cases
stat = N; // statements
def = N; // defaults
os = N; // last statement
oc = N; // last case
br = nod(OBREAK, N, N);
loop:
if(t == N) {
if(oc == N && os != N)
yyerror("first switch statement must be a case");
stat = list(stat, br);
cas = list(cas, def);
sw->nbody = nod(OLIST, rev(cas), rev(stat));
//dump("case", sw->nbody->left);
//dump("stat", sw->nbody->right);
lineno = lno;
return;
}
lno = setlineno(t);
switch(t->op) {
case OXCASE:
t->op = OCASE;
if(oc == N && os != N)
yyerror("first switch statement must be a case");
// botch - shouldnt fall thru declaration
if(os != N && os->op == OXFALL)
os->op = OFALL;
else
stat = list(stat, br);
go = nod(OGOTO, newlabel(), N);
c = t->left;
if(c == N) {
if(def != N)
yyerror("more than one default case");
// reuse original default case
t->right = go;
def = t;
}
// expand multi-valued cases
for(; c!=N; c=c->right) {
if(c->op != OLIST) {
// reuse original case
t->left = c;
t->right = go;
cas = list(cas, t);
break;
}
cas = list(cas, nod(OCASE, c->left, go));
}
stat = list(stat, nod(OLABEL, go->left, N));
oc = t;
os = N;
break;
default:
stat = list(stat, t);
os = t;
break;
}
t = listnext(&save);
goto loop;
}
/*
* rebulid case statements into if .. goto
*/
void
prepsw(Node *sw, int arg)
{
Iter save;
Node *name, *bool, *cas;
Node *t, *a;
//dump("prepsw before", sw->nbody->left);
cas = N;
name = N;
bool = N;
if(arg != Strue && arg != Sfalse) {
name = nod(OXXX, N, N);
tempname(name, sw->ntest->type);
cas = nod(OAS, name, sw->ntest);
}
t = listfirst(&save, &sw->nbody->left);
loop:
if(t == N) {
sw->nbody->left = rev(cas);
//dump("prepsw after", sw->nbody->left);
return;
}
if(t->left == N) {
cas = list(cas, t->right); // goto default
t = listnext(&save);
goto loop;
}
if(t->left->op == OAS) {
if(bool == N) {
bool = nod(OXXX, N, N);
tempname(bool, types[TBOOL]);
}
//dump("oas", t);
t->left->left = nod(OLIST, t->left->left, bool);
cas = list(cas, t->left); // v,bool = rhs
a = nod(OIF, N, N);
a->nbody = t->right; // then goto l
a->ntest = bool;
if(arg != Strue)
a->ntest = nod(ONOT, bool, N);
cas = list(cas, a); // if bool goto l
t = listnext(&save);
goto loop;
}
switch(arg) {
default:
// not bool const
a = binarysw(t, &save, name);
if(a != N)
break;
a = nod(OIF, N, N);
a->ntest = nod(OEQ, name, t->left); // if name == val
a->nbody = t->right; // then goto l
break;
case Strue:
a = nod(OIF, N, N);
a->ntest = t->left; // if val
a->nbody = t->right; // then goto l
break;
case Sfalse:
a = nod(OIF, N, N);
a->ntest = nod(ONOT, t->left, N); // if !val
a->nbody = t->right; // then goto l
break;
}
cas = list(cas, a);
t = listnext(&save);
goto loop;
}
/*
* convert switch of the form
* switch v := i.(type) { case t1: ..; case t2: ..; }
* into if statements
*/
void
typeswitch(Node *sw)
{
Iter save;
Node *face, *bool, *cas;
Node *t, *a, *b;
//dump("typeswitch", sw);
walktype(sw->ntest->right, Erv);
if(!istype(sw->ntest->right->type, TINTER)) {
yyerror("type switch must be on an interface");
return;
}
walkcases(sw, sw0, Stype);
/*
* predeclare variables for the interface var
* and the boolean var
*/
face = nod(OXXX, N, N);
tempname(face, sw->ntest->right->type);
cas = nod(OAS, face, sw->ntest->right);
bool = nod(OXXX, N, N);
tempname(bool, types[TBOOL]);
t = listfirst(&save, &sw->nbody->left);
loop:
if(t == N) {
sw->nbody->left = rev(cas);
walkstate(sw->nbody);
//dump("done", sw->nbody->left);
return;
}
if(t->left == N) {
cas = list(cas, t->right); // goto default
t = listnext(&save);
goto loop;
}
if(t->left->op != OTYPESW) {
t = listnext(&save);
goto loop;
}
a = t->left->left; // var
a = nod(OLIST, a, bool); // var,bool
b = nod(ODOTTYPE, face, N);
b->type = t->left->left->type; // interface.(type)
a = nod(OAS, a, b); // var,bool = interface.(type)
cas = list(cas, a);
a = nod(OIF, N, N);
a->ntest = bool;
a->nbody = t->right; // if bool { goto l }
cas = list(cas, a);
t = listnext(&save);
goto loop;
}
void
walkswitch(Node *sw)
{
Type *t;
int arg;
//dump("walkswitch", sw);
/*
* reorder the body into (OLIST, cases, statements)
* cases have OGOTO into statements.
* both have inserted OBREAK statements
*/
walkstate(sw->ninit);
if(sw->ntest == N)
sw->ntest = booltrue;
casebody(sw);
/*
* classify the switch test
* Strue or Sfalse if the test is a bool constant
* this allows cases to be map/chan/interface assignments
* as well as (boolean) expressions
* Stype if the test is v := interface.(type)
* this forces all cases to be types
* Snorm otherwise
* all cases are expressions
*/
if(sw->ntest->op == OTYPESW) {
typeswitch(sw);
return;
}
arg = Snorm;
if(whatis(sw->ntest) == Wlitbool) {
arg = Strue;
if(sw->ntest->val.u.xval == 0)
arg = Sfalse;
}
/*
* init statement is nothing important
*/
walktype(sw->ntest, Erv);
//print("after walkwalks\n");
/*
* pass 0,1,2,3
* walk the cases as appropriate for switch type
*/
walkcases(sw, sw0, arg);
t = sw->ntest->type;
if(t == T)
t = walkcases(sw, sw1, arg);
if(t == T)
t = walkcases(sw, sw2, arg);
if(t == T)
return;
walkcases(sw, sw3, arg);
convlit(sw->ntest, t);
//print("after walkcases\n");
/*
* convert the switch into OIF statements
*/
prepsw(sw, arg);
walkstate(sw->nbody);
//print("normal done\n");
}
/*
* binary search on cases
*/
enum
{
Ncase = 4, // needed to binary search
};
typedef struct Case Case;
struct Case
{
Node* node; // points at case statement
Case* link; // linked list to link
};
#define C ((Case*)nil)
int
iscaseconst(Node *t)
{
if(t == N || t->left == N)
return 0;
switch(whatis(t->left)) {
case Wlitfloat:
case Wlitint:
case Wlitstr:
return 1;
}
return 0;
}
int
countcase(Node *t, Iter save)
{
int n;
// note that the iter is by value,
// so cases are not really consumed
for(n=0;; n++) {
if(!iscaseconst(t))
return n;
t = listnext(&save);
}
}
Case*
csort(Case *l, int(*f)(Case*, Case*))
{
Case *l1, *l2, *le;
if(l == C || l->link == C)
return l;
l1 = l;
l2 = l;
for(;;) {
l2 = l2->link;
if(l2 == C)
break;
l2 = l2->link;
if(l2 == C)
break;
l1 = l1->link;
}
l2 = l1->link;
l1->link = C;
l1 = csort(l, f);
l2 = csort(l2, f);
/* set up lead element */
if((*f)(l1, l2) < 0) {
l = l1;
l1 = l1->link;
} else {
l = l2;
l2 = l2->link;
}
le = l;
for(;;) {
if(l1 == C) {
while(l2) {
le->link = l2;
le = l2;
l2 = l2->link;
}
le->link = C;
break;
}
if(l2 == C) {
while(l1) {
le->link = l1;
le = l1;
l1 = l1->link;
}
break;
}
if((*f)(l1, l2) < 0) {
le->link = l1;
le = l1;
l1 = l1->link;
} else {
le->link = l2;
le = l2;
l2 = l2->link;
}
}
le->link = C;
return l;
}
int
casecmp(Case *c1, Case *c2)
{
int w;
w = whatis(c1->node->left);
if(w != whatis(c2->node->left))
fatal("casecmp1");
switch(w) {
case Wlitfloat:
return mpcmpfltflt(c1->node->left->val.u.fval, c2->node->left->val.u.fval);
case Wlitint:
return mpcmpfixfix(c1->node->left->val.u.xval, c2->node->left->val.u.xval);
case Wlitstr:
return cmpslit(c1->node->left, c2->node->left);
}
fatal("casecmp2");
return 0;
}
Node*
constsw(Case *c0, int ncase, Node *name)
{
Node *cas, *a;
Case *c;
int i, n;
// small number do sequentially
if(ncase < Ncase) {
cas = N;
for(i=0; i<ncase; i++) {
a = nod(OIF, N, N);
a->ntest = nod(OEQ, name, c0->node->left);
a->nbody = c0->node->right;
cas = list(cas, a);
c0 = c0->link;
}
return rev(cas);
}
// find center and recur
c = c0;
n = ncase>>1;
for(i=0; i<n; i++)
c = c->link;
a = nod(OIF, N, N);
a->ntest = nod(OLE, name, c->node->left);
a->nbody = constsw(c0, n+1, name); // include center
a->nelse = constsw(c->link, ncase-n-1, name); // exclude center
return a;
}
Node*
binarysw(Node *t, Iter *save, Node *name)
{
Case *c, *c1;
int i, ncase;
Node *a;
ncase = countcase(t, *save);
if(ncase < Ncase)
return N;
c = C;
for(i=1; i<ncase; i++) {
c1 = mal(sizeof(*c1));
c1->link = c;
c1->node = t;
c = c1;
t = listnext(save);
}
// last one shouldnt consume the iter
c1 = mal(sizeof(*c1));
c1->link = c;
c1->node = t;
c = c1;
c = csort(c, casecmp);
a = constsw(c, ncase, name);
//dump("bin", a);
return a;
}