Refactoring netns configuration to fix thread safety

issues.
jonathan/netns_probe
Jonathan Nobels 6 days ago
parent b59d58bb89
commit 5e37be0fb6

@ -22,9 +22,60 @@ import (
"tailscale.com/net/netknob"
"tailscale.com/net/netmon"
"tailscale.com/syncs"
"tailscale.com/types/logger"
"tailscale.com/util/eventbus"
)
type Opts struct {
rc *routeCache
e *eventbus.Bus
tunName string
logf logger.Logf
}
func NewOpts(rc *routeCache, e *eventbus.Bus, tunName string, logf logger.Logf) Opts {
return Opts{
rc: rc,
e: e,
tunName: tunName,
logf: logf,
}
}
var netns struct {
mu syncs.Mutex
rc *routeCache
tunName string
logf logger.Logf
}
func cache() *routeCache {
netns.mu.Lock()
defer netns.mu.Unlock()
return netns.rc
}
// SetGlobalRouteCache sets the global route cache used by netns.
// It also subscribes the route cache to network change events from
// the provided event bus.
func Configure(opts Opts) {
netns.mu.Lock()
defer netns.mu.Unlock()
netns.rc = opts.rc
netns.rc.subscribeToNetworkChanges(opts.e, opts.logf)
netns.tunName = opts.tunName
netns.logf = opts.logf
opts.logf("netns: configured with tun as %q", opts.tunName)
}
func tunName() string {
netns.mu.Lock()
defer netns.mu.Unlock()
return netns.tunName
}
var disabled atomic.Bool
// SetEnabled enables or disables netns for the process.

@ -52,7 +52,7 @@ func controlLogf(logf logger.Logf, netMon *netmon.Monitor, network, address stri
hpn: HostPortNetwork{Network: network, Host: host, Port: port},
filterf: filterInvalidIntefaces,
race: true,
cache: globalRouteCache,
cache: cache(),
}
// No netmon and no routing table.

@ -39,11 +39,17 @@ import (
// this is more portable. It's still wildly different than the Windows method which
// checks the description strings.
func tailscaleInterface() (*net.Interface, error) {
tunName := tunName()
ifs, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, iface := range ifs {
if tunName == iface.Name {
return &iface, nil
}
addrs, err := iface.Addrs()
if err != nil {
continue
@ -166,19 +172,21 @@ func probeInterfacesReachability(opts probeOpts) ([]inetReachability, error) {
return nil, errors.New("no candidate interfaces")
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Close this channel to abort ongoing probes if we're racing and are only interested
// in the first result.
done := make(chan struct{})
for _, iface := range candidates {
go func() {
// Per-probe timeout.
err := reachabilityHook(&iface, opts.hpn)
select {
case results <- inetReachability{iface: iface, reachable: err == nil, err: err}:
case <-ctx.Done():
case <-done:
return
default:
}
// Per-probe timeout.
err := reachabilityHook(&iface, opts.hpn)
results <- inetReachability{iface: iface, reachable: err == nil, err: err}
}()
}
@ -193,6 +201,7 @@ func probeInterfacesReachability(opts probeOpts) ([]inetReachability, error) {
// TODO (barnstar): We should cache all reachable results so we can try alteratives if we
// can't get the conn up and running later but signal early if we're racing.
if opts.race && r.reachable {
close(done)
return []inetReachability{r}, nil
}
// .. otherwise, collect all results including the unreachable ones.
@ -284,7 +293,8 @@ func findInterfaceThatCanReach(opts probeOpts) (iface *net.Interface, err error)
if err == nil && addr.IsValid() {
// Check cache first
if cached := opts.cache.lookupCachedRoute(addr); cached != nil {
opts.logf("netns: using cached interface %v for %v", cached.Name, opts.hpn)
hits, misses, total := opts.cache.stats()
opts.logf("netns: cachHit for %v cache stats: hits=%d misses=%d total=%d", addr, hits, misses, total)
return cached, nil
}
}
@ -305,7 +315,6 @@ func findInterfaceThatCanReach(opts probeOpts) (iface *net.Interface, err error)
for _, r := range res {
candidatesNames = append(candidatesNames, r.iface.Name)
}
opts.logf("netns: found candidate interfaces that can reach %v:%v on %v: %v", opts.hpn.Host, opts.hpn.Port, opts.hpn.Network, candidatesNames)
iface = &res[0].iface
if defaultIfaceHintFn != nil {
@ -319,8 +328,6 @@ func findInterfaceThatCanReach(opts probeOpts) (iface *net.Interface, err error)
}
}
opts.logf("netns: returning interface %v at %v for %v:%v", iface.Name, iface.Index, opts.hpn.Host, opts.hpn.Port)
// Cache the result if we have a valid IP address
if addr.IsValid() {
opts.cache.setCachedRoute(addr, iface)
@ -354,43 +361,40 @@ func ifaceHasV4OrGlobalV6(iface *net.Interface) bool {
return false
}
var globalRouteCache *routeCache
// SetGlobalRouteCache sets the global route cache used by netns.
// It also subscribes the route cache to network change events from
// the provided event bus.
func SetGlobalRouteCache(rc *routeCache, e *eventbus.Bus, logf logger.Logf) {
globalRouteCache = rc
globalRouteCache.subscribeToNetworkChanges(e, logf)
}
func NewRouteCache() *routeCache {
return &routeCache{
v4: new(bart.Table[*net.Interface]),
v6: new(bart.Table[*net.Interface]),
table: new(bart.Table[*net.Interface]),
}
}
type routeCache struct {
mu syncs.Mutex
v4 *bart.Table[*net.Interface] // IPv4 routing table
v6 *bart.Table[*net.Interface] // IPv6 routing table
ec *eventbus.Client
mu syncs.Mutex
table *bart.Table[*net.Interface] // IPv4 routing table
ec *eventbus.Client
hits int
misses int
}
func (rc *routeCache) subscribeToNetworkChanges(eventBus *eventbus.Bus, logf logger.Logf) {
func (rc *routeCache) stats() (hits, misses, total int) {
rc.mu.Lock()
defer rc.mu.Unlock()
return rc.hits, rc.misses, rc.table.Size()
}
func (rc *routeCache) subscribeToNetworkChanges(eventBus *eventbus.Bus, logf logger.Logf) {
rc.mu.Lock()
if rc.ec != nil {
rc.ec.Close()
}
rc.ec = eventBus.Client("routeCache")
rc.mu.Unlock()
eventbus.SubscribeFunc(rc.ec, func(cd netmon.ChangeDelta) {
if cd.RebindLikelyRequired {
logf("netns: routeCache: major clearing all cached routes due to network change: %v", cd)
rc.ClearAllCachedRoutes()
rc.Reset()
}
})
logf("netns: routeCache: subscribed to network change events")
@ -400,10 +404,12 @@ func (rc *routeCache) lookupCachedRoute(addr netip.Addr) *net.Interface {
rc.mu.Lock()
defer rc.mu.Unlock()
iface, ok := rc.tableForAddr(addr).Lookup(addr)
iface, ok := rc.table.Lookup(addr)
if !ok {
rc.misses++
return nil
}
rc.hits++
return iface
}
@ -415,15 +421,13 @@ func (rc *routeCache) setCachedRoute(addr netip.Addr, iface *net.Interface) {
func (rc *routeCache) setCachedRoutePrefix(prefix netip.Prefix, iface *net.Interface) {
rc.mu.Lock()
defer rc.mu.Unlock()
addr := prefix.Addr()
rc.tableForAddr(addr).Insert(prefix, iface)
rc.table.Insert(prefix, iface)
}
func (rc *routeCache) clearCachedRoutePrefix(prefix netip.Prefix) {
rc.mu.Lock()
defer rc.mu.Unlock()
addr := prefix.Addr()
rc.tableForAddr(addr).Delete(prefix)
rc.table.Delete(prefix)
}
func (rc *routeCache) ClearCachedRoute(addr netip.Addr) {
@ -431,12 +435,14 @@ func (rc *routeCache) ClearCachedRoute(addr netip.Addr) {
rc.clearCachedRoutePrefix(prefix)
}
func (rc *routeCache) ClearAllCachedRoutes() {
func (rc *routeCache) Reset() {
rc.mu.Lock()
defer rc.mu.Unlock()
rc.v4 = new(bart.Table[*net.Interface])
rc.v6 = new(bart.Table[*net.Interface])
rc.hits = 0
rc.misses = 0
rc.table = new(bart.Table[*net.Interface])
}
func addrBits(addr netip.Addr) int {
@ -445,10 +451,3 @@ func addrBits(addr netip.Addr) int {
}
return 32
}
func (rc *routeCache) tableForAddr(addr netip.Addr) *bart.Table[*net.Interface] {
if addr.Is6() {
return rc.v6
}
return rc.v4
}

@ -262,7 +262,7 @@ func TestGlobalRouteCache(t *testing.T) {
routeCache.setCachedRoute(addr3, iface3)
// Clear all
routeCache.ClearAllCachedRoutes()
routeCache.Reset()
// Verify all are gone
if got := routeCache.lookupCachedRoute(addr1); got != nil {
@ -295,33 +295,6 @@ func TestGlobalRouteCache(t *testing.T) {
t.Errorf("after overwrite: lookupCachedRoute(%v) = %v, want %v", addr, got, iface2)
}
})
t.Run("IPv4 and IPv6 are separate", func(t *testing.T) {
routeCache := NewRouteCache()
addr4 := netip.MustParseAddr("10.0.1.5")
addr6 := netip.MustParseAddr("2001:db8::1")
routeCache.setCachedRoute(addr4, iface1)
routeCache.setCachedRoute(addr6, iface2)
// Verify both are stored independently
if got := routeCache.lookupCachedRoute(addr4); got != iface1 {
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr4, got, iface1)
}
if got := routeCache.lookupCachedRoute(addr6); got != iface2 {
t.Errorf("lookupCachedRoute(%v) = %v, want %v", addr6, got, iface2)
}
// Clear IPv4, verify IPv6 remains
routeCache.ClearCachedRoute(addr4)
if got := routeCache.lookupCachedRoute(addr4); got != nil {
t.Errorf("after clear v4: lookupCachedRoute(%v) = %v, want nil", addr4, got)
}
if got := routeCache.lookupCachedRoute(addr6); got != iface2 {
t.Errorf("after clear v4: lookupCachedRoute(%v) = %v, want %v", addr6, got, iface2)
}
})
}
func hookInterfaces(t *testing.T, ifaces []net.Interface) {
@ -338,28 +311,28 @@ func hookDefaultInterfaces(t *testing.T) {
}
var (
iface1 net.Interface = net.Interface{
interfaceEth0 net.Interface = net.Interface{
Index: 1,
MTU: 1500,
Name: "eth0",
HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x55},
Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning,
}
iface2 net.Interface = net.Interface{
interfaceWlan0 net.Interface = net.Interface{
Index: 2,
MTU: 1500,
Name: "wlan0",
HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x66},
Flags: net.FlagUp | net.FlagBroadcast | net.FlagMulticast | net.FlagRunning,
}
iface3 net.Interface = net.Interface{
interfaceEth1 net.Interface = net.Interface{
Index: 3,
MTU: 1500,
Name: "eth1",
HardwareAddr: net.HardwareAddr{0x00, 0x11, 0x22, 0x33, 0x44, 0x77},
Flags: net.FlagBroadcast | net.FlagMulticast,
}
allTestIfs = []net.Interface{iface1, iface2, iface3}
allTestIfs = []net.Interface{interfaceEth0, interfaceWlan0, interfaceEth1}
)
func TestFindInterfaceThatCanReach(t *testing.T) {
@ -379,7 +352,7 @@ func TestFindInterfaceThatCanReach(t *testing.T) {
// Pre-populate cache
addr := netip.MustParseAddr("8.8.8.8")
cache.setCachedRoute(addr, &iface2)
cache.setCachedRoute(addr, &interfaceWlan0)
// Hook should never be called when cache hits
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
@ -472,11 +445,11 @@ func TestFindInterfaceThatCanReach(t *testing.T) {
// Cache 10.0.0.0/8 -> eth0
prefix1 := netip.MustParsePrefix("10.0.0.0/8")
cache.setCachedRoutePrefix(prefix1, &iface1)
cache.setCachedRoutePrefix(prefix1, &interfaceEth0)
// Cache 10.0.1.0/24 -> wlan0
prefix2 := netip.MustParsePrefix("10.0.1.0/24")
cache.setCachedRoutePrefix(prefix2, &iface2)
cache.setCachedRoutePrefix(prefix2, &interfaceWlan0)
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
t.Error("should use cache, not probe")
@ -521,19 +494,18 @@ func TestFindInterfaceThatCanReach(t *testing.T) {
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
switch iface.Index {
case iface1.Index: // eth0 - returns immediately
case interfaceEth0.Index: // eth0 - returns immediately
return nil
case iface2.Index: // wlan0 - waits for signal
case interfaceWlan0.Index: // wlan0 - waits for signal
<-wlan0Done
return nil
case iface3.Index: // eth1 - waits for signal
case interfaceEth1.Index: // eth1 - waits for signal
<-eth1Done
return nil
}
return errors.New("unknown interface")
}
defer func() {
// Now signal the slower interfaces to complete
close(wlan0Done)
close(eth1Done)
}()
@ -667,7 +639,7 @@ func TestFindInterfaceThatCanReach(t *testing.T) {
// Pre-populate IPv6 cache
addr6 := netip.MustParseAddr("2001:4860:4860::8888")
cache.setCachedRoute(addr6, &iface3)
cache.setCachedRoute(addr6, &interfaceEth1)
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
t.Error("should use cache for IPv6")
@ -697,8 +669,8 @@ func TestFindInterfaceThatCanReach(t *testing.T) {
addr4 := netip.MustParseAddr("8.8.8.8")
addr6 := netip.MustParseAddr("2001:4860:4860::8888")
cache.setCachedRoute(addr4, &iface1)
cache.setCachedRoute(addr6, &iface2)
cache.setCachedRoute(addr4, &interfaceEth0)
cache.setCachedRoute(addr6, &interfaceWlan0)
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
t.Error("should use cache")
@ -756,7 +728,7 @@ func TestFindInterfaceThatCanReach(t *testing.T) {
// Manually cache a /16 subnet
prefix := netip.MustParsePrefix("192.168.0.0/16")
cache.setCachedRoutePrefix(prefix, &iface1)
cache.setCachedRoutePrefix(prefix, &interfaceEth0)
reachabilityHook = func(iface *net.Interface, hpn HostPortNetwork) error {
t.Error("should use cached subnet")

@ -392,7 +392,8 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error)
// TODO: there's probably a better place for this
sockstats.SetNetMon(e.netMon)
netns.SetGlobalRouteCache(netns.NewRouteCache(), e.eventBus, logf)
rc := netns.NewRouteCache()
netns.Configure(netns.NewOpts(rc, e.eventBus, tunName, logf))
logf("link state: %+v", e.netMon.InterfaceState())

Loading…
Cancel
Save