1
0
mirror of https://github.com/golang/go synced 2024-10-01 09:28:37 -06:00

go.tools/go/types: handle interface types recurring via method signatures

Fixes golang/go#5090.

R=adonovan, gri, r, mtj
CC=golang-dev
https://golang.org/cl/14795044
This commit is contained in:
Robert Griesemer 2013-10-23 13:46:40 -07:00
parent c8f4d650c8
commit 93ef310aab
18 changed files with 576 additions and 150 deletions

View File

@ -48,6 +48,10 @@ var tests = [][]string{
{"testdata/errors.src"}, {"testdata/errors.src"},
{"testdata/importdecl0a.src", "testdata/importdecl0b.src"}, {"testdata/importdecl0a.src", "testdata/importdecl0b.src"},
{"testdata/cycles.src"}, {"testdata/cycles.src"},
{"testdata/cycles1.src"},
{"testdata/cycles2.src"},
{"testdata/cycles3.src"},
{"testdata/cycles4.src"},
{"testdata/decls0.src"}, {"testdata/decls0.src"},
{"testdata/decls1.src"}, {"testdata/decls1.src"},
{"testdata/decls2a.src", "testdata/decls2b.src"}, {"testdata/decls2a.src", "testdata/decls2b.src"},

View File

@ -14,7 +14,6 @@ import (
"strings" "strings"
) )
// TODO(gri) eventually assert should disappear.
func assert(p bool) { func assert(p bool) {
if !p { if !p {
panic("assertion failed") panic("assertion failed")
@ -33,7 +32,7 @@ func (check *checker) formatMsg(format string, args []interface{}) string {
case operand: case operand:
panic("internal error: should always pass *operand") panic("internal error: should always pass *operand")
case token.Pos: case token.Pos:
args[i] = check.fset.Position(a) args[i] = check.fset.Position(a).String()
case ast.Expr: case ast.Expr:
args[i] = exprString(a) args[i] = exprString(a)
} }
@ -270,6 +269,17 @@ func writeType(buf *bytes.Buffer, typ Type) {
fmt.Fprintf(buf, "<type of %s>", t.name) fmt.Fprintf(buf, "<type of %s>", t.name)
case *Interface: case *Interface:
// We write the source-level methods and embedded types rather
// than the actual method set since resolved method signatures
// may have non-printable cycles if parameters have anonymous
// interface types that (directly or indirectly) embed the
// current interface. For instance, consider the result type
// of m:
//
// type T interface{
// m() interface{ T }
// }
//
buf.WriteString("interface{") buf.WriteString("interface{")
for i, m := range t.methods { for i, m := range t.methods {
if i > 0 { if i > 0 {
@ -278,6 +288,12 @@ func writeType(buf *bytes.Buffer, typ Type) {
buf.WriteString(m.name) buf.WriteString(m.name)
writeSignature(buf, m.typ.(*Signature)) writeSignature(buf, m.typ.(*Signature))
} }
for i, typ := range t.types {
if i > 0 || len(t.methods) > 0 {
buf.WriteString("; ")
}
writeType(buf, typ)
}
buf.WriteByte('}') buf.WriteByte('}')
case *Map: case *Map:

View File

@ -58,7 +58,7 @@ func TestEvalBasic(t *testing.T) {
} }
func TestEvalComposite(t *testing.T) { func TestEvalComposite(t *testing.T) {
for _, test := range testTypes { for _, test := range independentTestTypes {
testEval(t, nil, nil, test.src, nil, test.str, "") testEval(t, nil, nil, test.src, nil, test.str, "")
} }
} }

View File

@ -16,6 +16,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"sort"
"strconv" "strconv"
"strings" "strings"
"text/scanner" "text/scanner"
@ -633,7 +634,9 @@ func (p *gcParser) parseInterfaceType() Type {
} }
p.expect('}') p.expect('}')
sort.Sort(byUniqueMethodName(methods))
typ.methods = methods typ.methods = methods
typ.allMethods = methods // ok to share underlying array since we are not changing methods
return typ return typ
} }

View File

@ -177,7 +177,8 @@ func lookupFieldOrMethod(T Type, pkg *Package, name string) (obj Object, index [
case *Interface: case *Interface:
// look for a matching method // look for a matching method
if i, m := lookupMethod(t.methods, pkg, name); m != nil { // TODO(gri) t.allMethods is sorted - use binary search
if i, m := lookupMethod(t.allMethods, pkg, name); m != nil {
assert(m.typ != nil) assert(m.typ != nil)
index = concat(e.index, i) index = concat(e.index, i)
if obj != nil || e.multiples { if obj != nil || e.multiples {
@ -252,8 +253,9 @@ func MissingMethod(V Type, T *Interface, static bool) (method *Func, wrongType b
// TODO(gri) Consider using method sets here. Might be more efficient. // TODO(gri) Consider using method sets here. Might be more efficient.
if ityp, _ := V.Underlying().(*Interface); ityp != nil { if ityp, _ := V.Underlying().(*Interface); ityp != nil {
for _, m := range T.methods { // TODO(gri) allMethods is sorted - can do this more efficiently
_, obj := lookupMethod(ityp.methods, m.pkg, m.name) for _, m := range T.allMethods {
_, obj := lookupMethod(ityp.allMethods, m.pkg, m.name)
switch { switch {
case obj == nil: case obj == nil:
if static { if static {
@ -267,7 +269,7 @@ func MissingMethod(V Type, T *Interface, static bool) (method *Func, wrongType b
} }
// A concrete type implements T if it implements all methods of T. // A concrete type implements T if it implements all methods of T.
for _, m := range T.methods { for _, m := range T.allMethods {
obj, _, indirect := lookupFieldOrMethod(V, m.pkg, m.name) obj, _, indirect := lookupFieldOrMethod(V, m.pkg, m.name)
if obj == nil { if obj == nil {
return m, false return m, false

View File

@ -168,7 +168,7 @@ func NewMethodSet(T Type) *MethodSet {
} }
case *Interface: case *Interface:
mset = mset.add(t.methods, e.index, true, e.multiples) mset = mset.add(t.allMethods, e.index, true, e.multiples)
} }
} }

View File

@ -6,6 +6,8 @@
package types package types
import "sort"
func isNamed(typ Type) bool { func isNamed(typ Type) bool {
if _, ok := typ.(*Basic); ok { if _, ok := typ.(*Basic); ok {
return ok return ok
@ -105,8 +107,22 @@ func hasNil(typ Type) bool {
return false return false
} }
// IsIdentical returns true if x and y are identical. // IsIdentical reports whether x and y are identical.
func IsIdentical(x, y Type) bool { func IsIdentical(x, y Type) bool {
return isIdenticalInternal(x, y, nil)
}
// An ifacePair is a node in a stack of interface type pairs compared for identity.
type ifacePair struct {
x, y *Interface
prev *ifacePair
}
func (p *ifacePair) identical(q *ifacePair) bool {
return p.x == q.x && p.y == q.y || p.x == q.y && p.y == q.x
}
func isIdenticalInternal(x, y Type, p *ifacePair) bool {
if x == y { if x == y {
return true return true
} }
@ -124,13 +140,13 @@ func IsIdentical(x, y Type) bool {
// Two array types are identical if they have identical element types // Two array types are identical if they have identical element types
// and the same array length. // and the same array length.
if y, ok := y.(*Array); ok { if y, ok := y.(*Array); ok {
return x.len == y.len && IsIdentical(x.elt, y.elt) return x.len == y.len && isIdenticalInternal(x.elt, y.elt, p)
} }
case *Slice: case *Slice:
// Two slice types are identical if they have identical element types. // Two slice types are identical if they have identical element types.
if y, ok := y.(*Slice); ok { if y, ok := y.(*Slice); ok {
return IsIdentical(x.elt, y.elt) return isIdenticalInternal(x.elt, y.elt, p)
} }
case *Struct: case *Struct:
@ -145,7 +161,7 @@ func IsIdentical(x, y Type) bool {
if f.anonymous != g.anonymous || if f.anonymous != g.anonymous ||
x.Tag(i) != y.Tag(i) || x.Tag(i) != y.Tag(i) ||
!f.sameId(g.pkg, g.name) || !f.sameId(g.pkg, g.name) ||
!IsIdentical(f.typ, g.typ) { !isIdenticalInternal(f.typ, g.typ, p) {
return false return false
} }
} }
@ -156,14 +172,24 @@ func IsIdentical(x, y Type) bool {
case *Pointer: case *Pointer:
// Two pointer types are identical if they have identical base types. // Two pointer types are identical if they have identical base types.
if y, ok := y.(*Pointer); ok { if y, ok := y.(*Pointer); ok {
return IsIdentical(x.base, y.base) return isIdenticalInternal(x.base, y.base, p)
} }
case *Tuple: case *Tuple:
// Two tuples types are identical if they have the same number of elements // Two tuples types are identical if they have the same number of elements
// and corresponding elements have identical types. // and corresponding elements have identical types.
if y, ok := y.(*Tuple); ok { if y, ok := y.(*Tuple); ok {
return identicalTuples(x, y) if x.Len() == y.Len() {
if x != nil {
for i, v := range x.vars {
w := y.vars[i]
if !isIdenticalInternal(v.typ, w.typ, p) {
return false
}
}
}
return true
}
} }
case *Signature: case *Signature:
@ -173,8 +199,8 @@ func IsIdentical(x, y Type) bool {
// names are not required to match. // names are not required to match.
if y, ok := y.(*Signature); ok { if y, ok := y.(*Signature); ok {
return x.isVariadic == y.isVariadic && return x.isVariadic == y.isVariadic &&
identicalTuples(x.params, y.params) && isIdenticalInternal(x.params, y.params, p) &&
identicalTuples(x.results, y.results) isIdenticalInternal(x.results, y.results, p)
} }
case *Interface: case *Interface:
@ -182,20 +208,63 @@ func IsIdentical(x, y Type) bool {
// the same names and identical function types. Lower-case method names from // the same names and identical function types. Lower-case method names from
// different packages are always different. The order of the methods is irrelevant. // different packages are always different. The order of the methods is irrelevant.
if y, ok := y.(*Interface); ok { if y, ok := y.(*Interface); ok {
return identicalMethods(x.methods, y.methods) a := x.allMethods
b := y.allMethods
if len(a) == len(b) {
// Interface types are the only types where cycles can occur
// that are not "terminated" via named types; and such cycles
// can only be created via method parameter types that are
// anonymous interfaces (directly or indirectly) embedding
// the current interface. Example:
//
// type T interface {
// m() interface{T}
// }
//
// If two such (differently named) interfaces are compared,
// endless recursion occurs if the cycle is not detected.
//
// If x and y were compared before, they must be equal
// (if they were not, the recursion would have stopped);
// search the ifacePair stack for the same pair.
//
// This is a quadratic algorithm, but in practice these stacks
// are extremely short (bounded by the nesting depth of interface
// type declarations that recur via parameter types, an extremely
// rare occurrence). An alternative implementation might use a
// "visited" map, but that is probably less efficient overall.
q := &ifacePair{x, y, p}
for p != nil {
if p.identical(q) {
return true // same pair was compared before
}
p = p.prev
}
if debug {
assert(sort.IsSorted(byUniqueMethodName(a)))
assert(sort.IsSorted(byUniqueMethodName(b)))
}
for i, f := range a {
g := b[i]
if f.Id() != g.Id() || !isIdenticalInternal(f.typ, g.typ, q) {
return false
}
}
return true
}
} }
case *Map: case *Map:
// Two map types are identical if they have identical key and value types. // Two map types are identical if they have identical key and value types.
if y, ok := y.(*Map); ok { if y, ok := y.(*Map); ok {
return IsIdentical(x.key, y.key) && IsIdentical(x.elt, y.elt) return isIdenticalInternal(x.key, y.key, p) && isIdenticalInternal(x.elt, y.elt, p)
} }
case *Chan: case *Chan:
// Two channel types are identical if they have identical value types // Two channel types are identical if they have identical value types
// and the same direction. // and the same direction.
if y, ok := y.(*Chan); ok { if y, ok := y.(*Chan); ok {
return x.dir == y.dir && IsIdentical(x.elt, y.elt) return x.dir == y.dir && isIdenticalInternal(x.elt, y.elt, p)
} }
case *Named: case *Named:
@ -204,53 +273,14 @@ func IsIdentical(x, y Type) bool {
if y, ok := y.(*Named); ok { if y, ok := y.(*Named); ok {
return x.obj == y.obj return x.obj == y.obj
} }
default:
unreachable()
} }
return false return false
} }
// identicalTuples returns true if both tuples a and b have the
// same length and corresponding elements have identical types.
func identicalTuples(a, b *Tuple) bool {
if a.Len() != b.Len() {
return false
}
if a != nil {
for i, x := range a.vars {
y := b.vars[i]
if !IsIdentical(x.typ, y.typ) {
return false
}
}
}
return true
}
// identicalMethods returns true if both slices a and b have the
// same length and corresponding entries have identical types.
// TODO(gri) make this more efficient (e.g., sort them on completion)
func identicalMethods(a, b []*Func) bool {
if len(a) != len(b) {
return false
}
m := make(map[string]*Func)
for _, x := range a {
key := x.Id()
assert(m[key] == nil) // method list must not have duplicate entries
m[key] = x
}
for _, y := range b {
key := y.Id()
if x := m[key]; x == nil || !IsIdentical(x.typ, y.typ) {
return false
}
}
return true
}
// defaultType returns the default "typed" type for an "untyped" type; // defaultType returns the default "typed" type for an "untyped" type;
// it returns the incoming type for all other types. The default type // it returns the incoming type for all other types. The default type
// for untyped nil is untyped nil. // for untyped nil is untyped nil.

View File

@ -23,7 +23,7 @@ func (check *checker) reportAltDecl(obj Object) {
} }
} }
func (check *checker) declareObj(scope *Scope, id *ast.Ident, obj Object) { func (check *checker) declare(scope *Scope, id *ast.Ident, obj Object) {
if alt := scope.Insert(obj); alt != nil { if alt := scope.Insert(obj); alt != nil {
check.errorf(obj.Pos(), "%s redeclared in this block", obj.Name()) check.errorf(obj.Pos(), "%s redeclared in this block", obj.Name())
check.reportAltDecl(alt) check.reportAltDecl(alt)
@ -34,17 +34,6 @@ func (check *checker) declareObj(scope *Scope, id *ast.Ident, obj Object) {
} }
} }
func (check *checker) declareFld(oset *objset, id *ast.Ident, obj Object) {
if alt := oset.insert(obj); alt != nil {
check.errorf(obj.Pos(), "%s redeclared", obj.Name())
check.reportAltDecl(alt)
return
}
if id != nil {
check.recordObject(id, obj)
}
}
// A declInfo describes a package-level const, type, var, or func declaration. // A declInfo describes a package-level const, type, var, or func declaration.
type declInfo struct { type declInfo struct {
file *Scope // scope of file containing this declaration file *Scope // scope of file containing this declaration
@ -135,7 +124,7 @@ func (check *checker) resolveFiles(files []*ast.File) {
return return
} }
check.declareObj(pkg.scope, ident, obj) check.declare(pkg.scope, ident, obj)
objList = append(objList, obj) objList = append(objList, obj)
objMap[obj] = declInfo{fileScope, typ, init, nil} objMap[obj] = declInfo{fileScope, typ, init, nil}
} }
@ -227,7 +216,7 @@ func (check *checker) resolveFiles(files []*ast.File) {
if obj.IsExported() { if obj.IsExported() {
// Note: This will change each imported object's scope! // Note: This will change each imported object's scope!
// May be an issue for type aliases. // May be an issue for type aliases.
check.declareObj(fileScope, nil, obj) check.declare(fileScope, nil, obj)
check.recordImplicit(s, obj) check.recordImplicit(s, obj)
} }
} }
@ -241,7 +230,7 @@ func (check *checker) resolveFiles(files []*ast.File) {
posSet[imp] = s.Pos() posSet[imp] = s.Pos()
} else { } else {
// declare imported package object in file scope // declare imported package object in file scope
check.declareObj(fileScope, nil, obj) check.declare(fileScope, nil, obj)
} }
case *ast.ValueSpec: case *ast.ValueSpec:
@ -326,7 +315,7 @@ func (check *checker) resolveFiles(files []*ast.File) {
// ok to continue // ok to continue
} }
} else { } else {
check.declareObj(pkg.scope, d.Name, obj) check.declare(pkg.scope, d.Name, obj)
} }
} else { } else {
// Associate method with receiver base type name, if possible. // Associate method with receiver base type name, if possible.
@ -603,7 +592,7 @@ func (check *checker) typeDecl(obj *TypeName, typ ast.Expr, def *Named, cycleOk
// C A // C A
// ) // )
// //
// When we declare obj = C, typ is the identifier A which is incomplete. // When we declare object C, typ is the identifier A which is incomplete.
u := check.typ(typ, named, cycleOk) u := check.typ(typ, named, cycleOk)
// Determine the unnamed underlying type. // Determine the unnamed underlying type.
@ -723,7 +712,7 @@ func (check *checker) declStmt(decl ast.Decl) {
check.arityMatch(s, last) check.arityMatch(s, last)
for i, name := range s.Names { for i, name := range s.Names {
check.declareObj(check.topScope, name, lhs[i]) check.declare(check.topScope, name, lhs[i])
} }
case token.VAR: case token.VAR:
@ -756,7 +745,7 @@ func (check *checker) declStmt(decl ast.Decl) {
check.arityMatch(s, nil) check.arityMatch(s, nil)
for i, name := range s.Names { for i, name := range s.Names {
check.declareObj(check.topScope, name, lhs[i]) check.declare(check.topScope, name, lhs[i])
} }
default: default:
@ -765,7 +754,7 @@ func (check *checker) declStmt(decl ast.Decl) {
case *ast.TypeSpec: case *ast.TypeSpec:
obj := NewTypeName(s.Name.Pos(), pkg, s.Name.Name, nil) obj := NewTypeName(s.Name.Pos(), pkg, s.Name.Name, nil)
check.declareObj(check.topScope, s.Name, obj) check.declare(check.topScope, s.Name, obj)
check.typeDecl(obj, s.Type, nil, false) check.typeDecl(obj, s.Type, nil, false)
default: default:

View File

@ -131,7 +131,6 @@ func TestStdfixed(t *testing.T) {
"bug200.go", // TODO(gri) complete duplicate checking in expr switches "bug200.go", // TODO(gri) complete duplicate checking in expr switches
"bug223.go", "bug413.go", "bug459.go", // TODO(gri) complete initialization checks "bug223.go", "bug413.go", "bug459.go", // TODO(gri) complete initialization checks
"bug248.go", "bug302.go", "bug369.go", // complex test instructions - ignore "bug248.go", "bug302.go", "bug369.go", // complex test instructions - ignore
"bug250.go", // TODO(gri) fix recursive interfaces
"issue3924.go", // TODO(gri) && and || produce bool result (not untyped bool) "issue3924.go", // TODO(gri) && and || produce bool result (not untyped bool)
"issue4847.go", // TODO(gri) initialization cycle error not found "issue4847.go", // TODO(gri) initialization cycle error not found
) )

View File

@ -441,7 +441,7 @@ func (check *checker) stmt(ctxt stmtContext, s ast.Stmt) {
// one; i.e., if any one of them is 'used', all of them are 'used'. // one; i.e., if any one of them is 'used', all of them are 'used'.
// Collect them for later analysis. // Collect them for later analysis.
lhsVars = append(lhsVars, obj) lhsVars = append(lhsVars, obj)
check.declareObj(check.topScope, nil, obj) check.declare(check.topScope, nil, obj)
check.recordImplicit(clause, obj) check.recordImplicit(clause, obj)
} }
check.stmtList(inner, clause.Body) check.stmtList(inner, clause.Body)
@ -592,7 +592,7 @@ func (check *checker) stmt(ctxt stmtContext, s ast.Stmt) {
// declare variables // declare variables
for i, ident := range idents { for i, ident := range idents {
check.declareObj(check.topScope, ident, vars[i]) check.declare(check.topScope, ident, vars[i])
} }
} else { } else {
// ordinary assignment // ordinary assignment

77
go/types/testdata/cycles1.src vendored Normal file
View File

@ -0,0 +1,77 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package p
type (
A interface {
a() interface {
ABC1
}
}
B interface {
b() interface {
ABC2
}
}
C interface {
c() interface {
ABC3
}
}
AB interface {
A
B
}
BC interface {
B
C
}
ABC1 interface {
A
B
C
}
ABC2 interface {
AB
C
}
ABC3 interface {
A
BC
}
)
var (
x1 ABC1
x2 ABC2
x3 ABC3
)
func _() {
// all types have the same method set
x1 = x2
x2 = x1
x1 = x3
x3 = x1
x2 = x3
x3 = x2
// all methods return the same type again
x1 = x1.a()
x1 = x1.b()
x1 = x1.c()
x2 = x2.a()
x2 = x2.b()
x2 = x2.c()
x3 = x3.a()
x3 = x3.b()
x3 = x3.c()
}

66
go/types/testdata/cycles2.src vendored Normal file
View File

@ -0,0 +1,66 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package p
// Test case for issue 5090
type t interface {
f(u)
}
type u interface {
t
}
func _() {
var t t
var u u
t.f(t)
t.f(u)
u.f(t)
u.f(u)
}
// Test case for issue 6589.
type A interface {
a() interface {
AB
}
}
type B interface {
a() interface {
AB
}
}
type AB interface {
a() interface {
A
B /* ERROR a redeclared */
}
b() interface {
A
B /* ERROR a redeclared */
}
}
var x AB
var y interface {
A
B /* ERROR a redeclared */
}
var _ = x /* ERROR cannot compare */ == y
// Test case for issue 6638.
type T /* ERROR cycle */ interface {
m() [T /* ERROR no field or method */ (nil).m()[0]]int
}

60
go/types/testdata/cycles3.src vendored Normal file
View File

@ -0,0 +1,60 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package p
import "unsafe"
var (
_ A = A(nil).a().b().c().d().e().f()
_ A = A(nil).b().c().d().e().f()
_ A = A(nil).c().d().e().f()
_ A = A(nil).d().e().f()
_ A = A(nil).e().f()
_ A = A(nil).f()
_ A = A(nil)
)
type (
A interface {
a() B
B
}
B interface {
b() C
C
}
C interface {
c() D
D
}
D interface {
d() E
E
}
E interface {
e() F
F
}
F interface {
f() A
}
)
type (
U /* ERROR illegal cycle */ interface {
V
}
V interface {
v() [unsafe.Sizeof(u)]int
}
)
var u U

68
go/types/testdata/cycles4.src vendored Normal file
View File

@ -0,0 +1,68 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package p
// Check that all methods of T are collected before
// determining the result type of m (which embeds
// all methods of T).
type T interface {
m() interface {T}
E
}
var _ = T.m(nil).m().e()
type E interface {
e() int
}
// Check that unresolved forward chains are followed
// (see also comment in resolver.go, checker.typeDecl).
var _ = C.m(nil).m().e()
type A B
type B interface {
m() interface{C}
E
}
type C A
// Check that interface type comparison for identity
// does not recur endlessly.
type T1 interface {
m() interface{T1}
}
type T2 interface {
m() interface{T2}
}
func _(x T1, y T2) {
// Checking for assignability of interfaces must check
// if all methods of x are present in y, and that they
// have identical signatures. The signatures recur via
// the result type, which is an interface that embeds
// a single method m that refers to the very interface
// that contains it. This requires cycle detection in
// identity checks for interface types.
x = y
}
type T3 interface {
m() interface{T4}
}
type T4 interface {
m() interface{T3}
}
func _(x T1, y T3) {
x = y
}

View File

@ -4,7 +4,10 @@
package types package types
import "go/ast" import (
"go/ast"
"sort"
)
// TODO(gri) Revisit factory functions - make sure they have all relevant parameters. // TODO(gri) Revisit factory functions - make sure they have all relevant parameters.
@ -238,20 +241,25 @@ func (s *Signature) IsVariadic() bool { return s.isVariadic }
// An Interface represents an interface type. // An Interface represents an interface type.
type Interface struct { type Interface struct {
methods []*Func // methods declared with or embedded in this interface methods []*Func // explicitly declared methods
mset cachedMethodSet // method set for interface, lazily initialized types []*Named // explicitly embedded types
allMethods []*Func // ordered list of methods declared with or embedded in this interface (TODO(gri): replace with mset)
mset cachedMethodSet // method set for interface, lazily initialized
} }
// NewInterface returns a new interface for the given methods. // NewInterface returns a new interface for the given methods.
func NewInterface(methods []*Func) *Interface { func NewInterface(methods []*Func) *Interface {
return &Interface{methods: methods} // TODO(gri) should provide receiver to all methods
sort.Sort(byUniqueMethodName(methods))
return &Interface{methods: methods, allMethods: methods}
} }
// NumMethods returns the number of methods of interface t. // NumMethods returns the number of methods of interface t.
func (t *Interface) NumMethods() int { return len(t.methods) } func (t *Interface) NumMethods() int { return len(t.allMethods) }
// Method returns the i'th method of interface t for 0 <= i < t.NumMethods(). // Method returns the i'th method of interface t for 0 <= i < t.NumMethods().
func (t *Interface) Method(i int) *Func { return t.methods[i] } func (t *Interface) Method(i int) *Func { return t.allMethods[i] }
// A Map represents a map type. // A Map represents a map type.
type Map struct { type Map struct {

View File

@ -33,7 +33,8 @@ func dup(s string) testEntry {
return testEntry{s, s} return testEntry{s, s}
} }
var testTypes = []testEntry{ // types that don't depend on any other type declarations
var independentTestTypes = []testEntry{
// basic types // basic types
dup("int"), dup("int"),
dup("float32"), dup("float32"),
@ -89,8 +90,7 @@ var testTypes = []testEntry{
// interfaces // interfaces
dup("interface{}"), dup("interface{}"),
dup("interface{m()}"), dup("interface{m()}"),
dup(`interface{m(int) float32; String() string}`), dup(`interface{String() string; m(int) float32}`),
// TODO(gri) add test for interface w/ anonymous field
// maps // maps
dup("map[string]int"), dup("map[string]int"),
@ -102,9 +102,21 @@ var testTypes = []testEntry{
dup("<-chan []func() int"), dup("<-chan []func() int"),
} }
// types that depend on other type declarations (src in TestTypes)
var dependentTestTypes = []testEntry{
// interfaces
dup(`interface{io.Reader; io.Writer}`),
dup(`interface{m() int; io.Writer}`),
{`interface{m() interface{T}}`, `interface{m() interface{p.T}}`},
}
func TestTypes(t *testing.T) { func TestTypes(t *testing.T) {
for _, test := range testTypes { var tests []testEntry
src := "package p; type T " + test.src tests = append(tests, independentTestTypes...)
tests = append(tests, dependentTestTypes...)
for _, test := range tests {
src := `package p; import "io"; type _ io.Writer; type T ` + test.src
pkg, err := makePkg(t, src) pkg, err := makePkg(t, src)
if err != nil { if err != nil {
t.Errorf("%s: %s", src, err) t.Errorf("%s: %s", src, err)

View File

@ -9,6 +9,7 @@ package types
import ( import (
"go/ast" "go/ast"
"go/token" "go/token"
"sort"
"strconv" "strconv"
"code.google.com/p/go.tools/go/exact" "code.google.com/p/go.tools/go/exact"
@ -67,8 +68,9 @@ func (check *checker) ident(x *operand, e *ast.Ident, def *Named, cycleOk bool)
} }
x.val = check.iota x.val = check.iota
} else { } else {
x.val = obj.val // may be nil if we don't know the constant value x.val = obj.val
} }
assert(x.val != nil)
x.mode = constant x.mode = constant
case *TypeName: case *TypeName:
@ -166,6 +168,7 @@ func (check *checker) funcType(recv *ast.FieldList, ftyp *ast.FuncType, def *Nam
if T.obj.pkg != check.pkg { if T.obj.pkg != check.pkg {
err = "type not defined in this package" err = "type not defined in this package"
} else { } else {
// TODO(gri) This is not correct if the underlying type is unknown yet.
switch u := T.underlying.(type) { switch u := T.underlying.(type) {
case *Basic: case *Basic:
// unsafe.Pointer is treated like a regular pointer // unsafe.Pointer is treated like a regular pointer
@ -297,15 +300,7 @@ func (check *checker) typInternal(e ast.Expr, def *Named, cycleOk bool) Type {
return check.funcType(nil, e, def) return check.funcType(nil, e, def)
case *ast.InterfaceType: case *ast.InterfaceType:
typ := new(Interface) return check.interfaceType(e, def, cycleOk)
var recv Type = typ
if def != nil {
def.underlying = typ
recv = def // use named receiver type if available
}
typ.methods = check.collectMethods(recv, e.Methods, cycleOk)
return typ
case *ast.MapType: case *ast.MapType:
typ := new(Map) typ := new(Map)
@ -319,6 +314,7 @@ func (check *checker) typInternal(e ast.Expr, def *Named, cycleOk bool) Type {
// spec: "The comparison operators == and != must be fully defined // spec: "The comparison operators == and != must be fully defined
// for operands of the key type; thus the key type must not be a // for operands of the key type; thus the key type must not be a
// function, map, or slice." // function, map, or slice."
// TODO(gri) if the key type is not fully defined yet, this test will be incorrect
if !isComparable(typ.key) { if !isComparable(typ.key) {
check.errorf(e.Key.Pos(), "invalid map key type %s", typ.key) check.errorf(e.Key.Pos(), "invalid map key type %s", typ.key)
// ok to continue // ok to continue
@ -385,13 +381,13 @@ func (check *checker) collectParams(scope *Scope, list *ast.FieldList, variadicO
} }
} }
typ := check.typ(ftype, nil, true) typ := check.typ(ftype, nil, true)
// the parser ensures that f.Tag is nil and we don't // The parser ensures that f.Tag is nil and we don't
// care if a constructed AST contains a non-nil tag // care if a constructed AST contains a non-nil tag.
if len(field.Names) > 0 { if len(field.Names) > 0 {
// named parameter // named parameter
for _, name := range field.Names { for _, name := range field.Names {
par := NewParam(name.Pos(), check.pkg, name.Name, typ) par := NewParam(name.Pos(), check.pkg, name.Name, typ)
check.declareObj(scope, name, par) check.declare(scope, name, par)
params = append(params, par) params = append(params, par)
} }
} else { } else {
@ -411,57 +407,152 @@ func (check *checker) collectParams(scope *Scope, list *ast.FieldList, variadicO
return return
} }
func (check *checker) collectMethods(recv Type, list *ast.FieldList, cycleOk bool) (methods []*Func) { func (check *checker) declareInSet(oset *objset, pos token.Pos, id *ast.Ident, obj Object) bool {
if list == nil { if alt := oset.insert(obj); alt != nil {
return nil check.errorf(pos, "%s redeclared", obj.Name())
check.reportAltDecl(alt)
return false
}
if id != nil {
check.recordObject(id, obj)
}
return true
}
func (check *checker) interfaceType(ityp *ast.InterfaceType, def *Named, cycleOk bool) *Interface {
iface := new(Interface)
if def != nil {
def.underlying = iface
} }
var mset objset // empty interface: common case
if ityp.Methods == nil {
return iface
}
for _, f := range list.List { // The parser ensures that field tags are nil and we don't
typ := check.typ(f.Type, nil, cycleOk) // care if a constructed AST contains non-nil tags.
// the parser ensures that f.Tag is nil and we don't
// care if a constructed AST contains a non-nil tag // Phase 1: Collect explicitly declared methods, the corresponding
// signature (AST) expressions, and the list of embedded
// type (AST) expressions. Do not resolve signatures or
// embedded types yet to avoid cycles referring to this
// interface.
var (
mset objset
signatures []ast.Expr // list of corresponding method signatures
embedded []ast.Expr // list of embedded types
)
for _, f := range ityp.Methods.List {
if len(f.Names) > 0 { if len(f.Names) > 0 {
// methods (the parser ensures that there's only one // The parser ensures that there's only one method
// and we don't care if a constructed AST has more) // and we don't care if a constructed AST has more.
sig, _ := typ.(*Signature) name := f.Names[0]
if sig == nil { pos := name.Pos()
check.invalidAST(f.Type.Pos(), "%s is not a method signature", typ) // Don't type-check signature yet - use an
continue // empty signature now and update it later.
} m := NewFunc(pos, check.pkg, name.Name, new(Signature))
sig.recv = NewVar(token.NoPos, check.pkg, "", recv) if check.declareInSet(&mset, pos, name, m) {
for _, name := range f.Names { iface.methods = append(iface.methods, m)
m := NewFunc(name.Pos(), check.pkg, name.Name, sig) iface.allMethods = append(iface.allMethods, m)
check.declareFld(&mset, name, m) signatures = append(signatures, f.Type)
methods = append(methods, m)
} }
} else { } else {
// embedded interface // embedded type
switch t := typ.Underlying().(type) { embedded = append(embedded, f.Type)
case nil: }
// The underlying type is in the process of being defined }
// but we need it in order to complete this type. For now
// complain with an "unimplemented" error. This requires // Phase 2: Resolve embedded interfaces. Because an interface must not
// a bit more work. // embed itself (directly or indirectly), each embedded interface
// TODO(gri) finish this. // can be fully resolved without depending on any method of this
check.errorf(f.Type.Pos(), "reference to incomplete type %s - unimplemented", f.Type) // interface (if there is a cycle or another error, the embedded
case *Interface: // type resolves to an invalid type and is ignored).
for _, m := range t.methods { // In particular, the list of methods for each embedded interface
check.declareFld(&mset, nil, m) // must be complete (it cannot depend on this interface), and so
methods = append(methods, m) // those methods can be added to the list of all methods of this
} // interface.
default:
if t != Typ[Invalid] { for _, e := range embedded {
check.errorf(f.Type.Pos(), "%s is not an interface type", typ) pos := e.Pos()
} typ := check.typ(e, nil, cycleOk)
if typ == Typ[Invalid] {
continue
}
named, _ := typ.(*Named)
if named == nil {
check.invalidAST(pos, "%s is not named type", typ)
continue
}
// determine underlying (possibly incomplete) type
// by following its forward chain
// TODO(gri) should this be part of Underlying()?
u := named.underlying
for {
n, _ := u.(*Named)
if n == nil {
break
}
u = n.underlying
}
if u == Typ[Invalid] {
continue
}
embed, _ := u.(*Interface)
if embed == nil {
check.errorf(pos, "%s is not an interface", named)
continue
}
iface.types = append(iface.types, named)
// collect embedded methods
for _, m := range embed.allMethods {
if check.declareInSet(&mset, pos, nil, m) {
iface.allMethods = append(iface.allMethods, m)
} }
} }
} }
return // Phase 3: At this point all methods have been collected for this interface.
// It is now safe to type-check the signatures of all explicitly
// declared methods, even if they refer to this interface via a cycle
// and embed the methods of this interface in a parameter of interface
// type.
// determine receiver type
var recv Type = iface
if def != nil {
def.underlying = iface
recv = def // use named receiver type if available
}
for i, m := range iface.methods {
expr := signatures[i]
typ := check.typ(expr, nil, true)
if typ == Typ[Invalid] {
continue // keep method with empty method signature
}
sig, _ := typ.(*Signature)
if sig == nil {
check.invalidAST(expr.Pos(), "%s is not a method signature", typ)
continue // keep method with empty method signature
}
sig.recv = NewVar(m.pos, check.pkg, "", recv)
*m.typ.(*Signature) = *sig // update signature (don't replace it!)
}
sort.Sort(byUniqueMethodName(iface.allMethods))
return iface
} }
// byUniqueMethodName method lists can be sorted by their unique method names.
type byUniqueMethodName []*Func
func (a byUniqueMethodName) Len() int { return len(a) }
func (a byUniqueMethodName) Less(i, j int) bool { return a[i].Id() < a[j].Id() }
func (a byUniqueMethodName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (check *checker) tag(t *ast.BasicLit) string { func (check *checker) tag(t *ast.BasicLit) string {
if t != nil { if t != nil {
if t.Kind == token.STRING { if t.Kind == token.STRING {
@ -492,8 +583,9 @@ func (check *checker) collectFields(list *ast.FieldList, cycleOk bool) (fields [
} }
fld := NewField(pos, check.pkg, name, typ, anonymous) fld := NewField(pos, check.pkg, name, typ, anonymous)
check.declareFld(&fset, ident, fld) if check.declareInSet(&fset, pos, ident, fld) {
fields = append(fields, fld) fields = append(fields, fld)
}
} }
for _, f := range list.List { for _, f := range list.List {

View File

@ -67,7 +67,7 @@ func defPredeclaredTypes() {
res := NewVar(token.NoPos, nil, "", Typ[String]) res := NewVar(token.NoPos, nil, "", Typ[String])
sig := &Signature{results: NewTuple(res)} sig := &Signature{results: NewTuple(res)}
err := NewFunc(token.NoPos, nil, "Error", sig) err := NewFunc(token.NoPos, nil, "Error", sig)
typ := &Named{underlying: &Interface{methods: []*Func{err}}, complete: true} typ := &Named{underlying: NewInterface([]*Func{err}), complete: true}
sig.recv = NewVar(token.NoPos, nil, "", typ) sig.recv = NewVar(token.NoPos, nil, "", typ)
def(NewTypeName(token.NoPos, nil, "error", typ)) def(NewTypeName(token.NoPos, nil, "error", typ))
} }