From cafd9a2bec3c36a22578a901f97c781c5a74a42f Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Wed, 28 Jun 2023 08:06:21 -0700 Subject: [PATCH] syncs: add ShardedMap.Mutate To let callers do atomic/CAS-like operations. Updates tailscale/corp#7355 Signed-off-by: Brad Fitzpatrick --- syncs/shardedmap.go | 29 ++++++++++++++++++++++++++++- syncs/shardedmap_test.go | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/syncs/shardedmap.go b/syncs/shardedmap.go index 00ce3aafa..12edf5bfc 100644 --- a/syncs/shardedmap.go +++ b/syncs/shardedmap.go @@ -59,9 +59,36 @@ func (m *ShardedMap[K, V]) Get(key K) (value V) { return } +// Mutate atomically mutates m[k] by calling mutator. +// +// The mutator function is called with the old value (or its zero value) and +// whether it existed in the map and it returns the new value and whether it +// should be set in the map (true) or deleted from the map (false). +// +// It returns the change in size of the map as a result of the mutation, one of +// -1 (delete), 0 (change), or 1 (addition). +func (m *ShardedMap[K, V]) Mutate(key K, mutator func(oldValue V, oldValueExisted bool) (newValue V, keep bool)) (sizeDelta int) { + shard := m.shard(key) + shard.mu.Lock() + defer shard.mu.Unlock() + oldV, oldOK := shard.m[key] + newV, newOK := mutator(oldV, oldOK) + if newOK { + shard.m[key] = newV + if oldOK { + return 0 + } + return 1 + } + delete(shard.m, key) + if oldOK { + return -1 + } + return 0 +} + // Set sets m[key] = value. // -// It reports whether the map grew in size (that is, whether key was not already // present in m). func (m *ShardedMap[K, V]) Set(key K, value V) (grew bool) { shard := m.shard(key) diff --git a/syncs/shardedmap_test.go b/syncs/shardedmap_test.go index b09a268d7..993ffdff8 100644 --- a/syncs/shardedmap_test.go +++ b/syncs/shardedmap_test.go @@ -41,4 +41,41 @@ func TestShardedMap(t *testing.T) { if g, w := m.Len(), 0; g != w { t.Errorf("got Len %v; want %v", g, w) } + + // Mutation adding an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if ok { + t.Fatal("was okay") + } + return "ONE", true + }); v != 1 { + t.Errorf("Mutate = %v; want 1", v) + } + if g, w := m.Get(1), "ONE"; g != w { + t.Errorf("got %q; want %q", g, w) + } + // Mutation changing an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if !ok { + t.Fatal("wasn't okay") + } + return was + "-" + was, true + }); v != 0 { + t.Errorf("Mutate = %v; want 0", v) + } + if g, w := m.Get(1), "ONE-ONE"; g != w { + t.Errorf("got %q; want %q", g, w) + } + // Mutation removing an entry. + if v := m.Mutate(1, func(was string, ok bool) (string, bool) { + if !ok { + t.Fatal("wasn't okay") + } + return "", false + }); v != -1 { + t.Errorf("Mutate = %v; want -1", v) + } + if g, w := m.Get(1), ""; g != w { + t.Errorf("got %q; want %q", g, w) + } }