diff --git a/util/linuxfw/iptables_runner.go b/util/linuxfw/iptables_runner.go index dea764aca..e2e04af9c 100644 --- a/util/linuxfw/iptables_runner.go +++ b/util/linuxfw/iptables_runner.go @@ -87,7 +87,7 @@ func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { } supportsV6Filter = checkSupportsV6Filter(ipt6, logf) supportsV6NAT = checkSupportsV6NAT(ipt6, logf) - logf("v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT) + logf("netfilter running in iptables mode v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT) } return &iptablesRunner{ ipt4: ipt4, diff --git a/util/linuxfw/linuxfw.go b/util/linuxfw/linuxfw.go index 0f47328f9..2e8c1330b 100644 --- a/util/linuxfw/linuxfw.go +++ b/util/linuxfw/linuxfw.go @@ -104,12 +104,19 @@ func getTailscaleSubnetRouteMark() []byte { return []byte{0x00, 0x04, 0x00, 0x00} } +// checkIPv6ForTest can be set in tests. +var checkIPv6ForTest func(logger.Logf) error + // checkIPv6 checks whether the system appears to have a working IPv6 // network stack. It returns an error explaining what looks wrong or // missing. It does not check that IPv6 is currently functional or // that there's a global address, just that the system would support // IPv6 if it were on an IPv6 network. func CheckIPv6(logf logger.Logf) error { + if f := checkIPv6ForTest; f != nil { + return f(logf) + } + _, err := os.Stat("/proc/sys/net/ipv6") if os.IsNotExist(err) { return err diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index 1bb508a98..aac28209f 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -41,8 +41,9 @@ type chainInfo struct { chainPolicy *nftables.ChainPolicy } +// nftable contains nat and filter tables for the given IP family (Proto). type nftable struct { - Proto nftables.TableFamily + Proto nftables.TableFamily // IPv4 or IPv6 Filter *nftables.Table Nat *nftables.Table } @@ -69,11 +70,10 @@ type nftable struct { // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains type nftablesRunner struct { conn *nftables.Conn - nft4 *nftable - nft6 *nftable + nft4 *nftable // IPv4 tables + nft6 *nftable // IPv6 tables - v6Available bool - v6NATAvailable bool + v6Available bool // whether the host supports IPv6 } func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) { @@ -598,6 +598,10 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { if err != nil { return nil, fmt.Errorf("nftables connection: %w", err) } + return newNfTablesRunnerWithConn(logf, conn), nil +} + +func newNfTablesRunnerWithConn(logf logger.Logf, conn *nftables.Conn) *nftablesRunner { nft4 := &nftable{Proto: nftables.TableFamilyIPv4} v6err := CheckIPv6(logf) @@ -609,8 +613,8 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { if supportsV6 { nft6 = &nftable{Proto: nftables.TableFamilyIPv6} - logf("v6nat availability: true") } + logf("netfilter running in nftables mode, v6 = %v", supportsV6) // TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables @@ -619,7 +623,7 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { nft4: nft4, nft6: nft6, v6Available: supportsV6, - }, nil + } } // newLoadSaddrExpr creates a new nftables expression that loads the source @@ -837,24 +841,15 @@ func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error { return n.conn.Flush() } -// getTables gets the available nftable in nftables runner. +// getTables returns tables for IP families that this host was determined to +// support (either IPv4 and IPv6 or just IPv4). func (n *nftablesRunner) getTables() []*nftable { - if n.v6Available { + if n.HasIPV6() { return []*nftable{n.nft4, n.nft6} } return []*nftable{n.nft4} } -// getNATTables gets the available nftable in nftables runner. -// If the system does not support IPv6 NAT, only the IPv4 nftable -// will be returned. -func (n *nftablesRunner) getNATTables() []*nftable { - if n.v6NATAvailable { - return n.getTables() - } - return []*nftable{n.nft4} -} - // AddChains creates custom Tailscale chains in netfilter via nftables // if the ts-chain doesn't already exist. func (n *nftablesRunner) AddChains() error { @@ -883,9 +878,7 @@ func (n *nftablesRunner) AddChains() error { if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create input chain: %w", err) } - } - for _, table := range n.getNATTables() { // Create the nat table if it doesn't exist, this table name is the same // as the name used by iptables-nft and ufw. We install rules into the // same conventional table so that `accept` verdicts from our jump @@ -923,7 +916,7 @@ const ( // can be used. It cleans up the dummy chains after creation. func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) { polAccept := ptr.To(nftables.ChainPolicyAccept) - for _, table := range n.getNATTables() { + for _, table := range n.getTables() { nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName) if err != nil { return fmt.Errorf("create nat table: %w", err) @@ -980,7 +973,7 @@ func (n *nftablesRunner) DelChains() error { return fmt.Errorf("delete chain: %w", err) } - if n.v6NATAvailable { + if n.HasIPV6NAT() { if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil { return fmt.Errorf("delete chain: %w", err) } @@ -1046,9 +1039,7 @@ func (n *nftablesRunner) AddHooks() error { if err != nil { return fmt.Errorf("Addhook: %w", err) } - } - for _, table := range n.getNATTables() { postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") if err != nil { return fmt.Errorf("get INPUT chain: %w", err) @@ -1102,9 +1093,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error { if err != nil { return fmt.Errorf("delhook: %w", err) } - } - for _, table := range n.getNATTables() { postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") if err != nil { return fmt.Errorf("get INPUT chain: %w", err) @@ -1612,9 +1601,7 @@ func (n *nftablesRunner) DelBase() error { return fmt.Errorf("get forward chain: %v", err) } conn.FlushChain(forwardChain) - } - for _, table := range n.getNATTables() { postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { return fmt.Errorf("get postrouting chain v4: %v", err) @@ -1684,7 +1671,7 @@ func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, cha func (n *nftablesRunner) AddSNATRule() error { conn := n.conn - for _, table := range n.getNATTables() { + for _, table := range n.getTables() { chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { return fmt.Errorf("get postrouting chain v4: %w", err) @@ -1727,7 +1714,7 @@ func (n *nftablesRunner) DelSNATRule() error { &expr.Masq{}, } - for _, table := range n.getNATTables() { + for _, table := range n.getTables() { chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { return fmt.Errorf("get postrouting chain v4: %w", err) diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index a545d3c3c..ebf514c79 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -20,6 +20,8 @@ import ( "github.com/mdlayher/netlink" "github.com/vishvananda/netns" "tailscale.com/net/tsaddr" + "tailscale.com/tstest" + "tailscale.com/types/logger" ) // nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing @@ -503,19 +505,6 @@ func cleanupSysConn(t *testing.T, ns netns.NsHandle) { } } -func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner { - nft4 := &nftable{Proto: nftables.TableFamilyIPv4} - nft6 := &nftable{Proto: nftables.TableFamilyIPv6} - - return &nftablesRunner{ - conn: conn, - nft4: nft4, - nft6: nft6, - v6Available: true, - v6NATAvailable: true, - } -} - func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) { t.Helper() got, err := conn.ListChainsOfTableFamily(fam) @@ -526,42 +515,76 @@ func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wa t.Fatalf("len(got) = %d, want %d", len(got), wantCount) } } - -func TestAddAndDelNetfilterChains(t *testing.T) { - conn := newSysConn(t) - checkChains(t, conn, nftables.TableFamilyIPv4, 0) - checkChains(t, conn, nftables.TableFamilyIPv6, 0) - - runner := newFakeNftablesRunner(t, conn) - if err := runner.AddChains(); err != nil { - t.Fatalf("runner.AddChains() failed: %v", err) - } - - tables, err := conn.ListTables() +func checkTables(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) { + t.Helper() + got, err := conn.ListTablesOfFamily(fam) if err != nil { - t.Fatalf("conn.ListTables() failed: %v", err) + t.Fatalf("conn.ListTablesOfFamily(%v) failed: %v", fam, err) } - - if len(tables) != 4 { - t.Fatalf("len(tables) = %d, want 4", len(tables)) + if len(got) != wantCount { + t.Fatalf("len(got) = %d, want %d", len(got), wantCount) } +} - checkChains(t, conn, nftables.TableFamilyIPv4, 6) - checkChains(t, conn, nftables.TableFamilyIPv6, 6) +func TestAddAndDelNetfilterChains(t *testing.T) { + type test struct { + hostHasIPv6 bool + initIPv4ChainCount int + initIPv6ChainCount int + ipv4TableCount int + ipv6TableCount int + ipv4ChainCount int + ipv6ChainCount int + ipv4ChainCountPostDelete int + ipv6ChainCountPostDelete int + } + tests := []test{ + { + hostHasIPv6: true, + initIPv4ChainCount: 0, + initIPv6ChainCount: 0, + ipv4TableCount: 2, + ipv6TableCount: 2, + ipv4ChainCount: 6, + ipv6ChainCount: 6, + ipv4ChainCountPostDelete: 3, + ipv6ChainCountPostDelete: 3, + }, + { // host without IPv6 support + ipv4TableCount: 2, + ipv4ChainCount: 6, + ipv4ChainCountPostDelete: 3, + }} + for _, tt := range tests { + t.Logf("running a test case for IPv6 support: %v", tt.hostHasIPv6) + conn := newSysConn(t) + runner := newFakeNftablesRunnerWithConn(t, conn, tt.hostHasIPv6) + + // Check that we start off with no chains. + checkChains(t, conn, nftables.TableFamilyIPv4, tt.initIPv4ChainCount) + checkChains(t, conn, nftables.TableFamilyIPv6, tt.initIPv6ChainCount) - runner.DelChains() + if err := runner.AddChains(); err != nil { + t.Fatalf("runner.AddChains() failed: %v", err) + } - // The default chains should still be present. - checkChains(t, conn, nftables.TableFamilyIPv4, 3) - checkChains(t, conn, nftables.TableFamilyIPv6, 3) + // Check that the amount of tables for each IP family is as expected. + checkTables(t, conn, nftables.TableFamilyIPv4, tt.ipv4TableCount) + checkTables(t, conn, nftables.TableFamilyIPv6, tt.ipv6TableCount) - tables, err = conn.ListTables() - if err != nil { - t.Fatalf("conn.ListTables() failed: %v", err) - } + // Check that the amount of chains for each IP family is as expected. + checkChains(t, conn, nftables.TableFamilyIPv4, tt.ipv4ChainCount) + checkChains(t, conn, nftables.TableFamilyIPv6, tt.ipv6ChainCount) + + if err := runner.DelChains(); err != nil { + t.Fatalf("runner.DelChains() failed: %v", err) + } - if len(tables) != 4 { - t.Fatalf("len(tables) = %d, want 4", len(tables)) + // Test that the tables as well as the default chains are still present. + checkChains(t, conn, nftables.TableFamilyIPv4, tt.ipv4ChainCountPostDelete) + checkChains(t, conn, nftables.TableFamilyIPv6, tt.ipv6ChainCountPostDelete) + checkTables(t, conn, nftables.TableFamilyIPv4, tt.ipv4TableCount) + checkTables(t, conn, nftables.TableFamilyIPv6, tt.ipv6TableCount) } } @@ -665,7 +688,8 @@ func checkChainRules(t *testing.T, conn *nftables.Conn, chain *nftables.Chain, w func TestNFTAddAndDelNetfilterBase(t *testing.T) { conn := newSysConn(t) - runner := newFakeNftablesRunner(t, conn) + runner := newFakeNftablesRunnerWithConn(t, conn, true) + if err := runner.AddChains(); err != nil { t.Fatalf("AddChains() failed: %v", err) } @@ -759,7 +783,7 @@ func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nf func TestNFTAddAndDelLoopbackRule(t *testing.T) { conn := newSysConn(t) - runner := newFakeNftablesRunner(t, conn) + runner := newFakeNftablesRunnerWithConn(t, conn, true) if err := runner.AddChains(); err != nil { t.Fatalf("AddChains() failed: %v", err) } @@ -817,7 +841,7 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) { func TestNFTAddAndDelHookRule(t *testing.T) { conn := newSysConn(t) - runner := newFakeNftablesRunner(t, conn) + runner := newFakeNftablesRunnerWithConn(t, conn, true) if err := runner.AddChains(); err != nil { t.Fatalf("AddChains() failed: %v", err) } @@ -868,11 +892,11 @@ func (t *testFWDetector) nftDetect() (int, error) { // postrouting chains are cleaned up. func TestCreateDummyPostroutingChains(t *testing.T) { conn := newSysConn(t) - runner := newFakeNftablesRunner(t, conn) + runner := newFakeNftablesRunnerWithConn(t, conn, true) if err := runner.createDummyPostroutingChains(); err != nil { t.Fatalf("createDummyPostroutingChains() failed: %v", err) } - for _, table := range runner.getNATTables() { + for _, table := range runner.getTables() { nt, err := getTableIfExists(conn, table.Proto, tsDummyTableName) if err != nil { t.Fatalf("getTableIfExists() failed: %v", err) @@ -929,3 +953,14 @@ func TestPickFirewallModeFromInstalledRules(t *testing.T) { }) } } + +func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bool) *nftablesRunner { + t.Helper() + if !hasIPv6 { + tstest.Replace(t, &checkIPv6ForTest, func(logger.Logf) error { + return errors.New("test: no IPv6") + }) + + } + return newNfTablesRunnerWithConn(t.Logf, conn) +}