diff --git a/syncs/watchdog.go b/syncs/watchdog.go new file mode 100644 index 000000000..36dcb0758 --- /dev/null +++ b/syncs/watchdog.go @@ -0,0 +1,95 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package syncs + +import ( + "context" + "sync" + "time" +) + +// Watch monitors mu for contention. +// On first call, and at every tick, Watch locks and unlocks mu. +// (Tick should be large to avoid adding contention to mu.) +// Max is the maximum length of time Watch will wait to acquire the lock. +// The time required to lock mu is sent on the returned channel. +// Watch exits when ctx is done, and closes the returned channel. +func Watch(ctx context.Context, mu sync.Locker, tick, max time.Duration) chan time.Duration { + // Set up the return channel. + c := make(chan time.Duration) + var ( + closemu sync.Mutex + closed bool + ) + sendc := func(d time.Duration) { + closemu.Lock() + defer closemu.Unlock() + if closed { + // Drop values written after c is closed. + return + } + c <- d + } + closec := func() { + closemu.Lock() + defer closemu.Unlock() + close(c) + closed = true + } + + // check locks the mutex and writes how long it took to c. + // check returns ~immediately. + check := func() { + // Start a race between two goroutines. + // One locks the mutex; the other times out. + // Ensure that only one of the two gets to write its result. + // Since the common case is that locking the mutex is fast, + // let the timeout goroutine exit early when that happens. + var sendonce sync.Once + done := make(chan bool) + go func() { + start := time.Now() + mu.Lock() + mu.Unlock() //lint:ignore SA2001 ignore the empty critical section + elapsed := time.Since(start) + if elapsed > max { + elapsed = max + } + close(done) + sendonce.Do(func() { sendc(elapsed) }) + }() + go func() { + select { + case <-time.After(max): + // the other goroutine may not have sent a value + sendonce.Do(func() { sendc(max) }) + case <-done: + // the other goroutine sent a value + } + }() + } + + // Check once at startup. + // This is mainly to make testing easier. + check() + + // Start the watchdog goroutine. + // It checks the mutex every tick, until ctx is done. + go func() { + ticker := time.NewTicker(tick) + for { + select { + case <-ctx.Done(): + closec() + ticker.Stop() + return + case <-ticker.C: + check() + } + } + }() + + return c +} diff --git a/syncs/watchdog_test.go b/syncs/watchdog_test.go new file mode 100644 index 000000000..b5cc3452e --- /dev/null +++ b/syncs/watchdog_test.go @@ -0,0 +1,71 @@ +// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package syncs + +import ( + "context" + "sync" + "testing" + "time" +) + +// Time-based tests are fundamentally flaky. +// We use exaggerated durations in the hopes of minimizing such issues. + +func TestWatchUncontended(t *testing.T) { + mu := new(sync.Mutex) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Once an hour, and now, check whether we can lock mu in under an hour. + tick := time.Hour + max := time.Hour + c := Watch(ctx, mu, tick, max) + d := <-c + if d == max { + t.Errorf("uncontended mutex did not lock in under %v", max) + } +} + +func TestWatchContended(t *testing.T) { + mu := new(sync.Mutex) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + // Every hour, and now, check whether we can lock mu in under a millisecond, + // which is enough time for an uncontended mutex by several orders of magnitude. + tick := time.Hour + max := time.Millisecond + mu.Lock() + defer mu.Unlock() + c := Watch(ctx, mu, tick, max) + d := <-c + if d != max { + t.Errorf("contended mutex locked in under %v", max) + } +} + +func TestWatchMultipleValues(t *testing.T) { + mu := new(sync.Mutex) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // not necessary, but keep vet happy + // Check the mutex every millisecond. + // The goal is to see that we get a sufficient number of values out of the channel. + tick := time.Millisecond + max := time.Millisecond + c := Watch(ctx, mu, tick, max) + start := time.Now() + n := 0 + for d := range c { + n++ + if d == max { + t.Errorf("uncontended mutex did not lock in under %v", max) + } + if n == 10 { + cancel() + } + } + if elapsed := time.Since(start); elapsed > 100*time.Millisecond { + t.Errorf("expected 1 event per millisecond, got only %v events in %v", n, elapsed) + } +}