util/linuxfw: move detection logic

Just a refactor to consolidate the firewall detection logic in a single
package so that it can be reused in a later commit by containerboot.

Updates #9310

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/9739/head
Maisem Ali 8 months ago committed by Maisem Ali
parent 56c0a75ea9
commit 05a1f5bf71

@ -0,0 +1,110 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build linux
package linuxfw
import (
"tailscale.com/envknob"
"tailscale.com/hostinfo"
"tailscale.com/types/logger"
"tailscale.com/version/distro"
)
func detectFirewallMode(logf logger.Logf) FirewallMode {
if distro.Get() == distro.Gokrazy {
// Reduce startup logging on gokrazy. There's no way to do iptables on
// gokrazy anyway.
logf("GoKrazy should use nftables.")
hostinfo.SetFirewallMode("nft-gokrazy")
return FirewallModeNfTables
}
envMode := envknob.String("TS_DEBUG_FIREWALL_MODE")
// We now use iptables as default and have "auto" and "nftables" as
// options for people to test further.
switch envMode {
case "auto":
return pickFirewallModeFromInstalledRules(logf, linuxFWDetector{})
case "nftables":
logf("envknob TS_DEBUG_FIREWALL_MODE=nftables set")
hostinfo.SetFirewallMode("nft-forced")
return FirewallModeNfTables
case "iptables":
logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set")
hostinfo.SetFirewallMode("ipt-forced")
default:
logf("default choosing iptables")
hostinfo.SetFirewallMode("ipt-default")
}
return FirewallModeIPTables
}
// tableDetector abstracts helpers to detect the firewall mode.
// It is implemented for testing purposes.
type tableDetector interface {
iptDetect() (int, error)
nftDetect() (int, error)
}
type linuxFWDetector struct{}
// iptDetect returns the number of iptables rules in the current namespace.
func (l linuxFWDetector) iptDetect() (int, error) {
return detectIptables()
}
// nftDetect returns the number of nftables rules in the current namespace.
func (l linuxFWDetector) nftDetect() (int, error) {
return detectNetfilter()
}
// pickFirewallModeFromInstalledRules returns the firewall mode to use based on
// the environment and the system's capabilities.
func pickFirewallModeFromInstalledRules(logf logger.Logf, det tableDetector) FirewallMode {
if distro.Get() == distro.Gokrazy {
// Reduce startup logging on gokrazy. There's no way to do iptables on
// gokrazy anyway.
return FirewallModeNfTables
}
iptAva, nftAva := true, true
iptRuleCount, err := det.iptDetect()
if err != nil {
logf("detect iptables rule: %v", err)
iptAva = false
}
nftRuleCount, err := det.nftDetect()
if err != nil {
logf("detect nftables rule: %v", err)
nftAva = false
}
logf("nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount)
switch {
case nftRuleCount > 0 && iptRuleCount == 0:
logf("nftables is currently in use")
hostinfo.SetFirewallMode("nft-inuse")
return FirewallModeNfTables
case iptRuleCount > 0 && nftRuleCount == 0:
logf("iptables is currently in use")
hostinfo.SetFirewallMode("ipt-inuse")
return FirewallModeIPTables
case nftAva:
// if both iptables and nftables are available but
// neither/both are currently used, use nftables.
logf("nftables is available")
hostinfo.SetFirewallMode("nft")
return FirewallModeNfTables
case iptAva:
logf("iptables is available")
hostinfo.SetFirewallMode("ipt")
return FirewallModeIPTables
default:
// if neither iptables nor nftables are available, use iptablesRunner as a dummy
// runner which exists but won't do anything. Creating iptablesRunner errors only
// if the iptables command is missing or doesnt support "--version", as long as it
// can determine a version then itll carry on.
hostinfo.SetFirewallMode("ipt-fb")
return FirewallModeIPTables
}
}

@ -23,13 +23,13 @@ func DebugIptables(logf logger.Logf) error {
return nil return nil
} }
// DetectIptables returns the number of iptables rules that are present in the // detectIptables returns the number of iptables rules that are present in the
// system, ignoring the default "ACCEPT" rule present in the standard iptables // system, ignoring the default "ACCEPT" rule present in the standard iptables
// chains. // chains.
// //
// It only returns an error when there is no iptables binary, or when iptables -S // It only returns an error when there is no iptables binary, or when iptables -S
// fails. In all other cases, it returns the number of non-default rules. // fails. In all other cases, it returns the number of non-default rules.
func DetectIptables() (int, error) { func detectIptables() (int, error) {
// run "iptables -S" to get the list of rules using iptables // run "iptables -S" to get the list of rules using iptables
// exec.Command returns an error if the binary is not found // exec.Command returns an error if the binary is not found
cmd := exec.Command("iptables", "-S") cmd := exec.Command("iptables", "-S")

@ -45,11 +45,11 @@ func checkIP6TablesExists() error {
return nil return nil
} }
// NewIPTablesRunner constructs a NetfilterRunner that programs iptables rules. // newIPTablesRunner constructs a NetfilterRunner that programs iptables rules.
// If the underlying iptables library fails to initialize, that error is // If the underlying iptables library fails to initialize, that error is
// returned. The runner probes for IPv6 support once at initialization time and // returned. The runner probes for IPv6 support once at initialization time and
// if not found, no IPv6 rules will be modified for the lifetime of the runner. // if not found, no IPv6 rules will be modified for the lifetime of the runner.
func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
if err != nil { if err != nil {
return nil, err return nil, err
@ -79,12 +79,12 @@ func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
return &iptablesRunner{ipt4, ipt6, supportsV6, supportsV6NAT}, nil return &iptablesRunner{ipt4, ipt6, supportsV6, supportsV6NAT}, nil
} }
// HasIPV6 returns true if the system supports IPv6. // HasIPV6 reports true if the system supports IPv6.
func (i *iptablesRunner) HasIPV6() bool { func (i *iptablesRunner) HasIPV6() bool {
return i.v6Available return i.v6Available
} }
// HasIPV6NAT returns true if the system supports IPv6 NAT. // HasIPV6NAT reports true if the system supports IPv6 NAT.
func (i *iptablesRunner) HasIPV6NAT() bool { func (i *iptablesRunner) HasIPV6NAT() bool {
return i.v6NATAvailable return i.v6NATAvailable
} }

@ -25,16 +25,16 @@ func DebugNetfilter(logf logger.Logf) error {
} }
// DetectNetfilter is not supported on non-Linux platforms. // DetectNetfilter is not supported on non-Linux platforms.
func DetectNetfilter() (int, error) { func detectNetfilter() (int, error) {
return 0, ErrUnsupported return 0, ErrUnsupported
} }
// DebugIptables is not supported on non-Linux platforms. // DebugIptables is not supported on non-Linux platforms.
func DebugIptables(logf logger.Logf) error { func debugIptables(logf logger.Logf) error {
return ErrUnsupported return ErrUnsupported
} }
// DetectIptables is not supported on non-Linux platforms. // DetectIptables is not supported on non-Linux platforms.
func DetectIptables() (int, error) { func detectIptables() (int, error) {
return 0, ErrUnsupported return 0, ErrUnsupported
} }

@ -103,8 +103,8 @@ func DebugNetfilter(logf logger.Logf) error {
return nil return nil
} }
// DetectNetfilter returns the number of nftables rules present in the system. // detectNetfilter returns the number of nftables rules present in the system.
func DetectNetfilter() (int, error) { func detectNetfilter() (int, error) {
conn, err := nftables.New() conn, err := nftables.New()
if err != nil { if err != nil {
return 0, FWModeNotSupportedError{ return 0, FWModeNotSupportedError{

@ -175,9 +175,67 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
return nil return nil
} }
// NewNfTablesRunner creates a new nftablesRunner without guaranteeing // NetfilterRunner abstracts helpers to run netfilter commands. It is
// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
type NetfilterRunner interface {
// AddLoopbackRule adds a rule to permit loopback traffic to addr. This rule
// is added only if it does not already exist.
AddLoopbackRule(addr netip.Addr) error
// DelLoopbackRule removes the rule added by AddLoopbackRule.
DelLoopbackRule(addr netip.Addr) error
// AddHooks adds rules to conventional chains like "FORWARD", "INPUT" and
// "POSTROUTING" to jump from those chains to tailscale chains.
AddHooks() error
// DelHooks deletes rules added by AddHooks.
DelHooks(logf logger.Logf) error
// AddChains creates custom Tailscale chains.
AddChains() error
// DelChains removes chains added by AddChains.
DelChains() error
// AddBase adds rules reused by different other rules.
AddBase(tunname string) error
// DelBase removes rules added by AddBase.
DelBase() error
// AddSNATRule adds the netfilter rule to SNAT incoming traffic over
// the Tailscale interface destined for local subnets. An error is
// returned if the rule already exists.
AddSNATRule() error
// DelSNATRule removes the rule added by AddSNATRule.
DelSNATRule() error
// HasIPV6 reports true if the system supports IPv6.
HasIPV6() bool
// HasIPV6NAT reports true if the system supports IPv6 NAT.
HasIPV6NAT() bool
}
// New creates a NetfilterRunner using either nftables or iptables.
// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set.
func New(logf logger.Logf) (NetfilterRunner, error) {
mode := detectFirewallMode(logf)
switch mode {
case FirewallModeIPTables:
return newIPTablesRunner(logf)
case FirewallModeNfTables:
return newNfTablesRunner(logf)
default:
return nil, fmt.Errorf("unknown firewall mode %v", mode)
}
}
// newNfTablesRunner creates a new nftablesRunner without guaranteeing
// the existence of the tables and chains. // the existence of the tables and chains.
func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
conn, err := nftables.New() conn, err := nftables.New()
if err != nil { if err != nil {
return nil, fmt.Errorf("nftables connection: %w", err) return nil, fmt.Errorf("nftables connection: %w", err)
@ -231,7 +289,7 @@ func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, err
} }
} }
// HasIPV6 returns true if the system supports IPv6. // HasIPV6 reports true if the system supports IPv6.
func (n *nftablesRunner) HasIPV6() bool { func (n *nftablesRunner) HasIPV6() bool {
return n.v6Available return n.v6Available
} }

@ -7,6 +7,7 @@ package linuxfw
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"net/netip" "net/netip"
"os" "os"
@ -946,3 +947,63 @@ func TestNFTAddAndDelHookRule(t *testing.T) {
t.Fatalf("expected 0 rule in POSTROUTING chain, got %v", len(postroutingChainRules)) t.Fatalf("expected 0 rule in POSTROUTING chain, got %v", len(postroutingChainRules))
} }
} }
type testFWDetector struct {
iptRuleCount, nftRuleCount int
iptErr, nftErr error
}
func (t *testFWDetector) iptDetect() (int, error) {
return t.iptRuleCount, t.iptErr
}
func (t *testFWDetector) nftDetect() (int, error) {
return t.nftRuleCount, t.nftErr
}
func TestPickFirewallModeFromInstalledRules(t *testing.T) {
tests := []struct {
name string
det *testFWDetector
want FirewallMode
}{
{
name: "using iptables legacy",
det: &testFWDetector{iptRuleCount: 1},
want: FirewallModeIPTables,
},
{
name: "using nftables",
det: &testFWDetector{nftRuleCount: 1},
want: FirewallModeNfTables,
},
{
name: "using both iptables and nftables",
det: &testFWDetector{iptRuleCount: 2, nftRuleCount: 2},
want: FirewallModeNfTables,
},
{
name: "not using any firewall, both available",
det: &testFWDetector{},
want: FirewallModeNfTables,
},
{
name: "not using any firewall, iptables available only",
det: &testFWDetector{iptRuleCount: 1, nftErr: errors.New("nft error")},
want: FirewallModeIPTables,
},
{
name: "not using any firewall, nftables available only",
det: &testFWDetector{iptErr: errors.New("iptables error"), nftRuleCount: 1},
want: FirewallModeNfTables,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := pickFirewallModeFromInstalledRules(t.Logf, tt.det)
if got != tt.want {
t.Errorf("chooseFireWallMode() = %v, want %v", got, tt.want)
}
})
}
}

@ -22,7 +22,6 @@ import (
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"golang.org/x/time/rate" "golang.org/x/time/rate"
"tailscale.com/envknob" "tailscale.com/envknob"
"tailscale.com/hostinfo"
"tailscale.com/net/netmon" "tailscale.com/net/netmon"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/preftype" "tailscale.com/types/preftype"
@ -37,145 +36,6 @@ const (
netfilterOn = preftype.NetfilterOn netfilterOn = preftype.NetfilterOn
) )
// netfilterRunner abstracts helpers to run netfilter commands. It is
// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
type netfilterRunner interface {
AddLoopbackRule(addr netip.Addr) error
DelLoopbackRule(addr netip.Addr) error
AddHooks() error
DelHooks(logf logger.Logf) error
AddChains() error
DelChains() error
AddBase(tunname string) error
DelBase() error
AddSNATRule() error
DelSNATRule() error
HasIPV6() bool
HasIPV6NAT() bool
}
// tableDetector abstracts helpers to detect the firewall mode.
// It is implemented for testing purposes.
type tableDetector interface {
iptDetect() (int, error)
nftDetect() (int, error)
}
type linuxFWDetector struct{}
// iptDetect returns the number of iptables rules in the current namespace.
func (l *linuxFWDetector) iptDetect() (int, error) {
return linuxfw.DetectIptables()
}
// nftDetect returns the number of nftables rules in the current namespace.
func (l *linuxFWDetector) nftDetect() (int, error) {
return linuxfw.DetectNetfilter()
}
// chooseFireWallMode returns the firewall mode to use based on the
// environment and the system's capabilities.
func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMode {
if distro.Get() == distro.Gokrazy {
// Reduce startup logging on gokrazy. There's no way to do iptables on
// gokrazy anyway.
return linuxfw.FirewallModeNfTables
}
iptAva, nftAva := true, true
iptRuleCount, err := det.iptDetect()
if err != nil {
logf("detect iptables rule: %v", err)
iptAva = false
}
nftRuleCount, err := det.nftDetect()
if err != nil {
logf("detect nftables rule: %v", err)
nftAva = false
}
logf("nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount)
switch {
case nftRuleCount > 0 && iptRuleCount == 0:
logf("nftables is currently in use")
hostinfo.SetFirewallMode("nft-inuse")
return linuxfw.FirewallModeNfTables
case iptRuleCount > 0 && nftRuleCount == 0:
logf("iptables is currently in use")
hostinfo.SetFirewallMode("ipt-inuse")
return linuxfw.FirewallModeIPTables
case nftAva:
// if both iptables and nftables are available but
// neither/both are currently used, use nftables.
logf("nftables is available")
hostinfo.SetFirewallMode("nft")
return linuxfw.FirewallModeNfTables
case iptAva:
logf("iptables is available")
hostinfo.SetFirewallMode("ipt")
return linuxfw.FirewallModeIPTables
default:
// if neither iptables nor nftables are available, use iptablesRunner as a dummy
// runner which exists but won't do anything. Creating iptablesRunner errors only
// if the iptables command is missing or doesnt support "--version", as long as it
// can determine a version then itll carry on.
hostinfo.SetFirewallMode("ipt-fb")
return linuxfw.FirewallModeIPTables
}
}
// newNetfilterRunner creates a netfilterRunner using either nftables or iptables.
// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set.
func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) {
tableDetector := &linuxFWDetector{}
var mode linuxfw.FirewallMode
// We now use iptables as default and have "auto" and "nftables" as
// options for people to test further.
switch {
case distro.Get() == distro.Gokrazy:
// Reduce startup logging on gokrazy. There's no way to do iptables on
// gokrazy anyway.
logf("GoKrazy should use nftables.")
hostinfo.SetFirewallMode("nft-gokrazy")
mode = linuxfw.FirewallModeNfTables
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "nftables":
logf("envknob TS_DEBUG_FIREWALL_MODE=nftables set")
hostinfo.SetFirewallMode("nft-forced")
mode = linuxfw.FirewallModeNfTables
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "auto":
mode = chooseFireWallMode(logf, tableDetector)
case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables":
logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set")
hostinfo.SetFirewallMode("ipt-forced")
mode = linuxfw.FirewallModeIPTables
default:
logf("default choosing iptables")
hostinfo.SetFirewallMode("ipt-default")
mode = linuxfw.FirewallModeIPTables
}
var nfr netfilterRunner
var err error
switch mode {
case linuxfw.FirewallModeIPTables:
logf("using iptables")
nfr, err = linuxfw.NewIPTablesRunner(logf)
if err != nil {
return nil, err
}
case linuxfw.FirewallModeNfTables:
logf("using nftables")
nfr, err = linuxfw.NewNfTablesRunner(logf)
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("unknown firewall mode: %v", mode)
}
return nfr, nil
}
type linuxRouter struct { type linuxRouter struct {
closed atomic.Bool closed atomic.Bool
logf func(fmt string, args ...any) logf func(fmt string, args ...any)
@ -200,7 +60,7 @@ type linuxRouter struct {
// ipPolicyPrefBase is the base priority at which ip rules are installed. // ipPolicyPrefBase is the base priority at which ip rules are installed.
ipPolicyPrefBase int ipPolicyPrefBase int
nfr netfilterRunner nfr linuxfw.NetfilterRunner
cmd commandRunner cmd commandRunner
} }
@ -210,7 +70,7 @@ func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Moni
return nil, err return nil, err
} }
nfr, err := newNetfilterRunner(logf) nfr, err := linuxfw.New(logf)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -222,7 +82,7 @@ func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Moni
return newUserspaceRouterAdvanced(logf, tunname, netMon, nfr, cmd) return newUserspaceRouterAdvanced(logf, tunname, netMon, nfr, cmd)
} }
func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, nfr netfilterRunner, cmd commandRunner) (Router, error) { func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, nfr linuxfw.NetfilterRunner, cmd commandRunner) (Router, error) {
r := &linuxRouter{ r := &linuxRouter{
logf: logf, logf: logf,
tunname: tunname, tunname: tunname,

@ -372,7 +372,7 @@ type fakeIPTablesRunner struct {
//we always assume ipv6 and ipv6 nat are enabled when testing //we always assume ipv6 and ipv6 nat are enabled when testing
} }
func newIPTablesRunner(t *testing.T) netfilterRunner { func newIPTablesRunner(t *testing.T) linuxfw.NetfilterRunner {
return &fakeIPTablesRunner{ return &fakeIPTablesRunner{
t: t, t: t,
ipt4: map[string][]string{ ipt4: map[string][]string{
@ -603,7 +603,7 @@ type fakeOS struct {
rules []string rules []string
//This test tests on the router level, so we will not bother //This test tests on the router level, so we will not bother
//with using iptables or nftables, chose the simpler one. //with using iptables or nftables, chose the simpler one.
nfr netfilterRunner nfr linuxfw.NetfilterRunner
} }
func NewFakeOS(t *testing.T) *fakeOS { func NewFakeOS(t *testing.T) *fakeOS {
@ -1063,63 +1063,3 @@ func adjustFwmask(t *testing.T, s string) string {
return fwmaskAdjustRe.ReplaceAllString(s, "$1") return fwmaskAdjustRe.ReplaceAllString(s, "$1")
} }
type testFWDetector struct {
iptRuleCount, nftRuleCount int
iptErr, nftErr error
}
func (t *testFWDetector) iptDetect() (int, error) {
return t.iptRuleCount, t.iptErr
}
func (t *testFWDetector) nftDetect() (int, error) {
return t.nftRuleCount, t.nftErr
}
func TestChooseFireWallMode(t *testing.T) {
tests := []struct {
name string
det *testFWDetector
want linuxfw.FirewallMode
}{
{
name: "using iptables legacy",
det: &testFWDetector{iptRuleCount: 1},
want: linuxfw.FirewallModeIPTables,
},
{
name: "using nftables",
det: &testFWDetector{nftRuleCount: 1},
want: linuxfw.FirewallModeNfTables,
},
{
name: "using both iptables and nftables",
det: &testFWDetector{iptRuleCount: 2, nftRuleCount: 2},
want: linuxfw.FirewallModeNfTables,
},
{
name: "not using any firewall, both available",
det: &testFWDetector{},
want: linuxfw.FirewallModeNfTables,
},
{
name: "not using any firewall, iptables available only",
det: &testFWDetector{iptRuleCount: 1, nftErr: errors.New("nft error")},
want: linuxfw.FirewallModeIPTables,
},
{
name: "not using any firewall, nftables available only",
det: &testFWDetector{iptErr: errors.New("iptables error"), nftRuleCount: 1},
want: linuxfw.FirewallModeNfTables,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := chooseFireWallMode(t.Logf, tt.det)
if got != tt.want {
t.Errorf("chooseFireWallMode() = %v, want %v", got, tt.want)
}
})
}
}

Loading…
Cancel
Save