From 70d54df4f6bd63b0057d718c6fc3fffc0d94bbc1 Mon Sep 17 00:00:00 2001 From: Dan Scales Date: Fri, 12 Mar 2021 11:36:02 -0800 Subject: [PATCH] cmd/compile: getting more built-ins to work with generics For Builtin ops, we currently stay with using the old typechecker to transform the call to a more specific expression and possibly use more specific ops. However, for a bunch of the ops, we delay calling the old typechecker if any of the args have type params, for a variety of reasons. In the near future, we will start creating separate functions that do the same transformations as the old typechecker for calls, builtins, indexing, comparisons, etc. These functions can then be called at noder time for nodes with no type params, and at stenciling time for nodes with type params. Remove unnecessary calls to types1 typechecker for most kinds of statements (still need it for SendStmt, AssignStmt, ReturnStmt, and SelectStmt). In particular, we don't need it for RangeStmt, and this avoids some complaints by the types1 typechecker on generic code. Other small changes: - Fix check on whether to delay calling types1-typechecker on type conversions. Should check if HasTParam is true, rather than if the type is directly a TYPEPARAM. - Don't call types1-typechecker on an indexing operation if the left operand has a typeparam in its type and is not obviously a TMAP, TSLICE, or TARRAY. As above, we will eventually have to create a new function that can do the required transformations (for complicated cases) at noder time or stenciling time. - Copy n.BuiltinOp in subster.node() - The complex arithmetic example in absdiff.go now works. - Added new tests double.go and append.go - Added new example with a new() call in settable.go Change-Id: I8f377afb6126cab1826bd3c2732aa8cdf1f7e0b4 Reviewed-on: https://go-review.googlesource.com/c/go/+/301951 Run-TryBot: Dan Scales TryBot-Result: Go Bot Trust: Dan Scales Trust: Robert Griesemer Reviewed-by: Robert Griesemer --- src/cmd/compile/internal/noder/expr.go | 2 +- src/cmd/compile/internal/noder/helpers.go | 56 ++++++++--- src/cmd/compile/internal/noder/stencil.go | 23 ++++- src/cmd/compile/internal/noder/stmt.go | 29 ++++-- test/typeparam/absdiff.go | 35 ++++--- test/typeparam/append.go | 31 ++++++ test/typeparam/double.go | 50 ++++++++++ test/typeparam/settable.go | 112 +++++++++++++--------- 8 files changed, 250 insertions(+), 88 deletions(-) create mode 100644 test/typeparam/append.go create mode 100644 test/typeparam/double.go diff --git a/src/cmd/compile/internal/noder/expr.go b/src/cmd/compile/internal/noder/expr.go index 989ebf236e..1ca5552879 100644 --- a/src/cmd/compile/internal/noder/expr.go +++ b/src/cmd/compile/internal/noder/expr.go @@ -135,7 +135,7 @@ func (g *irgen) expr0(typ types2.Type, expr syntax.Expr) ir.Node { index := g.expr(expr.Index) if index.Op() != ir.OTYPE { // This is just a normal index expression - return Index(pos, g.expr(expr.X), index) + return Index(pos, g.typ(typ), g.expr(expr.X), index) } // This is generic function instantiation with a single type targs = []ir.Node{index} diff --git a/src/cmd/compile/internal/noder/helpers.go b/src/cmd/compile/internal/noder/helpers.go index cf7a3e22b3..1210d4b58c 100644 --- a/src/cmd/compile/internal/noder/helpers.go +++ b/src/cmd/compile/internal/noder/helpers.go @@ -83,11 +83,12 @@ func Binary(pos src.XPos, op ir.Op, x, y ir.Node) ir.Node { } func Call(pos src.XPos, typ *types.Type, fun ir.Node, args []ir.Node, dots bool) ir.Node { - // TODO(mdempsky): This should not be so difficult. + n := ir.NewCallExpr(pos, ir.OCALL, fun, args) + n.IsDDD = dots + if fun.Op() == ir.OTYPE { // Actually a type conversion, not a function call. - n := ir.NewCallExpr(pos, ir.OCALL, fun, args) - if fun.Type().Kind() == types.TTYPEPARAM { + if fun.Type().HasTParam() || args[0].Type().HasTParam() { // For type params, don't typecheck until we actually know // the type. return typed(typ, n) @@ -96,9 +97,34 @@ func Call(pos src.XPos, typ *types.Type, fun ir.Node, args []ir.Node, dots bool) } if fun, ok := fun.(*ir.Name); ok && fun.BuiltinOp != 0 { - // Call to a builtin function. - n := ir.NewCallExpr(pos, ir.OCALL, fun, args) - n.IsDDD = dots + // For Builtin ops, we currently stay with using the old + // typechecker to transform the call to a more specific expression + // and possibly use more specific ops. However, for a bunch of the + // ops, we delay doing the old typechecker if any of the args have + // type params, for a variety of reasons: + // + // OMAKE: hard to choose specific ops OMAKESLICE, etc. until arg type is known + // OREAL/OIMAG: can't determine type float32/float64 until arg type know + // OLEN/OCAP: old typechecker will complain if arg is not obviously a slice/array. + // OAPPEND: old typechecker will complain if arg is not obviously slice, etc. + // + // We will eventually break out the transforming functionality + // needed for builtin's, and call it here or during stenciling, as + // appropriate. + switch fun.BuiltinOp { + case ir.OMAKE, ir.OREAL, ir.OIMAG, ir.OLEN, ir.OCAP, ir.OAPPEND: + hasTParam := false + for _, arg := range args { + if arg.Type().HasTParam() { + hasTParam = true + break + } + } + if hasTParam { + return typed(typ, n) + } + } + switch fun.BuiltinOp { case ir.OCLOSE, ir.ODELETE, ir.OPANIC, ir.OPRINT, ir.OPRINTN: return typecheck.Stmt(n) @@ -124,9 +150,6 @@ func Call(pos src.XPos, typ *types.Type, fun ir.Node, args []ir.Node, dots bool) } } - n := ir.NewCallExpr(pos, ir.OCALL, fun, args) - n.IsDDD = dots - if fun.Op() == ir.OXDOT { if !fun.(*ir.SelectorExpr).X.Type().HasTParam() { base.FatalfAt(pos, "Expecting type param receiver in %v", fun) @@ -230,9 +253,18 @@ func method(typ *types.Type, index int) *types.Field { return types.ReceiverBaseType(typ).Methods().Index(index) } -func Index(pos src.XPos, x, index ir.Node) ir.Node { - // TODO(mdempsky): Avoid typecheck.Expr (which will call tcIndex) - return typecheck.Expr(ir.NewIndexExpr(pos, x, index)) +func Index(pos src.XPos, typ *types.Type, x, index ir.Node) ir.Node { + n := ir.NewIndexExpr(pos, x, index) + // TODO(danscales): Temporary fix. Need to separate out the + // transformations done by the old typechecker (in tcIndex()), to be + // called here or after stenciling. + if x.Type().HasTParam() && x.Type().Kind() != types.TMAP && + x.Type().Kind() != types.TSLICE && x.Type().Kind() != types.TARRAY { + // Old typechecker will complain if arg is not obviously a slice/array/map. + typed(typ, n) + return n + } + return typecheck.Expr(n) } func Slice(pos src.XPos, x, low, high, max ir.Node) ir.Node { diff --git a/src/cmd/compile/internal/noder/stencil.go b/src/cmd/compile/internal/noder/stencil.go index 64b3a942e2..55aee9b6ff 100644 --- a/src/cmd/compile/internal/noder/stencil.go +++ b/src/cmd/compile/internal/noder/stencil.go @@ -330,8 +330,13 @@ func (subst *subster) node(n ir.Node) ir.Node { m.SetIsClosureVar(true) } t := x.Type() - newt := subst.typ(t) - m.SetType(newt) + if t == nil { + assert(name.BuiltinOp != 0) + } else { + newt := subst.typ(t) + m.SetType(newt) + } + m.BuiltinOp = name.BuiltinOp m.Curfn = subst.newf m.Class = name.Class m.Func = name.Func @@ -396,11 +401,23 @@ func (subst *subster) node(n ir.Node) ir.Node { // that the OXDOT was resolved. call.SetTypecheck(0) typecheck.Call(call) + } else if name := call.X.Name(); name != nil { + switch name.BuiltinOp { + case ir.OMAKE, ir.OREAL, ir.OIMAG, ir.OLEN, ir.OCAP, ir.OAPPEND: + // Call old typechecker (to do any + // transformations) now that we know the + // type of the args. + m.SetTypecheck(0) + m = typecheck.Expr(m) + default: + base.FatalfAt(call.Pos(), "Unexpected builtin op") + } + } else if call.X.Op() != ir.OFUNCINST { // A call with an OFUNCINST will get typechecked // in stencil() once we have created & attached the // instantiation to be called. - base.FatalfAt(call.Pos(), "Expecting OCALLPART or OTYPE or OFUNCINST with CALL") + base.FatalfAt(call.Pos(), "Expecting OCALLPART or OTYPE or OFUNCINST or builtin with CALL") } } diff --git a/src/cmd/compile/internal/noder/stmt.go b/src/cmd/compile/internal/noder/stmt.go index 1775116f41..31c6bfe5c8 100644 --- a/src/cmd/compile/internal/noder/stmt.go +++ b/src/cmd/compile/internal/noder/stmt.go @@ -28,7 +28,11 @@ func (g *irgen) stmts(stmts []syntax.Stmt) []ir.Node { func (g *irgen) stmt(stmt syntax.Stmt) ir.Node { // TODO(mdempsky): Remove dependency on typecheck. - return typecheck.Stmt(g.stmt0(stmt)) + n := g.stmt0(stmt) + if n != nil { + n.SetTypecheck(1) + } + return n } func (g *irgen) stmt0(stmt syntax.Stmt) ir.Node { @@ -46,17 +50,20 @@ func (g *irgen) stmt0(stmt syntax.Stmt) ir.Node { } return x case *syntax.SendStmt: - return ir.NewSendStmt(g.pos(stmt), g.expr(stmt.Chan), g.expr(stmt.Value)) + n := ir.NewSendStmt(g.pos(stmt), g.expr(stmt.Chan), g.expr(stmt.Value)) + // Need to do the AssignConv() in tcSend(). + return typecheck.Stmt(n) case *syntax.DeclStmt: return ir.NewBlockStmt(g.pos(stmt), g.decls(stmt.DeclList)) case *syntax.AssignStmt: if stmt.Op != 0 && stmt.Op != syntax.Def { op := g.op(stmt.Op, binOps[:]) + // May need to insert ConvExpr nodes on the args in tcArith if stmt.Rhs == nil { - return IncDec(g.pos(stmt), op, g.expr(stmt.Lhs)) + return typecheck.Stmt(IncDec(g.pos(stmt), op, g.expr(stmt.Lhs))) } - return ir.NewAssignOpStmt(g.pos(stmt), op, g.expr(stmt.Lhs), g.expr(stmt.Rhs)) + return typecheck.Stmt(ir.NewAssignOpStmt(g.pos(stmt), op, g.expr(stmt.Lhs), g.expr(stmt.Rhs))) } names, lhs := g.assignList(stmt.Lhs, stmt.Op == syntax.Def) @@ -65,25 +72,31 @@ func (g *irgen) stmt0(stmt syntax.Stmt) ir.Node { if len(lhs) == 1 && len(rhs) == 1 { n := ir.NewAssignStmt(g.pos(stmt), lhs[0], rhs[0]) n.Def = initDefn(n, names) - return n + // Need to set Assigned in checkassign for maps + return typecheck.Stmt(n) } n := ir.NewAssignListStmt(g.pos(stmt), ir.OAS2, lhs, rhs) n.Def = initDefn(n, names) - return n + // Need to do tcAssignList(). + return typecheck.Stmt(n) case *syntax.BranchStmt: return ir.NewBranchStmt(g.pos(stmt), g.tokOp(int(stmt.Tok), branchOps[:]), g.name(stmt.Label)) case *syntax.CallStmt: return ir.NewGoDeferStmt(g.pos(stmt), g.tokOp(int(stmt.Tok), callOps[:]), g.expr(stmt.Call)) case *syntax.ReturnStmt: - return ir.NewReturnStmt(g.pos(stmt), g.exprList(stmt.Results)) + n := ir.NewReturnStmt(g.pos(stmt), g.exprList(stmt.Results)) + // Need to do typecheckaste() for multiple return values + return typecheck.Stmt(n) case *syntax.IfStmt: return g.ifStmt(stmt) case *syntax.ForStmt: return g.forStmt(stmt) case *syntax.SelectStmt: - return g.selectStmt(stmt) + n := g.selectStmt(stmt) + // Need to convert assignments to OSELRECV2 in tcSelect() + return typecheck.Stmt(n) case *syntax.SwitchStmt: return g.switchStmt(stmt) diff --git a/test/typeparam/absdiff.go b/test/typeparam/absdiff.go index 5dd58f14f7..1381d7c92c 100644 --- a/test/typeparam/absdiff.go +++ b/test/typeparam/absdiff.go @@ -8,7 +8,7 @@ package main import ( "fmt" - //"math" + "math" ) type Numeric interface { @@ -57,14 +57,14 @@ func (a orderedAbs[T]) Abs() orderedAbs[T] { // complexAbs is a helper type that defines an Abs method for // complex types. -// type complexAbs[T Complex] T +type complexAbs[T Complex] T -// func (a complexAbs[T]) Abs() complexAbs[T] { -// r := float64(real(a)) -// i := float64(imag(a)) -// d := math.Sqrt(r * r + i * i) -// return complexAbs[T](complex(d, 0)) -// } +func (a complexAbs[T]) Abs() complexAbs[T] { + r := float64(real(a)) + i := float64(imag(a)) + d := math.Sqrt(r * r + i * i) + return complexAbs[T](complex(d, 0)) +} // OrderedAbsDifference returns the absolute value of the difference // between a and b, where a and b are of an ordered type. @@ -74,9 +74,9 @@ func orderedAbsDifference[T orderedNumeric](a, b T) T { // ComplexAbsDifference returns the absolute value of the difference // between a and b, where a and b are of a complex type. -// func complexAbsDifference[T Complex](a, b T) T { -// return T(absDifference(complexAbs[T](a), complexAbs[T](b))) -// } +func complexAbsDifference[T Complex](a, b T) T { + return T(absDifference(complexAbs[T](a), complexAbs[T](b))) +} func main() { if got, want := orderedAbsDifference(1.0, -2.0), 3.0; got != want { @@ -89,11 +89,10 @@ func main() { panic(fmt.Sprintf("got = %v, want = %v", got, want)) } - // Still have to handle built-ins real/abs to make this work - // if got, want := complexAbsDifference(5.0 + 2.0i, 2.0 - 2.0i), 5; got != want { - // panic(fmt.Sprintf("got = %v, want = %v", got, want) - // } - // if got, want := complexAbsDifference(2.0 - 2.0i, 5.0 + 2.0i), 5; got != want { - // panic(fmt.Sprintf("got = %v, want = %v", got, want) - // } + if got, want := complexAbsDifference(5.0 + 2.0i, 2.0 - 2.0i), 5+0i; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } + if got, want := complexAbsDifference(2.0 - 2.0i, 5.0 + 2.0i), 5+0i; got != want { + panic(fmt.Sprintf("got = %v, want = %v", got, want)) + } } diff --git a/test/typeparam/append.go b/test/typeparam/append.go new file mode 100644 index 0000000000..8b9bc2039f --- /dev/null +++ b/test/typeparam/append.go @@ -0,0 +1,31 @@ +// run -gcflags=-G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +type Recv <-chan int + +type sliceOf[E any] interface { + type []E +} + +func _Append[S sliceOf[T], T any](s S, t ...T) S { + return append(s, t...) +} + +func main() { + recv := make(Recv) + a := _Append([]Recv{recv}, recv) + if len(a) != 2 || a[0] != recv || a[1] != recv { + panic(a) + } + + recv2 := make(chan<- int) + a2 := _Append([]chan<- int{recv2}, recv2) + if len(a2) != 2 || a2[0] != recv2 || a2[1] != recv2 { + panic(a) + } +} diff --git a/test/typeparam/double.go b/test/typeparam/double.go new file mode 100644 index 0000000000..1f7a26c7f4 --- /dev/null +++ b/test/typeparam/double.go @@ -0,0 +1,50 @@ +// run -gcflags=-G=3 + +// Copyright 2021 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "fmt" + "reflect" +) + +type Number interface { + type int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, uintptr, float32, float64 +} + +type MySlice []int + +type _SliceOf[E any] interface { + type []E +} + +func _DoubleElems[S _SliceOf[E], E Number](s S) S { + r := make(S, len(s)) + for i, v := range s { + r[i] = v + v + } + return r +} + +func main() { + arg := MySlice{1, 2, 3} + want := MySlice{2, 4, 6} + got := _DoubleElems[MySlice, int](arg) + if !reflect.DeepEqual(got, want) { + panic(fmt.Sprintf("got %s, want %s", got, want)) + } + + // constraint type inference + got = _DoubleElems[MySlice](arg) + if !reflect.DeepEqual(got, want) { + panic(fmt.Sprintf("got %s, want %s", got, want)) + } + + got = _DoubleElems(arg) + if !reflect.DeepEqual(got, want) { + panic(fmt.Sprintf("got %s, want %s", got, want)) + } +} diff --git a/test/typeparam/settable.go b/test/typeparam/settable.go index 29874fb189..588166da85 100644 --- a/test/typeparam/settable.go +++ b/test/typeparam/settable.go @@ -14,44 +14,58 @@ import ( // Various implementations of fromStrings(). type _Setter[B any] interface { - Set(string) + Set(string) type *B } // Takes two type parameters where PT = *T func fromStrings1[T any, PT _Setter[T]](s []string) []T { - result := make([]T, len(s)) - for i, v := range s { - // The type of &result[i] is *T which is in the type list - // of Setter, so we can convert it to PT. - p := PT(&result[i]) - // PT has a Set method. - p.Set(v) - } - return result + result := make([]T, len(s)) + for i, v := range s { + // The type of &result[i] is *T which is in the type list + // of Setter, so we can convert it to PT. + p := PT(&result[i]) + // PT has a Set method. + p.Set(v) + } + return result } +func fromStrings1a[T any, PT _Setter[T]](s []string) []PT { + result := make([]PT, len(s)) + for i, v := range s { + // The type new(T) is *T which is in the type list + // of Setter, so we can convert it to PT. + result[i] = PT(new(T)) + p := result[i] + // PT has a Set method. + p.Set(v) + } + return result +} + + // Takes one type parameter and a set function func fromStrings2[T any](s []string, set func(*T, string)) []T { - results := make([]T, len(s)) - for i, v := range s { - set(&results[i], v) - } - return results + results := make([]T, len(s)) + for i, v := range s { + set(&results[i], v) + } + return results } type _Setter2 interface { - Set(string) + Set(string) } // Takes only one type parameter, but causes a panic (see below) func fromStrings3[T _Setter2](s []string) []T { - results := make([]T, len(s)) - for i, v := range s { + results := make([]T, len(s)) + for i, v := range s { // Panics if T is a pointer type because receiver is T(nil). results[i].Set(v) - } - return results + } + return results } // Two concrete types with the appropriate Set method. @@ -59,11 +73,11 @@ func fromStrings3[T _Setter2](s []string) []T { type SettableInt int func (p *SettableInt) Set(s string) { - i, err := strconv.Atoi(s) - if err != nil { - panic(err) - } - *p = SettableInt(i) + i, err := strconv.Atoi(s) + if err != nil { + panic(err) + } + *p = SettableInt(i) } type SettableString struct { @@ -71,34 +85,40 @@ type SettableString struct { } func (x *SettableString) Set(s string) { - x.s = s + x.s = s } func main() { - s := fromStrings1[SettableInt, *SettableInt]([]string{"1"}) - if len(s) != 1 || s[0] != 1 { - panic(fmt.Sprintf("got %v, want %v", s, []int{1})) - } + s := fromStrings1[SettableInt, *SettableInt]([]string{"1"}) + if len(s) != 1 || s[0] != 1 { + panic(fmt.Sprintf("got %v, want %v", s, []int{1})) + } + + s2 := fromStrings1a[SettableInt, *SettableInt]([]string{"1"}) + if len(s2) != 1 || *s2[0] != 1 { + x := 1 + panic(fmt.Sprintf("got %v, want %v", s2, []*int{&x})) + } // Test out constraint type inference, which should determine that the second // type param is *SettableString. ps := fromStrings1[SettableString]([]string{"x", "y"}) - if len(ps) != 2 || ps[0] != (SettableString{"x"}) || ps[1] != (SettableString{"y"}) { - panic(s) - } + if len(ps) != 2 || ps[0] != (SettableString{"x"}) || ps[1] != (SettableString{"y"}) { + panic(s) + } - s = fromStrings2([]string{"1"}, func(p *SettableInt, s string) { p.Set(s) }) - if len(s) != 1 || s[0] != 1 { - panic(fmt.Sprintf("got %v, want %v", s, []int{1})) - } + s = fromStrings2([]string{"1"}, func(p *SettableInt, s string) { p.Set(s) }) + if len(s) != 1 || s[0] != 1 { + panic(fmt.Sprintf("got %v, want %v", s, []int{1})) + } - defer func() { - if recover() == nil { - panic("did not panic as expected") - } - }() - // This should type check but should panic at run time, - // because it will make a slice of *SettableInt and then call - // Set on a nil value. - fromStrings3[*SettableInt]([]string{"1"}) + defer func() { + if recover() == nil { + panic("did not panic as expected") + } + }() + // This should type check but should panic at run time, + // because it will make a slice of *SettableInt and then call + // Set on a nil value. + fromStrings3[*SettableInt]([]string{"1"}) }