From d62866ef793872779c9011161e51b9c805fcb73d Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Fri, 27 Aug 2021 20:07:00 +0700 Subject: [PATCH] cmd/compile: move checkptr alignment to SSA generation This is followup of CL 343972, moving the checkptr alignment instrumentation during SSA generation instead of walk. Change-Id: I29b2953e4eb8631277fe2e0f44b9d987dd7a69f9 Reviewed-on: https://go-review.googlesource.com/c/go/+/345430 Trust: Cuong Manh Le Run-TryBot: Cuong Manh Le TryBot-Result: Go Bot Reviewed-by: Matthew Dempsky --- src/cmd/compile/internal/ir/expr.go | 9 ++-- src/cmd/compile/internal/ir/symtab.go | 55 ++++++++++++------------ src/cmd/compile/internal/ssagen/ssa.go | 51 +++++++++++++++++++--- src/cmd/compile/internal/walk/convert.go | 38 ---------------- src/cmd/compile/internal/walk/expr.go | 13 +----- 5 files changed, 77 insertions(+), 89 deletions(-) diff --git a/src/cmd/compile/internal/ir/expr.go b/src/cmd/compile/internal/ir/expr.go index baf01174098..f526d987a7b 100644 --- a/src/cmd/compile/internal/ir/expr.go +++ b/src/cmd/compile/internal/ir/expr.go @@ -570,11 +570,10 @@ func (*SelectorExpr) CanBeNtype() {} // A SliceExpr is a slice expression X[Low:High] or X[Low:High:Max]. type SliceExpr struct { miniExpr - X Node - Low Node - High Node - Max Node - CheckPtrCall *CallExpr `mknode:"-"` + X Node + Low Node + High Node + Max Node } func NewSliceExpr(pos src.XPos, op Op, x, low, high, max Node) *SliceExpr { diff --git a/src/cmd/compile/internal/ir/symtab.go b/src/cmd/compile/internal/ir/symtab.go index 1e8261810f4..1435e4313e0 100644 --- a/src/cmd/compile/internal/ir/symtab.go +++ b/src/cmd/compile/internal/ir/symtab.go @@ -11,33 +11,34 @@ import ( // Syms holds known symbols. var Syms struct { - AssertE2I *obj.LSym - AssertE2I2 *obj.LSym - AssertI2I *obj.LSym - AssertI2I2 *obj.LSym - Deferproc *obj.LSym - DeferprocStack *obj.LSym - Deferreturn *obj.LSym - Duffcopy *obj.LSym - Duffzero *obj.LSym - GCWriteBarrier *obj.LSym - Goschedguarded *obj.LSym - Growslice *obj.LSym - Msanread *obj.LSym - Msanwrite *obj.LSym - Msanmove *obj.LSym - Newobject *obj.LSym - Newproc *obj.LSym - Panicdivide *obj.LSym - Panicshift *obj.LSym - PanicdottypeE *obj.LSym - PanicdottypeI *obj.LSym - Panicnildottype *obj.LSym - Panicoverflow *obj.LSym - Raceread *obj.LSym - Racereadrange *obj.LSym - Racewrite *obj.LSym - Racewriterange *obj.LSym + AssertE2I *obj.LSym + AssertE2I2 *obj.LSym + AssertI2I *obj.LSym + AssertI2I2 *obj.LSym + CheckPtrAlignment *obj.LSym + Deferproc *obj.LSym + DeferprocStack *obj.LSym + Deferreturn *obj.LSym + Duffcopy *obj.LSym + Duffzero *obj.LSym + GCWriteBarrier *obj.LSym + Goschedguarded *obj.LSym + Growslice *obj.LSym + Msanread *obj.LSym + Msanwrite *obj.LSym + Msanmove *obj.LSym + Newobject *obj.LSym + Newproc *obj.LSym + Panicdivide *obj.LSym + Panicshift *obj.LSym + PanicdottypeE *obj.LSym + PanicdottypeI *obj.LSym + Panicnildottype *obj.LSym + Panicoverflow *obj.LSym + Raceread *obj.LSym + Racereadrange *obj.LSym + Racewrite *obj.LSym + Racewriterange *obj.LSym // Wasm SigPanic *obj.LSym Staticuint64s *obj.LSym diff --git a/src/cmd/compile/internal/ssagen/ssa.go b/src/cmd/compile/internal/ssagen/ssa.go index dd19a254f8d..11bca89fd85 100644 --- a/src/cmd/compile/internal/ssagen/ssa.go +++ b/src/cmd/compile/internal/ssagen/ssa.go @@ -96,6 +96,7 @@ func InitConfig() { ir.Syms.AssertE2I2 = typecheck.LookupRuntimeFunc("assertE2I2") ir.Syms.AssertI2I = typecheck.LookupRuntimeFunc("assertI2I") ir.Syms.AssertI2I2 = typecheck.LookupRuntimeFunc("assertI2I2") + ir.Syms.CheckPtrAlignment = typecheck.LookupRuntimeFunc("checkptrAlignment") ir.Syms.Deferproc = typecheck.LookupRuntimeFunc("deferproc") ir.Syms.DeferprocStack = typecheck.LookupRuntimeFunc("deferprocStack") ir.Syms.Deferreturn = typecheck.LookupRuntimeFunc("deferreturn") @@ -366,6 +367,7 @@ func buildssa(fn *ir.Func, worker int) *ssa.Func { if fn.Pragma&ir.CgoUnsafeArgs != 0 { s.cgoUnsafeArgs = true } + s.checkPtrEnabled = ir.ShouldCheckPtr(fn, 1) fe := ssafn{ curfn: fn, @@ -709,6 +711,31 @@ func (s *state) newObject(typ *types.Type) *ssa.Value { return s.rtcall(ir.Syms.Newobject, true, []*types.Type{types.NewPtr(typ)}, s.reflectType(typ))[0] } +func (s *state) checkPtrAlignment(n *ir.ConvExpr, v *ssa.Value, count *ssa.Value) { + if !n.Type().IsPtr() { + s.Fatalf("expected pointer type: %v", n.Type()) + } + elem := n.Type().Elem() + if count != nil { + if !elem.IsArray() { + s.Fatalf("expected array type: %v", elem) + } + elem = elem.Elem() + } + size := elem.Size() + // Casting from larger type to smaller one is ok, so for smallest type, do nothing. + if elem.Alignment() == 1 && (size == 0 || size == 1 || count == nil) { + return + } + if count == nil { + count = s.constInt(types.Types[types.TUINTPTR], 1) + } + if count.Type.Size() != s.config.PtrSize { + s.Fatalf("expected count fit to an uintptr size, have: %d, want: %d", count.Type.Size(), s.config.PtrSize) + } + s.rtcall(ir.Syms.CheckPtrAlignment, true, nil, v, s.reflectType(elem), count) +} + // reflectType returns an SSA value representing a pointer to typ's // reflection type descriptor. func (s *state) reflectType(typ *types.Type) *ssa.Value { @@ -861,10 +888,11 @@ type state struct { // Used to deduplicate panic calls. panics map[funcLine]*ssa.Block - cgoUnsafeArgs bool - hasdefer bool // whether the function contains a defer statement - softFloat bool - hasOpenDefers bool // whether we are doing open-coded defers + cgoUnsafeArgs bool + hasdefer bool // whether the function contains a defer statement + softFloat bool + hasOpenDefers bool // whether we are doing open-coded defers + checkPtrEnabled bool // whether to insert checkptr instrumentation // If doing open-coded defers, list of info about the defer calls in // scanning order. Hence, at exit we should run these defers in reverse @@ -2494,6 +2522,10 @@ func (s *state) conv(n ir.Node, v *ssa.Value, ft, tt *types.Type) *ssa.Value { // expr converts the expression n to ssa, adds it to s and returns the ssa result. func (s *state) expr(n ir.Node) *ssa.Value { + return s.exprCheckPtr(n, true) +} + +func (s *state) exprCheckPtr(n ir.Node, checkPtrOK bool) *ssa.Value { if ir.HasUniquePos(n) { // ONAMEs and named OLITERALs have the line number // of the decl, not the use. See issue 14742. @@ -2641,6 +2673,9 @@ func (s *state) expr(n ir.Node) *ssa.Value { // unsafe.Pointer <--> *T if to.IsUnsafePtr() && from.IsPtrShaped() || from.IsUnsafePtr() && to.IsPtrShaped() { + if s.checkPtrEnabled && checkPtrOK && to.IsPtr() && from.IsUnsafePtr() { + s.checkPtrAlignment(n, v, nil) + } return v } @@ -3081,7 +3116,8 @@ func (s *state) expr(n ir.Node) *ssa.Value { case ir.OSLICE, ir.OSLICEARR, ir.OSLICE3, ir.OSLICE3ARR: n := n.(*ir.SliceExpr) - v := s.expr(n.X) + check := s.checkPtrEnabled && n.Op() == ir.OSLICE3ARR && n.X.Op() == ir.OCONVNOP && n.X.(*ir.ConvExpr).X.Type().IsUnsafePtr() + v := s.exprCheckPtr(n.X, !check) var i, j, k *ssa.Value if n.Low != nil { i = s.expr(n.Low) @@ -3093,8 +3129,9 @@ func (s *state) expr(n ir.Node) *ssa.Value { k = s.expr(n.Max) } p, l, c := s.slice(v, i, j, k, n.Bounded()) - if n.CheckPtrCall != nil { - s.stmt(n.CheckPtrCall) + if check { + // Emit checkptr instrumentation after bound check to prevent false positive, see #46938. + s.checkPtrAlignment(n.X.(*ir.ConvExpr), v, s.conv(n.Max, k, k.Type, types.Types[types.TUINTPTR])) } return s.newValue3(ssa.OpSliceMake, n.Type(), p, l, c) diff --git a/src/cmd/compile/internal/walk/convert.go b/src/cmd/compile/internal/walk/convert.go index d701d545de8..5d69fc38683 100644 --- a/src/cmd/compile/internal/walk/convert.go +++ b/src/cmd/compile/internal/walk/convert.go @@ -25,9 +25,6 @@ func walkConv(n *ir.ConvExpr, init *ir.Nodes) ir.Node { return n.X } if n.Op() == ir.OCONVNOP && ir.ShouldCheckPtr(ir.CurFunc, 1) { - if n.Type().IsPtr() && n.X.Type().IsUnsafePtr() { // unsafe.Pointer to *T - return walkCheckPtrAlignment(n, init, nil) - } if n.Type().IsUnsafePtr() && n.X.Type().IsUintptr() { // uintptr to unsafe.Pointer return walkCheckPtrArithmetic(n, init) } @@ -414,41 +411,6 @@ func byteindex(n ir.Node) ir.Node { return n } -func walkCheckPtrAlignment(n *ir.ConvExpr, init *ir.Nodes, se *ir.SliceExpr) ir.Node { - if !n.Type().IsPtr() { - base.Fatalf("expected pointer type: %v", n.Type()) - } - elem := n.Type().Elem() - var count ir.Node - if se != nil { - count = se.Max - } - if count != nil { - if !elem.IsArray() { - base.Fatalf("expected array type: %v", elem) - } - elem = elem.Elem() - } - - size := elem.Size() - if elem.Alignment() == 1 && (size == 0 || size == 1 && count == nil) { - return n - } - - if count == nil { - count = ir.NewInt(1) - } - - n.X = cheapExpr(n.X, init) - checkPtrCall := mkcall("checkptrAlignment", nil, init, typecheck.ConvNop(n.X, types.Types[types.TUNSAFEPTR]), reflectdata.TypePtr(elem), typecheck.Conv(count, types.Types[types.TUINTPTR])) - if se != nil { - se.CheckPtrCall = checkPtrCall - } else { - init.Append(checkPtrCall) - } - return n -} - func walkCheckPtrArithmetic(n *ir.ConvExpr, init *ir.Nodes) ir.Node { // Calling cheapExpr(n, init) below leads to a recursive call to // walkExpr, which leads us back here again. Use n.Checkptr to diff --git a/src/cmd/compile/internal/walk/expr.go b/src/cmd/compile/internal/walk/expr.go index ed2d68539dc..e5bf6cf0b5e 100644 --- a/src/cmd/compile/internal/walk/expr.go +++ b/src/cmd/compile/internal/walk/expr.go @@ -807,15 +807,7 @@ func walkSend(n *ir.SendStmt, init *ir.Nodes) ir.Node { // walkSlice walks an OSLICE, OSLICEARR, OSLICESTR, OSLICE3, or OSLICE3ARR node. func walkSlice(n *ir.SliceExpr, init *ir.Nodes) ir.Node { - - checkSlice := ir.ShouldCheckPtr(ir.CurFunc, 1) && n.Op() == ir.OSLICE3ARR && n.X.Op() == ir.OCONVNOP && n.X.(*ir.ConvExpr).X.Type().IsUnsafePtr() - if checkSlice { - conv := n.X.(*ir.ConvExpr) - conv.X = walkExpr(conv.X, init) - } else { - n.X = walkExpr(n.X, init) - } - + n.X = walkExpr(n.X, init) n.Low = walkExpr(n.Low, init) if n.Low != nil && ir.IsZero(n.Low) { // Reduce x[0:j] to x[:j] and x[0:j:k] to x[:j:k]. @@ -823,9 +815,6 @@ func walkSlice(n *ir.SliceExpr, init *ir.Nodes) ir.Node { } n.High = walkExpr(n.High, init) n.Max = walkExpr(n.Max, init) - if checkSlice { - n.X = walkCheckPtrAlignment(n.X.(*ir.ConvExpr), init, n) - } if n.Op().IsSlice3() { if n.Max != nil && n.Max.Op() == ir.OCAP && ir.SameSafeExpr(n.X, n.Max.(*ir.UnaryExpr).X) {