diff --git a/src/cmd/compile/internal/types2/api.go b/src/cmd/compile/internal/types2/api.go index 367cb8f700e..4ea3989c395 100644 --- a/src/cmd/compile/internal/types2/api.go +++ b/src/cmd/compile/internal/types2/api.go @@ -440,8 +440,16 @@ func ConvertibleTo(V, T Type) bool { // Implements reports whether type V implements interface T. func Implements(V Type, T *Interface) bool { - f, _ := MissingMethod(V, T, true) - return f == nil + if T.Empty() { + // All types (even Typ[Invalid]) implement the empty interface. + return true + } + // Checker.implements suppresses errors for invalid types, so we need special + // handling here. + if V.Underlying() == Typ[Invalid] { + return false + } + return (*Checker)(nil).implements(V, T, nil) == nil } // Identical reports whether x and y are identical types. diff --git a/src/cmd/compile/internal/types2/api_test.go b/src/cmd/compile/internal/types2/api_test.go index 9436a4ed970..4227397df9f 100644 --- a/src/cmd/compile/internal/types2/api_test.go +++ b/src/cmd/compile/internal/types2/api_test.go @@ -2102,3 +2102,106 @@ func TestInstanceIdentity(t *testing.T) { t.Errorf("mismatching types: a.A: %s, b.B: %s", a.Type(), b.Type()) } } + +func TestImplements(t *testing.T) { + const src = ` +package p + +type EmptyIface interface{} + +type I interface { + m() +} + +type C interface { + m() + ~int +} + +type Integer interface{ + int8 | int16 | int32 | int64 +} + +type EmptyTypeSet interface{ + Integer + ~string +} + +type N1 int +func (N1) m() {} + +type N2 int +func (*N2) m() {} + +type N3 int +func (N3) m(int) {} + +type N4 string +func (N4) m() + +type Bad Bad // invalid type +` + + f, err := parseSrc("p.go", src) + if err != nil { + t.Fatal(err) + } + conf := Config{Error: func(error) {}} + pkg, _ := conf.Check(f.PkgName.Value, []*syntax.File{f}, nil) + + scope := pkg.Scope() + var ( + EmptyIface = scope.Lookup("EmptyIface").Type().Underlying().(*Interface) + I = scope.Lookup("I").Type().(*Named) + II = I.Underlying().(*Interface) + C = scope.Lookup("C").Type().(*Named) + CI = C.Underlying().(*Interface) + Integer = scope.Lookup("Integer").Type().Underlying().(*Interface) + EmptyTypeSet = scope.Lookup("EmptyTypeSet").Type().Underlying().(*Interface) + N1 = scope.Lookup("N1").Type() + N1p = NewPointer(N1) + N2 = scope.Lookup("N2").Type() + N2p = NewPointer(N2) + N3 = scope.Lookup("N3").Type() + N4 = scope.Lookup("N4").Type() + Bad = scope.Lookup("Bad").Type() + ) + + tests := []struct { + t Type + i *Interface + want bool + }{ + {I, II, true}, + {I, CI, false}, + {C, II, true}, + {C, CI, true}, + {Typ[Int8], Integer, true}, + {Typ[Int64], Integer, true}, + {Typ[String], Integer, false}, + {EmptyTypeSet, II, true}, + {EmptyTypeSet, EmptyTypeSet, true}, + {Typ[Int], EmptyTypeSet, false}, + {N1, II, true}, + {N1, CI, true}, + {N1p, II, true}, + {N1p, CI, false}, + {N2, II, false}, + {N2, CI, false}, + {N2p, II, true}, + {N2p, CI, false}, + {N3, II, false}, + {N3, CI, false}, + {N4, II, true}, + {N4, CI, false}, + {Bad, II, false}, + {Bad, CI, false}, + {Bad, EmptyIface, true}, + } + + for _, test := range tests { + if got := Implements(test.t, test.i); got != test.want { + t.Errorf("Implements(%s, %s) = %t, want %t", test.t, test.i, got, test.want) + } + } +} diff --git a/src/go/types/api.go b/src/go/types/api.go index c115d07b418..51d58c49aab 100644 --- a/src/go/types/api.go +++ b/src/go/types/api.go @@ -436,8 +436,16 @@ func ConvertibleTo(V, T Type) bool { // Implements reports whether type V implements interface T. func Implements(V Type, T *Interface) bool { - f, _ := MissingMethod(V, T, true) - return f == nil + if T.Empty() { + // All types (even Typ[Invalid]) implement the empty interface. + return true + } + // Checker.implements suppresses errors for invalid types, so we need special + // handling here. + if V.Underlying() == Typ[Invalid] { + return false + } + return (*Checker)(nil).implements(V, T, nil) == nil } // Identical reports whether x and y are identical types. diff --git a/src/go/types/api_test.go b/src/go/types/api_test.go index c8fda5521ad..7b7baa76042 100644 --- a/src/go/types/api_test.go +++ b/src/go/types/api_test.go @@ -2093,3 +2093,107 @@ func TestInstanceIdentity(t *testing.T) { t.Errorf("mismatching types: a.A: %s, b.B: %s", a.Type(), b.Type()) } } + +func TestImplements(t *testing.T) { + const src = ` +package p + +type EmptyIface interface{} + +type I interface { + m() +} + +type C interface { + m() + ~int +} + +type Integer interface{ + int8 | int16 | int32 | int64 +} + +type EmptyTypeSet interface{ + Integer + ~string +} + +type N1 int +func (N1) m() {} + +type N2 int +func (*N2) m() {} + +type N3 int +func (N3) m(int) {} + +type N4 string +func (N4) m() + +type Bad Bad // invalid type +` + + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, "p.go", src, 0) + if err != nil { + t.Fatal(err) + } + conf := Config{Error: func(error) {}} + pkg, _ := conf.Check(f.Name.Name, fset, []*ast.File{f}, nil) + + scope := pkg.Scope() + var ( + EmptyIface = scope.Lookup("EmptyIface").Type().Underlying().(*Interface) + I = scope.Lookup("I").Type().(*Named) + II = I.Underlying().(*Interface) + C = scope.Lookup("C").Type().(*Named) + CI = C.Underlying().(*Interface) + Integer = scope.Lookup("Integer").Type().Underlying().(*Interface) + EmptyTypeSet = scope.Lookup("EmptyTypeSet").Type().Underlying().(*Interface) + N1 = scope.Lookup("N1").Type() + N1p = NewPointer(N1) + N2 = scope.Lookup("N2").Type() + N2p = NewPointer(N2) + N3 = scope.Lookup("N3").Type() + N4 = scope.Lookup("N4").Type() + Bad = scope.Lookup("Bad").Type() + ) + + tests := []struct { + t Type + i *Interface + want bool + }{ + {I, II, true}, + {I, CI, false}, + {C, II, true}, + {C, CI, true}, + {Typ[Int8], Integer, true}, + {Typ[Int64], Integer, true}, + {Typ[String], Integer, false}, + {EmptyTypeSet, II, true}, + {EmptyTypeSet, EmptyTypeSet, true}, + {Typ[Int], EmptyTypeSet, false}, + {N1, II, true}, + {N1, CI, true}, + {N1p, II, true}, + {N1p, CI, false}, + {N2, II, false}, + {N2, CI, false}, + {N2p, II, true}, + {N2p, CI, false}, + {N3, II, false}, + {N3, CI, false}, + {N4, II, true}, + {N4, CI, false}, + {Bad, II, false}, + {Bad, CI, false}, + {Bad, EmptyIface, true}, + } + + for _, test := range tests { + if got := Implements(test.t, test.i); got != test.want { + t.Errorf("Implements(%s, %s) = %t, want %t", test.t, test.i, got, test.want) + } + } +}