wgengine/magicsock: don't store tailcfg.Nodes alongside endpoints.

Updates #2752

Signed-off-by: David Anderson <danderson@tailscale.com>
pull/2773/head
David Anderson 3 years ago committed by Dave Anderson
parent b2181608b5
commit 1a899344bd

@ -74,8 +74,7 @@ func useDerpRoute() bool {
// peerInfo is all the information magicsock tracks about a particular // peerInfo is all the information magicsock tracks about a particular
// peer. // peer.
type peerInfo struct { type peerInfo struct {
node *tailcfg.Node // always present ep *endpoint // optional, if wireguard-go isn't currently talking to this peer.
ep *endpoint // optional, if wireguard-go isn't currently talking to this peer.
// ipPorts is an inverted version of peerMap.byIPPort (below), so // ipPorts is an inverted version of peerMap.byIPPort (below), so
// that when we're deleting this node, we can rapidly find out the // that when we're deleting this node, we can rapidly find out the
// keys that need deleting from peerMap.byIPPort without having to // keys that need deleting from peerMap.byIPPort without having to
@ -112,30 +111,6 @@ func (m *peerMap) nodeCount() int {
return len(m.byNodeKey) return len(m.byNodeKey)
} }
// nodeForDiscoKey returns the tailcfg.Node for dk. ok is true only if
// the disco key is known to us.
func (m *peerMap) nodeForDiscoKey(dk tailcfg.DiscoKey) (n *tailcfg.Node, ok bool) {
if dk.IsZero() {
return nil, false
}
if info, ok := m.byDiscoKey[dk]; ok {
return info.node, true
}
return nil, false
}
// nodeForNodeKey returns the tailcfg.Node for nk. ok is true only if
// the node key is known to us.
func (m *peerMap) nodeForNodeKey(nk tailcfg.NodeKey) (n *tailcfg.Node, ok bool) {
if nk.IsZero() {
return nil, false
}
if info, ok := m.byNodeKey[nk]; ok {
return info.node, true
}
return nil, false
}
// endpointForDiscoKey returns the endpoint for dk, or nil // endpointForDiscoKey returns the endpoint for dk, or nil
// if dk is not known to us. // if dk is not known to us.
func (m *peerMap) endpointForDiscoKey(dk tailcfg.DiscoKey) (ep *endpoint, ok bool) { func (m *peerMap) endpointForDiscoKey(dk tailcfg.DiscoKey) (ep *endpoint, ok bool) {
@ -178,20 +153,14 @@ func (m *peerMap) forEachDiscoEndpoint(f func(ep *endpoint)) {
} }
} }
// forEachNode invokes f on every tailcfg.Node in m.
func (m *peerMap) forEachNode(f func(n *tailcfg.Node)) {
for _, pi := range m.byNodeKey {
f(pi.node)
}
}
// upsertDiscoEndpoint stores endpoint in the peerInfo for // upsertDiscoEndpoint stores endpoint in the peerInfo for
// ep.publicKey, and updates indexes. m must already have a // ep.publicKey, and updates indexes. m must already have a
// tailcfg.Node for ep.publicKey. // tailcfg.Node for ep.publicKey.
func (m *peerMap) upsertDiscoEndpoint(ep *endpoint) { func (m *peerMap) upsertDiscoEndpoint(ep *endpoint) {
pi := m.byNodeKey[ep.publicKey] pi := m.byNodeKey[ep.publicKey]
if pi == nil { if pi == nil {
panic("can't have disco endpoint for unknown node") pi = newPeerInfo()
m.byNodeKey[ep.publicKey] = pi
} }
old := pi.ep old := pi.ep
pi.ep = ep pi.ep = ep
@ -201,25 +170,6 @@ func (m *peerMap) upsertDiscoEndpoint(ep *endpoint) {
m.byDiscoKey[ep.discoKey] = pi m.byDiscoKey[ep.discoKey] = pi
} }
// upsertNode stores n in the peerInfo for n.Key, creating the
// peerInfo if necessary, and updates indexes.
func (m *peerMap) upsertNode(n *tailcfg.Node) {
if n == nil {
panic("node can't be nil")
}
pi := m.byDiscoKey[n.DiscoKey]
if pi == nil {
pi = newPeerInfo()
}
old := pi.node
pi.node = n
m.byDiscoKey[n.DiscoKey] = pi
if old != nil && old.Key != n.Key {
delete(m.byNodeKey, old.Key)
}
m.byNodeKey[n.Key] = pi
}
// SetDiscoKeyForIPPort makes future peer lookups by ipp return the // SetDiscoKeyForIPPort makes future peer lookups by ipp return the
// same peer info as the lookup by dk. // same peer info as the lookup by dk.
func (m *peerMap) setDiscoKeyForIPPort(ipp netaddr.IPPort, dk tailcfg.DiscoKey) { func (m *peerMap) setDiscoKeyForIPPort(ipp netaddr.IPPort, dk tailcfg.DiscoKey) {
@ -249,25 +199,6 @@ func (m *peerMap) deleteDiscoEndpoint(ep *endpoint) {
} }
} }
// deleteNode deletes the peerInfo associated with n, and updates
// indexes.
func (m *peerMap) deleteNode(n *tailcfg.Node) {
if n == nil {
return
}
pi := m.byNodeKey[n.Key]
if pi != nil && pi.ep != nil {
pi.ep.stopAndReset()
}
delete(m.byNodeKey, n.Key)
if !n.DiscoKey.IsZero() {
delete(m.byDiscoKey, n.DiscoKey)
}
for ip := range pi.ipPorts {
delete(m.byIPPort, ip)
}
}
// A Conn routes UDP packets and actively manages a list of its endpoints. // A Conn routes UDP packets and actively manages a list of its endpoints.
// It implements wireguard/conn.Bind. // It implements wireguard/conn.Bind.
type Conn struct { type Conn struct {
@ -937,37 +868,13 @@ func (c *Conn) Ping(peer *tailcfg.Node, res *ipnstate.PingResult, cb func(*ipnst
} }
} }
de, ok := c.peerMap.endpointForNodeKey(peer.Key) ep, ok := c.peerMap.endpointForNodeKey(peer.Key)
if !ok { if !ok {
node, ok := c.peerMap.nodeForNodeKey(peer.Key) res.Err = "unknown peer"
if !ok { cb(res)
res.Err = "unknown peer" return
cb(res)
return
}
c.mu.Unlock() // temporarily release
if c.noteRecvActivity != nil {
c.noteRecvActivity(node.DiscoKey)
}
c.mu.Lock() // re-acquire
// re-check at least basic invariant:
if c.privateKey.IsZero() {
res.Err = "local tailscaled stopped"
cb(res)
return
}
de, ok = c.peerMap.endpointForNodeKey(peer.Key)
if !ok {
res.Err = "internal error: failed to get endpoint for node key"
cb(res)
return
}
c.logf("[v1] magicsock: started peer %v for ping to %v", de.discoKey.ShortString(), peer.Key.ShortString())
} }
de.cliPing(res, cb) ep.cliPing(res, cb)
} }
// c.mu must be held // c.mu must be held
@ -1010,8 +917,8 @@ func (c *Conn) DiscoPublicKey() tailcfg.DiscoKey {
func (c *Conn) PeerHasDiscoKey(k tailcfg.NodeKey) bool { func (c *Conn) PeerHasDiscoKey(k tailcfg.NodeKey) bool {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
if info, ok := c.peerMap.nodeForNodeKey(k); ok { if ep, ok := c.peerMap.endpointForNodeKey(k); ok {
return info.DiscoKey.IsZero() return ep.discoKey.IsZero()
} }
return false return false
} }
@ -1757,67 +1664,17 @@ func (c *Conn) processDERPReadResult(dm derpReadResult, b []byte) (n int, ep *en
return 0, nil return 0, nil
} }
didNoteRecvActivity := false
c.mu.Lock()
nk := tailcfg.NodeKey(dm.src)
var ok bool var ok bool
ep, ok = c.peerMap.endpointForNodeKey(nk) c.mu.Lock()
ep, ok = c.peerMap.endpointForNodeKey(tailcfg.NodeKey(dm.src))
c.mu.Unlock()
if !ok { if !ok {
node, ok := c.peerMap.nodeForNodeKey(nk) // We don't know anything about this node key, nothing to
if !ok { // record or process.
// We don't know anything about this node key, nothing to return 0, nil
// record or process.
c.mu.Unlock()
return 0, nil
}
// We know about the node, but have no disco endpoint for
// it. That's because it's an idle peer that doesn't yet exist
// in the wireguard config. If we have a receive hook, run it
// to get the endpoint created.
if c.noteRecvActivity == nil {
// No hook to lazily create endpoints, nothing we can do.
c.mu.Unlock()
return 0, nil
}
didNoteRecvActivity = true
// release lock before calling the activity callback, because
// it ends up calling back into magicsock to create the
// endpoint.
c.mu.Unlock()
c.noteRecvActivity(node.DiscoKey)
// Reacquire the lock. Because we were unlocked for a while,
// it's possible that even after all this, we still won't be
// able to find a disco endpoint for the node (e.g. because
// the peer was deleted from the netmap in the interim). Don't
// assume that ep != nil.
c.mu.Lock()
c.logf("magicsock: DERP packet received from idle peer %v; created=%v", dm.src.ShortString(), ep != nil)
ep, ok = c.peerMap.endpointForNodeKey(nk)
if !ok {
// There are a few edge cases where we can still end up
// with a nil ep here. Among them are: the peer was
// deleted while we were unlocked above (harmless, we no
// longer want to talk to that peer anyway), or there is a
// race between magicsock becoming aware of a new peer and
// WireGuard becoming aware, *and* lazy wg reconfiguration
// is disabled (at least test code does this, as of
// 2021-08).
//
// Either way, the bottom line is: we thought we might
// know about this peer, but it turns out we don't know
// enough to hand it to WireGuard, so, we want to drop the
// packet. If this was in error due to a race, the peer
// will eventually retry and heal things.
return 0, nil
}
} }
c.mu.Unlock()
if !didNoteRecvActivity { c.noteRecvActivityFromEndpoint(ep)
c.noteRecvActivityFromEndpoint(ep)
}
return n, ep return n, ep
} }
@ -1914,60 +1771,14 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) (isDiscoMsg bo
return return
} }
peerNode, ok := c.peerMap.nodeForDiscoKey(sender) ep, ok := c.peerMap.endpointForDiscoKey(sender)
if !ok { if !ok {
if debugDisco { if debugDisco {
c.logf("magicsock: disco: ignoring disco-looking frame, don't know node for %v", sender.ShortString()) c.logf("magicsock: disco: ignoring disco-looking frame, don't know endpoint for %v", sender.ShortString())
} }
return return
} }
if !ep.canP2P() {
needsRecvActivityCall := false
isLazyCreate := false
de, ok := c.peerMap.endpointForDiscoKey(sender)
if !ok {
// We know about the node, but it doesn't currently have active WireGuard state.
c.logf("magicsock: got disco message from idle peer, starting lazy conf for %v, %v", peerNode.Key.ShortString(), sender.ShortString())
if c.noteRecvActivity == nil {
c.logf("magicsock: [unexpected] have node without endpoint, without c.noteRecvActivity hook")
return
}
needsRecvActivityCall = true
isLazyCreate = true
} else if de.isFirstRecvActivityInAwhile() {
needsRecvActivityCall = true
}
if needsRecvActivityCall && c.noteRecvActivity != nil {
// We can't hold Conn.mu while calling noteRecvActivity.
// noteRecvActivity acquires userspaceEngine.wgLock (and per our
// lock ordering rules: wgLock must come first), and also calls
// back into our Conn.ParseEndpoint, which would double-acquire
// Conn.mu.
c.mu.Unlock()
c.noteRecvActivity(sender)
c.mu.Lock() // re-acquire
// Now, recheck invariants that might've changed while we'd
// released the lock, which isn't much:
if c.closed || c.privateKey.IsZero() {
return
}
de, ok = c.peerMap.endpointForDiscoKey(sender)
if !ok {
if _, ok := c.peerMap.nodeForDiscoKey(sender); !ok {
// They just disappeared while we'd released the lock.
return false
}
c.logf("magicsock: [unexpected] lazy endpoint not created for %v, %v", peerNode.Key.ShortString(), sender.ShortString())
return
}
if isLazyCreate {
c.logf("magicsock: lazy endpoint created via disco message for %v, %v", peerNode.Key.ShortString(), sender.ShortString())
}
}
if !de.canP2P() {
// This endpoint allegedly sent us a disco packet, but we know // This endpoint allegedly sent us a disco packet, but we know
// they can't speak disco. Drop. // they can't speak disco. Drop.
return return
@ -2014,9 +1825,9 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) (isDiscoMsg bo
switch dm := dm.(type) { switch dm := dm.(type) {
case *disco.Ping: case *disco.Ping:
c.handlePingLocked(dm, de, src, sender, peerNode) c.handlePingLocked(dm, ep, src, sender)
case *disco.Pong: case *disco.Pong:
de.handlePongConnLocked(dm, src) ep.handlePongConnLocked(dm, src)
case *disco.CallMeMaybe: case *disco.CallMeMaybe:
if src.IP() != derpMagicIPAddr { if src.IP() != derpMagicIPAddr {
// CallMeMaybe messages should only come via DERP. // CallMeMaybe messages should only come via DERP.
@ -2024,24 +1835,20 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) (isDiscoMsg bo
return return
} }
c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints", c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints",
c.discoShort, de.discoShort, c.discoShort, ep.discoShort,
de.publicKey.ShortString(), derpStr(src.String()), ep.publicKey.ShortString(), derpStr(src.String()),
len(dm.MyNumber)) len(dm.MyNumber))
go de.handleCallMeMaybe(dm) go ep.handleCallMeMaybe(dm)
} }
return return
} }
func (c *Conn) handlePingLocked(dm *disco.Ping, de *endpoint, src netaddr.IPPort, sender tailcfg.DiscoKey, peerNode *tailcfg.Node) { func (c *Conn) handlePingLocked(dm *disco.Ping, de *endpoint, src netaddr.IPPort, sender tailcfg.DiscoKey) {
if peerNode == nil {
c.logf("magicsock: disco: [unexpected] ignoring ping from unknown peer Node")
return
}
likelyHeartBeat := src == de.lastPingFrom && time.Since(de.lastPingTime) < 5*time.Second likelyHeartBeat := src == de.lastPingFrom && time.Since(de.lastPingTime) < 5*time.Second
de.lastPingFrom = src de.lastPingFrom = src
de.lastPingTime = time.Now() de.lastPingTime = time.Now()
if !likelyHeartBeat || debugDisco { if !likelyHeartBeat || debugDisco {
c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, de.discoShort, peerNode.Key.ShortString(), src, dm.TxID[:6]) c.logf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, de.discoShort, de.publicKey.ShortString(), src, dm.TxID[:6])
} }
// Remember this route if not present. // Remember this route if not present.
@ -2050,7 +1857,7 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, de *endpoint, src netaddr.IPPort
ipDst := src ipDst := src
discoDest := sender discoDest := sender
go c.sendDiscoMessage(ipDst, peerNode.Key, discoDest, &disco.Pong{ go c.sendDiscoMessage(ipDst, de.publicKey, discoDest, &disco.Pong{
TxID: dm.TxID, TxID: dm.TxID,
Src: src, Src: src,
}, discoVerboseLog) }, discoVerboseLog)
@ -2308,59 +2115,61 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
// we'll fall through to the next pass, which allocates but can // we'll fall through to the next pass, which allocates but can
// handle full set updates. // handle full set updates.
for _, n := range nm.Peers { for _, n := range nm.Peers {
c.peerMap.upsertNode(n) if ep, ok := c.peerMap.endpointForNodeKey(n.Key); ok {
if _, ok := c.peerMap.endpointForNodeKey(n.Key); !ok { ep.updateFromNode(n)
ep := &endpoint{ continue
c: c, }
publicKey: n.Key,
sentPing: map[stun.TxID]sentPing{},
endpointState: map[netaddr.IPPort]*endpointState{},
}
if !n.DiscoKey.IsZero() {
ep.discoKey = n.DiscoKey
ep.discoShort = n.DiscoKey.ShortString()
}
epDef := wgcfg.Endpoints{
PublicKey: wgkey.Key(n.Key),
DiscoKey: n.DiscoKey,
}
// We have to make the endpoint string we return to
// WireGuard be the right kind of json that wgcfg expects
// to get back out of uapi, so we have to do this somewhat
// unnecessary json encoding here.
// TODO(danderson): remove this in the wgcfg.Endpoints refactor.
epBytes, err := json.Marshal(epDef)
if err != nil {
c.logf("[unexpected] magicsock: creating endpoint: failed to marshal endpoints json %w", err)
}
ep.wgEndpoint = string(epBytes)
ep.initFakeUDPAddr()
c.logf("magicsock: created endpoint key=%s: disco=%s; %v", n.Key.ShortString(), n.DiscoKey.ShortString(), logger.ArgWriter(func(w *bufio.Writer) {
const derpPrefix = "127.3.3.40:"
if strings.HasPrefix(n.DERP, derpPrefix) {
ipp, _ := netaddr.ParseIPPort(n.DERP)
regionID := int(ipp.Port())
code := c.derpRegionCodeLocked(regionID)
if code != "" {
code = "(" + code + ")"
}
fmt.Fprintf(w, "derp=%v%s ", regionID, code)
}
for _, a := range n.AllowedIPs { ep := &endpoint{
if a.IsSingleIP() { c: c,
fmt.Fprintf(w, "aip=%v ", a.IP()) publicKey: n.Key,
} else { sentPing: map[stun.TxID]sentPing{},
fmt.Fprintf(w, "aip=%v ", a) endpointState: map[netaddr.IPPort]*endpointState{},
} }
if !n.DiscoKey.IsZero() {
ep.discoKey = n.DiscoKey
ep.discoShort = n.DiscoKey.ShortString()
}
epDef := wgcfg.Endpoints{
PublicKey: wgkey.Key(n.Key),
DiscoKey: n.DiscoKey,
}
// We have to make the endpoint string we return to
// WireGuard be the right kind of json that wgcfg expects
// to get back out of uapi, so we have to do this somewhat
// unnecessary json encoding here.
// TODO(danderson): remove this in the wgcfg.Endpoints refactor.
epBytes, err := json.Marshal(epDef)
if err != nil {
c.logf("[unexpected] magicsock: creating endpoint: failed to marshal endpoints json %w", err)
}
ep.wgEndpoint = string(epBytes)
ep.initFakeUDPAddr()
c.logf("magicsock: created endpoint key=%s: disco=%s; %v", n.Key.ShortString(), n.DiscoKey.ShortString(), logger.ArgWriter(func(w *bufio.Writer) {
const derpPrefix = "127.3.3.40:"
if strings.HasPrefix(n.DERP, derpPrefix) {
ipp, _ := netaddr.ParseIPPort(n.DERP)
regionID := int(ipp.Port())
code := c.derpRegionCodeLocked(regionID)
if code != "" {
code = "(" + code + ")"
} }
for _, ep := range n.Endpoints { fmt.Fprintf(w, "derp=%v%s ", regionID, code)
fmt.Fprintf(w, "ep=%v ", ep) }
for _, a := range n.AllowedIPs {
if a.IsSingleIP() {
fmt.Fprintf(w, "aip=%v ", a.IP())
} else {
fmt.Fprintf(w, "aip=%v ", a)
} }
})) }
ep.updateFromNode(n) for _, ep := range n.Endpoints {
c.peerMap.upsertDiscoEndpoint(ep) fmt.Fprintf(w, "ep=%v ", ep)
} }
}))
ep.updateFromNode(n)
c.peerMap.upsertDiscoEndpoint(ep)
} }
// If the set of nodes changed since the last SetNetworkMap, the // If the set of nodes changed since the last SetNetworkMap, the
@ -2373,11 +2182,11 @@ func (c *Conn) SetNetworkMap(nm *netmap.NetworkMap) {
for _, n := range nm.Peers { for _, n := range nm.Peers {
keep[n.Key] = true keep[n.Key] = true
} }
c.peerMap.forEachNode(func(n *tailcfg.Node) { c.peerMap.forEachDiscoEndpoint(func(ep *endpoint) {
if !keep[n.Key] { if !keep[ep.publicKey] {
c.peerMap.deleteNode(n) c.peerMap.deleteDiscoEndpoint(ep)
if !n.DiscoKey.IsZero() { if !ep.discoKey.IsZero() {
delete(c.sharedDiscoKey, n.DiscoKey) delete(c.sharedDiscoKey, ep.discoKey)
} }
} }
}) })
@ -2873,14 +2682,15 @@ func (c *Conn) ParseEndpoint(endpointStr string) (conn.Endpoint, error) {
if c.closed { if c.closed {
return nil, errConnClosed return nil, errConnClosed
} }
if ep, ok := c.peerMap.endpointForNodeKey(tailcfg.NodeKey(pk)); ok { ep, ok := c.peerMap.endpointForNodeKey(tailcfg.NodeKey(pk))
return ep, nil if !ok {
// We should never be telling WireGuard about a new peer
// before magicsock knows about it.
c.logf("[unexpected] magicsock: ParseEndpoint: unknown node key=%s", pk.ShortString())
return nil, fmt.Errorf("magicsock: ParseEndpoint: unknown peer %q", pk.ShortString())
} }
// We should never be telling WireGuard about a new peer return ep, nil
// before magicsock knows about it.
c.logf("[unexpected] magicsock: ParseEndpoint: unknown node key=%s", pk.ShortString())
return nil, fmt.Errorf("magicsock: ParseEndpoint: unknown peer %q", pk.ShortString())
} }
// RebindingUDPConn is a UDP socket that can be re-bound. // RebindingUDPConn is a UDP socket that can be re-bound.
@ -3152,14 +2962,11 @@ func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) {
ss.TailAddrDeprecated = tailAddr4 ss.TailAddrDeprecated = tailAddr4
}) })
c.peerMap.forEachNode(func(n *tailcfg.Node) { c.peerMap.forEachDiscoEndpoint(func(ep *endpoint) {
ps := &ipnstate.PeerStatus{InMagicSock: true} ps := &ipnstate.PeerStatus{InMagicSock: true}
ps.Addrs = append(ps.Addrs, n.Endpoints...) //ps.Addrs = append(ps.Addrs, n.Endpoints...)
ps.Relay = c.derpRegionCodeOfAddrLocked(n.DERP) ep.populatePeerStatus(ps)
if ep, ok := c.peerMap.endpointForNodeKey(n.Key); ok { sb.AddPeer(key.Public(ep.publicKey), ps)
ep.populatePeerStatus(ps)
}
sb.AddPeer(key.Public(n.Key), ps)
}) })
c.foreachActiveDerpSortedLocked(func(node int, ad activeDerp) { c.foreachActiveDerpSortedLocked(func(node int, ad activeDerp) {
@ -3908,6 +3715,8 @@ func (de *endpoint) populatePeerStatus(ps *ipnstate.PeerStatus) {
de.mu.Lock() de.mu.Lock()
defer de.mu.Unlock() defer de.mu.Unlock()
ps.Relay = de.c.derpRegionCodeOfIDLocked(int(de.derpAddr.Port()))
if de.lastSend.IsZero() { if de.lastSend.IsZero() {
return return
} }

@ -1080,7 +1080,6 @@ func TestDiscoMessage(t *testing.T) {
Key: tailcfg.NodeKey(key.NewPrivate().Public()), Key: tailcfg.NodeKey(key.NewPrivate().Public()),
DiscoKey: peer1Pub, DiscoKey: peer1Pub,
} }
c.peerMap.upsertNode(n)
c.peerMap.upsertDiscoEndpoint(&endpoint{ c.peerMap.upsertDiscoEndpoint(&endpoint{
publicKey: n.Key, publicKey: n.Key,
discoKey: n.DiscoKey, discoKey: n.DiscoKey,

Loading…
Cancel
Save