types/lazy: add (*SyncValue[T]).SetForTest method

It is sometimes necessary to change a global lazy.SyncValue for the duration of a test. This PR adds a (*SyncValue[T]).SetForTest method to facilitate that.

Updates #12687

Signed-off-by: Nick Khyl <nickk@tailscale.com>
marwan/offunc
Nick Khyl 4 months ago committed by Nick Khyl
parent d500a92926
commit 5d09649b0b

@ -154,3 +154,34 @@ func SyncFuncErr[T any](fill func() (T, error)) func() (T, error) {
return v, err return v, err
} }
} }
// TB is a subset of testing.TB that we use to set up test helpers.
// It's defined here to avoid pulling in the testing package.
type TB interface {
Helper()
Cleanup(func())
}
// SetForTest sets z's value and error.
// It's used in tests only and reverts z's state back when tb and all its
// subtests complete.
// It is not safe for concurrent use and must not be called concurrently with
// any SyncValue methods, including another call to itself.
func (z *SyncValue[T]) SetForTest(tb TB, val T, err error) {
tb.Helper()
z.once.Do(func() {})
oldErr, oldVal := z.err.Load(), z.v
z.v = val
if err != nil {
z.err.Store(ptr.To(err))
} else {
z.err.Store(nilErrPtr)
}
tb.Cleanup(func() {
z.v = oldVal
z.err.Store(oldErr)
})
}

@ -8,6 +8,8 @@ import (
"fmt" "fmt"
"sync" "sync"
"testing" "testing"
"tailscale.com/types/opt"
) )
func TestSyncValue(t *testing.T) { func TestSyncValue(t *testing.T) {
@ -147,6 +149,196 @@ func TestSyncValueConcurrent(t *testing.T) {
wg.Wait() wg.Wait()
} }
func TestSyncValueSetForTest(t *testing.T) {
testErr := errors.New("boom")
tests := []struct {
name string
initValue opt.Value[int]
initErr opt.Value[error]
setForTestValue int
setForTestErr error
getValue int
getErr opt.Value[error]
wantValue int
wantErr error
routines int
}{
{
name: "GetOk",
setForTestValue: 42,
getValue: 8,
wantValue: 42,
},
{
name: "GetOk/WithInit",
initValue: opt.ValueOf(4),
setForTestValue: 42,
getValue: 8,
wantValue: 42,
},
{
name: "GetOk/WithInitErr",
initValue: opt.ValueOf(4),
initErr: opt.ValueOf(errors.New("blast")),
setForTestValue: 42,
getValue: 8,
wantValue: 42,
},
{
name: "GetErr",
setForTestValue: 42,
setForTestErr: testErr,
getValue: 8,
getErr: opt.ValueOf(errors.New("ka-boom")),
wantValue: 42,
wantErr: testErr,
},
{
name: "GetErr/NilError",
setForTestValue: 42,
setForTestErr: nil,
getValue: 8,
getErr: opt.ValueOf(errors.New("ka-boom")),
wantValue: 42,
wantErr: nil,
},
{
name: "GetErr/WithInitErr",
initValue: opt.ValueOf(4),
initErr: opt.ValueOf(errors.New("blast")),
setForTestValue: 42,
setForTestErr: testErr,
getValue: 8,
getErr: opt.ValueOf(errors.New("ka-boom")),
wantValue: 42,
wantErr: testErr,
},
{
name: "Concurrent/GetOk",
setForTestValue: 42,
getValue: 8,
wantValue: 42,
routines: 10000,
},
{
name: "Concurrent/GetOk/WithInitErr",
initValue: opt.ValueOf(4),
initErr: opt.ValueOf(errors.New("blast")),
setForTestValue: 42,
getValue: 8,
wantValue: 42,
routines: 10000,
},
{
name: "Concurrent/GetErr",
setForTestValue: 42,
setForTestErr: testErr,
getValue: 8,
getErr: opt.ValueOf(errors.New("ka-boom")),
wantValue: 42,
wantErr: testErr,
routines: 10000,
},
{
name: "Concurrent/GetErr/WithInitErr",
initValue: opt.ValueOf(4),
initErr: opt.ValueOf(errors.New("blast")),
setForTestValue: 42,
setForTestErr: testErr,
getValue: 8,
getErr: opt.ValueOf(errors.New("ka-boom")),
wantValue: 42,
wantErr: testErr,
routines: 10000,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var v SyncValue[int]
// Initialize the sync value with the specified value and/or error,
// if required by the test.
if initValue, ok := tt.initValue.GetOk(); ok {
var wantInitErr, gotInitErr error
var wantInitValue, gotInitValue int
wantInitValue = initValue
if initErr, ok := tt.initErr.GetOk(); ok {
wantInitErr = initErr
gotInitValue, gotInitErr = v.GetErr(func() (int, error) { return initValue, initErr })
} else {
gotInitValue = v.Get(func() int { return initValue })
}
if gotInitErr != wantInitErr {
t.Fatalf("InitErr: got %v; want %v", gotInitErr, wantInitErr)
}
if gotInitValue != wantInitValue {
t.Fatalf("InitValue: got %v; want %v", gotInitValue, wantInitValue)
}
// Verify that SetForTest reverted the error and the value during the test cleanup.
t.Cleanup(func() {
wantCleanupValue, wantCleanupErr := wantInitValue, wantInitErr
gotCleanupValue, gotCleanupErr, ok := v.PeekErr()
if !ok {
t.Fatal("SyncValue is not set after cleanup")
}
if gotCleanupErr != wantCleanupErr {
t.Fatalf("CleanupErr: got %v; want %v", gotCleanupErr, wantCleanupErr)
}
if gotCleanupValue != wantCleanupValue {
t.Fatalf("CleanupValue: got %v; want %v", gotCleanupValue, wantCleanupValue)
}
})
}
// Set the test value and/or error.
v.SetForTest(t, tt.setForTestValue, tt.setForTestErr)
// Verify that the value and/or error have been set.
// This will run on either the current goroutine
// or concurrently depending on the tt.routines value.
checkSyncValue := func() {
var gotValue int
var gotErr error
if getErr, ok := tt.getErr.GetOk(); ok {
gotValue, gotErr = v.GetErr(func() (int, error) { return tt.getValue, getErr })
} else {
gotValue = v.Get(func() int { return tt.getValue })
}
if gotErr != tt.wantErr {
t.Errorf("Err: got %v; want %v", gotErr, tt.wantErr)
}
if gotValue != tt.wantValue {
t.Errorf("Value: got %v; want %v", gotValue, tt.wantValue)
}
}
switch tt.routines {
case 0:
checkSyncValue()
default:
var wg sync.WaitGroup
wg.Add(tt.routines)
start := make(chan struct{})
for range tt.routines {
go func() {
defer wg.Done()
// Every goroutine waits for the go signal, so that more of them
// have a chance to race on the initial Get than with sequential
// goroutine starts.
<-start
checkSyncValue()
}()
}
close(start)
wg.Wait()
}
})
}
}
func TestSyncFunc(t *testing.T) { func TestSyncFunc(t *testing.T) {
f := SyncFunc(fortyTwo) f := SyncFunc(fortyTwo)

Loading…
Cancel
Save