wgengine: configure wireguard peers lazily, as needed

wireguard-go uses 3 goroutines per peer (with reasonably large stacks
& buffers).

Rather than tell wireguard-go about all our peers, only tell it about
peers we're actively communicating with. That means we need hooks into
magicsock's packet receiving path and tstun's packet sending path to
lazily create a wireguard peer on demand from the network map.

This frees up lots of memory for iOS (where we have almost nothing
left for larger domains with many users).

We should ideally do this in wireguard-go itself one day, but that'd
be a pretty big change.

Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
pull/600/head
Brad Fitzpatrick 4 years ago committed by Brad Fitzpatrick
parent 5066b824a6
commit 16a9cfe2f4

@ -58,14 +58,15 @@ import (
// A Conn routes UDP packets and actively manages a list of its endpoints. // A Conn routes UDP packets and actively manages a list of its endpoints.
// It implements wireguard/conn.Bind. // It implements wireguard/conn.Bind.
type Conn struct { type Conn struct {
pconnPort uint16 // the preferred port from opts.Port; 0 means auto pconnPort uint16 // the preferred port from opts.Port; 0 means auto
pconn4 *RebindingUDPConn pconn4 *RebindingUDPConn
pconn6 *RebindingUDPConn // non-nil if IPv6 available pconn6 *RebindingUDPConn // non-nil if IPv6 available
epFunc func(endpoints []string) epFunc func(endpoints []string)
logf logger.Logf logf logger.Logf
sendLogLimit *rate.Limiter sendLogLimit *rate.Limiter
netChecker *netcheck.Client netChecker *netcheck.Client
idleFunc func() time.Duration // nil means unknown idleFunc func() time.Duration // nil means unknown
noteRecvActivity func(tailcfg.DiscoKey) // or nil, see Options.NoteRecvActivity
// bufferedIPv4From and bufferedIPv4Packet are owned by // bufferedIPv4From and bufferedIPv4Packet are owned by
// ReceiveIPv4, and used when both a DERP and IPv4 packet arrive // ReceiveIPv4, and used when both a DERP and IPv4 packet arrive
@ -89,6 +90,13 @@ type Conn struct {
// ============================================================ // ============================================================
mu sync.Mutex // guards all following fields mu sync.Mutex // guards all following fields
// canCreateEPUnlocked tracks at one place whether mu is
// already held. It's then checked in CreateEndpoint to avoid
// double-locking mu and thus deadlocking. mu should be held
// while setting this; but can be read without mu held.
// TODO(bradfitz): delete this shameful hack; refactor the one use
canCreateEPUnlocked syncs.AtomicBool
started bool // Start was called started bool // Start was called
closed bool // Close was called closed bool // Close was called
@ -104,8 +112,8 @@ type Conn struct {
nodeOfDisco map[tailcfg.DiscoKey]*tailcfg.Node nodeOfDisco map[tailcfg.DiscoKey]*tailcfg.Node
discoOfNode map[tailcfg.NodeKey]tailcfg.DiscoKey discoOfNode map[tailcfg.NodeKey]tailcfg.DiscoKey
discoOfAddr map[netaddr.IPPort]tailcfg.DiscoKey // validated non-DERP paths only discoOfAddr map[netaddr.IPPort]tailcfg.DiscoKey // validated non-DERP paths only
endpointOfDisco map[tailcfg.DiscoKey]*discoEndpoint endpointOfDisco map[tailcfg.DiscoKey]*discoEndpoint // those with activity only
sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte // nacl/box precomputed key sharedDiscoKey map[tailcfg.DiscoKey]*[32]byte // nacl/box precomputed key
// addrsByUDP is a map of every remote ip:port to a priority // addrsByUDP is a map of every remote ip:port to a priority
// list of endpoint addresses for a peer. // list of endpoint addresses for a peer.
@ -235,6 +243,17 @@ type Options struct {
// PacketListener optionally specifies how to create PacketConns. // PacketListener optionally specifies how to create PacketConns.
// It's meant for testing. // It's meant for testing.
PacketListener nettype.PacketListener PacketListener nettype.PacketListener
// NoteRecvActivity, if provided, is a func for magicsock to
// call whenever it receives a packet from a a
// discovery-capable peer if it's been more than ~10 seconds
// since the last one. (10 seconds is somewhat arbitrary; the
// sole user just doesn't need or want it called on every
// packet, just every minute or two for Wireguard timeouts,
// and 10 seconds seems like a good trade-off between often
// enough and not too often.) The provided func likely calls
// Conn.CreateEndpoint, which acquires Conn.mu.
NoteRecvActivity func(tailcfg.DiscoKey)
} }
func (o *Options) logf() logger.Logf { func (o *Options) logf() logger.Logf {
@ -282,6 +301,7 @@ func NewConn(opts Options) (*Conn, error) {
c.epFunc = opts.endpointsFunc() c.epFunc = opts.endpointsFunc()
c.idleFunc = opts.IdleFunc c.idleFunc = opts.IdleFunc
c.packetListener = opts.PacketListener c.packetListener = opts.PacketListener
c.noteRecvActivity = opts.NoteRecvActivity
if err := c.initialBind(); err != nil { if err := c.initialBind(); err != nil {
return nil, err return nil, err
@ -1300,6 +1320,16 @@ func wgRecvAddr(e conn.Endpoint, ipp netaddr.IPPort, addr *net.UDPAddr) *net.UDP
return ipp.UDPAddr() return ipp.UDPAddr()
} }
// noteRecvActivity calls the magicsock.Conn.noteRecvActivity hook if
// e is a discovery-capable peer.
//
// This should be called whenever a packet arrives from e.
func noteRecvActivity(e conn.Endpoint) {
if de, ok := e.(*discoEndpoint); ok {
de.onRecvActivity()
}
}
func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr, err error) { func (c *Conn) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, addr *net.UDPAddr, err error) {
Top: Top:
// First, process any buffered packet from earlier. // First, process any buffered packet from earlier.
@ -1307,6 +1337,7 @@ Top:
c.bufferedIPv4From = netaddr.IPPort{} c.bufferedIPv4From = netaddr.IPPort{}
addr = from.UDPAddr() addr = from.UDPAddr()
ep := c.findEndpoint(from, addr) ep := c.findEndpoint(from, addr)
noteRecvActivity(ep)
return copy(b, c.bufferedIPv4Packet), ep, wgRecvAddr(ep, from, addr), nil return copy(b, c.bufferedIPv4Packet), ep, wgRecvAddr(ep, from, addr), nil
} }
@ -1319,6 +1350,7 @@ Top:
var addrSet *AddrSet var addrSet *AddrSet
var discoEp *discoEndpoint var discoEp *discoEndpoint
var ipp netaddr.IPPort var ipp netaddr.IPPort
var didNoteRecvActivity bool
select { select {
case dm := <-c.derpRecvCh: case dm := <-c.derpRecvCh:
@ -1360,6 +1392,24 @@ Top:
c.mu.Lock() c.mu.Lock()
if dk, ok := c.discoOfNode[tailcfg.NodeKey(dm.src)]; ok { if dk, ok := c.discoOfNode[tailcfg.NodeKey(dm.src)]; ok {
discoEp = c.endpointOfDisco[dk] discoEp = c.endpointOfDisco[dk]
// If we know about the node (it's in discoOfNode) but don't know about the
// endpoint, that's because it's an idle peer that doesn't yet exist in the
// wireguard config. So run the receive hook, if defined, which should
// create the wireguard peer.
if discoEp == nil && c.noteRecvActivity != nil {
didNoteRecvActivity = true
c.mu.Unlock() // release lock before calling noteRecvActivity
c.noteRecvActivity(dk) // (calls back into CreateEndpoint)
// Now require the lock. No invariants need to be rechecked; just
// 1-2 map lookups follow that are harmless if, say, the peer has
// been deleted during this time. In that case we'll treate it as a
// legacy pre-disco UDP receive and hand it to wireguard which'll
// likely just drop it.
c.mu.Lock()
discoEp = c.endpointOfDisco[dk]
c.logf("magicsock: DERP packet received from idle peer %v; created=%v", dm.src.ShortString(), discoEp != nil)
}
} }
if discoEp == nil { if discoEp == nil {
addrSet = c.addrsByKey[dm.src] addrSet = c.addrsByKey[dm.src]
@ -1398,6 +1448,9 @@ Top:
} else { } else {
ep = c.findEndpoint(ipp, addr) ep = c.findEndpoint(ipp, addr)
} }
if !didNoteRecvActivity {
noteRecvActivity(ep)
}
return n, ep, wgRecvAddr(ep, ipp, addr), nil return n, ep, wgRecvAddr(ep, ipp, addr), nil
} }
@ -1424,6 +1477,7 @@ func (c *Conn) ReceiveIPv6(b []byte) (int, conn.Endpoint, *net.UDPAddr, error) {
} }
ep := c.findEndpoint(ipp, addr) ep := c.findEndpoint(ipp, addr)
noteRecvActivity(ep)
return n, ep, wgRecvAddr(ep, ipp, addr), nil return n, ep, wgRecvAddr(ep, ipp, addr), nil
} }
} }
@ -1440,7 +1494,7 @@ const (
discoVerboseLog discoVerboseLog
) )
func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey key.Public, dstDisco tailcfg.DiscoKey, m disco.Message, logLevel discoLogLevel) (sent bool, err error) { func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey tailcfg.NodeKey, dstDisco tailcfg.DiscoKey, m disco.Message, logLevel discoLogLevel) (sent bool, err error) {
c.mu.Lock() c.mu.Lock()
if c.closed { if c.closed {
c.mu.Unlock() c.mu.Unlock()
@ -1458,7 +1512,7 @@ func (c *Conn) sendDiscoMessage(dst netaddr.IPPort, dstKey key.Public, dstDisco
c.mu.Unlock() c.mu.Unlock()
pkt = box.SealAfterPrecomputation(pkt, m.AppendMarshal(nil), &nonce, sharedKey) pkt = box.SealAfterPrecomputation(pkt, m.AppendMarshal(nil), &nonce, sharedKey)
sent, err = c.sendAddr(dst, dstKey, pkt) sent, err = c.sendAddr(dst, key.Public(dstKey), pkt)
if sent { if sent {
if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco) { if logLevel == discoLog || (logLevel == discoVerboseLog && debugDisco) {
c.logf("magicsock: disco: %v->%v (%v, %v) sent %v", c.discoShort, dstDisco.ShortString(), dstKey.ShortString(), derpStr(dst.String()), disco.MessageSummary(m)) c.logf("magicsock: disco: %v->%v (%v, %v) sent %v", c.discoShort, dstDisco.ShortString(), dstKey.ShortString(), derpStr(dst.String()), disco.MessageSummary(m))
@ -1507,16 +1561,45 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
return false return false
} }
de, ok := c.endpointOfDisco[sender] peerNode, ok := c.nodeOfDisco[sender]
if !ok { if !ok {
if debugDisco { if debugDisco {
c.logf("magicsock: disco: ignoring disco-looking frame, don't know about %v", sender.ShortString()) c.logf("magicsock: disco: ignoring disco-looking frame, don't know node for %v", sender.ShortString())
} }
// Returning false keeps passing it down, to WireGuard. // Returning false keeps passing it down, to WireGuard.
// WireGuard will almost surely reject it, but give it a chance. // WireGuard will almost surely reject it, but give it a chance.
return false return false
} }
de, ok := c.endpointOfDisco[sender]
if !ok {
// We don't have an active endpoint for this sender but we knew about the node, so
// it's an idle endpoint that doesn't yet exist in the wireguard config. We now have
// to notify the userspace engine (via noteRecvActivity) so wireguard-go can create
// an Endpoint (ultimately calling our CreateEndpoint).
if debugDisco {
c.logf("magicsock: disco: got message from inactive peer %v", sender.ShortString())
}
if c.noteRecvActivity == nil {
c.logf("magicsock: [unexpected] have node without endpoint, without c.noteRecvActivity hook")
return false
}
// noteRecvActivity calls back into CreateEndpoint, which we can't easily control,
// and CreateEndpoint expects to be called with c.mu held, but we hold it here, and
// it's too invasive for now to release it here and recheck invariants. So instead,
// use this unfortunate hack: set canCreateEPUnlocked which CreateEndpoint then
// checks to conditionally acquire the mutex. I'm so sorry.
c.canCreateEPUnlocked.Set(true)
c.noteRecvActivity(sender)
c.canCreateEPUnlocked.Set(false)
de, ok = c.endpointOfDisco[sender]
if !ok {
c.logf("magicsock: [unexpected] lazy endpoint not created for %v, %v", peerNode.Key.ShortString(), sender.ShortString())
return false
}
c.logf("magicsock: lazy endpoint created via disco message for %v, %v", peerNode.Key.ShortString(), sender.ShortString())
}
// First, do we even know (and thus care) about this sender? If not, // First, do we even know (and thus care) about this sender? If not,
// don't bother decrypting it. // don't bother decrypting it.
@ -1556,8 +1639,11 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
switch dm := dm.(type) { switch dm := dm.(type) {
case *disco.Ping: case *disco.Ping:
c.handlePingLocked(dm, de, src) c.handlePingLocked(dm, de, src, sender, peerNode)
case *disco.Pong: case *disco.Pong:
if de == nil {
return true
}
de.handlePongConnLocked(dm, src) de.handlePongConnLocked(dm, src)
case disco.CallMeMaybe: case disco.CallMeMaybe:
if src.IP != derpMagicIPAddr { if src.IP != derpMagicIPAddr {
@ -1565,26 +1651,40 @@ func (c *Conn) handleDiscoMessage(msg []byte, src netaddr.IPPort) bool {
c.logf("[unexpected] CallMeMaybe packets should only come via DERP") c.logf("[unexpected] CallMeMaybe packets should only come via DERP")
return true return true
} }
c.logf("magicsock: disco: %v<-%v (%v, %v) got call-me-maybe", c.discoShort, de.discoShort, de.publicKey.ShortString(), derpStr(src.String())) if de != nil {
go de.handleCallMeMaybe() c.logf("magicsock: disco: %v<-%v (%v, %v) got call-me-maybe", c.discoShort, de.discoShort, de.publicKey.ShortString(), derpStr(src.String()))
go de.handleCallMeMaybe()
}
} }
return true return true
} }
func (c *Conn) handlePingLocked(dm *disco.Ping, de *discoEndpoint, src netaddr.IPPort) { // de may be nil
likelyHeartBeat := src == de.lastPingFrom && time.Since(de.lastPingTime) < 5*time.Second func (c *Conn) handlePingLocked(dm *disco.Ping, de *discoEndpoint, src netaddr.IPPort, sender tailcfg.DiscoKey, peerNode *tailcfg.Node) {
de.lastPingFrom = src if peerNode == nil {
de.lastPingTime = time.Now() c.logf("magicsock: disco: [unexpected] ignoring ping from unknown peer Node")
return
}
likelyHeartBeat := de != nil && src == de.lastPingFrom && time.Since(de.lastPingTime) < 5*time.Second
var discoShort string
if de != nil {
discoShort = de.discoShort
de.lastPingFrom = src
de.lastPingTime = time.Now()
} else {
discoShort = sender.ShortString()
}
if !likelyHeartBeat || debugDisco { if !likelyHeartBeat || debugDisco {
c.logf("magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, de.discoShort, de.publicKey.ShortString(), src, dm.TxID[:6]) c.logf("magicsock: disco: %v<-%v (%v, %v) got ping tx=%x", c.discoShort, discoShort, peerNode.Key.ShortString(), src, dm.TxID[:6])
} }
// Remember this route if not present. // Remember this route if not present.
c.setAddrToDiscoLocked(src, de.discoKey, nil) c.setAddrToDiscoLocked(src, sender, nil)
pongDst := src ipDst := src
go de.sendDiscoMessage(pongDst, &disco.Pong{ discoDest := sender
go c.sendDiscoMessage(ipDst, peerNode.Key, discoDest, &disco.Pong{
TxID: dm.TxID, TxID: dm.TxID,
Src: src, Src: src,
}, discoVerboseLog) }, discoVerboseLog)
@ -2455,17 +2555,30 @@ func (c *Conn) CreateEndpoint(pubKey [32]byte, addrs string) (conn.Endpoint, err
if err != nil { if err != nil {
return nil, fmt.Errorf("magicsock: invalid discokey endpoint %q for %v: %w", addrs, pk.ShortString(), err) return nil, fmt.Errorf("magicsock: invalid discokey endpoint %q for %v: %w", addrs, pk.ShortString(), err)
} }
c.mu.Lock() if !c.canCreateEPUnlocked.Get() { // sorry
defer c.mu.Unlock() c.mu.Lock()
defer c.mu.Unlock()
}
de := &discoEndpoint{ de := &discoEndpoint{
c: c, c: c,
publicKey: pk, // peer public key (for WireGuard + DERP) publicKey: tailcfg.NodeKey(pk), // peer public key (for WireGuard + DERP)
discoKey: tailcfg.DiscoKey(discoKey), // for discovery mesages discoKey: tailcfg.DiscoKey(discoKey), // for discovery mesages
discoShort: tailcfg.DiscoKey(discoKey).ShortString(), discoShort: tailcfg.DiscoKey(discoKey).ShortString(),
wgEndpointHostPort: addrs, wgEndpointHostPort: addrs,
sentPing: map[stun.TxID]sentPing{}, sentPing: map[stun.TxID]sentPing{},
endpointState: map[netaddr.IPPort]*endpointState{}, endpointState: map[netaddr.IPPort]*endpointState{},
} }
lastRecvTime := new(int64) // atomic
de.onRecvActivity = func() {
now := time.Now().Unix()
old := atomic.LoadInt64(lastRecvTime)
if old == 0 || old <= now-10 {
atomic.StoreInt64(lastRecvTime, now)
if c.noteRecvActivity != nil {
c.noteRecvActivity(de.discoKey)
}
}
}
de.initFakeUDPAddr() de.initFakeUDPAddr()
de.updateFromNode(c.nodeOfDisco[de.discoKey]) de.updateFromNode(c.nodeOfDisco[de.discoKey])
c.endpointOfDisco[de.discoKey] = de c.endpointOfDisco[de.discoKey] = de
@ -2694,14 +2807,14 @@ func (c *Conn) UpdateStatus(sb *ipnstate.StatusBuilder) {
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
for dk, de := range c.endpointOfDisco { for dk, n := range c.nodeOfDisco {
ps := &ipnstate.PeerStatus{InMagicSock: true} ps := &ipnstate.PeerStatus{InMagicSock: true}
if node, ok := c.nodeOfDisco[dk]; ok { ps.Addrs = append(ps.Addrs, n.Endpoints...)
ps.Addrs = append(ps.Addrs, node.Endpoints...) ps.Relay = c.derpRegionCodeOfAddrLocked(n.DERP)
ps.Relay = c.derpRegionCodeOfAddrLocked(node.DERP) if de, ok := c.endpointOfDisco[dk]; ok {
de.populatePeerStatus(ps)
} }
de.populatePeerStatus(ps) sb.AddPeer(key.Public(n.Key), ps)
sb.AddPeer(de.publicKey, ps)
} }
// Old-style (pre-disco) peers: // Old-style (pre-disco) peers:
for k, as := range c.addrsByKey { for k, as := range c.addrsByKey {
@ -2731,12 +2844,13 @@ func udpAddrDebugString(ua net.UDPAddr) string {
type discoEndpoint struct { type discoEndpoint struct {
// These fields are initialized once and never modified. // These fields are initialized once and never modified.
c *Conn c *Conn
publicKey key.Public // peer public key (for WireGuard + DERP) publicKey tailcfg.NodeKey // peer public key (for WireGuard + DERP)
discoKey tailcfg.DiscoKey // for discovery mesages discoKey tailcfg.DiscoKey // for discovery mesages
discoShort string // ShortString of discoKey discoShort string // ShortString of discoKey
fakeWGAddr netaddr.IPPort // the UDP address we tell wireguard-go we're using fakeWGAddr netaddr.IPPort // the UDP address we tell wireguard-go we're using
fakeWGAddrStd *net.UDPAddr // the *net.UDPAddr form of fakeWGAddr fakeWGAddrStd *net.UDPAddr // the *net.UDPAddr form of fakeWGAddr
wgEndpointHostPort string // string from CreateEndpoint: "<hex-discovery-key>.disco.tailscale:12345" wgEndpointHostPort string // string from CreateEndpoint: "<hex-discovery-key>.disco.tailscale:12345"
onRecvActivity func()
// Owned by Conn.mu: // Owned by Conn.mu:
lastPingFrom netaddr.IPPort lastPingFrom netaddr.IPPort
@ -2958,10 +3072,10 @@ func (de *discoEndpoint) send(b []byte) error {
} }
var err error var err error
if !udpAddr.IsZero() { if !udpAddr.IsZero() {
_, err = de.c.sendAddr(udpAddr, de.publicKey, b) _, err = de.c.sendAddr(udpAddr, key.Public(de.publicKey), b)
} }
if !derpAddr.IsZero() { if !derpAddr.IsZero() {
if ok, _ := de.c.sendAddr(derpAddr, de.publicKey, b); ok && err != nil { if ok, _ := de.c.sendAddr(derpAddr, key.Public(de.publicKey), b); ok && err != nil {
// UDP failed but DERP worked, so good enough: // UDP failed but DERP worked, so good enough:
return nil return nil
} }

@ -947,7 +947,12 @@ func TestDiscoMessage(t *testing.T) {
peer1Priv := c.discoPrivate peer1Priv := c.discoPrivate
c.endpointOfDisco = map[tailcfg.DiscoKey]*discoEndpoint{ c.endpointOfDisco = map[tailcfg.DiscoKey]*discoEndpoint{
tailcfg.DiscoKey(peer1Pub): &discoEndpoint{ tailcfg.DiscoKey(peer1Pub): &discoEndpoint{
// ... // ... (enough for this test)
},
}
c.nodeOfDisco = map[tailcfg.DiscoKey]*tailcfg.Node{
tailcfg.DiscoKey(peer1Pub): &tailcfg.Node{
// ... (enough for this test)
}, },
} }

@ -66,6 +66,8 @@ type TUN struct {
_ [4]byte // force 64-bit alignment of following field on 32-bit _ [4]byte // force 64-bit alignment of following field on 32-bit
lastActivityAtomic int64 // unix seconds of last send or receive lastActivityAtomic int64 // unix seconds of last send or receive
destIPActivity atomic.Value // of map[packet.IP]func()
// buffer stores the oldest unconsumed packet from tdev. // buffer stores the oldest unconsumed packet from tdev.
// It is made a static buffer in order to avoid allocations. // It is made a static buffer in order to avoid allocations.
buffer [maxBufferSize]byte buffer [maxBufferSize]byte
@ -129,6 +131,14 @@ func WrapTUN(logf logger.Logf, tdev tun.Device) *TUN {
return tun return tun
} }
// SetDestIPActivityFuncs sets a map of funcs to run per packet
// destination (the map keys).
//
// The map ownership passes to the TUN. It must be non-nil.
func (t *TUN) SetDestIPActivityFuncs(m map[packet.IP]func()) {
t.destIPActivity.Store(m)
}
func (t *TUN) Close() error { func (t *TUN) Close() error {
select { select {
case <-t.closed: case <-t.closed:
@ -204,10 +214,7 @@ func (t *TUN) poll() {
} }
} }
func (t *TUN) filterOut(buf []byte) filter.Response { func (t *TUN) filterOut(p *packet.ParsedPacket) filter.Response {
p := parsedPacketPool.Get().(*packet.ParsedPacket)
defer parsedPacketPool.Put(p)
p.Decode(buf)
if t.PreFilterOut != nil { if t.PreFilterOut != nil {
if t.PreFilterOut(p, t) == filter.Drop { if t.PreFilterOut(p, t) == filter.Drop {
@ -271,8 +278,18 @@ func (t *TUN) Read(buf []byte, offset int) (int, error) {
} }
} }
p := parsedPacketPool.Get().(*packet.ParsedPacket)
defer parsedPacketPool.Put(p)
p.Decode(buf[offset : offset+n])
if m, ok := t.destIPActivity.Load().(map[packet.IP]func()); ok {
if fn := m[p.DstIP]; fn != nil {
fn()
}
}
if !t.disableFilter { if !t.disableFilter {
response := t.filterOut(buf[offset : offset+n]) response := t.filterOut(p)
if response != filter.Accept { if response != filter.Accept {
// Wireguard considers read errors fatal; pretend nothing was read // Wireguard considers read errors fatal; pretend nothing was read
return 0, nil return 0, nil

@ -8,6 +8,7 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -59,6 +60,15 @@ const (
// magicDNSDomain is the parent domain for Tailscale nodes. // magicDNSDomain is the parent domain for Tailscale nodes.
const magicDNSDomain = "b.tailscale.net" const magicDNSDomain = "b.tailscale.net"
// Lazy wireguard-go configuration parameters.
const (
// lazyPeerIdleThreshold is the idle duration after
// which we remove a peer from the wireguard configuration.
// (This includes peers that have never been idle, which
// effectively have infinite idleness)
lazyPeerIdleThreshold = 5 * time.Minute
)
type userspaceEngine struct { type userspaceEngine struct {
logf logger.Logf logf logger.Logf
reqCh chan struct{} reqCh chan struct{}
@ -76,10 +86,14 @@ type userspaceEngine struct {
// incorrectly sent to us. // incorrectly sent to us.
localAddrs atomic.Value // of map[packet.IP]bool localAddrs atomic.Value // of map[packet.IP]bool
wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below
lastEngineSig string lastCfgFull wgcfg.Config
lastRouterSig string lastRouterSig string // of router.Config
lastCfg wgcfg.Config lastEngineSigFull string // of full wireguard config
lastEngineSigTrim string // of trimmed wireguard config
recvActivityAt map[tailcfg.DiscoKey]time.Time
sentActivityAt map[packet.IP]*int64 // value is atomic int64 of unixtime
destIPActivityFuncs map[packet.IP]func()
mu sync.Mutex // guards following; see lock order comment below mu sync.Mutex // guards following; see lock order comment below
closing bool // Close was called (even if we're still closing) closing bool // Close was called (even if we're still closing)
@ -210,10 +224,11 @@ func newUserspaceEngineAdvanced(conf EngineConfig) (_ Engine, reterr error) {
e.RequestStatus() e.RequestStatus()
} }
magicsockOpts := magicsock.Options{ magicsockOpts := magicsock.Options{
Logf: logf, Logf: logf,
Port: conf.ListenPort, Port: conf.ListenPort,
EndpointsFunc: endpointsFn, EndpointsFunc: endpointsFn,
IdleFunc: e.tundev.IdleDuration, IdleFunc: e.tundev.IdleDuration,
NoteRecvActivity: e.noteReceiveActivity,
} }
e.magicConn, err = magicsock.NewConn(magicsockOpts) e.magicConn, err = magicsock.NewConn(magicsockOpts)
if err != nil { if err != nil {
@ -513,8 +528,8 @@ func (e *userspaceEngine) pinger(peerKey wgcfg.Key, ips []wgcfg.IP) {
var srcIP packet.IP var srcIP packet.IP
e.wgLock.Lock() e.wgLock.Lock()
if len(e.lastCfg.Addresses) > 0 { if len(e.lastCfgFull.Addresses) > 0 {
srcIP = packet.NewIP(e.lastCfg.Addresses[0].IP.IP()) srcIP = packet.NewIP(e.lastCfgFull.Addresses[0].IP.IP())
} }
e.wgLock.Unlock() e.wgLock.Unlock()
@ -554,6 +569,198 @@ func updateSig(last *string, v interface{}) (changed bool) {
return false return false
} }
// isTrimmablePeer reports whether p is a peer that we can trim out of the
// network map.
//
// We can only trim peers that both a) support discovery (because we
// know who they are when we receive their data and don't need to rely
// on wireguard-go figuring it out) and b) for implementation
// simplicity, have only one IP address (an IPv4 /32), which is the
// common case for most peers. Subnet router nodes will just always be
// created in the wireguard-go config.
func isTrimmablePeer(p *wgcfg.Peer) bool {
if len(p.AllowedIPs) != 1 || len(p.Endpoints) != 1 {
return false
}
if !strings.HasSuffix(p.Endpoints[0].Host, ".disco.tailscale") {
return false
}
aip := p.AllowedIPs[0]
// TODO: IPv6 support, once we support IPv6 within the tunnel. In that case,
// len(p.AllowedIPs) probably will be more than 1.
if aip.Mask != 32 || !aip.IP.Is4() {
return false
}
return true
}
// noteReceiveActivity is called by magicsock when a packet has been received
// by the peer using discovery key dk. Magicsock calls this no more than
// every 10 seconds for a given peer.
func (e *userspaceEngine) noteReceiveActivity(dk tailcfg.DiscoKey) {
e.wgLock.Lock()
defer e.wgLock.Unlock()
now := time.Now()
was, ok := e.recvActivityAt[dk]
if !ok {
// Not a trimmable peer we care about tracking. (See isTrimmablePeer)
return
}
e.recvActivityAt[dk] = now
// If the last activity time jumped a bunch (say, at least
// half the idle timeout) then see if we need to reprogram
// Wireguard. This could probably be just
// lazyPeerIdleThreshold without the divide by 2, but
// maybeReconfigWireguardLocked is cheap enough to call every
// couple minutes (just not on every packet).
if was.IsZero() || now.Sub(was) < -lazyPeerIdleThreshold/2 {
e.maybeReconfigWireguardLocked()
}
}
// isActiveSince reports whether the peer identified by (dk, ip) has
// had a packet sent to or received from it since t.
//
// e.wgLock must be held.
func (e *userspaceEngine) isActiveSince(dk tailcfg.DiscoKey, ip wgcfg.IP, t time.Time) bool {
if e.recvActivityAt[dk].After(t) {
return true
}
pip := packet.IP(binary.BigEndian.Uint32(ip.Addr[12:]))
timePtr, ok := e.sentActivityAt[pip]
if !ok {
return false
}
unixTime := atomic.LoadInt64(timePtr)
return unixTime >= t.Unix()
}
// discoKeyFromPeer returns the DiscoKey for a wireguard config's Peer.
//
// Invariant: isTrimmablePeer(p) == true, so it should have 1 endpoint with
// Host of form "<64-hex-digits>.disco.tailscale". If invariant is violated,
// we return the zero value.
func discoKeyFromPeer(p *wgcfg.Peer) tailcfg.DiscoKey {
host := p.Endpoints[0].Host
if len(host) < 64 {
return tailcfg.DiscoKey{}
}
k, err := key.NewPublicFromHexMem(mem.S(host[:64]))
if err != nil {
return tailcfg.DiscoKey{}
}
return tailcfg.DiscoKey(k)
}
// e.wgLock must be held.
func (e *userspaceEngine) maybeReconfigWireguardLocked() error {
full := e.lastCfgFull
// Compute a minimal config to pass to wireguard-go
// based on the full config. Prune off all the peers
// and only add the active ones back.
min := full
min.Peers = nil
// We'll only keep a peer around if it's been active in
// the past 5 minutes. That's more than WireGuard's key
// rotation time anyway so it's no harm if we remove it
// later if it's been inactive.
activeCutoff := time.Now().Add(-lazyPeerIdleThreshold)
// Not all peers can be trimmed from the network map (see
// isTrimmablePeer). For those are are trimmable, keep track
// of their DiscoKey and Tailscale IPs. These are the ones
// we'll need to install tracking hooks for to watch their
// send/receive activity.
trackDisco := make([]tailcfg.DiscoKey, 0, len(full.Peers))
trackIPs := make([]wgcfg.IP, 0, len(full.Peers))
for i := range full.Peers {
p := &full.Peers[i]
if !isTrimmablePeer(p) {
min.Peers = append(min.Peers, *p)
continue
}
tsIP := p.AllowedIPs[0].IP
dk := discoKeyFromPeer(p)
trackDisco = append(trackDisco, dk)
trackIPs = append(trackIPs, tsIP)
if e.isActiveSince(dk, tsIP, activeCutoff) {
min.Peers = append(min.Peers, *p)
}
}
if !updateSig(&e.lastEngineSigTrim, min) {
// No changes
return nil
}
e.updateActivityMapsLocked(trackDisco, trackIPs)
e.logf("wgengine: Reconfig: configuring userspace wireguard config (with %d/%d peers)", len(min.Peers), len(full.Peers))
if err := e.wgdev.Reconfig(&min); err != nil {
e.logf("wgdev.Reconfig: %v", err)
return err
}
return nil
}
// updateActivityMapsLocked updates the data structures used for tracking the activity
// of wireguard peers that we might add/remove dynamically from the real config
// as given to wireguard-go.
//
// e.wgLock must be held.
func (e *userspaceEngine) updateActivityMapsLocked(trackDisco []tailcfg.DiscoKey, trackIPs []wgcfg.IP) {
// Generate the new map of which discokeys we want to track
// receive times for.
mr := map[tailcfg.DiscoKey]time.Time{} // TODO: only recreate this if set of keys changed
for _, dk := range trackDisco {
// Preserve old times in the new map, but also
// populate map entries for new trackDisco values with
// time.Time{} zero values. (Only entries in this map
// are tracked, so the Time zero values allow it to be
// tracked later)
mr[dk] = e.recvActivityAt[dk]
}
e.recvActivityAt = mr
oldTime := e.sentActivityAt
e.sentActivityAt = make(map[packet.IP]*int64, len(oldTime))
oldFunc := e.destIPActivityFuncs
e.destIPActivityFuncs = make(map[packet.IP]func(), len(oldFunc))
for _, wip := range trackIPs {
pip := packet.IP(binary.BigEndian.Uint32(wip.Addr[12:]))
timePtr := oldTime[pip]
if timePtr == nil {
timePtr = new(int64)
}
e.sentActivityAt[pip] = timePtr
fn := oldFunc[pip]
if fn == nil {
// This is the func that gets run on every outgoing packet for tracked IPs:
fn = func() {
now, old := time.Now().Unix(), atomic.LoadInt64(timePtr)
if old > now-10 {
return
}
atomic.StoreInt64(timePtr, now)
if old == 0 || (now-old) <= 60 {
e.wgLock.Lock()
defer e.wgLock.Unlock()
e.maybeReconfigWireguardLocked()
}
}
}
e.destIPActivityFuncs[pip] = fn
}
e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs)
}
func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) error { func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config) error {
if routerCfg == nil { if routerCfg == nil {
panic("routerCfg must not be nil") panic("routerCfg must not be nil")
@ -588,29 +795,24 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config)
routerCfg.Domains = append([]string{magicDNSDomain}, routerCfg.Domains...) routerCfg.Domains = append([]string{magicDNSDomain}, routerCfg.Domains...)
} }
engineChanged := updateSig(&e.lastEngineSig, cfg) engineChanged := updateSig(&e.lastEngineSigFull, cfg)
routerChanged := updateSig(&e.lastRouterSig, routerCfg) routerChanged := updateSig(&e.lastRouterSig, routerCfg)
if !engineChanged && !routerChanged { if !engineChanged && !routerChanged {
return ErrNoChanges return ErrNoChanges
} }
e.lastCfg = cfg.Copy() e.lastCfgFull = cfg.Copy()
if engineChanged { // Tell magicsock about the new (or initial) private key
e.logf("wgengine: Reconfig: configuring userspace wireguard config") // (which is needed by DERP) before wgdev gets it, as wgdev
// Tell magicsock about the new (or initial) private key // will start trying to handshake, which we want to be able to
// (which is needed by DERP) before wgdev gets it, as wgdev // go over DERP.
// will start trying to handshake, which we want to be able to if err := e.magicConn.SetPrivateKey(cfg.PrivateKey); err != nil {
// go over DERP. e.logf("wgengine: Reconfig: SetPrivateKey: %v", err)
if err := e.magicConn.SetPrivateKey(cfg.PrivateKey); err != nil { }
e.logf("wgengine: Reconfig: SetPrivateKey: %v", err) e.magicConn.UpdatePeers(peerSet)
}
if err := e.wgdev.Reconfig(cfg); err != nil {
e.logf("wgdev.Reconfig: %v", err)
return err
}
e.magicConn.UpdatePeers(peerSet) if err := e.maybeReconfigWireguardLocked(); err != nil {
return err
} }
if routerChanged { if routerChanged {
@ -758,15 +960,9 @@ func (e *userspaceEngine) getStatus() (*Status, error) {
var peers []PeerStatus var peers []PeerStatus
for _, pk := range e.peerSequence { for _, pk := range e.peerSequence {
p := pp[pk] if p, ok := pp[pk]; ok { // ignore idle ones not in wireguard-go's config
if p == nil { peers = append(peers, *p)
p = &PeerStatus{}
} }
peers = append(peers, *p)
}
if len(pp) != len(e.peerSequence) {
e.logf("wg status returned %v peers, expected %v", len(pp), len(e.peerSequence))
} }
return &Status{ return &Status{

Loading…
Cancel
Save