mirror of https://github.com/tailscale/tailscale/
wgengine/magicsock: only cache N most recent endpoints per-Addr
If a node is flapping or otherwise generating lots of STUN endpoints, we can end up caching a ton of useless values and sending them to peers. Instead, let's apply a fixed per-Addr limit of endpoints that we cache, so that we're only sending peers up to the N most recent. Updates tailscale/corp#13890 Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I8079a05b44220c46da55016c0e5fc96dd2135ef8pull/8903/head
parent
9c4364e0b7
commit
95d776bd8c
@ -0,0 +1,248 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tempfork/heap"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
const (
|
||||
// endpointTrackerLifetime is how long we continue advertising an
|
||||
// endpoint after we last see it. This is intentionally chosen to be
|
||||
// slightly longer than a full netcheck period.
|
||||
endpointTrackerLifetime = 5*time.Minute + 10*time.Second
|
||||
|
||||
// endpointTrackerMaxPerAddr is how many cached addresses we track for
|
||||
// a given netip.Addr. This allows e.g. restricting the number of STUN
|
||||
// endpoints we cache (which usually have the same netip.Addr but
|
||||
// different ports).
|
||||
//
|
||||
// The value of 6 is chosen because we can advertise up to 3 endpoints
|
||||
// based on the STUN IP:
|
||||
// 1. The STUN endpoint itself (EndpointSTUN)
|
||||
// 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort)
|
||||
// 3. The STUN IP with a portmapped port (EndpointPortmapped)
|
||||
//
|
||||
// Storing 6 endpoints in the cache means we can store up to 2 previous
|
||||
// sets of endpoints.
|
||||
endpointTrackerMaxPerAddr = 6
|
||||
)
|
||||
|
||||
// endpointTrackerEntry is an entry in an endpointHeap that stores the state of
|
||||
// a given cached endpoint.
|
||||
type endpointTrackerEntry struct {
|
||||
// endpoint is the cached endpoint.
|
||||
endpoint tailcfg.Endpoint
|
||||
// until is the time until which this endpoint is being cached.
|
||||
until time.Time
|
||||
// index is the index within the containing endpointHeap.
|
||||
index int
|
||||
}
|
||||
|
||||
// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in
|
||||
// ascending order by the 'until' expiry time (i.e. oldest first).
|
||||
type endpointHeap []*endpointTrackerEntry
|
||||
|
||||
var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil)
|
||||
|
||||
// Len implements heap.Interface.
|
||||
func (eh endpointHeap) Len() int { return len(eh) }
|
||||
|
||||
// Less implements heap.Interface.
|
||||
func (eh endpointHeap) Less(i, j int) bool {
|
||||
// We want to store items so that the lowest item in the heap is the
|
||||
// oldest, so that heap.Pop()-ing from the endpointHeap will remove the
|
||||
// oldest entry.
|
||||
return eh[i].until.Before(eh[j].until)
|
||||
}
|
||||
|
||||
// Swap implements heap.Interface.
|
||||
func (eh endpointHeap) Swap(i, j int) {
|
||||
eh[i], eh[j] = eh[j], eh[i]
|
||||
eh[i].index = i
|
||||
eh[j].index = j
|
||||
}
|
||||
|
||||
// Push implements heap.Interface.
|
||||
func (eh *endpointHeap) Push(item *endpointTrackerEntry) {
|
||||
n := len(*eh)
|
||||
item.index = n
|
||||
*eh = append(*eh, item)
|
||||
}
|
||||
|
||||
// Pop implements heap.Interface.
|
||||
func (eh *endpointHeap) Pop() *endpointTrackerEntry {
|
||||
old := *eh
|
||||
n := len(old)
|
||||
item := old[n-1]
|
||||
old[n-1] = nil // avoid memory leak
|
||||
item.index = -1 // for safety
|
||||
*eh = old[0 : n-1]
|
||||
return item
|
||||
}
|
||||
|
||||
// Min returns a pointer to the minimum element in the heap, without removing
|
||||
// it. Since this is a min-heap ordered by the 'until' field, this returns the
|
||||
// chronologically "earliest" element in the heap.
|
||||
//
|
||||
// Len() must be non-zero.
|
||||
func (eh endpointHeap) Min() *endpointTrackerEntry {
|
||||
return eh[0]
|
||||
}
|
||||
|
||||
// endpointTracker caches endpoints that are advertised to peers. This allows
|
||||
// peers to still reach this node if there's a temporary endpoint flap; rather
|
||||
// than withdrawing an endpoint and then re-advertising it the next time we run
|
||||
// a netcheck, we keep advertising the endpoint until it's not present for a
|
||||
// defined timeout.
|
||||
//
|
||||
// See tailscale/tailscale#7877 for more information.
|
||||
type endpointTracker struct {
|
||||
mu sync.Mutex
|
||||
endpoints map[netip.Addr]*endpointHeap
|
||||
}
|
||||
|
||||
// update takes as input the current sent of discovered endpoints and the
|
||||
// current time, and returns the set of endpoints plus any previous-cached and
|
||||
// non-expired endpoints that should be advertised to peers.
|
||||
func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) {
|
||||
var inputEps set.Slice[netip.AddrPort]
|
||||
for _, ep := range eps {
|
||||
inputEps.Add(ep.Addr)
|
||||
}
|
||||
|
||||
et.mu.Lock()
|
||||
defer et.mu.Unlock()
|
||||
|
||||
// Extend endpoints that already exist in the cache. We do this before
|
||||
// we remove expired endpoints, below, so we don't remove something
|
||||
// that would otherwise have survived by extending.
|
||||
until := now.Add(endpointTrackerLifetime)
|
||||
for _, ep := range eps {
|
||||
et.extendLocked(ep, until)
|
||||
}
|
||||
|
||||
// Now that we've extended existing endpoints, remove everything that
|
||||
// has expired.
|
||||
et.removeExpiredLocked(now)
|
||||
|
||||
// Add entries from the input set of endpoints into the cache; we do
|
||||
// this after removing expired ones so that we can store as many as
|
||||
// possible, with space freed by the entries removed after expiry.
|
||||
for _, ep := range eps {
|
||||
et.addLocked(now, ep, until)
|
||||
}
|
||||
|
||||
// Finally, add entries to the return array that aren't already there.
|
||||
epsPlusCached = eps
|
||||
for _, heap := range et.endpoints {
|
||||
for _, ep := range *heap {
|
||||
// If the endpoint was in the input list, or has expired, skip it.
|
||||
if inputEps.Contains(ep.endpoint.Addr) {
|
||||
continue
|
||||
} else if now.After(ep.until) {
|
||||
// Defense-in-depth; should never happen since
|
||||
// we removed expired entries above, but ignore
|
||||
// it anyway.
|
||||
continue
|
||||
}
|
||||
|
||||
// We haven't seen this endpoint; add to the return array
|
||||
epsPlusCached = append(epsPlusCached, ep.endpoint)
|
||||
}
|
||||
}
|
||||
|
||||
return epsPlusCached
|
||||
}
|
||||
|
||||
// extendLocked will update the expiry time of the provided endpoint in the
|
||||
// cache, if it is present. If it is not present, nothing will be done.
|
||||
//
|
||||
// et.mu must be held.
|
||||
func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) {
|
||||
key := ep.Addr.Addr()
|
||||
epHeap, found := et.endpoints[key]
|
||||
if !found {
|
||||
return
|
||||
}
|
||||
|
||||
// Find the entry for this exact address; this loop is quick since we
|
||||
// bound the number of items in the heap.
|
||||
//
|
||||
// TODO(andrew): this means we iterate over the entire heap once per
|
||||
// endpoint; even if the heap is small, if we have a lot of input
|
||||
// endpoints this can be expensive?
|
||||
for i, entry := range *epHeap {
|
||||
if entry.endpoint == ep {
|
||||
entry.until = until
|
||||
heap.Fix(epHeap, i)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// addLocked will store the provided endpoint(s) in the cache for a fixed
|
||||
// period of time, ensuring that the size of the endpoint cache remains below
|
||||
// the maximum.
|
||||
//
|
||||
// et.mu must be held.
|
||||
func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) {
|
||||
key := ep.Addr.Addr()
|
||||
|
||||
// Create or get the heap for this endpoint's addr
|
||||
epHeap := et.endpoints[key]
|
||||
if epHeap == nil {
|
||||
epHeap = new(endpointHeap)
|
||||
mak.Set(&et.endpoints, key, epHeap)
|
||||
}
|
||||
|
||||
// Find the entry for this exact address; this loop is quick
|
||||
// since we bound the number of items in the heap.
|
||||
found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool {
|
||||
return v.endpoint == ep
|
||||
})
|
||||
if !found {
|
||||
// Add address to heap; either the endpoint is new, or the heap
|
||||
// was newly-created and thus empty.
|
||||
heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until})
|
||||
}
|
||||
|
||||
// Now that we've added everything, pop from our heap until we're below
|
||||
// the limit. This is a min-heap, so popping removes the lowest (and
|
||||
// thus oldest) endpoint.
|
||||
for epHeap.Len() > endpointTrackerMaxPerAddr {
|
||||
heap.Pop(epHeap)
|
||||
}
|
||||
}
|
||||
|
||||
// removeExpired will remove all expired entries from the cache.
|
||||
//
|
||||
// et.mu must be held.
|
||||
func (et *endpointTracker) removeExpiredLocked(now time.Time) {
|
||||
for k, epHeap := range et.endpoints {
|
||||
// The minimum element is oldest/earliest endpoint; repeatedly
|
||||
// pop from the heap while it's in the past.
|
||||
for epHeap.Len() > 0 {
|
||||
minElem := epHeap.Min()
|
||||
if now.After(minElem.until) {
|
||||
heap.Pop(epHeap)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if epHeap.Len() == 0 {
|
||||
// Free up space in the map by removing the empty heap.
|
||||
delete(et.endpoints, k)
|
||||
}
|
||||
}
|
||||
}
|
@ -0,0 +1,187 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package magicsock
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
func TestEndpointTracker(t *testing.T) {
|
||||
local := tailcfg.Endpoint{
|
||||
Addr: netip.MustParseAddrPort("192.168.1.1:12345"),
|
||||
Type: tailcfg.EndpointLocal,
|
||||
}
|
||||
|
||||
stun4_1 := tailcfg.Endpoint{
|
||||
Addr: netip.MustParseAddrPort("1.2.3.4:12345"),
|
||||
Type: tailcfg.EndpointSTUN,
|
||||
}
|
||||
stun4_2 := tailcfg.Endpoint{
|
||||
Addr: netip.MustParseAddrPort("5.6.7.8:12345"),
|
||||
Type: tailcfg.EndpointSTUN,
|
||||
}
|
||||
|
||||
stun6_1 := tailcfg.Endpoint{
|
||||
Addr: netip.MustParseAddrPort("[2a09:8280:1::1111]:12345"),
|
||||
Type: tailcfg.EndpointSTUN,
|
||||
}
|
||||
stun6_2 := tailcfg.Endpoint{
|
||||
Addr: netip.MustParseAddrPort("[2a09:8280:1::2222]:12345"),
|
||||
Type: tailcfg.EndpointSTUN,
|
||||
}
|
||||
|
||||
start := time.Unix(1681503440, 0)
|
||||
|
||||
steps := []struct {
|
||||
name string
|
||||
now time.Time
|
||||
eps []tailcfg.Endpoint
|
||||
want []tailcfg.Endpoint
|
||||
}{
|
||||
{
|
||||
name: "initial endpoints",
|
||||
now: start,
|
||||
eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||||
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||||
},
|
||||
{
|
||||
name: "no change",
|
||||
now: start.Add(1 * time.Minute),
|
||||
eps: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||||
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||||
},
|
||||
{
|
||||
name: "missing stun4",
|
||||
now: start.Add(2 * time.Minute),
|
||||
eps: []tailcfg.Endpoint{local, stun6_1},
|
||||
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||||
},
|
||||
{
|
||||
name: "missing stun6",
|
||||
now: start.Add(3 * time.Minute),
|
||||
eps: []tailcfg.Endpoint{local, stun4_1},
|
||||
want: []tailcfg.Endpoint{local, stun4_1, stun6_1},
|
||||
},
|
||||
{
|
||||
name: "multiple STUN addresses within timeout",
|
||||
now: start.Add(4 * time.Minute),
|
||||
eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
|
||||
want: []tailcfg.Endpoint{local, stun4_1, stun4_2, stun6_1, stun6_2},
|
||||
},
|
||||
{
|
||||
name: "endpoint extended",
|
||||
now: start.Add(3*time.Minute + endpointTrackerLifetime - 1),
|
||||
eps: []tailcfg.Endpoint{local},
|
||||
want: []tailcfg.Endpoint{
|
||||
local, stun4_2, stun6_2,
|
||||
// stun4_1 had its lifetime extended by the
|
||||
// "missing stun6" test above to that start
|
||||
// time plus the lifetime, while stun6 should
|
||||
// have expired a minute sooner. It should thus
|
||||
// be in this returned list.
|
||||
stun4_1,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "after timeout",
|
||||
now: start.Add(4*time.Minute + endpointTrackerLifetime + 1),
|
||||
eps: []tailcfg.Endpoint{local, stun4_2, stun6_2},
|
||||
want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
|
||||
},
|
||||
{
|
||||
name: "after timeout still caches",
|
||||
now: start.Add(4*time.Minute + endpointTrackerLifetime + time.Minute),
|
||||
eps: []tailcfg.Endpoint{local},
|
||||
want: []tailcfg.Endpoint{local, stun4_2, stun6_2},
|
||||
},
|
||||
}
|
||||
|
||||
var et endpointTracker
|
||||
for _, tt := range steps {
|
||||
t.Logf("STEP: %s", tt.name)
|
||||
|
||||
got := et.update(tt.now, tt.eps)
|
||||
|
||||
// Sort both arrays for comparison
|
||||
slices.SortFunc(got, func(a, b tailcfg.Endpoint) int {
|
||||
return strings.Compare(a.Addr.String(), b.Addr.String())
|
||||
})
|
||||
slices.SortFunc(tt.want, func(a, b tailcfg.Endpoint) int {
|
||||
return strings.Compare(a.Addr.String(), b.Addr.String())
|
||||
})
|
||||
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEndpointTrackerMaxNum(t *testing.T) {
|
||||
start := time.Unix(1681503440, 0)
|
||||
|
||||
var allEndpoints []tailcfg.Endpoint // all created endpoints
|
||||
mkEp := func(i int) tailcfg.Endpoint {
|
||||
ep := tailcfg.Endpoint{
|
||||
Addr: netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), uint16(i)),
|
||||
Type: tailcfg.EndpointSTUN,
|
||||
}
|
||||
allEndpoints = append(allEndpoints, ep)
|
||||
return ep
|
||||
}
|
||||
|
||||
var et endpointTracker
|
||||
|
||||
// Add more endpoints to the list than our limit
|
||||
for i := 0; i <= endpointTrackerMaxPerAddr; i++ {
|
||||
et.update(start.Add(time.Duration(i)*time.Second), []tailcfg.Endpoint{mkEp(10000 + i)})
|
||||
}
|
||||
|
||||
// Now add two more, slightly later
|
||||
got := et.update(start.Add(1*time.Minute), []tailcfg.Endpoint{
|
||||
mkEp(10100),
|
||||
mkEp(10101),
|
||||
})
|
||||
|
||||
// We expect to get the last N endpoints per our per-Addr limit, since
|
||||
// all of the endpoints have the same netip.Addr. The first endpoint(s)
|
||||
// that we added were dropped because we had more than the limit for
|
||||
// this Addr.
|
||||
want := allEndpoints[len(allEndpoints)-endpointTrackerMaxPerAddr:]
|
||||
|
||||
compareEndpoints := func(got, want []tailcfg.Endpoint) {
|
||||
t.Helper()
|
||||
slices.SortFunc(want, func(a, b tailcfg.Endpoint) int {
|
||||
return strings.Compare(a.Addr.String(), b.Addr.String())
|
||||
})
|
||||
slices.SortFunc(got, func(a, b tailcfg.Endpoint) int {
|
||||
return strings.Compare(a.Addr.String(), b.Addr.String())
|
||||
})
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, want)
|
||||
}
|
||||
}
|
||||
compareEndpoints(got, want)
|
||||
|
||||
// However, if we have more than our limit of endpoints passed in to
|
||||
// the endpointTracker, we will return all of them (even if they're for
|
||||
// the same address).
|
||||
var inputEps []tailcfg.Endpoint
|
||||
for i := 0; i < endpointTrackerMaxPerAddr+5; i++ {
|
||||
inputEps = append(inputEps, tailcfg.Endpoint{
|
||||
Addr: netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 10200+uint16(i)),
|
||||
Type: tailcfg.EndpointSTUN,
|
||||
})
|
||||
}
|
||||
|
||||
want = inputEps
|
||||
got = et.update(start.Add(2*time.Minute), inputEps)
|
||||
compareEndpoints(got, want)
|
||||
}
|
Loading…
Reference in New Issue