util/linuxfw: add additional nftable detection logic

We were previously using the netlink API to see if there are chains/rules that
already exist. This works fine in environments where there is either full
nftable support or no support at all. However, we have identified certain
environments which have partial nftable support and the only feasible way of
detecting such an environment is to try to create some of the chains that we
need.

This adds a check to create a dummy postrouting chain which is immediately
deleted. The goal of the check is to ensure we are able to use nftables and
that it won't error out later. This check is only done in the path where we
detected that the system has no preexisting nftable rules.

Updates #5621
Updates #8555
Updates #8762

Signed-off-by: Maisem Ali <maisem@tailscale.com>
pull/9838/head
Maisem Ali 1 year ago committed by Maisem Ali
parent b47cf04624
commit c3a8e63100

@ -105,6 +105,29 @@ func DebugNetfilter(logf logger.Logf) error {
// detectNetfilter returns the number of nftables rules present in the system.
func detectNetfilter() (int, error) {
// Frist try creating a dummy postrouting chain. Emperically, we have
// noticed that on some devices there is partial nftables support and the
// kernel rejects some chains that are valid on other devices. This is a
// workaround to detect that case.
//
// This specifically allows us to run in on GKE nodes using COS images which
// have partial nftables support (as of 2023-10-18). When we try to create a
// dummy postrouting chain, we get an error like:
// add chain: conn.Receive: netlink receive: no such file or directory
nft, err := newNfTablesRunner(logger.Discard)
if err != nil {
return 0, FWModeNotSupportedError{
Mode: FirewallModeNfTables,
Err: fmt.Errorf("cannot create nftables runner: %w", err),
}
}
if err := nft.createDummyPostroutingChains(); err != nil {
return 0, FWModeNotSupportedError{
Mode: FirewallModeNfTables,
Err: err,
}
}
conn, err := nftables.New()
if err != nil {
return 0, FWModeNotSupportedError{
@ -129,6 +152,7 @@ func detectNetfilter() (int, error) {
}
validRules += len(rules)
}
return validRules, nil
}

@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"tailscale.com/types/ptr"
)
const (
@ -316,8 +317,33 @@ func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
return n.conn.Flush()
}
// createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family.
func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
// deleteTableIfExists deletes a nftables table via connection c if it exists
// within the given family.
func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error {
t, err := getTableIfExists(c, family, name)
if err != nil {
return fmt.Errorf("get table: %w", err)
}
if t == nil {
// Table does not exist, so nothing to delete.
return nil
}
c.DelTable(t)
if err := c.Flush(); err != nil {
if t, err = getTableIfExists(c, family, name); t == nil && err == nil {
// Check if the table still exists. If it does not, then the error
// is due to the table not existing, so we can ignore it. Maybe a
// concurrent process deleted the table.
return nil
}
return fmt.Errorf("del table: %w", err)
}
return nil
}
// getTableIfExists returns the table with the given name from the given family
// if it exists. If none match, it returns (nil, nil).
func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
tables, err := c.ListTables()
if err != nil {
return nil, fmt.Errorf("get tables: %w", err)
@ -327,7 +353,17 @@ func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name s
return table, nil
}
}
return nil, nil
}
// createTableIfNotExist creates a nftables table via connection c if it does
// not exist within the given family.
func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) {
if t, err := getTableIfExists(c, family, name); err != nil {
return nil, fmt.Errorf("get table: %w", err)
} else if t != nil {
return t, nil
}
t := c.AddTable(&nftables.Table{
Family: family,
Name: name,
@ -365,24 +401,6 @@ func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*n
return nil, errorChainNotFound{table.Name, name}
}
// getChainsFromTable returns all chains from the given table.
func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) {
chains, err := c.ListChainsOfTableFamily(table.Family)
if err != nil {
return nil, fmt.Errorf("list chains: %w", err)
}
var ret []*nftables.Chain
for _, chain := range chains {
// Table family is already checked so table name is unique
if chain.Table.Name == table.Name {
ret = append(ret, chain)
}
}
return ret, nil
}
// isTSChain reports whether `name` begins with "ts-" (and is thus a
// Tailscale-managed chain).
func isTSChain(name string) bool {
@ -804,6 +822,43 @@ func (n *nftablesRunner) AddChains() error {
return n.conn.Flush()
}
// These are dummy chains and tables we create to detect if nftables is
// available. We create them, then delete them. If we can create and delete
// them, then we can use nftables. If we can't, then we assume that we're
// running on a system that doesn't support nftables. See
// createDummyPostroutingChains.
const (
tsDummyChainName = "ts-test-postrouting"
tsDummyTableName = "ts-test-nat"
)
// createDummyPostroutingChains creates dummy postrouting chains in netfilter
// via netfilter via nftables, as a last resort measure to detect that nftables
// can be used. It cleans up the dummy chains after creation.
func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
polAccept := ptr.To(nftables.ChainPolicyAccept)
for _, table := range n.getNATTables() {
nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
if err != nil {
return fmt.Errorf("create nat table: %w", err)
}
defer func(fm nftables.TableFamily) {
if err := deleteTableIfExists(n.conn, table.Proto, tsDummyTableName); err != nil && retErr == nil {
retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err)
}
}(table.Proto)
table.Nat = nat
if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil {
return fmt.Errorf("create %q chain: %w", tsDummyChainName, err)
}
if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil {
return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err)
}
}
return nil
}
// deleteChainIfExists deletes a chain if it exists.
func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error {
chain, err := getChainFromTable(c, table, name)

@ -851,6 +851,26 @@ func (t *testFWDetector) nftDetect() (int, error) {
return t.nftRuleCount, t.nftErr
}
// TestCreateDummyPostroutingChains tests that on a system with nftables
// available, the function does not return an error and that the dummy
// postrouting chains are cleaned up.
func TestCreateDummyPostroutingChains(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn)
if err := runner.createDummyPostroutingChains(); err != nil {
t.Fatalf("createDummyPostroutingChains() failed: %v", err)
}
for _, table := range runner.getNATTables() {
nt, err := getTableIfExists(conn, table.Proto, tsDummyTableName)
if err != nil {
t.Fatalf("getTableIfExists() failed: %v", err)
}
if nt != nil {
t.Fatalf("expected table to be nil, got %v", nt)
}
}
}
func TestPickFirewallModeFromInstalledRules(t *testing.T) {
tests := []struct {
name string

Loading…
Cancel
Save