diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 8374193da..f315e88b6 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -64,6 +64,7 @@ import ( "tailscale.com/util/clientmetric" "tailscale.com/util/mak" "tailscale.com/util/ringbuffer" + "tailscale.com/util/set" "tailscale.com/util/sysresources" "tailscale.com/util/uniq" "tailscale.com/version" @@ -419,6 +420,10 @@ type Conn struct { // when endpoints are refreshed. onEndpointRefreshed map[*endpoint]func() + // endpointTracker tracks the set of cached endpoints that we advertise + // for a period of time before withdrawing them. + endpointTracker endpointTracker + // peerSet is the set of peers that are currently configured in // WireGuard. These are not used to filter inbound or outbound // traffic at all, but only to track what state can be cleaned up @@ -1196,6 +1201,22 @@ func (c *Conn) determineEndpoints(ctx context.Context) ([]tailcfg.Endpoint, erro c.ignoreSTUNPackets() + // Update our set of endpoints by adding any endpoints that we + // previously found but haven't expired yet. This also updates the + // cache with the set of endpoints discovered in this function. + // + // NOTE: we do this here and not below so that we don't cache local + // endpoints; we know that the local endpoints we discover are all + // possible local endpoints since we determine them by looking at the + // set of addresses on our local interfaces. + // + // TODO(andrew): If we pull in any cached endpoints, we should probably + // do something to ensure we're propagating the removal of those cached + // endpoints if they do actually time out without being rediscovered. + // For now, though, rely on a minor LinkChange event causing this to + // re-run. + eps = c.endpointTracker.update(time.Now(), eps) + if localAddr := c.pconn4.LocalAddr(); localAddr.IP.IsUnspecified() { ips, loopback, err := interfaces.LocalAddresses() if err != nil { @@ -4148,6 +4169,11 @@ const ( // STUN-derived endpoint valid for. UDP NAT mappings typically // expire at 30 seconds, so this is a few seconds shy of that. endpointsFreshEnoughDuration = 27 * time.Second + + // endpointTrackerLifetime is how long we continue advertising an + // endpoint after we last see it. This is intentionally chosen to be + // slightly longer than a full netcheck period. + endpointTrackerLifetime = 5*time.Minute + 10*time.Second ) // Constants that are variable for testing. @@ -5105,6 +5131,79 @@ func (s derpAddrFamSelector) PreferIPv6() bool { return false } +type endpointTrackerEntry struct { + endpoint tailcfg.Endpoint + until time.Time +} + +type endpointTracker struct { + mu sync.Mutex + cache map[netip.AddrPort]endpointTrackerEntry +} + +func (et *endpointTracker) update(now time.Time, eps []tailcfg.Endpoint) (epsPlusCached []tailcfg.Endpoint) { + epsPlusCached = eps + + var inputEps set.Slice[netip.AddrPort] + for _, ep := range eps { + inputEps.Add(ep.Addr) + } + + et.mu.Lock() + defer et.mu.Unlock() + + // Add entries to the return array that aren't already there. + for k, ep := range et.cache { + // If the endpoint was in the input list, or has expired, skip it. + if inputEps.Contains(k) { + continue + } else if now.After(ep.until) { + continue + } + + // We haven't seen this endpoint; add to the return array + epsPlusCached = append(epsPlusCached, ep.endpoint) + } + + // Add entries from the original input array into the cache, and/or + // extend the lifetime of entries that are already in the cache. + until := now.Add(endpointTrackerLifetime) + for _, ep := range eps { + et.addLocked(now, ep, until) + } + + // Remove everything that has now expired. + et.removeExpiredLocked(now) + return epsPlusCached +} + +// add will store the provided endpoint(s) in the cache for a fixed period of +// time, and remove any entries in the cache that have expired. +// +// et.mu must be held. +func (et *endpointTracker) addLocked(now time.Time, ep tailcfg.Endpoint, until time.Time) { + // If we already have an entry for this endpoint, update the timeout on + // it; otherwise, add it. + entry, found := et.cache[ep.Addr] + if found { + entry.until = until + } else { + entry = endpointTrackerEntry{ep, until} + } + mak.Set(&et.cache, ep.Addr, entry) +} + +// removeExpired will remove all expired entries from the cache +// +// et.mu must be held +func (et *endpointTracker) removeExpiredLocked(now time.Time) { + for k, ep := range et.cache { + if now.After(ep.until) { + delete(et.cache, k) + } + } +} + var ( metricNumPeers = clientmetric.NewGauge("magicsock_netmap_num_peers") metricNumDERPConns = clientmetric.NewGauge("magicsock_num_derp_conns") diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 6e844f26c..c050cfc20 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -18,6 +18,7 @@ import ( "net/http/httptest" "net/netip" "os" + "reflect" "runtime" "strconv" "strings" @@ -31,6 +32,7 @@ import ( "github.com/tailscale/wireguard-go/tun/tuntest" "go4.org/mem" "golang.org/x/exp/maps" + "golang.org/x/exp/slices" "golang.org/x/net/ipv6" "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/derp" @@ -390,6 +392,7 @@ collectEndpoints: for { select { case ep := <-epCh: + t.Logf("TestNewConn: got endpoint: %v", ep) endpoints = append(endpoints, ep) if strings.HasSuffix(ep, suffix) { break collectEndpoints @@ -2280,3 +2283,113 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) { t.Fatal("no packet after 1s") } } + +func TestEndpointTracker(t *testing.T) { + local := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("192.168.1.1:12345"), + Type: tailcfg.EndpointLocal, + } + + stun4_1 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("1.2.3.4:12345"), + Type: tailcfg.EndpointSTUN, + } + stun4_2 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("5.6.7.8:12345"), + Type: tailcfg.EndpointSTUN, + } + + stun6_1 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("[2a09:8280:1::1111]:12345"), + Type: tailcfg.EndpointSTUN, + } + stun6_2 := tailcfg.Endpoint{ + Addr: netip.MustParseAddrPort("[2a09:8280:1::2222]:12345"), + Type: tailcfg.EndpointSTUN, + } + + start := time.Unix(1681503440, 0) + + steps := []struct { + name string + now time.Time + eps []tailcfg.Endpoint + want []tailcfg.Endpoint + }{ + { + name: "initial endpoints", + now: start, + eps: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "no change", + now: start.Add(1 * time.Minute), + eps: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "missing stun4", + now: start.Add(2 * time.Minute), + eps: []tailcfg.Endpoint{local, stun6_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "missing stun6", + now: start.Add(3 * time.Minute), + eps: []tailcfg.Endpoint{local, stun4_1}, + want: []tailcfg.Endpoint{local, stun4_1, stun6_1}, + }, + { + name: "multiple STUN addresses within timeout", + now: start.Add(4 * time.Minute), + eps: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + want: []tailcfg.Endpoint{local, stun4_1, stun4_2, stun6_1, stun6_2}, + }, + { + name: "endpoint extended", + now: start.Add(3*time.Minute + endpointTrackerLifetime - 1), + eps: []tailcfg.Endpoint{local}, + want: []tailcfg.Endpoint{ + local, stun4_2, stun6_2, + // stun4_1 had its lifetime extended by the + // "missing stun6" test above to that start + // time plus the lifetime, while stun6 should + // have expired a minute sooner. It should thus + // be in this returned list. + stun4_1, + }, + }, + { + name: "after timeout", + now: start.Add(4*time.Minute + endpointTrackerLifetime + 1), + eps: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + want: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + }, + { + name: "after timeout still caches", + now: start.Add(4*time.Minute + endpointTrackerLifetime + time.Minute), + eps: []tailcfg.Endpoint{local}, + want: []tailcfg.Endpoint{local, stun4_2, stun6_2}, + }, + } + + var et endpointTracker + for _, tt := range steps { + t.Logf("STEP: %s", tt.name) + + got := et.update(tt.now, tt.eps) + + // Sort both arrays for comparison + slices.SortFunc(got, func(a, b tailcfg.Endpoint) bool { + return a.Addr.String() < b.Addr.String() + }) + slices.SortFunc(tt.want, func(a, b tailcfg.Endpoint) bool { + return a.Addr.String() < b.Addr.String() + }) + + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("endpoints mismatch\ngot: %+v\nwant: %+v", got, tt.want) + } + } +}