From 6cfcb3cae493122f12cca500c72895a7f287cfdd Mon Sep 17 00:00:00 2001 From: James Tucker Date: Sat, 8 Apr 2023 15:36:47 -0700 Subject: [PATCH] wgengine/magicsock: fix synchronization of endpoint disco fields Identified in review in #7821 endpoint.discoKey and endpoint.discoShort are often accessed without first taking endpoint.mu. The arrangement with endpoint.mu is inconvenient for a good number of those call-sites, so it is instead replaced with an atomic pointer to carry both pieces of disco info. This will also help with #7821 that wants to add explicit checks/guards to disable disco behaviors when disco keys are missing which is necessarily implicitly mostly covered by this change. Updates #7821 Signed-off-by: James Tucker --- wgengine/magicsock/magicsock.go | 119 +++++++++++++++++++-------- wgengine/magicsock/magicsock_test.go | 18 ++-- 2 files changed, 99 insertions(+), 38 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 95d0ef5e8..8f976a7da 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -203,13 +203,17 @@ func (m *peerMap) upsertEndpoint(ep *endpoint, oldDiscoKey key.DiscoPublic) { if m.byNodeKey[ep.publicKey] == nil { m.byNodeKey[ep.publicKey] = newPeerInfo(ep) } - if oldDiscoKey != ep.discoKey { + epDisco := ep.disco.Load() + if epDisco == nil || oldDiscoKey != epDisco.key { delete(m.nodesOfDisco[oldDiscoKey], ep.publicKey) } - set := m.nodesOfDisco[ep.discoKey] + if epDisco == nil { + return + } + set := m.nodesOfDisco[epDisco.key] if set == nil { set = map[key.NodePublic]bool{} - m.nodesOfDisco[ep.discoKey] = set + m.nodesOfDisco[epDisco.key] = set } set[ep.publicKey] = true } @@ -238,8 +242,13 @@ func (m *peerMap) deleteEndpoint(ep *endpoint) { return } ep.stopAndReset() + + epDisco := ep.disco.Load() + pi := m.byNodeKey[ep.publicKey] - delete(m.nodesOfDisco[ep.discoKey], ep.publicKey) + if epDisco != nil { + delete(m.nodesOfDisco[epDisco.key], ep.publicKey) + } delete(m.byNodeKey, ep.publicKey) if pi == nil { // Kneejerk paranoia from earlier issue 2801. @@ -2269,14 +2278,18 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke c.logf("magicsock: disco: ignoring CallMeMaybe from %v; %v is unknown", sender.ShortString(), derpNodeSrc.ShortString()) return } - if ep.discoKey != di.discoKey { + epDisco := ep.disco.Load() + if epDisco == nil { + return + } + if epDisco.key != di.discoKey { metricRecvDiscoCallMeMaybeBadDisco.Add(1) c.logf("[unexpected] CallMeMaybe from peer via DERP whose netmap discokey != disco source") return } di.setNodeKey(nodeKey) c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", - c.discoShort, ep.discoShort, + c.discoShort, epDisco.short, ep.publicKey.ShortString(), derpStr(src.String()), len(dm.MyNumber)) go ep.handleCallMeMaybe(dm) @@ -2293,15 +2306,21 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netip.AddrPort, derpNodeSrc ke // c.mu must be held. func (c *Conn) unambiguousNodeKeyOfPingLocked(dm *disco.Ping, dk key.DiscoPublic, derpNodeSrc key.NodePublic) (nk key.NodePublic, ok bool) { if !derpNodeSrc.IsZero() { - if ep, ok := c.peerMap.endpointForNodeKey(derpNodeSrc); ok && ep.discoKey == dk { - return derpNodeSrc, true + if ep, ok := c.peerMap.endpointForNodeKey(derpNodeSrc); ok { + epDisco := ep.disco.Load() + if epDisco != nil && epDisco.key == dk { + return derpNodeSrc, true + } } } // Pings after 1.16.0 contains its node source. See if it maps back. if !dm.NodeKey.IsZero() { - if ep, ok := c.peerMap.endpointForNodeKey(dm.NodeKey); ok && ep.discoKey == dk { - return dm.NodeKey, true + if ep, ok := c.peerMap.endpointForNodeKey(dm.NodeKey); ok { + epDisco := ep.disco.Load() + if epDisco != nil && epDisco.key == dk { + return dm.NodeKey, true + } } } @@ -2409,11 +2428,16 @@ func (c *Conn) enqueueCallMeMaybe(derpAddr netip.AddrPort, de *endpoint) { c.mu.Lock() defer c.mu.Unlock() + epDisco := de.disco.Load() + if epDisco == nil { + return + } + if !c.lastEndpointsTime.After(time.Now().Add(-endpointsFreshEnoughDuration)) { c.dlogf("[v1] magicsock: want call-me-maybe but endpoints stale; restunning") mak.Set(&c.onEndpointRefreshed, de, func() { - c.dlogf("[v1] magicsock: STUN done; sending call-me-maybe to %v %v", de.discoShort, de.publicKey.ShortString()) + c.dlogf("[v1] magicsock: STUN done; sending call-me-maybe to %v %v", epDisco.short, de.publicKey.ShortString()) c.enqueueCallMeMaybe(derpAddr, de) }) // TODO(bradfitz): make a new 'reSTUNQuickly' method @@ -2432,12 +2456,12 @@ func (c *Conn) enqueueCallMeMaybe(derpAddr netip.AddrPort, de *endpoint) { for _, ep := range c.lastEndpoints { eps = append(eps, ep.Addr) } - go de.c.sendDiscoMessage(derpAddr, de.publicKey, de.discoKey, &disco.CallMeMaybe{MyNumber: eps}, discoLog) + go de.c.sendDiscoMessage(derpAddr, de.publicKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) if debugSendCallMeUnknownPeer() { // Send a callMeMaybe packet to a non-existent peer unknownKey := key.NewNode().Public() c.logf("magicsock: sending CallMeMaybe to unknown peer per TS_DEBUG_SEND_CALLME_UNKNOWN_PEER") - go de.c.sendDiscoMessage(derpAddr, unknownKey, de.discoKey, &disco.CallMeMaybe{MyNumber: eps}, discoLog) + go de.c.sendDiscoMessage(derpAddr, unknownKey, epDisco.key, &disco.CallMeMaybe{MyNumber: eps}, discoLog) } } @@ -2713,7 +2737,10 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { c.peerMap.deleteEndpoint(ep) continue } - oldDiscoKey := ep.discoKey + var oldDiscoKey key.DiscoPublic + if epDisco := ep.disco.Load(); epDisco != nil { + oldDiscoKey = epDisco.key + } ep.updateFromNode(n, heartbeatDisabled) c.peerMap.upsertEndpoint(ep, oldDiscoKey) // maybe update discokey mappings in peerMap continue @@ -2736,8 +2763,10 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) { if len(n.Addresses) > 0 { ep.nodeAddr = n.Addresses[0].Addr() } - ep.discoKey = n.DiscoKey - ep.discoShort = n.DiscoKey.ShortString() + ep.disco.Store(&endpointDisco{ + key: n.DiscoKey, + short: n.DiscoKey.ShortString(), + }) ep.initFakeUDPAddr() if debugDisco() { // rather than making a new knob c.logf("magicsock: created endpoint key=%s: disco=%s; %v", n.Key.ShortString(), n.DiscoKey.ShortString(), logger.ArgWriter(func(w *bufio.Writer) { @@ -4040,6 +4069,11 @@ func ippDebugString(ua netip.AddrPort) string { // recalculated. type endpointSendFunc func([][]byte) error +type endpointDisco struct { + key key.DiscoPublic // for discovery messages. + short string // ShortString of discoKey. +} + // endpoint is a wireguard/conn.Endpoint that picks the best // available path to communicate with a peer, based on network // conditions and what the peer supports. @@ -4057,12 +4091,11 @@ type endpoint struct { fakeWGAddr netip.AddrPort // the UDP address we tell wireguard-go we're using nodeAddr netip.Addr // the node's first tailscale address; used for logging & wireguard rate-limiting (Issue 6686) + disco atomic.Pointer[endpointDisco] // if the peer supports disco, the key and short string + // mu protects all following fields. mu sync.Mutex // Lock ordering: Conn.mu, then endpoint.mu - discoKey key.DiscoPublic // for discovery messages. Should never be the zero value. - discoShort string // ShortString of discoKey. Empty if peer can't disco. - heartBeatTimer *time.Timer // nil when idle lastSend mono.Time // last time there was outgoing packets sent to this peer (from wireguard-go) lastFullPing mono.Time // last time we pinged all endpoints @@ -4249,11 +4282,19 @@ func (de *endpoint) noteRecvActivity() { } } +func (de *endpoint) discoShort() string { + var short string + if d := de.disco.Load(); d != nil { + short = d.short + } + return short +} + // String exists purely so wireguard-go internals can log.Printf("%v") // its internal conn.Endpoints and we don't end up with data races // from fmt (via log) reading mutex fields and such. func (de *endpoint) String() string { - return fmt.Sprintf("magicsock.endpoint{%v, %v}", de.publicKey.ShortString(), de.discoShort) + return fmt.Sprintf("magicsock.endpoint{%v, %v}", de.publicKey.ShortString(), de.discoShort()) } func (de *endpoint) ClearSrc() {} @@ -4298,7 +4339,7 @@ func (de *endpoint) heartbeat() { if mono.Since(de.lastSend) > sessionActiveTimeout { // Session's idle. Stop heartbeating. - de.c.dlogf("[v1] magicsock: disco: ending heartbeats for idle session to %v (%v)", de.publicKey.ShortString(), de.discoShort) + de.c.dlogf("[v1] magicsock: disco: ending heartbeats for idle session to %v (%v)", de.publicKey.ShortString(), de.discoShort()) return } @@ -4451,7 +4492,7 @@ func (de *endpoint) pingTimeout(txid stun.TxID) { return } if debugDisco() || !de.bestAddr.IsValid() || mono.Now().After(de.trustBestAddrUntil) { - de.c.dlogf("[v1] magicsock: disco: timeout waiting for pong %x from %v (%v, %v)", txid[:6], sp.to, de.publicKey.ShortString(), de.discoShort) + de.c.dlogf("[v1] magicsock: disco: timeout waiting for pong %x from %v (%v, %v)", txid[:6], sp.to, de.publicKey.ShortString(), de.discoShort()) } de.removeSentPingLocked(txid, sp) } @@ -4512,6 +4553,10 @@ func (de *endpoint) startPingLocked(ep netip.AddrPort, now mono.Time, purpose di if runtime.GOOS == "js" { return } + epDisco := de.disco.Load() + if epDisco == nil { + return + } if purpose != pingCLI { st, ok := de.endpointState[ep] if !ok { @@ -4534,7 +4579,7 @@ func (de *endpoint) startPingLocked(ep netip.AddrPort, now mono.Time, purpose di if purpose == pingHeartbeat { logLevel = discoVerboseLog } - go de.sendDiscoPing(ep, de.discoKey, txid, logLevel) + go de.sendDiscoPing(ep, epDisco.key, txid, logLevel) } func (de *endpoint) sendPingsLocked(now mono.Time, sendCallMeMaybe bool) { @@ -4556,7 +4601,7 @@ func (de *endpoint) sendPingsLocked(now mono.Time, sendCallMeMaybe bool) { sentAny = true if firstPing && sendCallMeMaybe { - de.c.dlogf("[v1] magicsock: disco: send, starting discovery for %v (%v)", de.publicKey.ShortString(), de.discoShort) + de.c.dlogf("[v1] magicsock: disco: send, starting discovery for %v (%v)", de.publicKey.ShortString(), de.discoShort()) } de.startPingLocked(ep, now, pingDiscovery) @@ -4582,10 +4627,18 @@ func (de *endpoint) updateFromNode(n *tailcfg.Node, heartbeatDisabled bool) { de.heartbeatDisabled = heartbeatDisabled de.expired = n.Expired - if de.discoKey != n.DiscoKey { - de.c.logf("[v1] magicsock: disco: node %s changed from %s to %s", de.publicKey.ShortString(), de.discoKey, n.DiscoKey) - de.discoKey = n.DiscoKey - de.discoShort = de.discoKey.ShortString() + epDisco := de.disco.Load() + var discoKey key.DiscoPublic + if epDisco != nil { + discoKey = epDisco.key + } + + if discoKey != n.DiscoKey { + de.c.logf("[v1] magicsock: disco: node %s changed from %s to %s", de.publicKey.ShortString(), discoKey, n.DiscoKey) + de.disco.Store(&endpointDisco{ + key: n.DiscoKey, + short: n.DiscoKey.ShortString(), + }) de.debugUpdates.Add(EndpointChange{ When: time.Now(), What: "updateFromNode-resetLocked", @@ -4681,7 +4734,7 @@ func (de *endpoint) addCandidateEndpoint(ep netip.AddrPort, forRxPingTxID stun.T } // Newly discovered endpoint. Exciting! - de.c.dlogf("[v1] magicsock: disco: adding %v as candidate endpoint for %v (%s)", ep, de.discoShort, de.publicKey.ShortString()) + de.c.dlogf("[v1] magicsock: disco: adding %v as candidate endpoint for %v (%s)", ep, de.discoShort(), de.publicKey.ShortString()) de.endpointState[ep] = &endpointState{ lastGotPing: time.Now(), lastGotPingTxID: forRxPingTxID, @@ -4750,7 +4803,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip } if sp.purpose != pingHeartbeat { - de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pong.src=%v%v", de.c.discoShort, de.discoShort, de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), m.Src, logger.ArgWriter(func(bw *bufio.Writer) { + de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pong.src=%v%v", de.c.discoShort, de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), m.Src, logger.ArgWriter(func(bw *bufio.Writer) { if sp.to != src { fmt.Fprintf(bw, " ping.to=%v", sp.to) } @@ -4768,7 +4821,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src netip if !isDerp { thisPong := addrLatency{sp.to, latency} if betterAddr(thisPong, de.bestAddr) { - de.c.logf("magicsock: disco: node %v %v now using %v", de.publicKey.ShortString(), de.discoShort, sp.to) + de.c.logf("magicsock: disco: node %v %v now using %v", de.publicKey.ShortString(), de.discoShort(), sp.to) de.debugUpdates.Add(EndpointChange{ When: time.Now(), What: "handlePingLocked-bestAddr-update", @@ -4896,7 +4949,7 @@ func (de *endpoint) handleCallMeMaybe(m *disco.CallMeMaybe) { }) de.c.dlogf("[v1] magicsock: disco: call-me-maybe from %v %v added new endpoints: %v", - de.publicKey.ShortString(), de.discoShort, + de.publicKey.ShortString(), de.discoShort(), logger.ArgWriter(func(w *bufio.Writer) { for i, ep := range newEPs { if i > 0 { @@ -4953,7 +5006,7 @@ func (de *endpoint) stopAndReset() { defer de.mu.Unlock() if closing := de.c.closing.Load(); !closing { - de.c.logf("[v1] magicsock: doing cleanup for discovery key %s", de.discoKey.ShortString()) + de.c.logf("[v1] magicsock: doing cleanup for discovery key %s", de.discoShort()) } de.debugUpdates.Add(EndpointChange{ diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 45fac2b98..3a553ddb7 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -1143,10 +1143,14 @@ func TestDiscoMessage(t *testing.T) { Key: key.NewNode().Public(), DiscoKey: peer1Pub, } - c.peerMap.upsertEndpoint(&endpoint{ + ep := &endpoint{ publicKey: n.Key, - discoKey: n.DiscoKey, - }, key.DiscoPublic{}) + } + ep.disco.Store(&endpointDisco{ + key: n.DiscoKey, + short: n.DiscoKey.ShortString(), + }) + c.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) const payload = "why hello" @@ -1458,8 +1462,12 @@ func TestSetNetworkMapChangingNodeKey(t *testing.T) { if ok && de.publicKey != nodeKey2 { t.Fatalf("discoEndpoint public key = %q; want %q", de.publicKey, nodeKey2) } - if de.discoKey != discoKey { - t.Errorf("discoKey = %v; want %v", de.discoKey, discoKey) + deDisco := de.disco.Load() + if deDisco == nil { + t.Fatalf("discoEndpoint disco is nil") + } + if deDisco.key != discoKey { + t.Errorf("discoKey = %v; want %v", deDisco.key, discoKey) } if _, ok := conn.peerMap.endpointForNodeKey(nodeKey1); ok { t.Errorf("didn't expect to find node for key1")