1
0
mirror of https://github.com/golang/go synced 2024-11-26 17:16:54 -07:00

sync/atomic: add (*Value).Swap and (*Value).CompareAndSwap

The functions SwapPointer and CompareAndSwapPointer can be used to
interact with unsafe.Pointer, however generally it is prefered to work
with Value, due to its safer interface. As such, they have been added
along with glue logic to maintain invariants Value guarantees.

To meet these guarantees, the current implementation duplicates much of
the Store function. Some of this is due to inexperience with concurrency
and desire for correctness, but the lack of generic programming
functionality does not help.

Fixes #39351

Change-Id: I1aa394b1e70944736ac1e19de49fe861e1e46fba
Reviewed-on: https://go-review.googlesource.com/c/go/+/241678
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Keith Randall <khr@golang.org>
This commit is contained in:
Colin Arnott 2020-07-09 07:06:46 +00:00 committed by Keith Randall
parent 496d7c6914
commit 2422c5eae5
2 changed files with 257 additions and 11 deletions

View File

@ -25,7 +25,7 @@ type ifaceWords struct {
// Load returns the value set by the most recent Store. // Load returns the value set by the most recent Store.
// It returns nil if there has been no call to Store for this Value. // It returns nil if there has been no call to Store for this Value.
func (v *Value) Load() (x interface{}) { func (v *Value) Load() (val interface{}) {
vp := (*ifaceWords)(unsafe.Pointer(v)) vp := (*ifaceWords)(unsafe.Pointer(v))
typ := LoadPointer(&vp.typ) typ := LoadPointer(&vp.typ)
if typ == nil || uintptr(typ) == ^uintptr(0) { if typ == nil || uintptr(typ) == ^uintptr(0) {
@ -33,21 +33,21 @@ func (v *Value) Load() (x interface{}) {
return nil return nil
} }
data := LoadPointer(&vp.data) data := LoadPointer(&vp.data)
xp := (*ifaceWords)(unsafe.Pointer(&x)) vlp := (*ifaceWords)(unsafe.Pointer(&val))
xp.typ = typ vlp.typ = typ
xp.data = data vlp.data = data
return return
} }
// Store sets the value of the Value to x. // Store sets the value of the Value to x.
// All calls to Store for a given Value must use values of the same concrete type. // All calls to Store for a given Value must use values of the same concrete type.
// Store of an inconsistent type panics, as does Store(nil). // Store of an inconsistent type panics, as does Store(nil).
func (v *Value) Store(x interface{}) { func (v *Value) Store(val interface{}) {
if x == nil { if val == nil {
panic("sync/atomic: store of nil value into Value") panic("sync/atomic: store of nil value into Value")
} }
vp := (*ifaceWords)(unsafe.Pointer(v)) vp := (*ifaceWords)(unsafe.Pointer(v))
xp := (*ifaceWords)(unsafe.Pointer(&x)) vlp := (*ifaceWords)(unsafe.Pointer(&val))
for { for {
typ := LoadPointer(&vp.typ) typ := LoadPointer(&vp.typ)
if typ == nil { if typ == nil {
@ -61,8 +61,8 @@ func (v *Value) Store(x interface{}) {
continue continue
} }
// Complete first store. // Complete first store.
StorePointer(&vp.data, xp.data) StorePointer(&vp.data, vlp.data)
StorePointer(&vp.typ, xp.typ) StorePointer(&vp.typ, vlp.typ)
runtime_procUnpin() runtime_procUnpin()
return return
} }
@ -73,14 +73,121 @@ func (v *Value) Store(x interface{}) {
continue continue
} }
// First store completed. Check type and overwrite data. // First store completed. Check type and overwrite data.
if typ != xp.typ { if typ != vlp.typ {
panic("sync/atomic: store of inconsistently typed value into Value") panic("sync/atomic: store of inconsistently typed value into Value")
} }
StorePointer(&vp.data, xp.data) StorePointer(&vp.data, vlp.data)
return return
} }
} }
// Swap stores new into Value and returns the previous value. It returns nil if
// the Value is empty.
//
// All calls to Swap for a given Value must use values of the same concrete
// type. Swap of an inconsistent type panics, as does Swap(nil).
func (v *Value) Swap(new interface{}) (old interface{}) {
if new == nil {
panic("sync/atomic: swap of nil value into Value")
}
vp := (*ifaceWords)(unsafe.Pointer(v))
np := (*ifaceWords)(unsafe.Pointer(&new))
for {
typ := LoadPointer(&vp.typ)
if typ == nil {
// Attempt to start first store.
// Disable preemption so that other goroutines can use
// active spin wait to wait for completion; and so that
// GC does not see the fake type accidentally.
runtime_procPin()
if !CompareAndSwapPointer(&vp.typ, nil, unsafe.Pointer(^uintptr(0))) {
runtime_procUnpin()
continue
}
// Complete first store.
StorePointer(&vp.data, np.data)
StorePointer(&vp.typ, np.typ)
runtime_procUnpin()
return nil
}
if uintptr(typ) == ^uintptr(0) {
// First store in progress. Wait.
// Since we disable preemption around the first store,
// we can wait with active spinning.
continue
}
// First store completed. Check type and overwrite data.
if typ != np.typ {
panic("sync/atomic: swap of inconsistently typed value into Value")
}
op := (*ifaceWords)(unsafe.Pointer(&old))
op.typ, op.data = np.typ, SwapPointer(&vp.data, np.data)
return old
}
}
// CompareAndSwapPointer executes the compare-and-swap operation for the Value.
//
// All calls to CompareAndSwap for a given Value must use values of the same
// concrete type. CompareAndSwap of an inconsistent type panics, as does
// CompareAndSwap(old, nil).
func (v *Value) CompareAndSwap(old, new interface{}) (swapped bool) {
if new == nil {
panic("sync/atomic: compare and swap of nil value into Value")
}
vp := (*ifaceWords)(unsafe.Pointer(v))
np := (*ifaceWords)(unsafe.Pointer(&new))
op := (*ifaceWords)(unsafe.Pointer(&old))
if op.typ != nil && np.typ != op.typ {
panic("sync/atomic: compare and swap of inconsistently typed values")
}
for {
typ := LoadPointer(&vp.typ)
if typ == nil {
if old != nil {
return false
}
// Attempt to start first store.
// Disable preemption so that other goroutines can use
// active spin wait to wait for completion; and so that
// GC does not see the fake type accidentally.
runtime_procPin()
if !CompareAndSwapPointer(&vp.typ, nil, unsafe.Pointer(^uintptr(0))) {
runtime_procUnpin()
continue
}
// Complete first store.
StorePointer(&vp.data, np.data)
StorePointer(&vp.typ, np.typ)
runtime_procUnpin()
return true
}
if uintptr(typ) == ^uintptr(0) {
// First store in progress. Wait.
// Since we disable preemption around the first store,
// we can wait with active spinning.
continue
}
// First store completed. Check type and overwrite data.
if typ != np.typ {
panic("sync/atomic: compare and swap of inconsistently typed value into Value")
}
// Compare old and current via runtime equality check.
// This allows value types to be compared, something
// not offered by the package functions.
// CompareAndSwapPointer below only ensures vp.data
// has not changed since LoadPointer.
data := LoadPointer(&vp.data)
var i interface{}
(*ifaceWords)(unsafe.Pointer(&i)).typ = typ
(*ifaceWords)(unsafe.Pointer(&i)).data = data
if i != old {
return false
}
return CompareAndSwapPointer(&vp.data, data, np.data)
}
}
// Disable/enable preemption, implemented in runtime. // Disable/enable preemption, implemented in runtime.
func runtime_procPin() func runtime_procPin()
func runtime_procUnpin() func runtime_procUnpin()

View File

@ -7,6 +7,9 @@ package atomic_test
import ( import (
"math/rand" "math/rand"
"runtime" "runtime"
"strconv"
"sync"
"sync/atomic"
. "sync/atomic" . "sync/atomic"
"testing" "testing"
) )
@ -133,3 +136,139 @@ func BenchmarkValueRead(b *testing.B) {
} }
}) })
} }
var Value_SwapTests = []struct {
init interface{}
new interface{}
want interface{}
err interface{}
}{
{init: nil, new: nil, err: "sync/atomic: swap of nil value into Value"},
{init: nil, new: true, want: nil, err: nil},
{init: true, new: "", err: "sync/atomic: swap of inconsistently typed value into Value"},
{init: true, new: false, want: true, err: nil},
}
func TestValue_Swap(t *testing.T) {
for i, tt := range Value_SwapTests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
var v Value
if tt.init != nil {
v.Store(tt.init)
}
defer func() {
err := recover()
switch {
case tt.err == nil && err != nil:
t.Errorf("should not panic, got %v", err)
case tt.err != nil && err == nil:
t.Errorf("should panic %v, got <nil>", tt.err)
}
}()
if got := v.Swap(tt.new); got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
if got := v.Load(); got != tt.new {
t.Errorf("got %v, want %v", got, tt.new)
}
})
}
}
func TestValueSwapConcurrent(t *testing.T) {
var v Value
var count uint64
var g sync.WaitGroup
var m, n uint64 = 10000, 10000
if testing.Short() {
m = 1000
n = 1000
}
for i := uint64(0); i < m*n; i += n {
i := i
g.Add(1)
go func() {
var c uint64
for new := i; new < i+n; new++ {
if old := v.Swap(new); old != nil {
c += old.(uint64)
}
}
atomic.AddUint64(&count, c)
g.Done()
}()
}
g.Wait()
if want, got := (m*n-1)*(m*n)/2, count+v.Load().(uint64); got != want {
t.Errorf("sum from 0 to %d was %d, want %v", m*n-1, got, want)
}
}
var heapA, heapB = struct{ uint }{0}, struct{ uint }{0}
var Value_CompareAndSwapTests = []struct {
init interface{}
new interface{}
old interface{}
want bool
err interface{}
}{
{init: nil, new: nil, old: nil, err: "sync/atomic: compare and swap of nil value into Value"},
{init: nil, new: true, old: "", err: "sync/atomic: compare and swap of inconsistently typed values into Value"},
{init: nil, new: true, old: true, want: false, err: nil},
{init: nil, new: true, old: nil, want: true, err: nil},
{init: true, new: "", err: "sync/atomic: compare and swap of inconsistently typed value into Value"},
{init: true, new: true, old: false, want: false, err: nil},
{init: true, new: true, old: true, want: true, err: nil},
{init: heapA, new: struct{ uint }{1}, old: heapB, want: true, err: nil},
}
func TestValue_CompareAndSwap(t *testing.T) {
for i, tt := range Value_CompareAndSwapTests {
t.Run(strconv.Itoa(i), func(t *testing.T) {
var v Value
if tt.init != nil {
v.Store(tt.init)
}
defer func() {
err := recover()
switch {
case tt.err == nil && err != nil:
t.Errorf("got %v, wanted no panic", err)
case tt.err != nil && err == nil:
t.Errorf("did not panic, want %v", tt.err)
}
}()
if got := v.CompareAndSwap(tt.old, tt.new); got != tt.want {
t.Errorf("got %v, want %v", got, tt.want)
}
})
}
}
func TestValueCompareAndSwapConcurrent(t *testing.T) {
var v Value
var w sync.WaitGroup
v.Store(0)
m, n := 1000, 100
if testing.Short() {
m = 100
n = 100
}
for i := 0; i < m; i++ {
i := i
w.Add(1)
go func() {
for j := i; j < m*n; runtime.Gosched() {
if v.CompareAndSwap(j, j+1) {
j += m
}
}
w.Done()
}()
}
w.Wait()
if stop := v.Load().(int); stop != m*n {
t.Errorf("did not get to %v, stopped at %v", m*n, stop)
}
}