diff --git a/util/cache/cache_test.go b/util/cache/cache_test.go new file mode 100644 index 000000000..a6683e12d --- /dev/null +++ b/util/cache/cache_test.go @@ -0,0 +1,199 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cache + +import ( + "errors" + "testing" + "time" +) + +var startTime = time.Date(2023, time.March, 1, 0, 0, 0, 0, time.UTC) + +func TestSingleCache(t *testing.T) { + testTime := startTime + timeNow := func() time.Time { return testTime } + c := &Single[string, int]{ + timeNow: timeNow, + } + + t.Run("NoServeExpired", func(t *testing.T) { + testCacheImpl(t, c, &testTime, false) + }) + + t.Run("ServeExpired", func(t *testing.T) { + c.Empty() + c.ServeExpired = true + testTime = startTime + testCacheImpl(t, c, &testTime, true) + }) +} + +func TestLocking(t *testing.T) { + testTime := startTime + timeNow := func() time.Time { return testTime } + c := NewLocking(&Single[string, int]{ + timeNow: timeNow, + }) + + // Just verify that the inner cache's behaviour hasn't changed. + testCacheImpl(t, c, &testTime, false) +} + +func testCacheImpl(t *testing.T, c Cache[string, int], testTime *time.Time, serveExpired bool) { + var fillTime time.Time + t.Run("InitialFill", func(t *testing.T) { + fillTime = testTime.Add(time.Hour) + val, err := c.Get("key", func() (int, time.Time, error) { + return 123, fillTime, nil + }) + if err != nil { + t.Fatal(err) + } + if val != 123 { + t.Fatalf("got val=%d; want 123", val) + } + }) + + // Fetching again won't call our fill function + t.Run("SecondFetch", func(t *testing.T) { + *testTime = fillTime.Add(-1 * time.Second) + called := false + val, err := c.Get("key", func() (int, time.Time, error) { + called = true + return -1, fillTime, nil + }) + if called { + t.Fatal("wanted no call to fill function") + } + if err != nil { + t.Fatal(err) + } + if val != 123 { + t.Fatalf("got val=%d; want 123", val) + } + }) + + // Fetching after the expiry time will re-fill + t.Run("ReFill", func(t *testing.T) { + *testTime = fillTime.Add(1) + fillTime = fillTime.Add(time.Hour) + val, err := c.Get("key", func() (int, time.Time, error) { + return 999, fillTime, nil + }) + if err != nil { + t.Fatal(err) + } + if val != 999 { + t.Fatalf("got val=%d; want 999", val) + } + }) + + // An error on fetch will serve the expired value. + t.Run("FetchError", func(t *testing.T) { + if !serveExpired { + t.Skipf("not testing ServeExpired") + } + + *testTime = fillTime.Add(time.Hour + 1) + val, err := c.Get("key", func() (int, time.Time, error) { + return 0, time.Time{}, errors.New("some error") + }) + if err != nil { + t.Fatal(err) + } + if val != 999 { + t.Fatalf("got val=%d; want 999", val) + } + }) + + // Fetching a different key re-fills + t.Run("DifferentKey", func(t *testing.T) { + *testTime = fillTime.Add(time.Hour + 1) + + var calls int + val, err := c.Get("key1", func() (int, time.Time, error) { + calls++ + return 123, fillTime, nil + }) + if err != nil { + t.Fatal(err) + } + if val != 123 { + t.Fatalf("got val=%d; want 123", val) + } + if calls != 1 { + t.Errorf("got %d, want 1 call", calls) + } + + val, err = c.Get("key2", func() (int, time.Time, error) { + calls++ + return 456, fillTime, nil + }) + if err != nil { + t.Fatal(err) + } + if val != 456 { + t.Fatalf("got val=%d; want 456", val) + } + if calls != 2 { + t.Errorf("got %d, want 2 call", calls) + } + }) + + // Calling Forget with the wrong key does nothing, and with the correct + // key will drop the cache. + t.Run("Forget", func(t *testing.T) { + // Add some time so that previously-cached values don't matter. + fillTime = testTime.Add(2 * time.Hour) + *testTime = fillTime.Add(-1 * time.Second) + + const key = "key" + + var calls int + val, err := c.Get(key, func() (int, time.Time, error) { + calls++ + return 123, fillTime, nil + }) + if err != nil { + t.Fatal(err) + } + if val != 123 { + t.Fatalf("got val=%d; want 123", val) + } + if calls != 1 { + t.Errorf("got %d, want 1 call", calls) + } + + // Forgetting the wrong key does nothing + c.Forget("other") + val, err = c.Get(key, func() (int, time.Time, error) { + t.Fatal("should not be called") + panic("unreachable") + }) + if err != nil { + t.Fatal(err) + } + if val != 123 { + t.Fatalf("got val=%d; want 123", val) + } + + // Forgetting the correct key re-fills + c.Forget(key) + + val, err = c.Get("key2", func() (int, time.Time, error) { + calls++ + return 456, fillTime, nil + }) + if err != nil { + t.Fatal(err) + } + if val != 456 { + t.Fatalf("got val=%d; want 456", val) + } + if calls != 2 { + t.Errorf("got %d, want 2 call", calls) + } + }) +} diff --git a/util/cache/interface.go b/util/cache/interface.go new file mode 100644 index 000000000..0db87ba0e --- /dev/null +++ b/util/cache/interface.go @@ -0,0 +1,40 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package cache contains an interface for a cache around a typed value, and +// various cache implementations that implement that interface. +package cache + +import "time" + +// Cache is the interface for the cache types in this package. +// +// Functions in this interface take a key parameter, but it is valid for a +// cache type to hold a single value associated with a key, and simply drop the +// cached value if provided with a different key. +// +// It is valid for Cache implementations to be concurrency-safe or not, and +// each implementation should document this. If you need a concurrency-safe +// cache, an existing cache can be wrapped with a lock using NewLocking(inner). +// +// K and V should be types that can be successfully passed to json.Marshal. +type Cache[K comparable, V any] interface { + // Get should return a previously-cached value or call the provided + // FillFunc to obtain a new one. The provided key can be used either to + // allow multiple cached values, or to drop the cache if the key + // changes; either is valid. + Get(K, FillFunc[V]) (V, error) + + // Forget should remove the given key from the cache, if it is present. + // If it is not present, nothing should be done. + Forget(K) + + // Empty should empty the cache such that the next call to Get should + // call the provided FillFunc for all possible keys. + Empty() +} + +// FillFunc is the signature of a function for filling a cache. It should +// return the value to be cached, the time that the cached value is valid +// until, or an error. +type FillFunc[T any] func() (T, time.Time, error) diff --git a/util/cache/locking.go b/util/cache/locking.go new file mode 100644 index 000000000..85e44b360 --- /dev/null +++ b/util/cache/locking.go @@ -0,0 +1,43 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cache + +import "sync" + +// Locking wraps an inner Cache implementation with a mutex, making it +// safe for concurrent use. All methods are serialized on the same mutex. +type Locking[K comparable, V any, C Cache[K, V]] struct { + sync.Mutex + inner C +} + +// NewLocking creates a new Locking cache wrapping inner. +func NewLocking[K comparable, V any, C Cache[K, V]](inner C) *Locking[K, V, C] { + return &Locking[K, V, C]{inner: inner} +} + +// Get implements Cache. +// +// The cache's mutex is held for the entire duration of this function, +// including while the FillFunc is being called. This function is not +// reentrant; attempting to call Get from a FillFunc will deadlock. +func (c *Locking[K, V, C]) Get(key K, f FillFunc[V]) (V, error) { + c.Lock() + defer c.Unlock() + return c.inner.Get(key, f) +} + +// Forget implements Cache. +func (c *Locking[K, V, C]) Forget(key K) { + c.Lock() + defer c.Unlock() + c.inner.Forget(key) +} + +// Empty implements Cache. +func (c *Locking[K, V, C]) Empty() { + c.Lock() + defer c.Unlock() + c.inner.Empty() +} diff --git a/util/cache/none.go b/util/cache/none.go new file mode 100644 index 000000000..e3e53e0b4 --- /dev/null +++ b/util/cache/none.go @@ -0,0 +1,18 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cache + +// None provides no caching and always calls the provided FillFunc. +// +// It is safe for concurrent use if the underlying FillFunc is. +type None[K comparable, V any] struct{} + +// Get always calls the provided FillFunc and returns what it does. +func (c None[K, V]) Get(_ K, f FillFunc[V]) (V, error) { + v, _, e := f() + return v, e +} + +// Forget implements Cache. +func (c None[K, V]) Forget() {} diff --git a/util/cache/single.go b/util/cache/single.go new file mode 100644 index 000000000..5b378cc15 --- /dev/null +++ b/util/cache/single.go @@ -0,0 +1,79 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package cache + +import ( + "time" +) + +// Single is a simple in-memory cache that stores a single value until a +// defined time before it is re-fetched. It also supports returning a +// previously-expired value if refreshing the value in the cache fails. +// +// Single is not safe for concurrent use. +type Single[K comparable, V any] struct { + key K + val V + goodUntil time.Time + timeNow func() time.Time // for tests + + // ServeExpired indicates that if an error occurs when filling the + // cache, an expired value can be returned instead of an error. + // + // This value should only be set when this struct is created. + ServeExpired bool +} + +// Get will return the cached value, if any, or fill the cache by calling f and +// return the corresponding value. If f returns an error and c.ServeExpired is +// true, then a previous expired value can be returned with no error. +func (c *Single[K, V]) Get(key K, f FillFunc[V]) (V, error) { + var now time.Time + if c.timeNow != nil { + now = c.timeNow() + } else { + now = time.Now() + } + + if c.key == key && now.Before(c.goodUntil) { + return c.val, nil + } + + // Re-fill cached entry + val, until, err := f() + if err == nil { + c.key = key + c.val = val + c.goodUntil = until + return val, nil + } + + // Never serve an expired entry for the wrong key. + if c.key == key && c.ServeExpired && !c.goodUntil.IsZero() { + return c.val, nil + } + + var zero V + return zero, err +} + +// Forget implements Cache. +func (c *Single[K, V]) Forget(key K) { + if c.key != key { + return + } + + c.Empty() +} + +// Empty implements Cache. +func (c *Single[K, V]) Empty() { + c.goodUntil = time.Time{} + + var zeroKey K + c.key = zeroKey + + var zeroVal V + c.val = zeroVal +}