From 66bf59712e6ee1ea84ce88ff35cea78e525ac5a7 Mon Sep 17 00:00:00 2001 From: Alan Donovan Date: Mon, 28 Jan 2013 18:06:14 -0500 Subject: [PATCH] exp/ssa: (#2 of 5): core utilities This CL includes the implementation of Literal, all the Value.String and Instruction.String methods, the sanity checker, and other misc utilities. R=gri, iant, iant CC=golang-dev https://golang.org/cl/7199052 --- src/pkg/exp/ssa/literal.go | 137 +++++++++++++ src/pkg/exp/ssa/print.go | 383 +++++++++++++++++++++++++++++++++++++ src/pkg/exp/ssa/sanity.go | 263 +++++++++++++++++++++++++ src/pkg/exp/ssa/ssa.go | 21 +- src/pkg/exp/ssa/util.go | 172 +++++++++++++++++ 5 files changed, 961 insertions(+), 15 deletions(-) create mode 100644 src/pkg/exp/ssa/literal.go create mode 100644 src/pkg/exp/ssa/print.go create mode 100644 src/pkg/exp/ssa/sanity.go create mode 100644 src/pkg/exp/ssa/util.go diff --git a/src/pkg/exp/ssa/literal.go b/src/pkg/exp/ssa/literal.go new file mode 100644 index 0000000000..fa26c47e92 --- /dev/null +++ b/src/pkg/exp/ssa/literal.go @@ -0,0 +1,137 @@ +package ssa + +// This file defines the Literal SSA value type. + +import ( + "fmt" + "go/types" + "math/big" + "strconv" +) + +// newLiteral returns a new literal of the specified value and type. +// val must be valid according to the specification of Literal.Value. +// +func newLiteral(val interface{}, typ types.Type) *Literal { + // This constructor exists to provide a single place to + // insert logging/assertions during debugging. + return &Literal{typ, val} +} + +// intLiteral returns an untyped integer literal that evaluates to i. +func intLiteral(i int64) *Literal { + return newLiteral(i, types.Typ[types.UntypedInt]) +} + +// nilLiteral returns a nil literal of the specified (reference) type. +func nilLiteral(typ types.Type) *Literal { + return newLiteral(types.NilType{}, typ) +} + +func (l *Literal) Name() string { + var s string + switch x := l.Value.(type) { + case bool: + s = fmt.Sprintf("%v", l.Value) + case int64: + s = fmt.Sprintf("%d", l.Value) + case *big.Int: + s = x.String() + case *big.Rat: + s = x.FloatString(20) + case string: + if len(x) > 20 { + x = x[:17] + "..." // abbreviate + } + s = strconv.Quote(x) + case types.Complex: + r := x.Re.FloatString(20) + i := x.Im.FloatString(20) + s = fmt.Sprintf("%s+%si", r, i) + case types.NilType: + s = "nil" + default: + panic(fmt.Sprintf("unexpected literal value: %T", x)) + } + return s + ":" + l.Type_.String() +} + +func (l *Literal) Type() types.Type { + return l.Type_ +} + +// IsNil returns true if this literal represents a typed or untyped nil value. +func (l *Literal) IsNil() bool { + _, ok := l.Value.(types.NilType) + return ok +} + +// Int64 returns the numeric value of this literal truncated to fit +// a signed 64-bit integer. +// +func (l *Literal) Int64() int64 { + switch x := l.Value.(type) { + case int64: + return x + case *big.Int: + return x.Int64() + case *big.Rat: + // TODO(adonovan): fix: is this the right rounding mode? + var q big.Int + return q.Quo(x.Num(), x.Denom()).Int64() + } + panic(fmt.Sprintf("unexpected literal value: %T", l.Value)) +} + +// Uint64 returns the numeric value of this literal truncated to fit +// an unsigned 64-bit integer. +// +func (l *Literal) Uint64() uint64 { + switch x := l.Value.(type) { + case int64: + if x < 0 { + return 0 + } + return uint64(x) + case *big.Int: + return x.Uint64() + case *big.Rat: + // TODO(adonovan): fix: is this right? + var q big.Int + return q.Quo(x.Num(), x.Denom()).Uint64() + } + panic(fmt.Sprintf("unexpected literal value: %T", l.Value)) +} + +// Float64 returns the numeric value of this literal truncated to fit +// a float64. +// +func (l *Literal) Float64() float64 { + switch x := l.Value.(type) { + case int64: + return float64(x) + case *big.Int: + var r big.Rat + f, _ := r.SetInt(x).Float64() + return f + case *big.Rat: + f, _ := x.Float64() + return f + } + panic(fmt.Sprintf("unexpected literal value: %T", l.Value)) +} + +// Complex128 returns the complex value of this literal truncated to +// fit a complex128. +// +func (l *Literal) Complex128() complex128 { + switch x := l.Value.(type) { + case int64, *big.Int, *big.Rat: + return complex(l.Float64(), 0) + case types.Complex: + re64, _ := x.Re.Float64() + im64, _ := x.Im.Float64() + return complex(re64, im64) + } + panic(fmt.Sprintf("unexpected literal value: %T", l.Value)) +} diff --git a/src/pkg/exp/ssa/print.go b/src/pkg/exp/ssa/print.go new file mode 100644 index 0000000000..b8708b6ede --- /dev/null +++ b/src/pkg/exp/ssa/print.go @@ -0,0 +1,383 @@ +package ssa + +// This file implements the String() methods for all Value and +// Instruction types. + +import ( + "bytes" + "fmt" + "go/ast" + "go/types" +) + +func (id Id) String() string { + if id.Pkg == nil { + return id.Name + } + return fmt.Sprintf("%s/%s", id.Pkg.Path, id.Name) +} + +// relName returns the name of v relative to i. +// In most cases, this is identical to v.Name(), but for cross-package +// references to Functions (including methods) and Globals, the +// package-qualified FullName is used instead. +// +func relName(v Value, i Instruction) string { + switch v := v.(type) { + case *Global: + if v.Pkg == i.Block().Func.Pkg { + return v.Name() + } + return v.FullName() + case *Function: + if v.Pkg == nil || v.Pkg == i.Block().Func.Pkg { + return v.Name() + } + return v.FullName() + } + return v.Name() +} + +// Value.String() +// +// This method is provided only for debugging. +// It never appears in disassembly, which uses Value.Name(). + +func (v *Literal) String() string { + return fmt.Sprintf("literal %s rep=%T", v.Name(), v.Value) +} + +func (v *Parameter) String() string { + return fmt.Sprintf("parameter %s : %s", v.Name(), v.Type()) +} + +func (v *Capture) String() string { + return fmt.Sprintf("capture %s : %s", v.Name(), v.Type()) +} + +func (v *Global) String() string { + return fmt.Sprintf("global %s : %s", v.Name(), v.Type()) +} + +func (v *Builtin) String() string { + return fmt.Sprintf("builtin %s : %s", v.Name(), v.Type()) +} + +func (r *Function) String() string { + return fmt.Sprintf("function %s : %s", r.Name(), r.Type()) +} + +// FullName returns the name of this function qualified by the +// package name, unless it is anonymous or synthetic. +// +// TODO(adonovan): move to func.go when it's submitted. +// +func (f *Function) FullName() string { + if f.Enclosing != nil || f.Pkg == nil { + return f.Name_ // anonymous or synthetic + } + return fmt.Sprintf("%s.%s", f.Pkg.ImportPath, f.Name_) +} + +// FullName returns g's package-qualified name. +func (g *Global) FullName() string { + return fmt.Sprintf("%s.%s", g.Pkg.ImportPath, g.Name_) +} + +// Instruction.String() + +func (v *Alloc) String() string { + op := "local" + if v.Heap { + op = "new" + } + return fmt.Sprintf("%s %s", op, indirectType(v.Type())) +} + +func (v *Phi) String() string { + var b bytes.Buffer + b.WriteString("phi [") + for i, edge := range v.Edges { + if i > 0 { + b.WriteString(", ") + } + // Be robust against malformed CFG. + blockname := "?" + if v.Block_ != nil && i < len(v.Block_.Preds) { + blockname = v.Block_.Preds[i].Name + } + b.WriteString(blockname) + b.WriteString(": ") + b.WriteString(relName(edge, v)) + } + b.WriteString("]") + return b.String() +} + +func printCall(v *CallCommon, prefix string, instr Instruction) string { + var b bytes.Buffer + b.WriteString(prefix) + if v.Func != nil { + b.WriteString(relName(v.Func, instr)) + } else { + name := underlyingType(v.Recv.Type()).(*types.Interface).Methods[v.Method].Name + fmt.Fprintf(&b, "invoke %s.%s [#%d]", relName(v.Recv, instr), name, v.Method) + } + b.WriteString("(") + for i, arg := range v.Args { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(relName(arg, instr)) + } + if v.HasEllipsis { + b.WriteString("...") + } + b.WriteString(")") + return b.String() +} + +func (v *Call) String() string { + return printCall(&v.CallCommon, "", v) +} + +func (v *BinOp) String() string { + return fmt.Sprintf("%s %s %s", relName(v.X, v), v.Op.String(), relName(v.Y, v)) +} + +func (v *UnOp) String() string { + return fmt.Sprintf("%s%s%s", v.Op, relName(v.X, v), commaOk(v.CommaOk)) +} + +func (v *Conv) String() string { + return fmt.Sprintf("convert %s <- %s (%s)", v.Type(), v.X.Type(), relName(v.X, v)) +} + +func (v *ChangeInterface) String() string { + return fmt.Sprintf("change interface %s <- %s (%s)", v.Type(), v.X.Type(), relName(v.X, v)) +} + +func (v *MakeInterface) String() string { + return fmt.Sprintf("make interface %s <- %s (%s)", v.Type(), v.X.Type(), relName(v.X, v)) +} + +func (v *MakeClosure) String() string { + var b bytes.Buffer + fmt.Fprintf(&b, "make closure %s", relName(v.Fn, v)) + if v.Bindings != nil { + b.WriteString(" [") + for i, c := range v.Bindings { + if i > 0 { + b.WriteString(", ") + } + b.WriteString(relName(c, v)) + } + b.WriteString("]") + } + return b.String() +} + +func (v *MakeSlice) String() string { + var b bytes.Buffer + b.WriteString("make slice ") + b.WriteString(v.Type().String()) + b.WriteString(" ") + b.WriteString(relName(v.Len, v)) + b.WriteString(" ") + b.WriteString(relName(v.Cap, v)) + return b.String() +} + +func (v *Slice) String() string { + var b bytes.Buffer + b.WriteString("slice ") + b.WriteString(relName(v.X, v)) + b.WriteString("[") + if v.Low != nil { + b.WriteString(relName(v.Low, v)) + } + b.WriteString(":") + if v.High != nil { + b.WriteString(relName(v.High, v)) + } + b.WriteString("]") + return b.String() +} + +func (v *MakeMap) String() string { + res := "" + if v.Reserve != nil { + res = relName(v.Reserve, v) + } + return fmt.Sprintf("make %s %s", v.Type(), res) +} + +func (v *MakeChan) String() string { + return fmt.Sprintf("make %s %s", v.Type(), relName(v.Size, v)) +} + +func (v *FieldAddr) String() string { + fields := underlyingType(indirectType(v.X.Type())).(*types.Struct).Fields + // Be robust against a bad index. + name := "?" + if v.Field >= 0 && v.Field < len(fields) { + name = fields[v.Field].Name + } + return fmt.Sprintf("&%s.%s [#%d]", relName(v.X, v), name, v.Field) +} + +func (v *Field) String() string { + fields := underlyingType(v.X.Type()).(*types.Struct).Fields + // Be robust against a bad index. + name := "?" + if v.Field >= 0 && v.Field < len(fields) { + name = fields[v.Field].Name + } + return fmt.Sprintf("%s.%s [#%d]", relName(v.X, v), name, v.Field) +} + +func (v *IndexAddr) String() string { + return fmt.Sprintf("&%s[%s]", relName(v.X, v), relName(v.Index, v)) +} + +func (v *Index) String() string { + return fmt.Sprintf("%s[%s]", relName(v.X, v), relName(v.Index, v)) +} + +func (v *Lookup) String() string { + return fmt.Sprintf("%s[%s]%s", relName(v.X, v), relName(v.Index, v), commaOk(v.CommaOk)) +} + +func (v *Range) String() string { + return "range " + relName(v.X, v) +} + +func (v *Next) String() string { + return "next " + relName(v.Iter, v) +} + +func (v *TypeAssert) String() string { + return fmt.Sprintf("typeassert%s %s.(%s)", commaOk(v.CommaOk), relName(v.X, v), v.AssertedType) +} + +func (v *Extract) String() string { + return fmt.Sprintf("extract %s #%d", relName(v.Tuple, v), v.Index) +} + +func (s *Jump) String() string { + // Be robust against malformed CFG. + blockname := "?" + if s.Block_ != nil && len(s.Block_.Succs) == 1 { + blockname = s.Block_.Succs[0].Name + } + return fmt.Sprintf("jump %s", blockname) +} + +func (s *If) String() string { + // Be robust against malformed CFG. + tblockname, fblockname := "?", "?" + if s.Block_ != nil && len(s.Block_.Succs) == 2 { + tblockname = s.Block_.Succs[0].Name + fblockname = s.Block_.Succs[1].Name + } + return fmt.Sprintf("if %s goto %s else %s", relName(s.Cond, s), tblockname, fblockname) +} + +func (s *Go) String() string { + return printCall(&s.CallCommon, "go ", s) +} + +func (s *Ret) String() string { + var b bytes.Buffer + b.WriteString("ret") + for i, r := range s.Results { + if i == 0 { + b.WriteString(" ") + } else { + b.WriteString(", ") + } + b.WriteString(relName(r, s)) + } + return b.String() +} + +func (s *Send) String() string { + return fmt.Sprintf("send %s <- %s", relName(s.Chan, s), relName(s.X, s)) +} + +func (s *Defer) String() string { + return printCall(&s.CallCommon, "defer ", s) +} + +func (s *Select) String() string { + var b bytes.Buffer + for i, st := range s.States { + if i > 0 { + b.WriteString(", ") + } + if st.Dir == ast.RECV { + b.WriteString("<-") + b.WriteString(relName(st.Chan, s)) + } else { + b.WriteString(relName(st.Chan, s)) + b.WriteString("<-") + b.WriteString(relName(st.Send, s)) + } + } + non := "" + if !s.Blocking { + non = "non" + } + return fmt.Sprintf("select %sblocking [%s]", non, b.String()) +} + +func (s *Store) String() string { + return fmt.Sprintf("*%s = %s", relName(s.Addr, s), relName(s.Val, s)) +} + +func (s *MapUpdate) String() string { + return fmt.Sprintf("%s[%s] = %s", relName(s.Map, s), relName(s.Key, s), relName(s.Value, s)) +} + +func (p *Package) String() string { + // TODO(adonovan): prettify output. + var b bytes.Buffer + fmt.Fprintf(&b, "Package %s at %s:\n", p.ImportPath, p.Prog.Files.File(p.Pos).Name()) + + // TODO(adonovan): make order deterministic. + maxname := 0 + for name := range p.Members { + if l := len(name); l > maxname { + maxname = l + } + } + + for name, mem := range p.Members { + switch mem := mem.(type) { + case *Literal: + fmt.Fprintf(&b, " const %-*s %s\n", maxname, name, mem.Name()) + + case *Function: + fmt.Fprintf(&b, " func %-*s %s\n", maxname, name, mem.Type()) + + case *Type: + fmt.Fprintf(&b, " type %-*s %s\n", maxname, name, mem.NamedType.Underlying) + // TODO(adonovan): make order deterministic. + for name, method := range mem.Methods { + fmt.Fprintf(&b, " method %s %s\n", name, method.Signature) + } + + case *Global: + fmt.Fprintf(&b, " var %-*s %s\n", maxname, name, mem.Type()) + + } + } + return b.String() +} + +func commaOk(x bool) string { + if x { + return ",ok" + } + return "" +} diff --git a/src/pkg/exp/ssa/sanity.go b/src/pkg/exp/ssa/sanity.go new file mode 100644 index 0000000000..bbb30cfcf4 --- /dev/null +++ b/src/pkg/exp/ssa/sanity.go @@ -0,0 +1,263 @@ +package ssa + +// An optional pass for sanity checking invariants of the SSA representation. +// Currently it checks CFG invariants but little at the instruction level. + +import ( + "bytes" + "fmt" + "io" + "os" +) + +type sanity struct { + reporter io.Writer + fn *Function + block *BasicBlock + insane bool +} + +// SanityCheck performs integrity checking of the SSA representation +// of the function fn and returns true if it was valid. Diagnostics +// are written to reporter if non-nil, os.Stderr otherwise. Some +// diagnostics are only warnings and do not imply a negative result. +// +// Sanity checking is intended to facilitate the debugging of code +// transformation passes. +// +func SanityCheck(fn *Function, reporter io.Writer) bool { + if reporter == nil { + reporter = os.Stderr + } + return (&sanity{reporter: reporter}).checkFunction(fn) +} + +// MustSanityCheck is like SanityCheck but panics instead of returning +// a negative result. +// +func MustSanityCheck(fn *Function, reporter io.Writer) { + if !SanityCheck(fn, reporter) { + panic("SanityCheck failed") + } +} + +// blockNames returns the names of the specified blocks as a +// human-readable string. +// +func blockNames(blocks []*BasicBlock) string { + var buf bytes.Buffer + for i, b := range blocks { + if i > 0 { + io.WriteString(&buf, ", ") + } + io.WriteString(&buf, b.Name) + } + return buf.String() +} + +func (s *sanity) diagnostic(prefix, format string, args ...interface{}) { + fmt.Fprintf(s.reporter, "%s: function %s", prefix, s.fn.FullName()) + if s.block != nil { + fmt.Fprintf(s.reporter, ", block %s", s.block.Name) + } + io.WriteString(s.reporter, ": ") + fmt.Fprintf(s.reporter, format, args...) + io.WriteString(s.reporter, "\n") +} + +func (s *sanity) errorf(format string, args ...interface{}) { + s.insane = true + s.diagnostic("Error", format, args...) +} + +func (s *sanity) warnf(format string, args ...interface{}) { + s.diagnostic("Warning", format, args...) +} + +// findDuplicate returns an arbitrary basic block that appeared more +// than once in blocks, or nil if all were unique. +func findDuplicate(blocks []*BasicBlock) *BasicBlock { + if len(blocks) < 2 { + return nil + } + if blocks[0] == blocks[1] { + return blocks[0] + } + // Slow path: + m := make(map[*BasicBlock]bool) + for _, b := range blocks { + if m[b] { + return b + } + m[b] = true + } + return nil +} + +func (s *sanity) checkInstr(idx int, instr Instruction) { + switch instr := instr.(type) { + case *If, *Jump, *Ret: + s.errorf("control flow instruction not at end of block") + case *Phi: + if idx == 0 { + // It suffices to apply this check to just the first phi node. + if dup := findDuplicate(s.block.Preds); dup != nil { + s.errorf("phi node in block with duplicate predecessor %s", dup.Name) + } + } else { + prev := s.block.Instrs[idx-1] + if _, ok := prev.(*Phi); !ok { + s.errorf("Phi instruction follows a non-Phi: %T", prev) + } + } + if ne, np := len(instr.Edges), len(s.block.Preds); ne != np { + s.errorf("phi node has %d edges but %d predecessors", ne, np) + } + + case *Alloc: + case *Call: + case *BinOp: + case *UnOp: + case *MakeClosure: + case *MakeChan: + case *MakeMap: + case *MakeSlice: + case *Slice: + case *Field: + case *FieldAddr: + case *IndexAddr: + case *Index: + case *Select: + case *Range: + case *TypeAssert: + case *Extract: + case *Go: + case *Defer: + case *Send: + case *Store: + case *MapUpdate: + case *Next: + case *Lookup: + case *Conv: + case *ChangeInterface: + case *MakeInterface: + // TODO(adonovan): implement checks. + default: + panic(fmt.Sprintf("Unknown instruction type: %T", instr)) + } +} + +func (s *sanity) checkFinalInstr(idx int, instr Instruction) { + switch instr.(type) { + case *If: + if nsuccs := len(s.block.Succs); nsuccs != 2 { + s.errorf("If-terminated block has %d successors; expected 2", nsuccs) + return + } + if s.block.Succs[0] == s.block.Succs[1] { + s.errorf("If-instruction has same True, False target blocks: %s", s.block.Succs[0].Name) + return + } + + case *Jump: + if nsuccs := len(s.block.Succs); nsuccs != 1 { + s.errorf("Jump-terminated block has %d successors; expected 1", nsuccs) + return + } + + case *Ret: + if nsuccs := len(s.block.Succs); nsuccs != 0 { + s.errorf("Ret-terminated block has %d successors; expected none", nsuccs) + return + } + // TODO(adonovan): check number and types of results + + default: + s.errorf("non-control flow instruction at end of block") + } +} + +func (s *sanity) checkBlock(b *BasicBlock, isEntry bool) { + s.block = b + + // Check all blocks are reachable. + // (The entry block is always implicitly reachable.) + if !isEntry && len(b.Preds) == 0 { + s.warnf("unreachable block") + if b.Instrs == nil { + // Since this block is about to be pruned, + // tolerating transient problems in it + // simplifies other optimisations. + return + } + } + + // Check predecessor and successor relations are dual. + for _, a := range b.Preds { + found := false + for _, bb := range a.Succs { + if bb == b { + found = true + break + } + } + if !found { + s.errorf("expected successor edge in predecessor %s; found only: %s", a.Name, blockNames(a.Succs)) + } + } + for _, c := range b.Succs { + found := false + for _, bb := range c.Preds { + if bb == b { + found = true + break + } + } + if !found { + s.errorf("expected predecessor edge in successor %s; found only: %s", c.Name, blockNames(c.Preds)) + } + } + + // Check each instruction is sane. + n := len(b.Instrs) + if n == 0 { + s.errorf("basic block contains no instructions") + } + for j, instr := range b.Instrs { + if b2 := instr.Block(); b2 == nil { + s.errorf("nil Block() for instruction at index %d", j) + continue + } else if b2 != b { + s.errorf("wrong Block() (%s) for instruction at index %d ", b2.Name, j) + continue + } + if j < n-1 { + s.checkInstr(j, instr) + } else { + s.checkFinalInstr(j, instr) + } + } +} + +func (s *sanity) checkFunction(fn *Function) bool { + // TODO(adonovan): check Function invariants: + // - check owning Package (if any) contains this function. + // - check params match signature + // - check locals are all !Heap + // - check transient fields are nil + // - check block labels are unique (warning) + s.fn = fn + if fn.Prog == nil { + s.errorf("nil Prog") + } + for i, b := range fn.Blocks { + if b == nil { + s.warnf("nil *BasicBlock at f.Blocks[%d]", i) + continue + } + s.checkBlock(b, i == 0) + } + s.block = nil + s.fn = nil + return !s.insane +} diff --git a/src/pkg/exp/ssa/ssa.go b/src/pkg/exp/ssa/ssa.go index 904280ae20..8e503dc35b 100644 --- a/src/pkg/exp/ssa/ssa.go +++ b/src/pkg/exp/ssa/ssa.go @@ -222,12 +222,12 @@ type Function struct { // The following fields are set transiently during building, // then cleared. - currentBlock *BasicBlock // where to emit code - objects map[types.Object]Value // addresses of local variables - results []*Alloc // tuple of named results - // syntax *funcSyntax // abstract syntax trees for Go source functions - // targets *targets // linked stack of branch targets - // lblocks map[*ast.Object]*lblock // labelled blocks + currentBlock *BasicBlock // where to emit code + objects map[types.Object]Value // addresses of local variables + results []*Alloc // tuple of named results + syntax *funcSyntax // abstract syntax trees for Go source functions + targets *targets // linked stack of branch targets + lblocks map[*ast.Object]*lblock // labelled blocks } // An SSA basic block. @@ -984,18 +984,9 @@ func (v *Capture) Name() string { return v.Outer.Name() } func (v *Global) Type() types.Type { return v.Type_ } func (v *Global) Name() string { return v.Name_ } -func (v *Global) String() string { return v.Name_ } // placeholder func (v *Function) Name() string { return v.Name_ } func (v *Function) Type() types.Type { return v.Signature } -func (v *Function) String() string { return v.Name_ } // placeholder - -// FullName returns v's package-qualified name. -func (v *Global) FullName() string { return fmt.Sprintf("%s.%s", v.Pkg.ImportPath, v.Name_) } - -func (v *Literal) Name() string { return "Literal" } // placeholder -func (v *Literal) String() string { return "Literal" } // placeholder -func (v *Literal) Type() types.Type { return v.Type_ } // placeholder func (v *Parameter) Type() types.Type { return v.Type_ } func (v *Parameter) Name() string { return v.Name_ } diff --git a/src/pkg/exp/ssa/util.go b/src/pkg/exp/ssa/util.go new file mode 100644 index 0000000000..0d2ebde268 --- /dev/null +++ b/src/pkg/exp/ssa/util.go @@ -0,0 +1,172 @@ +package ssa + +// This file defines a number of miscellaneous utility functions. + +import ( + "fmt" + "go/ast" + "go/types" +) + +func unreachable() { + panic("unreachable") +} + +//// AST utilities + +// noparens returns e with any enclosing parentheses stripped. +func noparens(e ast.Expr) ast.Expr { + for { + p, ok := e.(*ast.ParenExpr) + if !ok { + break + } + e = p.X + } + return e +} + +// isBlankIdent returns true iff e is an Ident with name "_". +// They have no associated types.Object, and thus no type. +// +// TODO(gri): consider making typechecker not treat them differently. +// It's one less thing for clients like us to worry about. +// +func isBlankIdent(e ast.Expr) bool { + id, ok := e.(*ast.Ident) + return ok && id.Name == "_" +} + +//// Type utilities. Some of these belong in go/types. + +// underlyingType returns the underlying type of typ. +// TODO(gri): this is a copy of go/types.underlying; export that function. +// +func underlyingType(typ types.Type) types.Type { + if typ, ok := typ.(*types.NamedType); ok { + return typ.Underlying // underlying types are never NamedTypes + } + if typ == nil { + panic("underlyingType(nil)") + } + return typ +} + +// isPointer returns true for types whose underlying type is a pointer. +func isPointer(typ types.Type) bool { + if nt, ok := typ.(*types.NamedType); ok { + typ = nt.Underlying + } + _, ok := typ.(*types.Pointer) + return ok +} + +// pointer(typ) returns the type that is a pointer to typ. +func pointer(typ types.Type) *types.Pointer { + return &types.Pointer{Base: typ} +} + +// indirect(typ) assumes that typ is a pointer type, +// or named alias thereof, and returns its base type. +// Panic ensures if it is not a pointer. +// +func indirectType(ptr types.Type) types.Type { + if v, ok := underlyingType(ptr).(*types.Pointer); ok { + return v.Base + } + // When debugging it is convenient to comment out this line + // and let it continue to print the (illegal) SSA form. + panic("indirect() of non-pointer type: " + ptr.String()) + return nil +} + +// deref returns a pointer's base type; otherwise it returns typ. +func deref(typ types.Type) types.Type { + if typ, ok := underlyingType(typ).(*types.Pointer); ok { + return typ.Base + } + return typ +} + +// methodIndex returns the method (and its index) named id within the +// method table methods of named or interface type typ. If not found, +// panic ensues. +// +func methodIndex(typ types.Type, methods []*types.Method, id Id) (i int, m *types.Method) { + for i, m = range methods { + if IdFromQualifiedName(m.QualifiedName) == id { + return + } + } + panic(fmt.Sprint("method not found: ", id, " in interface ", typ)) +} + +// objKind returns the syntactic category of the named entity denoted by obj. +func objKind(obj types.Object) ast.ObjKind { + switch obj.(type) { + case *types.Package: + return ast.Pkg + case *types.TypeName: + return ast.Typ + case *types.Const: + return ast.Con + case *types.Var: + return ast.Var + case *types.Func: + return ast.Fun + } + panic(fmt.Sprintf("unexpected Object type: %T", obj)) +} + +// DefaultType returns the default "typed" type for an "untyped" type; +// it returns the incoming type for all other types. If there is no +// corresponding untyped type, the result is types.Typ[types.Invalid]. +// +// Exported to exp/ssa/interp. +// +// TODO(gri): this is a copy of go/types.defaultType; export that function. +// +func DefaultType(typ types.Type) types.Type { + if t, ok := typ.(*types.Basic); ok { + k := types.Invalid + switch t.Kind { + // case UntypedNil: + // There is no default type for nil. For a good error message, + // catch this case before calling this function. + case types.UntypedBool: + k = types.Bool + case types.UntypedInt: + k = types.Int + case types.UntypedRune: + k = types.Rune + case types.UntypedFloat: + k = types.Float64 + case types.UntypedComplex: + k = types.Complex128 + case types.UntypedString: + k = types.String + } + typ = types.Typ[k] + } + return typ +} + +// makeId returns the Id (name, pkg) if the name is exported or +// (name, nil) otherwise. +// +func makeId(name string, pkg *types.Package) (id Id) { + id.Name = name + if !ast.IsExported(name) { + id.Pkg = pkg + } + return +} + +// IdFromQualifiedName returns the Id (qn.Name, qn.Pkg) if qn is an +// exported name or (qn.Name, nil) otherwise. +// +// Exported to exp/ssa/interp. +// +func IdFromQualifiedName(qn types.QualifiedName) Id { + return makeId(qn.Name, qn.Pkg) +}