diff --git a/syncs/syncs.go b/syncs/syncs.go index 66ca69b60..c3f729a90 100644 --- a/syncs/syncs.go +++ b/syncs/syncs.go @@ -23,11 +23,18 @@ func initClosedChan() <-chan struct{} { return ch } -// AtomicValue is the generic version of atomic.Value. +// AtomicValue is the generic version of [atomic.Value]. type AtomicValue[T any] struct { v atomic.Value } +// wrappedValue is used to wrap a value T in a concrete type, +// otherwise atomic.Value.Store may panic due to mismatching types in interfaces. +// This wrapping is not necessary for non-interface kinds of T, +// but there is no harm in wrapping anyways. +// See https://cs.opensource.google/go/go/+/refs/tags/go1.22.2:src/sync/atomic/value.go;l=78 +type wrappedValue[T any] struct{ v T } + // Load returns the value set by the most recent Store. // It returns the zero value for T if the value is empty. func (v *AtomicValue[T]) Load() T { @@ -40,7 +47,7 @@ func (v *AtomicValue[T]) Load() T { func (v *AtomicValue[T]) LoadOk() (_ T, ok bool) { x := v.v.Load() if x != nil { - return x.(T), true + return x.(wrappedValue[T]).v, true } var zero T return zero, false @@ -48,22 +55,22 @@ func (v *AtomicValue[T]) LoadOk() (_ T, ok bool) { // Store sets the value of the Value to x. func (v *AtomicValue[T]) Store(x T) { - v.v.Store(x) + v.v.Store(wrappedValue[T]{x}) } // Swap stores new into Value and returns the previous value. // It returns the zero value for T if the value is empty. func (v *AtomicValue[T]) Swap(x T) (old T) { - oldV := v.v.Swap(x) + oldV := v.v.Swap(wrappedValue[T]{x}) if oldV != nil { - return oldV.(T) + return oldV.(wrappedValue[T]).v } return old } // CompareAndSwap executes the compare-and-swap operation for the Value. func (v *AtomicValue[T]) CompareAndSwap(oldV, newV T) (swapped bool) { - return v.v.CompareAndSwap(oldV, newV) + return v.v.CompareAndSwap(wrappedValue[T]{oldV}, wrappedValue[T]{newV}) } // WaitGroupChan is like a sync.WaitGroup, but has a chan that closes diff --git a/syncs/syncs_test.go b/syncs/syncs_test.go index a9ed67a41..424d51794 100644 --- a/syncs/syncs_test.go +++ b/syncs/syncs_test.go @@ -5,12 +5,67 @@ package syncs import ( "context" + "io" + "os" "sync" "testing" "github.com/google/go-cmp/cmp" ) +func TestAtomicValue(t *testing.T) { + { + // Always wrapping should not allocate for simple values + // because wrappedValue[T] has the same memory layout as T. + var v AtomicValue[bool] + bools := []bool{true, false} + if n := int(testing.AllocsPerRun(1000, func() { + for _, b := range bools { + v.Store(b) + } + })); n != 0 { + t.Errorf("AllocsPerRun = %d, want 0", n) + } + } + + { + var v AtomicValue[int] + got, gotOk := v.LoadOk() + if got != 0 || gotOk { + t.Fatalf("LoadOk = (%v, %v), want (0, false)", got, gotOk) + } + v.Store(1) + got, gotOk = v.LoadOk() + if got != 1 || !gotOk { + t.Fatalf("LoadOk = (%v, %v), want (1, true)", got, gotOk) + } + } + + { + var v AtomicValue[error] + got, gotOk := v.LoadOk() + if got != nil || gotOk { + t.Fatalf("LoadOk = (%v, %v), want (nil, false)", got, gotOk) + } + v.Store(io.EOF) + got, gotOk = v.LoadOk() + if got != io.EOF || !gotOk { + t.Fatalf("LoadOk = (%v, %v), want (EOF, true)", got, gotOk) + } + err := &os.PathError{} + v.Store(err) + got, gotOk = v.LoadOk() + if got != err || !gotOk { + t.Fatalf("LoadOk = (%v, %v), want (%v, true)", got, gotOk, err) + } + v.Store(nil) + got, gotOk = v.LoadOk() + if got != nil || !gotOk { + t.Fatalf("LoadOk = (%v, %v), want (nil, true)", got, gotOk) + } + } +} + func TestWaitGroupChan(t *testing.T) { wg := NewWaitGroupChan()