diff --git a/src/internal/sync/hashtriemap.go b/src/internal/sync/hashtriemap.go index 4a7ae07166f..73c8bba1e36 100644 --- a/src/internal/sync/hashtriemap.go +++ b/src/internal/sync/hashtriemap.go @@ -21,7 +21,7 @@ import ( type HashTrieMap[K comparable, V any] struct { inited atomic.Uint32 initMu Mutex - root *indirect[K, V] + root atomic.Pointer[indirect[K, V]] keyHash hashFunc valEqual equalFunc seed uintptr @@ -47,7 +47,7 @@ func (ht *HashTrieMap[K, V]) initSlow() { // equal function for the value, if any. var m map[K]V mapType := abi.TypeOf(m).MapType() - ht.root = newIndirectNode[K, V](nil) + ht.root.Store(newIndirectNode[K, V](nil)) ht.keyHash = mapType.Hasher ht.valEqual = mapType.Elem.Equal ht.seed = uintptr(runtime_rand()) @@ -65,7 +65,7 @@ func (ht *HashTrieMap[K, V]) Load(key K) (value V, ok bool) { ht.init() hash := ht.keyHash(abi.NoEscape(unsafe.Pointer(&key)), ht.seed) - i := ht.root + i := ht.root.Load() hashShift := 8 * goarch.PtrSize for hashShift != 0 { hashShift -= nChildrenLog2 @@ -94,7 +94,7 @@ func (ht *HashTrieMap[K, V]) LoadOrStore(key K, value V) (result V, loaded bool) var n *node[K, V] for { // Find the key or a candidate location for insertion. - i = ht.root + i = ht.root.Load() hashShift = 8 * goarch.PtrSize haveInsertPoint := false for hashShift != 0 { @@ -211,7 +211,7 @@ func (ht *HashTrieMap[K, V]) Swap(key K, new V) (previous V, loaded bool) { var n *node[K, V] for { // Find the key or a candidate location for insertion. - i = ht.root + i = ht.root.Load() hashShift = 8 * goarch.PtrSize haveInsertPoint := false for hashShift != 0 { @@ -426,7 +426,7 @@ func (ht *HashTrieMap[K, V]) CompareAndDelete(key K, old V) (deleted bool) { func (ht *HashTrieMap[K, V]) find(key K, hash uintptr, valEqual equalFunc, value V) (i *indirect[K, V], hashShift uint, slot *atomic.Pointer[node[K, V]], n *node[K, V]) { for { // Find the key or return if it's not there. - i = ht.root + i = ht.root.Load() hashShift = 8 * goarch.PtrSize found := false for hashShift != 0 { @@ -470,15 +470,18 @@ func (ht *HashTrieMap[K, V]) find(key K, hash uintptr, valEqual equalFunc, value } } -// All returns an iter.Seq2 that produces all key-value pairs in the map. -// The enumeration does not represent any consistent snapshot of the map, -// but is guaranteed to visit each unique key-value pair only once. It is -// safe to operate on the tree during iteration. No particular enumeration -// order is guaranteed. +// All returns an iterator over each key and value present in the map. +// +// The iterator does not necessarily correspond to any consistent snapshot of the +// HashTrieMap's contents: no key will be visited more than once, but if the value +// for any key is stored or deleted concurrently (including by yield), the iterator +// may reflect any mapping for that key from any point during iteration. The iterator +// does not block other methods on the receiver; even yield itself may call any +// method on the HashTrieMap. func (ht *HashTrieMap[K, V]) All() func(yield func(K, V) bool) { ht.init() return func(yield func(key K, value V) bool) { - ht.iter(ht.root, yield) + ht.iter(ht.root.Load(), yield) } } @@ -505,6 +508,15 @@ func (ht *HashTrieMap[K, V]) iter(i *indirect[K, V], yield func(key K, value V) return true } +// Clear deletes all the entries, resulting in an empty HashTrieMap. +func (ht *HashTrieMap[K, V]) Clear() { + ht.init() + + // It's sufficient to just drop the root on the floor, but the root + // must always be non-nil. + ht.root.Store(newIndirectNode[K, V](nil)) +} + const ( // 16 children. This seems to be the sweet spot for // load performance: any smaller and we lose out on diff --git a/src/internal/sync/hashtriemap_test.go b/src/internal/sync/hashtriemap_test.go index 12e3ee6091d..5476add8803 100644 --- a/src/internal/sync/hashtriemap_test.go +++ b/src/internal/sync/hashtriemap_test.go @@ -66,6 +66,57 @@ func testHashTrieMap(t *testing.T, newMap func() *isync.HashTrieMap[string, int] return true }) }) + t.Run("Clear", func(t *testing.T) { + t.Run("Simple", func(t *testing.T) { + m := newMap() + + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + expectPresent(t, s, i)(m.Load(s)) + expectLoaded(t, s, i)(m.LoadOrStore(s, 0)) + } + m.Clear() + for _, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + } + }) + t.Run("Concurrent", func(t *testing.T) { + m := newMap() + + // Load up the map. + for i, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + expectStored(t, s, i)(m.LoadOrStore(s, i)) + } + gmp := runtime.GOMAXPROCS(-1) + var wg sync.WaitGroup + for i := range gmp { + wg.Add(1) + go func(id int) { + defer wg.Done() + + for _, s := range testData { + // Try a couple things to interfere with the clear. + expectNotDeleted(t, s, math.MaxInt)(m.CompareAndDelete(s, math.MaxInt)) + m.CompareAndSwap(s, i, i+1) // May succeed or fail; we don't care. + } + }(i) + } + + // Concurrently clear the map. + runtime.Gosched() + m.Clear() + + // Wait for workers to finish. + wg.Wait() + + // It should all be empty now. + for _, s := range testData { + expectMissing(t, s, 0)(m.Load(s)) + } + }) + }) t.Run("CompareAndDelete", func(t *testing.T) { t.Run("All", func(t *testing.T) { m := newMap()