diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index da756150a..74c28853e 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -171,9 +171,10 @@ func (m *peerMap) forEachEndpoint(f func(ep *endpoint)) { } } -// forEachEndpointWithDiscoKey invokes f on every endpoint in m -// that has the provided DiscoKey. -func (m *peerMap) forEachEndpointWithDiscoKey(dk key.DiscoPublic, f func(ep *endpoint)) { +// forEachEndpointWithDiscoKey invokes f on every endpoint in m that has the +// provided DiscoKey until f returns false or there are no endpoints left to +// iterate. +func (m *peerMap) forEachEndpointWithDiscoKey(dk key.DiscoPublic, f func(*endpoint) (keepGoing bool)) { for nk := range m.nodesOfDisco[dk] { pi, ok := m.byNodeKey[nk] if !ok { @@ -184,7 +185,9 @@ func (m *peerMap) forEachEndpointWithDiscoKey(dk key.DiscoPublic, f func(ep *end // into Conn. continue } - f(pi.ep) + if !f(pi.ep) { + return + } } } @@ -1781,7 +1784,7 @@ func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in reportToCaller := false for i, msg := range batch.msgs[:numMsgs] { ipp := msg.Addr.(*net.UDPAddr).AddrPort() - if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6, c.closeDisco6 == nil); ok { + if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6); ok { metricRecvDataIPv6.Add(1) eps[i] = ep sizes[i] = msg.N @@ -1819,7 +1822,7 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in reportToCaller := false for i, msg := range batch.msgs[:numMsgs] { ipp := msg.Addr.(*net.UDPAddr).AddrPort() - if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4, c.closeDisco4 == nil); ok { + if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4); ok { metricRecvDataIPv4.Add(1) eps[i] = ep sizes[i] = msg.N @@ -1838,18 +1841,12 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in // // ok is whether this read should be reported up to wireguard-go (our // caller). -func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache, checkDisco bool) (ep *endpoint, ok bool) { +func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) (ep *endpoint, ok bool) { if stun.Is(b) { c.stunReceiveFunc.Load()(b, ipp) return nil, false } - if checkDisco { - if c.handleDiscoMessage(b, ipp, key.NodePublic{}) { - return nil, false - } - } else if disco.LooksLikeDiscoWrapper(b) { - // Caller told us to ignore disco traffic, don't let it fall - // through to wireguard-go. + if c.handleDiscoMessage(b, ipp, key.NodePublic{}) { return nil, false } if !c.havePrivateKey.Load() { @@ -2132,11 +2129,11 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke // There might be multiple nodes for the sender's DiscoKey. // Ask each to handle it, stopping once one reports that // the Pong's TxID was theirs. - handled := false - c.peerMap.forEachEndpointWithDiscoKey(sender, func(ep *endpoint) { - if !handled && ep.handlePongConnLocked(dm, di, src) { - handled = true + c.peerMap.forEachEndpointWithDiscoKey(sender, func(ep *endpoint) (keepGoing bool) { + if ep.handlePongConnLocked(dm, di, src) { + return false } + return true }) case *disco.CallMeMaybe: metricRecvDiscoCallMeMaybe.Add(1) @@ -2230,19 +2227,29 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInf // Remember this route if not present. var numNodes int + var dup bool if isDerp { if ep, ok := c.peerMap.endpointForNodeKey(derpNodeSrc); ok { - ep.addCandidateEndpoint(src) + if ep.addCandidateEndpoint(src, dm.TxID) { + return + } numNodes = 1 } } else { - c.peerMap.forEachEndpointWithDiscoKey(di.discoKey, func(ep *endpoint) { - ep.addCandidateEndpoint(src) + c.peerMap.forEachEndpointWithDiscoKey(di.discoKey, func(ep *endpoint) (keepGoing bool) { + if ep.addCandidateEndpoint(src, dm.TxID) { + dup = true + return false + } numNodes++ if numNodes == 1 && dstKey.IsZero() { dstKey = ep.publicKey } + return true }) + if dup { + return + } if numNodes > 1 { // Zero it out if it's ambiguous, so sendDiscoMessage logging // isn't confusing. @@ -3625,7 +3632,7 @@ func ippDebugString(ua netip.AddrPort) string { // recalculated. type endpointSendFunc func([][]byte) error -// discoEndpoint is a wireguard/conn.Endpoint that picks the best +// endpoint is a wireguard/conn.Endpoint that picks the best // available path to communicate with a peer, based on network // conditions and what the peer supports. type endpoint struct { @@ -3740,6 +3747,12 @@ type endpointState struct { // updated and use it to discard old candidates. lastGotPing time.Time + // lastGotPingTxID, if lastGotPing is non-zero, contains the TxID for the + // last incoming ping. This is used to de-dup incoming pings that we may + // see on both the raw disco socket on Linux, and UDP socket. We cannot rely + // solely on the raw socket disco handling due to https://github.com/tailscale/tailscale/issues/7078. + lastGotPingTxID stun.TxID + // callMeMaybeTime, if non-zero, is the time this endpoint // was advertised last via a call-me-maybe disco message. callMeMaybeTime time.Time @@ -4195,27 +4208,34 @@ func (de *endpoint) updateFromNode(n *tailcfg.Node, heartbeatDisabled bool) { } // addCandidateEndpoint adds ep as an endpoint to which we should send -// future pings. +// future pings. If there is an existing endpointState for ep, and forRxPingTxID +// matches the last received ping TxID, this function reports true, otherwise +// false. // // This is called once we've already verified that we got a valid // discovery message from de via ep. -func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort) { +func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.TxID) (duplicatePing bool) { de.mu.Lock() defer de.mu.Unlock() if st, ok := de.endpointState[ep]; ok { if st.lastGotPing.IsZero() { // Already-known endpoint from the network map. - return + return false + } + if forRxPingTxID == st.lastGotPingTxID { + return true } st.lastGotPing = time.Now() - return + st.lastGotPingTxID = forRxPingTxID + return false } // Newly discovered endpoint. Exciting! de.c.dlogf("[v1] magicsock: disco: adding %v as candidate endpoint for %v (%s)", ep, de.discoShort, de.publicKey.ShortString()) de.endpointState[ep] = &endpointState{ - lastGotPing: time.Now(), + lastGotPing: time.Now(), + lastGotPingTxID: forRxPingTxID, } // If for some reason this gets very large, do some cleanup. @@ -4228,6 +4248,7 @@ func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort) { size2 := len(de.endpointState) de.c.dlogf("[v1] magicsock: disco: addCandidateEndpoint pruned %v candidate set from %v to %v entries", size, size2) } + return false } // noteConnectivityChange is called when connectivity changes enough