diff --git a/go/types/api.go b/go/types/api.go index 2a206d099ae..a733581da06 100644 --- a/go/types/api.go +++ b/go/types/api.go @@ -188,5 +188,3 @@ func Implements(V Type, T *Interface, static bool) bool { f, _ := MissingMethod(V, T, static) return f == nil } - -// BUG(gri): Interface vs non-interface comparisons are not correctly implemented. diff --git a/go/types/errors.go b/go/types/errors.go index 55a6c96cbe1..7c0271d99d7 100644 --- a/go/types/errors.go +++ b/go/types/errors.go @@ -24,7 +24,7 @@ func unreachable() { panic("unreachable") } -func (check *checker) formatMsg(format string, args []interface{}) string { +func (check *checker) sprintf(format string, args ...interface{}) string { for i, arg := range args { switch a := arg.(type) { case nil: @@ -44,13 +44,13 @@ func (check *checker) trace(pos token.Pos, format string, args ...interface{}) { fmt.Printf("%s:\t%s%s\n", check.fset.Position(pos), strings.Repeat(". ", check.indent), - check.formatMsg(format, args), + check.sprintf(format, args...), ) } // dump is only needed for debugging func (check *checker) dump(format string, args ...interface{}) { - fmt.Println(check.formatMsg(format, args)) + fmt.Println(check.sprintf(format, args...)) } func (check *checker) err(err error) { @@ -65,7 +65,7 @@ func (check *checker) err(err error) { } func (check *checker) errorf(pos token.Pos, format string, args ...interface{}) { - check.err(fmt.Errorf("%s: %s", check.fset.Position(pos), check.formatMsg(format, args))) + check.err(fmt.Errorf("%s: %s", check.fset.Position(pos), check.sprintf(format, args...))) } func (check *checker) invalidAST(pos token.Pos, format string, args ...interface{}) { diff --git a/go/types/expr.go b/go/types/expr.go index 6c9e97956a6..cfedd205544 100644 --- a/go/types/expr.go +++ b/go/types/expr.go @@ -560,24 +560,34 @@ Error: } func (check *checker) comparison(x, y *operand, op token.Token) { - // TODO(gri) deal with interface vs non-interface comparison - - valid := false + // spec: "In any comparison, the first operand must be assignable + // to the type of the second operand, or vice versa." + err := "" if x.isAssignableTo(check.conf, y.typ) || y.isAssignableTo(check.conf, x.typ) { + defined := false switch op { case token.EQL, token.NEQ: - valid = isComparable(x.typ) || - x.isNil() && hasNil(y.typ) || - y.isNil() && hasNil(x.typ) + // spec: "The equality operators == and != apply to operands that are comparable." + defined = isComparable(x.typ) || x.isNil() && hasNil(y.typ) || y.isNil() && hasNil(x.typ) case token.LSS, token.LEQ, token.GTR, token.GEQ: - valid = isOrdered(x.typ) + // spec: The ordering operators <, <=, >, and >= apply to operands that are ordered." + defined = isOrdered(x.typ) default: unreachable() } + if !defined { + typ := x.typ + if x.isNil() { + typ = y.typ + } + err = check.sprintf("operator %s not defined for %s", op, typ) + } + } else { + err = check.sprintf("mismatched types %s and %s", x.typ, y.typ) } - if !valid { - check.invalidOp(x.pos(), "cannot compare %s %s %s", x, op, y) + if err != "" { + check.errorf(x.pos(), "cannot compare %s %s %s (%s)", x.expr, op, y.expr, err) x.mode = invalid return } diff --git a/go/types/predicates.go b/go/types/predicates.go index cde1ec4f8da..67267321658 100644 --- a/go/types/predicates.go +++ b/go/types/predicates.go @@ -81,7 +81,6 @@ func isComparable(typ Type) bool { case *Basic: return t.kind != Invalid && t.kind != UntypedNil case *Pointer, *Interface, *Chan: - // assumes types are equal for pointers and channels return true case *Struct: for _, f := range t.fields { diff --git a/go/types/testdata/const0.src b/go/types/testdata/const0.src index 19e45898e07..4581fe97913 100644 --- a/go/types/testdata/const0.src +++ b/go/types/testdata/const0.src @@ -100,7 +100,7 @@ const ( tb0 bool = false tb1 bool = true tb2 mybool = 2 < 1 - tb3 mybool = ti1 /* ERROR "cannot compare" */ == tf1 + tb3 mybool = ti1 /* ERROR "mismatched types" */ == tf1 // integer values ti0 int8 = ui0 diff --git a/go/types/testdata/expr2.src b/go/types/testdata/expr2.src index 85bc2fedfee..31dc5f021c0 100644 --- a/go/types/testdata/expr2.src +++ b/go/types/testdata/expr2.src @@ -21,4 +21,227 @@ func _bool() { // corner cases var ( v0 = nil /* ERROR "cannot compare" */ == nil -) \ No newline at end of file +) + +func arrays() { + // basics + var a, b [10]int + _ = a == b + _ = a != b + _ = a /* ERROR < not defined */ < b + _ = a == nil /* ERROR cannot convert */ + + type C [10]int + var c C + _ = a == c + + type D [10]int + var d D + _ = c /* ERROR mismatched types */ == d + + var e [10]func() int + _ = e /* ERROR == not defined */ == e +} + +func structs() { + // basics + var s, t struct { + x int + a [10]float32 + _ bool + } + _ = s == t + _ = s != t + _ = s /* ERROR < not defined */ < t + _ = s == nil /* ERROR cannot convert */ + + type S struct { + x int + a [10]float32 + _ bool + } + type T struct { + x int + a [10]float32 + _ bool + } + var ss S + var tt T + _ = s == ss + _ = ss /* ERROR mismatched types */ == tt + + var u struct { + x int + a [10]map[string]int + } + _ = u /* ERROR cannot compare */ == u +} + +func pointers() { + // nil + _ = nil /* ERROR == not defined */ == nil + _ = nil /* ERROR != not defined */ != nil + _ = nil /* ERROR < not defined */ < nil + _ = nil /* ERROR <= not defined */ <= nil + _ = nil /* ERROR > not defined */ > nil + _ = nil /* ERROR >= not defined */ >= nil + + // basics + var p, q *int + _ = p == q + _ = p != q + + _ = p == nil + _ = p != nil + _ = nil == q + _ = nil != q + + _ = p /* ERROR < not defined */ < q + _ = p /* ERROR <= not defined */ <= q + _ = p /* ERROR > not defined */ > q + _ = p /* ERROR >= not defined */ >= q + + // various element types + type ( + S1 struct{} + S2 struct{} + P1 *S1 + P2 *S2 + ) + var ( + ps1 *S1 + ps2 *S2 + p1 P1 + p2 P2 + ) + _ = ps1 == ps1 + _ = ps1 /* ERROR mismatched types */ == ps2 + _ = ps2 /* ERROR mismatched types */ == ps1 + + _ = p1 == p1 + _ = p1 /* ERROR mismatched types */ == p2 + + _ = p1 == ps1 +} + +func channels() { + // basics + var c, d chan int + _ = c == d + _ = c != d + _ = c == nil + _ = c /* ERROR < not defined */ < d + + // various element types (named types) + type ( + C1 chan int + C1r <-chan int + C1s chan<- int + C2 chan float32 + ) + var ( + c1 C1 + c1r C1r + c1s C1s + c1a chan int + c2 C2 + ) + _ = c1 == c1 + _ = c1 /* ERROR mismatched types */ == c1r + _ = c1 /* ERROR mismatched types */ == c1s + _ = c1r /* ERROR mismatched types */ == c1s + _ = c1 == c1a + _ = c1a == c1 + _ = c1 /* ERROR mismatched types */ == c2 + _ = c1a /* ERROR mismatched types */ == c2 + + // various element types (unnamed types) + var ( + d1 chan int + d1r <-chan int + d1s chan<- int + d1a chan<- int + d2 chan float32 + ) + _ = d1 == d1 + _ = d1 == d1r + _ = d1 == d1s + _ = d1r /* ERROR mismatched types */ == d1s + _ = d1 == d1a + _ = d1a == d1 + _ = d1 /* ERROR mismatched types */ == d2 + _ = d1a /* ERROR mismatched types */ == d2 +} + +// for interfaces test +type S1 struct{} +type S11 struct{} +type S2 struct{} +func (*S1) m() int +func (*S11) m() int +func (*S11) n() +func (*S2) m() float32 + +func interfaces() { + // basics + var i, j interface{ m() int } + _ = i == j + _ = i != j + _ = i == nil + _ = i /* ERROR < not defined */ < j + + // various interfaces + var ii interface { m() int; n() } + var k interface { m() float32 } + _ = i == ii + _ = i /* ERROR mismatched types */ == k + + // interfaces vs values + var s1 S1 + var s11 S11 + var s2 S2 + + _ = i == 0 /* ERROR cannot convert */ + _ = i /* ERROR mismatched types */ == s1 + _ = i == &s1 + _ = i == &s11 + + _ = i /* ERROR mismatched types */ == s2 + _ = i /* ERROR mismatched types */ == &s2 +} + +func slices() { + // basics + var s []int + _ = s == nil + _ = s != nil + _ = s /* ERROR < not defined */ < nil + + // slices are not otherwise comparable + _ = s /* ERROR == not defined */ == s + _ = s /* ERROR < not defined */ < s +} + +func maps() { + // basics + var m map[string]int + _ = m == nil + _ = m != nil + _ = m /* ERROR < not defined */ < nil + + // maps are not otherwise comparable + _ = m /* ERROR == not defined */ == m + _ = m /* ERROR < not defined */ < m +} + +func funcs() { + // basics + var f func(int) float32 + _ = f == nil + _ = f != nil + _ = f /* ERROR < not defined */ < nil + + // funcs are not otherwise comparable + _ = f /* ERROR == not defined */ == f + _ = f /* ERROR < not defined */ < f +}