diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index 592a56876..8a0328cc9 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -664,7 +664,7 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en return 0, nil } - ep.noteRecvActivity() + ep.noteRecvActivity(ipp) if stats := c.stats.Load(); stats != nil { stats.UpdateRxPhysical(ep.nodeAddr, ipp, dm.n) } diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index 751fc8ba8..b42d134b8 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -223,14 +223,26 @@ func (de *endpoint) initFakeUDPAddr() { // noteRecvActivity records receive activity on de, and invokes // Conn.noteRecvActivity no more than once every 10s. -func (de *endpoint) noteRecvActivity() { - if de.c.noteRecvActivity == nil { - return - } +func (de *endpoint) noteRecvActivity(ipp netip.AddrPort) { now := mono.Now() + + // TODO(raggi): this probably applies relatively equally well to disco + // managed endpoints, but that would be a less conservative change. + if de.isWireguardOnly { + de.mu.Lock() + de.bestAddr.AddrPort = ipp + de.bestAddrAt = now + de.trustBestAddrUntil = now.Add(5 * time.Second) + de.mu.Unlock() + } + elapsed := now.Sub(de.lastRecv.LoadAtomic()) if elapsed > 10*time.Second { de.lastRecv.StoreAtomic(now) + + if de.c.noteRecvActivity == nil { + return + } de.c.noteRecvActivity(de.publicKey) } } @@ -292,11 +304,23 @@ func (de *endpoint) addrForSendLocked(now mono.Time) (udpAddr, derpAddr netip.Ad // // de.mu must be held. func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.AddrPort, shouldPing bool) { + if len(de.endpointState) == 0 { + de.c.logf("magicsock: addrForSendWireguardLocked: [unexpected] no candidates available for endpoint") + return udpAddr, false + } + // lowestLatency is a high duration initially, so we // can be sure we're going to have a duration lower than this // for the first latency retrieved. lowestLatency := time.Hour + var oldestPing mono.Time for ipp, state := range de.endpointState { + if oldestPing.IsZero() { + oldestPing = state.lastPing + } else if state.lastPing.Before(oldestPing) { + oldestPing = state.lastPing + } + if latency, ok := state.latencyLocked(); ok { if latency < lowestLatency || latency == lowestLatency && ipp.Addr().Is6() { // If we have the same latency,IPv6 is prioritized. @@ -307,35 +331,25 @@ func (de *endpoint) addrForWireGuardSendLocked(now mono.Time) (udpAddr netip.Add } } } + needPing := len(de.endpointState) > 1 && now.Sub(oldestPing) > wireguardPingInterval - if udpAddr.IsValid() { - // Set trustBestAddrUntil to an hour, so we will - // continue to use this address for a long period of time. - de.bestAddr.AddrPort = udpAddr - de.trustBestAddrUntil = now.Add(1 * time.Hour) - return udpAddr, false - } + if !udpAddr.IsValid() { + candidates := xmaps.Keys(de.endpointState) - candidates := xmaps.Keys(de.endpointState) - if len(candidates) == 0 { - de.c.logf("magicsock: addrForSendWireguardLocked: [unexpected] no candidates available for endpoint") - return udpAddr, false + // Randomly select an address to use until we retrieve latency information + // and give it a short trustBestAddrUntil time so we avoid flapping between + // addresses while waiting on latency information to be populated. + udpAddr = candidates[rand.Intn(len(candidates))] } - // Randomly select an address to use until we retrieve latency information - // and give it a short trustBestAddrUntil time so we avoid flapping between - // addresses while waiting on latency information to be populated. - udpAddr = candidates[rand.Intn(len(candidates))] de.bestAddr.AddrPort = udpAddr - if len(candidates) == 1 { - // if we only have one address that we can send data too, - // we should trust it for a longer period of time. - de.trustBestAddrUntil = now.Add(1 * time.Hour) - } else { - de.trustBestAddrUntil = now.Add(15 * time.Second) - } - - return udpAddr, len(candidates) > 1 + // Only extend trustBestAddrUntil by one second to avoid packet + // reordering and/or CPU usage from random selection during the first + // second. We should receive a response due to a WireGuard handshake in + // less than one second in good cases, in which case this will be then + // extended to 15 seconds. + de.trustBestAddrUntil = now.Add(time.Second) + return udpAddr, needPing } // heartbeat is called every heartbeatInterval to keep the best UDP path alive, diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index c52df03f0..7cea3877c 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -1188,7 +1188,7 @@ func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) cache.gen = de.numStopAndReset() ep = de } - ep.noteRecvActivity() + ep.noteRecvActivity(ipp) if stats := c.stats.Load(); stats != nil { stats.UpdateRxPhysical(ep.nodeAddr, ipp, len(b)) } @@ -2605,6 +2605,11 @@ var ( // resetting the counter, as the first pings likely didn't through // the firewall) discoPingInterval = 5 * time.Second + + // wireguardPingInterval is the minimum time between pings to an endpoint. + // Pings are only sent if we have not observed bidirectional traffic with an + // endpoint in at least this duration. + wireguardPingInterval = 5 * time.Second ) // indexSentinelDeleted is the temporary value that endpointState.index takes while diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 5bc68cf37..bd5a07624 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -1212,11 +1212,11 @@ func Test32bitAlignment(t *testing.T) { t.Fatalf("endpoint.lastRecv is not 8-byte aligned") } - de.noteRecvActivity() // verify this doesn't panic on 32-bit + de.noteRecvActivity(netip.AddrPort{}) // verify this doesn't panic on 32-bit if called != 1 { t.Fatal("expected call to noteRecvActivity") } - de.noteRecvActivity() + de.noteRecvActivity(netip.AddrPort{}) if called != 1 { t.Error("expected no second call to noteRecvActivity") } @@ -2678,6 +2678,7 @@ func newPingResponder(t *testing.T) *pingResponder { func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { testTime := mono.Now() + secondPingTime := testTime.Add(10 * time.Second) type endpointDetails struct { addrPort netip.AddrPort @@ -2685,16 +2686,79 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { } wgTests := []struct { - name string - noV4 bool - noV6 bool - sendWGPing bool - ep []endpointDetails - want netip.AddrPort + name string + sendInitialPing bool + validAddr bool + sendFollowUpPing bool + pingTime mono.Time + ep []endpointDetails + want netip.AddrPort }{ { - name: "choose lowest latency for useable IPv4 and IPv6", - sendWGPing: true, + name: "no endpoints", + sendInitialPing: false, + validAddr: false, + sendFollowUpPing: false, + pingTime: testTime, + ep: []endpointDetails{}, + want: netip.AddrPort{}, + }, + { + name: "singular endpoint does not request ping", + sendInitialPing: false, + validAddr: true, + sendFollowUpPing: false, + pingTime: testTime, + ep: []endpointDetails{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + }, + want: netip.MustParseAddrPort("1.1.1.1:111"), + }, + { + name: "ping sent within wireguardPingInterval should not request ping", + sendInitialPing: true, + validAddr: true, + sendFollowUpPing: false, + pingTime: testTime.Add(7 * time.Second), + ep: []endpointDetails{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + { + addrPort: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"), + latency: 2000 * time.Millisecond, + }, + }, + want: netip.MustParseAddrPort("1.1.1.1:111"), + }, + { + name: "ping sent outside of wireguardPingInterval should request ping", + sendInitialPing: true, + validAddr: true, + sendFollowUpPing: true, + pingTime: testTime.Add(3 * time.Second), + ep: []endpointDetails{ + { + addrPort: netip.MustParseAddrPort("1.1.1.1:111"), + latency: 100 * time.Millisecond, + }, + { + addrPort: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"), + latency: 150 * time.Millisecond, + }, + }, + want: netip.MustParseAddrPort("1.1.1.1:111"), + }, + { + name: "choose lowest latency for useable IPv4 and IPv6", + sendInitialPing: true, + validAddr: true, + sendFollowUpPing: false, + pingTime: secondPingTime, ep: []endpointDetails{ { addrPort: netip.MustParseAddrPort("1.1.1.1:111"), @@ -2708,8 +2772,11 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { want: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"), }, { - name: "choose IPv6 address when latency is the same for v4 and v6", - sendWGPing: true, + name: "choose IPv6 address when latency is the same for v4 and v6", + sendInitialPing: true, + validAddr: true, + sendFollowUpPing: false, + pingTime: secondPingTime, ep: []endpointDetails{ { addrPort: netip.MustParseAddrPort("1.1.1.1:111"), @@ -2725,52 +2792,57 @@ func TestAddrForSendLockedForWireGuardOnly(t *testing.T) { } for _, test := range wgTests { - endpoint := &endpoint{ - isWireguardOnly: true, - endpointState: map[netip.AddrPort]*endpointState{}, - c: &Conn{ - noV4: atomic.Bool{}, - noV6: atomic.Bool{}, - }, - } + t.Run(test.name, func(t *testing.T) { + endpoint := &endpoint{ + isWireguardOnly: true, + endpointState: map[netip.AddrPort]*endpointState{}, + c: &Conn{ + logf: t.Logf, + noV4: atomic.Bool{}, + noV6: atomic.Bool{}, + }, + } - for _, epd := range test.ep { - endpoint.endpointState[epd.addrPort] = &endpointState{} - } + for _, epd := range test.ep { + endpoint.endpointState[epd.addrPort] = &endpointState{} + } + udpAddr, _, shouldPing := endpoint.addrForSendLocked(testTime) + if udpAddr.IsValid() != test.validAddr { + t.Errorf("udpAddr validity is incorrect; got %v, want %v", udpAddr.IsValid(), test.validAddr) + } + if shouldPing != test.sendInitialPing { + t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendInitialPing) + } - udpAddr, _, shouldPing := endpoint.addrForSendLocked(testTime) - if !udpAddr.IsValid() { - t.Error("udpAddr returned is not valid") - } - if shouldPing != test.sendWGPing { - t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendWGPing) - } + // Update the endpointState to simulate a ping having been + // sent and a pong received. + for _, epd := range test.ep { + state, ok := endpoint.endpointState[epd.addrPort] + if !ok { + t.Errorf("addr does not exist in endpoint state map") + } + state.lastPing = test.pingTime - for _, epd := range test.ep { - state, ok := endpoint.endpointState[epd.addrPort] - if !ok { - t.Errorf("addr does not exist in endpoint state map") + latency, ok := state.latencyLocked() + if ok { + t.Errorf("latency was set for %v: %v", epd.addrPort, latency) + } + state.recentPongs = append(state.recentPongs, pongReply{ + latency: epd.latency, + }) + state.recentPong = 0 } - latency, ok := state.latencyLocked() - if ok { - t.Errorf("latency was set for %v: %v", epd.addrPort, latency) + udpAddr, _, shouldPing = endpoint.addrForSendLocked(secondPingTime) + if udpAddr != test.want { + t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, test.want) } - state.recentPongs = append(state.recentPongs, pongReply{ - latency: epd.latency, - }) - state.recentPong = 0 - } - - udpAddr, _, shouldPing = endpoint.addrForSendLocked(testTime.Add(2 * time.Minute)) - if udpAddr != test.want { - t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, test.want) - } - if shouldPing { - t.Error("addrForSendLocked should not indicate ping is required") - } - if endpoint.bestAddr.AddrPort != test.want { - t.Errorf("bestAddr.AddrPort is not as expected: got %v, want %v", endpoint.bestAddr.AddrPort, test.want) - } + if shouldPing != test.sendFollowUpPing { + t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendFollowUpPing) + } + if endpoint.bestAddr.AddrPort != test.want { + t.Errorf("bestAddr.AddrPort is not as expected: got %v, want %v", endpoint.bestAddr.AddrPort, test.want) + } + }) } }