diff --git a/net/flowtrack/flowtrack.go b/net/flowtrack/flowtrack.go index 9e96a42be..41b3bb9ab 100644 --- a/net/flowtrack/flowtrack.go +++ b/net/flowtrack/flowtrack.go @@ -34,7 +34,7 @@ func (t Tuple) String() string { // The zero value is valid to use. // // It is not safe for concurrent access. -type Cache struct { +type Cache[Value any] struct { // MaxEntries is the maximum number of cache entries before // an item is evicted. Zero means no limit. MaxEntries int @@ -44,9 +44,9 @@ type Cache struct { } // entry is the container/list element type. -type entry struct { +type entry[Value any] struct { key Tuple - value any + value Value } // Add adds a value to the cache, set or updating its associated @@ -54,17 +54,17 @@ type entry struct { // // If MaxEntries is non-zero and the length of the cache is greater // after any addition, the least recently used value is evicted. -func (c *Cache) Add(key Tuple, value any) { +func (c *Cache[Value]) Add(key Tuple, value Value) { if c.m == nil { c.m = make(map[Tuple]*list.Element) c.ll = list.New() } if ee, ok := c.m[key]; ok { c.ll.MoveToFront(ee) - ee.Value.(*entry).value = value + ee.Value.(*entry[Value]).value = value return } - ele := c.ll.PushFront(&entry{key, value}) + ele := c.ll.PushFront(&entry[Value]{key, value}) c.m[key] = ele if c.MaxEntries != 0 && c.Len() > c.MaxEntries { c.RemoveOldest() @@ -73,23 +73,23 @@ func (c *Cache) Add(key Tuple, value any) { // Get looks up a key's value from the cache, also reporting // whether it was present. -func (c *Cache) Get(key Tuple) (value any, ok bool) { +func (c *Cache[Value]) Get(key Tuple) (value *Value, ok bool) { if ele, hit := c.m[key]; hit { c.ll.MoveToFront(ele) - return ele.Value.(*entry).value, true + return &ele.Value.(*entry[Value]).value, true } return nil, false } // Remove removes the provided key from the cache if it was present. -func (c *Cache) Remove(key Tuple) { +func (c *Cache[Value]) Remove(key Tuple) { if ele, hit := c.m[key]; hit { c.removeElement(ele) } } // RemoveOldest removes the oldest item from the cache, if any. -func (c *Cache) RemoveOldest() { +func (c *Cache[Value]) RemoveOldest() { if c.ll != nil { if ele := c.ll.Back(); ele != nil { c.removeElement(ele) @@ -97,10 +97,10 @@ func (c *Cache) RemoveOldest() { } } -func (c *Cache) removeElement(e *list.Element) { +func (c *Cache[Value]) removeElement(e *list.Element) { c.ll.Remove(e) - delete(c.m, e.Value.(*entry).key) + delete(c.m, e.Value.(*entry[Value]).key) } // Len returns the number of items in the cache. -func (c *Cache) Len() int { return len(c.m) } +func (c *Cache[Value]) Len() int { return len(c.m) } diff --git a/net/flowtrack/flowtrack_test.go b/net/flowtrack/flowtrack_test.go index c48ded72b..cb71546f6 100644 --- a/net/flowtrack/flowtrack_test.go +++ b/net/flowtrack/flowtrack_test.go @@ -12,7 +12,7 @@ import ( ) func TestCache(t *testing.T) { - c := &Cache{MaxEntries: 2} + c := &Cache[int]{MaxEntries: 2} k1 := Tuple{Src: netip.MustParseAddrPort("1.1.1.1:1"), Dst: netip.MustParseAddrPort("1.1.1.1:1")} k2 := Tuple{Src: netip.MustParseAddrPort("1.1.1.1:1"), Dst: netip.MustParseAddrPort("2.2.2.2:2")} @@ -25,13 +25,13 @@ func TestCache(t *testing.T) { t.Fatalf("Len = %d; want %d", got, want) } } - wantVal := func(key Tuple, want any) { + wantVal := func(key Tuple, want int) { t.Helper() got, ok := c.Get(key) if !ok { t.Fatalf("Get(%q) failed; want value %v", key, want) } - if got != want { + if *got != want { t.Fatalf("Get(%q) = %v; want %v", key, got, want) } } @@ -73,7 +73,7 @@ func TestCache(t *testing.T) { if !ok { t.Fatal("missing k3") } - if got != 30 { + if *got != 30 { t.Fatalf("got = %d; want 30", got) } }) diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index f99253967..d34af75ef 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -59,7 +59,7 @@ type Filter struct { // filterState is a state cache of past seen packets. type filterState struct { mu sync.Mutex - lru *flowtrack.Cache // from flowtrack.Tuple -> nil + lru *flowtrack.Cache[struct{}] // from flowtrack.Tuple -> struct{} } // lruMax is the size of the LRU cache in filterState. @@ -176,7 +176,7 @@ func New(matches []Match, localNets *netipx.IPSet, logIPs *netipx.IPSet, shareSt state = shareStateWith.state } else { state = &filterState{ - lru: &flowtrack.Cache{MaxEntries: lruMax}, + lru: &flowtrack.Cache[struct{}]{MaxEntries: lruMax}, } } f := &Filter{ @@ -517,7 +517,7 @@ func (f *Filter) runOut(q *packet.Parsed) (r Response, why string) { Src: q.Dst, Dst: q.Src, // src/dst reversed } f.state.mu.Lock() - f.state.lru.Add(tuple, nil) + f.state.lru.Add(tuple, struct{}{}) f.state.mu.Unlock() } return Accept, "ok out"