diff --git a/util/linuxfw/nftables.go b/util/linuxfw/nftables.go index 8bf99a963..5c2723440 100644 --- a/util/linuxfw/nftables.go +++ b/util/linuxfw/nftables.go @@ -105,6 +105,29 @@ func DebugNetfilter(logf logger.Logf) error { // detectNetfilter returns the number of nftables rules present in the system. func detectNetfilter() (int, error) { + // Frist try creating a dummy postrouting chain. Emperically, we have + // noticed that on some devices there is partial nftables support and the + // kernel rejects some chains that are valid on other devices. This is a + // workaround to detect that case. + // + // This specifically allows us to run in on GKE nodes using COS images which + // have partial nftables support (as of 2023-10-18). When we try to create a + // dummy postrouting chain, we get an error like: + // add chain: conn.Receive: netlink receive: no such file or directory + nft, err := newNfTablesRunner(logger.Discard) + if err != nil { + return 0, FWModeNotSupportedError{ + Mode: FirewallModeNfTables, + Err: fmt.Errorf("cannot create nftables runner: %w", err), + } + } + if err := nft.createDummyPostroutingChains(); err != nil { + return 0, FWModeNotSupportedError{ + Mode: FirewallModeNfTables, + Err: err, + } + } + conn, err := nftables.New() if err != nil { return 0, FWModeNotSupportedError{ @@ -129,6 +152,7 @@ func detectNetfilter() (int, error) { } validRules += len(rules) } + return validRules, nil } diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index 0d438d9f5..bc1eecd9c 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" + "tailscale.com/types/ptr" ) const ( @@ -316,8 +317,33 @@ func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { return n.conn.Flush() } -// createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family. -func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { +// deleteTableIfExists deletes a nftables table via connection c if it exists +// within the given family. +func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error { + t, err := getTableIfExists(c, family, name) + if err != nil { + return fmt.Errorf("get table: %w", err) + } + if t == nil { + // Table does not exist, so nothing to delete. + return nil + } + c.DelTable(t) + if err := c.Flush(); err != nil { + if t, err = getTableIfExists(c, family, name); t == nil && err == nil { + // Check if the table still exists. If it does not, then the error + // is due to the table not existing, so we can ignore it. Maybe a + // concurrent process deleted the table. + return nil + } + return fmt.Errorf("del table: %w", err) + } + return nil +} + +// getTableIfExists returns the table with the given name from the given family +// if it exists. If none match, it returns (nil, nil). +func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { tables, err := c.ListTables() if err != nil { return nil, fmt.Errorf("get tables: %w", err) @@ -327,7 +353,17 @@ func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name s return table, nil } } + return nil, nil +} +// createTableIfNotExist creates a nftables table via connection c if it does +// not exist within the given family. +func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { + if t, err := getTableIfExists(c, family, name); err != nil { + return nil, fmt.Errorf("get table: %w", err) + } else if t != nil { + return t, nil + } t := c.AddTable(&nftables.Table{ Family: family, Name: name, @@ -365,24 +401,6 @@ func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*n return nil, errorChainNotFound{table.Name, name} } -// getChainsFromTable returns all chains from the given table. -func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) { - chains, err := c.ListChainsOfTableFamily(table.Family) - if err != nil { - return nil, fmt.Errorf("list chains: %w", err) - } - - var ret []*nftables.Chain - for _, chain := range chains { - // Table family is already checked so table name is unique - if chain.Table.Name == table.Name { - ret = append(ret, chain) - } - } - - return ret, nil -} - // isTSChain reports whether `name` begins with "ts-" (and is thus a // Tailscale-managed chain). func isTSChain(name string) bool { @@ -804,6 +822,43 @@ func (n *nftablesRunner) AddChains() error { return n.conn.Flush() } +// These are dummy chains and tables we create to detect if nftables is +// available. We create them, then delete them. If we can create and delete +// them, then we can use nftables. If we can't, then we assume that we're +// running on a system that doesn't support nftables. See +// createDummyPostroutingChains. +const ( + tsDummyChainName = "ts-test-postrouting" + tsDummyTableName = "ts-test-nat" +) + +// createDummyPostroutingChains creates dummy postrouting chains in netfilter +// via netfilter via nftables, as a last resort measure to detect that nftables +// can be used. It cleans up the dummy chains after creation. +func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) { + polAccept := ptr.To(nftables.ChainPolicyAccept) + for _, table := range n.getNATTables() { + nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName) + if err != nil { + return fmt.Errorf("create nat table: %w", err) + } + defer func(fm nftables.TableFamily) { + if err := deleteTableIfExists(n.conn, table.Proto, tsDummyTableName); err != nil && retErr == nil { + retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err) + } + }(table.Proto) + + table.Nat = nat + if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil { + return fmt.Errorf("create %q chain: %w", tsDummyChainName, err) + } + if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil { + return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err) + } + } + return nil +} + // deleteChainIfExists deletes a chain if it exists. func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error { chain, err := getChainFromTable(c, table, name) diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index b6ff44e7f..6b09317b9 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -851,6 +851,26 @@ func (t *testFWDetector) nftDetect() (int, error) { return t.nftRuleCount, t.nftErr } +// TestCreateDummyPostroutingChains tests that on a system with nftables +// available, the function does not return an error and that the dummy +// postrouting chains are cleaned up. +func TestCreateDummyPostroutingChains(t *testing.T) { + conn := newSysConn(t) + runner := newFakeNftablesRunner(t, conn) + if err := runner.createDummyPostroutingChains(); err != nil { + t.Fatalf("createDummyPostroutingChains() failed: %v", err) + } + for _, table := range runner.getNATTables() { + nt, err := getTableIfExists(conn, table.Proto, tsDummyTableName) + if err != nil { + t.Fatalf("getTableIfExists() failed: %v", err) + } + if nt != nil { + t.Fatalf("expected table to be nil, got %v", nt) + } + } +} + func TestPickFirewallModeFromInstalledRules(t *testing.T) { tests := []struct { name string