From 16a9cfe2f4ce7d9afa093ec0b910bfb97fc40cab Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Thu, 23 Jul 2020 15:15:28 -0700 Subject: [PATCH] wgengine: configure wireguard peers lazily, as needed wireguard-go uses 3 goroutines per peer (with reasonably large stacks & buffers). Rather than tell wireguard-go about all our peers, only tell it about peers we're actively communicating with. That means we need hooks into magicsock's packet receiving path and tstun's packet sending path to lazily create a wireguard peer on demand from the network map. This frees up lots of memory for iOS (where we have almost nothing left for larger domains with many users). We should ideally do this in wireguard-go itself one day, but that'd be a pretty big change. Signed-off-by: Brad Fitzpatrick --- wgengine/magicsock/magicsock.go | 188 +++++++++++++++---- wgengine/magicsock/magicsock_test.go | 7 +- wgengine/tstun/tun.go | 27 ++- wgengine/userspace.go | 268 +++++++++++++++++++++++---- 4 files changed, 411 insertions(+), 79 deletions(-) diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index 516ea0610..c7902a35d 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -58,14 +58,15 @@ import ( // A Conn routes UDP packets and actively manages a list of its endpoints. // It implements wireguard/conn.Bind. type Conn struct { - pconnPort uint16 // the preferred port from opts.Port; 0 means auto - pconn4 *RebindingUDPConn - pconn6 *RebindingUDPConn // non-nil if IPv6 available - epFunc func(endpoints []string) - logf logger.Logf - sendLogLimit *rate.Limiter - netChecker *netcheck.Client - idleFunc func() time.Duration // nil means unknown + pconnPort uint16 // the preferred port from opts.Port; 0 means auto + pconn4 *RebindingUDPConn + pconn6 *RebindingUDPConn // non-nil if IPv6 available + epFunc func(endpoints []string) + logf logger.Logf + sendLogLimit *rate.Limiter + netChecker *netcheck.Client + idleFunc func() time.Duration // nil means unknown + noteRecvActivity func(tailcfg.DiscoKey) // or nil, see Options.NoteRecvActivity // bufferedIPv4From and bufferedIPv4Packet are owned by // ReceiveIPv4, and used when both a DERP and IPv4 packet arrive @@ -89,6 +90,13 @@ type Conn struct { // ============================================================ mu sync.Mutex // guards all following fields + // canCreateEPUnlocked tracks at one place whether mu is + // already held. It's then checked in CreateEndpoint to avoid + // double-locking mu and thus deadlocking. mu should be held + // while setting this; but can be read without mu held. + // TODO(bradfitz): delete this shameful hack; refactor the one use + canCreateEPUnlocked syncs.AtomicBool + started bool // Start was called closed bool // Close was called @@ -104,8 +112,8 @@ type Conn struct { nodeOfDisco map[tailcfg.DiscoKey]*tailcfg.Node discoOfNode map[tailcfg.NodeKey]tailcfg.DiscoKey discoOfAddr map[netaddr.IPPort]tailcfg.DiscoKey // validated non-DERP paths only - endpointOfDisco map[tailcfg.DiscoKey]*discoEndpoint - sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte // nacl/box precomputed key + endpointOfDisco map[tailcfg.DiscoKey]*discoEndpoint // those with activity only + sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte // nacl/box precomputed key // addrsByUDP is a map of every remote ip:port to a priority // list of endpoint addresses for a peer. @@ -235,6 +243,17 @@ type Options struct { // PacketListener optionally specifies how to create PacketConns. // It's meant for testing. PacketListener nettype.PacketListener + + // NoteRecvActivity, if provided, is a func for magicsock to + // call whenever it receives a packet from a a + // discovery-capable peer if it's been more than ~10 seconds + // since the last one. (10 seconds is somewhat arbitrary; the + // sole user just doesn't need or want it called on every + // packet, just every minute or two for Wireguard timeouts, + // and 10 seconds seems like a good trade-off between often + // enough and not too often.) The provided func likely calls + // Conn.CreateEndpoint, which acquires Conn.mu. + NoteRecvActivity func(tailcfg.DiscoKey) } func (o *Options) logf() logger.Logf { @@ -282,6 +301,7 @@ func NewConn(opts Options) (*Conn, error) { c.epFunc = opts.endpointsFunc() c.idleFunc = opts.IdleFunc c.packetListener = opts.PacketListener + c.noteRecvActivity = opts.NoteRecvActivity if err := c.initialBind(); err != nil { return nil, err @@ -1300,6 +1320,16 @@ func wgRecvAddr(e conn.Endpoint, ipp netaddr.IPPort, addr *net.UDPAddr) *net.UDP return ipp.UDPAddr() } +// noteRecvActivity calls the magicsock.Conn.noteRecvActivity hook if +// e is a discovery-capable peer. +// +// This should be called whenever a packet arrives from e. +func noteRecvActivity(e conn.Endpoint) { + if de, ok := e.(*discoEndpoint); ok { + de.onRecvActivity() + } +} + func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr, err error) { Top: // First, process any buffered packet from earlier. @@ -1307,6 +1337,7 @@ Top: c.bufferedIPv4From = netaddr.IPPort{} addr = from.UDPAddr() ep := c.findEndpoint(from, addr) + noteRecvActivity(ep) return copy(b, c.bufferedIPv4Packet), ep, wgRecvAddr(ep, from, addr), nil } @@ -1319,6 +1350,7 @@ Top: var addrSet *AddrSet var discoEp *discoEndpoint var ipp netaddr.IPPort + var didNoteRecvActivity bool select { case dm := <-c.derpRecvCh: @@ -1360,6 +1392,24 @@ Top: c.mu.Lock() if dk, ok := c.discoOfNode[tailcfg.NodeKey(dm.src)]; ok { discoEp = c.endpointOfDisco[dk] + // If we know about the node (it's in discoOfNode) but don't know about the + // endpoint, that's because it's an idle peer that doesn't yet exist in the + // wireguard config. So run the receive hook, if defined, which should + // create the wireguard peer. + if discoEp == nil && c.noteRecvActivity != nil { + didNoteRecvActivity = true + c.mu.Unlock() // release lock before calling noteRecvActivity + c.noteRecvActivity(dk) // (calls back into CreateEndpoint) + // Now require the lock. No invariants need to be rechecked; just + // 1-2 map lookups follow that are harmless if, say, the peer has + // been deleted during this time. In that case we'll treate it as a + // legacy pre-disco UDP receive and hand it to wireguard which'll + // likely just drop it. + c.mu.Lock() + + discoEp = c.endpointOfDisco[dk] + c.logf("magicsock: DERP packet received from idle peer %v; created=%v", dm.src.ShortString(), discoEp != nil) + } } if discoEp == nil { addrSet = c.addrsByKey[dm.src] @@ -1398,6 +1448,9 @@ Top: } else { ep = c.findEndpoint(ipp, addr) } + if !didNoteRecvActivity { + noteRecvActivity(ep) + } return n, ep, wgRecvAddr(ep, ipp, addr), nil } @@ -1424,6 +1477,7 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) { } ep := c.findEndpoint(ipp, addr) + noteRecvActivity(ep) return n, ep, wgRecvAddr(ep, ipp, addr), nil } } @@ -1440,7 +1494,7 @@ const ( discoVerboseLog ) -func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey key.Public, dstDisco tailcfg.DiscoKey, m disco.Message, logLevel discoLogLevel) (sent bool, err error) { +func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstDisco tailcfg.DiscoKey, m disco.Message, logLevel discoLogLevel) (sent bool, err error) { c.mu.Lock() if c.closed { c.mu.Unlock() @@ -1458,7 +1512,7 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey key.Public, dstDisco c.mu.Unlock() pkt = box.SealAfterPrecomputation(pkt, m.AppendMarshal(nil), &nonce, sharedKey) - sent, err = c.sendAddr(dst, dstKey, pkt) + sent, err = c.sendAddr(dst, key.Public(dstKey), pkt) if sent { if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco) { c.logf("magicsock: disco: %v->%v (%v, %v) sent %v", c.discoShort, dstDisco.ShortString(), dstKey.ShortString(), derpStr(dst.String()), disco.MessageSummary(m)) @@ -1507,16 +1561,45 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool { return false } - de, ok := c.endpointOfDisco[sender] + peerNode, ok := c.nodeOfDisco[sender] if !ok { if debugDisco { - c.logf("magicsock: disco: ignoring disco-looking frame, don't know about %v", sender.ShortString()) + c.logf("magicsock: disco: ignoring disco-looking frame, don't know node for %v", sender.ShortString()) } // Returning false keeps passing it down, to WireGuard. // WireGuard will almost surely reject it, but give it a chance. return false } + de, ok := c.endpointOfDisco[sender] + if !ok { + // We don't have an active endpoint for this sender but we knew about the node, so + // it's an idle endpoint that doesn't yet exist in the wireguard config. We now have + // to notify the userspace engine (via noteRecvActivity) so wireguard-go can create + // an Endpoint (ultimately calling our CreateEndpoint). + if debugDisco { + c.logf("magicsock: disco: got message from inactive peer %v", sender.ShortString()) + } + if c.noteRecvActivity == nil { + c.logf("magicsock: [unexpected] have node without endpoint, without c.noteRecvActivity hook") + return false + } + // noteRecvActivity calls back into CreateEndpoint, which we can't easily control, + // and CreateEndpoint expects to be called with c.mu held, but we hold it here, and + // it's too invasive for now to release it here and recheck invariants. So instead, + // use this unfortunate hack: set canCreateEPUnlocked which CreateEndpoint then + // checks to conditionally acquire the mutex. I'm so sorry. + c.canCreateEPUnlocked.Set(true) + c.noteRecvActivity(sender) + c.canCreateEPUnlocked.Set(false) + de, ok = c.endpointOfDisco[sender] + if !ok { + c.logf("magicsock: [unexpected] lazy endpoint not created for %v, %v", peerNode.Key.ShortString(), sender.ShortString()) + return false + } + c.logf("magicsock: lazy endpoint created via disco message for %v, %v", peerNode.Key.ShortString(), sender.ShortString()) + } + // First, do we even know (and thus care) about this sender? If not, // don't bother decrypting it. @@ -1556,8 +1639,11 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool { switch dm := dm.(type) { case *disco.Ping: - c.handlePingLocked(dm, de, src) + c.handlePingLocked(dm, de, src, sender, peerNode) case *disco.Pong: + if de == nil { + return true + } de.handlePongConnLocked(dm, src) case disco.CallMeMaybe: if src.IP != derpMagicIPAddr { @@ -1565,26 +1651,40 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool { c.logf("[unexpected] CallMeMaybe packets should only come via DERP") return true } - c.logf("magicsock: disco: %v<-%v (%v, %v) got call-me-maybe", c.discoShort, de.discoShort, de.publicKey.ShortString(), derpStr(src.String())) - go de.handleCallMeMaybe() + if de != nil { + c.logf("magicsock: disco: %v<-%v (%v, %v) got call-me-maybe", c.discoShort, de.discoShort, de.publicKey.ShortString(), derpStr(src.String())) + go de.handleCallMeMaybe() + } } return true } -func (c *Conn) handlePingLocked(dm *disco.Ping, de *discoEndpoint, src netaddr.IPPort) { - likelyHeartBeat := src == de.lastPingFrom && time.Since(de.lastPingTime) < 5*time.Second - de.lastPingFrom = src - de.lastPingTime = time.Now() +// de may be nil +func (c *Conn) handlePingLocked(dm *disco.Ping, de *discoEndpoint, src netaddr.IPPort, sender tailcfg.DiscoKey, peerNode *tailcfg.Node) { + if peerNode == nil { + c.logf("magicsock: disco: [unexpected] ignoring ping from unknown peer Node") + return + } + likelyHeartBeat := de != nil && src == de.lastPingFrom && time.Since(de.lastPingTime) < 5*time.Second + var discoShort string + if de != nil { + discoShort = de.discoShort + de.lastPingFrom = src + de.lastPingTime = time.Now() + } else { + discoShort = sender.ShortString() + } if !likelyHeartBeat || debugDisco { - c.logf("magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, de.discoShort, de.publicKey.ShortString(), src, dm.TxID[:6]) + c.logf("magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, discoShort, peerNode.Key.ShortString(), src, dm.TxID[:6]) } // Remember this route if not present. - c.setAddrToDiscoLocked(src, de.discoKey, nil) + c.setAddrToDiscoLocked(src, sender, nil) - pongDst := src - go de.sendDiscoMessage(pongDst, &disco.Pong{ + ipDst := src + discoDest := sender + go c.sendDiscoMessage(ipDst, peerNode.Key, discoDest, &disco.Pong{ TxID: dm.TxID, Src: src, }, discoVerboseLog) @@ -2455,17 +2555,30 @@ func (c *Conn) CreateEndpoint(pubKey [32]byte, addrs string) (conn.Endpoint, err if err != nil { return nil, fmt.Errorf("magicsock: invalid discokey endpoint %q for %v: %w", addrs, pk.ShortString(), err) } - c.mu.Lock() - defer c.mu.Unlock() + if !c.canCreateEPUnlocked.Get() { // sorry + c.mu.Lock() + defer c.mu.Unlock() + } de := &discoEndpoint{ c: c, - publicKey: pk, // peer public key (for WireGuard + DERP) + publicKey: tailcfg.NodeKey(pk), // peer public key (for WireGuard + DERP) discoKey: tailcfg.DiscoKey(discoKey), // for discovery mesages discoShort: tailcfg.DiscoKey(discoKey).ShortString(), wgEndpointHostPort: addrs, sentPing: map[stun.TxID]sentPing{}, endpointState: map[netaddr.IPPort]*endpointState{}, } + lastRecvTime := new(int64) // atomic + de.onRecvActivity = func() { + now := time.Now().Unix() + old := atomic.LoadInt64(lastRecvTime) + if old == 0 || old <= now-10 { + atomic.StoreInt64(lastRecvTime, now) + if c.noteRecvActivity != nil { + c.noteRecvActivity(de.discoKey) + } + } + } de.initFakeUDPAddr() de.updateFromNode(c.nodeOfDisco[de.discoKey]) c.endpointOfDisco[de.discoKey] = de @@ -2694,14 +2807,14 @@ func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) { c.mu.Lock() defer c.mu.Unlock() - for dk, de := range c.endpointOfDisco { + for dk, n := range c.nodeOfDisco { ps := &ipnstate.PeerStatus{InMagicSock: true} - if node, ok := c.nodeOfDisco[dk]; ok { - ps.Addrs = append(ps.Addrs, node.Endpoints...) - ps.Relay = c.derpRegionCodeOfAddrLocked(node.DERP) + ps.Addrs = append(ps.Addrs, n.Endpoints...) + ps.Relay = c.derpRegionCodeOfAddrLocked(n.DERP) + if de, ok := c.endpointOfDisco[dk]; ok { + de.populatePeerStatus(ps) } - de.populatePeerStatus(ps) - sb.AddPeer(de.publicKey, ps) + sb.AddPeer(key.Public(n.Key), ps) } // Old-style (pre-disco) peers: for k, as := range c.addrsByKey { @@ -2731,12 +2844,13 @@ func udpAddrDebugString(ua net.UDPAddr) string { type discoEndpoint struct { // These fields are initialized once and never modified. c *Conn - publicKey key.Public // peer public key (for WireGuard + DERP) + publicKey tailcfg.NodeKey // peer public key (for WireGuard + DERP) discoKey tailcfg.DiscoKey // for discovery mesages discoShort string // ShortString of discoKey fakeWGAddr netaddr.IPPort // the UDP address we tell wireguard-go we're using fakeWGAddrStd *net.UDPAddr // the *net.UDPAddr form of fakeWGAddr wgEndpointHostPort string // string from CreateEndpoint: ".disco.tailscale:12345" + onRecvActivity func() // Owned by Conn.mu: lastPingFrom netaddr.IPPort @@ -2958,10 +3072,10 @@ func (de *discoEndpoint) send(b []byte) error { } var err error if !udpAddr.IsZero() { - _, err = de.c.sendAddr(udpAddr, de.publicKey, b) + _, err = de.c.sendAddr(udpAddr, key.Public(de.publicKey), b) } if !derpAddr.IsZero() { - if ok, _ := de.c.sendAddr(derpAddr, de.publicKey, b); ok && err != nil { + if ok, _ := de.c.sendAddr(derpAddr, key.Public(de.publicKey), b); ok && err != nil { // UDP failed but DERP worked, so good enough: return nil } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index 30628bac6..dcad71745 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -947,7 +947,12 @@ func TestDiscoMessage(t *testing.T) { peer1Priv := c.discoPrivate c.endpointOfDisco = map[tailcfg.DiscoKey]*discoEndpoint{ tailcfg.DiscoKey(peer1Pub): &discoEndpoint{ - // ... + // ... (enough for this test) + }, + } + c.nodeOfDisco = map[tailcfg.DiscoKey]*tailcfg.Node{ + tailcfg.DiscoKey(peer1Pub): &tailcfg.Node{ + // ... (enough for this test) }, } diff --git a/wgengine/tstun/tun.go b/wgengine/tstun/tun.go index 48d0e6c43..fb1804a14 100644 --- a/wgengine/tstun/tun.go +++ b/wgengine/tstun/tun.go @@ -66,6 +66,8 @@ type TUN struct { _ [4]byte // force 64-bit alignment of following field on 32-bit lastActivityAtomic int64 // unix seconds of last send or receive + destIPActivity atomic.Value // of map[packet.IP]func() + // buffer stores the oldest unconsumed packet from tdev. // It is made a static buffer in order to avoid allocations. buffer [maxBufferSize]byte @@ -129,6 +131,14 @@ func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN { return tun } +// SetDestIPActivityFuncs sets a map of funcs to run per packet +// destination (the map keys). +// +// The map ownership passes to the TUN. It must be non-nil. +func (t *TUN) SetDestIPActivityFuncs(m map[packet.IP]func()) { + t.destIPActivity.Store(m) +} + func (t *TUN) Close() error { select { case <-t.closed: @@ -204,10 +214,7 @@ func (t *TUN) poll() { } } -func (t *TUN) filterOut(buf []byte) filter.Response { - p := parsedPacketPool.Get().(*packet.ParsedPacket) - defer parsedPacketPool.Put(p) - p.Decode(buf) +func (t *TUN) filterOut(p *packet.ParsedPacket) filter.Response { if t.PreFilterOut != nil { if t.PreFilterOut(p, t) == filter.Drop { @@ -271,8 +278,18 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) { } } + p := parsedPacketPool.Get().(*packet.ParsedPacket) + defer parsedPacketPool.Put(p) + p.Decode(buf[offset : offset+n]) + + if m, ok := t.destIPActivity.Load().(map[packet.IP]func()); ok { + if fn := m[p.DstIP]; fn != nil { + fn() + } + } + if !t.disableFilter { - response := t.filterOut(buf[offset : offset+n]) + response := t.filterOut(p) if response != filter.Accept { // Wireguard considers read errors fatal; pretend nothing was read return 0, nil diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 25a0797e9..b1baf8a4e 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "context" + "encoding/binary" "errors" "fmt" "io" @@ -59,6 +60,15 @@ const ( // magicDNSDomain is the parent domain for Tailscale nodes. const magicDNSDomain = "b.tailscale.net" +// Lazy wireguard-go configuration parameters. +const ( + // lazyPeerIdleThreshold is the idle duration after + // which we remove a peer from the wireguard configuration. + // (This includes peers that have never been idle, which + // effectively have infinite idleness) + lazyPeerIdleThreshold = 5 * time.Minute +) + type userspaceEngine struct { logf logger.Logf reqCh chan struct{} @@ -76,10 +86,14 @@ type userspaceEngine struct { // incorrectly sent to us. localAddrs atomic.Value // of map[packet.IP]bool - wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below - lastEngineSig string - lastRouterSig string - lastCfg wgcfg.Config + wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below + lastCfgFull wgcfg.Config + lastRouterSig string // of router.Config + lastEngineSigFull string // of full wireguard config + lastEngineSigTrim string // of trimmed wireguard config + recvActivityAt map[tailcfg.DiscoKey]time.Time + sentActivityAt map[packet.IP]*int64 // value is atomic int64 of unixtime + destIPActivityFuncs map[packet.IP]func() mu sync.Mutex // guards following; see lock order comment below closing bool // Close was called (even if we're still closing) @@ -210,10 +224,11 @@ func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) { e.RequestStatus() } magicsockOpts := magicsock.Options{ - Logf: logf, - Port: conf.ListenPort, - EndpointsFunc: endpointsFn, - IdleFunc: e.tundev.IdleDuration, + Logf: logf, + Port: conf.ListenPort, + EndpointsFunc: endpointsFn, + IdleFunc: e.tundev.IdleDuration, + NoteRecvActivity: e.noteReceiveActivity, } e.magicConn, err = magicsock.NewConn(magicsockOpts) if err != nil { @@ -513,8 +528,8 @@ func (e *userspaceEngine) pinger(peerKey wgcfg.Key, ips []wgcfg.IP) { var srcIP packet.IP e.wgLock.Lock() - if len(e.lastCfg.Addresses) > 0 { - srcIP = packet.NewIP(e.lastCfg.Addresses[0].IP.IP()) + if len(e.lastCfgFull.Addresses) > 0 { + srcIP = packet.NewIP(e.lastCfgFull.Addresses[0].IP.IP()) } e.wgLock.Unlock() @@ -554,6 +569,198 @@ func updateSig(last *string, v interface{}) (changed bool) { return false } +// isTrimmablePeer reports whether p is a peer that we can trim out of the +// network map. +// +// We can only trim peers that both a) support discovery (because we +// know who they are when we receive their data and don't need to rely +// on wireguard-go figuring it out) and b) for implementation +// simplicity, have only one IP address (an IPv4 /32), which is the +// common case for most peers. Subnet router nodes will just always be +// created in the wireguard-go config. +func isTrimmablePeer(p *wgcfg.Peer) bool { + if len(p.AllowedIPs) != 1 || len(p.Endpoints) != 1 { + return false + } + if !strings.HasSuffix(p.Endpoints[0].Host, ".disco.tailscale") { + return false + } + aip := p.AllowedIPs[0] + // TODO: IPv6 support, once we support IPv6 within the tunnel. In that case, + // len(p.AllowedIPs) probably will be more than 1. + if aip.Mask != 32 || !aip.IP.Is4() { + return false + } + return true +} + +// noteReceiveActivity is called by magicsock when a packet has been received +// by the peer using discovery key dk. Magicsock calls this no more than +// every 10 seconds for a given peer. +func (e *userspaceEngine) noteReceiveActivity(dk tailcfg.DiscoKey) { + e.wgLock.Lock() + defer e.wgLock.Unlock() + + now := time.Now() + was, ok := e.recvActivityAt[dk] + if !ok { + // Not a trimmable peer we care about tracking. (See isTrimmablePeer) + return + } + e.recvActivityAt[dk] = now + + // If the last activity time jumped a bunch (say, at least + // half the idle timeout) then see if we need to reprogram + // Wireguard. This could probably be just + // lazyPeerIdleThreshold without the divide by 2, but + // maybeReconfigWireguardLocked is cheap enough to call every + // couple minutes (just not on every packet). + if was.IsZero() || now.Sub(was) < -lazyPeerIdleThreshold/2 { + e.maybeReconfigWireguardLocked() + } +} + +// isActiveSince reports whether the peer identified by (dk, ip) has +// had a packet sent to or received from it since t. +// +// e.wgLock must be held. +func (e *userspaceEngine) isActiveSince(dk tailcfg.DiscoKey, ip wgcfg.IP, t time.Time) bool { + if e.recvActivityAt[dk].After(t) { + return true + } + pip := packet.IP(binary.BigEndian.Uint32(ip.Addr[12:])) + timePtr, ok := e.sentActivityAt[pip] + if !ok { + return false + } + unixTime := atomic.LoadInt64(timePtr) + return unixTime >= t.Unix() +} + +// discoKeyFromPeer returns the DiscoKey for a wireguard config's Peer. +// +// Invariant: isTrimmablePeer(p) == true, so it should have 1 endpoint with +// Host of form "<64-hex-digits>.disco.tailscale". If invariant is violated, +// we return the zero value. +func discoKeyFromPeer(p *wgcfg.Peer) tailcfg.DiscoKey { + host := p.Endpoints[0].Host + if len(host) < 64 { + return tailcfg.DiscoKey{} + } + k, err := key.NewPublicFromHexMem(mem.S(host[:64])) + if err != nil { + return tailcfg.DiscoKey{} + } + return tailcfg.DiscoKey(k) +} + +// e.wgLock must be held. +func (e *userspaceEngine) maybeReconfigWireguardLocked() error { + full := e.lastCfgFull + + // Compute a minimal config to pass to wireguard-go + // based on the full config. Prune off all the peers + // and only add the active ones back. + min := full + min.Peers = nil + + // We'll only keep a peer around if it's been active in + // the past 5 minutes. That's more than WireGuard's key + // rotation time anyway so it's no harm if we remove it + // later if it's been inactive. + activeCutoff := time.Now().Add(-lazyPeerIdleThreshold) + + // Not all peers can be trimmed from the network map (see + // isTrimmablePeer). For those are are trimmable, keep track + // of their DiscoKey and Tailscale IPs. These are the ones + // we'll need to install tracking hooks for to watch their + // send/receive activity. + trackDisco := make([]tailcfg.DiscoKey, 0, len(full.Peers)) + trackIPs := make([]wgcfg.IP, 0, len(full.Peers)) + + for i := range full.Peers { + p := &full.Peers[i] + if !isTrimmablePeer(p) { + min.Peers = append(min.Peers, *p) + continue + } + tsIP := p.AllowedIPs[0].IP + dk := discoKeyFromPeer(p) + trackDisco = append(trackDisco, dk) + trackIPs = append(trackIPs, tsIP) + if e.isActiveSince(dk, tsIP, activeCutoff) { + min.Peers = append(min.Peers, *p) + } + } + + if !updateSig(&e.lastEngineSigTrim, min) { + // No changes + return nil + } + + e.updateActivityMapsLocked(trackDisco, trackIPs) + + e.logf("wgengine: Reconfig: configuring userspace wireguard config (with %d/%d peers)", len(min.Peers), len(full.Peers)) + if err := e.wgdev.Reconfig(&min); err != nil { + e.logf("wgdev.Reconfig: %v", err) + return err + } + return nil +} + +// updateActivityMapsLocked updates the data structures used for tracking the activity +// of wireguard peers that we might add/remove dynamically from the real config +// as given to wireguard-go. +// +// e.wgLock must be held. +func (e *userspaceEngine) updateActivityMapsLocked(trackDisco []tailcfg.DiscoKey, trackIPs []wgcfg.IP) { + // Generate the new map of which discokeys we want to track + // receive times for. + mr := map[tailcfg.DiscoKey]time.Time{} // TODO: only recreate this if set of keys changed + for _, dk := range trackDisco { + // Preserve old times in the new map, but also + // populate map entries for new trackDisco values with + // time.Time{} zero values. (Only entries in this map + // are tracked, so the Time zero values allow it to be + // tracked later) + mr[dk] = e.recvActivityAt[dk] + } + e.recvActivityAt = mr + + oldTime := e.sentActivityAt + e.sentActivityAt = make(map[packet.IP]*int64, len(oldTime)) + oldFunc := e.destIPActivityFuncs + e.destIPActivityFuncs = make(map[packet.IP]func(), len(oldFunc)) + + for _, wip := range trackIPs { + pip := packet.IP(binary.BigEndian.Uint32(wip.Addr[12:])) + timePtr := oldTime[pip] + if timePtr == nil { + timePtr = new(int64) + } + e.sentActivityAt[pip] = timePtr + + fn := oldFunc[pip] + if fn == nil { + // This is the func that gets run on every outgoing packet for tracked IPs: + fn = func() { + now, old := time.Now().Unix(), atomic.LoadInt64(timePtr) + if old > now-10 { + return + } + atomic.StoreInt64(timePtr, now) + if old == 0 || (now-old) <= 60 { + e.wgLock.Lock() + defer e.wgLock.Unlock() + e.maybeReconfigWireguardLocked() + } + } + } + e.destIPActivityFuncs[pip] = fn + } + e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs) +} + func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) error { if routerCfg == nil { panic("routerCfg must not be nil") @@ -588,29 +795,24 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) routerCfg.Domains = append([]string{magicDNSDomain}, routerCfg.Domains...) } - engineChanged := updateSig(&e.lastEngineSig, cfg) + engineChanged := updateSig(&e.lastEngineSigFull, cfg) routerChanged := updateSig(&e.lastRouterSig, routerCfg) if !engineChanged && !routerChanged { return ErrNoChanges } - e.lastCfg = cfg.Copy() - - if engineChanged { - e.logf("wgengine: Reconfig: configuring userspace wireguard config") - // Tell magicsock about the new (or initial) private key - // (which is needed by DERP) before wgdev gets it, as wgdev - // will start trying to handshake, which we want to be able to - // go over DERP. - if err := e.magicConn.SetPrivateKey(cfg.PrivateKey); err != nil { - e.logf("wgengine: Reconfig: SetPrivateKey: %v", err) - } - - if err := e.wgdev.Reconfig(cfg); err != nil { - e.logf("wgdev.Reconfig: %v", err) - return err - } + e.lastCfgFull = cfg.Copy() + + // Tell magicsock about the new (or initial) private key + // (which is needed by DERP) before wgdev gets it, as wgdev + // will start trying to handshake, which we want to be able to + // go over DERP. + if err := e.magicConn.SetPrivateKey(cfg.PrivateKey); err != nil { + e.logf("wgengine: Reconfig: SetPrivateKey: %v", err) + } + e.magicConn.UpdatePeers(peerSet) - e.magicConn.UpdatePeers(peerSet) + if err := e.maybeReconfigWireguardLocked(); err != nil { + return err } if routerChanged { @@ -758,15 +960,9 @@ func (e *userspaceEngine) getStatus() (*Status, error) { var peers []PeerStatus for _, pk := range e.peerSequence { - p := pp[pk] - if p == nil { - p = &PeerStatus{} + if p, ok := pp[pk]; ok { // ignore idle ones not in wireguard-go's config + peers = append(peers, *p) } - peers = append(peers, *p) - } - - if len(pp) != len(e.peerSequence) { - e.logf("wg status returned %v peers, expected %v", len(pp), len(e.peerSequence)) } return &Status{