@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"golang.org/x/sys/unix"
"tailscale.com/net/tsaddr"
"tailscale.com/net/tsaddr"
"tailscale.com/types/logger"
"tailscale.com/types/logger"
"tailscale.com/types/ptr"
)
)
const (
const (
@ -316,8 +317,33 @@ func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
return n . conn . Flush ( )
return n . conn . Flush ( )
}
}
// createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family.
// deleteTableIfExists deletes a nftables table via connection c if it exists
func createTableIfNotExist ( c * nftables . Conn , family nftables . TableFamily , name string ) ( * nftables . Table , error ) {
// within the given family.
func deleteTableIfExists ( c * nftables . Conn , family nftables . TableFamily , name string ) error {
t , err := getTableIfExists ( c , family , name )
if err != nil {
return fmt . Errorf ( "get table: %w" , err )
}
if t == nil {
// Table does not exist, so nothing to delete.
return nil
}
c . DelTable ( t )
if err := c . Flush ( ) ; err != nil {
if t , err = getTableIfExists ( c , family , name ) ; t == nil && err == nil {
// Check if the table still exists. If it does not, then the error
// is due to the table not existing, so we can ignore it. Maybe a
// concurrent process deleted the table.
return nil
}
return fmt . Errorf ( "del table: %w" , err )
}
return nil
}
// getTableIfExists returns the table with the given name from the given family
// if it exists. If none match, it returns (nil, nil).
func getTableIfExists ( c * nftables . Conn , family nftables . TableFamily , name string ) ( * nftables . Table , error ) {
tables , err := c . ListTables ( )
tables , err := c . ListTables ( )
if err != nil {
if err != nil {
return nil , fmt . Errorf ( "get tables: %w" , err )
return nil , fmt . Errorf ( "get tables: %w" , err )
@ -327,7 +353,17 @@ func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name s
return table , nil
return table , nil
}
}
}
}
return nil , nil
}
// createTableIfNotExist creates a nftables table via connection c if it does
// not exist within the given family.
func createTableIfNotExist ( c * nftables . Conn , family nftables . TableFamily , name string ) ( * nftables . Table , error ) {
if t , err := getTableIfExists ( c , family , name ) ; err != nil {
return nil , fmt . Errorf ( "get table: %w" , err )
} else if t != nil {
return t , nil
}
t := c . AddTable ( & nftables . Table {
t := c . AddTable ( & nftables . Table {
Family : family ,
Family : family ,
Name : name ,
Name : name ,
@ -365,24 +401,6 @@ func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*n
return nil , errorChainNotFound { table . Name , name }
return nil , errorChainNotFound { table . Name , name }
}
}
// getChainsFromTable returns all chains from the given table.
func getChainsFromTable ( c * nftables . Conn , table * nftables . Table ) ( [ ] * nftables . Chain , error ) {
chains , err := c . ListChainsOfTableFamily ( table . Family )
if err != nil {
return nil , fmt . Errorf ( "list chains: %w" , err )
}
var ret [ ] * nftables . Chain
for _ , chain := range chains {
// Table family is already checked so table name is unique
if chain . Table . Name == table . Name {
ret = append ( ret , chain )
}
}
return ret , nil
}
// isTSChain reports whether `name` begins with "ts-" (and is thus a
// isTSChain reports whether `name` begins with "ts-" (and is thus a
// Tailscale-managed chain).
// Tailscale-managed chain).
func isTSChain ( name string ) bool {
func isTSChain ( name string ) bool {
@ -804,6 +822,43 @@ func (n *nftablesRunner) AddChains() error {
return n . conn . Flush ( )
return n . conn . Flush ( )
}
}
// These are dummy chains and tables we create to detect if nftables is
// available. We create them, then delete them. If we can create and delete
// them, then we can use nftables. If we can't, then we assume that we're
// running on a system that doesn't support nftables. See
// createDummyPostroutingChains.
const (
tsDummyChainName = "ts-test-postrouting"
tsDummyTableName = "ts-test-nat"
)
// createDummyPostroutingChains creates dummy postrouting chains in netfilter
// via netfilter via nftables, as a last resort measure to detect that nftables
// can be used. It cleans up the dummy chains after creation.
func ( n * nftablesRunner ) createDummyPostroutingChains ( ) ( retErr error ) {
polAccept := ptr . To ( nftables . ChainPolicyAccept )
for _ , table := range n . getNATTables ( ) {
nat , err := createTableIfNotExist ( n . conn , table . Proto , tsDummyTableName )
if err != nil {
return fmt . Errorf ( "create nat table: %w" , err )
}
defer func ( fm nftables . TableFamily ) {
if err := deleteTableIfExists ( n . conn , table . Proto , tsDummyTableName ) ; err != nil && retErr == nil {
retErr = fmt . Errorf ( "delete %q table: %w" , tsDummyTableName , err )
}
} ( table . Proto )
table . Nat = nat
if err = createChainIfNotExist ( n . conn , chainInfo { nat , tsDummyChainName , nftables . ChainTypeNAT , nftables . ChainHookPostrouting , nftables . ChainPriorityNATSource , polAccept } ) ; err != nil {
return fmt . Errorf ( "create %q chain: %w" , tsDummyChainName , err )
}
if err := deleteChainIfExists ( n . conn , nat , tsDummyChainName ) ; err != nil {
return fmt . Errorf ( "delete %q chain: %w" , tsDummyChainName , err )
}
}
return nil
}
// deleteChainIfExists deletes a chain if it exists.
// deleteChainIfExists deletes a chain if it exists.
func deleteChainIfExists ( c * nftables . Conn , table * nftables . Table , name string ) error {
func deleteChainIfExists ( c * nftables . Conn , table * nftables . Table , name string ) error {
chain , err := getChainFromTable ( c , table , name )
chain , err := getChainFromTable ( c , table , name )