From c09c95ef67d5fe9ff127cf2102f189e47e41b119 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Mon, 3 Nov 2025 16:41:37 -0800 Subject: [PATCH] types/key,wgengine/magicsock,control/controlclient,ipn: add debug disco key rotation Adds the ability to rotate discovery keys on running clients, needed for testing upcoming disco key distribution changes. Introduces key.DiscoKey, an atomic container for a disco private key, public key, and the public key's ShortString, replacing the prior separate atomic fields. magicsock.Conn has a new RotateDiscoKey method, and access to this is provided via localapi and a CLI debug command. Note that this implementation is primarily for testing as it stands, and regular use should likely introduce an additional mechanism that allows the old key to be used for some time, to provide a seamless key rotation rather than one that invalidates all sessions. Updates tailscale/corp#34037 Signed-off-by: James Tucker --- cmd/tailscale/cli/debug.go | 6 +++ control/controlclient/auto.go | 7 +++ control/controlclient/client.go | 6 +++ control/controlclient/direct.go | 16 ++++-- control/controlclient/direct_test.go | 26 +++++++++ ipn/ipnlocal/local.go | 24 +++++++++ ipn/ipnlocal/state_test.go | 5 ++ ipn/localapi/debug.go | 20 +++++++ wgengine/magicsock/disco_atomic.go | 58 ++++++++++++++++++++ wgengine/magicsock/disco_atomic_test.go | 70 +++++++++++++++++++++++++ wgengine/magicsock/endpoint.go | 4 +- wgengine/magicsock/endpoint_test.go | 25 +++++++-- wgengine/magicsock/magicsock.go | 63 ++++++++++++++-------- wgengine/magicsock/magicsock_test.go | 70 +++++++++++++++++++++++++ wgengine/magicsock/relaymanager.go | 4 +- wgengine/magicsock/relaymanager_test.go | 8 ++- 16 files changed, 375 insertions(+), 37 deletions(-) create mode 100644 wgengine/magicsock/disco_atomic.go create mode 100644 wgengine/magicsock/disco_atomic_test.go diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index ffed51a63..2facd66ae 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -182,6 +182,12 @@ func debugCmd() *ffcli.Command { Exec: localAPIAction("rebind"), ShortHelp: "Force a magicsock rebind", }, + { + Name: "rotate-disco-key", + ShortUsage: "tailscale debug rotate-disco-key", + Exec: localAPIAction("rotate-disco-key"), + ShortHelp: "Rotate the discovery key", + }, { Name: "derp-set-on-demand", ShortUsage: "tailscale debug derp-set-on-demand", diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index 3cbfe8581..336a8d491 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -767,6 +767,13 @@ func (c *Auto) UpdateEndpoints(endpoints []tailcfg.Endpoint) { } } +// SetDiscoPublicKey sets the client's Disco public to key and sends the change +// to the control server. +func (c *Auto) SetDiscoPublicKey(key key.DiscoPublic) { + c.direct.SetDiscoPublicKey(key) + c.updateControl() +} + func (c *Auto) Shutdown() { c.mu.Lock() if c.closed { diff --git a/control/controlclient/client.go b/control/controlclient/client.go index d0aa129ae..41b39622b 100644 --- a/control/controlclient/client.go +++ b/control/controlclient/client.go @@ -12,6 +12,7 @@ import ( "context" "tailscale.com/tailcfg" + "tailscale.com/types/key" ) // LoginFlags is a bitmask of options to change the behavior of Client.Login @@ -80,7 +81,12 @@ type Client interface { // TODO: a server-side change would let us simply upload this // in a separate http request. It has nothing to do with the rest of // the state machine. + // Note: the auto client uploads the new endpoints to control immediately. UpdateEndpoints(endpoints []tailcfg.Endpoint) + // SetDiscoPublicKey updates the disco public key that will be sent in + // future map requests. This should be called after rotating the discovery key. + // Note: the auto client uploads the new key to control immediately. + SetDiscoPublicKey(key.DiscoPublic) // ClientID returns the ClientID of a client. This ID is meant to // distinguish one client from another. ClientID() int64 diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 62bbb3586..006a801ef 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -74,7 +74,6 @@ type Direct struct { logf logger.Logf netMon *netmon.Monitor // non-nil health *health.Tracker - discoPubKey key.DiscoPublic busClient *eventbus.Client clientVersionPub *eventbus.Publisher[tailcfg.ClientVersion] autoUpdatePub *eventbus.Publisher[AutoUpdate] @@ -95,6 +94,7 @@ type Direct struct { mu syncs.Mutex // mutex guards the following fields serverLegacyKey key.MachinePublic // original ("legacy") nacl crypto_box-based public key; only used for signRegisterRequest on Windows now serverNoiseKey key.MachinePublic + discoPubKey key.DiscoPublic // protected by mu; can be updated via [SetDiscoPublicKey] sfGroup singleflight.Group[struct{}, *ts2021.Client] // protects noiseClient creation. noiseClient *ts2021.Client // also protected by mu @@ -316,7 +316,6 @@ func NewDirect(opts Options) (*Direct, error) { logf: opts.Logf, persist: opts.Persist.View(), authKey: opts.AuthKey, - discoPubKey: opts.DiscoPublicKey, debugFlags: opts.DebugFlags, netMon: netMon, health: opts.HealthTracker, @@ -329,6 +328,7 @@ func NewDirect(opts Options) (*Direct, error) { dnsCache: dnsCache, dialPlan: opts.DialPlan, } + c.discoPubKey = opts.DiscoPublicKey c.closedCtx, c.closeCtx = context.WithCancel(context.Background()) c.controlClientID = nextControlClientID.Add(1) @@ -853,6 +853,14 @@ func (c *Direct) SendUpdate(ctx context.Context) error { return c.sendMapRequest(ctx, false, nil) } +// SetDiscoPublicKey updates the disco public key in local state. +// It does not implicitly trigger [SendUpdate]; callers should arrange for that. +func (c *Direct) SetDiscoPublicKey(key key.DiscoPublic) { + c.mu.Lock() + defer c.mu.Unlock() + c.discoPubKey = key +} + // ClientID returns the controlClientID of the controlClient. func (c *Direct) ClientID() int64 { return c.controlClientID @@ -902,6 +910,7 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap persist := c.persist serverURL := c.serverURL serverNoiseKey := c.serverNoiseKey + discoKey := c.discoPubKey hi := c.hostInfoLocked() backendLogID := hi.BackendLogID connectionHandleForTest := c.connectionHandleForTest @@ -945,11 +954,12 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap } nodeKey := persist.PublicNodeKey() + request := &tailcfg.MapRequest{ Version: tailcfg.CurrentCapabilityVersion, KeepAlive: true, NodeKey: nodeKey, - DiscoKey: c.discoPubKey, + DiscoKey: discoKey, Endpoints: eps, EndpointTypes: epTypes, Stream: isStreaming, diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index dd93dc7b3..4329fc878 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -20,6 +20,32 @@ import ( "tailscale.com/util/eventbus/eventbustest" ) +func TestSetDiscoPublicKey(t *testing.T) { + initialKey := key.NewDisco().Public() + + c := &Direct{ + discoPubKey: initialKey, + } + + c.mu.Lock() + if c.discoPubKey != initialKey { + t.Fatalf("initial disco key mismatch: got %v, want %v", c.discoPubKey, initialKey) + } + c.mu.Unlock() + + newKey := key.NewDisco().Public() + c.SetDiscoPublicKey(newKey) + + c.mu.Lock() + if c.discoPubKey != newKey { + t.Fatalf("disco key not updated: got %v, want %v", c.discoPubKey, newKey) + } + if c.discoPubKey == initialKey { + t.Fatal("disco key should have changed") + } + c.mu.Unlock() +} + func TestNewDirect(t *testing.T) { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 7eb673e6d..0ff299399 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -6620,6 +6620,30 @@ func (b *LocalBackend) DebugReSTUN() error { return nil } +func (b *LocalBackend) DebugRotateDiscoKey() error { + if !buildfeatures.HasDebug { + return nil + } + + mc := b.MagicConn() + mc.RotateDiscoKey() + + newDiscoKey := mc.DiscoPublicKey() + + if tunWrap, ok := b.sys.Tun.GetOK(); ok { + tunWrap.SetDiscoKey(newDiscoKey) + } + + b.mu.Lock() + cc := b.cc + b.mu.Unlock() + if cc != nil { + cc.SetDiscoPublicKey(newDiscoKey) + } + + return nil +} + func (b *LocalBackend) DebugPeerRelayServers() set.Set[netip.Addr] { return b.MagicConn().PeerRelays() } diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index b7325e957..152b375b0 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -316,6 +316,11 @@ func (cc *mockControl) UpdateEndpoints(endpoints []tailcfg.Endpoint) { cc.called("UpdateEndpoints") } +func (cc *mockControl) SetDiscoPublicKey(key key.DiscoPublic) { + cc.logf("SetDiscoPublicKey: %v", key) + cc.called("SetDiscoPublicKey") +} + func (cc *mockControl) ClientID() int64 { return cc.controlClientID } diff --git a/ipn/localapi/debug.go b/ipn/localapi/debug.go index 8aca7f009..ae9cb01e0 100644 --- a/ipn/localapi/debug.go +++ b/ipn/localapi/debug.go @@ -31,6 +31,7 @@ import ( func init() { Register("component-debug-logging", (*Handler).serveComponentDebugLogging) Register("debug", (*Handler).serveDebug) + Register("debug-rotate-disco-key", (*Handler).serveDebugRotateDiscoKey) Register("dev-set-state-store", (*Handler).serveDevSetStateStore) Register("debug-bus-events", (*Handler).serveDebugBusEvents) Register("debug-bus-graph", (*Handler).serveEventBusGraph) @@ -232,6 +233,8 @@ func (h *Handler) serveDebug(w http.ResponseWriter, r *http.Request) { if err == nil { return } + case "rotate-disco-key": + err = h.b.DebugRotateDiscoKey() case "": err = fmt.Errorf("missing parameter 'action'") default: @@ -473,3 +476,20 @@ func (h *Handler) serveDebugOptionalFeatures(w http.ResponseWriter, r *http.Requ w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(of) } + +func (h *Handler) serveDebugRotateDiscoKey(w http.ResponseWriter, r *http.Request) { + if !h.PermitWrite { + http.Error(w, "debug access denied", http.StatusForbidden) + return + } + if r.Method != httpm.POST { + http.Error(w, "POST required", http.StatusMethodNotAllowed) + return + } + if err := h.b.DebugRotateDiscoKey(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "text/plain") + io.WriteString(w, "done\n") +} diff --git a/wgengine/magicsock/disco_atomic.go b/wgengine/magicsock/disco_atomic.go new file mode 100644 index 000000000..5b765fbc2 --- /dev/null +++ b/wgengine/magicsock/disco_atomic.go @@ -0,0 +1,58 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "sync/atomic" + + "tailscale.com/types/key" +) + +type discoKeyPair struct { + private key.DiscoPrivate + public key.DiscoPublic + short string // public.ShortString() +} + +// discoAtomic is an atomic container for a disco private key, public key, and +// the public key's ShortString. The private and public keys are always kept +// synchronized. +// +// The zero value is not ready for use. Use [Set] to provide a usable value. +type discoAtomic struct { + pair atomic.Pointer[discoKeyPair] +} + +// Pair returns the private and public keys together atomically. +// Code that needs both the private and public keys synchronized should +// use Pair instead of calling Private and Public separately. +func (dk *discoAtomic) Pair() (key.DiscoPrivate, key.DiscoPublic) { + p := dk.pair.Load() + return p.private, p.public +} + +// Private returns the private key. +func (dk *discoAtomic) Private() key.DiscoPrivate { + return dk.pair.Load().private +} + +// Public returns the public key. +func (dk *discoAtomic) Public() key.DiscoPublic { + return dk.pair.Load().public +} + +// Short returns the short string of the public key (see [DiscoPublic.ShortString]). +func (dk *discoAtomic) Short() string { + return dk.pair.Load().short +} + +// Set updates the private key (and the cached public key and short string). +func (dk *discoAtomic) Set(private key.DiscoPrivate) { + public := private.Public() + dk.pair.Store(&discoKeyPair{ + private: private, + public: public, + short: public.ShortString(), + }) +} diff --git a/wgengine/magicsock/disco_atomic_test.go b/wgengine/magicsock/disco_atomic_test.go new file mode 100644 index 000000000..a1de9b843 --- /dev/null +++ b/wgengine/magicsock/disco_atomic_test.go @@ -0,0 +1,70 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "testing" + + "tailscale.com/types/key" +) + +func TestDiscoAtomic(t *testing.T) { + var dk discoAtomic + dk.Set(key.NewDisco()) + + private := dk.Private() + public := dk.Public() + short := dk.Short() + + if private.IsZero() { + t.Fatal("DiscoKey private key should not be zero") + } + if public.IsZero() { + t.Fatal("DiscoKey public key should not be zero") + } + if short == "" { + t.Fatal("DiscoKey short string should not be empty") + } + + if public != private.Public() { + t.Fatal("DiscoKey public key doesn't match private key") + } + if short != public.ShortString() { + t.Fatal("DiscoKey short string doesn't match public key") + } + + gotPrivate, gotPublic := dk.Pair() + if !gotPrivate.Equal(private) { + t.Fatal("Pair() returned different private key") + } + if gotPublic != public { + t.Fatal("Pair() returned different public key") + } +} + +func TestDiscoAtomicSet(t *testing.T) { + var dk discoAtomic + dk.Set(key.NewDisco()) + oldPrivate := dk.Private() + oldPublic := dk.Public() + + newPrivate := key.NewDisco() + dk.Set(newPrivate) + + currentPrivate := dk.Private() + currentPublic := dk.Public() + + if currentPrivate.Equal(oldPrivate) { + t.Fatal("DiscoKey private key should have changed after Set") + } + if currentPublic == oldPublic { + t.Fatal("DiscoKey public key should have changed after Set") + } + if !currentPrivate.Equal(newPrivate) { + t.Fatal("DiscoKey private key doesn't match the set key") + } + if currentPublic != newPrivate.Public() { + t.Fatal("DiscoKey public key doesn't match derived from set private key") + } +} diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index c2e5dcca3..eda589e14 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -697,7 +697,7 @@ func (de *endpoint) maybeProbeUDPLifetimeLocked() (afterInactivityFor time.Durat // shuffling probing probability where the local node ends up with a large // key value lexicographically relative to the other nodes it tends to // communicate with. If de's disco key changes, the cycle will reset. - if de.c.discoPublic.Compare(epDisco.key) >= 0 { + if de.c.discoAtomic.Public().Compare(epDisco.key) >= 0 { // lower disco pub key node probes higher return afterInactivityFor, false } @@ -1739,7 +1739,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src epAdd } if sp.purpose != pingHeartbeat && sp.purpose != pingHeartbeatForUDPLifetime { - de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pktlen=%v pong.src=%v%v", de.c.discoShort, de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), pktLen, m.Src, logger.ArgWriter(func(bw *bufio.Writer) { + de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pktlen=%v pong.src=%v%v", de.c.discoAtomic.Short(), de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), pktLen, m.Src, logger.ArgWriter(func(bw *bufio.Writer) { if sp.to != src { fmt.Fprintf(bw, " ping.to=%v", sp.to) } diff --git a/wgengine/magicsock/endpoint_test.go b/wgengine/magicsock/endpoint_test.go index df1c93406..f1dab924f 100644 --- a/wgengine/magicsock/endpoint_test.go +++ b/wgengine/magicsock/endpoint_test.go @@ -146,15 +146,22 @@ func TestProbeUDPLifetimeConfig_Valid(t *testing.T) { } func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { + var lowerPriv, higherPriv key.DiscoPrivate var lower, higher key.DiscoPublic - a := key.NewDisco().Public() - b := key.NewDisco().Public() + privA := key.NewDisco() + privB := key.NewDisco() + a := privA.Public() + b := privB.Public() if a.String() < b.String() { lower = a higher = b + lowerPriv = privA + higherPriv = privB } else { lower = b higher = a + lowerPriv = privB + higherPriv = privA } addr := addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort("1.1.1.1:1")}} newProbeUDPLifetime := func() *probeUDPLifetime { @@ -281,10 +288,18 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + c := &Conn{} + if tt.localDisco.IsZero() { + c.discoAtomic.Set(key.NewDisco()) + } else if tt.localDisco.Compare(lower) == 0 { + c.discoAtomic.Set(lowerPriv) + } else if tt.localDisco.Compare(higher) == 0 { + c.discoAtomic.Set(higherPriv) + } else { + t.Fatalf("unexpected localDisco value") + } de := &endpoint{ - c: &Conn{ - discoPublic: tt.localDisco, - }, + c: c, bestAddr: tt.bestAddr, } if tt.remoteDisco != nil { diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index f610d6adb..064838a2d 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -273,14 +273,8 @@ type Conn struct { // channel operations and goroutine creation. hasPeerRelayServers atomic.Bool - // discoPrivate is the private naclbox key used for active - // discovery traffic. It is always present, and immutable. - discoPrivate key.DiscoPrivate - // public of discoPrivate. It is always present and immutable. - discoPublic key.DiscoPublic - // ShortString of discoPublic (to save logging work later). It is always - // present and immutable. - discoShort string + // discoAtomic is the current disco private and public keypair for this conn. + discoAtomic discoAtomic // ============================================================ // mu guards all following fields; see userspaceEngine lock @@ -603,11 +597,9 @@ func newConn(logf logger.Logf) *Conn { peerLastDerp: make(map[key.NodePublic]int), peerMap: newPeerMap(), discoInfo: make(map[key.DiscoPublic]*discoInfo), - discoPrivate: discoPrivate, - discoPublic: discoPrivate.Public(), cloudInfo: newCloudInfo(logf), } - c.discoShort = c.discoPublic.ShortString() + c.discoAtomic.Set(discoPrivate) c.bind = &connBind{Conn: c, closed: true} c.receiveBatchPool = sync.Pool{New: func() any { msgs := make([]ipv6.Message, c.bind.BatchSize()) @@ -635,7 +627,7 @@ func (c *Conn) onUDPRelayAllocResp(allocResp UDPRelayAllocResp) { // now versus taking a network round-trip through DERP. selfNodeKey := c.publicKeyAtomic.Load() if selfNodeKey.Compare(allocResp.ReqRxFromNodeKey) == 0 && - allocResp.ReqRxFromDiscoKey.Compare(c.discoPublic) == 0 { + allocResp.ReqRxFromDiscoKey.Compare(c.discoAtomic.Public()) == 0 { c.relayManager.handleRxDiscoMsg(c, allocResp.Message, selfNodeKey, allocResp.ReqRxFromDiscoKey, epAddr{}) metricLocalDiscoAllocUDPRelayEndpointResponse.Add(1) } @@ -765,7 +757,7 @@ func NewConn(opts Options) (*Conn, error) { c.logf("[v1] couldn't create raw v6 disco listener, using regular listener instead: %v", err) } - c.logf("magicsock: disco key = %v", c.discoShort) + c.logf("magicsock: disco key = %v", c.discoAtomic.Short()) return c, nil } @@ -1244,7 +1236,32 @@ func (c *Conn) GetEndpointChanges(peer tailcfg.NodeView) ([]EndpointChange, erro // DiscoPublicKey returns the discovery public key. func (c *Conn) DiscoPublicKey() key.DiscoPublic { - return c.discoPublic + return c.discoAtomic.Public() +} + +// RotateDiscoKey generates a new discovery key pair and updates the connection +// to use it. This invalidates all existing disco sessions and will cause peers +// to re-establish discovery sessions with the new key. +// +// This is primarily for debugging and testing purposes, a future enhancement +// should provide a mechanism for seamless rotation by supporting short term use +// of the old key. +func (c *Conn) RotateDiscoKey() { + oldShort := c.discoAtomic.Short() + newPrivate := key.NewDisco() + + c.mu.Lock() + c.discoAtomic.Set(newPrivate) + newShort := c.discoAtomic.Short() + c.discoInfo = make(map[key.DiscoPublic]*discoInfo) + connCtx := c.connCtx + c.mu.Unlock() + + c.logf("magicsock: rotated disco key from %v to %v", oldShort, newShort) + + if connCtx != nil { + c.ReSTUN("disco-key-rotation") + } } // determineEndpoints returns the machine's endpoint addresses. It does a STUN @@ -1914,7 +1931,7 @@ func (c *Conn) sendDiscoAllocateUDPRelayEndpointRequest(dst epAddr, dstKey key.N if isDERP && dstKey.Compare(selfNodeKey) == 0 { c.allocRelayEndpointPub.Publish(UDPRelayAllocReq{ RxFromNodeKey: selfNodeKey, - RxFromDiscoKey: c.discoPublic, + RxFromDiscoKey: c.discoAtomic.Public(), Message: allocReq, }) metricLocalDiscoAllocUDPRelayEndpointRequest.Add(1) @@ -1985,7 +2002,7 @@ func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key. } } pkt = append(pkt, disco.Magic...) - pkt = c.discoPublic.AppendTo(pkt) + pkt = c.discoAtomic.Public().AppendTo(pkt) if isDERP { metricSendDiscoDERP.Add(1) @@ -2003,7 +2020,7 @@ func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key. if !dstKey.IsZero() { node = dstKey.ShortString() } - c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoShort, dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt)) + c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoAtomic.Short(), dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt)) } if isDERP { metricSentDiscoDERP.Add(1) @@ -2352,13 +2369,13 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake } if isVia { c.dlogf("[v1] magicsock: disco: %v<-%v via %v (%v, %v) got call-me-maybe-via, %d endpoints", - c.discoShort, epDisco.short, via.ServerDisco.ShortString(), + c.discoAtomic.Short(), epDisco.short, via.ServerDisco.ShortString(), ep.publicKey.ShortString(), derpStr(src.String()), len(via.AddrPorts)) c.relayManager.handleCallMeMaybeVia(ep, lastBest, lastBestIsTrusted, via) } else { c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", - c.discoShort, epDisco.short, + c.discoAtomic.Short(), epDisco.short, ep.publicKey.ShortString(), derpStr(src.String()), len(cmm.MyNumber)) go ep.handleCallMeMaybe(cmm) @@ -2404,7 +2421,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake if isResp { c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s, %d endpoints", - c.discoShort, epDisco.short, + c.discoAtomic.Short(), epDisco.short, ep.publicKey.ShortString(), derpStr(src.String()), msgType, len(resp.AddrPorts)) @@ -2418,7 +2435,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake return } else { c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s disco[0]=%v disco[1]=%v", - c.discoShort, epDisco.short, + c.discoAtomic.Short(), epDisco.short, ep.publicKey.ShortString(), derpStr(src.String()), msgType, req.ClientDisco[0].ShortString(), req.ClientDisco[1].ShortString()) @@ -2583,7 +2600,7 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src epAddr, di *discoInfo, derpN if numNodes > 1 { pingNodeSrcStr = "[one-of-multi]" } - c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x padding=%v", c.discoShort, di.discoShort, pingNodeSrcStr, src, dm.TxID[:6], dm.Padding) + c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x padding=%v", c.discoAtomic.Short(), di.discoShort, pingNodeSrcStr, src, dm.TxID[:6], dm.Padding) } ipDst := src @@ -2656,7 +2673,7 @@ func (c *Conn) discoInfoForKnownPeerLocked(k key.DiscoPublic) *discoInfo { di = &discoInfo{ discoKey: k, discoShort: k.ShortString(), - sharedKey: c.discoPrivate.Shared(k), + sharedKey: c.discoAtomic.Private().Shared(k), } c.discoInfo[k] = di } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 2a20b3cf6..7ae422906 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -4235,3 +4235,73 @@ func Test_lazyEndpoint_FromPeer(t *testing.T) { }) } } + +func TestRotateDiscoKey(t *testing.T) { + c := newConn(t.Logf) + + oldPrivate, oldPublic := c.discoAtomic.Pair() + oldShort := c.discoAtomic.Short() + + if oldPublic != oldPrivate.Public() { + t.Fatalf("old public key doesn't match old private key") + } + if oldShort != oldPublic.ShortString() { + t.Fatalf("old short string doesn't match old public key") + } + + testDiscoKey := key.NewDisco().Public() + c.mu.Lock() + c.discoInfo[testDiscoKey] = &discoInfo{ + discoKey: testDiscoKey, + discoShort: testDiscoKey.ShortString(), + } + if len(c.discoInfo) != 1 { + t.Fatalf("expected 1 discoInfo entry, got %d", len(c.discoInfo)) + } + c.mu.Unlock() + + c.RotateDiscoKey() + + newPrivate, newPublic := c.discoAtomic.Pair() + newShort := c.discoAtomic.Short() + + if newPublic.Compare(oldPublic) == 0 { + t.Fatalf("disco key didn't change after rotation") + } + if newShort == oldShort { + t.Fatalf("short string didn't change after rotation") + } + + if newPublic != newPrivate.Public() { + t.Fatalf("new public key doesn't match new private key") + } + if newShort != newPublic.ShortString() { + t.Fatalf("new short string doesn't match new public key") + } + + c.mu.Lock() + if len(c.discoInfo) != 0 { + t.Fatalf("expected discoInfo to be cleared, got %d entries", len(c.discoInfo)) + } + c.mu.Unlock() +} + +func TestRotateDiscoKeyMultipleTimes(t *testing.T) { + c := newConn(t.Logf) + + keys := make([]key.DiscoPublic, 0, 5) + keys = append(keys, c.discoAtomic.Public()) + + for i := 0; i < 4; i++ { + c.RotateDiscoKey() + newKey := c.discoAtomic.Public() + + for j, oldKey := range keys { + if newKey.Compare(oldKey) == 0 { + t.Fatalf("rotation %d produced same key as rotation %d", i+1, j) + } + } + + keys = append(keys, newKey) + } +} diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go index 2f93f1085..69831a4df 100644 --- a/wgengine/magicsock/relaymanager.go +++ b/wgengine/magicsock/relaymanager.go @@ -361,7 +361,7 @@ func (r *relayManager) ensureDiscoInfoFor(work *relayHandshakeWork) { di.di = &discoInfo{ discoKey: work.se.ServerDisco, discoShort: work.se.ServerDisco.ShortString(), - sharedKey: work.wlb.ep.c.discoPrivate.Shared(work.se.ServerDisco), + sharedKey: work.wlb.ep.c.discoAtomic.Private().Shared(work.se.ServerDisco), } } } @@ -1031,7 +1031,7 @@ func (r *relayManager) allocateAllServersRunLoop(wlb endpointWithLastBest) { if remoteDisco == nil { return } - discoKeys := key.NewSortedPairOfDiscoPublic(wlb.ep.c.discoPublic, remoteDisco.key) + discoKeys := key.NewSortedPairOfDiscoPublic(wlb.ep.c.discoAtomic.Public(), remoteDisco.key) for _, v := range r.serversByNodeKey { byDiscoKeys, ok := r.allocWorkByDiscoKeysByServerNodeKey[v.nodeKey] if !ok { diff --git a/wgengine/magicsock/relaymanager_test.go b/wgengine/magicsock/relaymanager_test.go index d40081839..e8fddfd91 100644 --- a/wgengine/magicsock/relaymanager_test.go +++ b/wgengine/magicsock/relaymanager_test.go @@ -22,11 +22,15 @@ func TestRelayManagerInitAndIdle(t *testing.T) { <-rm.runLoopStoppedCh rm = relayManager{} - rm.handleCallMeMaybeVia(&endpoint{c: &Conn{discoPrivate: key.NewDisco()}}, addrQuality{}, false, &disco.CallMeMaybeVia{UDPRelayEndpoint: disco.UDPRelayEndpoint{ServerDisco: key.NewDisco().Public()}}) + c1 := &Conn{} + c1.discoAtomic.Set(key.NewDisco()) + rm.handleCallMeMaybeVia(&endpoint{c: c1}, addrQuality{}, false, &disco.CallMeMaybeVia{UDPRelayEndpoint: disco.UDPRelayEndpoint{ServerDisco: key.NewDisco().Public()}}) <-rm.runLoopStoppedCh rm = relayManager{} - rm.handleRxDiscoMsg(&Conn{discoPrivate: key.NewDisco()}, &disco.BindUDPRelayEndpointChallenge{}, key.NodePublic{}, key.DiscoPublic{}, epAddr{}) + c2 := &Conn{} + c2.discoAtomic.Set(key.NewDisco()) + rm.handleRxDiscoMsg(c2, &disco.BindUDPRelayEndpointChallenge{}, key.NodePublic{}, key.DiscoPublic{}, epAddr{}) <-rm.runLoopStoppedCh rm = relayManager{}