diff --git a/net/packet/tsmp.go b/net/packet/tsmp.go index 8fad1d503..9881299b7 100644 --- a/net/packet/tsmp.go +++ b/net/packet/tsmp.go @@ -271,7 +271,7 @@ func (h TSMPPongReply) Marshal(buf []byte) error { // - 'a' (TSMPTypeDiscoAdvertisement) // - 32 disco key bytes type TSMPDiscoKeyAdvertisement struct { - Src, Dst netip.Addr + Src, Dst netip.Addr // Src and Dst are set from the parent IP Header when parsing. Key key.DiscoPublic } @@ -298,7 +298,7 @@ func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) { return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload)) } - return Generate(iph, payload), nil + return Generate(iph, payload[:]), nil } func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok bool) { @@ -310,6 +310,7 @@ func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok return } tka.Src = pp.Src.Addr() + tka.Dst = pp.Dst.Addr() tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33])) return tka, true diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 6e07c7a3d..fe1bc31b8 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -1126,8 +1126,10 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i return n, err } +// DiscoKeyAdvertisement is a TSMP message used for distributing disco keys. +// This struct is used an an event on the [eventbus.Bus]. type DiscoKeyAdvertisement struct { - Src netip.Addr + Src netip.Addr // Src field is populated by the IP header of the packet, not from the payload itself. Key key.DiscoPublic } diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index c7d0708df..3bc2ff447 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -986,7 +986,7 @@ func TestTSMPDisco(t *testing.T) { if tda.Src != src { t.Errorf("Src address did not match, expected %v, got %v", src, tda.Src) } - if !reflect.DeepEqual(tda.Key, discoKey.Public()) { + if tda.Key.Compare(discoKey.Public()) != 0 { t.Errorf("Key did not match, expected %q, got %q", discoKey.Public(), tda.Key) } }) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 064838a2d..b8a5f7da2 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -4104,6 +4104,11 @@ var ( metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff") metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff") metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff") + + // TSMP disco key exchange + metricTSMPDiscoKeyAdvertisementReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_received") + metricTSMPDiscoKeyAdvertisementApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_applied") + metricTSMPDiscoKeyAdvertisementUnchanged = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_unchanged") ) // newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided @@ -4264,3 +4269,40 @@ func (c *Conn) PeerRelays() set.Set[netip.Addr] { } return servers } + +// HandleDiscoKeyAdvertisement processes a TSMP disco key update. +// The update may be solicited (in response to a request) or unsolicited. +// node is the Tailscale tailcfg.NodeView of the peer that sent the update. +func (c *Conn) HandleDiscoKeyAdvertisement(node tailcfg.NodeView, update packet.TSMPDiscoKeyAdvertisement) { + discoKey := update.Key + c.logf("magicsock: received disco key update %v from %v", discoKey.ShortString(), node.StableID()) + metricTSMPDiscoKeyAdvertisementReceived.Add(1) + + c.mu.Lock() + defer c.mu.Unlock() + nodeKey := node.Key() + + ep, ok := c.peerMap.endpointForNodeKey(nodeKey) + if !ok { + c.logf("magicsock: endpoint not found for node %v", nodeKey.ShortString()) + return + } + + oldDiscoKey := key.DiscoPublic{} + if epDisco := ep.disco.Load(); epDisco != nil { + oldDiscoKey = epDisco.key + } + // If the key did not change, count it and return. + if oldDiscoKey.Compare(discoKey) == 0 { + metricTSMPDiscoKeyAdvertisementUnchanged.Add(1) + return + } + c.discoInfoForKnownPeerLocked(discoKey) + ep.disco.Store(&endpointDisco{ + key: discoKey, + short: discoKey.ShortString(), + }) + c.peerMap.upsertEndpoint(ep, oldDiscoKey) + c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString()) + metricTSMPDiscoKeyAdvertisementApplied.Add(1) +} diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 4e1024886..68ab4dfa0 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -64,6 +64,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/nettype" "tailscale.com/types/ptr" + "tailscale.com/types/views" "tailscale.com/util/cibuild" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" @@ -4302,3 +4303,47 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) { keys = append(keys, newKey) } } + +func TestReceiveTSMPDiscoKeyAdvertisement(t *testing.T) { + conn := newTestConn(t) + t.Cleanup(func() { conn.Close() }) + + peerKey := key.NewNode().Public() + ep := &endpoint{ + nodeID: 1, + publicKey: peerKey, + nodeAddr: netip.MustParseAddr("100.64.0.1"), + } + discoKey := key.NewDisco().Public() + ep.disco.Store(&endpointDisco{ + key: discoKey, + short: discoKey.ShortString(), + }) + ep.c = conn + conn.mu.Lock() + nodeView := (&tailcfg.Node{ + Key: ep.publicKey, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + }, + }).View() + conn.peers = views.SliceOf([]tailcfg.NodeView{nodeView}) + conn.mu.Unlock() + + conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) + + if ep.discoShort() != discoKey.ShortString() { + t.Errorf("Original disco key %s, does not match %s", discoKey.ShortString(), ep.discoShort()) + } + + newDiscoKey := key.NewDisco().Public() + tka := packet.TSMPDiscoKeyAdvertisement{ + Src: netip.MustParseAddr("100.64.0.1"), + Key: newDiscoKey, + } + conn.HandleDiscoKeyAdvertisement(nodeView, tka) + + if ep.disco.Load().short != newDiscoKey.ShortString() { + t.Errorf("New disco key %s, does not match %s", newDiscoKey.ShortString(), ep.disco.Load().short) + } +} diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 3db329a37..647923775 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -551,6 +551,23 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) } e.linkChangeQueue.Add(func() { e.linkChange(&cd) }) }) + eventbus.SubscribeFunc(ec, func(update tstun.DiscoKeyAdvertisement) { + e.logf("wgengine: got TSMP disco key advertisement from %v via eventbus", update.Src) + if e.magicConn == nil { + e.logf("wgengine: no magicConn") + return + } + + pkt := packet.TSMPDiscoKeyAdvertisement{ + Key: update.Key, + } + peer, ok := e.PeerForIP(update.Src) + if !ok { + e.logf("wgengine: no peer found for %v", update.Src) + return + } + e.magicConn.HandleDiscoKeyAdvertisement(peer.Node, pkt) + }) e.eventClient = ec e.logf("Engine created.") return e, nil