From 4324f312e933213e168eceaebd89b468201b5a11 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Wed, 5 Nov 2025 17:30:56 -0800 Subject: [PATCH] disco,ipn/ipnlocal,wgengine/magicsock: add graceful disco key rotation The client can now rotate a disco key gracefully, wherein it still accepts traffic from peers using the old disco key for a time, while informing them about the new key via a new KeyUpdate disco message. Updates #17756 Updates tailscale/corp#34037 Signed-off-by: James Tucker --- disco/disco.go | 32 ++ disco/disco_test.go | 9 + ipn/ipnlocal/local.go | 4 +- wgengine/magicsock/magicsock.go | 282 ++++++++++++- wgengine/magicsock/magicsock_test.go | 590 ++++++++++++++++++++++++++- 5 files changed, 903 insertions(+), 14 deletions(-) diff --git a/disco/disco.go b/disco/disco.go index f58bc1b8c..401d8a9b9 100644 --- a/disco/disco.go +++ b/disco/disco.go @@ -51,6 +51,7 @@ const ( TypeCallMeMaybeVia = MessageType(0x07) TypeAllocateUDPRelayEndpointRequest = MessageType(0x08) TypeAllocateUDPRelayEndpointResponse = MessageType(0x09) + TypeKeyUpdate = MessageType(0x0a) ) const v0 = byte(0) @@ -103,6 +104,8 @@ func Parse(p []byte) (Message, error) { return parseAllocateUDPRelayEndpointRequest(ver, p) case TypeAllocateUDPRelayEndpointResponse: return parseAllocateUDPRelayEndpointResponse(ver, p) + case TypeKeyUpdate: + return parseKeyUpdate(ver, p) default: return nil, fmt.Errorf("unknown message type 0x%02x", byte(t)) } @@ -278,6 +281,33 @@ func parsePong(ver uint8, p []byte) (m *Pong, err error) { return m, nil } +// KeyUpdate is a message sent during disco key rotation to notify a peer +// of our new disco public key. It is sent encrypted with the OLD shared key +// so that the peer can decrypt it before they learn about the new key from +// the control plane. +type KeyUpdate struct { + // NewDiscoKey is the sender's new disco public key. + NewDiscoKey key.DiscoPublic +} + +const keyUpdateLen = key.DiscoPublicRawLen + +func (m *KeyUpdate) AppendMarshal(b []byte) []byte { + ret, d := appendMsgHeader(b, TypeKeyUpdate, v0, keyUpdateLen) + m.NewDiscoKey.AppendTo(d[:0]) + return ret +} + +func parseKeyUpdate(ver uint8, p []byte) (*KeyUpdate, error) { + if len(p) < keyUpdateLen { + return nil, errShort + } + m := &KeyUpdate{ + NewDiscoKey: key.DiscoPublicFromRaw32(mem.B(p[:keyUpdateLen])), + } + return m, nil +} + // MessageSummary returns a short summary of m for logging purposes. func MessageSummary(m Message) string { switch m := m.(type) { @@ -299,6 +329,8 @@ func MessageSummary(m Message) string { return "allocate-udp-relay-endpoint-request" case *AllocateUDPRelayEndpointResponse: return "allocate-udp-relay-endpoint-response" + case *KeyUpdate: + return fmt.Sprintf("key-update new=%v", m.NewDiscoKey.ShortString()) default: return fmt.Sprintf("%#v", m) } diff --git a/disco/disco_test.go b/disco/disco_test.go index 71b68338a..fd9762bbf 100644 --- a/disco/disco_test.go +++ b/disco/disco_test.go @@ -38,6 +38,8 @@ func TestMarshalAndParse(t *testing.T) { }, } + testDiscoKey := key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 3: 3, 30: 30, 31: 31})) + tests := []struct { name string want string @@ -96,6 +98,13 @@ func TestMarshalAndParse(t *testing.T) { m: &CallMeMaybe{}, want: "03 00", }, + { + name: "key_update", + m: &KeyUpdate{ + NewDiscoKey: testDiscoKey, + }, + want: "0a 00 00 01 02 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f", + }, { name: "call_me_maybe_endpoints", m: &CallMeMaybe{ diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 0fbecf1c0..7524d71de 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -6691,7 +6691,9 @@ func (b *LocalBackend) DebugReSTUN() error { func (b *LocalBackend) DebugRotateDiscoKey() error { mc := b.MagicConn() - mc.RotateDiscoKey() + if err := mc.RotateDiscoKey(); err != nil { + return err + } newDiscoKey := mc.DiscoPublicKey() diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 8290de8ff..696760784 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -276,6 +276,16 @@ type Conn struct { // discoKey is the current disco private and public keypair for this conn. discoKey *key.DiscoKey + // discoKeyCreatedAt is when the current disco key was created. + // Used for both rate limiting rotations (ensuring keys are old enough to rotate) + // and for tracking when to cleanup the old key after the grace period. + discoKeyCreatedAt atomic.Pointer[time.Time] + + // oldDiscoKey is the previous disco private key, kept during the grace + // period after a rotation to allow peers to decrypt messages sent with + // the old key until they receive the new key from the control plane. + oldDiscoKey atomic.Pointer[key.DiscoPrivate] + // ============================================================ // mu guards all following fields; see userspaceEngine lock // ordering rules against the engine. For derphttp, mu must @@ -600,6 +610,8 @@ func newConn(logf logger.Logf) *Conn { cloudInfo: newCloudInfo(logf), } c.discoKey = key.NewDiscoKeyFromPrivate(discoPrivate) + now := time.Now() + c.discoKeyCreatedAt.Store(&now) c.bind = &connBind{Conn: c, closed: true} c.receiveBatchPool = sync.Pool{New: func() any { msgs := make([]ipv6.Message, c.bind.BatchSize()) @@ -1237,26 +1249,71 @@ func (c *Conn) DiscoPublicKey() key.DiscoPublic { // RotateDiscoKey generates a new discovery key pair and updates the connection // to use it. This invalidates all existing disco sessions and will cause peers // to re-establish discovery sessions with the new key. +// RotateDiscoKey rotates the discovery key gracefully. The old key is kept +// for a grace period (discoKeyRotationGracePeriod) to allow peers to transition +// to the new key. Active peers are notified directly via KeyUpdate messages. // -// This is primarily for debugging and testing purposes, a future enhancement -// should provide a mechanism for seamless rotation by supporting short term use -// of the old key. -func (c *Conn) RotateDiscoKey() { +// Returns an error if the current key is too new to rotate (less than +// minDiscoKeyAge old). +func (c *Conn) RotateDiscoKey() error { oldShort := c.discoKey.Short() + oldPrivate := c.discoKey.Private() + + if createdAt := c.discoKeyCreatedAt.Load(); createdAt != nil { + keyAge := time.Since(*createdAt) + if keyAge < minDiscoKeyAge { + return fmt.Errorf("disco key is only %v old, must be at least %v old to rotate", keyAge.Round(time.Second), minDiscoKeyAge) + } + } + newPrivate := key.NewDisco() c.mu.Lock() + + c.oldDiscoKey.Store(&oldPrivate) + now := time.Now() + c.discoKeyCreatedAt.Store(&now) + c.discoKey.Set(newPrivate) newShort := c.discoKey.Short() - c.discoInfo = make(map[key.DiscoPublic]*discoInfo) + + for peerDiscoKey, di := range c.discoInfo { + di.sharedKey = newPrivate.Shared(peerDiscoKey) + oldShared := oldPrivate.Shared(peerDiscoKey) + di.oldSharedKey = &oldShared + } + + cutoff := time.Now().Add(-5 * time.Minute) + var activePeers []key.DiscoPublic + for peerDiscoKey, di := range c.discoInfo { + if di.lastPingTime.After(cutoff) { + activePeers = append(activePeers, peerDiscoKey) + } + } + connCtx := c.connCtx c.mu.Unlock() - c.logf("magicsock: rotated disco key from %v to %v", oldShort, newShort) + c.logf("magicsock: rotated disco key from %v to %v, notifying %d peers", oldShort, newShort, len(activePeers)) + + // KeyUpdate messages are encrypted with the OLD shared key so peers can + // decrypt them before learning new key from control plane + go c.sendKeyUpdatesToPeers(activePeers, newPrivate.Public(), oldPrivate) + + // TODO(raggi): we should think carefully about and review if we even really + // want to do this. There may be little to no value in practice of dropping + // the old key - doing so increases the chances that we will fail to + // communicate with peers. If we were to introduce a regular disco key + // rotation schedule then old keys should phase out soon enough. + time.AfterFunc(discoKeyRotationGracePeriod, func() { + c.cleanupOldDiscoKey() + }) if connCtx != nil { c.ReSTUN("disco-key-rotation") } + + return nil } // determineEndpoints returns the machine's endpoint addresses. It does a STUN @@ -2220,6 +2277,14 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake sealedBox := msg[discoHeaderLen:] payload, ok := di.sharedKey.Open(sealedBox) + usedOldKey := false + if !ok && di.oldSharedKey != nil { + payload, ok = di.oldSharedKey.Open(sealedBox) + if ok { + usedOldKey = true + metricRecvDiscoWithOldKey.Add(1) + } + } if !ok { // This might have been intended for a previous // disco key. When we restart we get a new disco key @@ -2237,6 +2302,9 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake metricRecvDiscoBadKey.Add(1) return } + if usedOldKey && debugDisco() { + c.logf("magicsock: disco: decrypted message from %v using old key", sender.ShortString()) + } // Emit information about the disco frame into the pcap stream // if a capture hook is installed. @@ -2280,9 +2348,15 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake } switch dm := dm.(type) { + case *disco.KeyUpdate: + metricRecvDiscoKeyUpdate.Add(1) + c.handleKeyUpdateLocked(dm, sender, di, src) case *disco.Ping: metricRecvDiscoPing.Add(1) c.handlePingLocked(dm, src, di, derpNodeSrc) + if usedOldKey { + c.sendKeyUpdateToPeerLocked(sender, di) + } case *disco.Pong: metricRecvDiscoPong.Add(1) // There might be multiple nodes for the sender's DiscoKey. @@ -2667,11 +2741,180 @@ func (c *Conn) discoInfoForKnownPeerLocked(k key.DiscoPublic) *discoInfo { discoShort: k.ShortString(), sharedKey: c.discoKey.Private().Shared(k), } + if oldKey := c.oldDiscoKey.Load(); oldKey != nil { + oldShared := oldKey.Shared(k) + di.oldSharedKey = &oldShared + } c.discoInfo[k] = di } return di } +// handleKeyUpdateLocked processes a KeyUpdate message from a peer, updating +// their disco key and recomputing the shared key. +// +// c.mu must be held. +func (c *Conn) handleKeyUpdateLocked(m *disco.KeyUpdate, oldDiscoKey key.DiscoPublic, di *discoInfo, src epAddr) { + newDiscoKey := m.NewDiscoKey + if newDiscoKey.IsZero() { + c.logf("magicsock: disco: ignoring KeyUpdate with zero key from %v", oldDiscoKey.ShortString()) + return + } + + if newDiscoKey == oldDiscoKey { + // Same key, nothing to do + return + } + + c.logf("magicsock: disco: peer %v updated disco key from %v to %v", + di.discoKey.ShortString(), oldDiscoKey.ShortString(), newDiscoKey.ShortString()) + + delete(c.discoInfo, oldDiscoKey) + + newDi := &discoInfo{ + discoKey: newDiscoKey, + discoShort: newDiscoKey.ShortString(), + sharedKey: c.discoKey.Private().Shared(newDiscoKey), + lastPingFrom: di.lastPingFrom, + lastPingTime: di.lastPingTime, + } + + if oldKey := c.oldDiscoKey.Load(); oldKey != nil { + oldShared := oldKey.Shared(newDiscoKey) + newDi.oldSharedKey = &oldShared + } + + c.discoInfo[newDiscoKey] = newDi +} + +// sendKeyUpdateToPeerLocked sends a KeyUpdate message to a single peer. +// This is called when we receive a message from a peer using our old key, +// to accelerate their transition to our new key. +// +// c.mu must be held. +func (c *Conn) sendKeyUpdateToPeerLocked(peerDiscoKey key.DiscoPublic, di *discoInfo) { + if c.oldDiscoKey.Load() == nil { + return + } + + if time.Since(di.lastPingTime) < 10*time.Second { + return + } + + newKey := c.discoKey.Public() + c.mu.Unlock() + defer c.mu.Lock() + + if di.lastPingFrom.ap.IsValid() { + c.sendKeyUpdateToPeer(peerDiscoKey, di.lastPingFrom, newKey) + } +} + +// sendKeyUpdateToPeer sends a KeyUpdate message to a peer at the specified address. +func (c *Conn) sendKeyUpdateToPeer(peerDiscoKey key.DiscoPublic, dst epAddr, newKey key.DiscoPublic) { + oldKey := c.oldDiscoKey.Load() + if oldKey == nil { + return + } + + keyUpdate := &disco.KeyUpdate{NewDiscoKey: newKey} + cleartext := keyUpdate.AppendMarshal(nil) + oldShared := oldKey.Shared(peerDiscoKey) + sealed := oldShared.Seal(cleartext) + + pkt := make([]byte, 0, 512) + if dst.vni.IsSet() { + gh := packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: dst.vni, + Control: false, + } + pkt = append(pkt, make([]byte, packet.GeneveFixedHeaderLength)...) + if err := gh.Encode(pkt); err != nil { + return + } + } + pkt = append(pkt, disco.Magic...) + pkt = oldKey.Public().AppendTo(pkt) + pkt = append(pkt, sealed...) + + const isDisco = true + if sent, _ := c.sendAddr(dst.ap, key.NodePublic{}, pkt, isDisco, dst.vni.IsSet()); sent { + metricSentDiscoKeyUpdate.Add(1) + if debugDisco() { + c.dlogf("[v1] magicsock: disco: sent key-update to %v at %v", peerDiscoKey.ShortString(), dst) + } + } +} + +// sendKeyUpdatesToPeers sends KeyUpdate messages to a list of peers. +func (c *Conn) sendKeyUpdatesToPeers(peers []key.DiscoPublic, newKey key.DiscoPublic, oldKey key.DiscoPrivate) { + for _, peerDiscoKey := range peers { + c.mu.Lock() + di := c.discoInfo[peerDiscoKey] + if di == nil || !di.lastPingFrom.ap.IsValid() { + c.mu.Unlock() + continue + } + dst := di.lastPingFrom + c.mu.Unlock() + + keyUpdate := &disco.KeyUpdate{NewDiscoKey: newKey} + cleartext := keyUpdate.AppendMarshal(nil) + oldShared := oldKey.Shared(peerDiscoKey) + sealed := oldShared.Seal(cleartext) + + pkt := make([]byte, 0, 512) + if dst.vni.IsSet() { + gh := packet.GeneveHeader{ + Version: 0, + Protocol: packet.GeneveProtocolDisco, + VNI: dst.vni, + Control: false, + } + pkt = append(pkt, make([]byte, packet.GeneveFixedHeaderLength)...) + if err := gh.Encode(pkt); err != nil { + continue + } + } + pkt = append(pkt, disco.Magic...) + pkt = oldKey.Public().AppendTo(pkt) + pkt = append(pkt, sealed...) + + const isDisco = true + if sent, _ := c.sendAddr(dst.ap, key.NodePublic{}, pkt, isDisco, dst.vni.IsSet()); sent { + metricSentDiscoKeyUpdate.Add(1) + } + } +} + +// cleanupOldDiscoKey removes the old disco key after the grace period. +// The grace period is measured from when the current (new) key was created. +func (c *Conn) cleanupOldDiscoKey() { + c.mu.Lock() + defer c.mu.Unlock() + + createdAt := c.discoKeyCreatedAt.Load() + if createdAt == nil { + return + } + + if time.Since(*createdAt) < discoKeyRotationGracePeriod { + return + } + + c.oldDiscoKey.Store(nil) + + for _, di := range c.discoInfo { + di.oldSharedKey = nil + } + + if debugDisco() { + c.dlogf("[v1] magicsock: disco: cleaned up old key after grace period") + } +} + func (c *Conn) SetNetworkUp(up bool) { c.mu.Lock() defer c.mu.Unlock() @@ -3950,6 +4193,12 @@ type discoInfo struct { // Not modified once initialized. sharedKey key.DiscoShared + // oldSharedKey is the precomputed key using our old disco private key. + // This is set during rotation and allows us to decrypt messages from + // peers who haven't received our new key yet. + // Owned by [Conn.mu]. + oldSharedKey *key.DiscoShared + // Mutable fields follow, owned by [Conn.mu]. These are irrelevant when // discoInfo is a peer relay server disco key in the // [relayManager.discoInfoByServerDisco] map: @@ -3961,6 +4210,20 @@ type discoInfo struct { lastPingTime time.Time } +const ( + // discoKeyRotationGracePeriod is the duration for which we keep the old + // disco key after a rotation to allow peers to transition to the new key. + // This very large time window aims to provide substantial grace periods for + // new disco key propagation which could cover recovery from a wide array of + // network problems, while still expiring the old key on a schedule. + discoKeyRotationGracePeriod = 99 * time.Minute + + // minDiscoKeyAge is the minimum age a disco key must be before it can be + // rotated. This prevents accidentally rotating keys too frequently. It is + // not necessary to rotate disco keys on a high frequency schedule. + minDiscoKeyAge = 5 * time.Minute +) + var ( metricNumPeers = clientmetric.NewGauge("magicsock_netmap_num_peers") metricNumDERPConns = clientmetric.NewGauge("magicsock_num_derp_conns") @@ -4029,8 +4292,9 @@ var ( metricSentDiscoBindUDPRelayEndpoint = clientmetric.NewCounter("magicsock_disco_sent_bind_udp_relay_endpoint") metricSentDiscoBindUDPRelayEndpointAnswer = clientmetric.NewCounter("magicsock_disco_sent_bind_udp_relay_endpoint_answer") metricSentDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_sent_alloc_udp_relay_endpoint_request") - metricLocalDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_local_alloc_udp_relay_endpoint_request") metricSentDiscoAllocUDPRelayEndpointResponse = clientmetric.NewCounter("magicsock_disco_sent_alloc_udp_relay_endpoint_response") + metricSentDiscoKeyUpdate = clientmetric.NewCounter("magicsock_disco_sent_key_update") + metricLocalDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_local_alloc_udp_relay_endpoint_request") metricRecvDiscoBadPeer = clientmetric.NewCounter("magicsock_disco_recv_bad_peer") metricRecvDiscoBadKey = clientmetric.NewCounter("magicsock_disco_recv_bad_key") metricRecvDiscoBadParse = clientmetric.NewCounter("magicsock_disco_recv_bad_parse") @@ -4048,8 +4312,10 @@ var ( metricRecvDiscoBindUDPRelayEndpointChallenge = clientmetric.NewCounter("magicsock_disco_recv_bind_udp_relay_endpoint_challenge") metricRecvDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_request") metricRecvDiscoAllocUDPRelayEndpointRequestBadDisco = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_request_bad_disco") - metricRecvDiscoAllocUDPRelayEndpointResponseBadDisco = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_response_bad_disco") metricRecvDiscoAllocUDPRelayEndpointResponse = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_response") + metricRecvDiscoAllocUDPRelayEndpointResponseBadDisco = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_response_bad_disco") + metricRecvDiscoKeyUpdate = clientmetric.NewCounter("magicsock_disco_recv_key_update") + metricRecvDiscoWithOldKey = clientmetric.NewCounter("magicsock_disco_recv_with_old_key") metricLocalDiscoAllocUDPRelayEndpointResponse = clientmetric.NewCounter("magicsock_disco_local_alloc_udp_relay_endpoint_response") metricRecvDiscoDERPPeerNotHere = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_not_here") metricRecvDiscoDERPPeerGoneUnknown = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_gone_unknown") diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index a4c696203..369c16c47 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -396,6 +396,129 @@ func meshStacks(logf logger.Logf, mutateNetmap func(idx int, nm *netmap.NetworkM } } +// waitForPeers waits for all stacks to have the expected number of peers in their status. +func waitForPeers(t *testing.T, timeout time.Duration, stacks ...*magicStack) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + allReady := true + for _, s := range stacks { + if len(s.Status().Peer) != len(stacks)-1 { + allReady = false + break + } + } + if allReady { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("timeout waiting for peers to appear in status") +} + +// waitForDiscoInfo waits for conn to have discoInfo for the given peer disco key. +func waitForDiscoInfo(t *testing.T, conn *Conn, peerKey key.DiscoPublic, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + conn.mu.Lock() + hasInfo := conn.discoInfo[peerKey] != nil + conn.mu.Unlock() + if hasInfo { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("timeout waiting for discoInfo for peer %v", peerKey.ShortString()) +} + +// waitForKeyUpdate waits for KeyUpdate metrics to increase, indicating a KeyUpdate +// message was sent and received. +func waitForKeyUpdate(t *testing.T, sentBefore, recvBefore int64, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if metricSentDiscoKeyUpdate.Value() > sentBefore && + metricRecvDiscoKeyUpdate.Value() > recvBefore { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Errorf("timeout waiting for KeyUpdate: sent %d->%d, recv %d->%d", + sentBefore, metricSentDiscoKeyUpdate.Value(), + recvBefore, metricRecvDiscoKeyUpdate.Value()) +} + +// waitForDiscoKeyChange waits for conn to have discoInfo for newKey and not have +// discoInfo for oldKey, indicating the peer has processed a key rotation. +func waitForDiscoKeyChange(t *testing.T, conn *Conn, oldKey, newKey key.DiscoPublic, timeout time.Duration) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + conn.mu.Lock() + hasNew := conn.discoInfo[newKey] != nil + hasOld := conn.discoInfo[oldKey] != nil + conn.mu.Unlock() + if hasNew && !hasOld { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Errorf("timeout waiting for disco key change from %v to %v", + oldKey.ShortString(), newKey.ShortString()) +} + +// discoPing triggers an immediate disco ping from src to dst, bypassing the +// heartbeat interval. This is useful for tests that need to establish disco +// communication quickly without waiting for the 3-second heartbeat. +// Returns when the ping completes or times out after 2 seconds. +func discoPing(t *testing.T, src, dst *magicStack) { + t.Helper() + + src.conn.mu.Lock() + var dstNode tailcfg.NodeView + for _, peer := range src.conn.peers.All() { + if peer.Key() == dst.Public() { + dstNode = peer + break + } + } + src.conn.mu.Unlock() + + if !dstNode.Valid() { + t.Fatalf("src doesn't have dst in peers") + } + + pingDone := make(chan struct{}) + res := &ipnstate.PingResult{} + src.conn.Ping(dstNode, res, 0, func(pr *ipnstate.PingResult) { + if pr.Err != "" { + t.Logf("disco ping completed with error: %v", pr.Err) + } + close(pingDone) + }) + + select { + case <-pingDone: + case <-time.After(2 * time.Second): + t.Fatalf("disco ping timed out") + } +} + +// ageDiscoInfoForTest sets the lastPingTime for all discoInfo entries to be +// older than the cutoff used in RotateDiscoKey (5 minutes). This prevents +// KeyUpdate messages from being sent during rotation, allowing tests to verify +// netmap-only propagation. +func ageDiscoInfoForTest(conn *Conn) { + conn.mu.Lock() + defer conn.mu.Unlock() + + oldTime := time.Now().Add(-10 * time.Minute) + for _, di := range conn.discoInfo { + di.lastPingTime = oldTime + } +} + func TestNewConn(t *testing.T) { tstest.PanicOnLog() tstest.ResourceCheck(t) @@ -4266,7 +4389,13 @@ func TestRotateDiscoKey(t *testing.T) { } c.mu.Unlock() - c.RotateDiscoKey() + // Advance the disco key creation time to bypass rate limiting + pastTime := time.Now().Add(-10 * time.Minute) + c.discoKeyCreatedAt.Store(&pastTime) + + if err := c.RotateDiscoKey(); err != nil { + t.Fatalf("RotateDiscoKey failed: %v", err) + } newPrivate, newPublic := c.discoKey.Pair() newShort := c.discoKey.Short() @@ -4286,9 +4415,93 @@ func TestRotateDiscoKey(t *testing.T) { } c.mu.Lock() - if len(c.discoInfo) != 0 { - t.Fatalf("expected discoInfo to be cleared, got %d entries", len(c.discoInfo)) + // After graceful rotation, discoInfo should be preserved and updated with new shared keys + if len(c.discoInfo) != 1 { + t.Fatalf("expected discoInfo to be preserved with 1 entry, got %d entries", len(c.discoInfo)) + } + for peerDiscoKey, di := range c.discoInfo { + if peerDiscoKey != testDiscoKey { + t.Fatalf("peer disco key changed unexpectedly") + } + expectedSharedKey := newPrivate.Shared(peerDiscoKey) + if !di.sharedKey.Equal(expectedSharedKey) { + t.Fatalf("shared key was not updated after rotation") + } + if di.oldSharedKey == nil { + t.Fatalf("oldSharedKey should be set after rotation") + } + expectedOldSharedKey := oldPrivate.Shared(peerDiscoKey) + if !di.oldSharedKey.Equal(expectedOldSharedKey) { + t.Fatalf("oldSharedKey is not correct") + } + } + c.mu.Unlock() +} + +func TestRotateDiscoKeyGraceful(t *testing.T) { + c := newConn(t.Logf) + + peerPrivate := key.NewDisco() + peerPublic := peerPrivate.Public() + + c.mu.Lock() + c.discoInfo[peerPublic] = &discoInfo{ + discoKey: peerPublic, + discoShort: peerPublic.ShortString(), + sharedKey: c.discoKey.Private().Shared(peerPublic), + } + oldSharedKey := c.discoInfo[peerPublic].sharedKey + c.mu.Unlock() + + pastTime := time.Now().Add(-10 * time.Minute) + c.discoKeyCreatedAt.Store(&pastTime) + + if err := c.RotateDiscoKey(); err != nil { + t.Fatalf("RotateDiscoKey failed: %v", err) + } + + c.mu.Lock() + di := c.discoInfo[peerPublic] + if di == nil { + t.Fatalf("peer discoInfo was removed during rotation") + } + + if di.sharedKey.Equal(oldSharedKey) { + t.Fatalf("shared key was not updated") + } + + if di.oldSharedKey == nil { + t.Fatalf("oldSharedKey should be set after rotation") + } + if !di.oldSharedKey.Equal(oldSharedKey) { + t.Fatalf("oldSharedKey doesn't match the previous shared key") + } + + testMessage := &disco.Ping{TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}} + cleartext := testMessage.AppendMarshal(nil) + + sealedNew := di.sharedKey.Seal(cleartext) + decryptedNew, ok := di.sharedKey.Open(sealedNew) + if !ok { + t.Fatalf("failed to decrypt message encrypted with new key") + } + if string(decryptedNew) != string(cleartext) { + t.Fatalf("decrypted message doesn't match original") + } + + sealedOld := di.oldSharedKey.Seal(cleartext) + _, ok = di.sharedKey.Open(sealedOld) + if ok { + t.Fatalf("shouldn't be able to decrypt old-key message with new key") } + decryptedOld, ok := di.oldSharedKey.Open(sealedOld) + if !ok { + t.Fatalf("failed to decrypt message encrypted with old key") + } + if string(decryptedOld) != string(cleartext) { + t.Fatalf("decrypted old message doesn't match original") + } + c.mu.Unlock() } @@ -4298,8 +4511,14 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) { keys := make([]key.DiscoPublic, 0, 5) keys = append(keys, c.discoKey.Public()) - for i := 0; i < 4; i++ { - c.RotateDiscoKey() + for i := range 4 { + // Advance the disco key creation time to bypass rate limiting + pastTime := time.Now().Add(-10 * time.Minute) + c.discoKeyCreatedAt.Store(&pastTime) + + if err := c.RotateDiscoKey(); err != nil { + t.Fatalf("rotation %d failed: %v", i+1, err) + } newKey := c.discoKey.Public() for j, oldKey := range keys { @@ -4311,3 +4530,364 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) { keys = append(keys, newKey) } } + +func TestRotateDiscoKeyViaKeyUpdateMessage(t *testing.T) { + tstest.PanicOnLog() + tstest.ResourceCheck(t) + + derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1)) + defer cleanup() + + m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap) + defer m1.Close() + m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap) + defer m2.Close() + + cleanupMesh := meshStacks(t.Logf, nil, m1, m2) + defer cleanupMesh() + + waitForPeers(t, 2*time.Second, m1, m2) + + discoPing(t, m1, m2) + waitForDiscoInfo(t, m2.conn, m1.conn.DiscoPublicKey(), 1*time.Second) + + // Start pinger to maintain active session during rotation + cleanup = newPinger(t, t.Logf, m1, m2) + defer cleanup() + + m1DiscoKeyBefore := m1.conn.DiscoPublicKey() + + sentBefore := metricSentDiscoKeyUpdate.Value() + recvBefore := metricRecvDiscoKeyUpdate.Value() + recvWithOldKeyBefore := metricRecvDiscoWithOldKey.Value() + + pastTime := time.Now().Add(-10 * time.Minute) + m1.conn.discoKeyCreatedAt.Store(&pastTime) + + t.Logf("rotating m1 disco key from %v", m1DiscoKeyBefore.ShortString()) + if err := m1.conn.RotateDiscoKey(); err != nil { + t.Fatalf("RotateDiscoKey failed: %v", err) + } + + m1DiscoKeyAfter := m1.conn.DiscoPublicKey() + if m1DiscoKeyAfter == m1DiscoKeyBefore { + t.Fatalf("m1 disco key didn't change after rotation") + } + t.Logf("m1 disco key rotated to %v", m1DiscoKeyAfter.ShortString()) + + // No epCh push. + + waitForKeyUpdate(t, sentBefore, recvBefore, 2*time.Second) + t.Logf("KeyUpdate sent and received (sent: %d->%d, recv: %d->%d)", + sentBefore, metricSentDiscoKeyUpdate.Value(), + recvBefore, metricRecvDiscoKeyUpdate.Value()) + + waitForDiscoKeyChange(t, m2.conn, m1DiscoKeyBefore, m1DiscoKeyAfter, 2*time.Second) + t.Logf("m2 discoInfo updated to new key") + + sentAfter := metricSentDiscoKeyUpdate.Value() + recvAfter := metricRecvDiscoKeyUpdate.Value() + if sentAfter <= sentBefore { + t.Errorf("KeyUpdate not sent: metric before=%d after=%d", sentBefore, sentAfter) + } + if recvAfter <= recvBefore { + t.Errorf("KeyUpdate not received: metric before=%d after=%d", recvBefore, recvAfter) + } + + m2.conn.mu.Lock() + m2DiscoInfoAfter := m2.conn.discoInfo[m1DiscoKeyAfter] + m2DiscoInfoOld := m2.conn.discoInfo[m1DiscoKeyBefore] + m2.conn.mu.Unlock() + + if m2DiscoInfoAfter == nil { + t.Errorf("m2 doesn't have discoInfo for m1's new key %v", m1DiscoKeyAfter.ShortString()) + } + if m2DiscoInfoOld != nil { + t.Errorf("m2 still has discoInfo for m1's old key %v (should have been replaced)", m1DiscoKeyBefore.ShortString()) + } + + if m1.conn.oldDiscoKey.Load() == nil { + t.Errorf("m1 didn't keep old disco key for grace period") + } + + s1 := m1.Status() + s2 := m2.Status() + if len(s1.Peer) != 1 || len(s2.Peer) != 1 { + t.Fatalf("peers lost track of each other after rotation: m1 peers=%d, m2 peers=%d", len(s1.Peer), len(s2.Peer)) + } + + recvWithOldKeyAfter := metricRecvDiscoWithOldKey.Value() + if recvWithOldKeyAfter > recvWithOldKeyBefore { + t.Logf("m1 received %d messages with old key during transition (expected during graceful rotation)", + recvWithOldKeyAfter-recvWithOldKeyBefore) + } + + t.Logf("disco key rotation via KeyUpdate message successful, active session maintained without control plane") +} + +func TestRotateDiscoKeyViaKeyUpdateDirectUDP(t *testing.T) { + tstest.PanicOnLog() + tstest.ResourceCheck(t) + + derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1)) + defer cleanup() + + m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap) + defer m1.Close() + m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap) + defer m2.Close() + + cleanupMesh := meshStacks(t.Logf, nil, m1, m2) + defer cleanupMesh() + + waitForPeers(t, 2*time.Second, m1, m2) + + cleanup = newPinger(t, t.Logf, m1, m2) + defer cleanup() + + mustDirect(t, t.Logf, m1, m2) + mustDirect(t, t.Logf, m2, m1) + t.Logf("direct UDP paths established") + + m1DiscoKeyBefore := m1.conn.DiscoPublicKey() + + sentKeyUpdateBefore := metricSentDiscoKeyUpdate.Value() + recvKeyUpdateBefore := metricRecvDiscoKeyUpdate.Value() + + pastTime := time.Now().Add(-10 * time.Minute) + m1.conn.discoKeyCreatedAt.Store(&pastTime) + + t.Logf("rotating m1 disco key from %v", m1DiscoKeyBefore.ShortString()) + if err := m1.conn.RotateDiscoKey(); err != nil { + t.Fatalf("RotateDiscoKey failed: %v", err) + } + + m1DiscoKeyAfter := m1.conn.DiscoPublicKey() + if m1DiscoKeyAfter == m1DiscoKeyBefore { + t.Fatalf("m1 disco key didn't change after rotation") + } + t.Logf("m1 disco key rotated to %v", m1DiscoKeyAfter.ShortString()) + + // No push to epCh + + waitForKeyUpdate(t, sentKeyUpdateBefore, recvKeyUpdateBefore, 2*time.Second) + + sentKeyUpdateAfter := metricSentDiscoKeyUpdate.Value() + recvKeyUpdateAfter := metricRecvDiscoKeyUpdate.Value() + + if sentKeyUpdateAfter <= sentKeyUpdateBefore { + t.Errorf("KeyUpdate not sent: before=%d after=%d", sentKeyUpdateBefore, sentKeyUpdateAfter) + } + if recvKeyUpdateAfter <= recvKeyUpdateBefore { + t.Errorf("KeyUpdate not received: before=%d after=%d", recvKeyUpdateBefore, recvKeyUpdateAfter) + } + + m1.conn.mu.Lock() + m1DiscoInfo := m1.conn.discoInfo[m2.conn.DiscoPublicKey()] + var lastPingFrom epAddr + if m1DiscoInfo != nil { + lastPingFrom = m1DiscoInfo.lastPingFrom + } + m1.conn.mu.Unlock() + + if lastPingFrom.ap.IsValid() && lastPingFrom.ap.Addr() != tailcfg.DerpMagicIPAddr { + t.Logf("KeyUpdate sent via direct UDP to %v (as expected)", lastPingFrom.ap) + } else if lastPingFrom.ap.Addr() == tailcfg.DerpMagicIPAddr { + t.Errorf("KeyUpdate sent via DERP, but expected direct UDP path") + } else { + t.Logf("Note: Could not verify path from lastPingFrom") + } + + m2.conn.mu.Lock() + hasNewKey := m2.conn.discoInfo[m1DiscoKeyAfter] != nil + hasOldKey := m2.conn.discoInfo[m1DiscoKeyBefore] != nil + m2.conn.mu.Unlock() + + if !hasNewKey { + t.Errorf("m2 doesn't have discoInfo for m1's new key after KeyUpdate") + } + if hasOldKey { + t.Errorf("m2 still has discoInfo for m1's old key (should have been replaced)") + } + + t.Logf("KeyUpdate via direct UDP successful") +} + +func TestRotateDiscoKeyViaKeyUpdateDERP(t *testing.T) { + tstest.PanicOnLog() + tstest.ResourceCheck(t) + + derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1)) + defer cleanup() + + m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap) + defer m1.Close() + m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap) + defer m2.Close() + + cleanupMesh := meshStacks(t.Logf, nil, m1, m2) + defer cleanupMesh() + + waitForPeers(t, 2*time.Second, m1, m2) + + m1DiscoKeyBefore := m1.conn.DiscoPublicKey() + + discoPing(t, m1, m2) + waitForDiscoInfo(t, m2.conn, m1DiscoKeyBefore, 1*time.Second) + + // Start pinger to maintain active session during rotation + cleanup = newPinger(t, t.Logf, m1, m2) + defer cleanup() + + sentUDPBefore := metricSentDiscoUDP.Value() + sentDERPBefore := metricSentDiscoDERP.Value() + sentKeyUpdateBefore := metricSentDiscoKeyUpdate.Value() + recvKeyUpdateBefore := metricRecvDiscoKeyUpdate.Value() + + pastTime := time.Now().Add(-10 * time.Minute) + m1.conn.discoKeyCreatedAt.Store(&pastTime) + + t.Logf("rotating m1 disco key from %v", m1DiscoKeyBefore.ShortString()) + if err := m1.conn.RotateDiscoKey(); err != nil { + t.Fatalf("RotateDiscoKey failed: %v", err) + } + + m1DiscoKeyAfter := m1.conn.DiscoPublicKey() + if m1DiscoKeyAfter == m1DiscoKeyBefore { + t.Fatalf("m1 disco key didn't change after rotation") + } + t.Logf("m1 disco key rotated to %v", m1DiscoKeyAfter.ShortString()) + + // No push to epCh + + waitForKeyUpdate(t, sentKeyUpdateBefore, recvKeyUpdateBefore, 2*time.Second) + + sentUDPAfter := metricSentDiscoUDP.Value() + sentDERPAfter := metricSentDiscoDERP.Value() + sentKeyUpdateAfter := metricSentDiscoKeyUpdate.Value() + recvKeyUpdateAfter := metricRecvDiscoKeyUpdate.Value() + + if sentKeyUpdateAfter <= sentKeyUpdateBefore { + t.Errorf("KeyUpdate not sent: before=%d after=%d", sentKeyUpdateBefore, sentKeyUpdateAfter) + } + if recvKeyUpdateAfter <= recvKeyUpdateBefore { + t.Errorf("KeyUpdate not received: before=%d after=%d", recvKeyUpdateBefore, recvKeyUpdateAfter) + } + + derpIncreased := sentDERPAfter > sentDERPBefore + udpIncreased := sentUDPAfter > sentUDPBefore + + t.Logf("Disco sends after rotation: UDP %d->%d, DERP %d->%d", + sentUDPBefore, sentUDPAfter, sentDERPBefore, sentDERPAfter) + + if derpIncreased { + t.Logf("KeyUpdate sent via DERP (as expected for DERP-only path)") + } else if udpIncreased { + t.Logf("KeyUpdate sent via UDP (direct path may have been established)") + } else { + t.Logf("Note: Could not determine path from metrics alone") + } + + m2.conn.mu.Lock() + hasNewKey := m2.conn.discoInfo[m1DiscoKeyAfter] != nil + hasOldKey := m2.conn.discoInfo[m1DiscoKeyBefore] != nil + m2.conn.mu.Unlock() + + if !hasNewKey { + t.Errorf("m2 doesn't have discoInfo for m1's new key after KeyUpdate") + } + if hasOldKey { + t.Errorf("m2 still has discoInfo for m1's old key (should have been replaced)") + } + + t.Logf("KeyUpdate via DERP successful") +} + +func TestRotateDiscoKeyViaNetmap(t *testing.T) { + tstest.PanicOnLog() + tstest.ResourceCheck(t) + + derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1)) + defer cleanup() + + m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap) + defer m1.Close() + m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap) + defer m2.Close() + + cleanupMesh := meshStacks(t.Logf, nil, m1, m2) + defer cleanupMesh() + + waitForPeers(t, 2*time.Second, m1, m2) + + cleanup = newPinger(t, t.Logf, m1, m2) + m1DiscoKeyBefore := m1.conn.DiscoPublicKey() + waitForDiscoInfo(t, m2.conn, m1DiscoKeyBefore, 1*time.Second) + cleanup() // Stop pinging - simulate idle session + + sentKeyUpdateBefore := metricSentDiscoKeyUpdate.Value() + + // Allow rotation by making key appear old enough + pastTime := time.Now().Add(-10 * time.Minute) + m1.conn.discoKeyCreatedAt.Store(&pastTime) + + // Age the disco info so m2 is not considered an "active peer" during rotation. + // This prevents KeyUpdate messages from being sent, ensuring we test pure netmap propagation. + ageDiscoInfoForTest(m1.conn) + + t.Logf("rotating m1 disco key from %v (no active session)", m1DiscoKeyBefore.ShortString()) + if err := m1.conn.RotateDiscoKey(); err != nil { + t.Fatalf("RotateDiscoKey failed: %v", err) + } + + m1DiscoKeyAfter := m1.conn.DiscoPublicKey() + if m1DiscoKeyAfter == m1DiscoKeyBefore { + t.Fatalf("m1 disco key didn't change after rotation") + } + t.Logf("m1 disco key rotated to %v", m1DiscoKeyAfter.ShortString()) + + m1.conn.mu.Lock() + m1.epCh <- m1.conn.lastEndpoints + m1.conn.mu.Unlock() + + t.Logf("waiting for netmap update to propagate") + time.Sleep(100 * time.Millisecond) // Give meshStacks time to process + + sentKeyUpdateAfter := metricSentDiscoKeyUpdate.Value() + + if sentKeyUpdateAfter > sentKeyUpdateBefore { + t.Errorf("KeyUpdate was sent (sent %d->%d) but should not have been - test is invalid", + sentKeyUpdateBefore, sentKeyUpdateAfter) + } + t.Logf("KeyUpdate was not sent (session was idle), testing pure netmap propagation") + + if m1.conn.oldDiscoKey.Load() == nil { + t.Errorf("m1 didn't keep old disco key for grace period") + } + + // Instead of using newPinger which waits for heartbeat (3s delay), trigger + // immediate disco ping to test netmap propagation. + t.Logf("triggering immediate disco ping from m2 to m1 (with new key)") + discoPing(t, m2, m1) + + waitForDiscoInfo(t, m2.conn, m1DiscoKeyAfter, 1*time.Second) + + // Now start the actual pinger to verify ongoing communication works + cleanup = newPinger(t, t.Logf, m1, m2) + defer cleanup() + + sentKeyUpdateFinal := metricSentDiscoKeyUpdate.Value() + if sentKeyUpdateFinal > sentKeyUpdateBefore { + t.Errorf("KeyUpdate was sent after pinging resumed (sent %d->%d) - test is invalid", + sentKeyUpdateBefore, sentKeyUpdateFinal) + } + t.Logf("Confirmed: KeyUpdate was never sent (before=%d, after=%d)", sentKeyUpdateBefore, sentKeyUpdateFinal) + + s1 := m1.Status() + s2 := m2.Status() + if len(s1.Peer) != 1 || len(s2.Peer) != 1 { + t.Fatalf("peers lost track of each other after rotation: m1 peers=%d, m2 peers=%d", len(s1.Peer), len(s2.Peer)) + } + + t.Logf("disco key rotation via netmap successful, communication established with new key") +}