@ -41,8 +41,9 @@ type chainInfo struct {
chainPolicy * nftables . ChainPolicy
chainPolicy * nftables . ChainPolicy
}
}
// nftable contains nat and filter tables for the given IP family (Proto).
type nftable struct {
type nftable struct {
Proto nftables . TableFamily
Proto nftables . TableFamily // IPv4 or IPv6
Filter * nftables . Table
Filter * nftables . Table
Nat * nftables . Table
Nat * nftables . Table
}
}
@ -69,11 +70,10 @@ type nftable struct {
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains
type nftablesRunner struct {
type nftablesRunner struct {
conn * nftables . Conn
conn * nftables . Conn
nft4 * nftable
nft4 * nftable // IPv4 tables
nft6 * nftable
nft6 * nftable // IPv6 tables
v6Available bool
v6Available bool // whether the host supports IPv6
v6NATAvailable bool
}
}
func ( n * nftablesRunner ) ensurePreroutingChain ( dst netip . Addr ) ( * nftables . Table , * nftables . Chain , error ) {
func ( n * nftablesRunner ) ensurePreroutingChain ( dst netip . Addr ) ( * nftables . Table , * nftables . Chain , error ) {
@ -598,6 +598,10 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
if err != nil {
if err != nil {
return nil , fmt . Errorf ( "nftables connection: %w" , err )
return nil , fmt . Errorf ( "nftables connection: %w" , err )
}
}
return newNfTablesRunnerWithConn ( logf , conn ) , nil
}
func newNfTablesRunnerWithConn ( logf logger . Logf , conn * nftables . Conn ) * nftablesRunner {
nft4 := & nftable { Proto : nftables . TableFamilyIPv4 }
nft4 := & nftable { Proto : nftables . TableFamilyIPv4 }
v6err := CheckIPv6 ( logf )
v6err := CheckIPv6 ( logf )
@ -609,8 +613,8 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
if supportsV6 {
if supportsV6 {
nft6 = & nftable { Proto : nftables . TableFamilyIPv6 }
nft6 = & nftable { Proto : nftables . TableFamilyIPv6 }
logf ( "v6nat availability: true" )
}
}
logf ( "netfilter running in nftables mode, v6 = %v" , supportsV6 )
// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables
// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables
@ -619,7 +623,7 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
nft4 : nft4 ,
nft4 : nft4 ,
nft6 : nft6 ,
nft6 : nft6 ,
v6Available : supportsV6 ,
v6Available : supportsV6 ,
} , nil
}
}
}
// newLoadSaddrExpr creates a new nftables expression that loads the source
// newLoadSaddrExpr creates a new nftables expression that loads the source
@ -837,24 +841,15 @@ func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
return n . conn . Flush ( )
return n . conn . Flush ( )
}
}
// getTables gets the available nftable in nftables runner.
// getTables returns tables for IP families that this host was determined to
// support (either IPv4 and IPv6 or just IPv4).
func ( n * nftablesRunner ) getTables ( ) [ ] * nftable {
func ( n * nftablesRunner ) getTables ( ) [ ] * nftable {
if n . v6Available {
if n . HasIPV6 ( ) {
return [ ] * nftable { n . nft4 , n . nft6 }
return [ ] * nftable { n . nft4 , n . nft6 }
}
}
return [ ] * nftable { n . nft4 }
return [ ] * nftable { n . nft4 }
}
}
// getNATTables gets the available nftable in nftables runner.
// If the system does not support IPv6 NAT, only the IPv4 nftable
// will be returned.
func ( n * nftablesRunner ) getNATTables ( ) [ ] * nftable {
if n . v6NATAvailable {
return n . getTables ( )
}
return [ ] * nftable { n . nft4 }
}
// AddChains creates custom Tailscale chains in netfilter via nftables
// AddChains creates custom Tailscale chains in netfilter via nftables
// if the ts-chain doesn't already exist.
// if the ts-chain doesn't already exist.
func ( n * nftablesRunner ) AddChains ( ) error {
func ( n * nftablesRunner ) AddChains ( ) error {
@ -883,9 +878,7 @@ func (n *nftablesRunner) AddChains() error {
if err = createChainIfNotExist ( n . conn , chainInfo { filter , chainNameInput , chainTypeRegular , nil , nil , nil } ) ; err != nil {
if err = createChainIfNotExist ( n . conn , chainInfo { filter , chainNameInput , chainTypeRegular , nil , nil , nil } ) ; err != nil {
return fmt . Errorf ( "create input chain: %w" , err )
return fmt . Errorf ( "create input chain: %w" , err )
}
}
}
for _ , table := range n . getNATTables ( ) {
// Create the nat table if it doesn't exist, this table name is the same
// Create the nat table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the
// as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump
// same conventional table so that `accept` verdicts from our jump
@ -923,7 +916,7 @@ const (
// can be used. It cleans up the dummy chains after creation.
// can be used. It cleans up the dummy chains after creation.
func ( n * nftablesRunner ) createDummyPostroutingChains ( ) ( retErr error ) {
func ( n * nftablesRunner ) createDummyPostroutingChains ( ) ( retErr error ) {
polAccept := ptr . To ( nftables . ChainPolicyAccept )
polAccept := ptr . To ( nftables . ChainPolicyAccept )
for _ , table := range n . get NAT Tables( ) {
for _ , table := range n . get Tables( ) {
nat , err := createTableIfNotExist ( n . conn , table . Proto , tsDummyTableName )
nat , err := createTableIfNotExist ( n . conn , table . Proto , tsDummyTableName )
if err != nil {
if err != nil {
return fmt . Errorf ( "create nat table: %w" , err )
return fmt . Errorf ( "create nat table: %w" , err )
@ -980,7 +973,7 @@ func (n *nftablesRunner) DelChains() error {
return fmt . Errorf ( "delete chain: %w" , err )
return fmt . Errorf ( "delete chain: %w" , err )
}
}
if n . v6NATAvailable {
if n . HasIPV6NAT ( ) {
if err := deleteChainIfExists ( n . conn , n . nft6 . Nat , chainNamePostrouting ) ; err != nil {
if err := deleteChainIfExists ( n . conn , n . nft6 . Nat , chainNamePostrouting ) ; err != nil {
return fmt . Errorf ( "delete chain: %w" , err )
return fmt . Errorf ( "delete chain: %w" , err )
}
}
@ -1046,9 +1039,7 @@ func (n *nftablesRunner) AddHooks() error {
if err != nil {
if err != nil {
return fmt . Errorf ( "Addhook: %w" , err )
return fmt . Errorf ( "Addhook: %w" , err )
}
}
}
for _ , table := range n . getNATTables ( ) {
postroutingChain , err := getChainFromTable ( conn , table . Nat , "POSTROUTING" )
postroutingChain , err := getChainFromTable ( conn , table . Nat , "POSTROUTING" )
if err != nil {
if err != nil {
return fmt . Errorf ( "get INPUT chain: %w" , err )
return fmt . Errorf ( "get INPUT chain: %w" , err )
@ -1102,9 +1093,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
if err != nil {
if err != nil {
return fmt . Errorf ( "delhook: %w" , err )
return fmt . Errorf ( "delhook: %w" , err )
}
}
}
for _ , table := range n . getNATTables ( ) {
postroutingChain , err := getChainFromTable ( conn , table . Nat , "POSTROUTING" )
postroutingChain , err := getChainFromTable ( conn , table . Nat , "POSTROUTING" )
if err != nil {
if err != nil {
return fmt . Errorf ( "get INPUT chain: %w" , err )
return fmt . Errorf ( "get INPUT chain: %w" , err )
@ -1612,9 +1601,7 @@ func (n *nftablesRunner) DelBase() error {
return fmt . Errorf ( "get forward chain: %v" , err )
return fmt . Errorf ( "get forward chain: %v" , err )
}
}
conn . FlushChain ( forwardChain )
conn . FlushChain ( forwardChain )
}
for _ , table := range n . getNATTables ( ) {
postrouteChain , err := getChainFromTable ( conn , table . Nat , chainNamePostrouting )
postrouteChain , err := getChainFromTable ( conn , table . Nat , chainNamePostrouting )
if err != nil {
if err != nil {
return fmt . Errorf ( "get postrouting chain v4: %v" , err )
return fmt . Errorf ( "get postrouting chain v4: %v" , err )
@ -1684,7 +1671,7 @@ func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, cha
func ( n * nftablesRunner ) AddSNATRule ( ) error {
func ( n * nftablesRunner ) AddSNATRule ( ) error {
conn := n . conn
conn := n . conn
for _ , table := range n . get NAT Tables( ) {
for _ , table := range n . get Tables( ) {
chain , err := getChainFromTable ( conn , table . Nat , chainNamePostrouting )
chain , err := getChainFromTable ( conn , table . Nat , chainNamePostrouting )
if err != nil {
if err != nil {
return fmt . Errorf ( "get postrouting chain v4: %w" , err )
return fmt . Errorf ( "get postrouting chain v4: %w" , err )
@ -1727,7 +1714,7 @@ func (n *nftablesRunner) DelSNATRule() error {
& expr . Masq { } ,
& expr . Masq { } ,
}
}
for _ , table := range n . get NAT Tables( ) {
for _ , table := range n . get Tables( ) {
chain , err := getChainFromTable ( conn , table . Nat , chainNamePostrouting )
chain , err := getChainFromTable ( conn , table . Nat , chainNamePostrouting )
if err != nil {
if err != nil {
return fmt . Errorf ( "get postrouting chain v4: %w" , err )
return fmt . Errorf ( "get postrouting chain v4: %w" , err )