diff --git a/go.mod b/go.mod index bd6fe441d..74ba5a379 100644 --- a/go.mod +++ b/go.mod @@ -94,7 +94,7 @@ require ( github.com/tailscale/setec v0.0.0-20250205144240-8898a29c3fbb github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 - github.com/tailscale/wireguard-go v0.0.0-20250716170648-1d0488a3d7da + github.com/tailscale/wireguard-go v0.0.0-20251112210417-234d45e2e930 github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e github.com/tc-hib/winres v0.2.1 github.com/tcnksm/go-httpstat v0.2.0 diff --git a/go.sum b/go.sum index 111c99ac9..2cfbb18a4 100644 --- a/go.sum +++ b/go.sum @@ -1010,6 +1010,8 @@ github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15 github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6/go.mod h1:ZXRML051h7o4OcI0d3AaILDIad/Xw0IkXaHM17dic1Y= github.com/tailscale/wireguard-go v0.0.0-20250716170648-1d0488a3d7da h1:jVRUZPRs9sqyKlYHHzHjAqKN+6e/Vog6NpHYeNPJqOw= github.com/tailscale/wireguard-go v0.0.0-20250716170648-1d0488a3d7da/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/wireguard-go v0.0.0-20251112210417-234d45e2e930 h1:SR7Lyxe99k+7IOr1Vfs112Lr5gnhkEa3EArXs6+KXDc= +github.com/tailscale/wireguard-go v0.0.0-20251112210417-234d45e2e930/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= diff --git a/ipn/ipnlocal/extension_host.go b/ipn/ipnlocal/extension_host.go index ca802ab89..f4ec86719 100644 --- a/ipn/ipnlocal/extension_host.go +++ b/ipn/ipnlocal/extension_host.go @@ -8,7 +8,9 @@ import ( "errors" "fmt" "maps" + "os" "reflect" + "runtime" "slices" "strings" "sync" @@ -570,6 +572,9 @@ func (h *ExtensionHost) shutdownWorkQueue() { // for in-flight callbacks associated with those operations to finish. if err := h.workQueue.Wait(ctx); err != nil { h.logf("work queue shutdown failed: %v", err) + b := make([]byte, 2<<20) + n := runtime.Stack(b, true) + os.WriteFile("/tmp/shutdown-hang-stacks.txt", b[:n], 0644) } } diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index fbf34aa42..f9f771f5c 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -520,6 +520,8 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo b.currentNodeAtomic.Store(nb) nb.ready() + sys.Engine.Get().SetPeerByIPLookupFunc(b.lookupPeerByIP) + if sys.InitialConfig != nil { if err := b.initPrefsFromConfig(sys.InitialConfig); err != nil { return nil, err @@ -651,6 +653,25 @@ func (b *LocalBackend) currentNode() *nodeBackend { return b.currentNodeAtomic.Load() } +func (b *LocalBackend) lookupPeerByIP(ip netip.Addr) (peerKey key.NodePublic, ok bool) { + nb := b.currentNode() + nb.mu.Lock() + defer nb.mu.Unlock() + + nid, ok := nb.nodeByAddr[ip] + if !ok { + log.Printf("lookupPeerByIP: %v -> no node ID", ip) + return key.NodePublic{}, false + } + peer, ok := nb.peers[nid] + if !ok { + log.Printf("lookupPeerByIP: no node ID %v", nid) + return key.NodePublic{}, false + } + log.Printf("lookupPeerByIP: %v -> %v (%v)", ip, peer.Name(), peer.Key()) + return peer.Key(), true +} + // FindExtensionByName returns an active extension with the given name, // or nil if no such extension exists. func (b *LocalBackend) FindExtensionByName(name string) any { diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 6e07c7a3d..4e050b8f9 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -55,6 +55,9 @@ const MaxPacketSize = device.MaxContentSize // TAPDebug is whether super verbose TAP debugging is enabled. const TAPDebug = false +// TUNDebug is whether super verbose TUN debugging is enabled. +const TUNDebug = false + var ( // ErrClosed is returned when attempting an operation on a closed Wrapper. ErrClosed = errors.New("device closed") @@ -248,6 +251,16 @@ type tunVectorReadResult struct { dataOffset int } +func (r *tunVectorReadResult) String() string { + if r.err != nil { + return fmt.Sprintf("err=%v", r.err) + } + if r.data == nil { + return fmt.Sprintf("injected packet: pk=%p, len(injected.data)=%d", r.injected.packet, len(r.injected.data)) + } + return fmt.Sprintf("len(data)=%d, off=%d, data=% 02x", len(r.data), r.dataOffset, r.data) +} + // Start unblocks any Wrapper.Read calls that have already started // and makes the Wrapper functional. // @@ -930,6 +943,9 @@ func (t *Wrapper) awaitStart() { } } +// Read implements the tun.Device.Read method, which sends outbound packets. +// (mnemonic: the kernel is reading asking what to send, and we're implementing +// that read by providing packets to send) func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { if !t.started.Load() { t.awaitStart() @@ -939,6 +955,10 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { if !ok { return 0, io.EOF } + if TUNDebug { + t.logf("tstun: Wrapper.Read got outbound: %s", &res) + } + if res.err != nil && len(res.data) == 0 { return 0, res.err } @@ -1065,7 +1085,8 @@ func invertGSOChecksum(pkt []byte, gso netstack_GSO) { pkt[at+1] = ^pkt[at+1] } -// injectedRead handles injected reads, which bypass filters. +// injectedRead handles injected reads (outbound packets from +// [Wrapper.InjectOutbound] and callers), which bypass filters. func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []int, offset int) (n int, err error) { var gso netstack_GSO diff --git a/types/key/node.go b/types/key/node.go index 11ee1fa3c..aa9941cce 100644 --- a/types/key/node.go +++ b/types/key/node.go @@ -74,6 +74,11 @@ func NodePrivateFromRaw32(raw mem.RO) NodePrivate { return ret } +// NodePrivateAs returns a NodePrivate as a named fixed-size array of bytes. +// +// It's intended for interoperability with wireguard-go's device.NoisePrivateKey type. +func NodePrivateAs[T ~[32]byte](k NodePrivate) T { return k.k } + func ParseNodePrivateUntyped(raw mem.RO) (NodePrivate, error) { var ret NodePrivate if err := parseHex(ret.k[:], raw, mem.B(nil)); err != nil { diff --git a/wgengine/userspace.go b/wgengine/userspace.go index a369fa343..0fc3f7fe4 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -10,10 +10,9 @@ import ( "errors" "fmt" "io" - "maps" + "log" "math" "net/netip" - "reflect" "runtime" "slices" "strings" @@ -132,10 +131,8 @@ type userspaceEngine struct { wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below lastCfgFull wgcfg.Config - lastNMinPeers int lastRouter *router.Config - lastEngineFull *wgcfg.Config // of full wireguard config, not trimmed - lastEngineInputs *maybeReconfigInputs + lastEngineFull *wgcfg.Config // of full wireguard config, not trimmed lastDNSConfig dns.ConfigView // or invalid if none lastIsSubnetRouter bool // was the node a primary subnet router in the last run. recvActivityAt map[key.NodePublic]mono.Time @@ -687,7 +684,7 @@ func (e *userspaceEngine) noteRecvActivity(nk key.NodePublic) { // couple minutes (just not on every packet). if e.trimmedNodes[nk] { e.logf("wgengine: idle peer %v now active, reconfiguring WireGuard", nk.ShortString()) - e.maybeReconfigWireguardLocked(nil) + e.maybeReconfigWireguardLocked(false) } } @@ -706,36 +703,13 @@ func (e *userspaceEngine) isActiveSinceLocked(nk key.NodePublic, ip netip.Addr, return timePtr.LoadAtomic().After(t) } -// maybeReconfigInputs holds the inputs to the maybeReconfigWireguardLocked -// function. If these things don't change between calls, there's nothing to do. -type maybeReconfigInputs struct { - WGConfig *wgcfg.Config - TrimmedNodes map[key.NodePublic]bool - TrackNodes views.Slice[key.NodePublic] - TrackIPs views.Slice[netip.Addr] -} - -func (i *maybeReconfigInputs) Equal(o *maybeReconfigInputs) bool { - return reflect.DeepEqual(i, o) -} - -func (i *maybeReconfigInputs) Clone() *maybeReconfigInputs { - if i == nil { - return nil - } - v := *i - v.WGConfig = i.WGConfig.Clone() - v.TrimmedNodes = maps.Clone(i.TrimmedNodes) - return &v -} - // discoChanged are the set of peers whose disco keys have changed, implying they've restarted. // If a peer is in this set and was previously in the live wireguard config, // it needs to be first removed and then re-added to flush out its wireguard session key. // If discoChanged is nil or empty, this extra removal step isn't done. // // e.wgLock must be held. -func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.NodePublic]bool) error { +func (e *userspaceEngine) maybeReconfigWireguardLocked(forceReconfig bool) error { if hook := e.testMaybeReconfigHook; hook != nil { hook() return nil @@ -744,104 +718,8 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node full := e.lastCfgFull e.wgLogger.SetPeers(full.Peers) - // Compute a minimal config to pass to wireguard-go - // based on the full config. Prune off all the peers - // and only add the active ones back. - min := full - min.Peers = make([]wgcfg.Peer, 0, e.lastNMinPeers) - - // We'll only keep a peer around if it's been active in - // the past 5 minutes. That's more than WireGuard's key - // rotation time anyway so it's no harm if we remove it - // later if it's been inactive. - var activeCutoff mono.Time - if buildfeatures.HasLazyWG { - activeCutoff = e.timeNow().Add(-lazyPeerIdleThreshold) - } - - // Not all peers can be trimmed from the network map (see - // isTrimmablePeer). For those that are trimmable, keep track of - // their NodeKey and Tailscale IPs. These are the ones we'll need - // to install tracking hooks for to watch their send/receive - // activity. - var trackNodes []key.NodePublic - var trackIPs []netip.Addr - if buildfeatures.HasLazyWG { - trackNodes = make([]key.NodePublic, 0, len(full.Peers)) - trackIPs = make([]netip.Addr, 0, len(full.Peers)) - } - - // Don't re-alloc the map; the Go compiler optimizes map clears as of - // Go 1.11, so we can re-use the existing + allocated map. - if e.trimmedNodes != nil { - clear(e.trimmedNodes) - } else { - e.trimmedNodes = make(map[key.NodePublic]bool) - } - - needRemoveStep := false - for i := range full.Peers { - p := &full.Peers[i] - nk := p.PublicKey - if !buildfeatures.HasLazyWG || !e.isTrimmablePeer(p, len(full.Peers)) { - min.Peers = append(min.Peers, *p) - if discoChanged[nk] { - needRemoveStep = true - } - continue - } - trackNodes = append(trackNodes, nk) - recentlyActive := false - for _, cidr := range p.AllowedIPs { - trackIPs = append(trackIPs, cidr.Addr()) - recentlyActive = recentlyActive || e.isActiveSinceLocked(nk, cidr.Addr(), activeCutoff) - } - if recentlyActive { - min.Peers = append(min.Peers, *p) - if discoChanged[nk] { - needRemoveStep = true - } - } else { - e.trimmedNodes[nk] = true - } - } - e.lastNMinPeers = len(min.Peers) - - if changed := checkchange.Update(&e.lastEngineInputs, &maybeReconfigInputs{ - WGConfig: &min, - TrimmedNodes: e.trimmedNodes, - TrackNodes: views.SliceOf(trackNodes), - TrackIPs: views.SliceOf(trackIPs), - }); !changed { - return nil - } - - if buildfeatures.HasLazyWG { - e.updateActivityMapsLocked(trackNodes, trackIPs) - } - - if needRemoveStep { - minner := min - minner.Peers = nil - numRemove := 0 - for _, p := range min.Peers { - if discoChanged[p.PublicKey] { - numRemove++ - continue - } - minner.Peers = append(minner.Peers, p) - } - if numRemove > 0 { - e.logf("wgengine: Reconfig: removing session keys for %d peers", numRemove) - if err := wgcfg.ReconfigDevice(e.wgdev, &minner, e.logf); err != nil { - e.logf("wgdev.Reconfig: %v", err) - return err - } - } - } - - e.logf("wgengine: Reconfig: configuring userspace WireGuard config (with %d/%d peers)", len(min.Peers), len(full.Peers)) - if err := wgcfg.ReconfigDevice(e.wgdev, &min, e.logf); err != nil { + e.logf("wgengine: Reconfig: configuring userspace WireGuard config (with %d peers)", len(full.Peers)) + if err := wgcfg.ReconfigDevice(e.wgdev, &full, e.logf); err != nil { e.logf("wgdev.Reconfig: %v", err) return err } @@ -896,7 +774,7 @@ func (e *userspaceEngine) updateActivityMapsLocked(trackNodes []key.NodePublic, if elapsed >= packetSendRecheckWireguardThreshold { e.wgLock.Lock() defer e.wgLock.Unlock() - e.maybeReconfigWireguardLocked(nil) + e.maybeReconfigWireguardLocked(false) } } } @@ -1029,10 +907,8 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, } // See if any peers have changed disco keys, which means they've restarted. - // If so, we need to update the wireguard-go/device.Device in two phases: - // once without the node which has restarted, to clear its wireguard session key, - // and a second time with it. - discoChanged := make(map[key.NodePublic]bool) + // If we see that, we clear our wireguard-go session state for that peer. + forceReconfig := false { prevEP := make(map[key.NodePublic]key.DiscoPublic) for i := range e.lastCfgFull.Peers { @@ -1047,7 +923,8 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, } pub := p.PublicKey if old, ok := prevEP[pub]; ok && old != p.DiscoKey { - discoChanged[pub] = true + e.wgdev.RemovePeer(pub.Raw32()) + forceReconfig = true // to make sure we add it back e.logf("wgengine: Reconfig: %s changed from %q to %q", pub.ShortString(), old, p.DiscoKey) } } @@ -1066,7 +943,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, e.magicConn.SetPreferredPort(listenPort) e.magicConn.UpdatePMTUD() - if err := e.maybeReconfigWireguardLocked(discoChanged); err != nil { + if err := e.maybeReconfigWireguardLocked(forceReconfig); err != nil { return err } @@ -1205,8 +1082,8 @@ func (e *userspaceEngine) PeerByKey(pubKey key.NodePublic) (_ wgint.Peer, ok boo if dev == nil { return wgint.Peer{}, false } - peer := dev.LookupPeer(pubKey.Raw32()) - if peer == nil { + peer, ok := dev.LookupActivePeer(pubKey.Raw32()) + if !ok || peer == nil { return wgint.Peer{}, false } return wgint.PeerOf(peer), true @@ -1765,3 +1642,30 @@ func (e *userspaceEngine) reconfigureVPNIfNecessary() error { } return e.reconfigureVPN() } + +func (e *userspaceEngine) SetPeerByIPLookupFunc(fn func(netip.Addr) (key.NodePublic, bool)) { + e.wgdev.SetPeerByIPLookupFunc(func(addr netip.Addr) (_ *device.Peer, ok bool) { + pk, ok := fn(addr) + if !ok { + return nil, false + } + // TODO(bradfitz): optimize this LookupPeer map lookup on each packet; + // store it in the leaf of the bart lookup. + if peer, ok := e.wgdev.LookupActivePeer(pk.Raw32()); ok { + log.Printf("XXX active peer for %v found in LookupActivePeer", pk.ShortString()) + return peer, true + } + + peer := e.wgdev.LookupPeer(pk.Raw32()) + if peer == nil { + return nil, false + } + log.Printf("XXX making new peer for %v", pk.ShortString()) + ep, err := e.magicConn.ParseEndpoint(fmt.Sprintf("%02x", pk.Raw32())) + if err != nil { + return nil, false + } + peer.SetEndpointFromPacket(ep) + return peer, peer != nil + }) +} diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index 9cc4ed3b5..175b43744 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -179,3 +179,7 @@ func (e *watchdogEngine) InstallCaptureHook(cb packet.CaptureCallback) { func (e *watchdogEngine) PeerByKey(pubKey key.NodePublic) (_ wgint.Peer, ok bool) { return e.wrap.PeerByKey(pubKey) } + +func (e *watchdogEngine) SetPeerByIPLookupFunc(fn func(netip.Addr) (key.NodePublic, bool)) { + e.wrap.SetPeerByIPLookupFunc(fn) +} diff --git a/wgengine/wgcfg/device.go b/wgengine/wgcfg/device.go index ee7eb91c9..ae9bfbbe5 100644 --- a/wgengine/wgcfg/device.go +++ b/wgengine/wgcfg/device.go @@ -4,13 +4,17 @@ package wgcfg import ( - "errors" - "io" - "sort" + "fmt" + "log" + "net/netip" + "runtime" + "sync/atomic" + "time" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" + "tailscale.com/types/key" "tailscale.com/types/logger" ) @@ -21,26 +25,6 @@ func NewDevice(tunDev tun.Device, bind conn.Bind, logger *device.Logger) *device return ret } -func DeviceConfig(d *device.Device) (*Config, error) { - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcGetOperation(w) - w.Close() - }() - cfg, fromErr := FromUAPI(r) - r.Close() - getErr := <-errc - err := errors.Join(getErr, fromErr) - if err != nil { - return nil, err - } - sort.Slice(cfg.Peers, func(i, j int) bool { - return cfg.Peers[i].PublicKey.Less(cfg.Peers[j].PublicKey) - }) - return cfg, nil -} - // ReconfigDevice replaces the existing device configuration with cfg. func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) { defer func() { @@ -49,20 +33,44 @@ func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) } }() - prev, err := DeviceConfig(d) - if err != nil { - return err + d.SetPrivateKey(key.NodePrivateAs[device.NoisePrivateKey](cfg.PrivateKey)) + + peers := map[device.NoisePublicKey][]netip.Prefix{} // public key → allowed IPs + for _, p := range cfg.Peers { + peers[p.PublicKey.Raw32()] = p.AllowedIPs } + d.RemoveMatchingPeers(func(pk device.NoisePublicKey) bool { + _, exists := peers[pk] + return !exists + }) - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcSetOperation(r) - r.Close() - }() + var lastStack atomic.Int64 + + d.SetPeerLookupFunc(func(pubk device.NoisePublicKey) (_ *device.NewPeerConfig, ok bool) { + allowedIPs, ok := peers[pubk] + if !ok { + return nil, false + } + + var buf []byte + now := time.Now().Unix() + if lastStack.Swap(now) != now { + buf = make([]byte, 4<<10) + buf = buf[:runtime.Stack(buf, false)] + } + + log.Printf("XXX wgcfg.ReconfigDevice: lookup for peer %v, found=%v => %v, stack: %s", pubk, ok, allowedIPs, buf) + bind := d.Bind() + ep, err := bind.ParseEndpoint(fmt.Sprintf("%02x", pubk[:])) + if err != nil { + logf("wgcfg: failed to parse endpoint for peer %v: %v", pubk, err) + return nil, false + } + return &device.NewPeerConfig{ + AllowedIPs: allowedIPs, + Endpoint: ep, + }, ok + }) - toErr := cfg.ToUAPI(logf, w, prev) - w.Close() - setErr := <-errc - return errors.Join(setErr, toErr) + return nil } diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go deleted file mode 100644 index 9138d6e5a..000000000 --- a/wgengine/wgcfg/device_test.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "bytes" - "io" - "net/netip" - "os" - "sort" - "strings" - "sync" - "testing" - - "github.com/tailscale/wireguard-go/conn" - "github.com/tailscale/wireguard-go/device" - "github.com/tailscale/wireguard-go/tun" - "go4.org/mem" - "tailscale.com/types/key" -) - -func TestDeviceConfig(t *testing.T) { - newK := func() (key.NodePublic, key.NodePrivate) { - t.Helper() - k := key.NewNode() - return k.Public(), k - } - k1, pk1 := newK() - ip1 := netip.MustParsePrefix("10.0.0.1/32") - - k2, pk2 := newK() - ip2 := netip.MustParsePrefix("10.0.0.2/32") - - k3, _ := newK() - ip3 := netip.MustParsePrefix("10.0.0.3/32") - - cfg1 := &Config{ - PrivateKey: pk1, - Peers: []Peer{{ - PublicKey: k2, - AllowedIPs: []netip.Prefix{ip2}, - }}, - } - - cfg2 := &Config{ - PrivateKey: pk2, - Peers: []Peer{{ - PublicKey: k1, - AllowedIPs: []netip.Prefix{ip1}, - PersistentKeepalive: 5, - }}, - } - - device1 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1")) - device2 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2")) - defer device1.Close() - defer device2.Close() - - cmp := func(t *testing.T, d *device.Device, want *Config) { - t.Helper() - got, err := DeviceConfig(d) - if err != nil { - t.Fatal(err) - } - prev := new(Config) - gotbuf := new(strings.Builder) - err = got.ToUAPI(t.Logf, gotbuf, prev) - gotStr := gotbuf.String() - if err != nil { - t.Errorf("got.ToUAPI(): error: %v", err) - return - } - wantbuf := new(strings.Builder) - err = want.ToUAPI(t.Logf, wantbuf, prev) - wantStr := wantbuf.String() - if err != nil { - t.Errorf("want.ToUAPI(): error: %v", err) - return - } - if gotStr != wantStr { - buf := new(bytes.Buffer) - w := bufio.NewWriter(buf) - if err := d.IpcGetOperation(w); err != nil { - t.Errorf("on error, could not IpcGetOperation: %v", err) - } - w.Flush() - t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) - } - } - - t.Run("device1 config", func(t *testing.T) { - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device2 config", func(t *testing.T) { - if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device2, cfg2) - }) - - // This is only to test that Config and Reconfig are properly synchronized. - t.Run("device2 config/reconfig", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(2) - - go func() { - ReconfigDevice(device2, cfg2, t.Logf) - wg.Done() - }() - - go func() { - DeviceConfig(device2) - wg.Done() - }() - - wg.Wait() - }) - - t.Run("device1 modify peer", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 1, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1 replace endpoint", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 2, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1 add new peer", func(t *testing.T) { - cfg1.Peers = append(cfg1.Peers, Peer{ - PublicKey: k3, - AllowedIPs: []netip.Prefix{ip3}, - }) - sort.Slice(cfg1.Peers, func(i, j int) bool { - return cfg1.Peers[i].PublicKey.Less(cfg1.Peers[j].PublicKey) - }) - - origCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - peer0 := func(cfg *Config) Peer { - p, ok := cfg.PeerWithKey(k2) - if !ok { - t.Helper() - t.Fatal("failed to look up peer 2") - } - return p - } - peersEqual := func(p, q Peer) bool { - return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs) - } - if !peersEqual(peer0(origCfg), peer0(newCfg)) { - t.Error("reconfig modified old peer") - } - }) - - t.Run("device1 remove peer", func(t *testing.T) { - removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey - cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - _, ok := newCfg.PeerWithKey(removeKey) - if ok { - t.Error("reconfig failed to remove peer") - } - }) -} - -// TODO: replace with a loopback tunnel -type nilTun struct { - events chan tun.Event - closed chan struct{} -} - -func newNilTun() tun.Device { - return &nilTun{ - events: make(chan tun.Event), - closed: make(chan struct{}), - } -} - -func (t *nilTun) File() *os.File { return nil } -func (t *nilTun) Flush() error { return nil } -func (t *nilTun) MTU() (int, error) { return 1420, nil } -func (t *nilTun) Name() (string, error) { return "niltun", nil } -func (t *nilTun) Events() <-chan tun.Event { return t.events } - -func (t *nilTun) Read(data [][]byte, sizes []int, offset int) (int, error) { - <-t.closed - return 0, io.EOF -} - -func (t *nilTun) Write(data [][]byte, offset int) (int, error) { - <-t.closed - return 0, io.EOF -} - -func (t *nilTun) Close() error { - close(t.events) - close(t.closed) - return nil -} - -func (t *nilTun) BatchSize() int { return 1 } - -// A noopBind is a conn.Bind that does no actual binding work. -type noopBind struct{} - -func (noopBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) { - return nil, 1, nil -} -func (noopBind) Close() error { return nil } -func (noopBind) SetMark(mark uint32) error { return nil } -func (noopBind) Send(b [][]byte, ep conn.Endpoint, offset int) error { return nil } -func (noopBind) ParseEndpoint(s string) (conn.Endpoint, error) { - return dummyEndpoint(s), nil -} -func (noopBind) BatchSize() int { return 1 } - -// A dummyEndpoint is a string holding the endpoint destination. -type dummyEndpoint string - -func (e dummyEndpoint) ClearSrc() {} -func (e dummyEndpoint) SrcToString() string { return "" } -func (e dummyEndpoint) DstToString() string { return string(e) } -func (e dummyEndpoint) DstToBytes() []byte { return nil } -func (e dummyEndpoint) DstIP() netip.Addr { return netip.Addr{} } -func (dummyEndpoint) SrcIP() netip.Addr { return netip.Addr{} } diff --git a/wgengine/wgcfg/parser_test.go b/wgengine/wgcfg/parser_test.go index a5d7ad44f..377afd7bb 100644 --- a/wgengine/wgcfg/parser_test.go +++ b/wgengine/wgcfg/parser_test.go @@ -4,15 +4,9 @@ package wgcfg import ( - "bufio" - "bytes" - "io" - "net/netip" "reflect" "runtime" "testing" - - "tailscale.com/types/key" ) func noError(t *testing.T, err error) bool { @@ -58,38 +52,3 @@ func TestParseEndpoint(t *testing.T) { t.Error("Error was expected") } } - -func BenchmarkFromUAPI(b *testing.B) { - newK := func() (key.NodePublic, key.NodePrivate) { - b.Helper() - k := key.NewNode() - return k.Public(), k - } - k1, pk1 := newK() - ip1 := netip.MustParsePrefix("10.0.0.1/32") - - peer := Peer{ - PublicKey: k1, - AllowedIPs: []netip.Prefix{ip1}, - } - cfg1 := &Config{ - PrivateKey: pk1, - Peers: []Peer{peer, peer, peer, peer}, - } - - buf := new(bytes.Buffer) - w := bufio.NewWriter(buf) - if err := cfg1.ToUAPI(b.Logf, w, &Config{}); err != nil { - b.Fatal(err) - } - w.Flush() - r := bytes.NewReader(buf.Bytes()) - b.ReportAllocs() - for range b.N { - r.Seek(0, io.SeekStart) - _, err := FromUAPI(r) - if err != nil { - b.Errorf("failed from UAPI: %v", err) - } - } -} diff --git a/wgengine/wgcfg/writer.go b/wgengine/wgcfg/writer.go deleted file mode 100644 index 9cdd31df2..000000000 --- a/wgengine/wgcfg/writer.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) Tailscale Inc & AUTHORS -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "fmt" - "io" - "net/netip" - "strconv" - - "tailscale.com/types/key" - "tailscale.com/types/logger" -) - -// ToUAPI writes cfg in UAPI format to w. -// Prev is the previous device Config. -// -// Prev is required so that we can remove now-defunct peers without having to -// remove and re-add all peers, and so that we can avoid writing information -// about peers that have not changed since the previous time we wrote our -// Config. -func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error { - var stickyErr error - set := func(key, value string) { - if stickyErr != nil { - return - } - _, err := fmt.Fprintf(w, "%s=%s\n", key, value) - if err != nil { - stickyErr = err - } - } - setUint16 := func(key string, value uint16) { - set(key, strconv.FormatUint(uint64(value), 10)) - } - setPeer := func(peer Peer) { - set("public_key", peer.PublicKey.UntypedHexString()) - } - - // Device config. - if !prev.PrivateKey.Equal(cfg.PrivateKey) { - set("private_key", cfg.PrivateKey.UntypedHexString()) - } - - old := make(map[key.NodePublic]Peer) - for _, p := range prev.Peers { - old[p.PublicKey] = p - } - - // Add/configure all new peers. - for _, p := range cfg.Peers { - oldPeer, wasPresent := old[p.PublicKey] - - // We only want to write the peer header/version if we're about - // to change something about that peer, or if it's a new peer. - // Figure out up-front whether we'll need to do anything for - // this peer, and skip doing anything if not. - // - // If the peer was not present in the previous config, this - // implies that this is a new peer; set all of these to 'true' - // to ensure that we're writing the full peer configuration. - willSetEndpoint := oldPeer.WGEndpoint != p.PublicKey || !wasPresent - willChangeIPs := !cidrsEqual(oldPeer.AllowedIPs, p.AllowedIPs) || !wasPresent - willChangeKeepalive := oldPeer.PersistentKeepalive != p.PersistentKeepalive // if not wasPresent, no need to redundantly set zero (default) - - if !willSetEndpoint && !willChangeIPs && !willChangeKeepalive { - // It's safe to skip doing anything here; wireguard-go - // will not remove a peer if it's unspecified unless we - // tell it to (which we do below if necessary). - continue - } - - setPeer(p) - set("protocol_version", "1") - - // Avoid setting endpoints if the correct one is already known - // to WireGuard, because doing so generates a bit more work in - // calling magicsock's ParseEndpoint for effectively a no-op. - if willSetEndpoint { - if wasPresent { - // We had an endpoint, and it was wrong. - // By construction, this should not happen. - // If it does, keep going so that we can recover from it, - // but log so that we know about it, - // because it is an indicator of other failed invariants. - // See corp issue 3016. - logf("[unexpected] endpoint changed from %s to %s", oldPeer.WGEndpoint, p.PublicKey) - } - set("endpoint", p.PublicKey.UntypedHexString()) - } - - // TODO: replace_allowed_ips is expensive. - // If p.AllowedIPs is a strict superset of oldPeer.AllowedIPs, - // then skip replace_allowed_ips and instead add only - // the new ipps with allowed_ip. - if willChangeIPs { - set("replace_allowed_ips", "true") - for _, ipp := range p.AllowedIPs { - set("allowed_ip", ipp.String()) - } - } - - // Set PersistentKeepalive after the peer is otherwise configured, - // because it can trigger handshake packets. - if willChangeKeepalive { - setUint16("persistent_keepalive_interval", p.PersistentKeepalive) - } - } - - // Remove peers that were present but should no longer be. - for _, p := range cfg.Peers { - delete(old, p.PublicKey) - } - for _, p := range old { - setPeer(p) - set("remove", "true") - } - - if stickyErr != nil { - stickyErr = fmt.Errorf("ToUAPI: %w", stickyErr) - } - return stickyErr -} - -func cidrsEqual(x, y []netip.Prefix) bool { - // TODO: re-implement using netaddr.IPSet.Equal. - if len(x) != len(y) { - return false - } - // First see if they're equal in order, without allocating. - exact := true - for i := range x { - if x[i] != y[i] { - exact = false - break - } - } - if exact { - return true - } - - // Otherwise, see if they're the same, but out of order. - m := make(map[netip.Prefix]bool) - for _, v := range x { - m[v] = true - } - for _, v := range y { - if !m[v] { - return false - } - } - return true -} diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index be7873147..5b748fefe 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -97,6 +97,8 @@ type Engine interface { // WireGuard status changes. SetStatusCallback(StatusCallback) + SetPeerByIPLookupFunc(func(netip.Addr) (key.NodePublic, bool)) + // RequestStatus requests a WireGuard status update right // away, sent to the callback registered via SetStatusCallback. RequestStatus()