disco,ipn/ipnlocal,wgengine/magicsock: add graceful disco key rotation

The client can now rotate a disco key gracefully, wherein it still
accepts traffic from peers using the old disco key for a time, while
informing them about the new key via a new KeyUpdate disco message.

Updates #17756
Updates tailscale/corp#34037

Signed-off-by: James Tucker <james@tailscale.com>
raggi/disco-key-rotate-graceful
James Tucker 1 month ago
parent 7988303d30
commit 4324f312e9
No known key found for this signature in database

@ -51,6 +51,7 @@ const (
TypeCallMeMaybeVia = MessageType(0x07)
TypeAllocateUDPRelayEndpointRequest = MessageType(0x08)
TypeAllocateUDPRelayEndpointResponse = MessageType(0x09)
TypeKeyUpdate = MessageType(0x0a)
)
const v0 = byte(0)
@ -103,6 +104,8 @@ func Parse(p []byte) (Message, error) {
return parseAllocateUDPRelayEndpointRequest(ver, p)
case TypeAllocateUDPRelayEndpointResponse:
return parseAllocateUDPRelayEndpointResponse(ver, p)
case TypeKeyUpdate:
return parseKeyUpdate(ver, p)
default:
return nil, fmt.Errorf("unknown message type 0x%02x", byte(t))
}
@ -278,6 +281,33 @@ func parsePong(ver uint8, p []byte) (m *Pong, err error) {
return m, nil
}
// KeyUpdate is a message sent during disco key rotation to notify a peer
// of our new disco public key. It is sent encrypted with the OLD shared key
// so that the peer can decrypt it before they learn about the new key from
// the control plane.
type KeyUpdate struct {
// NewDiscoKey is the sender's new disco public key.
NewDiscoKey key.DiscoPublic
}
const keyUpdateLen = key.DiscoPublicRawLen
func (m *KeyUpdate) AppendMarshal(b []byte) []byte {
ret, d := appendMsgHeader(b, TypeKeyUpdate, v0, keyUpdateLen)
m.NewDiscoKey.AppendTo(d[:0])
return ret
}
func parseKeyUpdate(ver uint8, p []byte) (*KeyUpdate, error) {
if len(p) < keyUpdateLen {
return nil, errShort
}
m := &KeyUpdate{
NewDiscoKey: key.DiscoPublicFromRaw32(mem.B(p[:keyUpdateLen])),
}
return m, nil
}
// MessageSummary returns a short summary of m for logging purposes.
func MessageSummary(m Message) string {
switch m := m.(type) {
@ -299,6 +329,8 @@ func MessageSummary(m Message) string {
return "allocate-udp-relay-endpoint-request"
case *AllocateUDPRelayEndpointResponse:
return "allocate-udp-relay-endpoint-response"
case *KeyUpdate:
return fmt.Sprintf("key-update new=%v", m.NewDiscoKey.ShortString())
default:
return fmt.Sprintf("%#v", m)
}

@ -38,6 +38,8 @@ func TestMarshalAndParse(t *testing.T) {
},
}
testDiscoKey := key.DiscoPublicFromRaw32(mem.B([]byte{1: 1, 2: 2, 3: 3, 30: 30, 31: 31}))
tests := []struct {
name string
want string
@ -96,6 +98,13 @@ func TestMarshalAndParse(t *testing.T) {
m: &CallMeMaybe{},
want: "03 00",
},
{
name: "key_update",
m: &KeyUpdate{
NewDiscoKey: testDiscoKey,
},
want: "0a 00 00 01 02 03 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 1e 1f",
},
{
name: "call_me_maybe_endpoints",
m: &CallMeMaybe{

@ -6691,7 +6691,9 @@ func (b *LocalBackend) DebugReSTUN() error {
func (b *LocalBackend) DebugRotateDiscoKey() error {
mc := b.MagicConn()
mc.RotateDiscoKey()
if err := mc.RotateDiscoKey(); err != nil {
return err
}
newDiscoKey := mc.DiscoPublicKey()

@ -276,6 +276,16 @@ type Conn struct {
// discoKey is the current disco private and public keypair for this conn.
discoKey *key.DiscoKey
// discoKeyCreatedAt is when the current disco key was created.
// Used for both rate limiting rotations (ensuring keys are old enough to rotate)
// and for tracking when to cleanup the old key after the grace period.
discoKeyCreatedAt atomic.Pointer[time.Time]
// oldDiscoKey is the previous disco private key, kept during the grace
// period after a rotation to allow peers to decrypt messages sent with
// the old key until they receive the new key from the control plane.
oldDiscoKey atomic.Pointer[key.DiscoPrivate]
// ============================================================
// mu guards all following fields; see userspaceEngine lock
// ordering rules against the engine. For derphttp, mu must
@ -600,6 +610,8 @@ func newConn(logf logger.Logf) *Conn {
cloudInfo: newCloudInfo(logf),
}
c.discoKey = key.NewDiscoKeyFromPrivate(discoPrivate)
now := time.Now()
c.discoKeyCreatedAt.Store(&now)
c.bind = &connBind{Conn: c, closed: true}
c.receiveBatchPool = sync.Pool{New: func() any {
msgs := make([]ipv6.Message, c.bind.BatchSize())
@ -1237,26 +1249,71 @@ func (c *Conn) DiscoPublicKey() key.DiscoPublic {
// RotateDiscoKey generates a new discovery key pair and updates the connection
// to use it. This invalidates all existing disco sessions and will cause peers
// to re-establish discovery sessions with the new key.
// RotateDiscoKey rotates the discovery key gracefully. The old key is kept
// for a grace period (discoKeyRotationGracePeriod) to allow peers to transition
// to the new key. Active peers are notified directly via KeyUpdate messages.
//
// This is primarily for debugging and testing purposes, a future enhancement
// should provide a mechanism for seamless rotation by supporting short term use
// of the old key.
func (c *Conn) RotateDiscoKey() {
// Returns an error if the current key is too new to rotate (less than
// minDiscoKeyAge old).
func (c *Conn) RotateDiscoKey() error {
oldShort := c.discoKey.Short()
oldPrivate := c.discoKey.Private()
if createdAt := c.discoKeyCreatedAt.Load(); createdAt != nil {
keyAge := time.Since(*createdAt)
if keyAge < minDiscoKeyAge {
return fmt.Errorf("disco key is only %v old, must be at least %v old to rotate", keyAge.Round(time.Second), minDiscoKeyAge)
}
}
newPrivate := key.NewDisco()
c.mu.Lock()
c.oldDiscoKey.Store(&oldPrivate)
now := time.Now()
c.discoKeyCreatedAt.Store(&now)
c.discoKey.Set(newPrivate)
newShort := c.discoKey.Short()
c.discoInfo = make(map[key.DiscoPublic]*discoInfo)
for peerDiscoKey, di := range c.discoInfo {
di.sharedKey = newPrivate.Shared(peerDiscoKey)
oldShared := oldPrivate.Shared(peerDiscoKey)
di.oldSharedKey = &oldShared
}
cutoff := time.Now().Add(-5 * time.Minute)
var activePeers []key.DiscoPublic
for peerDiscoKey, di := range c.discoInfo {
if di.lastPingTime.After(cutoff) {
activePeers = append(activePeers, peerDiscoKey)
}
}
connCtx := c.connCtx
c.mu.Unlock()
c.logf("magicsock: rotated disco key from %v to %v", oldShort, newShort)
c.logf("magicsock: rotated disco key from %v to %v, notifying %d peers", oldShort, newShort, len(activePeers))
// KeyUpdate messages are encrypted with the OLD shared key so peers can
// decrypt them before learning new key from control plane
go c.sendKeyUpdatesToPeers(activePeers, newPrivate.Public(), oldPrivate)
// TODO(raggi): we should think carefully about and review if we even really
// want to do this. There may be little to no value in practice of dropping
// the old key - doing so increases the chances that we will fail to
// communicate with peers. If we were to introduce a regular disco key
// rotation schedule then old keys should phase out soon enough.
time.AfterFunc(discoKeyRotationGracePeriod, func() {
c.cleanupOldDiscoKey()
})
if connCtx != nil {
c.ReSTUN("disco-key-rotation")
}
return nil
}
// determineEndpoints returns the machine's endpoint addresses. It does a STUN
@ -2220,6 +2277,14 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
sealedBox := msg[discoHeaderLen:]
payload, ok := di.sharedKey.Open(sealedBox)
usedOldKey := false
if !ok && di.oldSharedKey != nil {
payload, ok = di.oldSharedKey.Open(sealedBox)
if ok {
usedOldKey = true
metricRecvDiscoWithOldKey.Add(1)
}
}
if !ok {
// This might have been intended for a previous
// disco key. When we restart we get a new disco key
@ -2237,6 +2302,9 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
metricRecvDiscoBadKey.Add(1)
return
}
if usedOldKey && debugDisco() {
c.logf("magicsock: disco: decrypted message from %v using old key", sender.ShortString())
}
// Emit information about the disco frame into the pcap stream
// if a capture hook is installed.
@ -2280,9 +2348,15 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
}
switch dm := dm.(type) {
case *disco.KeyUpdate:
metricRecvDiscoKeyUpdate.Add(1)
c.handleKeyUpdateLocked(dm, sender, di, src)
case *disco.Ping:
metricRecvDiscoPing.Add(1)
c.handlePingLocked(dm, src, di, derpNodeSrc)
if usedOldKey {
c.sendKeyUpdateToPeerLocked(sender, di)
}
case *disco.Pong:
metricRecvDiscoPong.Add(1)
// There might be multiple nodes for the sender's DiscoKey.
@ -2667,11 +2741,180 @@ func (c *Conn) discoInfoForKnownPeerLocked(k key.DiscoPublic) *discoInfo {
discoShort: k.ShortString(),
sharedKey: c.discoKey.Private().Shared(k),
}
if oldKey := c.oldDiscoKey.Load(); oldKey != nil {
oldShared := oldKey.Shared(k)
di.oldSharedKey = &oldShared
}
c.discoInfo[k] = di
}
return di
}
// handleKeyUpdateLocked processes a KeyUpdate message from a peer, updating
// their disco key and recomputing the shared key.
//
// c.mu must be held.
func (c *Conn) handleKeyUpdateLocked(m *disco.KeyUpdate, oldDiscoKey key.DiscoPublic, di *discoInfo, src epAddr) {
newDiscoKey := m.NewDiscoKey
if newDiscoKey.IsZero() {
c.logf("magicsock: disco: ignoring KeyUpdate with zero key from %v", oldDiscoKey.ShortString())
return
}
if newDiscoKey == oldDiscoKey {
// Same key, nothing to do
return
}
c.logf("magicsock: disco: peer %v updated disco key from %v to %v",
di.discoKey.ShortString(), oldDiscoKey.ShortString(), newDiscoKey.ShortString())
delete(c.discoInfo, oldDiscoKey)
newDi := &discoInfo{
discoKey: newDiscoKey,
discoShort: newDiscoKey.ShortString(),
sharedKey: c.discoKey.Private().Shared(newDiscoKey),
lastPingFrom: di.lastPingFrom,
lastPingTime: di.lastPingTime,
}
if oldKey := c.oldDiscoKey.Load(); oldKey != nil {
oldShared := oldKey.Shared(newDiscoKey)
newDi.oldSharedKey = &oldShared
}
c.discoInfo[newDiscoKey] = newDi
}
// sendKeyUpdateToPeerLocked sends a KeyUpdate message to a single peer.
// This is called when we receive a message from a peer using our old key,
// to accelerate their transition to our new key.
//
// c.mu must be held.
func (c *Conn) sendKeyUpdateToPeerLocked(peerDiscoKey key.DiscoPublic, di *discoInfo) {
if c.oldDiscoKey.Load() == nil {
return
}
if time.Since(di.lastPingTime) < 10*time.Second {
return
}
newKey := c.discoKey.Public()
c.mu.Unlock()
defer c.mu.Lock()
if di.lastPingFrom.ap.IsValid() {
c.sendKeyUpdateToPeer(peerDiscoKey, di.lastPingFrom, newKey)
}
}
// sendKeyUpdateToPeer sends a KeyUpdate message to a peer at the specified address.
func (c *Conn) sendKeyUpdateToPeer(peerDiscoKey key.DiscoPublic, dst epAddr, newKey key.DiscoPublic) {
oldKey := c.oldDiscoKey.Load()
if oldKey == nil {
return
}
keyUpdate := &disco.KeyUpdate{NewDiscoKey: newKey}
cleartext := keyUpdate.AppendMarshal(nil)
oldShared := oldKey.Shared(peerDiscoKey)
sealed := oldShared.Seal(cleartext)
pkt := make([]byte, 0, 512)
if dst.vni.IsSet() {
gh := packet.GeneveHeader{
Version: 0,
Protocol: packet.GeneveProtocolDisco,
VNI: dst.vni,
Control: false,
}
pkt = append(pkt, make([]byte, packet.GeneveFixedHeaderLength)...)
if err := gh.Encode(pkt); err != nil {
return
}
}
pkt = append(pkt, disco.Magic...)
pkt = oldKey.Public().AppendTo(pkt)
pkt = append(pkt, sealed...)
const isDisco = true
if sent, _ := c.sendAddr(dst.ap, key.NodePublic{}, pkt, isDisco, dst.vni.IsSet()); sent {
metricSentDiscoKeyUpdate.Add(1)
if debugDisco() {
c.dlogf("[v1] magicsock: disco: sent key-update to %v at %v", peerDiscoKey.ShortString(), dst)
}
}
}
// sendKeyUpdatesToPeers sends KeyUpdate messages to a list of peers.
func (c *Conn) sendKeyUpdatesToPeers(peers []key.DiscoPublic, newKey key.DiscoPublic, oldKey key.DiscoPrivate) {
for _, peerDiscoKey := range peers {
c.mu.Lock()
di := c.discoInfo[peerDiscoKey]
if di == nil || !di.lastPingFrom.ap.IsValid() {
c.mu.Unlock()
continue
}
dst := di.lastPingFrom
c.mu.Unlock()
keyUpdate := &disco.KeyUpdate{NewDiscoKey: newKey}
cleartext := keyUpdate.AppendMarshal(nil)
oldShared := oldKey.Shared(peerDiscoKey)
sealed := oldShared.Seal(cleartext)
pkt := make([]byte, 0, 512)
if dst.vni.IsSet() {
gh := packet.GeneveHeader{
Version: 0,
Protocol: packet.GeneveProtocolDisco,
VNI: dst.vni,
Control: false,
}
pkt = append(pkt, make([]byte, packet.GeneveFixedHeaderLength)...)
if err := gh.Encode(pkt); err != nil {
continue
}
}
pkt = append(pkt, disco.Magic...)
pkt = oldKey.Public().AppendTo(pkt)
pkt = append(pkt, sealed...)
const isDisco = true
if sent, _ := c.sendAddr(dst.ap, key.NodePublic{}, pkt, isDisco, dst.vni.IsSet()); sent {
metricSentDiscoKeyUpdate.Add(1)
}
}
}
// cleanupOldDiscoKey removes the old disco key after the grace period.
// The grace period is measured from when the current (new) key was created.
func (c *Conn) cleanupOldDiscoKey() {
c.mu.Lock()
defer c.mu.Unlock()
createdAt := c.discoKeyCreatedAt.Load()
if createdAt == nil {
return
}
if time.Since(*createdAt) < discoKeyRotationGracePeriod {
return
}
c.oldDiscoKey.Store(nil)
for _, di := range c.discoInfo {
di.oldSharedKey = nil
}
if debugDisco() {
c.dlogf("[v1] magicsock: disco: cleaned up old key after grace period")
}
}
func (c *Conn) SetNetworkUp(up bool) {
c.mu.Lock()
defer c.mu.Unlock()
@ -3950,6 +4193,12 @@ type discoInfo struct {
// Not modified once initialized.
sharedKey key.DiscoShared
// oldSharedKey is the precomputed key using our old disco private key.
// This is set during rotation and allows us to decrypt messages from
// peers who haven't received our new key yet.
// Owned by [Conn.mu].
oldSharedKey *key.DiscoShared
// Mutable fields follow, owned by [Conn.mu]. These are irrelevant when
// discoInfo is a peer relay server disco key in the
// [relayManager.discoInfoByServerDisco] map:
@ -3961,6 +4210,20 @@ type discoInfo struct {
lastPingTime time.Time
}
const (
// discoKeyRotationGracePeriod is the duration for which we keep the old
// disco key after a rotation to allow peers to transition to the new key.
// This very large time window aims to provide substantial grace periods for
// new disco key propagation which could cover recovery from a wide array of
// network problems, while still expiring the old key on a schedule.
discoKeyRotationGracePeriod = 99 * time.Minute
// minDiscoKeyAge is the minimum age a disco key must be before it can be
// rotated. This prevents accidentally rotating keys too frequently. It is
// not necessary to rotate disco keys on a high frequency schedule.
minDiscoKeyAge = 5 * time.Minute
)
var (
metricNumPeers = clientmetric.NewGauge("magicsock_netmap_num_peers")
metricNumDERPConns = clientmetric.NewGauge("magicsock_num_derp_conns")
@ -4029,8 +4292,9 @@ var (
metricSentDiscoBindUDPRelayEndpoint = clientmetric.NewCounter("magicsock_disco_sent_bind_udp_relay_endpoint")
metricSentDiscoBindUDPRelayEndpointAnswer = clientmetric.NewCounter("magicsock_disco_sent_bind_udp_relay_endpoint_answer")
metricSentDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_sent_alloc_udp_relay_endpoint_request")
metricLocalDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_local_alloc_udp_relay_endpoint_request")
metricSentDiscoAllocUDPRelayEndpointResponse = clientmetric.NewCounter("magicsock_disco_sent_alloc_udp_relay_endpoint_response")
metricSentDiscoKeyUpdate = clientmetric.NewCounter("magicsock_disco_sent_key_update")
metricLocalDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_local_alloc_udp_relay_endpoint_request")
metricRecvDiscoBadPeer = clientmetric.NewCounter("magicsock_disco_recv_bad_peer")
metricRecvDiscoBadKey = clientmetric.NewCounter("magicsock_disco_recv_bad_key")
metricRecvDiscoBadParse = clientmetric.NewCounter("magicsock_disco_recv_bad_parse")
@ -4048,8 +4312,10 @@ var (
metricRecvDiscoBindUDPRelayEndpointChallenge = clientmetric.NewCounter("magicsock_disco_recv_bind_udp_relay_endpoint_challenge")
metricRecvDiscoAllocUDPRelayEndpointRequest = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_request")
metricRecvDiscoAllocUDPRelayEndpointRequestBadDisco = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_request_bad_disco")
metricRecvDiscoAllocUDPRelayEndpointResponseBadDisco = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_response_bad_disco")
metricRecvDiscoAllocUDPRelayEndpointResponse = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_response")
metricRecvDiscoAllocUDPRelayEndpointResponseBadDisco = clientmetric.NewCounter("magicsock_disco_recv_alloc_udp_relay_endpoint_response_bad_disco")
metricRecvDiscoKeyUpdate = clientmetric.NewCounter("magicsock_disco_recv_key_update")
metricRecvDiscoWithOldKey = clientmetric.NewCounter("magicsock_disco_recv_with_old_key")
metricLocalDiscoAllocUDPRelayEndpointResponse = clientmetric.NewCounter("magicsock_disco_local_alloc_udp_relay_endpoint_response")
metricRecvDiscoDERPPeerNotHere = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_not_here")
metricRecvDiscoDERPPeerGoneUnknown = clientmetric.NewCounter("magicsock_disco_recv_derp_peer_gone_unknown")

@ -396,6 +396,129 @@ func meshStacks(logf logger.Logf, mutateNetmap func(idx int, nm *netmap.NetworkM
}
}
// waitForPeers waits for all stacks to have the expected number of peers in their status.
func waitForPeers(t *testing.T, timeout time.Duration, stacks ...*magicStack) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
allReady := true
for _, s := range stacks {
if len(s.Status().Peer) != len(stacks)-1 {
allReady = false
break
}
}
if allReady {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("timeout waiting for peers to appear in status")
}
// waitForDiscoInfo waits for conn to have discoInfo for the given peer disco key.
func waitForDiscoInfo(t *testing.T, conn *Conn, peerKey key.DiscoPublic, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn.mu.Lock()
hasInfo := conn.discoInfo[peerKey] != nil
conn.mu.Unlock()
if hasInfo {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("timeout waiting for discoInfo for peer %v", peerKey.ShortString())
}
// waitForKeyUpdate waits for KeyUpdate metrics to increase, indicating a KeyUpdate
// message was sent and received.
func waitForKeyUpdate(t *testing.T, sentBefore, recvBefore int64, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
if metricSentDiscoKeyUpdate.Value() > sentBefore &&
metricRecvDiscoKeyUpdate.Value() > recvBefore {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Errorf("timeout waiting for KeyUpdate: sent %d->%d, recv %d->%d",
sentBefore, metricSentDiscoKeyUpdate.Value(),
recvBefore, metricRecvDiscoKeyUpdate.Value())
}
// waitForDiscoKeyChange waits for conn to have discoInfo for newKey and not have
// discoInfo for oldKey, indicating the peer has processed a key rotation.
func waitForDiscoKeyChange(t *testing.T, conn *Conn, oldKey, newKey key.DiscoPublic, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
conn.mu.Lock()
hasNew := conn.discoInfo[newKey] != nil
hasOld := conn.discoInfo[oldKey] != nil
conn.mu.Unlock()
if hasNew && !hasOld {
return
}
time.Sleep(10 * time.Millisecond)
}
t.Errorf("timeout waiting for disco key change from %v to %v",
oldKey.ShortString(), newKey.ShortString())
}
// discoPing triggers an immediate disco ping from src to dst, bypassing the
// heartbeat interval. This is useful for tests that need to establish disco
// communication quickly without waiting for the 3-second heartbeat.
// Returns when the ping completes or times out after 2 seconds.
func discoPing(t *testing.T, src, dst *magicStack) {
t.Helper()
src.conn.mu.Lock()
var dstNode tailcfg.NodeView
for _, peer := range src.conn.peers.All() {
if peer.Key() == dst.Public() {
dstNode = peer
break
}
}
src.conn.mu.Unlock()
if !dstNode.Valid() {
t.Fatalf("src doesn't have dst in peers")
}
pingDone := make(chan struct{})
res := &ipnstate.PingResult{}
src.conn.Ping(dstNode, res, 0, func(pr *ipnstate.PingResult) {
if pr.Err != "" {
t.Logf("disco ping completed with error: %v", pr.Err)
}
close(pingDone)
})
select {
case <-pingDone:
case <-time.After(2 * time.Second):
t.Fatalf("disco ping timed out")
}
}
// ageDiscoInfoForTest sets the lastPingTime for all discoInfo entries to be
// older than the cutoff used in RotateDiscoKey (5 minutes). This prevents
// KeyUpdate messages from being sent during rotation, allowing tests to verify
// netmap-only propagation.
func ageDiscoInfoForTest(conn *Conn) {
conn.mu.Lock()
defer conn.mu.Unlock()
oldTime := time.Now().Add(-10 * time.Minute)
for _, di := range conn.discoInfo {
di.lastPingTime = oldTime
}
}
func TestNewConn(t *testing.T) {
tstest.PanicOnLog()
tstest.ResourceCheck(t)
@ -4266,7 +4389,13 @@ func TestRotateDiscoKey(t *testing.T) {
}
c.mu.Unlock()
c.RotateDiscoKey()
// Advance the disco key creation time to bypass rate limiting
pastTime := time.Now().Add(-10 * time.Minute)
c.discoKeyCreatedAt.Store(&pastTime)
if err := c.RotateDiscoKey(); err != nil {
t.Fatalf("RotateDiscoKey failed: %v", err)
}
newPrivate, newPublic := c.discoKey.Pair()
newShort := c.discoKey.Short()
@ -4286,9 +4415,93 @@ func TestRotateDiscoKey(t *testing.T) {
}
c.mu.Lock()
if len(c.discoInfo) != 0 {
t.Fatalf("expected discoInfo to be cleared, got %d entries", len(c.discoInfo))
// After graceful rotation, discoInfo should be preserved and updated with new shared keys
if len(c.discoInfo) != 1 {
t.Fatalf("expected discoInfo to be preserved with 1 entry, got %d entries", len(c.discoInfo))
}
for peerDiscoKey, di := range c.discoInfo {
if peerDiscoKey != testDiscoKey {
t.Fatalf("peer disco key changed unexpectedly")
}
expectedSharedKey := newPrivate.Shared(peerDiscoKey)
if !di.sharedKey.Equal(expectedSharedKey) {
t.Fatalf("shared key was not updated after rotation")
}
if di.oldSharedKey == nil {
t.Fatalf("oldSharedKey should be set after rotation")
}
expectedOldSharedKey := oldPrivate.Shared(peerDiscoKey)
if !di.oldSharedKey.Equal(expectedOldSharedKey) {
t.Fatalf("oldSharedKey is not correct")
}
}
c.mu.Unlock()
}
func TestRotateDiscoKeyGraceful(t *testing.T) {
c := newConn(t.Logf)
peerPrivate := key.NewDisco()
peerPublic := peerPrivate.Public()
c.mu.Lock()
c.discoInfo[peerPublic] = &discoInfo{
discoKey: peerPublic,
discoShort: peerPublic.ShortString(),
sharedKey: c.discoKey.Private().Shared(peerPublic),
}
oldSharedKey := c.discoInfo[peerPublic].sharedKey
c.mu.Unlock()
pastTime := time.Now().Add(-10 * time.Minute)
c.discoKeyCreatedAt.Store(&pastTime)
if err := c.RotateDiscoKey(); err != nil {
t.Fatalf("RotateDiscoKey failed: %v", err)
}
c.mu.Lock()
di := c.discoInfo[peerPublic]
if di == nil {
t.Fatalf("peer discoInfo was removed during rotation")
}
if di.sharedKey.Equal(oldSharedKey) {
t.Fatalf("shared key was not updated")
}
if di.oldSharedKey == nil {
t.Fatalf("oldSharedKey should be set after rotation")
}
if !di.oldSharedKey.Equal(oldSharedKey) {
t.Fatalf("oldSharedKey doesn't match the previous shared key")
}
testMessage := &disco.Ping{TxID: [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}}
cleartext := testMessage.AppendMarshal(nil)
sealedNew := di.sharedKey.Seal(cleartext)
decryptedNew, ok := di.sharedKey.Open(sealedNew)
if !ok {
t.Fatalf("failed to decrypt message encrypted with new key")
}
if string(decryptedNew) != string(cleartext) {
t.Fatalf("decrypted message doesn't match original")
}
sealedOld := di.oldSharedKey.Seal(cleartext)
_, ok = di.sharedKey.Open(sealedOld)
if ok {
t.Fatalf("shouldn't be able to decrypt old-key message with new key")
}
decryptedOld, ok := di.oldSharedKey.Open(sealedOld)
if !ok {
t.Fatalf("failed to decrypt message encrypted with old key")
}
if string(decryptedOld) != string(cleartext) {
t.Fatalf("decrypted old message doesn't match original")
}
c.mu.Unlock()
}
@ -4298,8 +4511,14 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) {
keys := make([]key.DiscoPublic, 0, 5)
keys = append(keys, c.discoKey.Public())
for i := 0; i < 4; i++ {
c.RotateDiscoKey()
for i := range 4 {
// Advance the disco key creation time to bypass rate limiting
pastTime := time.Now().Add(-10 * time.Minute)
c.discoKeyCreatedAt.Store(&pastTime)
if err := c.RotateDiscoKey(); err != nil {
t.Fatalf("rotation %d failed: %v", i+1, err)
}
newKey := c.discoKey.Public()
for j, oldKey := range keys {
@ -4311,3 +4530,364 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) {
keys = append(keys, newKey)
}
}
func TestRotateDiscoKeyViaKeyUpdateMessage(t *testing.T) {
tstest.PanicOnLog()
tstest.ResourceCheck(t)
derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1))
defer cleanup()
m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m1.Close()
m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m2.Close()
cleanupMesh := meshStacks(t.Logf, nil, m1, m2)
defer cleanupMesh()
waitForPeers(t, 2*time.Second, m1, m2)
discoPing(t, m1, m2)
waitForDiscoInfo(t, m2.conn, m1.conn.DiscoPublicKey(), 1*time.Second)
// Start pinger to maintain active session during rotation
cleanup = newPinger(t, t.Logf, m1, m2)
defer cleanup()
m1DiscoKeyBefore := m1.conn.DiscoPublicKey()
sentBefore := metricSentDiscoKeyUpdate.Value()
recvBefore := metricRecvDiscoKeyUpdate.Value()
recvWithOldKeyBefore := metricRecvDiscoWithOldKey.Value()
pastTime := time.Now().Add(-10 * time.Minute)
m1.conn.discoKeyCreatedAt.Store(&pastTime)
t.Logf("rotating m1 disco key from %v", m1DiscoKeyBefore.ShortString())
if err := m1.conn.RotateDiscoKey(); err != nil {
t.Fatalf("RotateDiscoKey failed: %v", err)
}
m1DiscoKeyAfter := m1.conn.DiscoPublicKey()
if m1DiscoKeyAfter == m1DiscoKeyBefore {
t.Fatalf("m1 disco key didn't change after rotation")
}
t.Logf("m1 disco key rotated to %v", m1DiscoKeyAfter.ShortString())
// No epCh push.
waitForKeyUpdate(t, sentBefore, recvBefore, 2*time.Second)
t.Logf("KeyUpdate sent and received (sent: %d->%d, recv: %d->%d)",
sentBefore, metricSentDiscoKeyUpdate.Value(),
recvBefore, metricRecvDiscoKeyUpdate.Value())
waitForDiscoKeyChange(t, m2.conn, m1DiscoKeyBefore, m1DiscoKeyAfter, 2*time.Second)
t.Logf("m2 discoInfo updated to new key")
sentAfter := metricSentDiscoKeyUpdate.Value()
recvAfter := metricRecvDiscoKeyUpdate.Value()
if sentAfter <= sentBefore {
t.Errorf("KeyUpdate not sent: metric before=%d after=%d", sentBefore, sentAfter)
}
if recvAfter <= recvBefore {
t.Errorf("KeyUpdate not received: metric before=%d after=%d", recvBefore, recvAfter)
}
m2.conn.mu.Lock()
m2DiscoInfoAfter := m2.conn.discoInfo[m1DiscoKeyAfter]
m2DiscoInfoOld := m2.conn.discoInfo[m1DiscoKeyBefore]
m2.conn.mu.Unlock()
if m2DiscoInfoAfter == nil {
t.Errorf("m2 doesn't have discoInfo for m1's new key %v", m1DiscoKeyAfter.ShortString())
}
if m2DiscoInfoOld != nil {
t.Errorf("m2 still has discoInfo for m1's old key %v (should have been replaced)", m1DiscoKeyBefore.ShortString())
}
if m1.conn.oldDiscoKey.Load() == nil {
t.Errorf("m1 didn't keep old disco key for grace period")
}
s1 := m1.Status()
s2 := m2.Status()
if len(s1.Peer) != 1 || len(s2.Peer) != 1 {
t.Fatalf("peers lost track of each other after rotation: m1 peers=%d, m2 peers=%d", len(s1.Peer), len(s2.Peer))
}
recvWithOldKeyAfter := metricRecvDiscoWithOldKey.Value()
if recvWithOldKeyAfter > recvWithOldKeyBefore {
t.Logf("m1 received %d messages with old key during transition (expected during graceful rotation)",
recvWithOldKeyAfter-recvWithOldKeyBefore)
}
t.Logf("disco key rotation via KeyUpdate message successful, active session maintained without control plane")
}
func TestRotateDiscoKeyViaKeyUpdateDirectUDP(t *testing.T) {
tstest.PanicOnLog()
tstest.ResourceCheck(t)
derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1))
defer cleanup()
m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m1.Close()
m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m2.Close()
cleanupMesh := meshStacks(t.Logf, nil, m1, m2)
defer cleanupMesh()
waitForPeers(t, 2*time.Second, m1, m2)
cleanup = newPinger(t, t.Logf, m1, m2)
defer cleanup()
mustDirect(t, t.Logf, m1, m2)
mustDirect(t, t.Logf, m2, m1)
t.Logf("direct UDP paths established")
m1DiscoKeyBefore := m1.conn.DiscoPublicKey()
sentKeyUpdateBefore := metricSentDiscoKeyUpdate.Value()
recvKeyUpdateBefore := metricRecvDiscoKeyUpdate.Value()
pastTime := time.Now().Add(-10 * time.Minute)
m1.conn.discoKeyCreatedAt.Store(&pastTime)
t.Logf("rotating m1 disco key from %v", m1DiscoKeyBefore.ShortString())
if err := m1.conn.RotateDiscoKey(); err != nil {
t.Fatalf("RotateDiscoKey failed: %v", err)
}
m1DiscoKeyAfter := m1.conn.DiscoPublicKey()
if m1DiscoKeyAfter == m1DiscoKeyBefore {
t.Fatalf("m1 disco key didn't change after rotation")
}
t.Logf("m1 disco key rotated to %v", m1DiscoKeyAfter.ShortString())
// No push to epCh
waitForKeyUpdate(t, sentKeyUpdateBefore, recvKeyUpdateBefore, 2*time.Second)
sentKeyUpdateAfter := metricSentDiscoKeyUpdate.Value()
recvKeyUpdateAfter := metricRecvDiscoKeyUpdate.Value()
if sentKeyUpdateAfter <= sentKeyUpdateBefore {
t.Errorf("KeyUpdate not sent: before=%d after=%d", sentKeyUpdateBefore, sentKeyUpdateAfter)
}
if recvKeyUpdateAfter <= recvKeyUpdateBefore {
t.Errorf("KeyUpdate not received: before=%d after=%d", recvKeyUpdateBefore, recvKeyUpdateAfter)
}
m1.conn.mu.Lock()
m1DiscoInfo := m1.conn.discoInfo[m2.conn.DiscoPublicKey()]
var lastPingFrom epAddr
if m1DiscoInfo != nil {
lastPingFrom = m1DiscoInfo.lastPingFrom
}
m1.conn.mu.Unlock()
if lastPingFrom.ap.IsValid() && lastPingFrom.ap.Addr() != tailcfg.DerpMagicIPAddr {
t.Logf("KeyUpdate sent via direct UDP to %v (as expected)", lastPingFrom.ap)
} else if lastPingFrom.ap.Addr() == tailcfg.DerpMagicIPAddr {
t.Errorf("KeyUpdate sent via DERP, but expected direct UDP path")
} else {
t.Logf("Note: Could not verify path from lastPingFrom")
}
m2.conn.mu.Lock()
hasNewKey := m2.conn.discoInfo[m1DiscoKeyAfter] != nil
hasOldKey := m2.conn.discoInfo[m1DiscoKeyBefore] != nil
m2.conn.mu.Unlock()
if !hasNewKey {
t.Errorf("m2 doesn't have discoInfo for m1's new key after KeyUpdate")
}
if hasOldKey {
t.Errorf("m2 still has discoInfo for m1's old key (should have been replaced)")
}
t.Logf("KeyUpdate via direct UDP successful")
}
func TestRotateDiscoKeyViaKeyUpdateDERP(t *testing.T) {
tstest.PanicOnLog()
tstest.ResourceCheck(t)
derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1))
defer cleanup()
m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m1.Close()
m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m2.Close()
cleanupMesh := meshStacks(t.Logf, nil, m1, m2)
defer cleanupMesh()
waitForPeers(t, 2*time.Second, m1, m2)
m1DiscoKeyBefore := m1.conn.DiscoPublicKey()
discoPing(t, m1, m2)
waitForDiscoInfo(t, m2.conn, m1DiscoKeyBefore, 1*time.Second)
// Start pinger to maintain active session during rotation
cleanup = newPinger(t, t.Logf, m1, m2)
defer cleanup()
sentUDPBefore := metricSentDiscoUDP.Value()
sentDERPBefore := metricSentDiscoDERP.Value()
sentKeyUpdateBefore := metricSentDiscoKeyUpdate.Value()
recvKeyUpdateBefore := metricRecvDiscoKeyUpdate.Value()
pastTime := time.Now().Add(-10 * time.Minute)
m1.conn.discoKeyCreatedAt.Store(&pastTime)
t.Logf("rotating m1 disco key from %v", m1DiscoKeyBefore.ShortString())
if err := m1.conn.RotateDiscoKey(); err != nil {
t.Fatalf("RotateDiscoKey failed: %v", err)
}
m1DiscoKeyAfter := m1.conn.DiscoPublicKey()
if m1DiscoKeyAfter == m1DiscoKeyBefore {
t.Fatalf("m1 disco key didn't change after rotation")
}
t.Logf("m1 disco key rotated to %v", m1DiscoKeyAfter.ShortString())
// No push to epCh
waitForKeyUpdate(t, sentKeyUpdateBefore, recvKeyUpdateBefore, 2*time.Second)
sentUDPAfter := metricSentDiscoUDP.Value()
sentDERPAfter := metricSentDiscoDERP.Value()
sentKeyUpdateAfter := metricSentDiscoKeyUpdate.Value()
recvKeyUpdateAfter := metricRecvDiscoKeyUpdate.Value()
if sentKeyUpdateAfter <= sentKeyUpdateBefore {
t.Errorf("KeyUpdate not sent: before=%d after=%d", sentKeyUpdateBefore, sentKeyUpdateAfter)
}
if recvKeyUpdateAfter <= recvKeyUpdateBefore {
t.Errorf("KeyUpdate not received: before=%d after=%d", recvKeyUpdateBefore, recvKeyUpdateAfter)
}
derpIncreased := sentDERPAfter > sentDERPBefore
udpIncreased := sentUDPAfter > sentUDPBefore
t.Logf("Disco sends after rotation: UDP %d->%d, DERP %d->%d",
sentUDPBefore, sentUDPAfter, sentDERPBefore, sentDERPAfter)
if derpIncreased {
t.Logf("KeyUpdate sent via DERP (as expected for DERP-only path)")
} else if udpIncreased {
t.Logf("KeyUpdate sent via UDP (direct path may have been established)")
} else {
t.Logf("Note: Could not determine path from metrics alone")
}
m2.conn.mu.Lock()
hasNewKey := m2.conn.discoInfo[m1DiscoKeyAfter] != nil
hasOldKey := m2.conn.discoInfo[m1DiscoKeyBefore] != nil
m2.conn.mu.Unlock()
if !hasNewKey {
t.Errorf("m2 doesn't have discoInfo for m1's new key after KeyUpdate")
}
if hasOldKey {
t.Errorf("m2 still has discoInfo for m1's old key (should have been replaced)")
}
t.Logf("KeyUpdate via DERP successful")
}
func TestRotateDiscoKeyViaNetmap(t *testing.T) {
tstest.PanicOnLog()
tstest.ResourceCheck(t)
derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1))
defer cleanup()
m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m1.Close()
m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m2.Close()
cleanupMesh := meshStacks(t.Logf, nil, m1, m2)
defer cleanupMesh()
waitForPeers(t, 2*time.Second, m1, m2)
cleanup = newPinger(t, t.Logf, m1, m2)
m1DiscoKeyBefore := m1.conn.DiscoPublicKey()
waitForDiscoInfo(t, m2.conn, m1DiscoKeyBefore, 1*time.Second)
cleanup() // Stop pinging - simulate idle session
sentKeyUpdateBefore := metricSentDiscoKeyUpdate.Value()
// Allow rotation by making key appear old enough
pastTime := time.Now().Add(-10 * time.Minute)
m1.conn.discoKeyCreatedAt.Store(&pastTime)
// Age the disco info so m2 is not considered an "active peer" during rotation.
// This prevents KeyUpdate messages from being sent, ensuring we test pure netmap propagation.
ageDiscoInfoForTest(m1.conn)
t.Logf("rotating m1 disco key from %v (no active session)", m1DiscoKeyBefore.ShortString())
if err := m1.conn.RotateDiscoKey(); err != nil {
t.Fatalf("RotateDiscoKey failed: %v", err)
}
m1DiscoKeyAfter := m1.conn.DiscoPublicKey()
if m1DiscoKeyAfter == m1DiscoKeyBefore {
t.Fatalf("m1 disco key didn't change after rotation")
}
t.Logf("m1 disco key rotated to %v", m1DiscoKeyAfter.ShortString())
m1.conn.mu.Lock()
m1.epCh <- m1.conn.lastEndpoints
m1.conn.mu.Unlock()
t.Logf("waiting for netmap update to propagate")
time.Sleep(100 * time.Millisecond) // Give meshStacks time to process
sentKeyUpdateAfter := metricSentDiscoKeyUpdate.Value()
if sentKeyUpdateAfter > sentKeyUpdateBefore {
t.Errorf("KeyUpdate was sent (sent %d->%d) but should not have been - test is invalid",
sentKeyUpdateBefore, sentKeyUpdateAfter)
}
t.Logf("KeyUpdate was not sent (session was idle), testing pure netmap propagation")
if m1.conn.oldDiscoKey.Load() == nil {
t.Errorf("m1 didn't keep old disco key for grace period")
}
// Instead of using newPinger which waits for heartbeat (3s delay), trigger
// immediate disco ping to test netmap propagation.
t.Logf("triggering immediate disco ping from m2 to m1 (with new key)")
discoPing(t, m2, m1)
waitForDiscoInfo(t, m2.conn, m1DiscoKeyAfter, 1*time.Second)
// Now start the actual pinger to verify ongoing communication works
cleanup = newPinger(t, t.Logf, m1, m2)
defer cleanup()
sentKeyUpdateFinal := metricSentDiscoKeyUpdate.Value()
if sentKeyUpdateFinal > sentKeyUpdateBefore {
t.Errorf("KeyUpdate was sent after pinging resumed (sent %d->%d) - test is invalid",
sentKeyUpdateBefore, sentKeyUpdateFinal)
}
t.Logf("Confirmed: KeyUpdate was never sent (before=%d, after=%d)", sentKeyUpdateBefore, sentKeyUpdateFinal)
s1 := m1.Status()
s2 := m2.Status()
if len(s1.Peer) != 1 || len(s2.Peer) != 1 {
t.Fatalf("peers lost track of each other after rotation: m1 peers=%d, m2 peers=%d", len(s1.Peer), len(s2.Peer))
}
t.Logf("disco key rotation via netmap successful, communication established with new key")
}

Loading…
Cancel
Save