diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index a20495608..0a685d61f 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -292,6 +292,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/tailcfg from tailscale.com/client/tailscale/apitype+ 💣 tailscale.com/tempfork/device from tailscale.com/net/tstun/table LD tailscale.com/tempfork/gliderlabs/ssh from tailscale.com/ssh/tailssh + tailscale.com/tempfork/heap from tailscale.com/wgengine/magicsock tailscale.com/tka from tailscale.com/ipn/ipnlocal+ W tailscale.com/tsconst from tailscale.com/net/interfaces tailscale.com/tsd from tailscale.com/cmd/tailscaled+ @@ -411,6 +412,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de golang.org/x/time/rate from gvisor.dev/gvisor/pkg/tcpip/stack+ bufio from compress/flate+ bytes from bufio+ + cmp from slices compress/flate from compress/gzip+ compress/gzip from golang.org/x/net/http2+ W compress/zlib from debug/pe @@ -495,6 +497,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de runtime/debug from github.com/klauspost/compress/zstd+ runtime/pprof from tailscale.com/log/logheap+ runtime/trace from net/http/pprof + slices from tailscale.com/wgengine/magicsock sort from compress/flate+ strconv from compress/flate+ strings from bufio+ diff --git a/wgengine/magicsock/endpoint_tracker.go b/wgengine/magicsock/endpoint_tracker.go new file mode 100644 index 000000000..5caddd1a0 --- /dev/null +++ b/wgengine/magicsock/endpoint_tracker.go @@ -0,0 +1,248 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + "slices" + "sync" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/tempfork/heap" + "tailscale.com/util/mak" + "tailscale.com/util/set" +) + +const ( + // endpointTrackerLifetime is how long we continue advertising an + // endpoint after we last see it. This is intentionally chosen to be + // slightly longer than a full netcheck period. + endpointTrackerLifetime = 5*time.Minute + 10*time.Second + + // endpointTrackerMaxPerAddr is how many cached addresses we track for + // a given netip.Addr. This allows e.g. restricting the number of STUN + // endpoints we cache (which usually have the same netip.Addr but + // different ports). + // + // The value of 6 is chosen because we can advertise up to 3 endpoints + // based on the STUN IP: + // 1. The STUN endpoint itself (EndpointSTUN) + // 2. The STUN IP with the local Tailscale port (EndpointSTUN4LocalPort) + // 3. The STUN IP with a portmapped port (EndpointPortmapped) + // + // Storing 6 endpoints in the cache means we can store up to 2 previous + // sets of endpoints. + endpointTrackerMaxPerAddr = 6 +) + +// endpointTrackerEntry is an entry in an endpointHeap that stores the state of +// a given cached endpoint. +type endpointTrackerEntry struct { + // endpoint is the cached endpoint. + endpoint tailcfg.Endpoint + // until is the time until which this endpoint is being cached. + until time.Time + // index is the index within the containing endpointHeap. + index int +} + +// endpointHeap is an ordered heap of endpointTrackerEntry structs, ordered in +// ascending order by the 'until' expiry time (i.e. oldest first). +type endpointHeap []*endpointTrackerEntry + +var _ heap.Interface[*endpointTrackerEntry] = (*endpointHeap)(nil) + +// Len implements heap.Interface. +func (eh endpointHeap) Len() int { return len(eh) } + +// Less implements heap.Interface. +func (eh endpointHeap) Less(i, j int) bool { + // We want to store items so that the lowest item in the heap is the + // oldest, so that heap.Pop()-ing from the endpointHeap will remove the + // oldest entry. + return eh[i].until.Before(eh[j].until) +} + +// Swap implements heap.Interface. +func (eh endpointHeap) Swap(i, j int) { + eh[i], eh[j] = eh[j], eh[i] + eh[i].index = i + eh[j].index = j +} + +// Push implements heap.Interface. +func (eh *endpointHeap) Push(item *endpointTrackerEntry) { + n := len(*eh) + item.index = n + *eh = append(*eh, item) +} + +// Pop implements heap.Interface. +func (eh *endpointHeap) Pop() *endpointTrackerEntry { + old := *eh + n := len(old) + item := old[n-1] + old[n-1] = nil // avoid memory leak + item.index = -1 // for safety + *eh = old[0 : n-1] + return item +} + +// Min returns a pointer to the minimum element in the heap, without removing +// it. Since this is a min-heap ordered by the 'until' field, this returns the +// chronologically "earliest" element in the heap. +// +// Len() must be non-zero. +func (eh endpointHeap) Min() *endpointTrackerEntry { + return eh[0] +} + +// endpointTracker caches endpoints that are advertised to peers. This allows +// peers to still reach this node if there's a temporary endpoint flap; rather +// than withdrawing an endpoint and then re-advertising it the next time we run +// a netcheck, we keep advertising the endpoint until it's not present for a +// defined timeout. +// +// See tailscale/tailscale#7877 for more information. +type endpointTracker struct { + mu sync.Mutex + endpoints map[netip.Addr]*endpointHeap +} + +// update takes as input the current sent of discovered endpoints and the +// current time, and returns the set of endpoints plus any previous-cached and +// non-expired endpoints that should be advertised to peers. +func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { + var inputEps set.Slice[netip.AddrPort] + for _, ep := range eps { + inputEps.Add(ep.Addr) + } + + et.mu.Lock() + defer et.mu.Unlock() + + // Extend endpoints that already exist in the cache. We do this before + // we remove expired endpoints, below, so we don't remove something + // that would otherwise have survived by extending. + until := now.Add(endpointTrackerLifetime) + for _, ep := range eps { + et.extendLocked(ep, until) + } + + // Now that we've extended existing endpoints, remove everything that + // has expired. + et.removeExpiredLocked(now) + + // Add entries from the input set of endpoints into the cache; we do + // this after removing expired ones so that we can store as many as + // possible, with space freed by the entries removed after expiry. + for _, ep := range eps { + et.addLocked(now, ep, until) + } + + // Finally, add entries to the return array that aren't already there. + epsPlusCached = eps + for _, heap := range et.endpoints { + for _, ep := range *heap { + // If the endpoint was in the input list, or has expired, skip it. + if inputEps.Contains(ep.endpoint.Addr) { + continue + } else if now.After(ep.until) { + // Defense-in-depth; should never happen since + // we removed expired entries above, but ignore + // it anyway. + continue + } + + // We haven't seen this endpoint; add to the return array + epsPlusCached = append(epsPlusCached, ep.endpoint) + } + } + + return epsPlusCached +} + +// extendLocked will update the expiry time of the provided endpoint in the +// cache, if it is present. If it is not present, nothing will be done. +// +// et.mu must be held. +func (et *endpointTracker) extendLocked(ep tailcfg.Endpoint, until time.Time) { + key := ep.Addr.Addr() + epHeap, found := et.endpoints[key] + if !found { + return + } + + // Find the entry for this exact address; this loop is quick since we + // bound the number of items in the heap. + // + // TODO(andrew): this means we iterate over the entire heap once per + // endpoint; even if the heap is small, if we have a lot of input + // endpoints this can be expensive? + for i, entry := range *epHeap { + if entry.endpoint == ep { + entry.until = until + heap.Fix(epHeap, i) + return + } + } +} + +// addLocked will store the provided endpoint(s) in the cache for a fixed +// period of time, ensuring that the size of the endpoint cache remains below +// the maximum. +// +// et.mu must be held. +func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { + key := ep.Addr.Addr() + + // Create or get the heap for this endpoint's addr + epHeap := et.endpoints[key] + if epHeap == nil { + epHeap = new(endpointHeap) + mak.Set(&et.endpoints, key, epHeap) + } + + // Find the entry for this exact address; this loop is quick + // since we bound the number of items in the heap. + found := slices.ContainsFunc(*epHeap, func(v *endpointTrackerEntry) bool { + return v.endpoint == ep + }) + if !found { + // Add address to heap; either the endpoint is new, or the heap + // was newly-created and thus empty. + heap.Push(epHeap, &endpointTrackerEntry{endpoint: ep, until: until}) + } + + // Now that we've added everything, pop from our heap until we're below + // the limit. This is a min-heap, so popping removes the lowest (and + // thus oldest) endpoint. + for epHeap.Len() > endpointTrackerMaxPerAddr { + heap.Pop(epHeap) + } +} + +// removeExpired will remove all expired entries from the cache. +// +// et.mu must be held. +func (et *endpointTracker) removeExpiredLocked(now time.Time) { + for k, epHeap := range et.endpoints { + // The minimum element is oldest/earliest endpoint; repeatedly + // pop from the heap while it's in the past. + for epHeap.Len() > 0 { + minElem := epHeap.Min() + if now.After(minElem.until) { + heap.Pop(epHeap) + } else { + break + } + } + + if epHeap.Len() == 0 { + // Free up space in the map by removing the empty heap. + delete(et.endpoints, k) + } + } +} diff --git a/wgengine/magicsock/endpoint_tracker_test.go b/wgengine/magicsock/endpoint_tracker_test.go new file mode 100644 index 000000000..b6a2699c1 --- /dev/null +++ b/wgengine/magicsock/endpoint_tracker_test.go @@ -0,0 +1,187 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + "reflect" + "slices" + "strings" + "testing" + "time" + + "tailscale.com/tailcfg" +) + +func TestEndpointTracker(t *testing.T) { + local := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("192.168.1.1:12345"), + Type: tailcfg.EndpointLocal, + } + + stun4_1 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("1.2.3.4:12345"), + Type: tailcfg.EndpointSTUN, + } + stun4_2 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("5.6.7.8:12345"), + Type: tailcfg.EndpointSTUN, + } + + stun6_1 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("[2a09:8280:1::1111]:12345"), + Type: tailcfg.EndpointSTUN, + } + stun6_2 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("[2a09:8280:1::2222]:12345"), + Type: tailcfg.EndpointSTUN, + } + + start := time.Unix(1681503440, 0) + + steps := []struct { + name string + now time.Time + eps []tailcfg.Endpoint + want []tailcfg.Endpoint + }{ + { + name: "initial endpoints", + now: start, + eps: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "no change", + now: start.Add(1 * time.Minute), + eps: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "missing stun4", + now: start.Add(2 * time.Minute), + eps: []tailcfg.Endpoint{local, stun6_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "missing stun6", + now: start.Add(3 * time.Minute), + eps: []tailcfg.Endpoint{local, stun4_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "multiple STUN addresses within timeout", + now: start.Add(4 * time.Minute), + eps: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + want: []tailcfg.Endpoint{local, stun4_1, stun4_2, stun6_1, stun6_2}, + }, + { + name: "endpoint extended", + now: start.Add(3*time.Minute + endpointTrackerLifetime - 1), + eps: []tailcfg.Endpoint{local}, + want: []tailcfg.Endpoint{ + local, stun4_2, stun6_2, + // stun4_1 had its lifetime extended by the + // "missing stun6" test above to that start + // time plus the lifetime, while stun6 should + // have expired a minute sooner. It should thus + // be in this returned list. + stun4_1, + }, + }, + { + name: "after timeout", + now: start.Add(4*time.Minute + endpointTrackerLifetime + 1), + eps: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + want: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + }, + { + name: "after timeout still caches", + now: start.Add(4*time.Minute + endpointTrackerLifetime + time.Minute), + eps: []tailcfg.Endpoint{local}, + want: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + }, + } + + var et endpointTracker + for _, tt := range steps { + t.Logf("STEP: %s", tt.name) + + got := et.update(tt.now, tt.eps) + + // Sort both arrays for comparison + slices.SortFunc(got, func(a, b tailcfg.Endpoint) int { + return strings.Compare(a.Addr.String(), b.Addr.String()) + }) + slices.SortFunc(tt.want, func(a, b tailcfg.Endpoint) int { + return strings.Compare(a.Addr.String(), b.Addr.String()) + }) + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, tt.want) + } + } +} + +func TestEndpointTrackerMaxNum(t *testing.T) { + start := time.Unix(1681503440, 0) + + var allEndpoints []tailcfg.Endpoint // all created endpoints + mkEp := func(i int) tailcfg.Endpoint { + ep := tailcfg.Endpoint{ + Addr: netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), uint16(i)), + Type: tailcfg.EndpointSTUN, + } + allEndpoints = append(allEndpoints, ep) + return ep + } + + var et endpointTracker + + // Add more endpoints to the list than our limit + for i := 0; i <= endpointTrackerMaxPerAddr; i++ { + et.update(start.Add(time.Duration(i)*time.Second), []tailcfg.Endpoint{mkEp(10000 + i)}) + } + + // Now add two more, slightly later + got := et.update(start.Add(1*time.Minute), []tailcfg.Endpoint{ + mkEp(10100), + mkEp(10101), + }) + + // We expect to get the last N endpoints per our per-Addr limit, since + // all of the endpoints have the same netip.Addr. The first endpoint(s) + // that we added were dropped because we had more than the limit for + // this Addr. + want := allEndpoints[len(allEndpoints)-endpointTrackerMaxPerAddr:] + + compareEndpoints := func(got, want []tailcfg.Endpoint) { + t.Helper() + slices.SortFunc(want, func(a, b tailcfg.Endpoint) int { + return strings.Compare(a.Addr.String(), b.Addr.String()) + }) + slices.SortFunc(got, func(a, b tailcfg.Endpoint) int { + return strings.Compare(a.Addr.String(), b.Addr.String()) + }) + if !reflect.DeepEqual(got, want) { + t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, want) + } + } + compareEndpoints(got, want) + + // However, if we have more than our limit of endpoints passed in to + // the endpointTracker, we will return all of them (even if they're for + // the same address). + var inputEps []tailcfg.Endpoint + for i := 0; i < endpointTrackerMaxPerAddr+5; i++ { + inputEps = append(inputEps, tailcfg.Endpoint{ + Addr: netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 10200+uint16(i)), + Type: tailcfg.EndpointSTUN, + }) + } + + want = inputEps + got = et.update(start.Add(2*time.Minute), inputEps) + compareEndpoints(got, want) +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index e552b8826..3f63547ae 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -53,7 +53,6 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/mak" "tailscale.com/util/ringbuffer" - "tailscale.com/util/set" "tailscale.com/util/uniq" "tailscale.com/version" "tailscale.com/wgengine/capture" @@ -2594,11 +2593,6 @@ const ( // STUN-derived endpoint valid for. UDP NAT mappings typically // expire at 30 seconds, so this is a few seconds shy of that. endpointsFreshEnoughDuration = 27 * time.Second - - // endpointTrackerLifetime is how long we continue advertising an - // endpoint after we last see it. This is intentionally chosen to be - // slightly longer than a full netcheck period. - endpointTrackerLifetime = 5*time.Minute + 10*time.Second ) // Constants that are variable for testing. @@ -2683,79 +2677,6 @@ type discoInfo struct { lastPingTime time.Time } -type endpointTrackerEntry struct { - endpoint tailcfg.Endpoint - until time.Time -} - -type endpointTracker struct { - mu sync.Mutex - cache map[netip.AddrPort]endpointTrackerEntry -} - -func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { - epsPlusCached = eps - - var inputEps set.Slice[netip.AddrPort] - for _, ep := range eps { - inputEps.Add(ep.Addr) - } - - et.mu.Lock() - defer et.mu.Unlock() - - // Add entries to the return array that aren't already there. - for k, ep := range et.cache { - // If the endpoint was in the input list, or has expired, skip it. - if inputEps.Contains(k) { - continue - } else if now.After(ep.until) { - continue - } - - // We haven't seen this endpoint; add to the return array - epsPlusCached = append(epsPlusCached, ep.endpoint) - } - - // Add entries from the original input array into the cache, and/or - // extend the lifetime of entries that are already in the cache. - until := now.Add(endpointTrackerLifetime) - for _, ep := range eps { - et.addLocked(now, ep, until) - } - - // Remove everything that has now expired. - et.removeExpiredLocked(now) - return epsPlusCached -} - -// add will store the provided endpoint(s) in the cache for a fixed period of -// time, and remove any entries in the cache that have expired. -// -// et.mu must be held. -func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { - // If we already have an entry for this endpoint, update the timeout on - // it; otherwise, add it. - entry, found := et.cache[ep.Addr] - if found { - entry.until = until - } else { - entry = endpointTrackerEntry{ep, until} - } - mak.Set(&et.cache, ep.Addr, entry) -} - -// removeExpired will remove all expired entries from the cache -// -// et.mu must be held -func (et *endpointTracker) removeExpiredLocked(now time.Time) { - for k, ep := range et.cache { - if now.After(ep.until) { - delete(et.cache, k) - } - } -} - var ( metricNumPeers = clientmetric.NewGauge("magicsock_netmap_num_peers") metricNumDERPConns = clientmetric.NewGauge("magicsock_num_derp_conns") diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index b6bfef107..4c4153bf3 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -18,7 +18,6 @@ import ( "net/http/httptest" "net/netip" "os" - "reflect" "runtime" "strconv" "strings" @@ -33,7 +32,6 @@ import ( "github.com/tailscale/wireguard-go/tun/tuntest" "go4.org/mem" "golang.org/x/exp/maps" - "golang.org/x/exp/slices" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -2341,116 +2339,6 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) { } } -func TestEndpointTracker(t *testing.T) { - local := tailcfg.Endpoint{ - Addr: netip.MustParseAddrPort("192.168.1.1:12345"), - Type: tailcfg.EndpointLocal, - } - - stun4_1 := tailcfg.Endpoint{ - Addr: netip.MustParseAddrPort("1.2.3.4:12345"), - Type: tailcfg.EndpointSTUN, - } - stun4_2 := tailcfg.Endpoint{ - Addr: netip.MustParseAddrPort("5.6.7.8:12345"), - Type: tailcfg.EndpointSTUN, - } - - stun6_1 := tailcfg.Endpoint{ - Addr: netip.MustParseAddrPort("[2a09:8280:1::1111]:12345"), - Type: tailcfg.EndpointSTUN, - } - stun6_2 := tailcfg.Endpoint{ - Addr: netip.MustParseAddrPort("[2a09:8280:1::2222]:12345"), - Type: tailcfg.EndpointSTUN, - } - - start := time.Unix(1681503440, 0) - - steps := []struct { - name string - now time.Time - eps []tailcfg.Endpoint - want []tailcfg.Endpoint - }{ - { - name: "initial endpoints", - now: start, - eps: []tailcfg.Endpoint{local, stun4_1, stun6_1}, - want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, - }, - { - name: "no change", - now: start.Add(1 * time.Minute), - eps: []tailcfg.Endpoint{local, stun4_1, stun6_1}, - want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, - }, - { - name: "missing stun4", - now: start.Add(2 * time.Minute), - eps: []tailcfg.Endpoint{local, stun6_1}, - want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, - }, - { - name: "missing stun6", - now: start.Add(3 * time.Minute), - eps: []tailcfg.Endpoint{local, stun4_1}, - want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, - }, - { - name: "multiple STUN addresses within timeout", - now: start.Add(4 * time.Minute), - eps: []tailcfg.Endpoint{local, stun4_2, stun6_2}, - want: []tailcfg.Endpoint{local, stun4_1, stun4_2, stun6_1, stun6_2}, - }, - { - name: "endpoint extended", - now: start.Add(3*time.Minute + endpointTrackerLifetime - 1), - eps: []tailcfg.Endpoint{local}, - want: []tailcfg.Endpoint{ - local, stun4_2, stun6_2, - // stun4_1 had its lifetime extended by the - // "missing stun6" test above to that start - // time plus the lifetime, while stun6 should - // have expired a minute sooner. It should thus - // be in this returned list. - stun4_1, - }, - }, - { - name: "after timeout", - now: start.Add(4*time.Minute + endpointTrackerLifetime + 1), - eps: []tailcfg.Endpoint{local, stun4_2, stun6_2}, - want: []tailcfg.Endpoint{local, stun4_2, stun6_2}, - }, - { - name: "after timeout still caches", - now: start.Add(4*time.Minute + endpointTrackerLifetime + time.Minute), - eps: []tailcfg.Endpoint{local}, - want: []tailcfg.Endpoint{local, stun4_2, stun6_2}, - }, - } - - var et endpointTracker - for _, tt := range steps { - t.Logf("STEP: %s", tt.name) - - got := et.update(tt.now, tt.eps) - - // Sort both arrays for comparison - slices.SortFunc(got, func(a, b tailcfg.Endpoint) int { - return strings.Compare(a.Addr.String(), b.Addr.String()) - }) - slices.SortFunc(tt.want, func(a, b tailcfg.Endpoint) int { - return strings.Compare(a.Addr.String(), b.Addr.String()) - }) - - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, tt.want) - } - } -} - // applyNetworkMap is a test helper that sets the network map and // configures WG. func applyNetworkMap(t *testing.T, m *magicStack, nm *netmap.NetworkMap) {