@ -13,6 +13,7 @@ import (
"net"
"net/netip"
"reflect"
"strings"
"github.com/google/nftables"
"github.com/google/nftables/expr"
@ -26,12 +27,16 @@ const (
chainNamePostrouting = "ts-postrouting"
)
// chainTypeRegular is an nftables chain that does not apply to a hook.
const chainTypeRegular = ""
type chainInfo struct {
table * nftables . Table
name string
chainType nftables . ChainType
chainHook * nftables . ChainHook
chainPriority * nftables . ChainPriority
chainPolicy * nftables . ChainPolicy
}
type nftable struct {
@ -40,6 +45,21 @@ type nftable struct {
Nat * nftables . Table
}
// nftablesRunner implements a netfilterRunner using the netlink based nftables
// library. As nftables allows for arbitrary tables and chains, there is a need
// to follow conventions in order to integrate well with a surrounding
// ecosystem. The rules installed by nftablesRunner have the following
// properties:
// - Install rules that intend to take precedence over rules installed by
// other software. Tailscale provides packet filtering for tailnet traffic
// inside the daemon based on the tailnet ACL rules.
// - As nftables "accept" is not final, rules from high priority tables (low
// numbers) will fall through to lower priority tables (high numbers). In
// order to effectively be 'final', we install "jump" rules into conventional
// tables and chains that will reach an accept verdict inside those tables.
// - The table and chain conventions followed here are those used by
// `iptables-nft` and `ufw`, so that those tools co-exist and do not
// negatively affect Tailscale function.
type nftablesRunner struct {
conn * nftables . Conn
nft4 * nftable
@ -116,6 +136,11 @@ func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Ch
return ret , nil
}
// isTSChain retruns true if the chain name starts with ts
func isTSChain ( name string ) bool {
return strings . HasPrefix ( name , "ts-" )
}
// createChainIfNotExist creates a chain with the given name in the given table
// if it does not exist.
func createChainIfNotExist ( c * nftables . Conn , cinfo chainInfo ) error {
@ -123,8 +148,11 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
if err != nil && ! errors . Is ( err , errorChainNotFound { cinfo . table . Name , cinfo . name } ) {
return fmt . Errorf ( "get chain: %w" , err )
} else if err == nil {
// Chain already exists
if chain . Type != cinfo . chainType || chain . Hooknum != cinfo . chainHook || chain . Priority != cinfo . chainPriority {
// The chain already exists. If it is a TS chain, check the
// type/hook/priority, but for "conventional chains" assume they're what
// we expect (in case iptables-nft/ufw make minor behavior changes in
// the future).
if isTSChain ( chain . Name ) && ( chain . Type != cinfo . chainType || chain . Hooknum != cinfo . chainHook || chain . Priority != cinfo . chainPriority ) {
return fmt . Errorf ( "chain %s already exists with different type/hook/priority" , cinfo . name )
}
return nil
@ -136,6 +164,7 @@ func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error {
Type : cinfo . chainType ,
Hooknum : cinfo . chainHook ,
Priority : cinfo . chainPriority ,
Policy : cinfo . chainPolicy ,
} )
if err := c . Flush ( ) ; err != nil {
@ -228,6 +257,10 @@ ruleLoop:
}
for i , e := range r . Exprs {
// Skip counter expressions, as they will not match.
if _ , ok := e . ( * expr . Counter ) ; ok {
continue
}
if ! reflect . DeepEqual ( e , rule . Exprs [ i ] ) {
continue ruleLoop
}
@ -388,27 +421,49 @@ func (n *nftablesRunner) getNATTables() []*nftable {
// AddChains creates custom Tailscale chains in netfilter via nftables
// if the ts-chain doesn't already exist.
func ( n * nftablesRunner ) AddChains ( ) error {
polAccept := nftables . ChainPolicyAccept
for _ , table := range n . getTables ( ) {
filter , err := createTableIfNotExist ( n . conn , table . Proto , "ts-filter" )
// Create the filter table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
// chains are conclusive.
filter , err := createTableIfNotExist ( n . conn , table . Proto , "filter" )
if err != nil {
return fmt . Errorf ( "create table: %w" , err )
}
table . Filter = filter
if err = createChainIfNotExist ( n . conn , chainInfo { filter , chainNameForward , nftables . ChainTypeFilter , nftables . ChainHookForward , nftables . ChainPriorityRef ( - 1 ) } ) ; err != nil {
// Adding the "conventional chains" that are used by iptables-nft and ufw.
if err = createChainIfNotExist ( n . conn , chainInfo { filter , "FORWARD" , nftables . ChainTypeFilter , nftables . ChainHookForward , nftables . ChainPriorityFilter , & polAccept } ) ; err != nil {
return fmt . Errorf ( "create forward chain: %w" , err )
}
if err = createChainIfNotExist ( n . conn , chainInfo { filter , "INPUT" , nftables . ChainTypeFilter , nftables . ChainHookInput , nftables . ChainPriorityFilter , & polAccept } ) ; err != nil {
return fmt . Errorf ( "create input chain: %w" , err )
}
// Adding the tailscale chains that contain our rules.
if err = createChainIfNotExist ( n . conn , chainInfo { filter , chainNameForward , chainTypeRegular , nil , nil , nil } ) ; err != nil {
return fmt . Errorf ( "create forward chain: %w" , err )
}
if err = createChainIfNotExist ( n . conn , chainInfo { filter , chainNameInput , nftables . ChainTypeFilter , nftables . ChainHookInput , nftables . ChainPriorityRef ( - 1 ) } ) ; err != nil {
if err = createChainIfNotExist ( n . conn , chainInfo { filter , chainNameInput , chainTypeRegular, nil , nil , nil } ) ; err != nil {
return fmt . Errorf ( "create input chain: %w" , err )
}
}
for _ , table := range n . getNATTables ( ) {
nat , err := createTableIfNotExist ( n . conn , table . Proto , "ts-nat" )
// Create the nat table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
// chains are conclusive.
nat , err := createTableIfNotExist ( n . conn , table . Proto , "nat" )
if err != nil {
return fmt . Errorf ( "create table: %w" , err )
}
table . Nat = nat
if err = createChainIfNotExist ( n . conn , chainInfo { nat , chainNamePostrouting , nftables . ChainTypeNAT , nftables . ChainHookPostrouting , nftables . ChainPriorityNATDest } ) ; err != nil {
// Adding the "conventional chains" that are used by iptables-nft and ufw.
if err = createChainIfNotExist ( n . conn , chainInfo { nat , "POSTROUTING" , nftables . ChainTypeNAT , nftables . ChainHookPostrouting , nftables . ChainPriorityNATSource , & polAccept } ) ; err != nil {
return fmt . Errorf ( "create postrouting chain: %w" , err )
}
// Adding the tailscale chain that contains our rules.
if err = createChainIfNotExist ( n . conn , chainInfo { nat , chainNamePostrouting , chainTypeRegular , nil , nil , nil } ) ; err != nil {
return fmt . Errorf ( "create postrouting chain: %w" , err )
}
}
@ -445,19 +500,16 @@ func (n *nftablesRunner) DelChains() error {
if err := deleteChainIfExists ( n . conn , table . Filter , chainNameInput ) ; err != nil {
return fmt . Errorf ( "delete chain: %w" , err )
}
n . conn . DelTable ( table . Filter )
}
if err := deleteChainIfExists ( n . conn , n . nft4 . Nat , chainNamePostrouting ) ; err != nil {
return fmt . Errorf ( "delete chain: %w" , err )
}
n . conn . DelTable ( n . nft4 . Nat )
if n . v6NATAvailable {
if err := deleteChainIfExists ( n . conn , n . nft6 . Nat , chainNamePostrouting ) ; err != nil {
return fmt . Errorf ( "delete chain: %w" , err )
}
n . conn . DelTable ( n . nft6 . Nat )
}
if err := n . conn . Flush ( ) ; err != nil {
@ -467,15 +519,128 @@ func (n *nftablesRunner) DelChains() error {
return nil
}
// AddHooks is defined to satisfy the interface. NfTables does not require
// AddHooks, since we don't have any default tables or chains in nftables.
// createHookRule creates a rule to jump from a hooked chain to a regular chain.
func createHookRule ( table * nftables . Table , fromChain * nftables . Chain , toChainName string ) * nftables . Rule {
exprs := [ ] expr . Any {
& expr . Counter { } ,
& expr . Verdict {
Kind : expr . VerdictJump ,
Chain : toChainName ,
} ,
}
rule := & nftables . Rule {
Table : table ,
Chain : fromChain ,
Exprs : exprs ,
}
return rule
}
// addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain.
func addHookRule ( conn * nftables . Conn , table * nftables . Table , fromChain * nftables . Chain , toChainName string ) error {
rule := createHookRule ( table , fromChain , toChainName )
_ = conn . InsertRule ( rule )
if err := conn . Flush ( ) ; err != nil {
return fmt . Errorf ( "flush add rule: %w" , err )
}
return nil
}
// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
// in tables and jump from those chains to tailscale chains.
func ( n * nftablesRunner ) AddHooks ( ) error {
conn := n . conn
for _ , table := range n . getTables ( ) {
inputChain , err := getChainFromTable ( conn , table . Filter , "INPUT" )
if err != nil {
return fmt . Errorf ( "get INPUT chain: %w" , err )
}
err = addHookRule ( conn , table . Filter , inputChain , chainNameInput )
if err != nil {
return fmt . Errorf ( "Addhook: %w" , err )
}
forwardChain , err := getChainFromTable ( conn , table . Filter , "FORWARD" )
if err != nil {
return fmt . Errorf ( "get FORWARD chain: %w" , err )
}
err = addHookRule ( conn , table . Filter , forwardChain , chainNameForward )
if err != nil {
return fmt . Errorf ( "Addhook: %w" , err )
}
}
for _ , table := range n . getNATTables ( ) {
postroutingChain , err := getChainFromTable ( conn , table . Nat , "POSTROUTING" )
if err != nil {
return fmt . Errorf ( "get INPUT chain: %w" , err )
}
err = addHookRule ( conn , table . Nat , postroutingChain , chainNamePostrouting )
if err != nil {
return fmt . Errorf ( "Addhook: %w" , err )
}
}
return nil
}
// delHookRule deletes a rule that jumps from a hooked chain to a regular chain.
func delHookRule ( conn * nftables . Conn , table * nftables . Table , fromChain * nftables . Chain , toChainName string ) error {
rule := createHookRule ( table , fromChain , toChainName )
existingRule , err := findRule ( conn , rule )
if err != nil {
return fmt . Errorf ( "Failed to find hook rule: %w" , err )
}
if existingRule == nil {
return nil
}
_ = conn . DelRule ( existingRule )
if err := conn . Flush ( ) ; err != nil {
return fmt . Errorf ( "flush del hook rule: %w" , err )
}
return nil
}
// DelHooks is defined to satisfy the interface. NfTables does not require
// DelHooks, since we don't have any default tables or chains in nftables.
// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
func ( n * nftablesRunner ) DelHooks ( logf logger . Logf ) error {
conn := n . conn
for _ , table := range n . getTables ( ) {
inputChain , err := getChainFromTable ( conn , table . Filter , "INPUT" )
if err != nil {
return fmt . Errorf ( "get INPUT chain: %w" , err )
}
err = delHookRule ( conn , table . Filter , inputChain , chainNameInput )
if err != nil {
return fmt . Errorf ( "delhook: %w" , err )
}
forwardChain , err := getChainFromTable ( conn , table . Filter , "FORWARD" )
if err != nil {
return fmt . Errorf ( "get FORWARD chain: %w" , err )
}
err = delHookRule ( conn , table . Filter , forwardChain , chainNameForward )
if err != nil {
return fmt . Errorf ( "delhook: %w" , err )
}
}
for _ , table := range n . getNATTables ( ) {
postroutingChain , err := getChainFromTable ( conn , table . Nat , "POSTROUTING" )
if err != nil {
return fmt . Errorf ( "get INPUT chain: %w" , err )
}
err = delHookRule ( conn , table . Nat , postroutingChain , chainNamePostrouting )
if err != nil {
return fmt . Errorf ( "delhook: %w" , err )
}
}
return nil
}
@ -953,25 +1118,62 @@ func (n *nftablesRunner) DelSNATRule() error {
return nil
}
// cleanupChain removes a jump rule from hookChainName to tsChainName, and then
// the entire chain tsChainName. Errors are logged, but attempts to remove both
// the jump rule and chain continue even if one errors.
func cleanupChain ( logf logger . Logf , conn * nftables . Conn , table * nftables . Table , hookChainName , tsChainName string ) {
// remove the jump first, before removing the jump destination.
defaultChain , err := getChainFromTable ( conn , table , hookChainName )
if err != nil && ! errors . Is ( err , errorChainNotFound { table . Name , hookChainName } ) {
logf ( "cleanup: did not find default chain: %s" , err )
}
if ! errors . Is ( err , errorChainNotFound { table . Name , hookChainName } ) {
// delete hook in convention chain
_ = delHookRule ( conn , table , defaultChain , tsChainName )
}
tsChain , err := getChainFromTable ( conn , table , tsChainName )
if err != nil && ! errors . Is ( err , errorChainNotFound { table . Name , tsChainName } ) {
logf ( "cleanup: did not find ts-chain: %s" , err )
}
if tsChain != nil {
// flush and delete ts-chain
conn . FlushChain ( tsChain )
conn . DelChain ( tsChain )
err = conn . Flush ( )
logf ( "cleanup: delete and flush chain %s: %s" , tsChainName , err )
}
}
// NfTablesCleanUp removes all Tailscale added nftables rules.
// Any errors that occur are logged to the provided logf.
func NfTablesCleanUp ( logf logger . Logf ) {
conn , err := nftables . New ( )
if err != nil {
logf ( "ERROR: nftables connection: %w" , err )
logf ( " cleanup: nftables connection: %s ", err )
}
tables , err := conn . ListTables ( ) // both v4 and v6
if err != nil {
logf ( "ERROR: list tables: %w" , err )
logf ( " cleanup: list tables: %s ", err )
}
for _ , table := range tables {
// These table names were used briefly in 1.48.0.
if table . Name == "ts-filter" || table . Name == "ts-nat" {
conn . DelTable ( table )
if err := conn . Flush ( ) ; err != nil {
logf ( "ERROR: flush table %s: %w" , table . Name , err )
logf ( "cleanup: flush delete table %s: %s" , table . Name , err )
}
}
if table . Name == "filter" {
cleanupChain ( logf , conn , table , "INPUT" , chainNameInput )
cleanupChain ( logf , conn , table , "FORWARD" , chainNameForward )
}
if table . Name == "nat" {
cleanupChain ( logf , conn , table , "POSTROUTING" , chainNamePostrouting )
}
}
}