1
0
mirror of https://github.com/golang/go synced 2024-11-11 22:20:22 -07:00

cmd/compile/internal/types2: combine all type inference in a single function

Rather than splitting up type inference into function argument
and constraint type inference, provide a single Checker.infer
that accepts type parameters, type arguments, value parameters,
and value arguments, if any. Checker.infer returns the completed
list of type arguments, or nil.

Updated (and simplified) call sites.

Change-Id: I9200a44b9c4ab7f2d21eed824abfffaab68ff766
Reviewed-on: https://go-review.googlesource.com/c/go/+/306170
Trust: Robert Griesemer <gri@golang.org>
Reviewed-by: Robert Findley <rfindley@google.com>
This commit is contained in:
Robert Griesemer 2021-03-30 17:16:53 -07:00
parent 8462169b5a
commit bce85b7011
2 changed files with 174 additions and 95 deletions

View File

@ -24,52 +24,36 @@ func (check *Checker) funcInst(x *operand, inst *syntax.IndexExpr) {
}
assert(len(targs) == len(xlist))
// check number of type arguments
n := len(targs)
// check number of type arguments (got) vs number of type parameters (want)
sig := x.typ.(*Signature)
if !check.conf.InferFromConstraints && n != len(sig.tparams) || n > len(sig.tparams) {
check.errorf(xlist[n-1], "got %d type arguments but want %d", n, len(sig.tparams))
got, want := len(targs), len(sig.tparams)
if !check.conf.InferFromConstraints && got != want || got > want {
check.errorf(xlist[got-1], "got %d type arguments but want %d", got, want)
x.mode = invalid
x.expr = inst
return
}
// determine argument positions (for error reporting)
poslist := make([]syntax.Pos, n)
for i, x := range xlist {
poslist[i] = syntax.StartPos(x)
}
// if we don't have enough type arguments, use constraint type inference
var inferred bool
if n < len(sig.tparams) {
var failed int
targs, failed = check.inferB(sig.tparams, targs)
// if we don't have enough type arguments, try type inference
inferred := false
if got < want {
targs = check.infer(inst.Pos(), sig.tparams, targs, nil, nil, true)
if targs == nil {
// error was already reported
x.mode = invalid
x.expr = inst
return
}
if failed >= 0 {
// at least one type argument couldn't be inferred
assert(targs[failed] == nil)
tpar := sig.tparams[failed]
check.errorf(inst, "cannot infer %s (%s) (%s)", tpar.name, tpar.pos, targs)
x.mode = invalid
x.expr = inst
return
}
// all type arguments were inferred successfully
if debug {
for _, targ := range targs {
assert(targ != nil)
}
}
n = len(targs)
got = len(targs)
inferred = true
}
assert(n == len(sig.tparams))
assert(got == want)
// determine argument positions (for error reporting)
poslist := make([]syntax.Pos, len(xlist))
for i, x := range xlist {
poslist[i] = syntax.StartPos(x)
}
// instantiate function signature
res := check.instantiate(x.Pos(), sig, targs, poslist).(*Signature)
@ -301,35 +285,10 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, args []*o
if len(sig.tparams) > 0 {
// TODO(gri) provide position information for targs so we can feed
// it to the instantiate call for better error reporting
targs, failed := check.infer(sig.tparams, sigParams, args)
targs := check.infer(call.Pos(), sig.tparams, nil, sigParams, args, true)
if targs == nil {
return // error already reported
}
if failed >= 0 {
// Some type arguments couldn't be inferred. Use
// bounds type inference to try to make progress.
if check.conf.InferFromConstraints {
targs, failed = check.inferB(sig.tparams, targs)
if targs == nil {
return // error already reported
}
}
if failed >= 0 {
// at least one type argument couldn't be inferred
assert(targs[failed] == nil)
tpar := sig.tparams[failed]
// TODO(gri) here we'd like to use the position of the call's ')'
check.errorf(call.Pos(), "cannot infer %s (%s) (%s)", tpar.name, tpar.pos, targs)
return
}
}
// all type arguments were inferred successfully
if debug {
for _, targ := range targs {
assert(targ != nil)
}
}
//check.dump("### inferred targs = %s", targs)
// compute result signature
rsig = check.instantiate(call.Pos(), sig, targs, nil).(*Signature)
@ -548,14 +507,14 @@ func (check *Checker) selector(x *operand, e *syntax.SelectorExpr) {
recv = recv.(*Pointer).base
}
}
// Disable reporting of errors during inference below. If we're unable to infer
// the receiver type arguments here, the receiver must be be otherwise invalid
// and an error has been reported elsewhere.
arg := operand{mode: variable, expr: x.expr, typ: recv}
targs, failed := check.infer(sig.rparams, NewTuple(sig.recv), []*operand{&arg})
targs := check.infer(m.pos, sig.rparams, nil, NewTuple(sig.recv), []*operand{&arg}, false /* no error reporting */)
//check.dump("### inferred targs = %s", targs)
if failed >= 0 {
if targs == nil {
// We may reach here if there were other errors (see issue #40056).
// check.infer will report a follow-up error.
// TODO(gri) avoid the follow-up error as it is confusing
// (there's no inference in the source code)
goto Error
}
// Don't modify m. Instead - for now - make a copy of m and use that instead.

View File

@ -2,29 +2,113 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// This file implements type parameter inference given
// a list of concrete arguments and a parameter list.
// This file implements type parameter inference.
package types2
import "bytes"
import (
"bytes"
"cmd/compile/internal/syntax"
)
// infer returns the list of actual type arguments for the given list of type parameters tparams
// by inferring them from the actual arguments args for the parameters params. If type inference
// is impossible because unification fails, an error is reported and the resulting types list is
// nil, and index is 0. Otherwise, types is the list of inferred type arguments, and index is
// the index of the first type argument in that list that couldn't be inferred (and thus is nil).
// If all type arguments were inferred successfully, index is < 0.
func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand) (types []Type, index int) {
// infer attempts to infer the complete set of type arguments for generic function instantiation/call
// based on the given type parameters tparams, type arguments targs, function parameters params, and
// function arguments args, if any. There must be at least one type parameter, no more type arguments
// than type parameters, and params and args must match in number (incl. zero).
// If successful, infer returns the complete list of type arguments, one for each type parameter.
// Otherwise the result is nil and appropriate errors will be reported unless report is set to false.
//
// Inference proceeds in 3 steps:
//
// 1) Start with given type arguments.
// 2) Infer type arguments from typed function arguments.
// 3) Infer type arguments from untyped function arguments.
//
// Constraint type inference is used after each step to expand the set of type arguments.
//
func (check *Checker) infer(pos syntax.Pos, tparams []*TypeName, targs []Type, params *Tuple, args []*operand, report bool) (result []Type) {
if debug {
defer func() {
assert(result == nil || len(result) == len(tparams))
for _, targ := range result {
assert(targ != nil)
}
//check.dump("### inferred targs = %s", result)
}()
}
// There must be at least one type parameter, and no more type arguments than type parameters.
n := len(tparams)
assert(n > 0 && len(targs) <= n)
// Function parameters and arguments must match in number.
assert(params.Len() == len(args))
// --- 0 ---
// If we already have all type arguments, we're done.
if len(targs) == n {
return targs
}
// len(targs) < n
// --- 1 ---
// Explicitly provided type arguments take precedence over any inferred types;
// and types inferred via constraint type inference take precedence over types
// inferred from function arguments.
// If we have type arguments, see how far we get with constraint type inference.
if len(targs) > 0 && check.conf.InferFromConstraints {
var index int
targs, index = check.inferB(tparams, targs, report)
if targs == nil || index < 0 {
return targs
}
}
// Continue with the type arguments we have now. Avoid matching generic
// parameters that already have type arguments against function arguments:
// It may fail because matching uses type identity while parameter passing
// uses assignment rules. Instantiate the parameter list with the type
// arguments we have, and continue with that parameter list.
// First, make sure we have a "full" list of type arguments, so of which
// may be nil (unknown).
if len(targs) < n {
targs2 := make([]Type, n)
copy(targs2, targs)
targs = targs2
}
// len(targs) == n
// Substitute type arguments for their respective type parameters in params,
// if any. Note that nil targs entries are ignored by check.subst.
// TODO(gri) Can we avoid this (we're setting known type argumemts below,
// but that doesn't impact the isParameterized check for now).
if params.Len() > 0 {
smap := makeSubstMap(tparams, targs)
params = check.subst(nopos, params, smap).(*Tuple)
}
// --- 2 ---
// Unify parameter and argument types for generic parameters with typed arguments
// and collect the indices of generic parameters with untyped arguments.
// Terminology: generic parameter = function parameter with a type-parameterized type
u := newUnifier(check, false)
u.x.init(tparams)
// Set the type arguments which we know already.
for i, targ := range targs {
if targ != nil {
u.x.set(i, targ)
}
}
errorf := func(kind string, tpar, targ Type, arg *operand) {
if !report {
return
}
// provide a better error message if we can
targs, failed := u.x.types()
if failed == 0 {
targs, index := u.x.types()
if index == 0 {
// The first type parameter couldn't be inferred.
// If none of them could be inferred, don't try
// to provide the inferred type in the error msg.
@ -49,16 +133,13 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
}
}
// Terminology: generic parameter = function parameter with a type-parameterized type
// 1st pass: Unify parameter and argument types for generic parameters with typed arguments
// and collect the indices of generic parameters with untyped arguments.
// indices of the generic parameters with untyped arguments - save for later
var indices []int
for i, arg := range args {
par := params.At(i)
// If we permit bidirectional unification, this conditional code needs to be
// executed even if par.typ is not parameterized since the argument may be a
// generic function (for which we want to infer // its type arguments).
// generic function (for which we want to infer its type arguments).
if isParameterized(tparams, par.typ) {
if arg.mode == invalid {
// An error was reported earlier. Ignore this targ
@ -72,7 +153,7 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
// the respective type parameters of targ.
if !u.unify(par.typ, targ) {
errorf("type", par.typ, targ, arg)
return nil, 0
return nil
}
} else {
indices = append(indices, i)
@ -80,12 +161,27 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
}
}
// Some generic parameters with untyped arguments may have been given a type
// indirectly through another generic parameter with a typed argument; we can
// ignore those now. (This only means that we know the types for those generic
// parameters; it doesn't mean untyped arguments can be passed safely. We still
// need to verify that assignment of those arguments is valid when we check
// function parameter passing external to infer.)
// If we've got all type arguments, we're done.
var index int
targs, index = u.x.types()
if index < 0 {
return targs
}
// See how far we get with constraint type inference.
// Note that even if we don't have any type arguments, constraint type inference
// may produce results for constraints that explicitly specify a type.
if check.conf.InferFromConstraints {
targs, index = check.inferB(tparams, targs, report)
if targs == nil || index < 0 {
return targs
}
}
// --- 3 ---
// Use any untyped arguments to infer additional type arguments.
// Some generic parameters with untyped arguments may have been given
// a type by now, we can ignore them.
j := 0
for _, i := range indices {
par := params.At(i)
@ -94,14 +190,15 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
// only parameter type it can possibly match against is a *TypeParam.
// Thus, only keep the indices of generic parameters that are not of
// composite types and which don't have a type inferred yet.
if tpar, _ := par.typ.(*TypeParam); tpar != nil && u.x.at(tpar.index) == nil {
if tpar, _ := par.typ.(*TypeParam); tpar != nil && targs[tpar.index] == nil {
indices[j] = i
j++
}
}
indices = indices[:j]
// 2nd pass: Unify parameter and default argument types for remaining generic parameters.
// Unify parameter and default argument types for remaining generic parameters.
// TODO(gri) Rather than iterating again, combine this code with the loop above.
for _, i := range indices {
par := params.At(i)
arg := args[i]
@ -111,11 +208,31 @@ func (check *Checker) infer(tparams []*TypeName, params *Tuple, args []*operand)
// nil by making sure all default argument types are typed.
if isTyped(targ) && !u.unify(par.typ, targ) {
errorf("default type", par.typ, targ, arg)
return nil, 0
return nil
}
}
return u.x.types()
// If we've got all type arguments, we're done.
targs, index = u.x.types()
if index < 0 {
return targs
}
// Again, follow up with constraint type inference.
if check.conf.InferFromConstraints {
targs, index = check.inferB(tparams, targs, report)
if targs == nil || index < 0 {
return targs
}
}
// At least one type argument couldn't be inferred.
assert(targs != nil && index >= 0 && targs[index] == nil)
tpar := tparams[index]
if report {
check.errorf(pos, "cannot infer %s (%s) (%s)", tpar.name, tpar.pos, targs)
}
return nil
}
// typeNamesString produces a string containing all the
@ -265,12 +382,13 @@ func (w *tpWalker) isParameterizedList(list []Type) bool {
// inferB returns the list of actual type arguments inferred from the type parameters'
// bounds and an initial set of type arguments. If type inference is impossible because
// unification fails, an error is reported, the resulting types list is nil, and index is 0.
// unification fails, an error is reported if report is set to true, the resulting types
// list is nil, and index is 0.
// Otherwise, types is the list of inferred type arguments, and index is the index of the
// first type argument in that list that couldn't be inferred (and thus is nil). If all
// type arguments where inferred successfully, index is < 0. The number of type arguments
// type arguments were inferred successfully, index is < 0. The number of type arguments
// provided may be less than the number of type parameters, but there must be at least one.
func (check *Checker) inferB(tparams []*TypeName, targs []Type) (types []Type, index int) {
func (check *Checker) inferB(tparams []*TypeName, targs []Type, report bool) (types []Type, index int) {
assert(len(tparams) >= len(targs) && len(targs) > 0)
// Setup bidirectional unification between those structural bounds
@ -292,7 +410,9 @@ func (check *Checker) inferB(tparams []*TypeName, targs []Type) (types []Type, i
sbound := check.structuralType(typ.bound)
if sbound != nil {
if !u.unify(typ, sbound) {
check.errorf(tpar.pos, "%s does not match %s", tpar, sbound)
if report {
check.errorf(tpar, "%s does not match %s", tpar, sbound)
}
return nil, 0
}
}