From 80b138f0df7c53b30d9a29a47dc1fc981ee6813b Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Fri, 14 Apr 2023 16:37:10 -0400 Subject: [PATCH] wgengine/magicsock: keep advertising endpoints after we stop discovering them Previously, when updating endpoints we would immediately stop advertising any endpoint that wasn't discovered during determineEndpoints. This could result in, for example, a case where we performed an incremental netcheck, didn't get any of our three STUN packets back, and then dropped our STUN endpoint from the set of advertised endpoints... which would result in clients falling back to a DERP connection until the next call to determineEndpoints. Instead, let's cache endpoints that we've discovered and continue reporting them to clients until a timeout expires. In the above case where we temporarily don't have a discovered STUN endpoint, we would continue reporting the old value, then re-discover the STUN endpoint again and continue reporting it as normal, so clients never see a withdrawal. Updates tailscale/coral#108 Signed-off-by: Andrew Dunham Change-Id: I42de72e7418ab328a6c732bdefc74549708cf8b9 --- wgengine/magicsock/magicsock.go | 99 +++++++++++++++++++++++ wgengine/magicsock/magicsock_test.go | 113 +++++++++++++++++++++++++++ 2 files changed, 212 insertions(+) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 8374193da..f315e88b6 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -64,6 +64,7 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/mak" "tailscale.com/util/ringbuffer" + "tailscale.com/util/set" "tailscale.com/util/sysresources" "tailscale.com/util/uniq" "tailscale.com/version" @@ -419,6 +420,10 @@ type Conn struct { // when endpoints are refreshed. onEndpointRefreshed map[*endpoint]func() + // endpointTracker tracks the set of cached endpoints that we advertise + // for a period of time before withdrawing them. + endpointTracker endpointTracker + // peerSet is the set of peers that are currently configured in // WireGuard. These are not used to filter inbound or outbound // traffic at all, but only to track what state can be cleaned up @@ -1196,6 +1201,22 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro c.ignoreSTUNPackets() + // Update our set of endpoints by adding any endpoints that we + // previously found but haven't expired yet. This also updates the + // cache with the set of endpoints discovered in this function. + // + // NOTE: we do this here and not below so that we don't cache local + // endpoints; we know that the local endpoints we discover are all + // possible local endpoints since we determine them by looking at the + // set of addresses on our local interfaces. + // + // TODO(andrew): If we pull in any cached endpoints, we should probably + // do something to ensure we're propagating the removal of those cached + // endpoints if they do actually time out without being rediscovered. + // For now, though, rely on a minor LinkChange event causing this to + // re-run. + eps = c.endpointTracker.update(time.Now(), eps) + if localAddr := c.pconn4.LocalAddr(); localAddr.IP.IsUnspecified() { ips, loopback, err := interfaces.LocalAddresses() if err != nil { @@ -4148,6 +4169,11 @@ const ( // STUN-derived endpoint valid for. UDP NAT mappings typically // expire at 30 seconds, so this is a few seconds shy of that. endpointsFreshEnoughDuration = 27 * time.Second + + // endpointTrackerLifetime is how long we continue advertising an + // endpoint after we last see it. This is intentionally chosen to be + // slightly longer than a full netcheck period. + endpointTrackerLifetime = 5*time.Minute + 10*time.Second ) // Constants that are variable for testing. @@ -5105,6 +5131,79 @@ func (s derpAddrFamSelector) PreferIPv6() bool { return false } +type endpointTrackerEntry struct { + endpoint tailcfg.Endpoint + until time.Time +} + +type endpointTracker struct { + mu sync.Mutex + cache map[netip.AddrPort]endpointTrackerEntry +} + +func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { + epsPlusCached = eps + + var inputEps set.Slice[netip.AddrPort] + for _, ep := range eps { + inputEps.Add(ep.Addr) + } + + et.mu.Lock() + defer et.mu.Unlock() + + // Add entries to the return array that aren't already there. + for k, ep := range et.cache { + // If the endpoint was in the input list, or has expired, skip it. + if inputEps.Contains(k) { + continue + } else if now.After(ep.until) { + continue + } + + // We haven't seen this endpoint; add to the return array + epsPlusCached = append(epsPlusCached, ep.endpoint) + } + + // Add entries from the original input array into the cache, and/or + // extend the lifetime of entries that are already in the cache. + until := now.Add(endpointTrackerLifetime) + for _, ep := range eps { + et.addLocked(now, ep, until) + } + + // Remove everything that has now expired. + et.removeExpiredLocked(now) + return epsPlusCached +} + +// add will store the provided endpoint(s) in the cache for a fixed period of +// time, and remove any entries in the cache that have expired. +// +// et.mu must be held. +func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { + // If we already have an entry for this endpoint, update the timeout on + // it; otherwise, add it. + entry, found := et.cache[ep.Addr] + if found { + entry.until = until + } else { + entry = endpointTrackerEntry{ep, until} + } + mak.Set(&et.cache, ep.Addr, entry) +} + +// removeExpired will remove all expired entries from the cache +// +// et.mu must be held +func (et *endpointTracker) removeExpiredLocked(now time.Time) { + for k, ep := range et.cache { + if now.After(ep.until) { + delete(et.cache, k) + } + } +} + var ( metricNumPeers = clientmetric.NewGauge("magicsock_netmap_num_peers") metricNumDERPConns = clientmetric.NewGauge("magicsock_num_derp_conns") diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 6e844f26c..c050cfc20 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -18,6 +18,7 @@ import ( "net/http/httptest" "net/netip" "os" + "reflect" "runtime" "strconv" "strings" @@ -31,6 +32,7 @@ import ( "github.com/tailscale/wireguard-go/tun/tuntest" "go4.org/mem" "golang.org/x/exp/maps" + "golang.org/x/exp/slices" "golang.org/x/net/ipv6" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/derp" @@ -390,6 +392,7 @@ collectEndpoints: for { select { case ep := <-epCh: + t.Logf("TestNewConn: got endpoint: %v", ep) endpoints = append(endpoints, ep) if strings.HasSuffix(ep, suffix) { break collectEndpoints @@ -2280,3 +2283,113 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) { t.Fatal("no packet after 1s") } } + +func TestEndpointTracker(t *testing.T) { + local := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("192.168.1.1:12345"), + Type: tailcfg.EndpointLocal, + } + + stun4_1 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("1.2.3.4:12345"), + Type: tailcfg.EndpointSTUN, + } + stun4_2 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("5.6.7.8:12345"), + Type: tailcfg.EndpointSTUN, + } + + stun6_1 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("[2a09:8280:1::1111]:12345"), + Type: tailcfg.EndpointSTUN, + } + stun6_2 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("[2a09:8280:1::2222]:12345"), + Type: tailcfg.EndpointSTUN, + } + + start := time.Unix(1681503440, 0) + + steps := []struct { + name string + now time.Time + eps []tailcfg.Endpoint + want []tailcfg.Endpoint + }{ + { + name: "initial endpoints", + now: start, + eps: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "no change", + now: start.Add(1 * time.Minute), + eps: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "missing stun4", + now: start.Add(2 * time.Minute), + eps: []tailcfg.Endpoint{local, stun6_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "missing stun6", + now: start.Add(3 * time.Minute), + eps: []tailcfg.Endpoint{local, stun4_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "multiple STUN addresses within timeout", + now: start.Add(4 * time.Minute), + eps: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + want: []tailcfg.Endpoint{local, stun4_1, stun4_2, stun6_1, stun6_2}, + }, + { + name: "endpoint extended", + now: start.Add(3*time.Minute + endpointTrackerLifetime - 1), + eps: []tailcfg.Endpoint{local}, + want: []tailcfg.Endpoint{ + local, stun4_2, stun6_2, + // stun4_1 had its lifetime extended by the + // "missing stun6" test above to that start + // time plus the lifetime, while stun6 should + // have expired a minute sooner. It should thus + // be in this returned list. + stun4_1, + }, + }, + { + name: "after timeout", + now: start.Add(4*time.Minute + endpointTrackerLifetime + 1), + eps: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + want: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + }, + { + name: "after timeout still caches", + now: start.Add(4*time.Minute + endpointTrackerLifetime + time.Minute), + eps: []tailcfg.Endpoint{local}, + want: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + }, + } + + var et endpointTracker + for _, tt := range steps { + t.Logf("STEP: %s", tt.name) + + got := et.update(tt.now, tt.eps) + + // Sort both arrays for comparison + slices.SortFunc(got, func(a, b tailcfg.Endpoint) bool { + return a.Addr.String() < b.Addr.String() + }) + slices.SortFunc(tt.want, func(a, b tailcfg.Endpoint) bool { + return a.Addr.String() < b.Addr.String() + }) + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, tt.want) + } + } +}