diff --git a/src/cmd/compile/internal/walk/stmt.go b/src/cmd/compile/internal/walk/stmt.go index 46a621c2ba..0c851506cb 100644 --- a/src/cmd/compile/internal/walk/stmt.go +++ b/src/cmd/compile/internal/walk/stmt.go @@ -253,15 +253,22 @@ func wrapCall(n *ir.CallExpr, init *ir.Nodes) ir.Node { } } + wrapArgs := n.Args + // If there's a receiver argument, it needs to be passed through the wrapper too. + if n.Op() == ir.OCALLMETH || n.Op() == ir.OCALLINTER { + recv := n.X.(*ir.SelectorExpr).X + wrapArgs = append([]ir.Node{recv}, wrapArgs...) + } + // origArgs keeps track of what argument is uintptr-unsafe/unsafe-uintptr conversion. - origArgs := make([]ir.Node, len(n.Args)) + origArgs := make([]ir.Node, len(wrapArgs)) var funcArgs []*ir.Field - for i, arg := range n.Args { + for i, arg := range wrapArgs { s := typecheck.LookupNum("a", i) if !isBuiltinCall && arg.Op() == ir.OCONVNOP && arg.Type().IsUintptr() && arg.(*ir.ConvExpr).X.Type().IsUnsafePtr() { origArgs[i] = arg arg = arg.(*ir.ConvExpr).X - n.Args[i] = arg + wrapArgs[i] = arg } funcArgs = append(funcArgs, ir.NewField(base.Pos, s, nil, arg.Type())) } @@ -278,6 +285,12 @@ func wrapCall(n *ir.CallExpr, init *ir.Nodes) ir.Node { } args[i] = ir.NewConvExpr(base.Pos, origArg.Op(), origArg.Type(), args[i]) } + if n.Op() == ir.OCALLMETH || n.Op() == ir.OCALLINTER { + // Move wrapped receiver argument back to its appropriate place. + recv := typecheck.Expr(args[0]) + n.X.(*ir.SelectorExpr).X = recv + args = args[1:] + } call := ir.NewCallExpr(base.Pos, n.Op(), n.X, args) if !isBuiltinCall { call.SetOp(ir.OCALL) @@ -291,6 +304,6 @@ func wrapCall(n *ir.CallExpr, init *ir.Nodes) ir.Node { typecheck.Stmts(fn.Body) typecheck.Target.Decls = append(typecheck.Target.Decls, fn) - call = ir.NewCallExpr(base.Pos, ir.OCALL, fn.Nname, n.Args) + call = ir.NewCallExpr(base.Pos, ir.OCALL, fn.Nname, wrapArgs) return walkExpr(typecheck.Stmt(call), init) } diff --git a/test/fixedbugs/issue24491a.go b/test/fixedbugs/issue24491a.go index 8accf8c0a3..d30b65b233 100644 --- a/test/fixedbugs/issue24491a.go +++ b/test/fixedbugs/issue24491a.go @@ -48,6 +48,14 @@ func f() int { return test("return", uintptr(setup()), uintptr(setup()), uintptr(setup()), uintptr(setup())) } +type S struct{} + +//go:noinline +//go:uintptrescapes +func (S) test(s string, p, q uintptr, rest ...uintptr) int { + return test(s, p, q, rest...) +} + func main() { test("normal", uintptr(setup()), uintptr(setup()), uintptr(setup()), uintptr(setup())) <-done @@ -60,6 +68,29 @@ func main() { }() <-done + func() { + for { + defer test("defer in for loop", uintptr(setup()), uintptr(setup()), uintptr(setup()), uintptr(setup())) + break + } + }() + + <-done + func() { + s := &S{} + defer s.test("method call", uintptr(setup()), uintptr(setup()), uintptr(setup()), uintptr(setup())) + }() + <-done + + func() { + s := &S{} + for { + defer s.test("defer method loop", uintptr(setup()), uintptr(setup()), uintptr(setup()), uintptr(setup())) + break + } + }() + <-done + f() <-done }