net/netns: interface probe prototype

WIP

Experiment with an netns alternative that doesn't rely
on the system routing table, but rather probes routes to
find a working interface.
jonathan/netns_probe
Jonathan Nobels 3 months ago
parent 9a6282b515
commit b59d58bb89

@ -72,6 +72,8 @@ func SetDisableBindConnToInterfaceAppleExt(logf logger.Logf, v bool) {
}
}
var probeInterfaces atomic.Bool
// Listener returns a new net.Listener with its Control hook func
// initialized as necessary to run in logical network namespace that
// doesn't route back into Tailscale.

@ -8,7 +8,6 @@ package netns
import (
"errors"
"fmt"
"log"
"net"
"net/netip"
"os"
@ -19,7 +18,6 @@ import (
"golang.org/x/sys/unix"
"tailscale.com/envknob"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"tailscale.com/version"
)
@ -37,23 +35,103 @@ var errInterfaceStateInvalid = errors.New("interface state invalid")
// controlLogf binds c to a particular interface as necessary to dial the
// provided (network, address).
func controlLogf(logf logger.Logf, netMon *netmon.Monitor, network, address string, c syscall.RawConn) error {
if disableBindConnToInterface.Load() || (version.IsMacGUIVariant() && disableBindConnToInterfaceAppleExt.Load()) {
if isLocalhost(address) {
return nil
}
if isLocalhost(address) {
/// FIXME: (barnstar) Temporary probeInterfaces logic. Maybe set via a cap? By platform? So may caps.
probeInterfaces.Store(true)
if probeInterfaces.Load() {
host, port, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("netns: control: SplitHostPort %q: %w", address, err)
}
opts := probeOpts{
logf: logf,
hpn: HostPortNetwork{Network: network, Host: host, Port: port},
filterf: filterInvalidIntefaces,
race: true,
cache: globalRouteCache,
}
// No netmon and no routing table.
iface, err := findInterfaceThatCanReach(opts)
if err != nil || iface == nil {
return err
}
bindFn := getBindFn(network, address)
logf("netns: post-probe binding to interface %q (index %d) for %s/%s", iface.Name, iface.Index, network, address)
return bindFn(c, uint32(iface.Index))
}
// Not probing? Then check if we should bind at all.
if disableBindConnToInterface.Load() || (version.IsMacGUIVariant() && disableBindConnToInterfaceAppleExt.Load()) {
return nil
}
idx, err := getInterfaceIndex(logf, netMon, address)
// Bind using the legacy RIB / netmon method.
idx, _ := getInterfaceIndex(logf, netMon, address)
bindFn := getBindFn(network, address)
return bindFn(c, uint32(idx))
}
func filterInvalidIntefaces(iface net.Interface) bool {
uninterestingPrefixes := []string{"awdl", "llw", "gif", "stf", "ipsec", "bond", "fwip", "utun"}
for _, prefix := range uninterestingPrefixes {
if strings.HasPrefix(iface.Name, prefix) {
return false
}
}
return true
}
// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound
// to the provided interface index.
func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error {
if lc == nil {
return errors.New("nil ListenConfig")
}
if lc.Control != nil {
return errors.New("ListenConfig.Control already set")
}
lc.Control = func(network, address string, c syscall.RawConn) error {
bindFn := getBindFn(network, address)
return bindFn(c, uint32(ifIndex))
}
return nil
}
func bindSocket6(c syscall.RawConn, idx uint32) error {
var sockErr error
err := c.Control(func(fd uintptr) {
sockErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, int(idx))
})
if err != nil {
// callee logged
return nil
return fmt.Errorf("RawConn.Control on %T: %w", c, err)
}
return sockErr
}
return bindConnToInterface(c, network, address, idx, logf)
func bindSocket4(c syscall.RawConn, idx uint32) error {
var sockErr error
err := c.Control(func(fd uintptr) {
sockErr = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, int(idx))
})
if err != nil {
return fmt.Errorf("RawConn.Control on %T: %w", c, err)
}
return sockErr
}
// Legacy
// getInterfaceIndex returns the interface index that we should bind to
// in order to send traffic to the provided address using netmon's view of
// the DefaultRouteInterfaceIndex and/or a direct query to the routing table.
func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string) (int, error) {
// Helper so we can log errors.
defaultIdx := func() (int, error) {
@ -115,14 +193,9 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string)
}
// If the address doesn't parse, use the default index.
addr, err := parseAddress(address)
if err != nil {
if err != errUnspecifiedHost {
logf("[unexpected] netns: error parsing address %q: %v", address, err)
}
return defaultIdx()
}
logf("netns: getting interface index for address %q", address)
addr, err := parseAddress(address)
idx, err := interfaceIndexFor(addr, true /* canRecurse */)
if err != nil {
logf("netns: error getting interface index for %q: %v", address, err)
@ -143,34 +216,6 @@ func getInterfaceIndex(logf logger.Logf, netMon *netmon.Monitor, address string)
return idx, err
}
// tailscaleInterface returns the current machine's Tailscale interface, if any.
// If none is found, (nil, nil) is returned.
// A non-nil error is only returned on a problem listing the system interfaces.
func tailscaleInterface() (*net.Interface, error) {
ifs, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifs {
if !strings.HasPrefix(iface.Name, "utun") {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, a := range addrs {
if ipnet, ok := a.(*net.IPNet); ok {
nip, ok := netip.AddrFromSlice(ipnet.IP)
if ok && tsaddr.IsTailscaleIP(nip.Unmap()) {
return &iface, nil
}
}
}
}
return nil, nil
}
// interfaceIndexFor returns the interface index that we should bind to in
// order to send traffic to the provided address.
func interfaceIndexFor(addr netip.Addr, canRecurse bool) (int, error) {
@ -276,40 +321,3 @@ func interfaceIndexFor(addr netip.Addr, canRecurse bool) (int, error) {
return 0, fmt.Errorf("no valid address found")
}
// SetListenConfigInterfaceIndex sets lc.Control such that sockets are bound
// to the provided interface index.
func SetListenConfigInterfaceIndex(lc *net.ListenConfig, ifIndex int) error {
if lc == nil {
return errors.New("nil ListenConfig")
}
if lc.Control != nil {
return errors.New("ListenConfig.Control already set")
}
lc.Control = func(network, address string, c syscall.RawConn) error {
return bindConnToInterface(c, network, address, ifIndex, log.Printf)
}
return nil
}
func bindConnToInterface(c syscall.RawConn, network, address string, ifIndex int, logf logger.Logf) error {
v6 := strings.Contains(address, "]:") || strings.HasSuffix(network, "6") // hacky test for v6
proto := unix.IPPROTO_IP
opt := unix.IP_BOUND_IF
if v6 {
proto = unix.IPPROTO_IPV6
opt = unix.IPV6_BOUND_IF
}
var sockErr error
err := c.Control(func(fd uintptr) {
sockErr = unix.SetsockoptInt(int(fd), proto, opt, ifIndex)
})
if sockErr != nil {
logf("[unexpected] netns: bindConnToInterface(%q, %q), v6=%v, index=%v: %v", network, address, v6, ifIndex, sockErr)
}
if err != nil {
return fmt.Errorf("RawConn.Control on %T: %w", c, err)
}
return sockErr
}

@ -5,27 +5,6 @@
package netns
import (
"errors"
"net"
"net/netip"
)
var errUnspecifiedHost = errors.New("unspecified host")
func parseAddress(address string) (addr netip.Addr, err error) {
host, _, err := net.SplitHostPort(address)
if err != nil {
// error means the string didn't contain a port number, so use the string directly
host = address
}
if host == "" {
return addr, errUnspecifiedHost
}
return netip.ParseAddr(host)
}
func UseSocketMark() bool {
return false
}

@ -0,0 +1,454 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package netns contains the common code for using the Go net package
// in a logical "network namespace" to avoid routing loops where
// Tailscale-created packets would otherwise loop back through
// Tailscale routes.
//
// Despite the name netns, the exact mechanism used differs by
// operating system, and perhaps even by version of the OS.
//
// The netns package also handles connecting via SOCKS proxies when
// configured by the environment.
package netns
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"syscall"
"time"
"github.com/gaissmai/bart"
"tailscale.com/net/netmon"
"tailscale.com/net/tsaddr"
"tailscale.com/syncs"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
)
// tailscaleInterface returns the current machine's Tailscale interface, if any.
// If none is found, (nil, nil) is returned.
// A non-nil error is only returned on a problem listing the system interfaces.
// TODO (barnstar): netmon *usually* knows this (at least for darwing), but
// this is more portable. It's still wildly different than the Windows method which
// checks the description strings.
func tailscaleInterface() (*net.Interface, error) {
ifs, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifs {
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, a := range addrs {
if ipnet, ok := a.(*net.IPNet); ok {
nip, ok := netip.AddrFromSlice(ipnet.IP)
if ok && tsaddr.IsTailscaleIP(nip.Unmap()) {
return &iface, nil
}
}
}
}
return nil, nil
}
// inetReachability describes an interface and whether it was able to reach
// the provided address.
type inetReachability struct {
iface net.Interface
// TODO (barnstar): These are invariant. reachable should be true if err==nil.
reachable bool
err error
}
// Tuple of the destination host, port, and network.
// ie: "tcp4", "example.com", "80"
type HostPortNetwork struct {
Host string
Port string
Network string
}
func (hpn HostPortNetwork) String() string {
return fmt.Sprintf("%s/%s:%s", hpn.Network, hpn.Host, hpn.Port)
}
type probeOpts struct {
logf logger.Logf
hpn HostPortNetwork
race bool // if true, we'll pick the first interface that responds. sortf is ignored.
filterf interfaceFilter // optional pre-filter for interfaces
cache *routeCache // must be non-nil
}
type DefaultIfaceHintFn func() int
var defaultIfaceHintFn DefaultIfaceHintFn
// Platforms may set defaultIFQueryFn to a function that returns the platforms's high
// level view of the default interface index.
func SetDefaultIFQueryFn(fn DefaultIfaceHintFn) {
defaultIfaceHintFn = fn
}
// uint
type bindFn func(c syscall.RawConn, ifidx uint32) error
// Returns the proper bind function for the given network and address.
// Currently only differentiates between IPv4 and IPv6 - and poorly.
func bindFnByAddrType(network, address string) bindFn {
// Very naive check for IPv6.
if strings.Contains(address, "]:") || strings.HasSuffix(network, "6") {
return bindSocket6
}
return bindSocket4
}
type bindFunctionHook func(network, address string) bindFn
var getBindFn bindFunctionHook = bindFnByAddrType
var interfacesHookFn func() ([]net.Interface, error)
var interfacesHook = net.Interfaces
// ProbeInterfacesReachability probes all non-loopback, up interfaces
// concurrently to determine which can reach the given address. It returns
// a slice with one entry per probed interface in the same order as
// net.Interfaces() filtered by the probe criteria.
func probeInterfacesReachability(opts probeOpts) ([]inetReachability, error) {
ifaces, err := interfacesHook()
if err != nil {
opts.logf("netns: ProbeInterfacesReachability: net.Interfaces: %v", err)
return nil, err
}
results := make(chan inetReachability, len(ifaces))
tsiface, _ := tailscaleInterface()
var candidates []net.Interface
for _, iface := range ifaces {
// Individual platforms can exclude potential intefaces based on platorm-specific logic.
// For example, on Darwin, we skip "utun" interfaces.
if opts.filterf != nil && !opts.filterf(iface) {
continue
}
// Only consider up, non-loopback interfaces.
if iface.Flags&net.FlagUp == 0 || iface.Flags&net.FlagLoopback != 0 || iface.Flags&net.FlagRunning == 0 {
continue
}
// Skip the Tailscale interface.
if tsiface != nil && iface.Index == tsiface.Index {
continue
}
// require an IPv4 or IPv6 global unicast address
if !ifaceHasV4OrGlobalV6(&iface) {
continue
}
candidates = append(candidates, iface)
}
if len(candidates) == 0 {
opts.logf("netns: ProbeInterfacesReachability: no candidate interfaces found")
return nil, errors.New("no candidate interfaces")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, iface := range candidates {
go func() {
// Per-probe timeout.
err := reachabilityHook(&iface, opts.hpn)
select {
case results <- inetReachability{iface: iface, reachable: err == nil, err: err}:
case <-ctx.Done():
}
}()
}
out := make([]inetReachability, 0, len(candidates))
timeout := time.After(600 * time.Millisecond)
received := 0
for received < len(candidates) {
select {
case r := <-results:
// If we're racing, return the first reachable interface immediately.
// TODO (barnstar): We should cache all reachable results so we can try alteratives if we
// can't get the conn up and running later but signal early if we're racing.
if opts.race && r.reachable {
return []inetReachability{r}, nil
}
// .. otherwise, collect all results including the unreachable ones.
out = append(out, r)
received++
case <-timeout:
return out, fmt.Errorf("netns: probe timed out after %v; received %d/%d results", timeout, received, len(candidates))
}
}
return out, nil
}
// For testing
type reachabilityHookFn func(iface *net.Interface, hpn HostPortNetwork) error
var reachabilityHook reachabilityHookFn = reachabilityCheck
func reachabilityCheck(iface *net.Interface, hpn HostPortNetwork) error {
// Per-probe timeout.
dialCtx, dialCancel := context.WithTimeout(context.Background(), 300*time.Millisecond)
defer dialCancel()
d := net.Dialer{
Control: func(network, address string, c syscall.RawConn) error {
// (barnstar) TODO: The bind step here is still platform specific
bindFn := getBindFn(network, address)
return bindFn(c, uint32(iface.Index))
},
}
dst := net.JoinHostPort(hpn.Host, hpn.Port)
conn, err := d.DialContext(dialCtx, hpn.Network, dst)
if err == nil {
defer conn.Close()
}
return err
}
// Pre-filter for interfaces. Platform-specific code can provide a filter
// to exclude certain interfaces from consideration. For example, on Darwin,
// we exclude "utun" interfaces and various other types which will never provie
// have general internet connectivity.
type interfaceFilter func(net.Interface) bool
func filterInPlace[T any](s []T, keep func(T) bool) []T {
i := 0
for _, v := range s {
if keep(v) {
s[i] = v
i++
}
}
return s[:i]
}
var errUnspecifiedHost = errors.New("unspecified host")
func parseAddress(address string) (addr netip.Addr, err error) {
host, _, err := net.SplitHostPort(address)
if err != nil {
// error means the string didn't contain a port number, so use the string directly
host = address
}
if host == "" {
return addr, errUnspecifiedHost
}
return netip.ParseAddr(host)
}
// findInterfaceThatCanReach finds an interface that can reach the given host:port.
// It uses the provided filterf to exclude certain interfaces, and the
// sortf to prioritize certain interfaces. It returns the first interface that can reach
// the destination.
//
// TODO (barnstar): What this does NOT do is provide a way to flag an interface as "bad" if
// we can't get a connection up and running. Ideally we race for the first candidate, try
// it for a partciular route, and if it fails, remove it from the route cache try a "different"
// candidate. This requires the Dialer to be aware of this logic, and to be able to signal
// back to the route cache that a given interface is "bad" for a given destination. We also
// need to cache all of the candidates found during probing so we can try them again later some
// related state.
//
// nil is returned if no interface can reach the destination.
func findInterfaceThatCanReach(opts probeOpts) (iface *net.Interface, err error) {
// Try to parse the host as an IP address for cache lookup
addr, err := parseAddress(opts.hpn.Host)
if err == nil && addr.IsValid() {
// Check cache first
if cached := opts.cache.lookupCachedRoute(addr); cached != nil {
opts.logf("netns: using cached interface %v for %v", cached.Name, opts.hpn)
return cached, nil
}
}
res, err := probeInterfacesReachability(opts)
if err != nil {
opts.logf("netns: ProbeInterfacesReachability error: %v", err)
return nil, err
}
res = filterInPlace(res, func(r inetReachability) bool { return r.reachable })
if len(res) == 0 {
opts.logf("netns: could not find interface on network %v to reach %q:%q on %q: %v", opts.hpn.Network, opts.hpn.Host, opts.hpn.Port, opts.hpn.Network, err)
return nil, nil
}
candidatesNames := make([]string, 0, len(res))
for _, r := range res {
candidatesNames = append(candidatesNames, r.iface.Name)
}
opts.logf("netns: found candidate interfaces that can reach %v:%v on %v: %v", opts.hpn.Host, opts.hpn.Port, opts.hpn.Network, candidatesNames)
iface = &res[0].iface
if defaultIfaceHintFn != nil {
defIdx := defaultIfaceHintFn()
for _, r := range res {
if r.iface.Index == defIdx {
opts.logf("netns: using default iface hint")
iface = &r.iface
break
}
}
}
opts.logf("netns: returning interface %v at %v for %v:%v", iface.Name, iface.Index, opts.hpn.Host, opts.hpn.Port)
// Cache the result if we have a valid IP address
if addr.IsValid() {
opts.cache.setCachedRoute(addr, iface)
}
return iface, nil
}
var ifaceHasV4AndGlobalV6Hook func(iface *net.Interface) bool
// ifaceHasV4AndGlobalV6 reports whether iface has at least one IPv4 address
// and at least one IPv6 address that is not link-local.
func ifaceHasV4OrGlobalV6(iface *net.Interface) bool {
if ifaceHasV4AndGlobalV6Hook != nil {
return ifaceHasV4AndGlobalV6Hook(iface)
}
addrs, err := iface.Addrs()
if err != nil {
return false
}
for _, a := range addrs {
switch v := a.(type) {
case *net.IPNet:
if v.IP.IsGlobalUnicast() {
return true
}
}
}
return false
}
var globalRouteCache *routeCache
// SetGlobalRouteCache sets the global route cache used by netns.
// It also subscribes the route cache to network change events from
// the provided event bus.
func SetGlobalRouteCache(rc *routeCache, e *eventbus.Bus, logf logger.Logf) {
globalRouteCache = rc
globalRouteCache.subscribeToNetworkChanges(e, logf)
}
func NewRouteCache() *routeCache {
return &routeCache{
v4: new(bart.Table[*net.Interface]),
v6: new(bart.Table[*net.Interface]),
}
}
type routeCache struct {
mu syncs.Mutex
v4 *bart.Table[*net.Interface] // IPv4 routing table
v6 *bart.Table[*net.Interface] // IPv6 routing table
ec *eventbus.Client
}
func (rc *routeCache) subscribeToNetworkChanges(eventBus *eventbus.Bus, logf logger.Logf) {
rc.mu.Lock()
defer rc.mu.Unlock()
if rc.ec != nil {
rc.ec.Close()
}
rc.ec = eventBus.Client("routeCache")
eventbus.SubscribeFunc(rc.ec, func(cd netmon.ChangeDelta) {
if cd.RebindLikelyRequired {
logf("netns: routeCache: major clearing all cached routes due to network change: %v", cd)
rc.ClearAllCachedRoutes()
}
})
logf("netns: routeCache: subscribed to network change events")
}
func (rc *routeCache) lookupCachedRoute(addr netip.Addr) *net.Interface {
rc.mu.Lock()
defer rc.mu.Unlock()
iface, ok := rc.tableForAddr(addr).Lookup(addr)
if !ok {
return nil
}
return iface
}
func (rc *routeCache) setCachedRoute(addr netip.Addr, iface *net.Interface) {
prefix := netip.PrefixFrom(addr, addrBits(addr))
rc.setCachedRoutePrefix(prefix, iface)
}
func (rc *routeCache) setCachedRoutePrefix(prefix netip.Prefix, iface *net.Interface) {
rc.mu.Lock()
defer rc.mu.Unlock()
addr := prefix.Addr()
rc.tableForAddr(addr).Insert(prefix, iface)
}
func (rc *routeCache) clearCachedRoutePrefix(prefix netip.Prefix) {
rc.mu.Lock()
defer rc.mu.Unlock()
addr := prefix.Addr()
rc.tableForAddr(addr).Delete(prefix)
}
func (rc *routeCache) ClearCachedRoute(addr netip.Addr) {
prefix := netip.PrefixFrom(addr, addrBits(addr))
rc.clearCachedRoutePrefix(prefix)
}
func (rc *routeCache) ClearAllCachedRoutes() {
rc.mu.Lock()
defer rc.mu.Unlock()
rc.v4 = new(bart.Table[*net.Interface])
rc.v6 = new(bart.Table[*net.Interface])
}
func addrBits(addr netip.Addr) int {
if addr.Is6() {
return 128
}
return 32
}
func (rc *routeCache) tableForAddr(addr netip.Addr) *bart.Table[*net.Interface] {
if addr.Is6() {
return rc.v6
}
return rc.v4
}

@ -14,7 +14,11 @@
package netns
import (
"errors"
"flag"
"net"
"net/netip"
"sync/atomic"
"testing"
)
@ -76,3 +80,738 @@ func TestIsLocalhost(t *testing.T) {
}
}
}
func TestGlobalRouteCache(t *testing.T) {
iface1 := &net.Interface{Index: 1, Name: "eth0"}
iface2 := &net.Interface{Index: 2, Name: "eth1"}
iface3 := &net.Interface{Index: 3, Name: "wlan0"}
t.Run("insert and lookup IPv4", func(t *testing.T) {
routeCache := NewRouteCache()
addr := netip.MustParseAddr("10.0.1.5")
routeCache.setCachedRoute(addr, iface1)
got := routeCache.lookupCachedRoute(addr)
if got != iface1 {
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, iface1)
}
})
t.Run("insert and lookup IPv6", func(t *testing.T) {
routeCache := NewRouteCache()
addr := netip.MustParseAddr("2001:db8::1")
routeCache.setCachedRoute(addr, iface2)
got := routeCache.lookupCachedRoute(addr)
if got != iface2 {
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, iface2)
}
})
t.Run("lookup non-existent", func(t *testing.T) {
routeCache := NewRouteCache()
addr := netip.MustParseAddr("192.168.1.1")
got := routeCache.lookupCachedRoute(addr)
if got != nil {
t.Errorf("lookupCachedRoute(%v) = %v, want nil", addr, got)
}
})
t.Run("longest prefix match IPv4", func(t *testing.T) {
routeCache := NewRouteCache()
// Insert broader prefix
prefix1 := netip.MustParsePrefix("10.0.0.0/8")
routeCache.setCachedRoutePrefix(prefix1, iface1)
// Insert more specific prefix
prefix2 := netip.MustParsePrefix("10.0.1.0/24")
routeCache.setCachedRoutePrefix(prefix2, iface2)
// Insert even more specific prefix
prefix3 := netip.MustParsePrefix("10.0.1.128/25")
routeCache.setCachedRoutePrefix(prefix3, iface3)
tests := []struct {
addr string
want *net.Interface
}{
{"10.0.0.1", iface1}, // matches 10.0.0.0/8
{"10.0.1.1", iface2}, // matches 10.0.1.0/24
{"10.0.1.129", iface3}, // matches 10.0.1.128/25
{"10.0.1.127", iface2}, // matches 10.0.1.0/24 (not /25)
{"10.0.2.1", iface1}, // matches 10.0.0.0/8
{"192.168.1.1", nil}, // no match
}
for _, tt := range tests {
addr := netip.MustParseAddr(tt.addr)
got := routeCache.lookupCachedRoute(addr)
if got != tt.want {
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, tt.want)
}
}
})
t.Run("longest prefix match IPv6", func(t *testing.T) {
routeCache := NewRouteCache()
// Insert broader prefix
prefix1 := netip.MustParsePrefix("2001:db8::/32")
routeCache.setCachedRoutePrefix(prefix1, iface1)
// Insert more specific prefix
prefix2 := netip.MustParsePrefix("2001:db8:1::/48")
routeCache.setCachedRoutePrefix(prefix2, iface2)
tests := []struct {
addr string
want *net.Interface
}{
{"2001:db8::1", iface1}, // matches 2001:db8::/32
{"2001:db8:1::1", iface2}, // matches 2001:db8:1::/48
{"2001:db8:2::1", iface1}, // matches 2001:db8::/32
{"2001:db9::1", nil}, // no match
}
for _, tt := range tests {
addr := netip.MustParseAddr(tt.addr)
got := routeCache.lookupCachedRoute(addr)
if got != tt.want {
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr, got, tt.want)
}
}
})
t.Run("clear cached route by address", func(t *testing.T) {
routeCache := NewRouteCache()
addr := netip.MustParseAddr("10.0.1.5")
routeCache.setCachedRoute(addr, iface1)
// Verify it's there
if got := routeCache.lookupCachedRoute(addr); got != iface1 {
t.Errorf("before clear: lookupCachedRoute(%v) = %v, want %v", addr, got, iface1)
}
// Clear it
routeCache.ClearCachedRoute(addr)
// Verify it's gone
if got := routeCache.lookupCachedRoute(addr); got != nil {
t.Errorf("after clear: lookupCachedRoute(%v) = %v, want nil", addr, got)
}
})
t.Run("clear cached route by prefix", func(t *testing.T) {
routeCache := NewRouteCache()
prefix := netip.MustParsePrefix("10.0.1.0/24")
routeCache.setCachedRoutePrefix(prefix, iface1)
// Verify it's there
addr := netip.MustParseAddr("10.0.1.5")
if got := routeCache.lookupCachedRoute(addr); got != iface1 {
t.Errorf("before clear: lookupCachedRoute(%v) = %v, want %v", addr, got, iface1)
}
// Clear it
routeCache.clearCachedRoutePrefix(prefix)
// Verify it's gone
if got := routeCache.lookupCachedRoute(addr); got != nil {
t.Errorf("after clear: lookupCachedRoute(%v) = %v, want nil", addr, got)
}
})
t.Run("clear specific prefix preserves other prefixes", func(t *testing.T) {
routeCache := NewRouteCache()
prefix1 := netip.MustParsePrefix("10.0.0.0/8")
prefix2 := netip.MustParsePrefix("192.168.0.0/16")
routeCache.setCachedRoutePrefix(prefix1, iface1)
routeCache.setCachedRoutePrefix(prefix2, iface2)
// Clear only prefix1
routeCache.clearCachedRoutePrefix(prefix1)
// Verify prefix1 is gone
addr1 := netip.MustParseAddr("10.0.1.5")
if got := routeCache.lookupCachedRoute(addr1); got != nil {
t.Errorf("lookupCachedRoute(%v) = %v, want nil", addr1, got)
}
// Verify prefix2 is still there
addr2 := netip.MustParseAddr("192.168.1.1")
if got := routeCache.lookupCachedRoute(addr2); got != iface2 {
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr2, got, iface2)
}
})
t.Run("clear all cached routes", func(t *testing.T) {
routeCache := NewRouteCache()
// Insert multiple routes
addr1 := netip.MustParseAddr("10.0.1.5")
addr2 := netip.MustParseAddr("192.168.1.1")
addr3 := netip.MustParseAddr("2001:db8::1")
routeCache.setCachedRoute(addr1, iface1)
routeCache.setCachedRoute(addr2, iface2)
routeCache.setCachedRoute(addr3, iface3)
// Clear all
routeCache.ClearAllCachedRoutes()
// Verify all are gone
if got := routeCache.lookupCachedRoute(addr1); got != nil {
t.Errorf("after clear all: lookupCachedRoute(%v) = %v, want nil", addr1, got)
}
if got := routeCache.lookupCachedRoute(addr2); got != nil {
t.Errorf("after clear all: lookupCachedRoute(%v) = %v, want nil", addr2, got)
}
if got := routeCache.lookupCachedRoute(addr3); got != nil {
t.Errorf("after clear all: lookupCachedRoute(%v) = %v, want nil", addr3, got)
}
})
t.Run("overwrite existing route", func(t *testing.T) {
routeCache := NewRouteCache()
addr := netip.MustParseAddr("10.0.1.5")
routeCache.setCachedRoute(addr, iface1)
// Verify initial value
if got := routeCache.lookupCachedRoute(addr); got != iface1 {
t.Errorf("initial: lookupCachedRoute(%v) = %v, want %v", addr, got, iface1)
}
// Overwrite with different interface
routeCache.setCachedRoute(addr, iface2)
// Verify new value
if got := routeCache.lookupCachedRoute(addr); got != iface2 {
t.Errorf("after overwrite: lookupCachedRoute(%v) = %v, want %v", addr, got, iface2)
}
})
t.Run("IPv4 and IPv6 are separate", func(t *testing.T) {
routeCache := NewRouteCache()
addr4 := netip.MustParseAddr("10.0.1.5")
addr6 := netip.MustParseAddr("2001:db8::1")
routeCache.setCachedRoute(addr4, iface1)
routeCache.setCachedRoute(addr6, iface2)
// Verify both are stored independently
if got := routeCache.lookupCachedRoute(addr4); got != iface1 {
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr4, got, iface1)
}
if got := routeCache.lookupCachedRoute(addr6); got != iface2 {
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr6, got, iface2)
}
// Clear IPv4, verify IPv6 remains
routeCache.ClearCachedRoute(addr4)
if got := routeCache.lookupCachedRoute(addr4); got != nil {
t.Errorf("after clear v4: lookupCachedRoute(%v) = %v, want nil", addr4, got)
}
if got := routeCache.lookupCachedRoute(addr6); got != iface2 {
t.Errorf("after clear v4: lookupCachedRoute(%v) = %v, want %v", addr6, got, iface2)
}
})
}
func hookInterfaces(t *testing.T, ifaces []net.Interface) {
interfacesHook = func() ([]net.Interface, error) {
return ifaces, nil
}
t.Cleanup(func() {
interfacesHook = net.Interfaces
})
}
func hookDefaultInterfaces(t *testing.T) {
hookInterfaces(t, allTestIfs)
}
var (
iface1 net.Interface = net.Interface{
Index: 1,
MTU: 1500,
Name: "eth0",
HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55},
Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning,
}
iface2 net.Interface = net.Interface{
Index: 2,
MTU: 1500,
Name: "wlan0",
HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x66},
Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning,
}
iface3 net.Interface = net.Interface{
Index: 3,
MTU: 1500,
Name: "eth1",
HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x77},
Flags: net.FlagBroadcast | net.FlagMulticast,
}
allTestIfs = []net.Interface{iface1, iface2, iface3}
)
func TestFindInterfaceThatCanReach(t *testing.T) {
origReachabilityHook := reachabilityHook
t.Cleanup(func() {
ifaceHasV4AndGlobalV6Hook = nil
reachabilityHook = origReachabilityHook
})
ifaceHasV4AndGlobalV6Hook = func(iface *net.Interface) bool {
return true
}
t.Run("uses route cache on hit", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
// Pre-populate cache
addr := netip.MustParseAddr("8.8.8.8")
cache.setCachedRoute(addr, &iface2)
// Hook should never be called when cache hits
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
t.Error("reachabilityHookFn should not be called when cache hits")
return nil
}
opts := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"},
cache: cache,
}
result, err := findInterfaceThatCanReach(opts)
if err != nil {
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
if result.Name != "wlan0" {
t.Errorf("expected wlan0 from cache, got %s", result.Name)
}
})
t.Run("populates cache on miss", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
// All interfaces succeed
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
return nil
}
opts := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "1.1.1.1", Port: "53", Network: "udp"},
cache: cache,
}
result, err := findInterfaceThatCanReach(opts)
if err != nil {
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
// Check cache was populated
addr := netip.MustParseAddr("1.1.1.1")
cached := cache.lookupCachedRoute(addr)
if cached == nil {
t.Error("expected cache to be populated")
} else if cached.Name != result.Name {
t.Errorf("cached interface %s != result interface %s", cached.Name, result.Name)
}
})
t.Run("returns nil when no interface reachable", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
// All interfaces fail
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
return errors.New("unreachable")
}
opts := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "192.0.2.1", Port: "53", Network: "udp"},
cache: cache,
}
result, err := findInterfaceThatCanReach(opts)
if err != nil {
t.Logf("expected error: %v", err)
}
if result != nil {
t.Errorf("expected nil result when unreachable, got %v", result)
}
})
t.Run("cache respects longest prefix match", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
// Cache 10.0.0.0/8 -> eth0
prefix1 := netip.MustParsePrefix("10.0.0.0/8")
cache.setCachedRoutePrefix(prefix1, &iface1)
// Cache 10.0.1.0/24 -> wlan0
prefix2 := netip.MustParsePrefix("10.0.1.0/24")
cache.setCachedRoutePrefix(prefix2, &iface2)
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
t.Error("should use cache, not probe")
return nil
}
// Test 10.0.1.5 -> should match more specific /24
opts1 := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "10.0.1.5", Port: "53", Network: "udp"},
cache: cache,
}
result1, _ := findInterfaceThatCanReach(opts1)
if result1 == nil || result1.Name != "wlan0" {
t.Errorf("expected wlan0 for 10.0.1.5, got %v", result1)
}
// Test 10.0.2.5 -> should match broader /8
opts2 := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "10.0.2.5", Port: "53", Network: "udp"},
cache: cache,
}
result2, _ := findInterfaceThatCanReach(opts2)
if result2 == nil || result2.Name != "eth0" {
t.Errorf("expected eth0 for 10.0.2.5, got %v", result2)
}
})
t.Run("race mode returns first reachable", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
// eth0 (iface1) responds quickly
// wlan0 (iface2) responds slowly
// eth1 (iface3) responds slowly
// Channels to control when each probe completes
wlan0Done := make(chan struct{})
eth1Done := make(chan struct{})
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
switch iface.Index {
case iface1.Index: // eth0 - returns immediately
return nil
case iface2.Index: // wlan0 - waits for signal
<-wlan0Done
return nil
case iface3.Index: // eth1 - waits for signal
<-eth1Done
return nil
}
return errors.New("unknown interface")
}
defer func() {
// Now signal the slower interfaces to complete
close(wlan0Done)
close(eth1Done)
}()
opts := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"},
race: true,
cache: cache,
}
result, err := findInterfaceThatCanReach(opts)
if err != nil {
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result in race mode")
}
// Should return quickly without waiting for all probes
t.Logf("race mode returned interface: %s", result.Name)
})
t.Run("filterf excludes interfaces", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
probeCount := atomic.Int32{}
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
probeCount.Add(1)
return nil
}
opts := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"},
cache: cache,
filterf: func(iface net.Interface) bool {
// Exclude wlan0 and eth1
return iface.Name != "wlan0" && iface.Name != "eth1"
},
}
result, err := findInterfaceThatCanReach(opts)
if err != nil {
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
}
// Should only probe filtered interfaces
if probeCount.Load() > 1 {
t.Logf("probed %d interfaces after filtering", probeCount.Load())
}
if result != nil && (result.Name == "wlan0" || result.Name == "eth1") {
t.Errorf("filterf should have excluded %s", result.Name)
}
})
t.Run("handles hostname instead of IP", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
return nil
}
// Use a hostname that can't be parsed as an IP
opts := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "example.com", Port: "443", Network: "tcp"},
cache: cache,
}
result, err := findInterfaceThatCanReach(opts)
if err != nil {
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
// Cache should not be used for hostnames
addr, parseErr := netip.ParseAddr("example.com")
if parseErr == nil && addr.IsValid() {
t.Error("example.com should not parse as valid IP")
}
})
t.Run("default interface hint is respected", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
// All interfaces are reachable
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
return nil
}
// Set hint to prefer iface2 (index 2)
origHintFn := defaultIfaceHintFn
defer func() { defaultIfaceHintFn = origHintFn }()
defaultIfaceHintFn = func() int {
return 2 // iface2 / wlan0
}
opts := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "1.1.1.1", Port: "53", Network: "udp"},
cache: cache,
}
result, err := findInterfaceThatCanReach(opts)
if err != nil {
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
}
if result.Index != 2 {
t.Errorf("expected default hint interface (index 2), got index %d (%s)", result.Index, result.Name)
}
})
t.Run("IPv6 address uses IPv6 cache table", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
// Pre-populate IPv6 cache
addr6 := netip.MustParseAddr("2001:4860:4860::8888")
cache.setCachedRoute(addr6, &iface3)
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
t.Error("should use cache for IPv6")
return nil
}
opts := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "2001:4860:4860::8888", Port: "53", Network: "udp6"},
cache: cache,
}
result, err := findInterfaceThatCanReach(opts)
if err != nil {
t.Fatalf("findInterfaceThatCanReach failed: %v", err)
}
if result == nil || result.Name != "eth1" {
t.Errorf("expected eth1 from IPv6 cache, got %v", result)
}
})
t.Run("IPv4 and IPv6 caches are independent", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
addr4 := netip.MustParseAddr("8.8.8.8")
addr6 := netip.MustParseAddr("2001:4860:4860::8888")
cache.setCachedRoute(addr4, &iface1)
cache.setCachedRoute(addr6, &iface2)
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
t.Error("should use cache")
return nil
}
// Test IPv4
opts4 := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "8.8.8.8", Port: "53", Network: "udp"},
cache: cache,
}
result4, _ := findInterfaceThatCanReach(opts4)
if result4 == nil || result4.Name != "eth0" {
t.Errorf("IPv4: expected eth0, got %v", result4)
}
// Test IPv6
opts6 := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "2001:4860:4860::8888", Port: "53", Network: "udp6"},
cache: cache,
}
result6, _ := findInterfaceThatCanReach(opts6)
if result6 == nil || result6.Name != "wlan0" {
t.Errorf("IPv6: expected wlan0, got %v", result6)
}
})
t.Run("empty host returns error", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
return nil
}
opts := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: "", Port: "53", Network: "udp"},
cache: cache,
}
result, err := findInterfaceThatCanReach(opts)
// Should handle empty host gracefully
if err == nil && result != nil {
t.Logf("handled empty host, returned %v", result)
}
})
t.Run("caches subnet prefix correctly", func(t *testing.T) {
cache := NewRouteCache()
hookDefaultInterfaces(t)
// Manually cache a /16 subnet
prefix := netip.MustParsePrefix("192.168.0.0/16")
cache.setCachedRoutePrefix(prefix, &iface1)
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
t.Error("should use cached subnet")
return nil
}
// Test various IPs in the subnet
testIPs := []string{
"192.168.0.1",
"192.168.1.1",
"192.168.255.254",
}
for _, ip := range testIPs {
opts := probeOpts{
logf: t.Logf,
hpn: HostPortNetwork{Host: ip, Port: "53", Network: "udp"},
cache: cache,
}
result, _ := findInterfaceThatCanReach(opts)
if result == nil || result.Name != "eth0" {
t.Errorf("IP %s: expected eth0 from cached subnet, got %v", ip, result)
}
}
})
}
// TODO (barnstar): Working, but the sleep is egregious. How to test async eventbus properly?
// func TestRouteCacheEventBus(t *testing.T) {
// t.Run("insert and lookup IPv4", func(t *testing.T) {
// rc := NewRouteCache()
// bus := eventbus.New()
// b := bus.Client("netns_test")
// t.Cleanup(func() {
// b.Close()
// })
// route := netip.MustParseAddr("1.1.1.1")
// // Example of publishing a route cache clear event
// publisher := eventbus.Publish[netmon.ChangeDelta](b)
// SetGlobalRouteCache(rc, bus, t.Logf)
// rc.setCachedRoute(route, &net.Interface{Index: 1, Name: "eth0"})
// ifBeforeEvent := rc.lookupCachedRoute(route)
// if ifBeforeEvent == nil || ifBeforeEvent.Name != "eth0" {
// t.Fatalf("expected cached route before event, got %v", ifBeforeEvent)
// }
// publisher.Publish(netmon.ChangeDelta{RebindLikelyRequired: true})
// time.Sleep(100 * time.Millisecond)
// ifAfterEvent := rc.lookupCachedRoute(route)
// if ifAfterEvent != nil {
// t.Fatalf("expected cached route to be cleared after event, got %v", ifAfterEvent)
// }
// })
// }

@ -33,6 +33,7 @@ import (
"tailscale.com/net/dns/resolver"
"tailscale.com/net/ipset"
"tailscale.com/net/netmon"
"tailscale.com/net/netns"
"tailscale.com/net/packet"
"tailscale.com/net/sockstats"
"tailscale.com/net/tsaddr"
@ -391,6 +392,7 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
// TODO: there's probably a better place for this
sockstats.SetNetMon(e.netMon)
netns.SetGlobalRouteCache(netns.NewRouteCache(), e.eventBus, logf)
logf("link state: %+v", e.netMon.InterfaceState())

Loading…
Cancel
Save