net,wgengine: add support for disco key exchnage via TSMP

Updates tailscale/corp#34037

Signed-off-by: James Tucker <james@tailscale.com>
raggi/disco-key-tsmp2
James Tucker 1 month ago
parent 9eff8a4503
commit 5bfa8e97f6
No known key found for this signature in database

@ -15,12 +15,10 @@ import (
"fmt" "fmt"
"net/netip" "net/netip"
"go4.org/mem"
"tailscale.com/types/ipproto" "tailscale.com/types/ipproto"
"tailscale.com/types/key"
) )
const minTSMPSize = 7 // the rejected body is 7 bytes const minTSMPSize = 1 // minimum is 1 byte for the type field (e.g., disco key request 'd')
// TailscaleRejectedHeader is a TSMP message that says that one // TailscaleRejectedHeader is a TSMP message that says that one
// Tailscale node has rejected the connection from another. Unlike a // Tailscale node has rejected the connection from another. Unlike a
@ -75,8 +73,11 @@ const (
// TSMPTypePong is the type byte for a TailscalePongResponse. // TSMPTypePong is the type byte for a TailscalePongResponse.
TSMPTypePong TSMPType = 'o' TSMPTypePong TSMPType = 'o'
// TSPMTypeDiscoAdvertisement is the type byte for sending disco keys // TSMPTypeDiscoKeyRequest is the type byte for a disco key request.
TSMPTypeDiscoAdvertisement TSMPType = 'a' TSMPTypeDiscoKeyRequest TSMPType = 'd'
// TSMPTypeDiscoKeyUpdate is the type byte for a disco key update.
TSMPTypeDiscoKeyUpdate TSMPType = 'D'
) )
type TailscaleRejectReason byte type TailscaleRejectReason byte
@ -265,52 +266,62 @@ func (h TSMPPongReply) Marshal(buf []byte) error {
return nil return nil
} }
// TSMPDiscoKeyAdvertisement is a TSMP message that's used for distributing Disco Keys. // TSMPDiscoKeyRequest is a TSMP message that requests a peer's disco key.
// //
// On the wire, after the IP header, it's currently 33 bytes: // On the wire, after the IP header, it's currently 1 byte:
// - 'a' (TSMPTypeDiscoAdvertisement) // - 'd' (TSMPTypeDiscoKeyRequest)
// - 32 disco key bytes type TSMPDiscoKeyRequest struct{}
type TSMPDiscoKeyAdvertisement struct {
Src, Dst netip.Addr
Key key.DiscoPublic
}
func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) { func (pp *Parsed) AsTSMPDiscoKeyRequest() (h TSMPDiscoKeyRequest, ok bool) {
var iph Header if pp.IPProto != ipproto.TSMP {
if ka.Src.Is4() { return
iph = IP4Header{
IPProto: ipproto.TSMP,
Src: ka.Src,
Dst: ka.Dst,
}
} else {
iph = IP6Header{
IPProto: ipproto.TSMP,
Src: ka.Src,
Dst: ka.Dst,
}
} }
payload := make([]byte, 0, 33) p := pp.Payload()
payload = append(payload, byte(TSMPTypeDiscoAdvertisement)) if len(p) < 1 || p[0] != byte(TSMPTypeDiscoKeyRequest) {
payload = ka.Key.AppendTo(payload) return
if len(payload) != 33 {
// Mostly to safeguard against ourselves changing this in the future.
return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload))
} }
return h, true
}
return Generate(iph, payload), nil // TSMPDiscoKeyUpdate is a TSMP message that contains a disco public key.
// It may be sent in response to a request, or unsolicited when a node
// believes its peer may have stale disco key information.
//
// On the wire, after the IP header, it's currently 33 bytes:
// - 'D' (TSMPTypeDiscoKeyUpdate)
// - 32 bytes disco public key
type TSMPDiscoKeyUpdate struct {
IPHeader Header
DiscoKey [32]byte // raw disco public key bytes
} }
func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok bool) { // AsTSMPDiscoKeyUpdate returns pp as a TSMPDiscoKeyUpdate and whether it is one.
// The update.IPHeader field is not populated.
func (pp *Parsed) AsTSMPDiscoKeyUpdate() (update TSMPDiscoKeyUpdate, ok bool) {
if pp.IPProto != ipproto.TSMP { if pp.IPProto != ipproto.TSMP {
return return
} }
p := pp.Payload() p := pp.Payload()
if len(p) < 33 || p[0] != byte(TSMPTypeDiscoAdvertisement) { if len(p) < 33 || p[0] != byte(TSMPTypeDiscoKeyUpdate) {
return return
} }
tka.Src = pp.Src.Addr() copy(update.DiscoKey[:], p[1:33])
tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33])) return update, true
}
return tka, true func (h TSMPDiscoKeyUpdate) Len() int {
return h.IPHeader.Len() + 33
}
func (h TSMPDiscoKeyUpdate) Marshal(buf []byte) error {
if len(buf) < h.Len() {
return errSmallBuffer
}
if err := h.IPHeader.Marshal(buf); err != nil {
return err
}
buf = buf[h.IPHeader.Len():]
buf[0] = byte(TSMPTypeDiscoKeyUpdate)
copy(buf[1:33], h.DiscoKey[:])
return nil
} }

@ -4,14 +4,8 @@
package packet package packet
import ( import (
"bytes"
"encoding/hex"
"net/netip" "net/netip"
"slices"
"testing" "testing"
"go4.org/mem"
"tailscale.com/types/key"
) )
func TestTailscaleRejectedHeader(t *testing.T) { func TestTailscaleRejectedHeader(t *testing.T) {
@ -78,61 +72,168 @@ func TestTailscaleRejectedHeader(t *testing.T) {
} }
} }
func TestTSMPDiscoKeyAdvertisementMarshal(t *testing.T) { func TestTSMPDiscoKeyRequest(t *testing.T) {
var ( t.Run("Manual", func(t *testing.T) {
// IPv4: Ver(4)Len(5), TOS, Len(53), ID, Flags, TTL(64), Proto(99), Cksum var payload [1]byte
headerV4, _ = hex.DecodeString("45000035000000004063705d") payload[0] = byte(TSMPTypeDiscoKeyRequest)
// IPv6: Ver(6)TCFlow, Len(33), NextHdr(99), HopLim(64)
headerV6, _ = hex.DecodeString("6000000000216340") var p Parsed
p.IPProto = TSMP
packetType = []byte{'a'} p.dataofs = 40 // simulate after IP header
testKey = bytes.Repeat([]byte{'a'}, 32) buf := make([]byte, 40+1)
copy(buf[40:], payload[:])
// IPs p.b = buf
srcV4 = netip.MustParseAddr("1.2.3.4") p.length = len(buf)
dstV4 = netip.MustParseAddr("4.3.2.1")
srcV6 = netip.MustParseAddr("2001:db8::1") _, ok := p.AsTSMPDiscoKeyRequest()
dstV6 = netip.MustParseAddr("2001:db8::2") if !ok {
) t.Fatal("failed to parse TSMP disco key request")
}
join := func(parts ...[]byte) []byte { })
return bytes.Join(parts, nil)
t.Run("RoundTripIPv4", func(t *testing.T) {
src := netip.MustParseAddr("100.64.0.1")
dst := netip.MustParseAddr("100.64.0.2")
iph := IP4Header{
IPProto: TSMP,
Src: src,
Dst: dst,
}
var payload [1]byte
payload[0] = byte(TSMPTypeDiscoKeyRequest)
pkt := Generate(iph, payload[:])
t.Logf("Generated packet: %d bytes, hex: %x", len(pkt), pkt)
// Manually check what decode4 would see
if len(pkt) >= 4 {
declaredLen := int(uint16(pkt[2])<<8 | uint16(pkt[3]))
t.Logf("Packet buffer length: %d, IP header declares length: %d", len(pkt), declaredLen)
t.Logf("Protocol byte at [9]: 0x%02x = %d", pkt[9], pkt[9])
}
var p Parsed
p.Decode(pkt)
t.Logf("Decoded: IPVersion=%d IPProto=%v Src=%v Dst=%v length=%d dataofs=%d",
p.IPVersion, p.IPProto, p.Src, p.Dst, p.length, p.dataofs)
if p.IPVersion != 4 {
t.Errorf("IPVersion = %d, want 4", p.IPVersion)
}
if p.IPProto != TSMP {
t.Errorf("IPProto = %v, want TSMP", p.IPProto)
}
if p.Src.Addr() != src {
t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
}
if p.Dst.Addr() != dst {
t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
}
_, ok := p.AsTSMPDiscoKeyRequest()
if !ok {
t.Fatal("failed to parse TSMP disco key request from generated packet")
}
})
t.Run("RoundTripIPv6", func(t *testing.T) {
src := netip.MustParseAddr("2001:db8::1")
dst := netip.MustParseAddr("2001:db8::2")
iph := IP6Header{
IPProto: TSMP,
Src: src,
Dst: dst,
}
var payload [1]byte
payload[0] = byte(TSMPTypeDiscoKeyRequest)
pkt := Generate(iph, payload[:])
t.Logf("Generated packet: %d bytes", len(pkt))
var p Parsed
p.Decode(pkt)
if p.IPVersion != 6 {
t.Errorf("IPVersion = %d, want 6", p.IPVersion)
}
if p.IPProto != TSMP {
t.Errorf("IPProto = %v, want TSMP", p.IPProto)
}
if p.Src.Addr() != src {
t.Errorf("Src = %v, want %v", p.Src.Addr(), src)
}
if p.Dst.Addr() != dst {
t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst)
}
_, ok := p.AsTSMPDiscoKeyRequest()
if !ok {
t.Fatal("failed to parse TSMP disco key request from generated packet")
}
})
}
func TestTSMPDiscoKeyUpdate(t *testing.T) {
var discoKey [32]byte
for i := range discoKey {
discoKey[i] = byte(i + 10)
} }
tests := []struct { t.Run("IPv4", func(t *testing.T) {
name string update := TSMPDiscoKeyUpdate{
tka TSMPDiscoKeyAdvertisement IPHeader: IP4Header{
want []byte IPProto: TSMP,
}{ Src: netip.MustParseAddr("1.2.3.4"),
{ Dst: netip.MustParseAddr("5.6.7.8"),
name: "v4Header",
tka: TSMPDiscoKeyAdvertisement{
Src: srcV4,
Dst: dstV4,
Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
}, },
want: join(headerV4, srcV4.AsSlice(), dstV4.AsSlice(), packetType, testKey), DiscoKey: discoKey,
}, }
{
name: "v6Header", pkt := make([]byte, update.Len())
tka: TSMPDiscoKeyAdvertisement{ if err := update.Marshal(pkt); err != nil {
Src: srcV6, t.Fatal(err)
Dst: dstV6, }
Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
var p Parsed
p.Decode(pkt)
parsed, ok := p.AsTSMPDiscoKeyUpdate()
if !ok {
t.Fatal("failed to parse TSMP disco key update")
}
if parsed.DiscoKey != discoKey {
t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
}
})
t.Run("IPv6", func(t *testing.T) {
update := TSMPDiscoKeyUpdate{
IPHeader: IP6Header{
IPProto: TSMP,
Src: netip.MustParseAddr("2001:db8::1"),
Dst: netip.MustParseAddr("2001:db8::2"),
}, },
want: join(headerV6, srcV6.AsSlice(), dstV6.AsSlice(), packetType, testKey), DiscoKey: discoKey,
}, }
}
for _, tt := range tests { pkt := make([]byte, update.Len())
t.Run(tt.name, func(t *testing.T) { if err := update.Marshal(pkt); err != nil {
got, err := tt.tka.Marshal() t.Fatal(err)
if err != nil { }
t.Errorf("error mashalling TSMPDiscoAdvertisement: %s", err)
} var p Parsed
if !slices.Equal(got, tt.want) { p.Decode(pkt)
t.Errorf("error mashalling TSMPDiscoAdvertisement, expected: \n%x, \ngot:\n%x", tt.want, got)
} parsed, ok := p.AsTSMPDiscoKeyUpdate()
}) if !ok {
} t.Fatal("failed to parse TSMP disco key update")
}
if parsed.DiscoKey != discoKey {
t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey)
}
})
} }

@ -194,6 +194,10 @@ type Wrapper struct {
// false otherwise. // false otherwise.
OnICMPEchoResponseReceived func(*packet.Parsed) bool OnICMPEchoResponseReceived func(*packet.Parsed) bool
// GetDiscoPublicKey, if non-nil, returns the local node's disco public key.
// This is called when responding to TSMP disco key requests.
GetDiscoPublicKey func() key.DiscoPublic
// PeerAPIPort, if non-nil, returns the peerapi port that's // PeerAPIPort, if non-nil, returns the peerapi port that's
// running for the given IP address. // running for the given IP address.
PeerAPIPort func(netip.Addr) (port uint16, ok bool) PeerAPIPort func(netip.Addr) (port uint16, ok bool)
@ -211,8 +215,8 @@ type Wrapper struct {
metrics *metrics metrics *metrics
eventClient *eventbus.Client eventClient *eventbus.Client
discoKeyAdvertisementPub *eventbus.Publisher[DiscoKeyAdvertisement] discoKeyUpdatePub *eventbus.Publisher[DiscoKeyUpdate]
} }
type metrics struct { type metrics struct {
@ -227,6 +231,12 @@ func registerMetrics(reg *usermetric.Registry) *metrics {
} }
} }
// DiscoKeyUpdate is published on the event bus when a TSMP disco key update is received.
type DiscoKeyUpdate struct {
SrcIP netip.Addr
Key [32]byte
}
// tunInjectedRead is an injected packet pretending to be a tun.Read(). // tunInjectedRead is an injected packet pretending to be a tun.Read().
type tunInjectedRead struct { type tunInjectedRead struct {
// Only one of packet or data should be set, and are read in that order of // Only one of packet or data should be set, and are read in that order of
@ -288,7 +298,7 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry,
} }
w.eventClient = bus.Client("net.tstun") w.eventClient = bus.Client("net.tstun")
w.discoKeyAdvertisementPub = eventbus.Publish[DiscoKeyAdvertisement](w.eventClient) w.discoKeyUpdatePub = eventbus.Publish[DiscoKeyUpdate](w.eventClient)
w.vectorBuffer = make([][]byte, tdev.BatchSize()) w.vectorBuffer = make([][]byte, tdev.BatchSize())
for i := range w.vectorBuffer { for i := range w.vectorBuffer {
@ -1126,11 +1136,6 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i
return n, err return n, err
} }
type DiscoKeyAdvertisement struct {
Src netip.Addr
Key key.DiscoPublic
}
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) { func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) {
if captHook != nil { if captHook != nil {
captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta)
@ -1141,16 +1146,21 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa
t.noteActivity() t.noteActivity()
t.injectOutboundPong(p, pingReq) t.injectOutboundPong(p, pingReq)
return filter.DropSilently, gro return filter.DropSilently, gro
} else if discoKeyAdvert, ok := p.AsTSMPDiscoAdvertisement(); ok {
t.discoKeyAdvertisementPub.Publish(DiscoKeyAdvertisement{
Src: discoKeyAdvert.Src,
Key: discoKeyAdvert.Key,
})
return filter.DropSilently, gro
} else if data, ok := p.AsTSMPPong(); ok { } else if data, ok := p.AsTSMPPong(); ok {
if f := t.OnTSMPPongReceived; f != nil { if f := t.OnTSMPPongReceived; f != nil {
f(data) f(data)
} }
} else if _, ok := p.AsTSMPDiscoKeyRequest(); ok {
t.noteActivity()
t.injectOutboundDiscoKeyUpdate(p)
return filter.DropSilently, gro
} else if discoKeyUpdate, ok := p.AsTSMPDiscoKeyUpdate(); ok {
// Publish to eventbus for subscribers
t.discoKeyUpdatePub.Publish(DiscoKeyUpdate{
SrcIP: p.Src.Addr(),
Key: discoKeyUpdate.DiscoKey,
})
return filter.DropSilently, gro
} }
} }
@ -1459,6 +1469,36 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque
t.InjectOutbound(packet.Generate(pong, nil)) t.InjectOutbound(packet.Generate(pong, nil))
} }
func (t *Wrapper) injectOutboundDiscoKeyUpdate(pp *packet.Parsed) {
if t.GetDiscoPublicKey == nil {
return
}
discoKey := t.GetDiscoPublicKey()
if discoKey.IsZero() {
return
}
update := packet.TSMPDiscoKeyUpdate{
DiscoKey: discoKey.Raw32(),
}
switch pp.IPVersion {
case 4:
h4 := pp.IP4Header()
h4.ToResponse()
update.IPHeader = h4
case 6:
h6 := pp.IP6Header()
h6.ToResponse()
update.IPHeader = h6
default:
return
}
t.InjectOutbound(packet.Generate(update, nil))
}
// InjectOutbound makes the Wrapper device behave as if a packet // InjectOutbound makes the Wrapper device behave as if a packet
// with the given contents was sent to the network. // with the given contents was sent to the network.
// It does not block, but takes ownership of the packet. // It does not block, but takes ownership of the packet.

@ -966,28 +966,57 @@ func TestCaptureHook(t *testing.T) {
} }
func TestTSMPDisco(t *testing.T) { func TestTSMPDisco(t *testing.T) {
t.Run("IPv6DiscoAdvert", func(t *testing.T) { t.Run("DiscoKeyRequest", func(t *testing.T) {
src := netip.MustParseAddr("2001:db8::1") src := netip.MustParseAddr("2001:db8::1")
dst := netip.MustParseAddr("2001:db8::2") dst := netip.MustParseAddr("2001:db8::2")
discoKey := key.NewDisco()
buf, _ := (&packet.TSMPDiscoKeyAdvertisement{ iph := packet.IP6Header{
Src: src, IPProto: ipproto.TSMP,
Dst: dst, Src: src,
Key: discoKey.Public(), Dst: dst,
}).Marshal() }
var payload [1]byte
payload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
buf := packet.Generate(iph, payload[:])
var p packet.Parsed var p packet.Parsed
p.Decode(buf) p.Decode(buf)
tda, ok := p.AsTSMPDiscoAdvertisement() _, ok := p.AsTSMPDiscoKeyRequest()
if !ok { if !ok {
t.Error("Unable to parse message as TSMPDiscoAdversitement") t.Error("Unable to parse message as TSMPDiscoKeyRequest")
} }
if tda.Src != src { })
t.Errorf("Src address did not match, expected %v, got %v", src, tda.Src)
t.Run("DiscoKeyUpdate", func(t *testing.T) {
src := netip.MustParseAddr("2001:db8::1")
dst := netip.MustParseAddr("2001:db8::2")
discoKey := key.NewDisco()
update := packet.TSMPDiscoKeyUpdate{
IPHeader: packet.IP6Header{
IPProto: ipproto.TSMP,
Src: src,
Dst: dst,
},
DiscoKey: discoKey.Public().Raw32(),
}
buf := make([]byte, update.Len())
if err := update.Marshal(buf); err != nil {
t.Fatal(err)
}
var p packet.Parsed
p.Decode(buf)
parsed, ok := p.AsTSMPDiscoKeyUpdate()
if !ok {
t.Error("Unable to parse message as TSMPDiscoKeyUpdate")
} }
if !reflect.DeepEqual(tda.Key, discoKey.Public()) { if parsed.DiscoKey != discoKey.Public().Raw32() {
t.Errorf("Key did not match, expected %q, got %q", discoKey.Public(), tda.Key) t.Errorf("Key did not match, expected %v, got %v", discoKey.Public().Raw32(), parsed.DiscoKey)
} }
}) })
} }

@ -721,6 +721,16 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, dm.n, true) update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, dm.n, true)
} }
// Request disco key from peer via TSMP if we receive a WireGuard handshake
// over DERP without recent disco success. This handles the "WireGuard-first"
// case where WireGuard establishes a tunnel via DERP before disco succeeds
// (e.g., control plane unreachable or stale disco keys).
if looksLikeWireGuardHandshake(b[:n]) && n > 0 {
c.mu.Lock()
c.requestDiscoKeyViaTSMPLocked(dm.src, ep)
c.mu.Unlock()
}
c.metrics.inboundPacketsDERPTotal.Add(1) c.metrics.inboundPacketsDERPTotal.Add(1)
c.metrics.inboundBytesDERPTotal.Add(int64(n)) c.metrics.inboundBytesDERPTotal.Add(int64(n))
return n, ep return n, ep

@ -178,9 +178,10 @@ type Conn struct {
// A publisher for synchronization points to ensure correct ordering of // A publisher for synchronization points to ensure correct ordering of
// config changes between magicsock and wireguard. // config changes between magicsock and wireguard.
syncPub *eventbus.Publisher[syncPoint] syncPub *eventbus.Publisher[syncPoint]
allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq] allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq]
portUpdatePub *eventbus.Publisher[router.PortUpdate] portUpdatePub *eventbus.Publisher[router.PortUpdate]
tsmpDiscoKeyRequestPub *eventbus.Publisher[TSMPDiscoKeyRequest]
// pconn4 and pconn6 are the underlying UDP sockets used to // pconn4 and pconn6 are the underlying UDP sockets used to
// send/receive packets for wireguard and other magicsock // send/receive packets for wireguard and other magicsock
@ -572,6 +573,14 @@ type UDPRelayAllocReq struct {
Message *disco.AllocateUDPRelayEndpointRequest Message *disco.AllocateUDPRelayEndpointRequest
} }
// TSMPDiscoKeyRequest is published on the event bus when magicsock needs to
// send a TSMP disco key request to a peer. Subscribers should inject the
// TSMP packet into the tunnel device.
type TSMPDiscoKeyRequest struct {
DstIP netip.Addr
MetricSent *clientmetric.Metric
}
// UDPRelayAllocResp represents a [*disco.AllocateUDPRelayEndpointResponse] // UDPRelayAllocResp represents a [*disco.AllocateUDPRelayEndpointResponse]
// that is yet to be transmitted over DERP (or delivered locally if // that is yet to be transmitted over DERP (or delivered locally if
// ReqRxFromNodeKey is self). This is signaled over an [eventbus.Bus] from // ReqRxFromNodeKey is self). This is signaled over an [eventbus.Bus] from
@ -691,6 +700,7 @@ func NewConn(opts Options) (*Conn, error) {
c.syncPub = eventbus.Publish[syncPoint](ec) c.syncPub = eventbus.Publish[syncPoint](ec)
c.allocRelayEndpointPub = eventbus.Publish[UDPRelayAllocReq](ec) c.allocRelayEndpointPub = eventbus.Publish[UDPRelayAllocReq](ec)
c.portUpdatePub = eventbus.Publish[router.PortUpdate](ec) c.portUpdatePub = eventbus.Publish[router.PortUpdate](ec)
c.tsmpDiscoKeyRequestPub = eventbus.Publish[TSMPDiscoKeyRequest](ec)
eventbus.SubscribeFunc(ec, c.onPortMapChanged) eventbus.SubscribeFunc(ec, c.onPortMapChanged)
eventbus.SubscribeFunc(ec, c.onFilterUpdate) eventbus.SubscribeFunc(ec, c.onFilterUpdate)
eventbus.SubscribeFunc(ec, c.onNodeViewsUpdate) eventbus.SubscribeFunc(ec, c.onNodeViewsUpdate)
@ -1800,6 +1810,15 @@ func looksLikeInitiationMsg(b []byte) bool {
binary.LittleEndian.Uint32(b) == device.MessageInitiationType binary.LittleEndian.Uint32(b) == device.MessageInitiationType
} }
func looksLikeWireGuardHandshake(b []byte) bool {
if len(b) < 4 {
return false
}
msgType := binary.LittleEndian.Uint32(b)
return (len(b) == device.MessageInitiationSize && msgType == device.MessageInitiationType) ||
(len(b) == device.MessageResponseSize && msgType == device.MessageResponseType)
}
// receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6. // receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6.
// //
// size is the length of 'b' to report up to wireguard-go (only relevant if // size is the length of 'b' to report up to wireguard-go (only relevant if
@ -4104,6 +4123,12 @@ var (
metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff") metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff")
metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff") metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff")
metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff") metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff")
// TSMP disco key exchange
metricTSMPDiscoKeyRequestSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_sent")
metricTSMPDiscoKeyUpdateReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_received")
metricTSMPDiscoKeyUpdateApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_applied")
metricTSMPDiscoKeyUpdateUnknown = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_unknown_peer")
) )
// newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided // newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided
@ -4242,6 +4267,81 @@ func (le *lazyEndpoint) FromPeer(peerPublicKey [32]byte) {
// See http://go/corp/29422 & http://go/corp/30042 // See http://go/corp/29422 & http://go/corp/30042
le.c.peerMap.setNodeKeyForEpAddr(le.src, pubKey) le.c.peerMap.setNodeKeyForEpAddr(le.src, pubKey)
le.c.logf("magicsock: lazyEndpoint.FromPeer(%v) setting epAddr(%v) in peerMap for node(%v)", pubKey.ShortString(), le.src, ep.nodeAddr) le.c.logf("magicsock: lazyEndpoint.FromPeer(%v) setting epAddr(%v) in peerMap for node(%v)", pubKey.ShortString(), le.src, ep.nodeAddr)
le.c.requestDiscoKeyViaTSMPLocked(pubKey, ep)
}
// requestDiscoKeyViaTSMPLocked sends a TSMP disco key request to a peer if there
// hasn't been a recent disco ping.
// c.mu must be held.
func (c *Conn) requestDiscoKeyViaTSMPLocked(nodeKey key.NodePublic, ep *endpoint) {
if !ep.nodeAddr.IsValid() {
return
}
epDisco := ep.disco.Load()
if epDisco != nil {
di := c.discoInfo[epDisco.key]
recentDiscoPing := di != nil && time.Since(di.lastPingTime) < discoPingInterval
if recentDiscoPing {
return
}
}
go c.tsmpDiscoKeyRequestPub.Publish(TSMPDiscoKeyRequest{DstIP: ep.nodeAddr, MetricSent: metricTSMPDiscoKeyRequestSent})
}
// HandleDiscoKeyUpdate processes a TSMP disco key update.
// The update may be solicited (in response to a request) or unsolicited.
// srcIP is the Tailscale IP of the peer that sent the update.
func (c *Conn) HandleDiscoKeyUpdate(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
discoKey := key.DiscoPublicFromRaw32(mem.B(update.DiscoKey[:]))
c.logf("magicsock: received disco key update %v from %v", discoKey.ShortString(), srcIP)
metricTSMPDiscoKeyUpdateReceived.Add(1)
c.mu.Lock()
defer c.mu.Unlock()
var nodeKey key.NodePublic
var found bool
for _, peer := range c.peers.All() {
for _, addr := range peer.Addresses().All() {
if addr.Addr() == srcIP {
nodeKey = peer.Key()
found = true
break
}
}
if found {
break
}
}
if !found {
c.logf("magicsock: disco key update from unknown peer %v", srcIP)
metricTSMPDiscoKeyUpdateUnknown.Add(1)
return
}
ep, ok := c.peerMap.endpointForNodeKey(nodeKey)
if !ok {
c.logf("magicsock: endpoint not found for node %v", nodeKey.ShortString())
return
}
oldDiscoKey := key.DiscoPublic{}
if epDisco := ep.disco.Load(); epDisco != nil {
oldDiscoKey = epDisco.key
}
c.discoInfoForKnownPeerLocked(discoKey)
ep.disco.Store(&endpointDisco{
key: discoKey,
short: discoKey.ShortString(),
})
c.peerMap.upsertEndpoint(ep, oldDiscoKey)
c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString())
metricTSMPDiscoKeyUpdateApplied.Add(1)
} }
// PeerRelays returns the current set of candidate peer relays. // PeerRelays returns the current set of candidate peer relays.

@ -64,6 +64,7 @@ import (
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
"tailscale.com/types/nettype" "tailscale.com/types/nettype"
"tailscale.com/types/ptr" "tailscale.com/types/ptr"
"tailscale.com/types/views"
"tailscale.com/util/cibuild" "tailscale.com/util/cibuild"
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/util/eventbus" "tailscale.com/util/eventbus"
@ -4302,3 +4303,66 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) {
keys = append(keys, newKey) keys = append(keys, newKey)
} }
} }
func TestSendTSMPDiscoKeyRequest(t *testing.T) {
ep := &endpoint{
nodeID: 1,
publicKey: key.NewNode().Public(),
nodeAddr: netip.MustParseAddr("100.64.0.1"),
}
discoKey := key.NewDisco().Public()
ep.disco.Store(&endpointDisco{
key: discoKey,
short: discoKey.ShortString(),
})
bus := eventbustest.NewBus(t)
conn := newConn(t.Logf)
conn.eventBus = bus
conn.eventClient = bus.Client("magicsock.Conn.test")
conn.tsmpDiscoKeyRequestPub = eventbus.Publish[TSMPDiscoKeyRequest](conn.eventClient)
ep.c = conn
tsmpRequestCalled := make(chan struct{}, 1)
var capturedIP netip.Addr
ec := bus.Client("test")
defer ec.Close()
eventbus.SubscribeFunc(ec, func(req TSMPDiscoKeyRequest) {
capturedIP = req.DstIP
if req.MetricSent != nil {
req.MetricSent.Add(1)
}
tsmpRequestCalled <- struct{}{}
})
conn.mu.Lock()
conn.peers = views.SliceOf([]tailcfg.NodeView{
(&tailcfg.Node{
Key: ep.publicKey,
Addresses: []netip.Prefix{
netip.MustParsePrefix("100.64.0.1/32"),
},
}).View(),
})
conn.mu.Unlock()
var pubKey [32]byte
copy(pubKey[:], ep.publicKey.AppendTo(nil))
conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{})
le := &lazyEndpoint{
c: conn,
src: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")},
}
le.FromPeer(pubKey)
select {
case <-tsmpRequestCalled:
if !capturedIP.IsValid() {
t.Error("TSMP request sent with invalid IP")
}
t.Logf("TSMP disco key request sent to %v", capturedIP)
case <-time.After(time.Second):
t.Error("TSMP disco key request was not sent")
}
}

@ -54,6 +54,7 @@ import (
"tailscale.com/util/execqueue" "tailscale.com/util/execqueue"
"tailscale.com/util/mak" "tailscale.com/util/mak"
"tailscale.com/util/set" "tailscale.com/util/set"
"tailscale.com/util/singleflight"
"tailscale.com/util/testenv" "tailscale.com/util/testenv"
"tailscale.com/util/usermetric" "tailscale.com/util/usermetric"
"tailscale.com/version" "tailscale.com/version"
@ -469,6 +470,13 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
return true return true
} }
e.tundev.GetDiscoPublicKey = func() key.DiscoPublic {
if e.magicConn == nil {
return key.DiscoPublic{}
}
return e.magicConn.DiscoPublicKey()
}
// wgdev takes ownership of tundev, will close it when closed. // wgdev takes ownership of tundev, will close it when closed.
e.logf("Creating WireGuard device...") e.logf("Creating WireGuard device...")
e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger) e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger)
@ -549,6 +557,36 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
} }
e.linkChangeQueue.Add(func() { e.linkChange(&cd) }) e.linkChangeQueue.Add(func() { e.linkChange(&cd) })
}) })
eventbus.SubscribeFunc(ec, func(update tstun.DiscoKeyUpdate) {
e.logf("wgengine: got TSMP disco key update from %v via eventbus", update.SrcIP)
if e.magicConn != nil {
pkt := packet.TSMPDiscoKeyUpdate{
DiscoKey: update.Key,
}
e.magicConn.HandleDiscoKeyUpdate(update.SrcIP, pkt)
}
})
var tsmpRequestGroup singleflight.Group[netip.Addr, struct{}]
eventbus.SubscribeFunc(ec, func(req magicsock.TSMPDiscoKeyRequest) {
go tsmpRequestGroup.Do(req.DstIP, func() (struct{}, error) {
// DiscoKeyRequests are triggered by an incoming WireGuard handshake
// initiation arriving before a disco ping, which is a likely
// indicator that disco pings failed due to a lack of key
// synchronization. If the requests are sent immediately, before the
// handshake state is accepted in the WireGuard client state
// machine, this starts a new session, and the two peer state
// machines conflict, causing loss and additional delays. Delaying
// the send avoids this, so coalesce duplicate sends, and delay them
// by a short time to avoid the state machine conflict.
time.Sleep(time.Millisecond)
if err := e.sendTSMPDiscoKeyRequest(req.DstIP); err != nil {
e.logf("wgengine: failed to send TSMP disco key request: %v", err)
}
e.logf("wgengine: sending TSMP disco key request to %v", req.DstIP)
req.MetricSent.Add(1)
return struct{}{}, nil
})
})
e.eventClient = ec e.eventClient = ec
e.logf("Engine created.") e.logf("Engine created.")
return e, nil return e, nil
@ -1436,7 +1474,6 @@ func (e *userspaceEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size in
e.magicConn.Ping(peer, res, size, cb) e.magicConn.Ping(peer, res, size, cb)
case "TSMP": case "TSMP":
e.sendTSMPPing(ip, peer, res, cb) e.sendTSMPPing(ip, peer, res, cb)
e.sendTSMPDiscoAdvertisement(ip)
case "ICMP": case "ICMP":
e.sendICMPEchoRequest(ip, peer, res, cb) e.sendICMPEchoRequest(ip, peer, res, cb)
} }
@ -1557,29 +1594,6 @@ func (e *userspaceEngine) sendTSMPPing(ip netip.Addr, peer tailcfg.NodeView, res
e.tundev.InjectOutbound(tsmpPing) e.tundev.InjectOutbound(tsmpPing)
} }
func (e *userspaceEngine) sendTSMPDiscoAdvertisement(ip netip.Addr) {
srcIP, err := e.mySelfIPMatchingFamily(ip)
if err != nil {
e.logf("getting matching node: %s", err)
return
}
tdka := packet.TSMPDiscoKeyAdvertisement{
Src: srcIP,
Dst: ip,
Key: e.magicConn.DiscoPublicKey(),
}
payload, err := tdka.Marshal()
if err != nil {
e.logf("error generating TSMP Advertisement: %s", err)
metricTSMPDiscoKeyAdvertisementError.Add(1)
} else if err := e.tundev.InjectOutbound(payload); err != nil {
e.logf("error sending TSMP Advertisement: %s", err)
metricTSMPDiscoKeyAdvertisementError.Add(1)
} else {
metricTSMPDiscoKeyAdvertisementSent.Add(1)
}
}
func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPPongReply)) { func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPPongReply)) {
e.mu.Lock() e.mu.Lock()
defer e.mu.Unlock() defer e.mu.Unlock()
@ -1593,6 +1607,35 @@ func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPP
} }
} }
// sendTSMPDiscoKeyRequest sends a TSMP disco key request to the given peer IP.
func (e *userspaceEngine) sendTSMPDiscoKeyRequest(ip netip.Addr) error {
srcIP, err := e.mySelfIPMatchingFamily(ip)
if err != nil {
return err
}
var iph packet.Header
if srcIP.Is4() {
iph = packet.IP4Header{
IPProto: ipproto.TSMP,
Src: srcIP,
Dst: ip,
}
} else {
iph = packet.IP6Header{
IPProto: ipproto.TSMP,
Src: srcIP,
Dst: ip,
}
}
var tsmpPayload [1]byte
tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
tsmpRequest := packet.Generate(iph, tsmpPayload[:])
return e.tundev.InjectOutbound(tsmpRequest)
}
func (e *userspaceEngine) setICMPEchoResponseCallback(idSeq uint32, cb func()) { func (e *userspaceEngine) setICMPEchoResponseCallback(idSeq uint32, cb func()) {
e.mu.Lock() e.mu.Lock()
defer e.mu.Unlock() defer e.mu.Unlock()
@ -1746,9 +1789,6 @@ var (
metricNumMajorChanges = clientmetric.NewCounter("wgengine_major_changes") metricNumMajorChanges = clientmetric.NewCounter("wgengine_major_changes")
metricNumMinorChanges = clientmetric.NewCounter("wgengine_minor_changes") metricNumMinorChanges = clientmetric.NewCounter("wgengine_minor_changes")
metricTSMPDiscoKeyAdvertisementSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_sent")
metricTSMPDiscoKeyAdvertisementError = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_error")
) )
func (e *userspaceEngine) InstallCaptureHook(cb packet.CaptureCallback) { func (e *userspaceEngine) InstallCaptureHook(cb packet.CaptureCallback) {

@ -325,7 +325,7 @@ func TestUserspaceEnginePeerMTUReconfig(t *testing.T) {
} }
} }
func TestTSMPKeyAdvertisement(t *testing.T) { func TestTSMPDiscoKeyRequest(t *testing.T) {
var knobs controlknobs.Knobs var knobs controlknobs.Knobs
bus := eventbustest.NewBus(t) bus := eventbustest.NewBus(t)
@ -369,13 +369,12 @@ func TestTSMPKeyAdvertisement(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
addr := netip.MustParseAddr("100.100.99.1") peerAddr := netip.MustParseAddr("100.100.99.1")
previousValue := metricTSMPDiscoKeyAdvertisementSent.Value() err = ue.sendTSMPDiscoKeyRequest(peerAddr)
ue.sendTSMPDiscoAdvertisement(addr) if err != nil {
if val := metricTSMPDiscoKeyAdvertisementSent.Value(); val <= previousValue { t.Fatalf("sendTSMPDiscoKeyRequest failed: %v", err)
errs := metricTSMPDiscoKeyAdvertisementError.Value()
t.Errorf("Expected 1 disco key advert, got %d, errors %d", val, errs)
} }
// Remove config to have the engine shut down more consistently // Remove config to have the engine shut down more consistently
err = ue.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) err = ue.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{})
if err != nil { if err != nil {

Loading…
Cancel
Save