diff --git a/src/go/types/api_test.go b/src/go/types/api_test.go index 3e10be59852..c9127f366aa 100644 --- a/src/go/types/api_test.go +++ b/src/go/types/api_test.go @@ -28,8 +28,12 @@ import ( // If source begins with "package generic_" and type parameters are enabled, // generic code is permitted. func pkgFor(path, source string, info *Info) (*Package, error) { - fset := token.NewFileSet() mode := modeForSource(source) + return pkgForMode(path, source, info, mode) +} + +func pkgForMode(path, source string, info *Info, mode parser.Mode) (*Package, error) { + fset := token.NewFileSet() f, err := parser.ParseFile(fset, path, source, mode) if err != nil { return nil, err @@ -1938,8 +1942,8 @@ func f(x T) T { return foo.F(x) } func TestInstantiate(t *testing.T) { // eventually we like more tests but this is a start - const src = genericPkg + "p; type T[P any] *T[P]" - pkg, err := pkgFor(".", src, nil) + const src = "package p; type T[P any] *T[P]" + pkg, err := pkgForMode(".", src, nil, 0) if err != nil { t.Fatal(err) } @@ -1976,8 +1980,8 @@ func TestInstantiateErrors(t *testing.T) { } for _, test := range tests { - src := genericPkg + "p; " + test.src - pkg, err := pkgFor(".", src, nil) + src := "package p; " + test.src + pkg, err := pkgForMode(".", src, nil, 0) if err != nil { t.Fatal(err) } diff --git a/src/go/types/context.go b/src/go/types/context.go index 1f102f0b8b7..ff4bf89f3cb 100644 --- a/src/go/types/context.go +++ b/src/go/types/context.go @@ -7,6 +7,7 @@ package types import ( "bytes" "fmt" + "strconv" "strings" "sync" ) @@ -17,10 +18,10 @@ import ( // // It is safe for concurrent use. type Context struct { - mu sync.Mutex - typeMap map[string][]ctxtEntry // type hash -> instances entries - nextID int // next unique ID - seen map[*Named]int // assigned unique IDs + mu sync.Mutex + typeMap map[string][]ctxtEntry // type hash -> instances entries + nextID int // next unique ID + originIDs map[Type]int // origin type -> unique ID } type ctxtEntry struct { @@ -32,23 +33,25 @@ type ctxtEntry struct { // NewContext creates a new Context. func NewContext() *Context { return &Context{ - typeMap: make(map[string][]ctxtEntry), - seen: make(map[*Named]int), + typeMap: make(map[string][]ctxtEntry), + originIDs: make(map[Type]int), } } -// typeHash returns a string representation of typ instantiated with targs, -// which can be used as an exact type hash: types that are identical produce -// identical string representations. If targs is not empty, typ is printed as -// if it were instantiated with targs. The result is guaranteed to not contain -// blanks (" "). -func (ctxt *Context) typeHash(typ Type, targs []Type) string { +// instanceHash returns a string representation of typ instantiated with targs. +// The hash should be a perfect hash, though out of caution the type checker +// does not assume this. The result is guaranteed to not contain blanks. +func (ctxt *Context) instanceHash(orig Type, targs []Type) string { assert(ctxt != nil) - assert(typ != nil) + assert(orig != nil) var buf bytes.Buffer h := newTypeHasher(&buf, ctxt) - h.typ(typ) + h.string(strconv.Itoa(ctxt.getID(orig))) + // Because we've already written the unique origin ID this call to h.typ is + // unnecessary, but we leave it for hash readability. It can be removed later + // if performance is an issue. + h.typ(orig) if len(targs) > 0 { // TODO(rfindley): consider asserting on isGeneric(typ) here, if and when // isGeneric handles *Signature types. @@ -106,14 +109,14 @@ func (ctxt *Context) update(h string, orig Type, targs []Type, inst Type) Type { return inst } -// idForType returns a unique ID for the pointer n. -func (ctxt *Context) idForType(n *Named) int { +// getID returns a unique ID for the type t. +func (ctxt *Context) getID(t Type) int { ctxt.mu.Lock() defer ctxt.mu.Unlock() - id, ok := ctxt.seen[n] + id, ok := ctxt.originIDs[t] if !ok { id = ctxt.nextID - ctxt.seen[n] = id + ctxt.originIDs[t] = id ctxt.nextID++ } return id diff --git a/src/go/types/instantiate.go b/src/go/types/instantiate.go index 62d9e184018..ec646e1a5cb 100644 --- a/src/go/types/instantiate.go +++ b/src/go/types/instantiate.go @@ -52,30 +52,26 @@ func Instantiate(ctxt *Context, typ Type, targs []Type, validate bool) (Type, er // instance creates a type or function instance using the given original type // typ and arguments targs. For Named types the resulting instance will be // unexpanded. -func (check *Checker) instance(pos token.Pos, orig Type, targs []Type, ctxt *Context) Type { +func (check *Checker) instance(pos token.Pos, orig Type, targs []Type, ctxt *Context) (res Type) { + var h string + if ctxt != nil { + h = ctxt.instanceHash(orig, targs) + // typ may already have been instantiated with identical type arguments. In + // that case, re-use the existing instance. + if inst := ctxt.lookup(h, orig, targs); inst != nil { + return inst + } + } + switch orig := orig.(type) { case *Named: - var h string - if ctxt != nil { - h = ctxt.typeHash(orig, targs) - // typ may already have been instantiated with identical type arguments. In - // that case, re-use the existing instance. - if inst := ctxt.lookup(h, orig, targs); inst != nil { - return inst - } - } tname := NewTypeName(pos, orig.obj.pkg, orig.obj.name, nil) named := check.newNamed(tname, orig, nil, nil, nil) // underlying, tparams, and methods are set when named is resolved named.targs = NewTypeList(targs) named.resolver = func(ctxt *Context, n *Named) (*TypeParamList, Type, []*Func) { return expandNamed(ctxt, n, pos) } - if ctxt != nil { - // It's possible that we've lost a race to add named to the context. - // In this case, use whichever instance is recorded in the context. - named = ctxt.update(h, orig, targs, named).(*Named) - } - return named + res = named case *Signature: tparams := orig.TypeParams() @@ -96,10 +92,19 @@ func (check *Checker) instance(pos token.Pos, orig Type, targs []Type, ctxt *Con // After instantiating a generic signature, it is not generic // anymore; we need to set tparams to nil. sig.tparams = nil - return sig + res = sig + default: + // only types and functions can be generic + panic(fmt.Sprintf("%v: cannot instantiate %v", pos, orig)) } - // only types and functions can be generic - panic(fmt.Sprintf("%v: cannot instantiate %v", pos, orig)) + + if ctxt != nil { + // It's possible that we've lost a race to add named to the context. + // In this case, use whichever instance is recorded in the context. + res = ctxt.update(h, orig, targs, res) + } + + return res } // validateTArgLen verifies that the length of targs and tparams matches, diff --git a/src/go/types/instantiate_test.go b/src/go/types/instantiate_test.go index 832c8222242..a4ed581e35b 100644 --- a/src/go/types/instantiate_test.go +++ b/src/go/types/instantiate_test.go @@ -11,40 +11,106 @@ import ( ) func TestInstantiateEquality(t *testing.T) { - const src = genericPkg + "p; type T[P any] int" - - pkg, err := pkgFor(".", src, nil) - if err != nil { - t.Fatal(err) + tests := []struct { + src string + name1 string + targs1 []Type + name2 string + targs2 []Type + wantEqual bool + }{ + { + "package basictype; type T[P any] int", + "T", []Type{Typ[Int]}, + "T", []Type{Typ[Int]}, + true, + }, + { + "package differenttypeargs; type T[P any] int", + "T", []Type{Typ[Int]}, + "T", []Type{Typ[String]}, + false, + }, + { + "package typeslice; type T[P any] int", + "T", []Type{NewSlice(Typ[Int])}, + "T", []Type{NewSlice(Typ[Int])}, + true, + }, + { + "package basicfunc; func F[P any]() {}", + "F", []Type{Typ[Int]}, + "F", []Type{Typ[Int]}, + true, + }, + { + "package funcslice; func F[P any]() {}", + "F", []Type{NewSlice(Typ[Int])}, + "F", []Type{NewSlice(Typ[Int])}, + true, + }, + { + "package funcwithparams; func F[P any](x string) float64 { return 0 }", + "F", []Type{Typ[Int]}, + "F", []Type{Typ[Int]}, + true, + }, + { + "package differentfuncargs; func F[P any](x string) float64 { return 0 }", + "F", []Type{Typ[Int]}, + "F", []Type{Typ[String]}, + false, + }, + { + "package funcequality; func F1[P any](x int) {}; func F2[Q any](x int) {}", + "F1", []Type{Typ[Int]}, + "F2", []Type{Typ[Int]}, + false, + }, + { + "package funcsymmetry; func F1[P any](x P) {}; func F2[Q any](x Q) {}", + "F1", []Type{Typ[Int]}, + "F2", []Type{Typ[Int]}, + false, + }, } - T := pkg.Scope().Lookup("T").Type().(*Named) + for _, test := range tests { + pkg, err := pkgForMode(".", test.src, nil, 0) + if err != nil { + t.Fatal(err) + } - // Instantiating the same type twice should result in pointer-equivalent - // instances. - ctxt := NewContext() - res1, err := Instantiate(ctxt, T, []Type{Typ[Int]}, false) - if err != nil { - t.Fatal(err) - } - res2, err := Instantiate(ctxt, T, []Type{Typ[Int]}, false) - if err != nil { - t.Fatal(err) - } + t.Run(pkg.Name(), func(t *testing.T) { + ctxt := NewContext() - if res1 != res2 { - t.Errorf("first instance (%s) not pointer-equivalent to second instance (%s)", res1, res2) + T1 := pkg.Scope().Lookup(test.name1).Type() + res1, err := Instantiate(ctxt, T1, test.targs1, false) + if err != nil { + t.Fatal(err) + } + + T2 := pkg.Scope().Lookup(test.name2).Type() + res2, err := Instantiate(ctxt, T2, test.targs2, false) + if err != nil { + t.Fatal(err) + } + + if gotEqual := res1 == res2; gotEqual != test.wantEqual { + t.Errorf("%s == %s: %t, want %t", res1, res2, gotEqual, test.wantEqual) + } + }) } } func TestInstantiateNonEquality(t *testing.T) { - const src = genericPkg + "p; type T[P any] int" + const src = "package p; type T[P any] int" - pkg1, err := pkgFor(".", src, nil) + pkg1, err := pkgForMode(".", src, nil, 0) if err != nil { t.Fatal(err) } - pkg2, err := pkgFor(".", src, nil) + pkg2, err := pkgForMode(".", src, nil, 0) if err != nil { t.Fatal(err) } @@ -73,7 +139,7 @@ func TestInstantiateNonEquality(t *testing.T) { } func TestMethodInstantiation(t *testing.T) { - const prefix = genericPkg + `p + const prefix = `package p type T[P any] struct{} @@ -95,7 +161,7 @@ var X T[int] for _, test := range tests { src := prefix + test.decl - pkg, err := pkgFor(".", src, nil) + pkg, err := pkgForMode(".", src, nil, 0) if err != nil { t.Fatal(err) } @@ -112,7 +178,7 @@ var X T[int] } func TestImmutableSignatures(t *testing.T) { - const src = genericPkg + `p + const src = `package p type T[P any] struct{} @@ -120,7 +186,7 @@ func (T[P]) m() {} var _ T[int] ` - pkg, err := pkgFor(".", src, nil) + pkg, err := pkgForMode(".", src, nil, 0) if err != nil { t.Fatal(err) } diff --git a/src/go/types/named.go b/src/go/types/named.go index ed3c426a123..06b6d4692b5 100644 --- a/src/go/types/named.go +++ b/src/go/types/named.go @@ -253,7 +253,7 @@ func expandNamed(ctxt *Context, n *Named, instPos token.Pos) (tparams *TypeParam if n.orig.tparams.Len() == n.targs.Len() { // We must always have a context, to avoid infinite recursion. ctxt = check.bestContext(ctxt) - h := ctxt.typeHash(n.orig, n.targs.list()) + h := ctxt.instanceHash(n.orig, n.targs.list()) // ensure that an instance is recorded for h to avoid infinite recursion. ctxt.update(h, n.orig, n.TypeArgs().list(), n) diff --git a/src/go/types/subst.go b/src/go/types/subst.go index 0e3eafdaf18..3ff81a06b62 100644 --- a/src/go/types/subst.go +++ b/src/go/types/subst.go @@ -207,7 +207,7 @@ func (subst *subster) typ(typ Type) Type { } // before creating a new named type, check if we have this one already - h := subst.ctxt.typeHash(t.orig, newTArgs) + h := subst.ctxt.instanceHash(t.orig, newTArgs) dump(">>> new type hash: %s", h) if named := subst.ctxt.lookup(h, t.orig, newTArgs); named != nil { dump(">>> found %s", named) diff --git a/src/go/types/typestring.go b/src/go/types/typestring.go index c448d254587..9192b0423bb 100644 --- a/src/go/types/typestring.go +++ b/src/go/types/typestring.go @@ -291,7 +291,7 @@ func (w *typeWriter) typ(typ Type) { // nothing. func (w *typeWriter) typePrefix(t *Named) { if w.ctxt != nil { - w.string(strconv.Itoa(w.ctxt.idForType(t))) + w.string(strconv.Itoa(w.ctxt.getID(t))) } } diff --git a/src/go/types/typexpr.go b/src/go/types/typexpr.go index 048bc95e153..09d14719857 100644 --- a/src/go/types/typexpr.go +++ b/src/go/types/typexpr.go @@ -409,7 +409,7 @@ func (check *Checker) instantiatedType(x ast.Expr, targsx []ast.Expr, def *Named } // create the instance - h := check.conf.Context.typeHash(orig, targs) + h := check.conf.Context.instanceHash(orig, targs) // targs may be incomplete, and require inference. In any case we should de-duplicate. inst, _ := check.conf.Context.lookup(h, orig, targs).(*Named) // If inst is non-nil, we can't just return here. Inst may have been