diff --git a/net/flowtrack/flowtrack.go b/net/flowtrack/flowtrack.go new file mode 100644 index 000000000..8d490d854 --- /dev/null +++ b/net/flowtrack/flowtrack.go @@ -0,0 +1,99 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// Original implementation (from same author) from which this was derived was: +// https://github.com/golang/groupcache/blob/5b532d6fd5efaf7fa130d4e859a2fde0fc3a9e1b/lru/lru.go +// ... which was Apache licensed: +// https://github.com/golang/groupcache/blob/master/LICENSE + +// Package flowtrack contains types for tracking TCP/UDP flows by 4-tuples. +package flowtrack + +import ( + "container/list" + + "inet.af/netaddr" +) + +// Tuple is a 4-tuple of source and destination IP and port. +type Tuple struct { + Src netaddr.IPPort + Dst netaddr.IPPort +} + +// Cache is an LRU cache keyed by Tuple. +// +// The zero value is valid to use. +// +// It is not safe for concurrent access. +type Cache struct { + // MaxEntries is the maximum number of cache entries before + // an item is evicted. Zero means no limit. + MaxEntries int + + ll *list.List + m map[Tuple]*list.Element // of *entry +} + +// entry is the container/list element type. +type entry struct { + key Tuple + value interface{} +} + +// Add adds a value to the cache, set or updating its assoicated +// value. +// +// If MaxEntries is non-zero and the length of the cache is greater +// after any addition, the least recently used value is evicted. +func (c *Cache) Add(key Tuple, value interface{}) { + if c.m == nil { + c.m = make(map[Tuple]*list.Element) + c.ll = list.New() + } + if ee, ok := c.m[key]; ok { + c.ll.MoveToFront(ee) + ee.Value.(*entry).value = value + return + } + ele := c.ll.PushFront(&entry{key, value}) + c.m[key] = ele + if c.MaxEntries != 0 && c.Len() > c.MaxEntries { + c.RemoveOldest() + } +} + +// Get looks up a key's value from the cache, also reporting +// whether it was present. +func (c *Cache) Get(key Tuple) (value interface{}, ok bool) { + if ele, hit := c.m[key]; hit { + c.ll.MoveToFront(ele) + return ele.Value.(*entry).value, true + } + return nil, false +} + +// Remove removes the provided key from the cache if it was present. +func (c *Cache) Remove(key Tuple) { + if ele, hit := c.m[key]; hit { + c.removeElement(ele) + } +} + +// RemoveOldest removes the oldest item from the cache, if any. +func (c *Cache) RemoveOldest() { + if c.ll != nil { + if ele := c.ll.Back(); ele != nil { + c.removeElement(ele) + } + } +} + +func (c *Cache) removeElement(e *list.Element) { + c.ll.Remove(e) + delete(c.m, e.Value.(*entry).key) +} + +// Len returns the number of items in the cache. +func (c *Cache) Len() int { return len(c.m) } diff --git a/net/flowtrack/flowtrack_test.go b/net/flowtrack/flowtrack_test.go new file mode 100644 index 000000000..4c473c717 --- /dev/null +++ b/net/flowtrack/flowtrack_test.go @@ -0,0 +1,82 @@ +// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flowtrack + +import ( + "testing" + + "inet.af/netaddr" +) + +func TestCache(t *testing.T) { + c := &Cache{MaxEntries: 2} + + k1 := Tuple{Src: netaddr.MustParseIPPort("1.1.1.1:1"), Dst: netaddr.MustParseIPPort("1.1.1.1:1")} + k2 := Tuple{Src: netaddr.MustParseIPPort("1.1.1.1:1"), Dst: netaddr.MustParseIPPort("2.2.2.2:2")} + k3 := Tuple{Src: netaddr.MustParseIPPort("1.1.1.1:1"), Dst: netaddr.MustParseIPPort("3.3.3.3:3")} + k4 := Tuple{Src: netaddr.MustParseIPPort("1.1.1.1:1"), Dst: netaddr.MustParseIPPort("4.4.4.4:4")} + + wantLen := func(want int) { + t.Helper() + if got := c.Len(); got != want { + t.Fatalf("Len = %d; want %d", got, want) + } + } + wantVal := func(key Tuple, want interface{}) { + t.Helper() + got, ok := c.Get(key) + if !ok { + t.Fatalf("Get(%q) failed; want value %v", key, want) + } + if got != want { + t.Fatalf("Get(%q) = %v; want %v", key, got, want) + } + } + wantMissing := func(key Tuple) { + t.Helper() + if got, ok := c.Get(key); ok { + t.Fatalf("Get(%q) = %v; want absent from cache", key, got) + } + } + + wantLen(0) + c.RemoveOldest() // shouldn't panic + c.Remove(k4) // shouldn't panic + + c.Add(k1, 1) + wantLen(1) + c.Add(k2, 2) + wantLen(2) + c.Add(k3, 3) + wantLen(2) // hit the max + + wantMissing(k1) + c.Remove(k1) + wantLen(2) // no change; k1 should've been the deleted one per LRU + + wantVal(k3, 3) + + wantVal(k2, 2) + c.Remove(k2) + wantLen(1) + wantMissing(k2) + + c.Add(k3, 30) + wantVal(k3, 30) + wantLen(1) + + allocs := int(testing.AllocsPerRun(1000, func() { + got, ok := c.Get(k3) + if !ok { + t.Fatal("missing k3") + } + if got != 30 { + t.Fatalf("got = %d; want 30", got) + } + })) + if allocs != 0 { + t.Errorf("allocs = %v; want 0", allocs) + } +} diff --git a/wgengine/filter/filter.go b/wgengine/filter/filter.go index f35578e15..a0bdbf3af 100644 --- a/wgengine/filter/filter.go +++ b/wgengine/filter/filter.go @@ -10,9 +10,9 @@ import ( "sync" "time" - "github.com/golang/groupcache/lru" "golang.org/x/time/rate" "inet.af/netaddr" + "tailscale.com/net/flowtrack" "tailscale.com/net/packet" "tailscale.com/types/logger" ) @@ -41,17 +41,10 @@ type Filter struct { state *filterState } -// tuple is a 4-tuple of source and destination IP and port. It's used -// as a lookup key in filterState. -type tuple struct { - Src netaddr.IPPort - Dst netaddr.IPPort -} - // filterState is a state cache of past seen packets. type filterState struct { mu sync.Mutex - lru *lru.Cache // of tuple + lru *flowtrack.Cache // from flowtrack.Tuple -> nil } // lruMax is the size of the LRU cache in filterState. @@ -141,7 +134,7 @@ func New(matches []Match, localNets []netaddr.IPPrefix, shareStateWith *Filter, state = shareStateWith.state } else { state = &filterState{ - lru: lru.New(lruMax), + lru: &flowtrack.Cache{MaxEntries: lruMax}, } } f := &Filter{ @@ -334,7 +327,7 @@ func (f *Filter) runIn4(q *packet.Parsed) (r Response, why string) { return Accept, "tcp ok" } case packet.UDP: - t := tuple{q.Src, q.Dst} + t := flowtrack.Tuple{Src: q.Src, Dst: q.Dst} f.state.mu.Lock() _, ok := f.state.lru.Get(t) @@ -389,7 +382,7 @@ func (f *Filter) runIn6(q *packet.Parsed) (r Response, why string) { return Accept, "tcp ok" } case packet.UDP: - t := tuple{q.Src, q.Dst} + t := flowtrack.Tuple{Src: q.Src, Dst: q.Dst} f.state.mu.Lock() _, ok := f.state.lru.Get(t) @@ -413,10 +406,10 @@ func (f *Filter) runOut(q *packet.Parsed) (r Response, why string) { return Accept, "ok out" } - t := tuple{q.Dst, q.Src} - var ti interface{} = t // allocate once, rather than twice inside mutex + tuple := flowtrack.Tuple{Src: q.Dst, Dst: q.Src} // src/dst reversed + f.state.mu.Lock() - f.state.lru.Add(ti, ti) + f.state.lru.Add(tuple, nil) f.state.mu.Unlock() return Accept, "ok out" }