diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index cbdfc0ac0..04ec092bc 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -1522,10 +1522,11 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool { func (c *Conn) handlePingLocked(dm *disco.Ping, de *discoEndpoint, src netaddr.IPPort) { c.logf("magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, de.discoShort, de.publicKey.ShortString(), src, dm.TxID[:6]) - // Remember this this route if not present. + // Remember this route if not present. c.setAddrToDiscoLocked(src, de.discoKey, nil) - go de.sendDiscoMessage(src, &disco.Pong{ + pongDst := src + go de.sendDiscoMessage(pongDst, &disco.Pong{ TxID: dm.TxID, Src: src, }) @@ -2623,8 +2624,10 @@ type discoEndpoint struct { // mu protects all following fields. mu sync.Mutex // Lock ordering: Conn.mu, then discoEndpoint.mu - lastSend time.Time // last time there was outgoing packets sent to this peer (from wireguard-go) - derpAddr netaddr.IPPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) + heartBeatTimer *time.Timer // nil when idle + lastSend time.Time // last time there was outgoing packets sent to this peer (from wireguard-go) + lastFullPing time.Time // last time we pinged all endpoints + derpAddr netaddr.IPPort // fallback/bootstrap path, if non-zero (non-zero for well-behaved clients) bestAddr netaddr.IPPort // best non-DERP path; zero if none bestAddrLatency time.Duration @@ -2635,6 +2638,18 @@ type discoEndpoint struct { } const ( + // sessionActiveTimeout is how long since the last activity we + // try to keep an established discoEndpoint peering alive. + sessionActiveTimeout = 2 * time.Minute + + // upgradeInterval is how often we try to upgrade to a better path + // even if we have some non-DERP route that works. + upgradeInterval = 1 * time.Minute + + // heartbeatInterval is how often pings to the best UDP address + // are sent. + heartbeatInterval = 2 * time.Second + // discoPingInterval is the minimum time between pings // to an endpoint. (Except in the case of CallMeMaybe frames // resetting the counter, as the first pings likely didn't through @@ -2648,6 +2663,10 @@ const ( // trustUDPAddrDuration is how long we trust a UDP address as the exclusive // path (without using DERP) without having heard a Pong reply. trustUDPAddrDuration = 5 * time.Second + + // goodEnoughLatency is the latency at or under which we don't + // try to upgrade to a better path. + goodEnoughLatency = 5 * time.Millisecond ) // endpointState is some state and history for a specific endpoint of @@ -2734,15 +2753,75 @@ func (de *discoEndpoint) addrForSendLocked(now time.Time) (udpAddr, derpAddr net return } +// heartbeat is called every heartbeatInterval to keep the best UDP path alive, +// or kick off discovery of other paths. +func (de *discoEndpoint) heartbeat() { + de.mu.Lock() + defer de.mu.Unlock() + + de.heartBeatTimer = nil + + if de.lastSend.IsZero() { + // Shouldn't happen. + return + } + + if time.Since(de.lastSend) > sessionActiveTimeout { + // Session's idle. Stop heartbeating. + de.c.logf("magicsock: disco: ending heartbeats for idle session to %v (%v)", de.publicKey.ShortString(), de.discoShort) + return + } + + now := time.Now() + udpAddr, _ := de.addrForSendLocked(now) + if !udpAddr.IsZero() { + // We have a preferred path. Ping that every 2 seconds. + de.startPingLocked(udpAddr, now) + } + + if de.wantFullPingLocked(now) { + de.sendPingsLocked(now, true) + } + + de.heartBeatTimer = time.AfterFunc(heartbeatInterval, de.heartbeat) +} + +// wantFullPingLocked reports whether we should ping to all our peers looking for +// a better path. +// +// de.mu must be held. +func (de *discoEndpoint) wantFullPingLocked(now time.Time) bool { + if de.bestAddr.IsZero() || de.lastFullPing.IsZero() { + return true + } + if now.After(de.trustBestAddrUntil) { + return true + } + if de.bestAddrLatency <= goodEnoughLatency { + return false + } + if now.Sub(de.lastFullPing) >= upgradeInterval { + return true + } + return false +} + +func (de *discoEndpoint) noteActiveLocked() { + de.lastSend = time.Now() + if de.heartBeatTimer == nil { + de.heartBeatTimer = time.AfterFunc(heartbeatInterval, de.heartbeat) + } +} + func (de *discoEndpoint) send(b []byte) error { now := time.Now() de.mu.Lock() - de.lastSend = now udpAddr, derpAddr := de.addrForSendLocked(now) if udpAddr.IsZero() || now.After(de.trustBestAddrUntil) { de.sendPingsLocked(now, true) } + de.noteActiveLocked() de.mu.Unlock() if udpAddr.IsZero() && derpAddr.IsZero() { @@ -2778,34 +2857,47 @@ func (de *discoEndpoint) removeSentPingLocked(txid stun.TxID, sp sentPing) { delete(de.sentPing, txid) } -// sendPing sends a ping with the provided txid to ep. -// The caller should've already been recorded the ping in sentPing -// and set up the timer. -func (de *discoEndpoint) sendPing(ep netaddr.IPPort, txid stun.TxID) { +// sendDiscoPing sends a ping with the provided txid to ep. +// +// The caller (startPingLocked) should've already been recorded the ping in +// sentPing and set up the timer. +func (de *discoEndpoint) sendDiscoPing(ep netaddr.IPPort, txid stun.TxID) { sent, _ := de.sendDiscoMessage(ep, &disco.Ping{TxID: [12]byte(txid)}) if !sent { de.forgetPing(txid) } } +func (de *discoEndpoint) startPingLocked(ep netaddr.IPPort, now time.Time) { + st, ok := de.endpointState[ep] + if !ok { + // Shouldn't happen. But don't ping an endpoint that's + // not active for us. + de.c.logf("magicsock: disco: [unexpected] attempt to ping no longer live endpoint %v", ep) + return + } + st.lastPing = now + + txid := stun.NewTxID() + de.sentPing[txid] = sentPing{ + to: ep, + at: now, + timer: time.AfterFunc(pingTimeoutDuration, func() { + de.c.logf("magicsock: disco: timeout waiting for pong %x from %v (%v, %v)", txid[:6], ep, de.publicKey.ShortString(), de.discoShort) + de.forgetPing(txid) + }), + } + go de.sendDiscoPing(ep, txid) +} + func (de *discoEndpoint) sendPingsLocked(now time.Time, sendCallMeMaybe bool) { + de.lastFullPing = now var sentAny bool for ep, st := range de.endpointState { ep := ep if !st.lastPing.IsZero() && now.Sub(st.lastPing) < discoPingInterval { continue } - st.lastPing = now - - txid := stun.NewTxID() - de.sentPing[txid] = sentPing{ - to: ep, - at: now, - timer: time.AfterFunc(pingTimeoutDuration, func() { - de.c.logf("magicsock: disco: timeout waiting for pong %x from %v (%v, %v)", txid[:6], ep, de.publicKey.ShortString(), de.discoShort) - de.forgetPing(txid) - }), - } firstPing := !sentAny sentAny = true @@ -2814,7 +2906,7 @@ func (de *discoEndpoint) sendPingsLocked(now time.Time, sendCallMeMaybe bool) { de.c.logf("magicsock: disco: send, starting discovery for %v (%v)", de.publicKey.ShortString(), de.discoShort) } - go de.sendPing(ep, txid) + de.startPingLocked(ep, now) } derpAddr := de.derpAddr if sentAny && sendCallMeMaybe && !derpAddr.IsZero() { @@ -3003,6 +3095,10 @@ func (de *discoEndpoint) cleanup() { for txid, sp := range de.sentPing { de.removeSentPingLocked(txid, sp) } + if de.heartBeatTimer != nil { + de.heartBeatTimer.Stop() + de.heartBeatTimer = nil + } } // ippCache is a cache of *net.UDPAddr => netaddr.IPPort mappings.