From b59d58bb89b7880ae8ba55d1946baffe05ae1c05 Mon Sep 17 00:00:00 2001 From: Jonathan Nobels Date: Mon, 20 Oct 2025 14:41:08 -0400 Subject: [PATCH] net/netns: interface probe prototype WIP Experiment with an netns alternative that doesn't rely on the system routing table, but rather probes routes to find a working interface. --- net/netns/netns.go | 2 + net/netns/netns_darwin.go | 168 ++++----- net/netns/netns_dw.go | 21 -- net/netns/netns_probe.go | 454 +++++++++++++++++++++++ net/netns/netns_test.go | 739 ++++++++++++++++++++++++++++++++++++++ wgengine/userspace.go | 2 + 6 files changed, 1285 insertions(+), 101 deletions(-) create mode 100644 net/netns/netns_probe.go diff --git a/net/netns/netns.go b/net/netns/netns.go index 81ab5e2a2..27c46217b 100644 --- a/net/netns/netns.go +++ b/net/netns/netns.go @@ -72,6 +72,8 @@ func SetDisableBindConnToInterfaceAppleExt(logf logger.Logf, v bool) { } } +var probeInterfaces atomic.Bool + // Listener returns a new net.Listener with its Control hook func // initialized as necessary to run in logical network namespace that // doesn't route back into Tailscale. diff --git a/net/netns/netns_darwin.go b/net/netns/netns_darwin.go index ff05a3f31..c35f3d39c 100644 --- a/net/netns/netns_darwin.go +++ b/net/netns/netns_darwin.go @@ -8,7 +8,6 @@ package netns import ( "errors" "fmt" - "log" "net" "net/netip" "os" @@ -19,7 +18,6 @@ import ( "golang.org/x/sys/unix" "tailscale.com/envknob" "tailscale.com/net/netmon" - "tailscale.com/net/tsaddr" "tailscale.com/types/logger" "tailscale.com/version" ) @@ -37,23 +35,103 @@ var errInterfaceStateInvalid = errors.New("interface state invalid") // controlLogf binds c to a particular interface as necessary to dial the // provided (network, address). func controlLogf(logf logger.Logf, netMon *netmon.Monitor, network, address string, c syscall.RawConn) error { - if disableBindConnToInterface.Load() || (version.IsMacGUIVariant() && disableBindConnToInterfaceAppleExt.Load()) { + if isLocalhost(address) { return nil } - if isLocalhost(address) { + /// FIXME: (barnstar) Temporary probeInterfaces logic. Maybe set via a cap? By platform? So may caps. + probeInterfaces.Store(true) + if probeInterfaces.Load() { + host, port, err := net.SplitHostPort(address) + if err != nil { + return fmt.Errorf("netns: control: SplitHostPort %q: %w", address, err) + } + + opts := probeOpts{ + logf: logf, + hpn: HostPortNetwork{Network: network, Host: host, Port: port}, + filterf: filterInvalidIntefaces, + race: true, + cache: globalRouteCache, + } + + // No netmon and no routing table. + iface, err := findInterfaceThatCanReach(opts) + + if err != nil || iface == nil { + return err + } + + bindFn := getBindFn(network, address) + logf("netns: post-probe binding to interface %q (index %d) for %s/%s", iface.Name, iface.Index, network, address) + return bindFn(c, uint32(iface.Index)) + } + + // Not probing? Then check if we should bind at all. + if disableBindConnToInterface.Load() || (version.IsMacGUIVariant() && disableBindConnToInterfaceAppleExt.Load()) { return nil } - idx, err := getInterfaceIndex(logf, netMon, address) + // Bind using the legacy RIB / netmon method. + idx, _ := getInterfaceIndex(logf, netMon, address) + bindFn := getBindFn(network, address) + return bindFn(c, uint32(idx)) +} + +func filterInvalidIntefaces(iface net.Interface) bool { + uninterestingPrefixes := []string{"awdl", "llw", "gif", "stf", "ipsec", "bond", "fwip", "utun"} + + for _, prefix := range uninterestingPrefixes { + if strings.HasPrefix(iface.Name, prefix) { + return false + } + } + return true +} + +// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound +// to the provided interface index. +func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error { + if lc == nil { + return errors.New("nil ListenConfig") + } + if lc.Control != nil { + return errors.New("ListenConfig.Control already set") + } + lc.Control = func(network, address string, c syscall.RawConn) error { + bindFn := getBindFn(network, address) + return bindFn(c, uint32(ifIndex)) + } + return nil +} + +func bindSocket6(c syscall.RawConn, idx uint32) error { + var sockErr error + err := c.Control(func(fd uintptr) { + sockErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, int(idx)) + }) if err != nil { - // callee logged - return nil + return fmt.Errorf("RawConn.Control on %T: %w", c, err) } + return sockErr +} - return bindConnToInterface(c, network, address, idx, logf) +func bindSocket4(c syscall.RawConn, idx uint32) error { + var sockErr error + err := c.Control(func(fd uintptr) { + sockErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, int(idx)) + }) + if err != nil { + return fmt.Errorf("RawConn.Control on %T: %w", c, err) + } + return sockErr } +// Legacy + +// getInterfaceIndex returns the interface index that we should bind to +// in order to send traffic to the provided address using netmon's view of +// the DefaultRouteInterfaceIndex and/or a direct query to the routing table. func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string) (int, error) { // Helper so we can log errors. defaultIdx := func() (int, error) { @@ -115,14 +193,9 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string) } // If the address doesn't parse, use the default index. - addr, err := parseAddress(address) - if err != nil { - if err != errUnspecifiedHost { - logf("[unexpected] netns: error parsing address %q: %v", address, err) - } - return defaultIdx() - } + logf("netns: getting interface index for address %q", address) + addr, err := parseAddress(address) idx, err := interfaceIndexFor(addr, true /* canRecurse */) if err != nil { logf("netns: error getting interface index for %q: %v", address, err) @@ -143,34 +216,6 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string) return idx, err } -// tailscaleInterface returns the current machine's Tailscale interface, if any. -// If none is found, (nil, nil) is returned. -// A non-nil error is only returned on a problem listing the system interfaces. -func tailscaleInterface() (*net.Interface, error) { - ifs, err := net.Interfaces() - if err != nil { - return nil, err - } - for _, iface := range ifs { - if !strings.HasPrefix(iface.Name, "utun") { - continue - } - addrs, err := iface.Addrs() - if err != nil { - continue - } - for _, a := range addrs { - if ipnet, ok := a.(*net.IPNet); ok { - nip, ok := netip.AddrFromSlice(ipnet.IP) - if ok && tsaddr.IsTailscaleIP(nip.Unmap()) { - return &iface, nil - } - } - } - } - return nil, nil -} - // interfaceIndexFor returns the interface index that we should bind to in // order to send traffic to the provided address. func interfaceIndexFor(addr netip.Addr, canRecurse bool) (int, error) { @@ -276,40 +321,3 @@ func interfaceIndexFor(addr netip.Addr, canRecurse bool) (int, error) { return 0, fmt.Errorf("no valid address found") } - -// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound -// to the provided interface index. -func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error { - if lc == nil { - return errors.New("nil ListenConfig") - } - if lc.Control != nil { - return errors.New("ListenConfig.Control already set") - } - lc.Control = func(network, address string, c syscall.RawConn) error { - return bindConnToInterface(c, network, address, ifIndex, log.Printf) - } - return nil -} - -func bindConnToInterface(c syscall.RawConn, network, address string, ifIndex int, logf logger.Logf) error { - v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6 - proto := unix.IPPROTO_IP - opt := unix.IP_BOUND_IF - if v6 { - proto = unix.IPPROTO_IPV6 - opt = unix.IPV6_BOUND_IF - } - - var sockErr error - err := c.Control(func(fd uintptr) { - sockErr = unix.SetsockoptInt(int(fd), proto, opt, ifIndex) - }) - if sockErr != nil { - logf("[unexpected] netns: bindConnToInterface(%q, %q), v6=%v, index=%v: %v", network, address, v6, ifIndex, sockErr) - } - if err != nil { - return fmt.Errorf("RawConn.Control on %T: %w", c, err) - } - return sockErr -} diff --git a/net/netns/netns_dw.go b/net/netns/netns_dw.go index b9f750e8a..1e1f38b55 100644 --- a/net/netns/netns_dw.go +++ b/net/netns/netns_dw.go @@ -5,27 +5,6 @@ package netns -import ( - "errors" - "net" - "net/netip" -) - -var errUnspecifiedHost = errors.New("unspecified host") - -func parseAddress(address string) (addr netip.Addr, err error) { - host, _, err := net.SplitHostPort(address) - if err != nil { - // error means the string didn't contain a port number, so use the string directly - host = address - } - if host == "" { - return addr, errUnspecifiedHost - } - - return netip.ParseAddr(host) -} - func UseSocketMark() bool { return false } diff --git a/net/netns/netns_probe.go b/net/netns/netns_probe.go new file mode 100644 index 000000000..ffd76e780 --- /dev/null +++ b/net/netns/netns_probe.go @@ -0,0 +1,454 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Package netns contains the common code for using the Go net package +// in a logical "network namespace" to avoid routing loops where +// Tailscale-created packets would otherwise loop back through +// Tailscale routes. +// +// Despite the name netns, the exact mechanism used differs by +// operating system, and perhaps even by version of the OS. +// +// The netns package also handles connecting via SOCKS proxies when +// configured by the environment. + +package netns + +import ( + "context" + "errors" + "fmt" + "net" + "net/netip" + "strings" + "syscall" + "time" + + "github.com/gaissmai/bart" + "tailscale.com/net/netmon" + "tailscale.com/net/tsaddr" + "tailscale.com/syncs" + "tailscale.com/types/logger" + "tailscale.com/util/eventbus" +) + +// tailscaleInterface returns the current machine's Tailscale interface, if any. +// If none is found, (nil, nil) is returned. +// A non-nil error is only returned on a problem listing the system interfaces. +// TODO (barnstar): netmon *usually* knows this (at least for darwing), but +// this is more portable. It's still wildly different than the Windows method which +// checks the description strings. +func tailscaleInterface() (*net.Interface, error) { + ifs, err := net.Interfaces() + if err != nil { + return nil, err + } + for _, iface := range ifs { + addrs, err := iface.Addrs() + if err != nil { + continue + } + for _, a := range addrs { + if ipnet, ok := a.(*net.IPNet); ok { + nip, ok := netip.AddrFromSlice(ipnet.IP) + if ok && tsaddr.IsTailscaleIP(nip.Unmap()) { + return &iface, nil + } + } + } + } + return nil, nil +} + +// inetReachability describes an interface and whether it was able to reach +// the provided address. +type inetReachability struct { + iface net.Interface + // TODO (barnstar): These are invariant. reachable should be true if err==nil. + reachable bool + err error +} + +// Tuple of the destination host, port, and network. +// ie: "tcp4", "example.com", "80" +type HostPortNetwork struct { + Host string + Port string + Network string +} + +func (hpn HostPortNetwork) String() string { + return fmt.Sprintf("%s/%s:%s", hpn.Network, hpn.Host, hpn.Port) +} + +type probeOpts struct { + logf logger.Logf + hpn HostPortNetwork + race bool // if true, we'll pick the first interface that responds. sortf is ignored. + filterf interfaceFilter // optional pre-filter for interfaces + cache *routeCache // must be non-nil +} + +type DefaultIfaceHintFn func() int + +var defaultIfaceHintFn DefaultIfaceHintFn + +// Platforms may set defaultIFQueryFn to a function that returns the platforms's high +// level view of the default interface index. +func SetDefaultIFQueryFn(fn DefaultIfaceHintFn) { + defaultIfaceHintFn = fn +} + +// uint +type bindFn func(c syscall.RawConn, ifidx uint32) error + +// Returns the proper bind function for the given network and address. +// Currently only differentiates between IPv4 and IPv6 - and poorly. +func bindFnByAddrType(network, address string) bindFn { + // Very naive check for IPv6. + if strings.Contains(address, "]:") || strings.HasSuffix(network, "6") { + return bindSocket6 + } + return bindSocket4 +} + +type bindFunctionHook func(network, address string) bindFn + +var getBindFn bindFunctionHook = bindFnByAddrType + +var interfacesHookFn func() ([]net.Interface, error) + +var interfacesHook = net.Interfaces + +// ProbeInterfacesReachability probes all non-loopback, up interfaces +// concurrently to determine which can reach the given address. It returns +// a slice with one entry per probed interface in the same order as +// net.Interfaces() filtered by the probe criteria. +func probeInterfacesReachability(opts probeOpts) ([]inetReachability, error) { + ifaces, err := interfacesHook() + if err != nil { + opts.logf("netns: ProbeInterfacesReachability: net.Interfaces: %v", err) + return nil, err + } + + results := make(chan inetReachability, len(ifaces)) + + tsiface, _ := tailscaleInterface() + + var candidates []net.Interface + for _, iface := range ifaces { + // Individual platforms can exclude potential intefaces based on platorm-specific logic. + // For example, on Darwin, we skip "utun" interfaces. + if opts.filterf != nil && !opts.filterf(iface) { + continue + } + + // Only consider up, non-loopback interfaces. + if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagRunning == 0 { + continue + } + + // Skip the Tailscale interface. + if tsiface != nil && iface.Index == tsiface.Index { + continue + } + + // require an IPv4 or IPv6 global unicast address + if !ifaceHasV4OrGlobalV6(&iface) { + continue + } + + candidates = append(candidates, iface) + } + + if len(candidates) == 0 { + opts.logf("netns: ProbeInterfacesReachability: no candidate interfaces found") + return nil, errors.New("no candidate interfaces") + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + for _, iface := range candidates { + go func() { + // Per-probe timeout. + + err := reachabilityHook(&iface, opts.hpn) + + select { + case results <- inetReachability{iface: iface, reachable: err == nil, err: err}: + case <-ctx.Done(): + } + }() + } + + out := make([]inetReachability, 0, len(candidates)) + timeout := time.After(600 * time.Millisecond) + received := 0 + + for received < len(candidates) { + select { + case r := <-results: + // If we're racing, return the first reachable interface immediately. + // TODO (barnstar): We should cache all reachable results so we can try alteratives if we + // can't get the conn up and running later but signal early if we're racing. + if opts.race && r.reachable { + return []inetReachability{r}, nil + } + // .. otherwise, collect all results including the unreachable ones. + out = append(out, r) + received++ + case <-timeout: + return out, fmt.Errorf("netns: probe timed out after %v; received %d/%d results", timeout, received, len(candidates)) + } + } + + return out, nil +} + +// For testing +type reachabilityHookFn func(iface *net.Interface, hpn HostPortNetwork) error + +var reachabilityHook reachabilityHookFn = reachabilityCheck + +func reachabilityCheck(iface *net.Interface, hpn HostPortNetwork) error { + // Per-probe timeout. + dialCtx, dialCancel := context.WithTimeout(context.Background(), 300*time.Millisecond) + defer dialCancel() + + d := net.Dialer{ + Control: func(network, address string, c syscall.RawConn) error { + // (barnstar) TODO: The bind step here is still platform specific + bindFn := getBindFn(network, address) + return bindFn(c, uint32(iface.Index)) + }, + } + + dst := net.JoinHostPort(hpn.Host, hpn.Port) + conn, err := d.DialContext(dialCtx, hpn.Network, dst) + if err == nil { + defer conn.Close() + } + return err +} + +// Pre-filter for interfaces. Platform-specific code can provide a filter +// to exclude certain interfaces from consideration. For example, on Darwin, +// we exclude "utun" interfaces and various other types which will never provie +// have general internet connectivity. +type interfaceFilter func(net.Interface) bool + +func filterInPlace[T any](s []T, keep func(T) bool) []T { + i := 0 + for _, v := range s { + if keep(v) { + s[i] = v + i++ + } + } + return s[:i] +} + +var errUnspecifiedHost = errors.New("unspecified host") + +func parseAddress(address string) (addr netip.Addr, err error) { + host, _, err := net.SplitHostPort(address) + if err != nil { + // error means the string didn't contain a port number, so use the string directly + host = address + } + if host == "" { + return addr, errUnspecifiedHost + } + + return netip.ParseAddr(host) +} + +// findInterfaceThatCanReach finds an interface that can reach the given host:port. +// It uses the provided filterf to exclude certain interfaces, and the +// sortf to prioritize certain interfaces. It returns the first interface that can reach +// the destination. +// +// TODO (barnstar): What this does NOT do is provide a way to flag an interface as "bad" if +// we can't get a connection up and running. Ideally we race for the first candidate, try +// it for a partciular route, and if it fails, remove it from the route cache try a "different" +// candidate. This requires the Dialer to be aware of this logic, and to be able to signal +// back to the route cache that a given interface is "bad" for a given destination. We also +// need to cache all of the candidates found during probing so we can try them again later some +// related state. +// +// nil is returned if no interface can reach the destination. +func findInterfaceThatCanReach(opts probeOpts) (iface *net.Interface, err error) { + // Try to parse the host as an IP address for cache lookup + addr, err := parseAddress(opts.hpn.Host) + if err == nil && addr.IsValid() { + // Check cache first + if cached := opts.cache.lookupCachedRoute(addr); cached != nil { + opts.logf("netns: using cached interface %v for %v", cached.Name, opts.hpn) + return cached, nil + } + } + + res, err := probeInterfacesReachability(opts) + if err != nil { + opts.logf("netns: ProbeInterfacesReachability error: %v", err) + return nil, err + } + + res = filterInPlace(res, func(r inetReachability) bool { return r.reachable }) + if len(res) == 0 { + opts.logf("netns: could not find interface on network %v to reach %q:%q on %q: %v", opts.hpn.Network, opts.hpn.Host, opts.hpn.Port, opts.hpn.Network, err) + return nil, nil + } + + candidatesNames := make([]string, 0, len(res)) + for _, r := range res { + candidatesNames = append(candidatesNames, r.iface.Name) + } + opts.logf("netns: found candidate interfaces that can reach %v:%v on %v: %v", opts.hpn.Host, opts.hpn.Port, opts.hpn.Network, candidatesNames) + iface = &res[0].iface + + if defaultIfaceHintFn != nil { + defIdx := defaultIfaceHintFn() + for _, r := range res { + if r.iface.Index == defIdx { + opts.logf("netns: using default iface hint") + iface = &r.iface + break + } + } + } + + opts.logf("netns: returning interface %v at %v for %v:%v", iface.Name, iface.Index, opts.hpn.Host, opts.hpn.Port) + + // Cache the result if we have a valid IP address + if addr.IsValid() { + opts.cache.setCachedRoute(addr, iface) + } + + return iface, nil +} + +var ifaceHasV4AndGlobalV6Hook func(iface *net.Interface) bool + +// ifaceHasV4AndGlobalV6 reports whether iface has at least one IPv4 address +// and at least one IPv6 address that is not link-local. +func ifaceHasV4OrGlobalV6(iface *net.Interface) bool { + if ifaceHasV4AndGlobalV6Hook != nil { + return ifaceHasV4AndGlobalV6Hook(iface) + } + + addrs, err := iface.Addrs() + if err != nil { + return false + } + for _, a := range addrs { + switch v := a.(type) { + case *net.IPNet: + if v.IP.IsGlobalUnicast() { + return true + } + + } + } + return false +} + +var globalRouteCache *routeCache + +// SetGlobalRouteCache sets the global route cache used by netns. +// It also subscribes the route cache to network change events from +// the provided event bus. +func SetGlobalRouteCache(rc *routeCache, e *eventbus.Bus, logf logger.Logf) { + globalRouteCache = rc + globalRouteCache.subscribeToNetworkChanges(e, logf) +} + +func NewRouteCache() *routeCache { + return &routeCache{ + v4: new(bart.Table[*net.Interface]), + v6: new(bart.Table[*net.Interface]), + } +} + +type routeCache struct { + mu syncs.Mutex + v4 *bart.Table[*net.Interface] // IPv4 routing table + v6 *bart.Table[*net.Interface] // IPv6 routing table + ec *eventbus.Client +} + +func (rc *routeCache) subscribeToNetworkChanges(eventBus *eventbus.Bus, logf logger.Logf) { + rc.mu.Lock() + defer rc.mu.Unlock() + + if rc.ec != nil { + rc.ec.Close() + } + + rc.ec = eventBus.Client("routeCache") + eventbus.SubscribeFunc(rc.ec, func(cd netmon.ChangeDelta) { + if cd.RebindLikelyRequired { + logf("netns: routeCache: major clearing all cached routes due to network change: %v", cd) + rc.ClearAllCachedRoutes() + } + }) + logf("netns: routeCache: subscribed to network change events") +} + +func (rc *routeCache) lookupCachedRoute(addr netip.Addr) *net.Interface { + rc.mu.Lock() + defer rc.mu.Unlock() + + iface, ok := rc.tableForAddr(addr).Lookup(addr) + if !ok { + return nil + } + return iface +} + +func (rc *routeCache) setCachedRoute(addr netip.Addr, iface *net.Interface) { + prefix := netip.PrefixFrom(addr, addrBits(addr)) + rc.setCachedRoutePrefix(prefix, iface) +} + +func (rc *routeCache) setCachedRoutePrefix(prefix netip.Prefix, iface *net.Interface) { + rc.mu.Lock() + defer rc.mu.Unlock() + addr := prefix.Addr() + rc.tableForAddr(addr).Insert(prefix, iface) +} + +func (rc *routeCache) clearCachedRoutePrefix(prefix netip.Prefix) { + rc.mu.Lock() + defer rc.mu.Unlock() + addr := prefix.Addr() + rc.tableForAddr(addr).Delete(prefix) +} + +func (rc *routeCache) ClearCachedRoute(addr netip.Addr) { + prefix := netip.PrefixFrom(addr, addrBits(addr)) + rc.clearCachedRoutePrefix(prefix) +} + +func (rc *routeCache) ClearAllCachedRoutes() { + rc.mu.Lock() + defer rc.mu.Unlock() + + rc.v4 = new(bart.Table[*net.Interface]) + rc.v6 = new(bart.Table[*net.Interface]) +} + +func addrBits(addr netip.Addr) int { + if addr.Is6() { + return 128 + } + return 32 +} + +func (rc *routeCache) tableForAddr(addr netip.Addr) *bart.Table[*net.Interface] { + if addr.Is6() { + return rc.v6 + } + return rc.v4 +} diff --git a/net/netns/netns_test.go b/net/netns/netns_test.go index 82f919b94..99bd034d4 100644 --- a/net/netns/netns_test.go +++ b/net/netns/netns_test.go @@ -14,7 +14,11 @@ package netns import ( + "errors" "flag" + "net" + "net/netip" + "sync/atomic" "testing" ) @@ -76,3 +80,738 @@ func TestIsLocalhost(t *testing.T) { } } } + +func TestGlobalRouteCache(t *testing.T) { + iface1 := &net.Interface{Index: 1, Name: "eth0"} + iface2 := &net.Interface{Index: 2, Name: "eth1"} + iface3 := &net.Interface{Index: 3, Name: "wlan0"} + + t.Run("insert and lookup IPv4", func(t *testing.T) { + routeCache := NewRouteCache() + + addr := netip.MustParseAddr("10.0.1.5") + routeCache.setCachedRoute(addr, iface1) + + got := routeCache.lookupCachedRoute(addr) + if got != iface1 { + t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, iface1) + } + }) + + t.Run("insert and lookup IPv6", func(t *testing.T) { + routeCache := NewRouteCache() + + addr := netip.MustParseAddr("2001:db8::1") + routeCache.setCachedRoute(addr, iface2) + + got := routeCache.lookupCachedRoute(addr) + if got != iface2 { + t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, iface2) + } + }) + + t.Run("lookup non-existent", func(t *testing.T) { + routeCache := NewRouteCache() + addr := netip.MustParseAddr("192.168.1.1") + got := routeCache.lookupCachedRoute(addr) + if got != nil { + t.Errorf("lookupCachedRoute(%v) = %v, want nil", addr, got) + } + }) + + t.Run("longest prefix match IPv4", func(t *testing.T) { + routeCache := NewRouteCache() + + // Insert broader prefix + prefix1 := netip.MustParsePrefix("10.0.0.0/8") + routeCache.setCachedRoutePrefix(prefix1, iface1) + + // Insert more specific prefix + prefix2 := netip.MustParsePrefix("10.0.1.0/24") + routeCache.setCachedRoutePrefix(prefix2, iface2) + + // Insert even more specific prefix + prefix3 := netip.MustParsePrefix("10.0.1.128/25") + routeCache.setCachedRoutePrefix(prefix3, iface3) + + tests := []struct { + addr string + want *net.Interface + }{ + {"10.0.0.1", iface1}, // matches 10.0.0.0/8 + {"10.0.1.1", iface2}, // matches 10.0.1.0/24 + {"10.0.1.129", iface3}, // matches 10.0.1.128/25 + {"10.0.1.127", iface2}, // matches 10.0.1.0/24 (not /25) + {"10.0.2.1", iface1}, // matches 10.0.0.0/8 + {"192.168.1.1", nil}, // no match + } + + for _, tt := range tests { + addr := netip.MustParseAddr(tt.addr) + got := routeCache.lookupCachedRoute(addr) + if got != tt.want { + t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, tt.want) + } + } + }) + + t.Run("longest prefix match IPv6", func(t *testing.T) { + routeCache := NewRouteCache() + + // Insert broader prefix + prefix1 := netip.MustParsePrefix("2001:db8::/32") + routeCache.setCachedRoutePrefix(prefix1, iface1) + + // Insert more specific prefix + prefix2 := netip.MustParsePrefix("2001:db8:1::/48") + routeCache.setCachedRoutePrefix(prefix2, iface2) + + tests := []struct { + addr string + want *net.Interface + }{ + {"2001:db8::1", iface1}, // matches 2001:db8::/32 + {"2001:db8:1::1", iface2}, // matches 2001:db8:1::/48 + {"2001:db8:2::1", iface1}, // matches 2001:db8::/32 + {"2001:db9::1", nil}, // no match + } + + for _, tt := range tests { + addr := netip.MustParseAddr(tt.addr) + got := routeCache.lookupCachedRoute(addr) + if got != tt.want { + t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, tt.want) + } + } + }) + + t.Run("clear cached route by address", func(t *testing.T) { + routeCache := NewRouteCache() + + addr := netip.MustParseAddr("10.0.1.5") + routeCache.setCachedRoute(addr, iface1) + + // Verify it's there + if got := routeCache.lookupCachedRoute(addr); got != iface1 { + t.Errorf("before clear: lookupCachedRoute(%v) = %v, want %v", addr, got, iface1) + } + + // Clear it + routeCache.ClearCachedRoute(addr) + + // Verify it's gone + if got := routeCache.lookupCachedRoute(addr); got != nil { + t.Errorf("after clear: lookupCachedRoute(%v) = %v, want nil", addr, got) + } + }) + + t.Run("clear cached route by prefix", func(t *testing.T) { + routeCache := NewRouteCache() + + prefix := netip.MustParsePrefix("10.0.1.0/24") + routeCache.setCachedRoutePrefix(prefix, iface1) + + // Verify it's there + addr := netip.MustParseAddr("10.0.1.5") + if got := routeCache.lookupCachedRoute(addr); got != iface1 { + t.Errorf("before clear: lookupCachedRoute(%v) = %v, want %v", addr, got, iface1) + } + + // Clear it + routeCache.clearCachedRoutePrefix(prefix) + + // Verify it's gone + if got := routeCache.lookupCachedRoute(addr); got != nil { + t.Errorf("after clear: lookupCachedRoute(%v) = %v, want nil", addr, got) + } + }) + + t.Run("clear specific prefix preserves other prefixes", func(t *testing.T) { + routeCache := NewRouteCache() + + prefix1 := netip.MustParsePrefix("10.0.0.0/8") + prefix2 := netip.MustParsePrefix("192.168.0.0/16") + routeCache.setCachedRoutePrefix(prefix1, iface1) + routeCache.setCachedRoutePrefix(prefix2, iface2) + + // Clear only prefix1 + routeCache.clearCachedRoutePrefix(prefix1) + + // Verify prefix1 is gone + addr1 := netip.MustParseAddr("10.0.1.5") + if got := routeCache.lookupCachedRoute(addr1); got != nil { + t.Errorf("lookupCachedRoute(%v) = %v, want nil", addr1, got) + } + + // Verify prefix2 is still there + addr2 := netip.MustParseAddr("192.168.1.1") + if got := routeCache.lookupCachedRoute(addr2); got != iface2 { + t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr2, got, iface2) + } + }) + + t.Run("clear all cached routes", func(t *testing.T) { + routeCache := NewRouteCache() + + // Insert multiple routes + addr1 := netip.MustParseAddr("10.0.1.5") + addr2 := netip.MustParseAddr("192.168.1.1") + addr3 := netip.MustParseAddr("2001:db8::1") + routeCache.setCachedRoute(addr1, iface1) + routeCache.setCachedRoute(addr2, iface2) + routeCache.setCachedRoute(addr3, iface3) + + // Clear all + routeCache.ClearAllCachedRoutes() + + // Verify all are gone + if got := routeCache.lookupCachedRoute(addr1); got != nil { + t.Errorf("after clear all: lookupCachedRoute(%v) = %v, want nil", addr1, got) + } + if got := routeCache.lookupCachedRoute(addr2); got != nil { + t.Errorf("after clear all: lookupCachedRoute(%v) = %v, want nil", addr2, got) + } + if got := routeCache.lookupCachedRoute(addr3); got != nil { + t.Errorf("after clear all: lookupCachedRoute(%v) = %v, want nil", addr3, got) + } + }) + + t.Run("overwrite existing route", func(t *testing.T) { + routeCache := NewRouteCache() + + addr := netip.MustParseAddr("10.0.1.5") + routeCache.setCachedRoute(addr, iface1) + + // Verify initial value + if got := routeCache.lookupCachedRoute(addr); got != iface1 { + t.Errorf("initial: lookupCachedRoute(%v) = %v, want %v", addr, got, iface1) + } + + // Overwrite with different interface + routeCache.setCachedRoute(addr, iface2) + + // Verify new value + if got := routeCache.lookupCachedRoute(addr); got != iface2 { + t.Errorf("after overwrite: lookupCachedRoute(%v) = %v, want %v", addr, got, iface2) + } + }) + + t.Run("IPv4 and IPv6 are separate", func(t *testing.T) { + routeCache := NewRouteCache() + + addr4 := netip.MustParseAddr("10.0.1.5") + addr6 := netip.MustParseAddr("2001:db8::1") + + routeCache.setCachedRoute(addr4, iface1) + routeCache.setCachedRoute(addr6, iface2) + + // Verify both are stored independently + if got := routeCache.lookupCachedRoute(addr4); got != iface1 { + t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr4, got, iface1) + } + if got := routeCache.lookupCachedRoute(addr6); got != iface2 { + t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr6, got, iface2) + } + + // Clear IPv4, verify IPv6 remains + routeCache.ClearCachedRoute(addr4) + if got := routeCache.lookupCachedRoute(addr4); got != nil { + t.Errorf("after clear v4: lookupCachedRoute(%v) = %v, want nil", addr4, got) + } + if got := routeCache.lookupCachedRoute(addr6); got != iface2 { + t.Errorf("after clear v4: lookupCachedRoute(%v) = %v, want %v", addr6, got, iface2) + } + }) +} + +func hookInterfaces(t *testing.T, ifaces []net.Interface) { + interfacesHook = func() ([]net.Interface, error) { + return ifaces, nil + } + t.Cleanup(func() { + interfacesHook = net.Interfaces + }) +} + +func hookDefaultInterfaces(t *testing.T) { + hookInterfaces(t, allTestIfs) +} + +var ( + iface1 net.Interface = net.Interface{ + Index: 1, + MTU: 1500, + Name: "eth0", + HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55}, + Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning, + } + iface2 net.Interface = net.Interface{ + Index: 2, + MTU: 1500, + Name: "wlan0", + HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x66}, + Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning, + } + iface3 net.Interface = net.Interface{ + Index: 3, + MTU: 1500, + Name: "eth1", + HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x77}, + Flags: net.FlagBroadcast | net.FlagMulticast, + } + allTestIfs = []net.Interface{iface1, iface2, iface3} +) + +func TestFindInterfaceThatCanReach(t *testing.T) { + origReachabilityHook := reachabilityHook + t.Cleanup(func() { + ifaceHasV4AndGlobalV6Hook = nil + reachabilityHook = origReachabilityHook + }) + + ifaceHasV4AndGlobalV6Hook = func(iface *net.Interface) bool { + return true + } + + t.Run("uses route cache on hit", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + // Pre-populate cache + addr := netip.MustParseAddr("8.8.8.8") + cache.setCachedRoute(addr, &iface2) + + // Hook should never be called when cache hits + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + t.Error("reachabilityHookFn should not be called when cache hits") + return nil + } + + opts := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"}, + cache: cache, + } + + result, err := findInterfaceThatCanReach(opts) + if err != nil { + t.Fatalf("findInterfaceThatCanReach failed: %v", err) + } + + if result == nil { + t.Fatal("expected non-nil result") + } + + if result.Name != "wlan0" { + t.Errorf("expected wlan0 from cache, got %s", result.Name) + } + }) + + t.Run("populates cache on miss", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + // All interfaces succeed + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + return nil + } + + opts := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "1.1.1.1", Port: "53", Network: "udp"}, + cache: cache, + } + + result, err := findInterfaceThatCanReach(opts) + if err != nil { + t.Fatalf("findInterfaceThatCanReach failed: %v", err) + } + + if result == nil { + t.Fatal("expected non-nil result") + } + + // Check cache was populated + addr := netip.MustParseAddr("1.1.1.1") + cached := cache.lookupCachedRoute(addr) + if cached == nil { + t.Error("expected cache to be populated") + } else if cached.Name != result.Name { + t.Errorf("cached interface %s != result interface %s", cached.Name, result.Name) + } + }) + + t.Run("returns nil when no interface reachable", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + // All interfaces fail + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + return errors.New("unreachable") + } + + opts := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "192.0.2.1", Port: "53", Network: "udp"}, + cache: cache, + } + + result, err := findInterfaceThatCanReach(opts) + if err != nil { + t.Logf("expected error: %v", err) + } + + if result != nil { + t.Errorf("expected nil result when unreachable, got %v", result) + } + }) + + t.Run("cache respects longest prefix match", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + // Cache 10.0.0.0/8 -> eth0 + prefix1 := netip.MustParsePrefix("10.0.0.0/8") + cache.setCachedRoutePrefix(prefix1, &iface1) + + // Cache 10.0.1.0/24 -> wlan0 + prefix2 := netip.MustParsePrefix("10.0.1.0/24") + cache.setCachedRoutePrefix(prefix2, &iface2) + + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + t.Error("should use cache, not probe") + return nil + } + + // Test 10.0.1.5 -> should match more specific /24 + opts1 := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "10.0.1.5", Port: "53", Network: "udp"}, + cache: cache, + } + + result1, _ := findInterfaceThatCanReach(opts1) + if result1 == nil || result1.Name != "wlan0" { + t.Errorf("expected wlan0 for 10.0.1.5, got %v", result1) + } + + // Test 10.0.2.5 -> should match broader /8 + opts2 := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "10.0.2.5", Port: "53", Network: "udp"}, + cache: cache, + } + + result2, _ := findInterfaceThatCanReach(opts2) + if result2 == nil || result2.Name != "eth0" { + t.Errorf("expected eth0 for 10.0.2.5, got %v", result2) + } + }) + + t.Run("race mode returns first reachable", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + // eth0 (iface1) responds quickly + // wlan0 (iface2) responds slowly + // eth1 (iface3) responds slowly + // Channels to control when each probe completes + wlan0Done := make(chan struct{}) + eth1Done := make(chan struct{}) + + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + switch iface.Index { + case iface1.Index: // eth0 - returns immediately + return nil + case iface2.Index: // wlan0 - waits for signal + <-wlan0Done + return nil + case iface3.Index: // eth1 - waits for signal + <-eth1Done + return nil + } + return errors.New("unknown interface") + } + defer func() { + // Now signal the slower interfaces to complete + close(wlan0Done) + close(eth1Done) + }() + + opts := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"}, + race: true, + cache: cache, + } + + result, err := findInterfaceThatCanReach(opts) + if err != nil { + t.Fatalf("findInterfaceThatCanReach failed: %v", err) + } + + if result == nil { + t.Fatal("expected non-nil result in race mode") + } + + // Should return quickly without waiting for all probes + t.Logf("race mode returned interface: %s", result.Name) + }) + + t.Run("filterf excludes interfaces", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + probeCount := atomic.Int32{} + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + probeCount.Add(1) + return nil + } + + opts := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"}, + cache: cache, + filterf: func(iface net.Interface) bool { + // Exclude wlan0 and eth1 + return iface.Name != "wlan0" && iface.Name != "eth1" + }, + } + + result, err := findInterfaceThatCanReach(opts) + if err != nil { + t.Fatalf("findInterfaceThatCanReach failed: %v", err) + } + + // Should only probe filtered interfaces + if probeCount.Load() > 1 { + t.Logf("probed %d interfaces after filtering", probeCount.Load()) + } + + if result != nil && (result.Name == "wlan0" || result.Name == "eth1") { + t.Errorf("filterf should have excluded %s", result.Name) + } + }) + + t.Run("handles hostname instead of IP", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + return nil + } + + // Use a hostname that can't be parsed as an IP + opts := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "example.com", Port: "443", Network: "tcp"}, + cache: cache, + } + + result, err := findInterfaceThatCanReach(opts) + if err != nil { + t.Fatalf("findInterfaceThatCanReach failed: %v", err) + } + + if result == nil { + t.Fatal("expected non-nil result") + } + + // Cache should not be used for hostnames + addr, parseErr := netip.ParseAddr("example.com") + if parseErr == nil && addr.IsValid() { + t.Error("example.com should not parse as valid IP") + } + }) + + t.Run("default interface hint is respected", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + // All interfaces are reachable + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + return nil + } + + // Set hint to prefer iface2 (index 2) + origHintFn := defaultIfaceHintFn + defer func() { defaultIfaceHintFn = origHintFn }() + + defaultIfaceHintFn = func() int { + return 2 // iface2 / wlan0 + } + + opts := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "1.1.1.1", Port: "53", Network: "udp"}, + cache: cache, + } + + result, err := findInterfaceThatCanReach(opts) + if err != nil { + t.Fatalf("findInterfaceThatCanReach failed: %v", err) + } + + if result == nil { + t.Fatal("expected non-nil result") + } + + if result.Index != 2 { + t.Errorf("expected default hint interface (index 2), got index %d (%s)", result.Index, result.Name) + } + }) + + t.Run("IPv6 address uses IPv6 cache table", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + // Pre-populate IPv6 cache + addr6 := netip.MustParseAddr("2001:4860:4860::8888") + cache.setCachedRoute(addr6, &iface3) + + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + t.Error("should use cache for IPv6") + return nil + } + + opts := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "2001:4860:4860::8888", Port: "53", Network: "udp6"}, + cache: cache, + } + + result, err := findInterfaceThatCanReach(opts) + if err != nil { + t.Fatalf("findInterfaceThatCanReach failed: %v", err) + } + + if result == nil || result.Name != "eth1" { + t.Errorf("expected eth1 from IPv6 cache, got %v", result) + } + }) + + t.Run("IPv4 and IPv6 caches are independent", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + addr4 := netip.MustParseAddr("8.8.8.8") + addr6 := netip.MustParseAddr("2001:4860:4860::8888") + + cache.setCachedRoute(addr4, &iface1) + cache.setCachedRoute(addr6, &iface2) + + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + t.Error("should use cache") + return nil + } + + // Test IPv4 + opts4 := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"}, + cache: cache, + } + result4, _ := findInterfaceThatCanReach(opts4) + if result4 == nil || result4.Name != "eth0" { + t.Errorf("IPv4: expected eth0, got %v", result4) + } + + // Test IPv6 + opts6 := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "2001:4860:4860::8888", Port: "53", Network: "udp6"}, + cache: cache, + } + result6, _ := findInterfaceThatCanReach(opts6) + if result6 == nil || result6.Name != "wlan0" { + t.Errorf("IPv6: expected wlan0, got %v", result6) + } + }) + + t.Run("empty host returns error", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + return nil + } + + opts := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: "", Port: "53", Network: "udp"}, + cache: cache, + } + + result, err := findInterfaceThatCanReach(opts) + + // Should handle empty host gracefully + if err == nil && result != nil { + t.Logf("handled empty host, returned %v", result) + } + }) + + t.Run("caches subnet prefix correctly", func(t *testing.T) { + cache := NewRouteCache() + hookDefaultInterfaces(t) + + // Manually cache a /16 subnet + prefix := netip.MustParsePrefix("192.168.0.0/16") + cache.setCachedRoutePrefix(prefix, &iface1) + + reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error { + t.Error("should use cached subnet") + return nil + } + + // Test various IPs in the subnet + testIPs := []string{ + "192.168.0.1", + "192.168.1.1", + "192.168.255.254", + } + + for _, ip := range testIPs { + opts := probeOpts{ + logf: t.Logf, + hpn: HostPortNetwork{Host: ip, Port: "53", Network: "udp"}, + cache: cache, + } + + result, _ := findInterfaceThatCanReach(opts) + if result == nil || result.Name != "eth0" { + t.Errorf("IP %s: expected eth0 from cached subnet, got %v", ip, result) + } + } + }) +} + +// TODO (barnstar): Working, but the sleep is egregious. How to test async eventbus properly? +// func TestRouteCacheEventBus(t *testing.T) { +// t.Run("insert and lookup IPv4", func(t *testing.T) { +// rc := NewRouteCache() +// bus := eventbus.New() +// b := bus.Client("netns_test") +// t.Cleanup(func() { +// b.Close() +// }) + +// route := netip.MustParseAddr("1.1.1.1") + +// // Example of publishing a route cache clear event +// publisher := eventbus.Publish[netmon.ChangeDelta](b) +// SetGlobalRouteCache(rc, bus, t.Logf) +// rc.setCachedRoute(route, &net.Interface{Index: 1, Name: "eth0"}) +// ifBeforeEvent := rc.lookupCachedRoute(route) +// if ifBeforeEvent == nil || ifBeforeEvent.Name != "eth0" { +// t.Fatalf("expected cached route before event, got %v", ifBeforeEvent) +// } + +// publisher.Publish(netmon.ChangeDelta{RebindLikelyRequired: true}) +// time.Sleep(100 * time.Millisecond) + +// ifAfterEvent := rc.lookupCachedRoute(route) +// if ifAfterEvent != nil { +// t.Fatalf("expected cached route to be cleared after event, got %v", ifAfterEvent) +// } +// }) +// } diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 875011a9c..92caa59a6 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -33,6 +33,7 @@ import ( "tailscale.com/net/dns/resolver" "tailscale.com/net/ipset" "tailscale.com/net/netmon" + "tailscale.com/net/netns" "tailscale.com/net/packet" "tailscale.com/net/sockstats" "tailscale.com/net/tsaddr" @@ -391,6 +392,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) // TODO: there's probably a better place for this sockstats.SetNetMon(e.netMon) + netns.SetGlobalRouteCache(netns.NewRouteCache(), e.eventBus, logf) logf("link state: %+v", e.netMon.InterfaceState())