net,wgengine: add support for disco key exchnage via TSMP

Updates tailscale/corp#34037

Signed-off-by: James Tucker <james@tailscale.com>
raggi/disco-key-tsmp2
James Tucker 4 weeks ago
parent 9eff8a4503
commit 5bfa8e97f6
No known key found for this signature in database

@ -15,12 +15,10 @@ import (
"fmt"
"net/netip"
"go4.org/mem"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
)
const minTSMPSize = 7 // the rejected body is 7 bytes
const minTSMPSize = 1 // minimum is 1 byte for the type field (e.g., disco key request 'd')
// TailscaleRejectedHeader is a TSMP message that says that one
// Tailscale node has rejected the connection from another. Unlike a
@ -75,8 +73,11 @@ const (
// TSMPTypePong is the type byte for a TailscalePongResponse.
TSMPTypePong TSMPType = 'o'
// TSPMTypeDiscoAdvertisement is the type byte for sending disco keys
TSMPTypeDiscoAdvertisement TSMPType = 'a'
// TSMPTypeDiscoKeyRequest is the type byte for a disco key request.
TSMPTypeDiscoKeyRequest TSMPType = 'd'
// TSMPTypeDiscoKeyUpdate is the type byte for a disco key update.
TSMPTypeDiscoKeyUpdate TSMPType = 'D'
)
type TailscaleRejectReason byte
@ -265,52 +266,62 @@ func (h TSMPPongReply) Marshal(buf []byte) error {
return nil
}
// TSMPDiscoKeyAdvertisement is a TSMP message that's used for distributing Disco Keys.
// TSMPDiscoKeyRequest is a TSMP message that requests a peer's disco key.
//
// On the wire, after the IP header, it's currently 33 bytes:
// - 'a' (TSMPTypeDiscoAdvertisement)
// - 32 disco key bytes
type TSMPDiscoKeyAdvertisement struct {
Src, Dst netip.Addr
Key key.DiscoPublic
}
// On the wire, after the IP header, it's currently 1 byte:
// - 'd' (TSMPTypeDiscoKeyRequest)
type TSMPDiscoKeyRequest struct{}
func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) {
var iph Header
if ka.Src.Is4() {
iph = IP4Header{
IPProto: ipproto.TSMP,
Src: ka.Src,
Dst: ka.Dst,
}
} else {
iph = IP6Header{
IPProto: ipproto.TSMP,
Src: ka.Src,
Dst: ka.Dst,
}
func (pp *Parsed) AsTSMPDiscoKeyRequest() (h TSMPDiscoKeyRequest, ok bool) {
if pp.IPProto != ipproto.TSMP {
return
}
payload := make([]byte, 0, 33)
payload = append(payload, byte(TSMPTypeDiscoAdvertisement))
payload = ka.Key.AppendTo(payload)
if len(payload) != 33 {
// Mostly to safeguard against ourselves changing this in the future.
return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload))
p := pp.Payload()
if len(p) < 1 || p[0] != byte(TSMPTypeDiscoKeyRequest) {
return
}
return h, true
}
return Generate(iph, payload), nil
// TSMPDiscoKeyUpdate is a TSMP message that contains a disco public key.
// It may be sent in response to a request, or unsolicited when a node
// believes its peer may have stale disco key information.
//
// On the wire, after the IP header, it's currently 33 bytes:
// - 'D' (TSMPTypeDiscoKeyUpdate)
// - 32 bytes disco public key
type TSMPDiscoKeyUpdate struct {
IPHeader Header
DiscoKey [32]byte // raw disco public key bytes
}
func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok bool) {
// AsTSMPDiscoKeyUpdate returns pp as a TSMPDiscoKeyUpdate and whether it is one.
// The update.IPHeader field is not populated.
func (pp *Parsed) AsTSMPDiscoKeyUpdate() (update TSMPDiscoKeyUpdate, ok bool) {
if pp.IPProto != ipproto.TSMP {
return
}
p := pp.Payload()
if len(p) < 33 || p[0] != byte(TSMPTypeDiscoAdvertisement) {
if len(p) < 33 || p[0] != byte(TSMPTypeDiscoKeyUpdate) {
return
}
tka.Src = pp.Src.Addr()
tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33]))
copy(update.DiscoKey[:], p[1:33])
return update, true
}
return tka, true
func (h TSMPDiscoKeyUpdate) Len() int {
return h.IPHeader.Len() + 33
}
func (h TSMPDiscoKeyUpdate) Marshal(buf []byte) error {
if len(buf) < h.Len() {
return errSmallBuffer
}
if err := h.IPHeader.Marshal(buf); err != nil {
return err
}
buf = buf[h.IPHeader.Len():]
buf[0] = byte(TSMPTypeDiscoKeyUpdate)
copy(buf[1:33], h.DiscoKey[:])
return nil
}

@ -4,14 +4,8 @@
package packet
import (
"bytes"
"encoding/hex"
"net/netip"
"slices"
"testing"
"go4.org/mem"
"tailscale.com/types/key"
)
func TestTailscaleRejectedHeader(t *testing.T) {
@ -78,61 +72,168 @@ func TestTailscaleRejectedHeader(t *testing.T) {
}
}
func TestTSMPDiscoKeyAdvertisementMarshal(t *testing.T) {
var (
// IPv4: Ver(4)Len(5), TOS, Len(53), ID, Flags, TTL(64), Proto(99), Cksum
headerV4, _ = hex.DecodeString("45000035000000004063705d")
// IPv6: Ver(6)TCFlow, Len(33), NextHdr(99), HopLim(64)
headerV6, _ = hex.DecodeString("6000000000216340")
packetType = []byte{'a'}
testKey = bytes.Repeat([]byte{'a'}, 32)
// IPs
srcV4 = netip.MustParseAddr("1.2.3.4")
dstV4 = netip.MustParseAddr("4.3.2.1")
srcV6 = netip.MustParseAddr("2001:db8::1")
dstV6 = netip.MustParseAddr("2001:db8::2")
)
join := func(parts ...[]byte) []byte {
return bytes.Join(parts, nil)
func TestTSMPDiscoKeyRequest(t *testing.T) {
t.Run("Manual", func(t *testing.T) {
var payload [1]byte
payload[0] = byte(TSMPTypeDiscoKeyRequest)
var p Parsed
p.IPProto = TSMP
p.dataofs = 40 // simulate after IP header
buf := make([]byte, 40+1)
copy(buf[40:], payload[:])
p.b = buf
p.length = len(buf)
_, ok := p.AsTSMPDiscoKeyRequest()
if !ok {
t.Fatal("failed to parse TSMP disco key request")
}
})
t.Run("RoundTripIPv4", func(t *testing.T) {
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
iph := IP4Header{
IPProto: TSMP,
Src: src,
Dst: dst,
}
var payload [1]byte
payload[0] = byte(TSMPTypeDiscoKeyRequest)
pkt := Generate(iph, payload[:])
t.Logf("Generated packet: %d bytes, hex: %x", len(pkt), pkt)
// Manually check what decode4 would see
if len(pkt) >= 4 {
declaredLen := int(uint16(pkt[2])<<8 | uint16(pkt[3]))
t.Logf("Packet buffer length: %d, IP header declares length: %d", len(pkt), declaredLen)
t.Logf("Protocol byte at [9]: 0x%02x = %d", pkt[9], pkt[9])
}
var p Parsed
p.Decode(pkt)
t.Logf("Decoded: IPVersion=%d IPProto=%v Src=%v Dst=%v length=%d dataofs=%d",
p.IPVersion, p.IPProto, p.Src, p.Dst, p.length, p.dataofs)
if p.IPVersion != 4 {
t.Errorf("IPVersion = %d, want 4", p.IPVersion)
}
if p.IPProto != TSMP {
t.Errorf("IPProto = %v, want TSMP", p.IPProto)
}
if p.Src.Addr() != src {
t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
}
if p.Dst.Addr() != dst {
t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
}
_, ok := p.AsTSMPDiscoKeyRequest()
if !ok {
t.Fatal("failed to parse TSMP disco key request from generated packet")
}
})
t.Run("RoundTripIPv6", func(t *testing.T) {
src := netip.MustParseAddr("2001:db8::1")
dst := netip.MustParseAddr("2001:db8::2")
iph := IP6Header{
IPProto: TSMP,
Src: src,
Dst: dst,
}
var payload [1]byte
payload[0] = byte(TSMPTypeDiscoKeyRequest)
pkt := Generate(iph, payload[:])
t.Logf("Generated packet: %d bytes", len(pkt))
var p Parsed
p.Decode(pkt)
if p.IPVersion != 6 {
t.Errorf("IPVersion = %d, want 6", p.IPVersion)
}
if p.IPProto != TSMP {
t.Errorf("IPProto = %v, want TSMP", p.IPProto)
}
if p.Src.Addr() != src {
t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
}
if p.Dst.Addr() != dst {
t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
}
_, ok := p.AsTSMPDiscoKeyRequest()
if !ok {
t.Fatal("failed to parse TSMP disco key request from generated packet")
}
})
}
func TestTSMPDiscoKeyUpdate(t *testing.T) {
var discoKey [32]byte
for i := range discoKey {
discoKey[i] = byte(i + 10)
}
tests := []struct {
name string
tka TSMPDiscoKeyAdvertisement
want []byte
}{
{
name: "v4Header",
tka: TSMPDiscoKeyAdvertisement{
Src: srcV4,
Dst: dstV4,
Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
t.Run("IPv4", func(t *testing.T) {
update := TSMPDiscoKeyUpdate{
IPHeader: IP4Header{
IPProto: TSMP,
Src: netip.MustParseAddr("1.2.3.4"),
Dst: netip.MustParseAddr("5.6.7.8"),
},
want: join(headerV4, srcV4.AsSlice(), dstV4.AsSlice(), packetType, testKey),
},
{
name: "v6Header",
tka: TSMPDiscoKeyAdvertisement{
Src: srcV6,
Dst: dstV6,
Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
DiscoKey: discoKey,
}
pkt := make([]byte, update.Len())
if err := update.Marshal(pkt); err != nil {
t.Fatal(err)
}
var p Parsed
p.Decode(pkt)
parsed, ok := p.AsTSMPDiscoKeyUpdate()
if !ok {
t.Fatal("failed to parse TSMP disco key update")
}
if parsed.DiscoKey != discoKey {
t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
}
})
t.Run("IPv6", func(t *testing.T) {
update := TSMPDiscoKeyUpdate{
IPHeader: IP6Header{
IPProto: TSMP,
Src: netip.MustParseAddr("2001:db8::1"),
Dst: netip.MustParseAddr("2001:db8::2"),
},
want: join(headerV6, srcV6.AsSlice(), dstV6.AsSlice(), packetType, testKey),
},
}
DiscoKey: discoKey,
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.tka.Marshal()
if err != nil {
t.Errorf("error mashalling TSMPDiscoAdvertisement: %s", err)
}
if !slices.Equal(got, tt.want) {
t.Errorf("error mashalling TSMPDiscoAdvertisement, expected: \n%x, \ngot:\n%x", tt.want, got)
}
})
}
pkt := make([]byte, update.Len())
if err := update.Marshal(pkt); err != nil {
t.Fatal(err)
}
var p Parsed
p.Decode(pkt)
parsed, ok := p.AsTSMPDiscoKeyUpdate()
if !ok {
t.Fatal("failed to parse TSMP disco key update")
}
if parsed.DiscoKey != discoKey {
t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
}
})
}

@ -194,6 +194,10 @@ type Wrapper struct {
// false otherwise.
OnICMPEchoResponseReceived func(*packet.Parsed) bool
// GetDiscoPublicKey, if non-nil, returns the local node's disco public key.
// This is called when responding to TSMP disco key requests.
GetDiscoPublicKey func() key.DiscoPublic
// PeerAPIPort, if non-nil, returns the peerapi port that's
// running for the given IP address.
PeerAPIPort func(netip.Addr) (port uint16, ok bool)
@ -211,8 +215,8 @@ type Wrapper struct {
metrics *metrics
eventClient *eventbus.Client
discoKeyAdvertisementPub *eventbus.Publisher[DiscoKeyAdvertisement]
eventClient *eventbus.Client
discoKeyUpdatePub *eventbus.Publisher[DiscoKeyUpdate]
}
type metrics struct {
@ -227,6 +231,12 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
}
}
// DiscoKeyUpdate is published on the event bus when a TSMP disco key update is received.
type DiscoKeyUpdate struct {
SrcIP netip.Addr
Key [32]byte
}
// tunInjectedRead is an injected packet pretending to be a tun.Read().
type tunInjectedRead struct {
// Only one of packet or data should be set, and are read in that order of
@ -288,7 +298,7 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry,
}
w.eventClient = bus.Client("net.tstun")
w.discoKeyAdvertisementPub = eventbus.Publish[DiscoKeyAdvertisement](w.eventClient)
w.discoKeyUpdatePub = eventbus.Publish[DiscoKeyUpdate](w.eventClient)
w.vectorBuffer = make([][]byte, tdev.BatchSize())
for i := range w.vectorBuffer {
@ -1126,11 +1136,6 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i
return n, err
}
type DiscoKeyAdvertisement struct {
Src netip.Addr
Key key.DiscoPublic
}
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) {
if captHook != nil {
captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta)
@ -1141,16 +1146,21 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa
t.noteActivity()
t.injectOutboundPong(p, pingReq)
return filter.DropSilently, gro
} else if discoKeyAdvert, ok := p.AsTSMPDiscoAdvertisement(); ok {
t.discoKeyAdvertisementPub.Publish(DiscoKeyAdvertisement{
Src: discoKeyAdvert.Src,
Key: discoKeyAdvert.Key,
})
return filter.DropSilently, gro
} else if data, ok := p.AsTSMPPong(); ok {
if f := t.OnTSMPPongReceived; f != nil {
f(data)
}
} else if _, ok := p.AsTSMPDiscoKeyRequest(); ok {
t.noteActivity()
t.injectOutboundDiscoKeyUpdate(p)
return filter.DropSilently, gro
} else if discoKeyUpdate, ok := p.AsTSMPDiscoKeyUpdate(); ok {
// Publish to eventbus for subscribers
t.discoKeyUpdatePub.Publish(DiscoKeyUpdate{
SrcIP: p.Src.Addr(),
Key: discoKeyUpdate.DiscoKey,
})
return filter.DropSilently, gro
}
}
@ -1459,6 +1469,36 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque
t.InjectOutbound(packet.Generate(pong, nil))
}
func (t *Wrapper) injectOutboundDiscoKeyUpdate(pp *packet.Parsed) {
if t.GetDiscoPublicKey == nil {
return
}
discoKey := t.GetDiscoPublicKey()
if discoKey.IsZero() {
return
}
update := packet.TSMPDiscoKeyUpdate{
DiscoKey: discoKey.Raw32(),
}
switch pp.IPVersion {
case 4:
h4 := pp.IP4Header()
h4.ToResponse()
update.IPHeader = h4
case 6:
h6 := pp.IP6Header()
h6.ToResponse()
update.IPHeader = h6
default:
return
}
t.InjectOutbound(packet.Generate(update, nil))
}
// InjectOutbound makes the Wrapper device behave as if a packet
// with the given contents was sent to the network.
// It does not block, but takes ownership of the packet.

@ -966,28 +966,57 @@ func TestCaptureHook(t *testing.T) {
}
func TestTSMPDisco(t *testing.T) {
t.Run("IPv6DiscoAdvert", func(t *testing.T) {
t.Run("DiscoKeyRequest", func(t *testing.T) {
src := netip.MustParseAddr("2001:db8::1")
dst := netip.MustParseAddr("2001:db8::2")
discoKey := key.NewDisco()
buf, _ := (&packet.TSMPDiscoKeyAdvertisement{
Src: src,
Dst: dst,
Key: discoKey.Public(),
}).Marshal()
iph := packet.IP6Header{
IPProto: ipproto.TSMP,
Src: src,
Dst: dst,
}
var payload [1]byte
payload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
buf := packet.Generate(iph, payload[:])
var p packet.Parsed
p.Decode(buf)
tda, ok := p.AsTSMPDiscoAdvertisement()
_, ok := p.AsTSMPDiscoKeyRequest()
if !ok {
t.Error("Unable to parse message as TSMPDiscoAdversitement")
t.Error("Unable to parse message as TSMPDiscoKeyRequest")
}
if tda.Src != src {
t.Errorf("Src address did not match, expected %v, got %v", src, tda.Src)
})
t.Run("DiscoKeyUpdate", func(t *testing.T) {
src := netip.MustParseAddr("2001:db8::1")
dst := netip.MustParseAddr("2001:db8::2")
discoKey := key.NewDisco()
update := packet.TSMPDiscoKeyUpdate{
IPHeader: packet.IP6Header{
IPProto: ipproto.TSMP,
Src: src,
Dst: dst,
},
DiscoKey: discoKey.Public().Raw32(),
}
buf := make([]byte, update.Len())
if err := update.Marshal(buf); err != nil {
t.Fatal(err)
}
var p packet.Parsed
p.Decode(buf)
parsed, ok := p.AsTSMPDiscoKeyUpdate()
if !ok {
t.Error("Unable to parse message as TSMPDiscoKeyUpdate")
}
if !reflect.DeepEqual(tda.Key, discoKey.Public()) {
t.Errorf("Key did not match, expected %q, got %q", discoKey.Public(), tda.Key)
if parsed.DiscoKey != discoKey.Public().Raw32() {
t.Errorf("Key did not match, expected %v, got %v", discoKey.Public().Raw32(), parsed.DiscoKey)
}
})
}

@ -721,6 +721,16 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, dm.n, true)
}
// Request disco key from peer via TSMP if we receive a WireGuard handshake
// over DERP without recent disco success. This handles the "WireGuard-first"
// case where WireGuard establishes a tunnel via DERP before disco succeeds
// (e.g., control plane unreachable or stale disco keys).
if looksLikeWireGuardHandshake(b[:n]) && n > 0 {
c.mu.Lock()
c.requestDiscoKeyViaTSMPLocked(dm.src, ep)
c.mu.Unlock()
}
c.metrics.inboundPacketsDERPTotal.Add(1)
c.metrics.inboundBytesDERPTotal.Add(int64(n))
return n, ep

@ -178,9 +178,10 @@ type Conn struct {
// A publisher for synchronization points to ensure correct ordering of
// config changes between magicsock and wireguard.
syncPub *eventbus.Publisher[syncPoint]
allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq]
portUpdatePub *eventbus.Publisher[router.PortUpdate]
syncPub *eventbus.Publisher[syncPoint]
allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq]
portUpdatePub *eventbus.Publisher[router.PortUpdate]
tsmpDiscoKeyRequestPub *eventbus.Publisher[TSMPDiscoKeyRequest]
// pconn4 and pconn6 are the underlying UDP sockets used to
// send/receive packets for wireguard and other magicsock
@ -572,6 +573,14 @@ type UDPRelayAllocReq struct {
Message *disco.AllocateUDPRelayEndpointRequest
}
// TSMPDiscoKeyRequest is published on the event bus when magicsock needs to
// send a TSMP disco key request to a peer. Subscribers should inject the
// TSMP packet into the tunnel device.
type TSMPDiscoKeyRequest struct {
DstIP netip.Addr
MetricSent *clientmetric.Metric
}
// UDPRelayAllocResp represents a [*disco.AllocateUDPRelayEndpointResponse]
// that is yet to be transmitted over DERP (or delivered locally if
// ReqRxFromNodeKey is self). This is signaled over an [eventbus.Bus] from
@ -691,6 +700,7 @@ func NewConn(opts Options) (*Conn, error) {
c.syncPub = eventbus.Publish[syncPoint](ec)
c.allocRelayEndpointPub = eventbus.Publish[UDPRelayAllocReq](ec)
c.portUpdatePub = eventbus.Publish[router.PortUpdate](ec)
c.tsmpDiscoKeyRequestPub = eventbus.Publish[TSMPDiscoKeyRequest](ec)
eventbus.SubscribeFunc(ec, c.onPortMapChanged)
eventbus.SubscribeFunc(ec, c.onFilterUpdate)
eventbus.SubscribeFunc(ec, c.onNodeViewsUpdate)
@ -1800,6 +1810,15 @@ func looksLikeInitiationMsg(b []byte) bool {
binary.LittleEndian.Uint32(b) == device.MessageInitiationType
}
func looksLikeWireGuardHandshake(b []byte) bool {
if len(b) < 4 {
return false
}
msgType := binary.LittleEndian.Uint32(b)
return (len(b) == device.MessageInitiationSize && msgType == device.MessageInitiationType) ||
(len(b) == device.MessageResponseSize && msgType == device.MessageResponseType)
}
// receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6.
//
// size is the length of 'b' to report up to wireguard-go (only relevant if
@ -4104,6 +4123,12 @@ var (
metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff")
metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff")
metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff")
// TSMP disco key exchange
metricTSMPDiscoKeyRequestSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_sent")
metricTSMPDiscoKeyUpdateReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_received")
metricTSMPDiscoKeyUpdateApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_applied")
metricTSMPDiscoKeyUpdateUnknown = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_unknown_peer")
)
// newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided
@ -4242,6 +4267,81 @@ func (le *lazyEndpoint) FromPeer(peerPublicKey [32]byte) {
// See http://go/corp/29422 & http://go/corp/30042
le.c.peerMap.setNodeKeyForEpAddr(le.src, pubKey)
le.c.logf("magicsock: lazyEndpoint.FromPeer(%v) setting epAddr(%v) in peerMap for node(%v)", pubKey.ShortString(), le.src, ep.nodeAddr)
le.c.requestDiscoKeyViaTSMPLocked(pubKey, ep)
}
// requestDiscoKeyViaTSMPLocked sends a TSMP disco key request to a peer if there
// hasn't been a recent disco ping.
// c.mu must be held.
func (c *Conn) requestDiscoKeyViaTSMPLocked(nodeKey key.NodePublic, ep *endpoint) {
if !ep.nodeAddr.IsValid() {
return
}
epDisco := ep.disco.Load()
if epDisco != nil {
di := c.discoInfo[epDisco.key]
recentDiscoPing := di != nil && time.Since(di.lastPingTime) < discoPingInterval
if recentDiscoPing {
return
}
}
go c.tsmpDiscoKeyRequestPub.Publish(TSMPDiscoKeyRequest{DstIP: ep.nodeAddr, MetricSent: metricTSMPDiscoKeyRequestSent})
}
// HandleDiscoKeyUpdate processes a TSMP disco key update.
// The update may be solicited (in response to a request) or unsolicited.
// srcIP is the Tailscale IP of the peer that sent the update.
func (c *Conn) HandleDiscoKeyUpdate(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
discoKey := key.DiscoPublicFromRaw32(mem.B(update.DiscoKey[:]))
c.logf("magicsock: received disco key update %v from %v", discoKey.ShortString(), srcIP)
metricTSMPDiscoKeyUpdateReceived.Add(1)
c.mu.Lock()
defer c.mu.Unlock()
var nodeKey key.NodePublic
var found bool
for _, peer := range c.peers.All() {
for _, addr := range peer.Addresses().All() {
if addr.Addr() == srcIP {
nodeKey = peer.Key()
found = true
break
}
}
if found {
break
}
}
if !found {
c.logf("magicsock: disco key update from unknown peer %v", srcIP)
metricTSMPDiscoKeyUpdateUnknown.Add(1)
return
}
ep, ok := c.peerMap.endpointForNodeKey(nodeKey)
if !ok {
c.logf("magicsock: endpoint not found for node %v", nodeKey.ShortString())
return
}
oldDiscoKey := key.DiscoPublic{}
if epDisco := ep.disco.Load(); epDisco != nil {
oldDiscoKey = epDisco.key
}
c.discoInfoForKnownPeerLocked(discoKey)
ep.disco.Store(&endpointDisco{
key: discoKey,
short: discoKey.ShortString(),
})
c.peerMap.upsertEndpoint(ep, oldDiscoKey)
c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString())
metricTSMPDiscoKeyUpdateApplied.Add(1)
}
// PeerRelays returns the current set of candidate peer relays.

@ -64,6 +64,7 @@ import (
"tailscale.com/types/netmap"
"tailscale.com/types/nettype"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
"tailscale.com/util/cibuild"
"tailscale.com/util/clientmetric"
"tailscale.com/util/eventbus"
@ -4302,3 +4303,66 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) {
keys = append(keys, newKey)
}
}
func TestSendTSMPDiscoKeyRequest(t *testing.T) {
ep := &endpoint{
nodeID: 1,
publicKey: key.NewNode().Public(),
nodeAddr: netip.MustParseAddr("100.64.0.1"),
}
discoKey := key.NewDisco().Public()
ep.disco.Store(&endpointDisco{
key: discoKey,
short: discoKey.ShortString(),
})
bus := eventbustest.NewBus(t)
conn := newConn(t.Logf)
conn.eventBus = bus
conn.eventClient = bus.Client("magicsock.Conn.test")
conn.tsmpDiscoKeyRequestPub = eventbus.Publish[TSMPDiscoKeyRequest](conn.eventClient)
ep.c = conn
tsmpRequestCalled := make(chan struct{}, 1)
var capturedIP netip.Addr
ec := bus.Client("test")
defer ec.Close()
eventbus.SubscribeFunc(ec, func(req TSMPDiscoKeyRequest) {
capturedIP = req.DstIP
if req.MetricSent != nil {
req.MetricSent.Add(1)
}
tsmpRequestCalled <- struct{}{}
})
conn.mu.Lock()
conn.peers = views.SliceOf([]tailcfg.NodeView{
(&tailcfg.Node{
Key: ep.publicKey,
Addresses: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
},
}).View(),
})
conn.mu.Unlock()
var pubKey [32]byte
copy(pubKey[:], ep.publicKey.AppendTo(nil))
conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{})
le := &lazyEndpoint{
c: conn,
src: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")},
}
le.FromPeer(pubKey)
select {
case <-tsmpRequestCalled:
if !capturedIP.IsValid() {
t.Error("TSMP request sent with invalid IP")
}
t.Logf("TSMP disco key request sent to %v", capturedIP)
case <-time.After(time.Second):
t.Error("TSMP disco key request was not sent")
}
}

@ -54,6 +54,7 @@ import (
"tailscale.com/util/execqueue"
"tailscale.com/util/mak"
"tailscale.com/util/set"
"tailscale.com/util/singleflight"
"tailscale.com/util/testenv"
"tailscale.com/util/usermetric"
"tailscale.com/version"
@ -469,6 +470,13 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
return true
}
e.tundev.GetDiscoPublicKey = func() key.DiscoPublic {
if e.magicConn == nil {
return key.DiscoPublic{}
}
return e.magicConn.DiscoPublicKey()
}
// wgdev takes ownership of tundev, will close it when closed.
e.logf("Creating WireGuard device...")
e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger)
@ -549,6 +557,36 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
}
e.linkChangeQueue.Add(func() { e.linkChange(&cd) })
})
eventbus.SubscribeFunc(ec, func(update tstun.DiscoKeyUpdate) {
e.logf("wgengine: got TSMP disco key update from %v via eventbus", update.SrcIP)
if e.magicConn != nil {
pkt := packet.TSMPDiscoKeyUpdate{
DiscoKey: update.Key,
}
e.magicConn.HandleDiscoKeyUpdate(update.SrcIP, pkt)
}
})
var tsmpRequestGroup singleflight.Group[netip.Addr, struct{}]
eventbus.SubscribeFunc(ec, func(req magicsock.TSMPDiscoKeyRequest) {
go tsmpRequestGroup.Do(req.DstIP, func() (struct{}, error) {
// DiscoKeyRequests are triggered by an incoming WireGuard handshake
// initiation arriving before a disco ping, which is a likely
// indicator that disco pings failed due to a lack of key
// synchronization. If the requests are sent immediately, before the
// handshake state is accepted in the WireGuard client state
// machine, this starts a new session, and the two peer state
// machines conflict, causing loss and additional delays. Delaying
// the send avoids this, so coalesce duplicate sends, and delay them
// by a short time to avoid the state machine conflict.
time.Sleep(time.Millisecond)
if err := e.sendTSMPDiscoKeyRequest(req.DstIP); err != nil {
e.logf("wgengine: failed to send TSMP disco key request: %v", err)
}
e.logf("wgengine: sending TSMP disco key request to %v", req.DstIP)
req.MetricSent.Add(1)
return struct{}{}, nil
})
})
e.eventClient = ec
e.logf("Engine created.")
return e, nil
@ -1436,7 +1474,6 @@ func (e *userspaceEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size in
e.magicConn.Ping(peer, res, size, cb)
case "TSMP":
e.sendTSMPPing(ip, peer, res, cb)
e.sendTSMPDiscoAdvertisement(ip)
case "ICMP":
e.sendICMPEchoRequest(ip, peer, res, cb)
}
@ -1557,29 +1594,6 @@ func (e *userspaceEngine) sendTSMPPing(ip netip.Addr, peer tailcfg.NodeView, res
e.tundev.InjectOutbound(tsmpPing)
}
func (e *userspaceEngine) sendTSMPDiscoAdvertisement(ip netip.Addr) {
srcIP, err := e.mySelfIPMatchingFamily(ip)
if err != nil {
e.logf("getting matching node: %s", err)
return
}
tdka := packet.TSMPDiscoKeyAdvertisement{
Src: srcIP,
Dst: ip,
Key: e.magicConn.DiscoPublicKey(),
}
payload, err := tdka.Marshal()
if err != nil {
e.logf("error generating TSMP Advertisement: %s", err)
metricTSMPDiscoKeyAdvertisementError.Add(1)
} else if err := e.tundev.InjectOutbound(payload); err != nil {
e.logf("error sending TSMP Advertisement: %s", err)
metricTSMPDiscoKeyAdvertisementError.Add(1)
} else {
metricTSMPDiscoKeyAdvertisementSent.Add(1)
}
}
func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPPongReply)) {
e.mu.Lock()
defer e.mu.Unlock()
@ -1593,6 +1607,35 @@ func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPP
}
}
// sendTSMPDiscoKeyRequest sends a TSMP disco key request to the given peer IP.
func (e *userspaceEngine) sendTSMPDiscoKeyRequest(ip netip.Addr) error {
srcIP, err := e.mySelfIPMatchingFamily(ip)
if err != nil {
return err
}
var iph packet.Header
if srcIP.Is4() {
iph = packet.IP4Header{
IPProto: ipproto.TSMP,
Src: srcIP,
Dst: ip,
}
} else {
iph = packet.IP6Header{
IPProto: ipproto.TSMP,
Src: srcIP,
Dst: ip,
}
}
var tsmpPayload [1]byte
tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
tsmpRequest := packet.Generate(iph, tsmpPayload[:])
return e.tundev.InjectOutbound(tsmpRequest)
}
func (e *userspaceEngine) setICMPEchoResponseCallback(idSeq uint32, cb func()) {
e.mu.Lock()
defer e.mu.Unlock()
@ -1746,9 +1789,6 @@ var (
metricNumMajorChanges = clientmetric.NewCounter("wgengine_major_changes")
metricNumMinorChanges = clientmetric.NewCounter("wgengine_minor_changes")
metricTSMPDiscoKeyAdvertisementSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_sent")
metricTSMPDiscoKeyAdvertisementError = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_error")
)
func (e *userspaceEngine) InstallCaptureHook(cb packet.CaptureCallback) {

@ -325,7 +325,7 @@ func TestUserspaceEnginePeerMTUReconfig(t *testing.T) {
}
}
func TestTSMPKeyAdvertisement(t *testing.T) {
func TestTSMPDiscoKeyRequest(t *testing.T) {
var knobs controlknobs.Knobs
bus := eventbustest.NewBus(t)
@ -369,13 +369,12 @@ func TestTSMPKeyAdvertisement(t *testing.T) {
t.Fatal(err)
}
addr := netip.MustParseAddr("100.100.99.1")
previousValue := metricTSMPDiscoKeyAdvertisementSent.Value()
ue.sendTSMPDiscoAdvertisement(addr)
if val := metricTSMPDiscoKeyAdvertisementSent.Value(); val <= previousValue {
errs := metricTSMPDiscoKeyAdvertisementError.Value()
t.Errorf("Expected 1 disco key advert, got %d, errors %d", val, errs)
peerAddr := netip.MustParseAddr("100.100.99.1")
err = ue.sendTSMPDiscoKeyRequest(peerAddr)
if err != nil {
t.Fatalf("sendTSMPDiscoKeyRequest failed: %v", err)
}
// Remove config to have the engine shut down more consistently
err = ue.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{})
if err != nil {

Loading…
Cancel
Save