You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
tailscale/wgengine/magicsock/tsmp_disco_test.go

292 lines
8.9 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"net/netip"
"testing"
"time"
"github.com/tailscale/wireguard-go/tun/tuntest"
"tailscale.com/net/netaddr"
"tailscale.com/net/packet"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
"tailscale.com/types/netmap"
"tailscale.com/util/set"
"tailscale.com/wgengine/wgcfg/nmcfg"
)
func TestTSMPDiscoKeyExchange(t *testing.T) {
tstest.ResourceCheck(t)
// Set up DERP and STUN servers
derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1))
defer cleanup()
// Create two magicsock peers
m1 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m1.Close()
m2 := newMagicStack(t, t.Logf, localhostListener{}, derpMap)
defer m2.Close()
// Wire up TSMP hooks to enable disco key exchange
// This mimics what userspaceEngine does in wgengine/userspace.go
// Hook 0: GetDiscoPublicKey - allows TSMP replies to include current disco key
m1.tsTun.GetDiscoPublicKey = m1.conn.DiscoPublicKey
m2.tsTun.GetDiscoPublicKey = m2.conn.DiscoPublicKey
// Hook 1: OnTSMPDiscoKeyReceived - handle incoming TSMP disco key updates
m1.tsTun.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
t.Logf("m1: received TSMP disco key update from %v", srcIP)
m1.conn.HandleDiscoKeyUpdate(srcIP, update)
}
m2.tsTun.OnTSMPDiscoKeyReceived = func(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) {
t.Logf("m2: received TSMP disco key update from %v", srcIP)
m2.conn.HandleDiscoKeyUpdate(srcIP, update)
}
sendTSMPDiscoKeyRequest := func(dstIP netip.Addr) error {
var srcIP netip.Addr
var stack *magicStack
switch dstIP {
case m1.IP():
srcIP = m2.IP()
stack = m2
t.Logf("m2: sending disco key request to m1")
case m2.IP():
srcIP = m1.IP()
stack = m1
t.Logf("m1: sending disco key request to m2")
}
// equivalent to the implementation in userspace.Engine
iph := packet.IP4Header{
IPProto: ipproto.TSMP,
Src: srcIP,
Dst: dstIP,
}
var tsmpPayload [1]byte
tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest)
tsmpRequest := packet.Generate(iph, tsmpPayload[:])
return stack.tsTun.InjectOutbound(tsmpRequest)
}
// Hook 2: SetSendTSMPDiscoKeyRequest - send TSMP disco key requests
m1.conn.SetSendTSMPDiscoKeyRequest(sendTSMPDiscoKeyRequest)
m2.conn.SetSendTSMPDiscoKeyRequest(sendTSMPDiscoKeyRequest)
// Get initial disco keys
disco1Original := m1.conn.DiscoPublicKey()
disco2 := m2.conn.DiscoPublicKey()
t.Logf("m1: node=%v disco=%v", m1.Public().ShortString(), disco1Original.ShortString())
t.Logf("m2: node=%v disco=%v", m2.Public().ShortString(), disco2.ShortString())
// Wait for initial endpoints
var eps1, eps2 []tailcfg.Endpoint
select {
case eps1 = <-m1.epCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for m1 endpoints")
}
select {
case eps2 = <-m2.epCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for m2 endpoints")
}
// Build initial network maps and establish connection
nm1 := &netmap.NetworkMap{
NodeKey: m1.Public(),
SelfNode: (&tailcfg.Node{
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
}).View(),
Peers: []tailcfg.NodeView{
(&tailcfg.Node{
ID: 2,
Key: m2.Public(),
DiscoKey: disco2,
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
Endpoints: epFromTyped(eps2),
HomeDERP: 1,
}).View(),
},
}
nm2 := &netmap.NetworkMap{
NodeKey: m2.Public(),
SelfNode: (&tailcfg.Node{
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 2), 32)},
}).View(),
Peers: []tailcfg.NodeView{
(&tailcfg.Node{
ID: 1,
Key: m1.Public(),
DiscoKey: disco1Original,
Addresses: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
AllowedIPs: []netip.Prefix{netip.PrefixFrom(netaddr.IPv4(100, 64, 0, 1), 32)},
Endpoints: epFromTyped(eps1),
HomeDERP: 1,
}).View(),
},
}
cfg1, err := nmcfg.WGCfg(m1.privateKey, nm1, t.Logf, 0, "")
if err != nil {
t.Fatal(err)
}
cfg2, err := nmcfg.WGCfg(m2.privateKey, nm2, t.Logf, 0, "")
if err != nil {
t.Fatal(err)
}
nv1 := NodeViewsUpdate{
SelfNode: nm1.SelfNode,
Peers: nm1.Peers,
}
m1.conn.onNodeViewsUpdate(nv1)
peerSet1 := set.Set[key.NodePublic]{}
peerSet1.Add(m2.Public())
m1.conn.UpdatePeers(peerSet1)
nv2 := NodeViewsUpdate{
SelfNode: nm2.SelfNode,
Peers: nm2.Peers,
}
m2.conn.onNodeViewsUpdate(nv2)
peerSet2 := set.Set[key.NodePublic]{}
peerSet2.Add(m1.Public())
m2.conn.UpdatePeers(peerSet2)
if err := m1.Reconfig(cfg1); err != nil {
t.Fatal(err)
}
if err := m2.Reconfig(cfg2); err != nil {
t.Fatal(err)
}
t.Logf("=== INITIAL CONFIGURATION COMPLETE ===")
// Start goroutines to drain TUN inbound channels so TSMP packets can be received
drainTun := func(name string, stack *magicStack) {
go func() {
for {
select {
case <-t.Context().Done():
return
case pkt := <-stack.tun.Inbound:
var p packet.Parsed
p.Decode(pkt)
if p.IPProto == ipproto.TSMP {
t.Logf("%s: received TSMP packet on TUN inbound: %d bytes", name, len(pkt))
} else if p.IPProto == ipproto.ICMPv4 {
t.Logf("%s: received ICMPv4 packet on TUN inbound: %d bytes", name, len(pkt))
} else {
t.Logf("%s: received packet on TUN inbound: %d bytes, proto=%v", name, len(pkt), p.IPProto)
}
}
}
}()
}
drainTun("m1", m1)
drainTun("m2", m2)
initialRequestsSent := metricTSMPDiscoKeyRequestSent.Value()
initialUpdatesReceived := metricTSMPDiscoKeyUpdateReceived.Value()
initialUpdatesApplied := metricTSMPDiscoKeyUpdateApplied.Value()
t.Logf("Initial metrics: requests_sent=%d updates_received=%d updates_applied=%d",
initialRequestsSent, initialUpdatesReceived, initialUpdatesApplied)
t.Logf("=== ROTATING m1's DISCO KEY ===")
m1.conn.RotateDiscoKey()
disco1New := m1.conn.DiscoPublicKey()
if disco1Original.Compare(disco1New) == 0 {
t.Fatal("disco key failed to rotate")
}
t.Logf("Rotated: %v -> %v", disco1Original.ShortString(), disco1New.ShortString())
t.Logf("=== SENDING PACKETS TO TRIGGER TSMP EXCHANGE ===")
ping1to2 := tuntest.Ping(netip.MustParseAddr("100.64.0.2"), netip.MustParseAddr("100.64.0.1"))
// Send packets from m2 to m1 only - this will trigger m1's handshake initiation
// and when m2 receives the encrypted packet, it should trigger FromPeer -> TSMP
select {
case m1.tun.Outbound <- ping1to2:
default:
}
for {
time.Sleep(time.Millisecond)
// Check if m2 has learned m1's new disco key
st := m2.Status()
if ps, ok := st.Peer[m1.Public()]; ok && ps.CurAddr != "" {
t.Logf("Connection established after disco key rotation")
t.Logf("m2 -> m1 via %v", ps.CurAddr)
t.Logf("Disco key rotation: %v -> %v", disco1Original.ShortString(), disco1New.ShortString())
// Verify TSMP metrics incremented
finalRequestsSent := metricTSMPDiscoKeyRequestSent.Value()
finalUpdatesReceived := metricTSMPDiscoKeyUpdateReceived.Value()
finalUpdatesApplied := metricTSMPDiscoKeyUpdateApplied.Value()
t.Logf("Final metrics: requests_sent=%d updates_received=%d updates_applied=%d",
finalRequestsSent, finalUpdatesReceived, finalUpdatesApplied)
// Check that at least one TSMP request was sent
if finalRequestsSent <= initialRequestsSent {
t.Errorf("Expected TSMP disco key request to be sent, but metric did not increment: %d -> %d",
initialRequestsSent, finalRequestsSent)
} else {
t.Logf("✓ TSMP disco key request sent (metric: %d -> %d)",
initialRequestsSent, finalRequestsSent)
}
// Check that at least one TSMP update was received
if finalUpdatesReceived <= initialUpdatesReceived {
t.Errorf("Expected TSMP disco key update to be received, but metric did not increment: %d -> %d",
initialUpdatesReceived, finalUpdatesReceived)
} else {
t.Logf("✓ TSMP disco key update received (metric: %d -> %d)",
initialUpdatesReceived, finalUpdatesReceived)
}
// Check that at least one TSMP update was applied
if finalUpdatesApplied <= initialUpdatesApplied {
t.Errorf("Expected TSMP disco key update to be applied, but metric did not increment: %d -> %d",
initialUpdatesApplied, finalUpdatesApplied)
} else {
t.Logf("✓ TSMP disco key update applied (metric: %d -> %d)",
initialUpdatesApplied, finalUpdatesApplied)
}
// Verify error counter didn't increment
requestErrors := metricTSMPDiscoKeyRequestError.Value()
if requestErrors > 0 {
t.Logf("Warning: TSMP disco key request errors: %d", requestErrors)
}
unknownPeers := metricTSMPDiscoKeyUpdateUnknown.Value()
if unknownPeers > 0 {
t.Logf("Warning: TSMP disco key updates from unknown peers: %d", unknownPeers)
}
t.Logf("TSMP disco key exchange infrastructure is functional")
return
}
}
}