From adf7bbf902e6410139d7517d23ace8c10a91bd00 Mon Sep 17 00:00:00 2001 From: James Tucker Date: Mon, 3 Nov 2025 14:53:11 -0800 Subject: [PATCH] net,wgengine: add support for disco key exchnage via TSMP Updates tailscale/corp#34037 Signed-off-by: James Tucker --- net/packet/tsmp.go | 68 +++++- net/packet/tsmp_test.go | 168 +++++++++++++++ net/tstun/wrap.go | 47 +++++ wgengine/magicsock/derp.go | 10 + wgengine/magicsock/magicsock.go | 142 ++++++++++++- wgengine/magicsock/magicsock_test.go | 56 +++++ wgengine/magicsock/tsmp_disco_test.go | 291 ++++++++++++++++++++++++++ wgengine/userspace.go | 48 +++++ 8 files changed, 818 insertions(+), 12 deletions(-) create mode 100644 wgengine/magicsock/tsmp_disco_test.go diff --git a/net/packet/tsmp.go b/net/packet/tsmp.go index 0ea321e84..b9b805b1a 100644 --- a/net/packet/tsmp.go +++ b/net/packet/tsmp.go @@ -18,7 +18,7 @@ import ( "tailscale.com/types/ipproto" ) -const minTSMPSize = 7 // the rejected body is 7 bytes +const minTSMPSize = 1 // minimum is 1 byte for the type field (e.g., disco key request 'd') // TailscaleRejectedHeader is a TSMP message that says that one // Tailscale node has rejected the connection from another. Unlike a @@ -72,6 +72,12 @@ const ( // TSMPTypePong is the type byte for a TailscalePongResponse. TSMPTypePong TSMPType = 'o' + + // TSMPTypeDiscoKeyRequest is the type byte for a disco key request. + TSMPTypeDiscoKeyRequest TSMPType = 'd' + + // TSMPTypeDiscoKeyUpdate is the type byte for a disco key update. + TSMPTypeDiscoKeyUpdate TSMPType = 'D' ) type TailscaleRejectReason byte @@ -259,3 +265,63 @@ func (h TSMPPongReply) Marshal(buf []byte) error { binary.BigEndian.PutUint16(buf[9:11], h.PeerAPIPort) return nil } + +// TSMPDiscoKeyRequest is a TSMP message that requests a peer's disco key. +// +// On the wire, after the IP header, it's currently 1 byte: +// - 'd' (TSMPTypeDiscoKeyRequest) +type TSMPDiscoKeyRequest struct{} + +func (pp *Parsed) AsTSMPDiscoKeyRequest() (h TSMPDiscoKeyRequest, ok bool) { + if pp.IPProto != ipproto.TSMP { + return + } + p := pp.Payload() + if len(p) < 1 || p[0] != byte(TSMPTypeDiscoKeyRequest) { + return + } + return h, true +} + +// TSMPDiscoKeyUpdate is a TSMP message that contains a disco public key. +// It may be sent in response to a request, or unsolicited when a node +// believes its peer may have stale disco key information. +// +// On the wire, after the IP header, it's currently 33 bytes: +// - 'D' (TSMPTypeDiscoKeyUpdate) +// - 32 bytes disco public key +type TSMPDiscoKeyUpdate struct { + IPHeader Header + DiscoKey [32]byte // raw disco public key bytes +} + +// AsTSMPDiscoKeyUpdate returns pp as a TSMPDiscoKeyUpdate and whether it is one. +// The update.IPHeader field is not populated. +func (pp *Parsed) AsTSMPDiscoKeyUpdate() (update TSMPDiscoKeyUpdate, ok bool) { + if pp.IPProto != ipproto.TSMP { + return + } + p := pp.Payload() + if len(p) < 33 || p[0] != byte(TSMPTypeDiscoKeyUpdate) { + return + } + copy(update.DiscoKey[:], p[1:33]) + return update, true +} + +func (h TSMPDiscoKeyUpdate) Len() int { + return h.IPHeader.Len() + 33 +} + +func (h TSMPDiscoKeyUpdate) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if err := h.IPHeader.Marshal(buf); err != nil { + return err + } + buf = buf[h.IPHeader.Len():] + buf[0] = byte(TSMPTypeDiscoKeyUpdate) + copy(buf[1:33], h.DiscoKey[:]) + return nil +} diff --git a/net/packet/tsmp_test.go b/net/packet/tsmp_test.go index e261e6a41..a8cd3cad5 100644 --- a/net/packet/tsmp_test.go +++ b/net/packet/tsmp_test.go @@ -71,3 +71,171 @@ func TestTailscaleRejectedHeader(t *testing.T) { } } } + +func TestTSMPDiscoKeyRequest(t *testing.T) { + t.Run("Manual", func(t *testing.T) { + var payload [1]byte + payload[0] = byte(TSMPTypeDiscoKeyRequest) + + var p Parsed + p.IPProto = TSMP + p.dataofs = 40 // simulate after IP header + buf := make([]byte, 40+1) + copy(buf[40:], payload[:]) + p.b = buf + p.length = len(buf) + + _, ok := p.AsTSMPDiscoKeyRequest() + if !ok { + t.Fatal("failed to parse TSMP disco key request") + } + }) + + t.Run("RoundTripIPv4", func(t *testing.T) { + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + iph := IP4Header{ + IPProto: TSMP, + Src: src, + Dst: dst, + } + + var payload [1]byte + payload[0] = byte(TSMPTypeDiscoKeyRequest) + + pkt := Generate(iph, payload[:]) + t.Logf("Generated packet: %d bytes, hex: %x", len(pkt), pkt) + + // Manually check what decode4 would see + if len(pkt) >= 4 { + declaredLen := int(uint16(pkt[2])<<8 | uint16(pkt[3])) + t.Logf("Packet buffer length: %d, IP header declares length: %d", len(pkt), declaredLen) + t.Logf("Protocol byte at [9]: 0x%02x = %d", pkt[9], pkt[9]) + } + + var p Parsed + p.Decode(pkt) + t.Logf("Decoded: IPVersion=%d IPProto=%v Src=%v Dst=%v length=%d dataofs=%d", + p.IPVersion, p.IPProto, p.Src, p.Dst, p.length, p.dataofs) + + if p.IPVersion != 4 { + t.Errorf("IPVersion = %d, want 4", p.IPVersion) + } + if p.IPProto != TSMP { + t.Errorf("IPProto = %v, want TSMP", p.IPProto) + } + if p.Src.Addr() != src { + t.Errorf("Src = %v, want %v", p.Src.Addr(), src) + } + if p.Dst.Addr() != dst { + t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst) + } + + _, ok := p.AsTSMPDiscoKeyRequest() + if !ok { + t.Fatal("failed to parse TSMP disco key request from generated packet") + } + }) + + t.Run("RoundTripIPv6", func(t *testing.T) { + src := netip.MustParseAddr("2001:db8::1") + dst := netip.MustParseAddr("2001:db8::2") + + iph := IP6Header{ + IPProto: TSMP, + Src: src, + Dst: dst, + } + + var payload [1]byte + payload[0] = byte(TSMPTypeDiscoKeyRequest) + + pkt := Generate(iph, payload[:]) + t.Logf("Generated packet: %d bytes", len(pkt)) + + var p Parsed + p.Decode(pkt) + + if p.IPVersion != 6 { + t.Errorf("IPVersion = %d, want 6", p.IPVersion) + } + if p.IPProto != TSMP { + t.Errorf("IPProto = %v, want TSMP", p.IPProto) + } + if p.Src.Addr() != src { + t.Errorf("Src = %v, want %v", p.Src.Addr(), src) + } + if p.Dst.Addr() != dst { + t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst) + } + + _, ok := p.AsTSMPDiscoKeyRequest() + if !ok { + t.Fatal("failed to parse TSMP disco key request from generated packet") + } + }) +} + +func TestTSMPDiscoKeyUpdate(t *testing.T) { + var discoKey [32]byte + for i := range discoKey { + discoKey[i] = byte(i + 10) + } + + // Test IPv4 + t.Run("IPv4", func(t *testing.T) { + update := TSMPDiscoKeyUpdate{ + IPHeader: IP4Header{ + IPProto: TSMP, + Src: netip.MustParseAddr("1.2.3.4"), + Dst: netip.MustParseAddr("5.6.7.8"), + }, + DiscoKey: discoKey, + } + + pkt := make([]byte, update.Len()) + if err := update.Marshal(pkt); err != nil { + t.Fatal(err) + } + + var p Parsed + p.Decode(pkt) + + parsed, ok := p.AsTSMPDiscoKeyUpdate() + if !ok { + t.Fatal("failed to parse TSMP disco key update") + } + if parsed.DiscoKey != discoKey { + t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey) + } + }) + + // Test IPv6 + t.Run("IPv6", func(t *testing.T) { + update := TSMPDiscoKeyUpdate{ + IPHeader: IP6Header{ + IPProto: TSMP, + Src: netip.MustParseAddr("2001:db8::1"), + Dst: netip.MustParseAddr("2001:db8::2"), + }, + DiscoKey: discoKey, + } + + pkt := make([]byte, update.Len()) + if err := update.Marshal(pkt); err != nil { + t.Fatal(err) + } + + var p Parsed + p.Decode(pkt) + + parsed, ok := p.AsTSMPDiscoKeyUpdate() + if !ok { + t.Fatal("failed to parse TSMP disco key update") + } + if parsed.DiscoKey != discoKey { + t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey) + } + }) +} diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index db4f689bf..234dc4941 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -188,11 +188,19 @@ type Wrapper struct { // OnTSMPPongReceived, if non-nil, is called whenever a TSMP pong arrives. OnTSMPPongReceived func(packet.TSMPPongReply) + // OnTSMPDiscoKeyReceived, if non-nil, is called whenever a TSMP disco key update arrives. + // The srcIP parameter identifies the peer that sent the update. + OnTSMPDiscoKeyReceived func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) + // OnICMPEchoResponseReceived, if non-nil, is called whenever a ICMP echo response // arrives. If the packet is to be handled internally this returns true, // false otherwise. OnICMPEchoResponseReceived func(*packet.Parsed) bool + // GetDiscoPublicKey, if non-nil, returns the local node's disco public key. + // This is called when responding to TSMP disco key requests. + GetDiscoPublicKey func() key.DiscoPublic + // PeerAPIPort, if non-nil, returns the peerapi port that's // running for the given IP address. PeerAPIPort func(netip.Addr) (port uint16, ok bool) @@ -1132,6 +1140,15 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa if f := t.OnTSMPPongReceived; f != nil { f(data) } + } else if _, ok := p.AsTSMPDiscoKeyRequest(); ok { + t.noteActivity() + t.injectOutboundDiscoKeyUpdate(p) + return filter.DropSilently, gro + } else if discoKeyUpdate, ok := p.AsTSMPDiscoKeyUpdate(); ok { + if f := t.OnTSMPDiscoKeyReceived; f != nil { + f(p.Src.Addr(), discoKeyUpdate) + } + return filter.DropSilently, gro } } @@ -1440,6 +1457,36 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque t.InjectOutbound(packet.Generate(pong, nil)) } +func (t *Wrapper) injectOutboundDiscoKeyUpdate(pp *packet.Parsed) { + if t.GetDiscoPublicKey == nil { + return + } + + discoKey := t.GetDiscoPublicKey() + if discoKey.IsZero() { + return + } + + update := packet.TSMPDiscoKeyUpdate{ + DiscoKey: discoKey.Raw32(), + } + + switch pp.IPVersion { + case 4: + h4 := pp.IP4Header() + h4.ToResponse() + update.IPHeader = h4 + case 6: + h6 := pp.IP6Header() + h6.ToResponse() + update.IPHeader = h6 + default: + return + } + + t.InjectOutbound(packet.Generate(update, nil)) +} + // InjectOutbound makes the Wrapper device behave as if a packet // with the given contents was sent to the network. // It does not block, but takes ownership of the packet. diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index 37a4f1a64..8896d4009 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -721,6 +721,16 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, dm.n, true) } + // Request disco key from peer via TSMP if we receive a WireGuard handshake + // over DERP without recent disco success. This handles the "WireGuard-first" + // case where WireGuard establishes a tunnel via DERP before disco succeeds + // (e.g., control plane unreachable or stale disco keys). + // We only trigger on data packets (not handshake packets) because the tunnel + // must be fully established before we can send TSMP requests through it. + if looksLikeWireGuardHandshake(b[:n]) && n > 0 { + go c.requestDiscoKeyViaTSMP(dm.src, ep) + } + c.metrics.inboundPacketsDERPTotal.Add(1) c.metrics.inboundBytesDERPTotal.Add(int64(n)) return n, ep diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 064838a2d..ae5ea3daa 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -155,17 +155,18 @@ type Conn struct { // This block mirrors the contents and field order of the Options // struct. Initialized once at construction, then constant. - eventBus *eventbus.Bus - eventClient *eventbus.Client - logf logger.Logf - epFunc func([]tailcfg.Endpoint) - derpActiveFunc func() - idleFunc func() time.Duration // nil means unknown - testOnlyPacketListener nettype.PacketListener - noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity - netMon *netmon.Monitor // must be non-nil - health *health.Tracker // or nil - controlKnobs *controlknobs.Knobs // or nil + eventBus *eventbus.Bus + eventClient *eventbus.Client + logf logger.Logf + epFunc func([]tailcfg.Endpoint) + derpActiveFunc func() + idleFunc func() time.Duration // nil means unknown + testOnlyPacketListener nettype.PacketListener + noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity + sendTSMPDiscoKeyRequest func(netip.Addr) error // or nil, sends TSMP disco key request to peer + netMon *netmon.Monitor // must be non-nil + health *health.Tracker // or nil + controlKnobs *controlknobs.Knobs // or nil // ================================================================ // No locking required to access these fields, either because @@ -1800,6 +1801,15 @@ func looksLikeInitiationMsg(b []byte) bool { binary.LittleEndian.Uint32(b) == device.MessageInitiationType } +func looksLikeWireGuardHandshake(b []byte) bool { + if len(b) < 4 { + return false + } + msgType := binary.LittleEndian.Uint32(b) + return (len(b) == device.MessageInitiationSize && msgType == device.MessageInitiationType) || + (len(b) == device.MessageResponseSize && msgType == device.MessageResponseType) +} + // receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6. // // size is the length of 'b' to report up to wireguard-go (only relevant if @@ -2857,6 +2867,14 @@ func (c *Conn) SetSilentDisco(v bool) { }) } +// SetSendTSMPDiscoKeyRequest sets the callback function to send TSMP disco key requests. +// This is provided by the engine/tundev to inject TSMP packets. +func (c *Conn) SetSendTSMPDiscoKeyRequest(fn func(netip.Addr) error) { + c.mu.Lock() + defer c.mu.Unlock() + c.sendTSMPDiscoKeyRequest = fn +} + // SilentDisco returns true if silent disco is enabled, otherwise false. func (c *Conn) SilentDisco() bool { c.mu.Lock() @@ -4104,6 +4122,13 @@ var ( metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff") metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff") metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff") + + // TSMP disco key exchange + metricTSMPDiscoKeyRequestSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_sent") + metricTSMPDiscoKeyRequestError = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_error") + metricTSMPDiscoKeyUpdateReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_received") + metricTSMPDiscoKeyUpdateApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_applied") + metricTSMPDiscoKeyUpdateUnknown = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_unknown_peer") ) // newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided @@ -4242,6 +4267,101 @@ func (le *lazyEndpoint) FromPeer(peerPublicKey [32]byte) { // See http://go/corp/29422 & http://go/corp/30042 le.c.peerMap.setNodeKeyForEpAddr(le.src, pubKey) le.c.logf("magicsock: lazyEndpoint.FromPeer(%v) setting epAddr(%v) in peerMap for node(%v)", pubKey.ShortString(), le.src, ep.nodeAddr) + + // Request disco key from peer via TSMP if we establish a tunnel + // without a recent disco ping. This handles cases where WireGuard + // establishes a tunnel before disco succeeds (e.g., control plane + // unreachable or stale disco keys). + go le.c.requestDiscoKeyViaTSMP(pubKey, ep) +} + +// requestDiscoKeyViaTSMP sends a TSMP disco key request to a peer if there +// hasn't been a recent disco ping. +func (c *Conn) requestDiscoKeyViaTSMP(nodeKey key.NodePublic, ep *endpoint) { + if c.sendTSMPDiscoKeyRequest == nil { + return + } + if !ep.nodeAddr.IsValid() { + return + } + + epDisco := ep.disco.Load() + if epDisco != nil { + c.mu.Lock() + di := c.discoInfo[epDisco.key] + recentDiscoPing := di != nil && time.Since(di.lastPingTime) < discoPingInterval + c.mu.Unlock() + + if recentDiscoPing { + return + } + } + // YUCK. once again goroutines fight back - we need to deterministically + // schedule this _after_ the wireguard handshake response or else we trigger + // the wireguard handshake race problem. Maybe it's ok though, as we should + // really be singleflighting this, and perhaps we just use a singleflight + // with a short cork. + time.Sleep(time.Millisecond) + + c.logf("magicsock: sending TSMP disco key request to %v (%v)", nodeKey.ShortString(), ep.nodeAddr) + if err := c.sendTSMPDiscoKeyRequest(ep.nodeAddr); err != nil { + c.logf("magicsock: failed to send TSMP disco key request: %v", err) + metricTSMPDiscoKeyRequestError.Add(1) + return + } + metricTSMPDiscoKeyRequestSent.Add(1) +} + +// HandleDiscoKeyUpdate processes a TSMP disco key update. +// The update may be solicited (in response to a request) or unsolicited. +// srcIP is the Tailscale IP of the peer that sent the update. +func (c *Conn) HandleDiscoKeyUpdate(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) { + discoKey := key.DiscoPublicFromRaw32(mem.B(update.DiscoKey[:])) + c.logf("magicsock: received disco key update %v from %v", discoKey.ShortString(), srcIP) + metricTSMPDiscoKeyUpdateReceived.Add(1) + + c.mu.Lock() + defer c.mu.Unlock() + + var nodeKey key.NodePublic + var found bool + for _, peer := range c.peers.All() { + for _, addr := range peer.Addresses().All() { + if addr.Addr() == srcIP { + nodeKey = peer.Key() + found = true + break + } + } + if found { + break + } + } + + if !found { + c.logf("magicsock: disco key update from unknown peer %v", srcIP) + metricTSMPDiscoKeyUpdateUnknown.Add(1) + return + } + + ep, ok := c.peerMap.endpointForNodeKey(nodeKey) + if !ok { + c.logf("magicsock: endpoint not found for node %v", nodeKey.ShortString()) + return + } + + oldDiscoKey := key.DiscoPublic{} + if epDisco := ep.disco.Load(); epDisco != nil { + oldDiscoKey = epDisco.key + } + c.discoInfoForKnownPeerLocked(discoKey) + ep.disco.Store(&endpointDisco{ + key: discoKey, + short: discoKey.ShortString(), + }) + c.peerMap.upsertEndpoint(ep, oldDiscoKey) + c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString()) + metricTSMPDiscoKeyUpdateApplied.Add(1) } // PeerRelays returns the current set of candidate peer relays. diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 7ae422906..59c8cfb9f 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -64,6 +64,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/nettype" "tailscale.com/types/ptr" + "tailscale.com/types/views" "tailscale.com/util/cibuild" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" @@ -4305,3 +4306,58 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) { keys = append(keys, newKey) } } + +func TestSendTSMPDiscoKeyRequest(t *testing.T) { + ep := &endpoint{ + nodeID: 1, + publicKey: key.NewNode().Public(), + nodeAddr: netip.MustParseAddr("100.64.0.1"), + } + discoKey := key.NewDisco().Public() + ep.disco.Store(&endpointDisco{ + key: discoKey, + short: discoKey.ShortString(), + }) + conn := newConn(t.Logf) + ep.c = conn + + tsmpRequestCalled := make(chan struct{}, 1) + var capturedIP netip.Addr + conn.sendTSMPDiscoKeyRequest = func(ip netip.Addr) error { + capturedIP = ip + tsmpRequestCalled <- struct{}{} + return nil + } + + conn.mu.Lock() + conn.peers = views.SliceOf([]tailcfg.NodeView{ + (&tailcfg.Node{ + Key: ep.publicKey, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + }, + }).View(), + }) + conn.mu.Unlock() + + var pubKey [32]byte + copy(pubKey[:], ep.publicKey.AppendTo(nil)) + conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) + + le := &lazyEndpoint{ + c: conn, + src: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")}, + } + + le.FromPeer(pubKey) + + select { + case <-tsmpRequestCalled: + if !capturedIP.IsValid() { + t.Error("TSMP request sent with invalid IP") + } + t.Logf("TSMP disco key request sent to %v", capturedIP) + case <-time.After(time.Second): + t.Error("TSMP disco key request was not sent") + } +} diff --git a/wgengine/magicsock/tsmp_disco_test.go b/wgengine/magicsock/tsmp_disco_test.go new file mode 100644 index 000000000..7a89e27ec --- /dev/null +++ b/wgengine/magicsock/tsmp_disco_test.go @@ -0,0 +1,291 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +package magicsock + +import ( + "net/netip" + "testing" + "time" + + "github.com/tailscale/wireguard-go/tun/tuntest" + "tailscale.com/net/netaddr" + "tailscale.com/net/packet" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/types/ipproto" + "tailscale.com/types/key" + "tailscale.com/types/netmap" + "tailscale.com/util/set" + "tailscale.com/wgengine/wgcfg/nmcfg" +) + +func TestTSMPDiscoKeyExchange(t *testing.T) { + tstest.ResourceCheck(t) + + // Set up DERP and STUN servers + derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1)) + defer cleanup() + + // Create two magicsock peers + m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap) + defer m1.Close() + m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap) + defer m2.Close() + + // Wire up TSMP hooks to enable disco key exchange + // This mimics what userspaceEngine does in wgengine/userspace.go + + // Hook 0: GetDiscoPublicKey - allows TSMP replies to include current disco key + m1.tsTun.GetDiscoPublicKey = m1.conn.DiscoPublicKey + m2.tsTun.GetDiscoPublicKey = m2.conn.DiscoPublicKey + + // Hook 1: OnTSMPDiscoKeyReceived - handle incoming TSMP disco key updates + m1.tsTun.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) { + t.Logf("m1: received TSMP disco key update from %v", srcIP) + m1.conn.HandleDiscoKeyUpdate(srcIP, update) + } + m2.tsTun.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) { + t.Logf("m2: received TSMP disco key update from %v", srcIP) + m2.conn.HandleDiscoKeyUpdate(srcIP, update) + } + + sendTSMPDiscoKeyRequest := func(dstIP netip.Addr) error { + var srcIP netip.Addr + var stack *magicStack + + switch dstIP { + case m1.IP(): + srcIP = m2.IP() + stack = m2 + t.Logf("m2: sending disco key request to m1") + case m2.IP(): + srcIP = m1.IP() + stack = m1 + t.Logf("m1: sending disco key request to m2") + } + + // equivalent to the implementation in userspace.Engine + iph := packet.IP4Header{ + IPProto: ipproto.TSMP, + Src: srcIP, + Dst: dstIP, + } + + var tsmpPayload [1]byte + tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest) + + tsmpRequest := packet.Generate(iph, tsmpPayload[:]) + return stack.tsTun.InjectOutbound(tsmpRequest) + } + + // Hook 2: SetSendTSMPDiscoKeyRequest - send TSMP disco key requests + m1.conn.SetSendTSMPDiscoKeyRequest(sendTSMPDiscoKeyRequest) + m2.conn.SetSendTSMPDiscoKeyRequest(sendTSMPDiscoKeyRequest) + + // Get initial disco keys + disco1Original := m1.conn.DiscoPublicKey() + disco2 := m2.conn.DiscoPublicKey() + + t.Logf("m1: node=%v disco=%v", m1.Public().ShortString(), disco1Original.ShortString()) + t.Logf("m2: node=%v disco=%v", m2.Public().ShortString(), disco2.ShortString()) + + // Wait for initial endpoints + var eps1, eps2 []tailcfg.Endpoint + select { + case eps1 = <-m1.epCh: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for m1 endpoints") + } + select { + case eps2 = <-m2.epCh: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for m2 endpoints") + } + + // Build initial network maps and establish connection + nm1 := &netmap.NetworkMap{ + NodeKey: m1.Public(), + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)}, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 2, + Key: m2.Public(), + DiscoKey: disco2, + Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)}, + AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)}, + Endpoints: epFromTyped(eps2), + HomeDERP: 1, + }).View(), + }, + } + + nm2 := &netmap.NetworkMap{ + NodeKey: m2.Public(), + SelfNode: (&tailcfg.Node{ + Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)}, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + ID: 1, + Key: m1.Public(), + DiscoKey: disco1Original, + Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)}, + AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)}, + Endpoints: epFromTyped(eps1), + HomeDERP: 1, + }).View(), + }, + } + + cfg1, err := nmcfg.WGCfg(m1.privateKey, nm1, t.Logf, 0, "") + if err != nil { + t.Fatal(err) + } + cfg2, err := nmcfg.WGCfg(m2.privateKey, nm2, t.Logf, 0, "") + if err != nil { + t.Fatal(err) + } + + nv1 := NodeViewsUpdate{ + SelfNode: nm1.SelfNode, + Peers: nm1.Peers, + } + m1.conn.onNodeViewsUpdate(nv1) + + peerSet1 := set.Set[key.NodePublic]{} + peerSet1.Add(m2.Public()) + m1.conn.UpdatePeers(peerSet1) + + nv2 := NodeViewsUpdate{ + SelfNode: nm2.SelfNode, + Peers: nm2.Peers, + } + m2.conn.onNodeViewsUpdate(nv2) + + peerSet2 := set.Set[key.NodePublic]{} + peerSet2.Add(m1.Public()) + m2.conn.UpdatePeers(peerSet2) + + if err := m1.Reconfig(cfg1); err != nil { + t.Fatal(err) + } + if err := m2.Reconfig(cfg2); err != nil { + t.Fatal(err) + } + + t.Logf("=== INITIAL CONFIGURATION COMPLETE ===") + + // Start goroutines to drain TUN inbound channels so TSMP packets can be received + drainTun := func(name string, stack *magicStack) { + go func() { + for { + select { + case <-t.Context().Done(): + return + case pkt := <-stack.tun.Inbound: + var p packet.Parsed + p.Decode(pkt) + if p.IPProto == ipproto.TSMP { + t.Logf("%s: received TSMP packet on TUN inbound: %d bytes", name, len(pkt)) + } else if p.IPProto == ipproto.ICMPv4 { + t.Logf("%s: received ICMPv4 packet on TUN inbound: %d bytes", name, len(pkt)) + } else { + t.Logf("%s: received packet on TUN inbound: %d bytes, proto=%v", name, len(pkt), p.IPProto) + } + } + } + }() + } + drainTun("m1", m1) + drainTun("m2", m2) + + initialRequestsSent := metricTSMPDiscoKeyRequestSent.Value() + initialUpdatesReceived := metricTSMPDiscoKeyUpdateReceived.Value() + initialUpdatesApplied := metricTSMPDiscoKeyUpdateApplied.Value() + + t.Logf("Initial metrics: requests_sent=%d updates_received=%d updates_applied=%d", + initialRequestsSent, initialUpdatesReceived, initialUpdatesApplied) + + t.Logf("=== ROTATING m1's DISCO KEY ===") + m1.conn.RotateDiscoKey() + disco1New := m1.conn.DiscoPublicKey() + + if disco1Original.Compare(disco1New) == 0 { + t.Fatal("disco key failed to rotate") + } + t.Logf("Rotated: %v -> %v", disco1Original.ShortString(), disco1New.ShortString()) + + t.Logf("=== SENDING PACKETS TO TRIGGER TSMP EXCHANGE ===") + + ping1to2 := tuntest.Ping(netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("100.64.0.1")) + + // Send packets from m2 to m1 only - this will trigger m1's handshake initiation + // and when m2 receives the encrypted packet, it should trigger FromPeer -> TSMP + select { + case m1.tun.Outbound <- ping1to2: + default: + } + + for { + time.Sleep(time.Millisecond) + // Check if m2 has learned m1's new disco key + st := m2.Status() + if ps, ok := st.Peer[m1.Public()]; ok && ps.CurAddr != "" { + t.Logf("Connection established after disco key rotation") + t.Logf("m2 -> m1 via %v", ps.CurAddr) + t.Logf("Disco key rotation: %v -> %v", disco1Original.ShortString(), disco1New.ShortString()) + + // Verify TSMP metrics incremented + finalRequestsSent := metricTSMPDiscoKeyRequestSent.Value() + finalUpdatesReceived := metricTSMPDiscoKeyUpdateReceived.Value() + finalUpdatesApplied := metricTSMPDiscoKeyUpdateApplied.Value() + + t.Logf("Final metrics: requests_sent=%d updates_received=%d updates_applied=%d", + finalRequestsSent, finalUpdatesReceived, finalUpdatesApplied) + + // Check that at least one TSMP request was sent + if finalRequestsSent <= initialRequestsSent { + t.Errorf("Expected TSMP disco key request to be sent, but metric did not increment: %d -> %d", + initialRequestsSent, finalRequestsSent) + } else { + t.Logf("✓ TSMP disco key request sent (metric: %d -> %d)", + initialRequestsSent, finalRequestsSent) + } + + // Check that at least one TSMP update was received + if finalUpdatesReceived <= initialUpdatesReceived { + t.Errorf("Expected TSMP disco key update to be received, but metric did not increment: %d -> %d", + initialUpdatesReceived, finalUpdatesReceived) + } else { + t.Logf("✓ TSMP disco key update received (metric: %d -> %d)", + initialUpdatesReceived, finalUpdatesReceived) + } + + // Check that at least one TSMP update was applied + if finalUpdatesApplied <= initialUpdatesApplied { + t.Errorf("Expected TSMP disco key update to be applied, but metric did not increment: %d -> %d", + initialUpdatesApplied, finalUpdatesApplied) + } else { + t.Logf("✓ TSMP disco key update applied (metric: %d -> %d)", + initialUpdatesApplied, finalUpdatesApplied) + } + + // Verify error counter didn't increment + requestErrors := metricTSMPDiscoKeyRequestError.Value() + if requestErrors > 0 { + t.Logf("Warning: TSMP disco key request errors: %d", requestErrors) + } + + unknownPeers := metricTSMPDiscoKeyUpdateUnknown.Value() + if unknownPeers > 0 { + t.Logf("Warning: TSMP disco key updates from unknown peers: %d", unknownPeers) + } + + t.Logf("TSMP disco key exchange infrastructure is functional") + return + } + } +} diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 8ad771fc5..53d3bb0dc 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -466,6 +466,25 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) return true } + e.tundev.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) { + e.logf("wgengine: got TSMP disco key update from %v", srcIP) + if e.magicConn != nil { + e.magicConn.HandleDiscoKeyUpdate(srcIP, update) + } + } + + e.tundev.GetDiscoPublicKey = func() key.DiscoPublic { + if e.magicConn == nil { + return key.DiscoPublic{} + } + return e.magicConn.DiscoPublicKey() + } + + // Wire up TSMP disco key request sending to magicsock + if e.magicConn != nil { + e.magicConn.SetSendTSMPDiscoKeyRequest(e.sendTSMPDiscoKeyRequest) + } + // wgdev takes ownership of tundev, will close it when closed. e.logf("Creating WireGuard device...") e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger) @@ -1563,6 +1582,35 @@ func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPP } } +// sendTSMPDiscoKeyRequest sends a TSMP disco key request to the given peer IP. +func (e *userspaceEngine) sendTSMPDiscoKeyRequest(ip netip.Addr) error { + srcIP, err := e.mySelfIPMatchingFamily(ip) + if err != nil { + return err + } + + var iph packet.Header + if srcIP.Is4() { + iph = packet.IP4Header{ + IPProto: ipproto.TSMP, + Src: srcIP, + Dst: ip, + } + } else { + iph = packet.IP6Header{ + IPProto: ipproto.TSMP, + Src: srcIP, + Dst: ip, + } + } + + var tsmpPayload [1]byte + tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest) + + tsmpRequest := packet.Generate(iph, tsmpPayload[:]) + return e.tundev.InjectOutbound(tsmpRequest) +} + func (e *userspaceEngine) setICMPEchoResponseCallback(idSeq uint32, cb func()) { e.mu.Lock() defer e.mu.Unlock()