diff --git a/go/types/stdlib_test.go b/go/types/stdlib_test.go index 1a09030fa9c..11d8ba92985 100644 --- a/go/types/stdlib_test.go +++ b/go/types/stdlib_test.go @@ -119,14 +119,12 @@ func TestStdtest(t *testing.T) { testTestDir(t, filepath.Join(runtime.GOROOT(), "test"), "cmplxdivide.go", // also needs file cmplxdivide1.go - ignore "mapnan.go", "sigchld.go", // don't work on Windows; testTestDir should consult build tags - "typeswitch2.go", // TODO(gri) implement duplicate checking in type switches ) } func TestStdfixed(t *testing.T) { testTestDir(t, filepath.Join(runtime.GOROOT(), "test", "fixedbugs"), "bug165.go", // TODO(gri) isComparable not working for incomplete struct type - "bug200.go", // TODO(gri) complete duplicate checking in type switches "bug223.go", "bug413.go", "bug459.go", // TODO(gri) complete initialization checks "bug248.go", "bug302.go", "bug369.go", // complex test instructions - ignore "issue3924.go", // TODO(gri) && and || produce bool result (not untyped bool) diff --git a/go/types/stmt.go b/go/types/stmt.go index c94fd4bae2a..f11086e14d7 100644 --- a/go/types/stmt.go +++ b/go/types/stmt.go @@ -121,6 +121,31 @@ func (check *checker) caseValues(x operand /* copy argument (not *operand!) */, } } +func (check *checker) caseTypes(x *operand, xtyp *Interface, types []ast.Expr, seen map[Type]token.Pos) (T Type) { +L: + for _, e := range types { + T = check.typOrNil(e) + if T == Typ[Invalid] { + continue + } + // complain about duplicate types + // TODO(gri) use a type hash to avoid quadratic algorithm + for t, pos := range seen { + if T == nil && t == nil || T != nil && t != nil && IsIdentical(T, t) { + // talk about "case" rather than "type" because of nil case + check.errorf(e.Pos(), "duplicate case in type switch") + check.errorf(pos, "previous case %s", T) + continue L + } + } + seen[T] = e.Pos() + if T != nil { + check.typeAssertion(e.Pos(), x, xtyp, T) + } + } + return +} + // stmt typechecks statement s. func (check *checker) stmt(ctxt stmtContext, s ast.Stmt) { // statements cannot use iota in general @@ -396,20 +421,16 @@ func (check *checker) stmt(ctxt stmtContext, s ast.Stmt) { check.multipleDefaults(s.Body.List) - var lhsVars []*Var // set of implicitly declared lhs variables + var lhsVars []*Var // set of implicitly declared lhs variables + seen := make(map[Type]token.Pos) // map of seen types to positions for _, s := range s.Body.List { clause, _ := s.(*ast.CaseClause) if clause == nil { - continue // error reported before + check.invalidAST(s.Pos(), "incorrect type switch case") + continue } // Check each type in this type switch case. - var T Type - for _, expr := range clause.List { - T = check.typOrNil(expr) - if T != nil && T != Typ[Invalid] { - check.typeAssertion(expr.Pos(), &x, xtyp, T) - } - } + T := check.caseTypes(&x, xtyp, clause.List, seen) check.openScope(clause) // If lhs exists, declare a corresponding variable in the case-local scope if necessary. if lhs != nil { diff --git a/go/types/testdata/stmt0.src b/go/types/testdata/stmt0.src index 5fa39997dbb..f7be5d90cdc 100644 --- a/go/types/testdata/stmt0.src +++ b/go/types/testdata/stmt0.src @@ -528,11 +528,33 @@ type B interface { b() } type C interface { a(int) } func typeswitch2() { - switch A(nil).(type) { - case A: - case B: - case C /* ERROR "cannot have dynamic type" */: - } + switch A(nil).(type) { + case A: + case B: + case C /* ERROR "cannot have dynamic type" */: + } +} + +func typeswitch3(x interface{}) { + switch x.(type) { + case int /* ERROR previous case int */ : + case float64: + case int /* ERROR duplicate case */ : + } + + switch x.(type) { + case nil /* ERROR previous case */ /* ERROR previous case */ : + case int: + case nil /* ERROR duplicate case */ , nil /* ERROR duplicate case */ : + } + + type F func(int) + switch x.(type) { + case nil: + case int, func /* ERROR previous case */ (int): + case float32, func /* ERROR duplicate case */ (x int): + case F: + } } func rangeloops1() {