diff --git a/util/lru/lru.go b/util/lru/lru.go new file mode 100644 index 000000000..639e7c91b --- /dev/null +++ b/util/lru/lru.go @@ -0,0 +1,110 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package lru contains a typed Least-Recently-Used cache. +package lru + +import ( + "container/list" +) + +// Cache is container type keyed by K, storing V, optionally evicting the least +// recently used items if a maximum size is exceeded. +// +// The zero value is valid to use. +// +// It is not safe for concurrent access. +// +// The current implementation is just the traditional LRU linked list; a future +// implementation may be more advanced to avoid pathological cases. +type Cache[K comparable, V any] struct { + // MaxEntries is the maximum number of cache entries before + // an item is evicted. Zero means no limit. + MaxEntries int + + ll *list.List + m map[K]*list.Element // of *entry[K,V] +} + +// entry is the element type for the container/list.Element. +type entry[K comparable, V any] struct { + key K + value V +} + +// Set adds or replaces a value to the cache, set or updating its associated +// value. +// +// If MaxEntries is non-zero and the length of the cache is greater +// after any addition, the least recently used value is evicted. +func (c *Cache[K, V]) Set(key K, value V) { + if c.m == nil { + c.m = make(map[K]*list.Element) + c.ll = list.New() + } + if ee, ok := c.m[key]; ok { + c.ll.MoveToFront(ee) + ee.Value.(*entry[K, V]).value = value + return + } + ele := c.ll.PushFront(&entry[K, V]{key, value}) + c.m[key] = ele + if c.MaxEntries != 0 && c.Len() > c.MaxEntries { + c.DeleteOldest() + } +} + +// Get looks up a key's value from the cache, returning either +// the value or the zero value if it not present. +// +// If found, key is moved to the front of the LRU. +func (c *Cache[K, V]) Get(key K) V { + v, _ := c.GetOk(key) + return v +} + +// Contains reports whether c contains key. +// +// If found, key is moved to the front of the LRU. +func (c *Cache[K, V]) Contains(key K) bool { + _, ok := c.GetOk(key) + return ok +} + +// GetOk looks up a key's value from the cache, also reporting +// whether it was present. +// +// If found, key is moved to the front of the LRU. +func (c *Cache[K, V]) GetOk(key K) (value V, ok bool) { + if ele, hit := c.m[key]; hit { + c.ll.MoveToFront(ele) + return ele.Value.(*entry[K, V]).value, true + } + var zero V + return zero, false +} + +// Delete removes the provided key from the cache if it was present. +func (c *Cache[K, V]) Delete(key K) { + if e, ok := c.m[key]; ok { + c.deleteElement(e) + } +} + +// DeleteOldest removes the item from the cache that was least recently +// accessed. It is a no-op if the cache is empty. +func (c *Cache[K, V]) DeleteOldest() { + if c.ll != nil { + if e := c.ll.Back(); e != nil { + c.deleteElement(e) + } + } +} + +func (c *Cache[K, V]) deleteElement(e *list.Element) { + c.ll.Remove(e) + delete(c.m, e.Value.(*entry[K, V]).key) +} + +// Len returns the number of items in the cache. +func (c *Cache[K, V]) Len() int { return len(c.m) } diff --git a/util/lru/lru_test.go b/util/lru/lru_test.go new file mode 100644 index 000000000..315c86218 --- /dev/null +++ b/util/lru/lru_test.go @@ -0,0 +1,42 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package lru + +import "testing" + +func TestLRU(t *testing.T) { + var c Cache[int, string] + c.Set(1, "one") + c.Set(2, "two") + if g, w := c.Get(1), "one"; g != w { + t.Errorf("got %q; want %q", g, w) + } + if g, w := c.Get(2), "two"; g != w { + t.Errorf("got %q; want %q", g, w) + } + c.DeleteOldest() + if g, w := c.Get(1), ""; g != w { + t.Errorf("got %q; want %q", g, w) + } + if g, w := c.Len(), 1; g != w { + t.Errorf("Len = %d; want %d", g, w) + } + c.MaxEntries = 2 + c.Set(1, "one") + c.Set(2, "two") + c.Set(3, "three") + if c.Contains(1) { + t.Errorf("contains 1; should not") + } + if !c.Contains(2) { + t.Errorf("doesn't contain 2; should") + } + if !c.Contains(3) { + t.Errorf("doesn't contain 3; should") + } + c.Delete(3) + if c.Contains(3) { + t.Errorf("contains 3; should not") + } +}