wgengine/magicsock: de-dup disco pings (#7093)

Fixes #7078

Signed-off-by: Jordan Whited <jordan@tailscale.com>
pull/7102/head
Jordan Whited 2 years ago committed by GitHub
parent 0dc9cbc9ab
commit 7921198c05
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -171,9 +171,10 @@ func (m *peerMap) forEachEndpoint(f func(ep *endpoint)) {
} }
} }
// forEachEndpointWithDiscoKey invokes f on every endpoint in m // forEachEndpointWithDiscoKey invokes f on every endpoint in m that has the
// that has the provided DiscoKey. // provided DiscoKey until f returns false or there are no endpoints left to
func (m *peerMap) forEachEndpointWithDiscoKey(dk key.DiscoPublic, f func(ep *endpoint)) { // iterate.
func (m *peerMap) forEachEndpointWithDiscoKey(dk key.DiscoPublic, f func(*endpoint) (keepGoing bool)) {
for nk := range m.nodesOfDisco[dk] { for nk := range m.nodesOfDisco[dk] {
pi, ok := m.byNodeKey[nk] pi, ok := m.byNodeKey[nk]
if !ok { if !ok {
@ -184,7 +185,9 @@ func (m *peerMap) forEachEndpointWithDiscoKey(dk key.DiscoPublic, f func(ep *end
// into Conn. // into Conn.
continue continue
} }
f(pi.ep) if !f(pi.ep) {
return
}
} }
} }
@ -1781,7 +1784,7 @@ func (c *Conn) receiveIPv6(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in
reportToCaller := false reportToCaller := false
for i, msg := range batch.msgs[:numMsgs] { for i, msg := range batch.msgs[:numMsgs] {
ipp := msg.Addr.(*net.UDPAddr).AddrPort() ipp := msg.Addr.(*net.UDPAddr).AddrPort()
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6, c.closeDisco6 == nil); ok { if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint6); ok {
metricRecvDataIPv6.Add(1) metricRecvDataIPv6.Add(1)
eps[i] = ep eps[i] = ep
sizes[i] = msg.N sizes[i] = msg.N
@ -1819,7 +1822,7 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in
reportToCaller := false reportToCaller := false
for i, msg := range batch.msgs[:numMsgs] { for i, msg := range batch.msgs[:numMsgs] {
ipp := msg.Addr.(*net.UDPAddr).AddrPort() ipp := msg.Addr.(*net.UDPAddr).AddrPort()
if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4, c.closeDisco4 == nil); ok { if ep, ok := c.receiveIP(msg.Buffers[0][:msg.N], ipp, &c.ippEndpoint4); ok {
metricRecvDataIPv4.Add(1) metricRecvDataIPv4.Add(1)
eps[i] = ep eps[i] = ep
sizes[i] = msg.N sizes[i] = msg.N
@ -1838,20 +1841,14 @@ func (c *Conn) receiveIPv4(buffs [][]byte, sizes []int, eps []conn.Endpoint) (in
// //
// ok is whether this read should be reported up to wireguard-go (our // ok is whether this read should be reported up to wireguard-go (our
// caller). // caller).
func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache, checkDisco bool) (ep *endpoint, ok bool) { func (c *Conn) receiveIP(b []byte, ipp netip.AddrPort, cache *ippEndpointCache) (ep *endpoint, ok bool) {
if stun.Is(b) { if stun.Is(b) {
c.stunReceiveFunc.Load()(b, ipp) c.stunReceiveFunc.Load()(b, ipp)
return nil, false return nil, false
} }
if checkDisco {
if c.handleDiscoMessage(b, ipp, key.NodePublic{}) { if c.handleDiscoMessage(b, ipp, key.NodePublic{}) {
return nil, false return nil, false
} }
} else if disco.LooksLikeDiscoWrapper(b) {
// Caller told us to ignore disco traffic, don't let it fall
// through to wireguard-go.
return nil, false
}
if !c.havePrivateKey.Load() { if !c.havePrivateKey.Load() {
// If we have no private key, we're logged out or // If we have no private key, we're logged out or
// stopped. Don't try to pass these wireguard packets // stopped. Don't try to pass these wireguard packets
@ -2132,11 +2129,11 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke
// There might be multiple nodes for the sender's DiscoKey. // There might be multiple nodes for the sender's DiscoKey.
// Ask each to handle it, stopping once one reports that // Ask each to handle it, stopping once one reports that
// the Pong's TxID was theirs. // the Pong's TxID was theirs.
handled := false c.peerMap.forEachEndpointWithDiscoKey(sender, func(ep *endpoint) (keepGoing bool) {
c.peerMap.forEachEndpointWithDiscoKey(sender, func(ep *endpoint) { if ep.handlePongConnLocked(dm, di, src) {
if !handled && ep.handlePongConnLocked(dm, di, src) { return false
handled = true
} }
return true
}) })
case *disco.CallMeMaybe: case *disco.CallMeMaybe:
metricRecvDiscoCallMeMaybe.Add(1) metricRecvDiscoCallMeMaybe.Add(1)
@ -2230,19 +2227,29 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src netip.AddrPort, di *discoInf
// Remember this route if not present. // Remember this route if not present.
var numNodes int var numNodes int
var dup bool
if isDerp { if isDerp {
if ep, ok := c.peerMap.endpointForNodeKey(derpNodeSrc); ok { if ep, ok := c.peerMap.endpointForNodeKey(derpNodeSrc); ok {
ep.addCandidateEndpoint(src) if ep.addCandidateEndpoint(src, dm.TxID) {
return
}
numNodes = 1 numNodes = 1
} }
} else { } else {
c.peerMap.forEachEndpointWithDiscoKey(di.discoKey, func(ep *endpoint) { c.peerMap.forEachEndpointWithDiscoKey(di.discoKey, func(ep *endpoint) (keepGoing bool) {
ep.addCandidateEndpoint(src) if ep.addCandidateEndpoint(src, dm.TxID) {
dup = true
return false
}
numNodes++ numNodes++
if numNodes == 1 && dstKey.IsZero() { if numNodes == 1 && dstKey.IsZero() {
dstKey = ep.publicKey dstKey = ep.publicKey
} }
return true
}) })
if dup {
return
}
if numNodes > 1 { if numNodes > 1 {
// Zero it out if it's ambiguous, so sendDiscoMessage logging // Zero it out if it's ambiguous, so sendDiscoMessage logging
// isn't confusing. // isn't confusing.
@ -3625,7 +3632,7 @@ func ippDebugString(ua netip.AddrPort) string {
// recalculated. // recalculated.
type endpointSendFunc func([][]byte) error type endpointSendFunc func([][]byte) error
// discoEndpoint is a wireguard/conn.Endpoint that picks the best // endpoint is a wireguard/conn.Endpoint that picks the best
// available path to communicate with a peer, based on network // available path to communicate with a peer, based on network
// conditions and what the peer supports. // conditions and what the peer supports.
type endpoint struct { type endpoint struct {
@ -3740,6 +3747,12 @@ type endpointState struct {
// updated and use it to discard old candidates. // updated and use it to discard old candidates.
lastGotPing time.Time lastGotPing time.Time
// lastGotPingTxID, if lastGotPing is non-zero, contains the TxID for the
// last incoming ping. This is used to de-dup incoming pings that we may
// see on both the raw disco socket on Linux, and UDP socket. We cannot rely
// solely on the raw socket disco handling due to https://github.com/tailscale/tailscale/issues/7078.
lastGotPingTxID stun.TxID
// callMeMaybeTime, if non-zero, is the time this endpoint // callMeMaybeTime, if non-zero, is the time this endpoint
// was advertised last via a call-me-maybe disco message. // was advertised last via a call-me-maybe disco message.
callMeMaybeTime time.Time callMeMaybeTime time.Time
@ -4195,27 +4208,34 @@ func (de *endpoint) updateFromNode(n *tailcfg.Node, heartbeatDisabled bool) {
} }
// addCandidateEndpoint adds ep as an endpoint to which we should send // addCandidateEndpoint adds ep as an endpoint to which we should send
// future pings. // future pings. If there is an existing endpointState for ep, and forRxPingTxID
// matches the last received ping TxID, this function reports true, otherwise
// false.
// //
// This is called once we've already verified that we got a valid // This is called once we've already verified that we got a valid
// discovery message from de via ep. // discovery message from de via ep.
func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort) { func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.TxID) (duplicatePing bool) {
de.mu.Lock() de.mu.Lock()
defer de.mu.Unlock() defer de.mu.Unlock()
if st, ok := de.endpointState[ep]; ok { if st, ok := de.endpointState[ep]; ok {
if st.lastGotPing.IsZero() { if st.lastGotPing.IsZero() {
// Already-known endpoint from the network map. // Already-known endpoint from the network map.
return return false
}
if forRxPingTxID == st.lastGotPingTxID {
return true
} }
st.lastGotPing = time.Now() st.lastGotPing = time.Now()
return st.lastGotPingTxID = forRxPingTxID
return false
} }
// Newly discovered endpoint. Exciting! // Newly discovered endpoint. Exciting!
de.c.dlogf("[v1] magicsock: disco: adding %v as candidate endpoint for %v (%s)", ep, de.discoShort, de.publicKey.ShortString()) de.c.dlogf("[v1] magicsock: disco: adding %v as candidate endpoint for %v (%s)", ep, de.discoShort, de.publicKey.ShortString())
de.endpointState[ep] = &endpointState{ de.endpointState[ep] = &endpointState{
lastGotPing: time.Now(), lastGotPing: time.Now(),
lastGotPingTxID: forRxPingTxID,
} }
// If for some reason this gets very large, do some cleanup. // If for some reason this gets very large, do some cleanup.
@ -4228,6 +4248,7 @@ func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort) {
size2 := len(de.endpointState) size2 := len(de.endpointState)
de.c.dlogf("[v1] magicsock: disco: addCandidateEndpoint pruned %v candidate set from %v to %v entries", size, size2) de.c.dlogf("[v1] magicsock: disco: addCandidateEndpoint pruned %v candidate set from %v to %v entries", size, size2)
} }
return false
} }
// noteConnectivityChange is called when connectivity changes enough // noteConnectivityChange is called when connectivity changes enough

Loading…
Cancel
Save