From 05a1f5bf7169a3a662fa457621a4d0b0b885aecc Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Tue, 10 Oct 2023 18:26:52 -0700 Subject: [PATCH] util/linuxfw: move detection logic Just a refactor to consolidate the firewall detection logic in a single package so that it can be reused in a later commit by containerboot. Updates #9310 Signed-off-by: Maisem Ali --- util/linuxfw/detector.go | 110 ++++++++++++++++++++ util/linuxfw/iptables.go | 4 +- util/linuxfw/iptables_runner.go | 8 +- util/linuxfw/linuxfw_unsupported.go | 6 +- util/linuxfw/nftables.go | 4 +- util/linuxfw/nftables_runner.go | 64 +++++++++++- util/linuxfw/nftables_runner_test.go | 61 +++++++++++ wgengine/router/router_linux.go | 146 +-------------------------- wgengine/router/router_linux_test.go | 64 +----------- 9 files changed, 248 insertions(+), 219 deletions(-) create mode 100644 util/linuxfw/detector.go diff --git a/util/linuxfw/detector.go b/util/linuxfw/detector.go new file mode 100644 index 000000000..17b47e2b5 --- /dev/null +++ b/util/linuxfw/detector.go @@ -0,0 +1,110 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "tailscale.com/envknob" + "tailscale.com/hostinfo" + "tailscale.com/types/logger" + "tailscale.com/version/distro" +) + +func detectFirewallMode(logf logger.Logf) FirewallMode { + if distro.Get() == distro.Gokrazy { + // Reduce startup logging on gokrazy. There's no way to do iptables on + // gokrazy anyway. + logf("GoKrazy should use nftables.") + hostinfo.SetFirewallMode("nft-gokrazy") + return FirewallModeNfTables + } + + envMode := envknob.String("TS_DEBUG_FIREWALL_MODE") + // We now use iptables as default and have "auto" and "nftables" as + // options for people to test further. + switch envMode { + case "auto": + return pickFirewallModeFromInstalledRules(logf, linuxFWDetector{}) + case "nftables": + logf("envknob TS_DEBUG_FIREWALL_MODE=nftables set") + hostinfo.SetFirewallMode("nft-forced") + return FirewallModeNfTables + case "iptables": + logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set") + hostinfo.SetFirewallMode("ipt-forced") + default: + logf("default choosing iptables") + hostinfo.SetFirewallMode("ipt-default") + } + return FirewallModeIPTables +} + +// tableDetector abstracts helpers to detect the firewall mode. +// It is implemented for testing purposes. +type tableDetector interface { + iptDetect() (int, error) + nftDetect() (int, error) +} + +type linuxFWDetector struct{} + +// iptDetect returns the number of iptables rules in the current namespace. +func (l linuxFWDetector) iptDetect() (int, error) { + return detectIptables() +} + +// nftDetect returns the number of nftables rules in the current namespace. +func (l linuxFWDetector) nftDetect() (int, error) { + return detectNetfilter() +} + +// pickFirewallModeFromInstalledRules returns the firewall mode to use based on +// the environment and the system's capabilities. +func pickFirewallModeFromInstalledRules(logf logger.Logf, det tableDetector) FirewallMode { + if distro.Get() == distro.Gokrazy { + // Reduce startup logging on gokrazy. There's no way to do iptables on + // gokrazy anyway. + return FirewallModeNfTables + } + iptAva, nftAva := true, true + iptRuleCount, err := det.iptDetect() + if err != nil { + logf("detect iptables rule: %v", err) + iptAva = false + } + nftRuleCount, err := det.nftDetect() + if err != nil { + logf("detect nftables rule: %v", err) + nftAva = false + } + logf("nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount) + switch { + case nftRuleCount > 0 && iptRuleCount == 0: + logf("nftables is currently in use") + hostinfo.SetFirewallMode("nft-inuse") + return FirewallModeNfTables + case iptRuleCount > 0 && nftRuleCount == 0: + logf("iptables is currently in use") + hostinfo.SetFirewallMode("ipt-inuse") + return FirewallModeIPTables + case nftAva: + // if both iptables and nftables are available but + // neither/both are currently used, use nftables. + logf("nftables is available") + hostinfo.SetFirewallMode("nft") + return FirewallModeNfTables + case iptAva: + logf("iptables is available") + hostinfo.SetFirewallMode("ipt") + return FirewallModeIPTables + default: + // if neither iptables nor nftables are available, use iptablesRunner as a dummy + // runner which exists but won't do anything. Creating iptablesRunner errors only + // if the iptables command is missing or doesn’t support "--version", as long as it + // can determine a version then it’ll carry on. + hostinfo.SetFirewallMode("ipt-fb") + return FirewallModeIPTables + } +} diff --git a/util/linuxfw/iptables.go b/util/linuxfw/iptables.go index 3cc612d03..7231c83fe 100644 --- a/util/linuxfw/iptables.go +++ b/util/linuxfw/iptables.go @@ -23,13 +23,13 @@ func DebugIptables(logf logger.Logf) error { return nil } -// DetectIptables returns the number of iptables rules that are present in the +// detectIptables returns the number of iptables rules that are present in the // system, ignoring the default "ACCEPT" rule present in the standard iptables // chains. // // It only returns an error when there is no iptables binary, or when iptables -S // fails. In all other cases, it returns the number of non-default rules. -func DetectIptables() (int, error) { +func detectIptables() (int, error) { // run "iptables -S" to get the list of rules using iptables // exec.Command returns an error if the binary is not found cmd := exec.Command("iptables", "-S") diff --git a/util/linuxfw/iptables_runner.go b/util/linuxfw/iptables_runner.go index 3c4199ece..d703190bc 100644 --- a/util/linuxfw/iptables_runner.go +++ b/util/linuxfw/iptables_runner.go @@ -45,11 +45,11 @@ func checkIP6TablesExists() error { return nil } -// NewIPTablesRunner constructs a NetfilterRunner that programs iptables rules. +// newIPTablesRunner constructs a NetfilterRunner that programs iptables rules. // If the underlying iptables library fails to initialize, that error is // returned. The runner probes for IPv6 support once at initialization time and // if not found, no IPv6 rules will be modified for the lifetime of the runner. -func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { +func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { ipt4, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) if err != nil { return nil, err @@ -79,12 +79,12 @@ func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { return &iptablesRunner{ipt4, ipt6, supportsV6, supportsV6NAT}, nil } -// HasIPV6 returns true if the system supports IPv6. +// HasIPV6 reports true if the system supports IPv6. func (i *iptablesRunner) HasIPV6() bool { return i.v6Available } -// HasIPV6NAT returns true if the system supports IPv6 NAT. +// HasIPV6NAT reports true if the system supports IPv6 NAT. func (i *iptablesRunner) HasIPV6NAT() bool { return i.v6NATAvailable } diff --git a/util/linuxfw/linuxfw_unsupported.go b/util/linuxfw/linuxfw_unsupported.go index 4c6029af1..003d4bdff 100644 --- a/util/linuxfw/linuxfw_unsupported.go +++ b/util/linuxfw/linuxfw_unsupported.go @@ -25,16 +25,16 @@ func DebugNetfilter(logf logger.Logf) error { } // DetectNetfilter is not supported on non-Linux platforms. -func DetectNetfilter() (int, error) { +func detectNetfilter() (int, error) { return 0, ErrUnsupported } // DebugIptables is not supported on non-Linux platforms. -func DebugIptables(logf logger.Logf) error { +func debugIptables(logf logger.Logf) error { return ErrUnsupported } // DetectIptables is not supported on non-Linux platforms. -func DetectIptables() (int, error) { +func detectIptables() (int, error) { return 0, ErrUnsupported } diff --git a/util/linuxfw/nftables.go b/util/linuxfw/nftables.go index afe6dfa6e..8bf99a963 100644 --- a/util/linuxfw/nftables.go +++ b/util/linuxfw/nftables.go @@ -103,8 +103,8 @@ func DebugNetfilter(logf logger.Logf) error { return nil } -// DetectNetfilter returns the number of nftables rules present in the system. -func DetectNetfilter() (int, error) { +// detectNetfilter returns the number of nftables rules present in the system. +func detectNetfilter() (int, error) { conn, err := nftables.New() if err != nil { return 0, FWModeNotSupportedError{ diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index d7588107c..d87610dda 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -175,9 +175,67 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { return nil } -// NewNfTablesRunner creates a new nftablesRunner without guaranteeing +// NetfilterRunner abstracts helpers to run netfilter commands. It is +// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner. +type NetfilterRunner interface { + // AddLoopbackRule adds a rule to permit loopback traffic to addr. This rule + // is added only if it does not already exist. + AddLoopbackRule(addr netip.Addr) error + + // DelLoopbackRule removes the rule added by AddLoopbackRule. + DelLoopbackRule(addr netip.Addr) error + + // AddHooks adds rules to conventional chains like "FORWARD", "INPUT" and + // "POSTROUTING" to jump from those chains to tailscale chains. + AddHooks() error + + // DelHooks deletes rules added by AddHooks. + DelHooks(logf logger.Logf) error + + // AddChains creates custom Tailscale chains. + AddChains() error + + // DelChains removes chains added by AddChains. + DelChains() error + + // AddBase adds rules reused by different other rules. + AddBase(tunname string) error + + // DelBase removes rules added by AddBase. + DelBase() error + + // AddSNATRule adds the netfilter rule to SNAT incoming traffic over + // the Tailscale interface destined for local subnets. An error is + // returned if the rule already exists. + AddSNATRule() error + + // DelSNATRule removes the rule added by AddSNATRule. + DelSNATRule() error + + // HasIPV6 reports true if the system supports IPv6. + HasIPV6() bool + + // HasIPV6NAT reports true if the system supports IPv6 NAT. + HasIPV6NAT() bool +} + +// New creates a NetfilterRunner using either nftables or iptables. +// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set. +func New(logf logger.Logf) (NetfilterRunner, error) { + mode := detectFirewallMode(logf) + switch mode { + case FirewallModeIPTables: + return newIPTablesRunner(logf) + case FirewallModeNfTables: + return newNfTablesRunner(logf) + default: + return nil, fmt.Errorf("unknown firewall mode %v", mode) + } +} + +// newNfTablesRunner creates a new nftablesRunner without guaranteeing // the existence of the tables and chains. -func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { +func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { conn, err := nftables.New() if err != nil { return nil, fmt.Errorf("nftables connection: %w", err) @@ -231,7 +289,7 @@ func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, err } } -// HasIPV6 returns true if the system supports IPv6. +// HasIPV6 reports true if the system supports IPv6. func (n *nftablesRunner) HasIPV6() bool { return n.v6Available } diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index b8c66363f..1a451238b 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -7,6 +7,7 @@ package linuxfw import ( "bytes" + "errors" "fmt" "net/netip" "os" @@ -946,3 +947,63 @@ func TestNFTAddAndDelHookRule(t *testing.T) { t.Fatalf("expected 0 rule in POSTROUTING chain, got %v", len(postroutingChainRules)) } } + +type testFWDetector struct { + iptRuleCount, nftRuleCount int + iptErr, nftErr error +} + +func (t *testFWDetector) iptDetect() (int, error) { + return t.iptRuleCount, t.iptErr +} + +func (t *testFWDetector) nftDetect() (int, error) { + return t.nftRuleCount, t.nftErr +} + +func TestPickFirewallModeFromInstalledRules(t *testing.T) { + tests := []struct { + name string + det *testFWDetector + want FirewallMode + }{ + { + name: "using iptables legacy", + det: &testFWDetector{iptRuleCount: 1}, + want: FirewallModeIPTables, + }, + { + name: "using nftables", + det: &testFWDetector{nftRuleCount: 1}, + want: FirewallModeNfTables, + }, + { + name: "using both iptables and nftables", + det: &testFWDetector{iptRuleCount: 2, nftRuleCount: 2}, + want: FirewallModeNfTables, + }, + { + name: "not using any firewall, both available", + det: &testFWDetector{}, + want: FirewallModeNfTables, + }, + { + name: "not using any firewall, iptables available only", + det: &testFWDetector{iptRuleCount: 1, nftErr: errors.New("nft error")}, + want: FirewallModeIPTables, + }, + { + name: "not using any firewall, nftables available only", + det: &testFWDetector{iptErr: errors.New("iptables error"), nftRuleCount: 1}, + want: FirewallModeNfTables, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := pickFirewallModeFromInstalledRules(t.Logf, tt.det) + if got != tt.want { + t.Errorf("chooseFireWallMode() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/wgengine/router/router_linux.go b/wgengine/router/router_linux.go index 8a7273bd2..17bf38693 100644 --- a/wgengine/router/router_linux.go +++ b/wgengine/router/router_linux.go @@ -22,7 +22,6 @@ import ( "golang.org/x/sys/unix" "golang.org/x/time/rate" "tailscale.com/envknob" - "tailscale.com/hostinfo" "tailscale.com/net/netmon" "tailscale.com/types/logger" "tailscale.com/types/preftype" @@ -37,145 +36,6 @@ const ( netfilterOn = preftype.NetfilterOn ) -// netfilterRunner abstracts helpers to run netfilter commands. It is -// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner. -type netfilterRunner interface { - AddLoopbackRule(addr netip.Addr) error - DelLoopbackRule(addr netip.Addr) error - AddHooks() error - DelHooks(logf logger.Logf) error - AddChains() error - DelChains() error - AddBase(tunname string) error - DelBase() error - AddSNATRule() error - DelSNATRule() error - - HasIPV6() bool - HasIPV6NAT() bool -} - -// tableDetector abstracts helpers to detect the firewall mode. -// It is implemented for testing purposes. -type tableDetector interface { - iptDetect() (int, error) - nftDetect() (int, error) -} - -type linuxFWDetector struct{} - -// iptDetect returns the number of iptables rules in the current namespace. -func (l *linuxFWDetector) iptDetect() (int, error) { - return linuxfw.DetectIptables() -} - -// nftDetect returns the number of nftables rules in the current namespace. -func (l *linuxFWDetector) nftDetect() (int, error) { - return linuxfw.DetectNetfilter() -} - -// chooseFireWallMode returns the firewall mode to use based on the -// environment and the system's capabilities. -func chooseFireWallMode(logf logger.Logf, det tableDetector) linuxfw.FirewallMode { - if distro.Get() == distro.Gokrazy { - // Reduce startup logging on gokrazy. There's no way to do iptables on - // gokrazy anyway. - return linuxfw.FirewallModeNfTables - } - iptAva, nftAva := true, true - iptRuleCount, err := det.iptDetect() - if err != nil { - logf("detect iptables rule: %v", err) - iptAva = false - } - nftRuleCount, err := det.nftDetect() - if err != nil { - logf("detect nftables rule: %v", err) - nftAva = false - } - logf("nftables rule count: %d, iptables rule count: %d", nftRuleCount, iptRuleCount) - switch { - case nftRuleCount > 0 && iptRuleCount == 0: - logf("nftables is currently in use") - hostinfo.SetFirewallMode("nft-inuse") - return linuxfw.FirewallModeNfTables - case iptRuleCount > 0 && nftRuleCount == 0: - logf("iptables is currently in use") - hostinfo.SetFirewallMode("ipt-inuse") - return linuxfw.FirewallModeIPTables - case nftAva: - // if both iptables and nftables are available but - // neither/both are currently used, use nftables. - logf("nftables is available") - hostinfo.SetFirewallMode("nft") - return linuxfw.FirewallModeNfTables - case iptAva: - logf("iptables is available") - hostinfo.SetFirewallMode("ipt") - return linuxfw.FirewallModeIPTables - default: - // if neither iptables nor nftables are available, use iptablesRunner as a dummy - // runner which exists but won't do anything. Creating iptablesRunner errors only - // if the iptables command is missing or doesn’t support "--version", as long as it - // can determine a version then it’ll carry on. - hostinfo.SetFirewallMode("ipt-fb") - return linuxfw.FirewallModeIPTables - } -} - -// newNetfilterRunner creates a netfilterRunner using either nftables or iptables. -// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set. -func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) { - tableDetector := &linuxFWDetector{} - var mode linuxfw.FirewallMode - - // We now use iptables as default and have "auto" and "nftables" as - // options for people to test further. - switch { - case distro.Get() == distro.Gokrazy: - // Reduce startup logging on gokrazy. There's no way to do iptables on - // gokrazy anyway. - logf("GoKrazy should use nftables.") - hostinfo.SetFirewallMode("nft-gokrazy") - mode = linuxfw.FirewallModeNfTables - case envknob.String("TS_DEBUG_FIREWALL_MODE") == "nftables": - logf("envknob TS_DEBUG_FIREWALL_MODE=nftables set") - hostinfo.SetFirewallMode("nft-forced") - mode = linuxfw.FirewallModeNfTables - case envknob.String("TS_DEBUG_FIREWALL_MODE") == "auto": - mode = chooseFireWallMode(logf, tableDetector) - case envknob.String("TS_DEBUG_FIREWALL_MODE") == "iptables": - logf("envknob TS_DEBUG_FIREWALL_MODE=iptables set") - hostinfo.SetFirewallMode("ipt-forced") - mode = linuxfw.FirewallModeIPTables - default: - logf("default choosing iptables") - hostinfo.SetFirewallMode("ipt-default") - mode = linuxfw.FirewallModeIPTables - } - - var nfr netfilterRunner - var err error - switch mode { - case linuxfw.FirewallModeIPTables: - logf("using iptables") - nfr, err = linuxfw.NewIPTablesRunner(logf) - if err != nil { - return nil, err - } - case linuxfw.FirewallModeNfTables: - logf("using nftables") - nfr, err = linuxfw.NewNfTablesRunner(logf) - if err != nil { - return nil, err - } - default: - return nil, fmt.Errorf("unknown firewall mode: %v", mode) - } - - return nfr, nil -} - type linuxRouter struct { closed atomic.Bool logf func(fmt string, args ...any) @@ -200,7 +60,7 @@ type linuxRouter struct { // ipPolicyPrefBase is the base priority at which ip rules are installed. ipPolicyPrefBase int - nfr netfilterRunner + nfr linuxfw.NetfilterRunner cmd commandRunner } @@ -210,7 +70,7 @@ func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Moni return nil, err } - nfr, err := newNetfilterRunner(logf) + nfr, err := linuxfw.New(logf) if err != nil { return nil, err } @@ -222,7 +82,7 @@ func newUserspaceRouter(logf logger.Logf, tunDev tun.Device, netMon *netmon.Moni return newUserspaceRouterAdvanced(logf, tunname, netMon, nfr, cmd) } -func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, nfr netfilterRunner, cmd commandRunner) (Router, error) { +func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon.Monitor, nfr linuxfw.NetfilterRunner, cmd commandRunner) (Router, error) { r := &linuxRouter{ logf: logf, tunname: tunname, diff --git a/wgengine/router/router_linux_test.go b/wgengine/router/router_linux_test.go index 761cdc44b..d77708f51 100644 --- a/wgengine/router/router_linux_test.go +++ b/wgengine/router/router_linux_test.go @@ -372,7 +372,7 @@ type fakeIPTablesRunner struct { //we always assume ipv6 and ipv6 nat are enabled when testing } -func newIPTablesRunner(t *testing.T) netfilterRunner { +func newIPTablesRunner(t *testing.T) linuxfw.NetfilterRunner { return &fakeIPTablesRunner{ t: t, ipt4: map[string][]string{ @@ -603,7 +603,7 @@ type fakeOS struct { rules []string //This test tests on the router level, so we will not bother //with using iptables or nftables, chose the simpler one. - nfr netfilterRunner + nfr linuxfw.NetfilterRunner } func NewFakeOS(t *testing.T) *fakeOS { @@ -1063,63 +1063,3 @@ func adjustFwmask(t *testing.T, s string) string { return fwmaskAdjustRe.ReplaceAllString(s, "$1") } - -type testFWDetector struct { - iptRuleCount, nftRuleCount int - iptErr, nftErr error -} - -func (t *testFWDetector) iptDetect() (int, error) { - return t.iptRuleCount, t.iptErr -} - -func (t *testFWDetector) nftDetect() (int, error) { - return t.nftRuleCount, t.nftErr -} - -func TestChooseFireWallMode(t *testing.T) { - tests := []struct { - name string - det *testFWDetector - want linuxfw.FirewallMode - }{ - { - name: "using iptables legacy", - det: &testFWDetector{iptRuleCount: 1}, - want: linuxfw.FirewallModeIPTables, - }, - { - name: "using nftables", - det: &testFWDetector{nftRuleCount: 1}, - want: linuxfw.FirewallModeNfTables, - }, - { - name: "using both iptables and nftables", - det: &testFWDetector{iptRuleCount: 2, nftRuleCount: 2}, - want: linuxfw.FirewallModeNfTables, - }, - { - name: "not using any firewall, both available", - det: &testFWDetector{}, - want: linuxfw.FirewallModeNfTables, - }, - { - name: "not using any firewall, iptables available only", - det: &testFWDetector{iptRuleCount: 1, nftErr: errors.New("nft error")}, - want: linuxfw.FirewallModeIPTables, - }, - { - name: "not using any firewall, nftables available only", - det: &testFWDetector{iptErr: errors.New("iptables error"), nftRuleCount: 1}, - want: linuxfw.FirewallModeNfTables, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := chooseFireWallMode(t.Logf, tt.det) - if got != tt.want { - t.Errorf("chooseFireWallMode() = %v, want %v", got, tt.want) - } - }) - } -}