diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index 4d46ea104..a4d65857a 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -13,6 +13,7 @@ import ( "net" "net/netip" "reflect" + "strings" "github.com/google/nftables" "github.com/google/nftables/expr" @@ -26,12 +27,16 @@ const ( chainNamePostrouting = "ts-postrouting" ) +// chainTypeRegular is an nftables chain that does not apply to a hook. +const chainTypeRegular = "" + type chainInfo struct { table *nftables.Table name string chainType nftables.ChainType chainHook *nftables.ChainHook chainPriority *nftables.ChainPriority + chainPolicy *nftables.ChainPolicy } type nftable struct { @@ -40,6 +45,21 @@ type nftable struct { Nat *nftables.Table } +// nftablesRunner implements a netfilterRunner using the netlink based nftables +// library. As nftables allows for arbitrary tables and chains, there is a need +// to follow conventions in order to integrate well with a surrounding +// ecosystem. The rules installed by nftablesRunner have the following +// properties: +// - Install rules that intend to take precedence over rules installed by +// other software. Tailscale provides packet filtering for tailnet traffic +// inside the daemon based on the tailnet ACL rules. +// - As nftables "accept" is not final, rules from high priority tables (low +// numbers) will fall through to lower priority tables (high numbers). In +// order to effectively be 'final', we install "jump" rules into conventional +// tables and chains that will reach an accept verdict inside those tables. +// - The table and chain conventions followed here are those used by +// `iptables-nft` and `ufw`, so that those tools co-exist and do not +// negatively affect Tailscale function. type nftablesRunner struct { conn *nftables.Conn nft4 *nftable @@ -116,6 +136,11 @@ func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Ch return ret, nil } +// isTSChain retruns true if the chain name starts with ts +func isTSChain(name string) bool { + return strings.HasPrefix(name, "ts-") +} + // createChainIfNotExist creates a chain with the given name in the given table // if it does not exist. func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { @@ -123,8 +148,11 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) { return fmt.Errorf("get chain: %w", err) } else if err == nil { - // Chain already exists - if chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority { + // The chain already exists. If it is a TS chain, check the + // type/hook/priority, but for "conventional chains" assume they're what + // we expect (in case iptables-nft/ufw make minor behavior changes in + // the future). + if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority) { return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name) } return nil @@ -136,6 +164,7 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { Type: cinfo.chainType, Hooknum: cinfo.chainHook, Priority: cinfo.chainPriority, + Policy: cinfo.chainPolicy, }) if err := c.Flush(); err != nil { @@ -228,6 +257,10 @@ ruleLoop: } for i, e := range r.Exprs { + // Skip counter expressions, as they will not match. + if _, ok := e.(*expr.Counter); ok { + continue + } if !reflect.DeepEqual(e, rule.Exprs[i]) { continue ruleLoop } @@ -388,27 +421,49 @@ func (n *nftablesRunner) getNATTables() []*nftable { // AddChains creates custom Tailscale chains in netfilter via nftables // if the ts-chain doesn't already exist. func (n *nftablesRunner) AddChains() error { + polAccept := nftables.ChainPolicyAccept for _, table := range n.getTables() { - filter, err := createTableIfNotExist(n.conn, table.Proto, "ts-filter") + // Create the filter table if it doesn't exist, this table name is the same + // as the name used by iptables-nft and ufw. We install rules into the + // same conventional table so that `accept` verdicts from our jump + // chains are conclusive. + filter, err := createTableIfNotExist(n.conn, table.Proto, "filter") if err != nil { return fmt.Errorf("create table: %w", err) } table.Filter = filter - if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityRef(-1)}); err != nil { + // Adding the "conventional chains" that are used by iptables-nft and ufw. + if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil { + return fmt.Errorf("create forward chain: %w", err) + } + if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil { + return fmt.Errorf("create input chain: %w", err) + } + // Adding the tailscale chains that contain our rules. + if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create forward chain: %w", err) } - if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityRef(-1)}); err != nil { + if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create input chain: %w", err) } } for _, table := range n.getNATTables() { - nat, err := createTableIfNotExist(n.conn, table.Proto, "ts-nat") + // Create the nat table if it doesn't exist, this table name is the same + // as the name used by iptables-nft and ufw. We install rules into the + // same conventional table so that `accept` verdicts from our jump + // chains are conclusive. + nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") if err != nil { return fmt.Errorf("create table: %w", err) } table.Nat = nat - if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATDest}); err != nil { + // Adding the "conventional chains" that are used by iptables-nft and ufw. + if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil { + return fmt.Errorf("create postrouting chain: %w", err) + } + // Adding the tailscale chain that contains our rules. + if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create postrouting chain: %w", err) } } @@ -445,19 +500,16 @@ func (n *nftablesRunner) DelChains() error { if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil { return fmt.Errorf("delete chain: %w", err) } - n.conn.DelTable(table.Filter) } if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil { return fmt.Errorf("delete chain: %w", err) } - n.conn.DelTable(n.nft4.Nat) if n.v6NATAvailable { if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil { return fmt.Errorf("delete chain: %w", err) } - n.conn.DelTable(n.nft6.Nat) } if err := n.conn.Flush(); err != nil { @@ -467,15 +519,128 @@ func (n *nftablesRunner) DelChains() error { return nil } -// AddHooks is defined to satisfy the interface. NfTables does not require -// AddHooks, since we don't have any default tables or chains in nftables. +// createHookRule creates a rule to jump from a hooked chain to a regular chain. +func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule { + exprs := []expr.Any{ + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictJump, + Chain: toChainName, + }, + } + + rule := &nftables.Rule{ + Table: table, + Chain: fromChain, + Exprs: exprs, + } + + return rule +} + +// addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain. +func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { + rule := createHookRule(table, fromChain, toChainName) + _ = conn.InsertRule(rule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush add rule: %w", err) + } + + return nil +} + +// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING" +// in tables and jump from those chains to tailscale chains. func (n *nftablesRunner) AddHooks() error { + conn := n.conn + + for _, table := range n.getTables() { + inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = addHookRule(conn, table.Filter, inputChain, chainNameInput) + if err != nil { + return fmt.Errorf("Addhook: %w", err) + } + forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") + if err != nil { + return fmt.Errorf("get FORWARD chain: %w", err) + } + err = addHookRule(conn, table.Filter, forwardChain, chainNameForward) + if err != nil { + return fmt.Errorf("Addhook: %w", err) + } + } + + for _, table := range n.getNATTables() { + postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting) + if err != nil { + return fmt.Errorf("Addhook: %w", err) + } + } return nil } -// DelHooks is defined to satisfy the interface. NfTables does not require -// DelHooks, since we don't have any default tables or chains in nftables. +// delHookRule deletes a rule that jumps from a hooked chain to a regular chain. +func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { + rule := createHookRule(table, fromChain, toChainName) + existingRule, err := findRule(conn, rule) + if err != nil { + return fmt.Errorf("Failed to find hook rule: %w", err) + } + + if existingRule == nil { + return nil + } + + _ = conn.DelRule(existingRule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush del hook rule: %w", err) + } + return nil +} + +// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains. func (n *nftablesRunner) DelHooks(logf logger.Logf) error { + conn := n.conn + + for _, table := range n.getTables() { + inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = delHookRule(conn, table.Filter, inputChain, chainNameInput) + if err != nil { + return fmt.Errorf("delhook: %w", err) + } + forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") + if err != nil { + return fmt.Errorf("get FORWARD chain: %w", err) + } + err = delHookRule(conn, table.Filter, forwardChain, chainNameForward) + if err != nil { + return fmt.Errorf("delhook: %w", err) + } + } + + for _, table := range n.getNATTables() { + postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") + if err != nil { + return fmt.Errorf("get INPUT chain: %w", err) + } + err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting) + if err != nil { + return fmt.Errorf("delhook: %w", err) + } + } + return nil } @@ -953,25 +1118,62 @@ func (n *nftablesRunner) DelSNATRule() error { return nil } +// cleanupChain removes a jump rule from hookChainName to tsChainName, and then +// the entire chain tsChainName. Errors are logged, but attempts to remove both +// the jump rule and chain continue even if one errors. +func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) { + // remove the jump first, before removing the jump destination. + defaultChain, err := getChainFromTable(conn, table, hookChainName) + if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) { + logf("cleanup: did not find default chain: %s", err) + } + if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) { + // delete hook in convention chain + _ = delHookRule(conn, table, defaultChain, tsChainName) + } + + tsChain, err := getChainFromTable(conn, table, tsChainName) + if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) { + logf("cleanup: did not find ts-chain: %s", err) + } + + if tsChain != nil { + // flush and delete ts-chain + conn.FlushChain(tsChain) + conn.DelChain(tsChain) + err = conn.Flush() + logf("cleanup: delete and flush chain %s: %s", tsChainName, err) + } +} + // NfTablesCleanUp removes all Tailscale added nftables rules. // Any errors that occur are logged to the provided logf. func NfTablesCleanUp(logf logger.Logf) { conn, err := nftables.New() if err != nil { - logf("ERROR: nftables connection: %w", err) + logf("cleanup: nftables connection: %s", err) } tables, err := conn.ListTables() // both v4 and v6 if err != nil { - logf("ERROR: list tables: %w", err) + logf("cleanup: list tables: %s", err) } for _, table := range tables { + // These table names were used briefly in 1.48.0. if table.Name == "ts-filter" || table.Name == "ts-nat" { conn.DelTable(table) if err := conn.Flush(); err != nil { - logf("ERROR: flush table %s: %w", table.Name, err) + logf("cleanup: flush delete table %s: %s", table.Name, err) } } + + if table.Name == "filter" { + cleanupChain(logf, conn, table, "INPUT", chainNameInput) + cleanupChain(logf, conn, table, "FORWARD", chainNameForward) + } + if table.Name == "nat" { + cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting) + } } } diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index ab4543b2d..ad068957e 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -101,6 +101,48 @@ func newTestConn(t *testing.T, want [][]byte) *nftables.Conn { return conn } +func TestInsertHookRule(t *testing.T) { + proto := nftables.TableFamilyIPv4 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip ts-filter-test + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip ts-filter-test ts-input-test { type filter hook input priority 0 \; } + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x03\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), + // nft add chain ip ts-filter-test ts-jumpto + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x0e\x00\x03\x00\x74\x73\x2d\x6a\x75\x6d\x70\x74\x6f\x00\x00\x00"), + // nft add rule ip ts-filter-test ts-input-test counter jump ts-jumptp + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x02\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x70\x00\x04\x80\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x40\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x20\x00\x02\x80\x1c\x00\x02\x80\x08\x00\x01\x00\xff\xff\xff\xfd\x0e\x00\x02\x00\x74\x73\x2d\x6a\x75\x6d\x70\x74\x6f\x00\x00\x00"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + testConn := newTestConn(t, want) + table := testConn.AddTable(&nftables.Table{ + Family: proto, + Name: "ts-filter-test", + }) + + fromchain := testConn.AddChain(&nftables.Chain{ + Name: "ts-input-test", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }) + + tochain := testConn.AddChain(&nftables.Chain{ + Name: "ts-jumpto", + Table: table, + }) + + err := addHookRule(testConn, table, fromchain, tochain.Name) + if err != nil { + t.Fatal(err) + } + +} + func TestInsertLoopbackRule(t *testing.T) { proto := nftables.TableFamilyIPv4 want := [][]byte{ @@ -461,8 +503,8 @@ func TestAddAndDelNetfilterChains(t *testing.T) { t.Fatalf("list chains failed: %v", err) } - if len(chainsV4) != 3 { - t.Fatalf("len(chainsV4) = %d, want 3", len(chainsV4)) + if len(chainsV4) != 6 { + t.Fatalf("len(chainsV4) = %d, want 6", len(chainsV4)) } chainsV6, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv6) @@ -470,8 +512,8 @@ func TestAddAndDelNetfilterChains(t *testing.T) { t.Fatalf("list chains failed: %v", err) } - if len(chainsV6) != 3 { - t.Fatalf("len(chainsV6) = %d, want 3", len(chainsV6)) + if len(chainsV6) != 6 { + t.Fatalf("len(chainsV6) = %d, want 6", len(chainsV6)) } runner.DelChains() @@ -788,3 +830,87 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) { t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules)) } } + +func TestNFTAddAndDelHookRule(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip(t.Name(), " requires privileges to create a namespace in order to run") + return + } + + conn := newSysConn(t) + runner := newFakeNftablesRunner(t, conn) + runner.AddChains() + defer runner.DelChains() + runner.AddHooks() + + forwardChain, err := getChainFromTable(conn, runner.nft4.Filter, "FORWARD") + if err != nil { + t.Fatalf("failed to get forwardChain: %v", err) + } + + forwardChainRules, err := conn.GetRules(forwardChain.Table, forwardChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(forwardChainRules) != 1 { + t.Fatalf("expected 1 rule in FORWARD chain, got %v", len(forwardChainRules)) + } + + inputChain, err := getChainFromTable(conn, runner.nft4.Filter, "INPUT") + if err != nil { + t.Fatalf("failed to get inputChain: %v", err) + } + + inputChainRules, err := conn.GetRules(inputChain.Table, inputChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(inputChainRules) != 1 { + t.Fatalf("expected 1 rule in INPUT chain, got %v", len(inputChainRules)) + } + + postroutingChain, err := getChainFromTable(conn, runner.nft4.Nat, "POSTROUTING") + if err != nil { + t.Fatalf("failed to get postroutingChain: %v", err) + } + + postroutingChainRules, err := conn.GetRules(postroutingChain.Table, postroutingChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(postroutingChainRules) != 1 { + t.Fatalf("expected 1 rule in POSTROUTING chain, got %v", len(postroutingChainRules)) + } + + runner.DelHooks(t.Logf) + + forwardChainRules, err = conn.GetRules(forwardChain.Table, forwardChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(forwardChainRules) != 0 { + t.Fatalf("expected 0 rule in FORWARD chain, got %v", len(forwardChainRules)) + } + + inputChainRules, err = conn.GetRules(inputChain.Table, inputChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(inputChainRules) != 0 { + t.Fatalf("expected 0 rule in INPUT chain, got %v", len(inputChainRules)) + } + + postroutingChainRules, err = conn.GetRules(postroutingChain.Table, postroutingChain) + if err != nil { + t.Fatalf("failed to get rules: %v", err) + } + + if len(postroutingChainRules) != 0 { + t.Fatalf("expected 0 rule in POSTROUTING chain, got %v", len(postroutingChainRules)) + } +}