diff --git a/net/packet/tsmp.go b/net/packet/tsmp.go index 8fad1d503..b9b805b1a 100644 --- a/net/packet/tsmp.go +++ b/net/packet/tsmp.go @@ -15,12 +15,10 @@ import ( "fmt" "net/netip" - "go4.org/mem" "tailscale.com/types/ipproto" - "tailscale.com/types/key" ) -const minTSMPSize = 7 // the rejected body is 7 bytes +const minTSMPSize = 1 // minimum is 1 byte for the type field (e.g., disco key request 'd') // TailscaleRejectedHeader is a TSMP message that says that one // Tailscale node has rejected the connection from another. Unlike a @@ -75,8 +73,11 @@ const ( // TSMPTypePong is the type byte for a TailscalePongResponse. TSMPTypePong TSMPType = 'o' - // TSPMTypeDiscoAdvertisement is the type byte for sending disco keys - TSMPTypeDiscoAdvertisement TSMPType = 'a' + // TSMPTypeDiscoKeyRequest is the type byte for a disco key request. + TSMPTypeDiscoKeyRequest TSMPType = 'd' + + // TSMPTypeDiscoKeyUpdate is the type byte for a disco key update. + TSMPTypeDiscoKeyUpdate TSMPType = 'D' ) type TailscaleRejectReason byte @@ -265,52 +266,62 @@ func (h TSMPPongReply) Marshal(buf []byte) error { return nil } -// TSMPDiscoKeyAdvertisement is a TSMP message that's used for distributing Disco Keys. +// TSMPDiscoKeyRequest is a TSMP message that requests a peer's disco key. // -// On the wire, after the IP header, it's currently 33 bytes: -// - 'a' (TSMPTypeDiscoAdvertisement) -// - 32 disco key bytes -type TSMPDiscoKeyAdvertisement struct { - Src, Dst netip.Addr - Key key.DiscoPublic -} +// On the wire, after the IP header, it's currently 1 byte: +// - 'd' (TSMPTypeDiscoKeyRequest) +type TSMPDiscoKeyRequest struct{} -func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) { - var iph Header - if ka.Src.Is4() { - iph = IP4Header{ - IPProto: ipproto.TSMP, - Src: ka.Src, - Dst: ka.Dst, - } - } else { - iph = IP6Header{ - IPProto: ipproto.TSMP, - Src: ka.Src, - Dst: ka.Dst, - } +func (pp *Parsed) AsTSMPDiscoKeyRequest() (h TSMPDiscoKeyRequest, ok bool) { + if pp.IPProto != ipproto.TSMP { + return } - payload := make([]byte, 0, 33) - payload = append(payload, byte(TSMPTypeDiscoAdvertisement)) - payload = ka.Key.AppendTo(payload) - if len(payload) != 33 { - // Mostly to safeguard against ourselves changing this in the future. - return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload)) + p := pp.Payload() + if len(p) < 1 || p[0] != byte(TSMPTypeDiscoKeyRequest) { + return } + return h, true +} - return Generate(iph, payload), nil +// TSMPDiscoKeyUpdate is a TSMP message that contains a disco public key. +// It may be sent in response to a request, or unsolicited when a node +// believes its peer may have stale disco key information. +// +// On the wire, after the IP header, it's currently 33 bytes: +// - 'D' (TSMPTypeDiscoKeyUpdate) +// - 32 bytes disco public key +type TSMPDiscoKeyUpdate struct { + IPHeader Header + DiscoKey [32]byte // raw disco public key bytes } -func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok bool) { +// AsTSMPDiscoKeyUpdate returns pp as a TSMPDiscoKeyUpdate and whether it is one. +// The update.IPHeader field is not populated. +func (pp *Parsed) AsTSMPDiscoKeyUpdate() (update TSMPDiscoKeyUpdate, ok bool) { if pp.IPProto != ipproto.TSMP { return } p := pp.Payload() - if len(p) < 33 || p[0] != byte(TSMPTypeDiscoAdvertisement) { + if len(p) < 33 || p[0] != byte(TSMPTypeDiscoKeyUpdate) { return } - tka.Src = pp.Src.Addr() - tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33])) + copy(update.DiscoKey[:], p[1:33]) + return update, true +} - return tka, true +func (h TSMPDiscoKeyUpdate) Len() int { + return h.IPHeader.Len() + 33 +} + +func (h TSMPDiscoKeyUpdate) Marshal(buf []byte) error { + if len(buf) < h.Len() { + return errSmallBuffer + } + if err := h.IPHeader.Marshal(buf); err != nil { + return err + } + buf = buf[h.IPHeader.Len():] + buf[0] = byte(TSMPTypeDiscoKeyUpdate) + copy(buf[1:33], h.DiscoKey[:]) + return nil } diff --git a/net/packet/tsmp_test.go b/net/packet/tsmp_test.go index d8f1d38d5..52e829d70 100644 --- a/net/packet/tsmp_test.go +++ b/net/packet/tsmp_test.go @@ -4,14 +4,8 @@ package packet import ( - "bytes" - "encoding/hex" "net/netip" - "slices" "testing" - - "go4.org/mem" - "tailscale.com/types/key" ) func TestTailscaleRejectedHeader(t *testing.T) { @@ -78,61 +72,168 @@ func TestTailscaleRejectedHeader(t *testing.T) { } } -func TestTSMPDiscoKeyAdvertisementMarshal(t *testing.T) { - var ( - // IPv4: Ver(4)Len(5), TOS, Len(53), ID, Flags, TTL(64), Proto(99), Cksum - headerV4, _ = hex.DecodeString("45000035000000004063705d") - // IPv6: Ver(6)TCFlow, Len(33), NextHdr(99), HopLim(64) - headerV6, _ = hex.DecodeString("6000000000216340") - - packetType = []byte{'a'} - testKey = bytes.Repeat([]byte{'a'}, 32) - - // IPs - srcV4 = netip.MustParseAddr("1.2.3.4") - dstV4 = netip.MustParseAddr("4.3.2.1") - srcV6 = netip.MustParseAddr("2001:db8::1") - dstV6 = netip.MustParseAddr("2001:db8::2") - ) - - join := func(parts ...[]byte) []byte { - return bytes.Join(parts, nil) +func TestTSMPDiscoKeyRequest(t *testing.T) { + t.Run("Manual", func(t *testing.T) { + var payload [1]byte + payload[0] = byte(TSMPTypeDiscoKeyRequest) + + var p Parsed + p.IPProto = TSMP + p.dataofs = 40 // simulate after IP header + buf := make([]byte, 40+1) + copy(buf[40:], payload[:]) + p.b = buf + p.length = len(buf) + + _, ok := p.AsTSMPDiscoKeyRequest() + if !ok { + t.Fatal("failed to parse TSMP disco key request") + } + }) + + t.Run("RoundTripIPv4", func(t *testing.T) { + src := netip.MustParseAddr("100.64.0.1") + dst := netip.MustParseAddr("100.64.0.2") + + iph := IP4Header{ + IPProto: TSMP, + Src: src, + Dst: dst, + } + + var payload [1]byte + payload[0] = byte(TSMPTypeDiscoKeyRequest) + + pkt := Generate(iph, payload[:]) + t.Logf("Generated packet: %d bytes, hex: %x", len(pkt), pkt) + + // Manually check what decode4 would see + if len(pkt) >= 4 { + declaredLen := int(uint16(pkt[2])<<8 | uint16(pkt[3])) + t.Logf("Packet buffer length: %d, IP header declares length: %d", len(pkt), declaredLen) + t.Logf("Protocol byte at [9]: 0x%02x = %d", pkt[9], pkt[9]) + } + + var p Parsed + p.Decode(pkt) + t.Logf("Decoded: IPVersion=%d IPProto=%v Src=%v Dst=%v length=%d dataofs=%d", + p.IPVersion, p.IPProto, p.Src, p.Dst, p.length, p.dataofs) + + if p.IPVersion != 4 { + t.Errorf("IPVersion = %d, want 4", p.IPVersion) + } + if p.IPProto != TSMP { + t.Errorf("IPProto = %v, want TSMP", p.IPProto) + } + if p.Src.Addr() != src { + t.Errorf("Src = %v, want %v", p.Src.Addr(), src) + } + if p.Dst.Addr() != dst { + t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst) + } + + _, ok := p.AsTSMPDiscoKeyRequest() + if !ok { + t.Fatal("failed to parse TSMP disco key request from generated packet") + } + }) + + t.Run("RoundTripIPv6", func(t *testing.T) { + src := netip.MustParseAddr("2001:db8::1") + dst := netip.MustParseAddr("2001:db8::2") + + iph := IP6Header{ + IPProto: TSMP, + Src: src, + Dst: dst, + } + + var payload [1]byte + payload[0] = byte(TSMPTypeDiscoKeyRequest) + + pkt := Generate(iph, payload[:]) + t.Logf("Generated packet: %d bytes", len(pkt)) + + var p Parsed + p.Decode(pkt) + + if p.IPVersion != 6 { + t.Errorf("IPVersion = %d, want 6", p.IPVersion) + } + if p.IPProto != TSMP { + t.Errorf("IPProto = %v, want TSMP", p.IPProto) + } + if p.Src.Addr() != src { + t.Errorf("Src = %v, want %v", p.Src.Addr(), src) + } + if p.Dst.Addr() != dst { + t.Errorf("Dst = %v, want %v", p.Dst.Addr(), dst) + } + + _, ok := p.AsTSMPDiscoKeyRequest() + if !ok { + t.Fatal("failed to parse TSMP disco key request from generated packet") + } + }) +} + +func TestTSMPDiscoKeyUpdate(t *testing.T) { + var discoKey [32]byte + for i := range discoKey { + discoKey[i] = byte(i + 10) } - tests := []struct { - name string - tka TSMPDiscoKeyAdvertisement - want []byte - }{ - { - name: "v4Header", - tka: TSMPDiscoKeyAdvertisement{ - Src: srcV4, - Dst: dstV4, - Key: key.DiscoPublicFromRaw32(mem.B(testKey)), + t.Run("IPv4", func(t *testing.T) { + update := TSMPDiscoKeyUpdate{ + IPHeader: IP4Header{ + IPProto: TSMP, + Src: netip.MustParseAddr("1.2.3.4"), + Dst: netip.MustParseAddr("5.6.7.8"), }, - want: join(headerV4, srcV4.AsSlice(), dstV4.AsSlice(), packetType, testKey), - }, - { - name: "v6Header", - tka: TSMPDiscoKeyAdvertisement{ - Src: srcV6, - Dst: dstV6, - Key: key.DiscoPublicFromRaw32(mem.B(testKey)), + DiscoKey: discoKey, + } + + pkt := make([]byte, update.Len()) + if err := update.Marshal(pkt); err != nil { + t.Fatal(err) + } + + var p Parsed + p.Decode(pkt) + + parsed, ok := p.AsTSMPDiscoKeyUpdate() + if !ok { + t.Fatal("failed to parse TSMP disco key update") + } + if parsed.DiscoKey != discoKey { + t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey) + } + }) + + t.Run("IPv6", func(t *testing.T) { + update := TSMPDiscoKeyUpdate{ + IPHeader: IP6Header{ + IPProto: TSMP, + Src: netip.MustParseAddr("2001:db8::1"), + Dst: netip.MustParseAddr("2001:db8::2"), }, - want: join(headerV6, srcV6.AsSlice(), dstV6.AsSlice(), packetType, testKey), - }, - } + DiscoKey: discoKey, + } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.tka.Marshal() - if err != nil { - t.Errorf("error mashalling TSMPDiscoAdvertisement: %s", err) - } - if !slices.Equal(got, tt.want) { - t.Errorf("error mashalling TSMPDiscoAdvertisement, expected: \n%x, \ngot:\n%x", tt.want, got) - } - }) - } + pkt := make([]byte, update.Len()) + if err := update.Marshal(pkt); err != nil { + t.Fatal(err) + } + + var p Parsed + p.Decode(pkt) + + parsed, ok := p.AsTSMPDiscoKeyUpdate() + if !ok { + t.Fatal("failed to parse TSMP disco key update") + } + if parsed.DiscoKey != discoKey { + t.Errorf("disco key mismatch: got %v, want %v", parsed.DiscoKey, discoKey) + } + }) } diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 6e07c7a3d..0619276ae 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -194,6 +194,10 @@ type Wrapper struct { // false otherwise. OnICMPEchoResponseReceived func(*packet.Parsed) bool + // GetDiscoPublicKey, if non-nil, returns the local node's disco public key. + // This is called when responding to TSMP disco key requests. + GetDiscoPublicKey func() key.DiscoPublic + // PeerAPIPort, if non-nil, returns the peerapi port that's // running for the given IP address. PeerAPIPort func(netip.Addr) (port uint16, ok bool) @@ -211,8 +215,8 @@ type Wrapper struct { metrics *metrics - eventClient *eventbus.Client - discoKeyAdvertisementPub *eventbus.Publisher[DiscoKeyAdvertisement] + eventClient *eventbus.Client + discoKeyUpdatePub *eventbus.Publisher[DiscoKeyUpdate] } type metrics struct { @@ -227,6 +231,12 @@ func registerMetrics(reg *usermetric.Registry) *metrics { } } +// DiscoKeyUpdate is published on the event bus when a TSMP disco key update is received. +type DiscoKeyUpdate struct { + SrcIP netip.Addr + Key [32]byte +} + // tunInjectedRead is an injected packet pretending to be a tun.Read(). type tunInjectedRead struct { // Only one of packet or data should be set, and are read in that order of @@ -288,7 +298,7 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry, } w.eventClient = bus.Client("net.tstun") - w.discoKeyAdvertisementPub = eventbus.Publish[DiscoKeyAdvertisement](w.eventClient) + w.discoKeyUpdatePub = eventbus.Publish[DiscoKeyUpdate](w.eventClient) w.vectorBuffer = make([][]byte, tdev.BatchSize()) for i := range w.vectorBuffer { @@ -1126,11 +1136,6 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i return n, err } -type DiscoKeyAdvertisement struct { - Src netip.Addr - Key key.DiscoPublic -} - func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) { if captHook != nil { captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta) @@ -1141,16 +1146,21 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa t.noteActivity() t.injectOutboundPong(p, pingReq) return filter.DropSilently, gro - } else if discoKeyAdvert, ok := p.AsTSMPDiscoAdvertisement(); ok { - t.discoKeyAdvertisementPub.Publish(DiscoKeyAdvertisement{ - Src: discoKeyAdvert.Src, - Key: discoKeyAdvert.Key, - }) - return filter.DropSilently, gro } else if data, ok := p.AsTSMPPong(); ok { if f := t.OnTSMPPongReceived; f != nil { f(data) } + } else if _, ok := p.AsTSMPDiscoKeyRequest(); ok { + t.noteActivity() + t.injectOutboundDiscoKeyUpdate(p) + return filter.DropSilently, gro + } else if discoKeyUpdate, ok := p.AsTSMPDiscoKeyUpdate(); ok { + // Publish to eventbus for subscribers + t.discoKeyUpdatePub.Publish(DiscoKeyUpdate{ + SrcIP: p.Src.Addr(), + Key: discoKeyUpdate.DiscoKey, + }) + return filter.DropSilently, gro } } @@ -1459,6 +1469,36 @@ func (t *Wrapper) injectOutboundPong(pp *packet.Parsed, req packet.TSMPPingReque t.InjectOutbound(packet.Generate(pong, nil)) } +func (t *Wrapper) injectOutboundDiscoKeyUpdate(pp *packet.Parsed) { + if t.GetDiscoPublicKey == nil { + return + } + + discoKey := t.GetDiscoPublicKey() + if discoKey.IsZero() { + return + } + + update := packet.TSMPDiscoKeyUpdate{ + DiscoKey: discoKey.Raw32(), + } + + switch pp.IPVersion { + case 4: + h4 := pp.IP4Header() + h4.ToResponse() + update.IPHeader = h4 + case 6: + h6 := pp.IP6Header() + h6.ToResponse() + update.IPHeader = h6 + default: + return + } + + t.InjectOutbound(packet.Generate(update, nil)) +} + // InjectOutbound makes the Wrapper device behave as if a packet // with the given contents was sent to the network. // It does not block, but takes ownership of the packet. diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index c7d0708df..2d33228b8 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -966,28 +966,57 @@ func TestCaptureHook(t *testing.T) { } func TestTSMPDisco(t *testing.T) { - t.Run("IPv6DiscoAdvert", func(t *testing.T) { + t.Run("DiscoKeyRequest", func(t *testing.T) { src := netip.MustParseAddr("2001:db8::1") dst := netip.MustParseAddr("2001:db8::2") - discoKey := key.NewDisco() - buf, _ := (&packet.TSMPDiscoKeyAdvertisement{ - Src: src, - Dst: dst, - Key: discoKey.Public(), - }).Marshal() + + iph := packet.IP6Header{ + IPProto: ipproto.TSMP, + Src: src, + Dst: dst, + } + + var payload [1]byte + payload[0] = byte(packet.TSMPTypeDiscoKeyRequest) + buf := packet.Generate(iph, payload[:]) var p packet.Parsed p.Decode(buf) - tda, ok := p.AsTSMPDiscoAdvertisement() + _, ok := p.AsTSMPDiscoKeyRequest() if !ok { - t.Error("Unable to parse message as TSMPDiscoAdversitement") + t.Error("Unable to parse message as TSMPDiscoKeyRequest") } - if tda.Src != src { - t.Errorf("Src address did not match, expected %v, got %v", src, tda.Src) + }) + + t.Run("DiscoKeyUpdate", func(t *testing.T) { + src := netip.MustParseAddr("2001:db8::1") + dst := netip.MustParseAddr("2001:db8::2") + discoKey := key.NewDisco() + + update := packet.TSMPDiscoKeyUpdate{ + IPHeader: packet.IP6Header{ + IPProto: ipproto.TSMP, + Src: src, + Dst: dst, + }, + DiscoKey: discoKey.Public().Raw32(), + } + + buf := make([]byte, update.Len()) + if err := update.Marshal(buf); err != nil { + t.Fatal(err) + } + + var p packet.Parsed + p.Decode(buf) + + parsed, ok := p.AsTSMPDiscoKeyUpdate() + if !ok { + t.Error("Unable to parse message as TSMPDiscoKeyUpdate") } - if !reflect.DeepEqual(tda.Key, discoKey.Public()) { - t.Errorf("Key did not match, expected %q, got %q", discoKey.Public(), tda.Key) + if parsed.DiscoKey != discoKey.Public().Raw32() { + t.Errorf("Key did not match, expected %v, got %v", discoKey.Public().Raw32(), parsed.DiscoKey) } }) } diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index 37a4f1a64..e20acadc4 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -721,6 +721,16 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en update(0, netip.AddrPortFrom(ep.nodeAddr, 0), srcAddr.ap, 1, dm.n, true) } + // Request disco key from peer via TSMP if we receive a WireGuard handshake + // over DERP without recent disco success. This handles the "WireGuard-first" + // case where WireGuard establishes a tunnel via DERP before disco succeeds + // (e.g., control plane unreachable or stale disco keys). + if looksLikeWireGuardHandshake(b[:n]) && n > 0 { + c.mu.Lock() + c.requestDiscoKeyViaTSMPLocked(dm.src, ep) + c.mu.Unlock() + } + c.metrics.inboundPacketsDERPTotal.Add(1) c.metrics.inboundBytesDERPTotal.Add(int64(n)) return n, ep diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 064838a2d..da4f5c47d 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -178,9 +178,10 @@ type Conn struct { // A publisher for synchronization points to ensure correct ordering of // config changes between magicsock and wireguard. - syncPub *eventbus.Publisher[syncPoint] - allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq] - portUpdatePub *eventbus.Publisher[router.PortUpdate] + syncPub *eventbus.Publisher[syncPoint] + allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq] + portUpdatePub *eventbus.Publisher[router.PortUpdate] + tsmpDiscoKeyRequestPub *eventbus.Publisher[TSMPDiscoKeyRequest] // pconn4 and pconn6 are the underlying UDP sockets used to // send/receive packets for wireguard and other magicsock @@ -572,6 +573,14 @@ type UDPRelayAllocReq struct { Message *disco.AllocateUDPRelayEndpointRequest } +// TSMPDiscoKeyRequest is published on the event bus when magicsock needs to +// send a TSMP disco key request to a peer. Subscribers should inject the +// TSMP packet into the tunnel device. +type TSMPDiscoKeyRequest struct { + DstIP netip.Addr + MetricSent *clientmetric.Metric +} + // UDPRelayAllocResp represents a [*disco.AllocateUDPRelayEndpointResponse] // that is yet to be transmitted over DERP (or delivered locally if // ReqRxFromNodeKey is self). This is signaled over an [eventbus.Bus] from @@ -691,6 +700,7 @@ func NewConn(opts Options) (*Conn, error) { c.syncPub = eventbus.Publish[syncPoint](ec) c.allocRelayEndpointPub = eventbus.Publish[UDPRelayAllocReq](ec) c.portUpdatePub = eventbus.Publish[router.PortUpdate](ec) + c.tsmpDiscoKeyRequestPub = eventbus.Publish[TSMPDiscoKeyRequest](ec) eventbus.SubscribeFunc(ec, c.onPortMapChanged) eventbus.SubscribeFunc(ec, c.onFilterUpdate) eventbus.SubscribeFunc(ec, c.onNodeViewsUpdate) @@ -1800,6 +1810,15 @@ func looksLikeInitiationMsg(b []byte) bool { binary.LittleEndian.Uint32(b) == device.MessageInitiationType } +func looksLikeWireGuardHandshake(b []byte) bool { + if len(b) < 4 { + return false + } + msgType := binary.LittleEndian.Uint32(b) + return (len(b) == device.MessageInitiationSize && msgType == device.MessageInitiationType) || + (len(b) == device.MessageResponseSize && msgType == device.MessageResponseType) +} + // receiveIP is the shared bits of ReceiveIPv4 and ReceiveIPv6. // // size is the length of 'b' to report up to wireguard-go (only relevant if @@ -4104,6 +4123,12 @@ var ( metricUDPLifetimeCycleCompleteAt10sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_10s_cliff") metricUDPLifetimeCycleCompleteAt30sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_30s_cliff") metricUDPLifetimeCycleCompleteAt60sCliff = newUDPLifetimeCounter("magicsock_udp_lifetime_cycle_complete_at_60s_cliff") + + // TSMP disco key exchange + metricTSMPDiscoKeyRequestSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_request_sent") + metricTSMPDiscoKeyUpdateReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_received") + metricTSMPDiscoKeyUpdateApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_applied") + metricTSMPDiscoKeyUpdateUnknown = clientmetric.NewCounter("magicsock_tsmp_disco_key_update_unknown_peer") ) // newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided @@ -4242,6 +4267,81 @@ func (le *lazyEndpoint) FromPeer(peerPublicKey [32]byte) { // See http://go/corp/29422 & http://go/corp/30042 le.c.peerMap.setNodeKeyForEpAddr(le.src, pubKey) le.c.logf("magicsock: lazyEndpoint.FromPeer(%v) setting epAddr(%v) in peerMap for node(%v)", pubKey.ShortString(), le.src, ep.nodeAddr) + + le.c.requestDiscoKeyViaTSMPLocked(pubKey, ep) +} + +// requestDiscoKeyViaTSMPLocked sends a TSMP disco key request to a peer if there +// hasn't been a recent disco ping. +// c.mu must be held. +func (c *Conn) requestDiscoKeyViaTSMPLocked(nodeKey key.NodePublic, ep *endpoint) { + if !ep.nodeAddr.IsValid() { + return + } + + epDisco := ep.disco.Load() + if epDisco != nil { + di := c.discoInfo[epDisco.key] + recentDiscoPing := di != nil && time.Since(di.lastPingTime) < discoPingInterval + + if recentDiscoPing { + return + } + } + + go c.tsmpDiscoKeyRequestPub.Publish(TSMPDiscoKeyRequest{DstIP: ep.nodeAddr, MetricSent: metricTSMPDiscoKeyRequestSent}) +} + +// HandleDiscoKeyUpdate processes a TSMP disco key update. +// The update may be solicited (in response to a request) or unsolicited. +// srcIP is the Tailscale IP of the peer that sent the update. +func (c *Conn) HandleDiscoKeyUpdate(srcIP netip.Addr, update packet.TSMPDiscoKeyUpdate) { + discoKey := key.DiscoPublicFromRaw32(mem.B(update.DiscoKey[:])) + c.logf("magicsock: received disco key update %v from %v", discoKey.ShortString(), srcIP) + metricTSMPDiscoKeyUpdateReceived.Add(1) + + c.mu.Lock() + defer c.mu.Unlock() + + var nodeKey key.NodePublic + var found bool + for _, peer := range c.peers.All() { + for _, addr := range peer.Addresses().All() { + if addr.Addr() == srcIP { + nodeKey = peer.Key() + found = true + break + } + } + if found { + break + } + } + + if !found { + c.logf("magicsock: disco key update from unknown peer %v", srcIP) + metricTSMPDiscoKeyUpdateUnknown.Add(1) + return + } + + ep, ok := c.peerMap.endpointForNodeKey(nodeKey) + if !ok { + c.logf("magicsock: endpoint not found for node %v", nodeKey.ShortString()) + return + } + + oldDiscoKey := key.DiscoPublic{} + if epDisco := ep.disco.Load(); epDisco != nil { + oldDiscoKey = epDisco.key + } + c.discoInfoForKnownPeerLocked(discoKey) + ep.disco.Store(&endpointDisco{ + key: discoKey, + short: discoKey.ShortString(), + }) + c.peerMap.upsertEndpoint(ep, oldDiscoKey) + c.logf("magicsock: updated disco key for peer %v to %v", nodeKey.ShortString(), discoKey.ShortString()) + metricTSMPDiscoKeyUpdateApplied.Add(1) } // PeerRelays returns the current set of candidate peer relays. diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 4e1024886..b7694ba56 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -64,6 +64,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/types/nettype" "tailscale.com/types/ptr" + "tailscale.com/types/views" "tailscale.com/util/cibuild" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" @@ -4302,3 +4303,66 @@ func TestRotateDiscoKeyMultipleTimes(t *testing.T) { keys = append(keys, newKey) } } + +func TestSendTSMPDiscoKeyRequest(t *testing.T) { + ep := &endpoint{ + nodeID: 1, + publicKey: key.NewNode().Public(), + nodeAddr: netip.MustParseAddr("100.64.0.1"), + } + discoKey := key.NewDisco().Public() + ep.disco.Store(&endpointDisco{ + key: discoKey, + short: discoKey.ShortString(), + }) + bus := eventbustest.NewBus(t) + conn := newConn(t.Logf) + conn.eventBus = bus + conn.eventClient = bus.Client("magicsock.Conn.test") + conn.tsmpDiscoKeyRequestPub = eventbus.Publish[TSMPDiscoKeyRequest](conn.eventClient) + ep.c = conn + + tsmpRequestCalled := make(chan struct{}, 1) + var capturedIP netip.Addr + ec := bus.Client("test") + defer ec.Close() + eventbus.SubscribeFunc(ec, func(req TSMPDiscoKeyRequest) { + capturedIP = req.DstIP + if req.MetricSent != nil { + req.MetricSent.Add(1) + } + tsmpRequestCalled <- struct{}{} + }) + + conn.mu.Lock() + conn.peers = views.SliceOf([]tailcfg.NodeView{ + (&tailcfg.Node{ + Key: ep.publicKey, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + }, + }).View(), + }) + conn.mu.Unlock() + + var pubKey [32]byte + copy(pubKey[:], ep.publicKey.AppendTo(nil)) + conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) + + le := &lazyEndpoint{ + c: conn, + src: epAddr{ap: netip.MustParseAddrPort("127.0.0.1:7777")}, + } + + le.FromPeer(pubKey) + + select { + case <-tsmpRequestCalled: + if !capturedIP.IsValid() { + t.Error("TSMP request sent with invalid IP") + } + t.Logf("TSMP disco key request sent to %v", capturedIP) + case <-time.After(time.Second): + t.Error("TSMP disco key request was not sent") + } +} diff --git a/wgengine/userspace.go b/wgengine/userspace.go index a369fa343..c0e79633a 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -54,6 +54,7 @@ import ( "tailscale.com/util/execqueue" "tailscale.com/util/mak" "tailscale.com/util/set" + "tailscale.com/util/singleflight" "tailscale.com/util/testenv" "tailscale.com/util/usermetric" "tailscale.com/version" @@ -469,6 +470,13 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) return true } + e.tundev.GetDiscoPublicKey = func() key.DiscoPublic { + if e.magicConn == nil { + return key.DiscoPublic{} + } + return e.magicConn.DiscoPublicKey() + } + // wgdev takes ownership of tundev, will close it when closed. e.logf("Creating WireGuard device...") e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger) @@ -549,6 +557,36 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) } e.linkChangeQueue.Add(func() { e.linkChange(&cd) }) }) + eventbus.SubscribeFunc(ec, func(update tstun.DiscoKeyUpdate) { + e.logf("wgengine: got TSMP disco key update from %v via eventbus", update.SrcIP) + if e.magicConn != nil { + pkt := packet.TSMPDiscoKeyUpdate{ + DiscoKey: update.Key, + } + e.magicConn.HandleDiscoKeyUpdate(update.SrcIP, pkt) + } + }) + var tsmpRequestGroup singleflight.Group[netip.Addr, struct{}] + eventbus.SubscribeFunc(ec, func(req magicsock.TSMPDiscoKeyRequest) { + go tsmpRequestGroup.Do(req.DstIP, func() (struct{}, error) { + // DiscoKeyRequests are triggered by an incoming WireGuard handshake + // initiation arriving before a disco ping, which is a likely + // indicator that disco pings failed due to a lack of key + // synchronization. If the requests are sent immediately, before the + // handshake state is accepted in the WireGuard client state + // machine, this starts a new session, and the two peer state + // machines conflict, causing loss and additional delays. Delaying + // the send avoids this, so coalesce duplicate sends, and delay them + // by a short time to avoid the state machine conflict. + time.Sleep(time.Millisecond) + if err := e.sendTSMPDiscoKeyRequest(req.DstIP); err != nil { + e.logf("wgengine: failed to send TSMP disco key request: %v", err) + } + e.logf("wgengine: sending TSMP disco key request to %v", req.DstIP) + req.MetricSent.Add(1) + return struct{}{}, nil + }) + }) e.eventClient = ec e.logf("Engine created.") return e, nil @@ -1436,7 +1474,6 @@ func (e *userspaceEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size in e.magicConn.Ping(peer, res, size, cb) case "TSMP": e.sendTSMPPing(ip, peer, res, cb) - e.sendTSMPDiscoAdvertisement(ip) case "ICMP": e.sendICMPEchoRequest(ip, peer, res, cb) } @@ -1557,29 +1594,6 @@ func (e *userspaceEngine) sendTSMPPing(ip netip.Addr, peer tailcfg.NodeView, res e.tundev.InjectOutbound(tsmpPing) } -func (e *userspaceEngine) sendTSMPDiscoAdvertisement(ip netip.Addr) { - srcIP, err := e.mySelfIPMatchingFamily(ip) - if err != nil { - e.logf("getting matching node: %s", err) - return - } - tdka := packet.TSMPDiscoKeyAdvertisement{ - Src: srcIP, - Dst: ip, - Key: e.magicConn.DiscoPublicKey(), - } - payload, err := tdka.Marshal() - if err != nil { - e.logf("error generating TSMP Advertisement: %s", err) - metricTSMPDiscoKeyAdvertisementError.Add(1) - } else if err := e.tundev.InjectOutbound(payload); err != nil { - e.logf("error sending TSMP Advertisement: %s", err) - metricTSMPDiscoKeyAdvertisementError.Add(1) - } else { - metricTSMPDiscoKeyAdvertisementSent.Add(1) - } -} - func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPPongReply)) { e.mu.Lock() defer e.mu.Unlock() @@ -1593,6 +1607,35 @@ func (e *userspaceEngine) setTSMPPongCallback(data [8]byte, cb func(packet.TSMPP } } +// sendTSMPDiscoKeyRequest sends a TSMP disco key request to the given peer IP. +func (e *userspaceEngine) sendTSMPDiscoKeyRequest(ip netip.Addr) error { + srcIP, err := e.mySelfIPMatchingFamily(ip) + if err != nil { + return err + } + + var iph packet.Header + if srcIP.Is4() { + iph = packet.IP4Header{ + IPProto: ipproto.TSMP, + Src: srcIP, + Dst: ip, + } + } else { + iph = packet.IP6Header{ + IPProto: ipproto.TSMP, + Src: srcIP, + Dst: ip, + } + } + + var tsmpPayload [1]byte + tsmpPayload[0] = byte(packet.TSMPTypeDiscoKeyRequest) + + tsmpRequest := packet.Generate(iph, tsmpPayload[:]) + return e.tundev.InjectOutbound(tsmpRequest) +} + func (e *userspaceEngine) setICMPEchoResponseCallback(idSeq uint32, cb func()) { e.mu.Lock() defer e.mu.Unlock() @@ -1746,9 +1789,6 @@ var ( metricNumMajorChanges = clientmetric.NewCounter("wgengine_major_changes") metricNumMinorChanges = clientmetric.NewCounter("wgengine_minor_changes") - - metricTSMPDiscoKeyAdvertisementSent = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_sent") - metricTSMPDiscoKeyAdvertisementError = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_error") ) func (e *userspaceEngine) InstallCaptureHook(cb packet.CaptureCallback) { diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index 0a1d2924d..abcf2f64f 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -325,7 +325,7 @@ func TestUserspaceEnginePeerMTUReconfig(t *testing.T) { } } -func TestTSMPKeyAdvertisement(t *testing.T) { +func TestTSMPDiscoKeyRequest(t *testing.T) { var knobs controlknobs.Knobs bus := eventbustest.NewBus(t) @@ -369,13 +369,12 @@ func TestTSMPKeyAdvertisement(t *testing.T) { t.Fatal(err) } - addr := netip.MustParseAddr("100.100.99.1") - previousValue := metricTSMPDiscoKeyAdvertisementSent.Value() - ue.sendTSMPDiscoAdvertisement(addr) - if val := metricTSMPDiscoKeyAdvertisementSent.Value(); val <= previousValue { - errs := metricTSMPDiscoKeyAdvertisementError.Value() - t.Errorf("Expected 1 disco key advert, got %d, errors %d", val, errs) + peerAddr := netip.MustParseAddr("100.100.99.1") + err = ue.sendTSMPDiscoKeyRequest(peerAddr) + if err != nil { + t.Fatalf("sendTSMPDiscoKeyRequest failed: %v", err) } + // Remove config to have the engine shut down more consistently err = ue.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}) if err != nil {