util/linuxfw: fix IPv6 availability check for nftables (#12009)

* util/linuxfw: fix IPv6 NAT availability check for nftables

When running firewall in nftables mode,
there is no need for a separate NAT availability check
(unlike with iptables, there are no hosts that support nftables, but not IPv6 NAT - see tailscale/tailscale#11353).
This change fixes a firewall NAT availability check that was using the no-longer set ipv6NATAvailable field
by removing the field and using a method that, for nftables, just checks that IPv6 is available.

Updates tailscale/tailscale#12008

Signed-off-by: Irbe Krumina <irbe@tailscale.com>
pull/12130/head
Irbe Krumina 6 months ago committed by GitHub
parent 8aa5c3534d
commit 7ef2f72135
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -87,7 +87,7 @@ func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
} }
supportsV6Filter = checkSupportsV6Filter(ipt6, logf) supportsV6Filter = checkSupportsV6Filter(ipt6, logf)
supportsV6NAT = checkSupportsV6NAT(ipt6, logf) supportsV6NAT = checkSupportsV6NAT(ipt6, logf)
logf("v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT) logf("netfilter running in iptables mode v6 = %v, v6filter = %v, v6nat = %v", supportsV6, supportsV6Filter, supportsV6NAT)
} }
return &iptablesRunner{ return &iptablesRunner{
ipt4: ipt4, ipt4: ipt4,

@ -104,12 +104,19 @@ func getTailscaleSubnetRouteMark() []byte {
return []byte{0x00, 0x04, 0x00, 0x00} return []byte{0x00, 0x04, 0x00, 0x00}
} }
// checkIPv6ForTest can be set in tests.
var checkIPv6ForTest func(logger.Logf) error
// checkIPv6 checks whether the system appears to have a working IPv6 // checkIPv6 checks whether the system appears to have a working IPv6
// network stack. It returns an error explaining what looks wrong or // network stack. It returns an error explaining what looks wrong or
// missing. It does not check that IPv6 is currently functional or // missing. It does not check that IPv6 is currently functional or
// that there's a global address, just that the system would support // that there's a global address, just that the system would support
// IPv6 if it were on an IPv6 network. // IPv6 if it were on an IPv6 network.
func CheckIPv6(logf logger.Logf) error { func CheckIPv6(logf logger.Logf) error {
if f := checkIPv6ForTest; f != nil {
return f(logf)
}
_, err := os.Stat("/proc/sys/net/ipv6") _, err := os.Stat("/proc/sys/net/ipv6")
if os.IsNotExist(err) { if os.IsNotExist(err) {
return err return err

@ -41,8 +41,9 @@ type chainInfo struct {
chainPolicy *nftables.ChainPolicy chainPolicy *nftables.ChainPolicy
} }
// nftable contains nat and filter tables for the given IP family (Proto).
type nftable struct { type nftable struct {
Proto nftables.TableFamily Proto nftables.TableFamily // IPv4 or IPv6
Filter *nftables.Table Filter *nftables.Table
Nat *nftables.Table Nat *nftables.Table
} }
@ -69,11 +70,10 @@ type nftable struct {
// https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains
type nftablesRunner struct { type nftablesRunner struct {
conn *nftables.Conn conn *nftables.Conn
nft4 *nftable nft4 *nftable // IPv4 tables
nft6 *nftable nft6 *nftable // IPv6 tables
v6Available bool v6Available bool // whether the host supports IPv6
v6NATAvailable bool
} }
func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) { func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
@ -598,6 +598,10 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("nftables connection: %w", err) return nil, fmt.Errorf("nftables connection: %w", err)
} }
return newNfTablesRunnerWithConn(logf, conn), nil
}
func newNfTablesRunnerWithConn(logf logger.Logf, conn *nftables.Conn) *nftablesRunner {
nft4 := &nftable{Proto: nftables.TableFamilyIPv4} nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
v6err := CheckIPv6(logf) v6err := CheckIPv6(logf)
@ -609,8 +613,8 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
if supportsV6 { if supportsV6 {
nft6 = &nftable{Proto: nftables.TableFamilyIPv6} nft6 = &nftable{Proto: nftables.TableFamilyIPv6}
logf("v6nat availability: true")
} }
logf("netfilter running in nftables mode, v6 = %v", supportsV6)
// TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables // TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables
@ -619,7 +623,7 @@ func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) {
nft4: nft4, nft4: nft4,
nft6: nft6, nft6: nft6,
v6Available: supportsV6, v6Available: supportsV6,
}, nil }
} }
// newLoadSaddrExpr creates a new nftables expression that loads the source // newLoadSaddrExpr creates a new nftables expression that loads the source
@ -837,24 +841,15 @@ func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
return n.conn.Flush() return n.conn.Flush()
} }
// getTables gets the available nftable in nftables runner. // getTables returns tables for IP families that this host was determined to
// support (either IPv4 and IPv6 or just IPv4).
func (n *nftablesRunner) getTables() []*nftable { func (n *nftablesRunner) getTables() []*nftable {
if n.v6Available { if n.HasIPV6() {
return []*nftable{n.nft4, n.nft6} return []*nftable{n.nft4, n.nft6}
} }
return []*nftable{n.nft4} return []*nftable{n.nft4}
} }
// getNATTables gets the available nftable in nftables runner.
// If the system does not support IPv6 NAT, only the IPv4 nftable
// will be returned.
func (n *nftablesRunner) getNATTables() []*nftable {
if n.v6NATAvailable {
return n.getTables()
}
return []*nftable{n.nft4}
}
// AddChains creates custom Tailscale chains in netfilter via nftables // AddChains creates custom Tailscale chains in netfilter via nftables
// if the ts-chain doesn't already exist. // if the ts-chain doesn't already exist.
func (n *nftablesRunner) AddChains() error { func (n *nftablesRunner) AddChains() error {
@ -883,9 +878,7 @@ func (n *nftablesRunner) AddChains() error {
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil { if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create input chain: %w", err) return fmt.Errorf("create input chain: %w", err)
} }
}
for _, table := range n.getNATTables() {
// Create the nat table if it doesn't exist, this table name is the same // Create the nat table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the // as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump // same conventional table so that `accept` verdicts from our jump
@ -923,7 +916,7 @@ const (
// can be used. It cleans up the dummy chains after creation. // can be used. It cleans up the dummy chains after creation.
func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) { func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
polAccept := ptr.To(nftables.ChainPolicyAccept) polAccept := ptr.To(nftables.ChainPolicyAccept)
for _, table := range n.getNATTables() { for _, table := range n.getTables() {
nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName) nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
if err != nil { if err != nil {
return fmt.Errorf("create nat table: %w", err) return fmt.Errorf("create nat table: %w", err)
@ -980,7 +973,7 @@ func (n *nftablesRunner) DelChains() error {
return fmt.Errorf("delete chain: %w", err) return fmt.Errorf("delete chain: %w", err)
} }
if n.v6NATAvailable { if n.HasIPV6NAT() {
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil { if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
return fmt.Errorf("delete chain: %w", err) return fmt.Errorf("delete chain: %w", err)
} }
@ -1046,9 +1039,7 @@ func (n *nftablesRunner) AddHooks() error {
if err != nil { if err != nil {
return fmt.Errorf("Addhook: %w", err) return fmt.Errorf("Addhook: %w", err)
} }
}
for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil { if err != nil {
return fmt.Errorf("get INPUT chain: %w", err) return fmt.Errorf("get INPUT chain: %w", err)
@ -1102,9 +1093,7 @@ func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
if err != nil { if err != nil {
return fmt.Errorf("delhook: %w", err) return fmt.Errorf("delhook: %w", err)
} }
}
for _, table := range n.getNATTables() {
postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING")
if err != nil { if err != nil {
return fmt.Errorf("get INPUT chain: %w", err) return fmt.Errorf("get INPUT chain: %w", err)
@ -1612,9 +1601,7 @@ func (n *nftablesRunner) DelBase() error {
return fmt.Errorf("get forward chain: %v", err) return fmt.Errorf("get forward chain: %v", err)
} }
conn.FlushChain(forwardChain) conn.FlushChain(forwardChain)
}
for _, table := range n.getNATTables() {
postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil { if err != nil {
return fmt.Errorf("get postrouting chain v4: %v", err) return fmt.Errorf("get postrouting chain v4: %v", err)
@ -1684,7 +1671,7 @@ func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, cha
func (n *nftablesRunner) AddSNATRule() error { func (n *nftablesRunner) AddSNATRule() error {
conn := n.conn conn := n.conn
for _, table := range n.getNATTables() { for _, table := range n.getTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil { if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err) return fmt.Errorf("get postrouting chain v4: %w", err)
@ -1727,7 +1714,7 @@ func (n *nftablesRunner) DelSNATRule() error {
&expr.Masq{}, &expr.Masq{},
} }
for _, table := range n.getNATTables() { for _, table := range n.getTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil { if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err) return fmt.Errorf("get postrouting chain v4: %w", err)

@ -20,6 +20,8 @@ import (
"github.com/mdlayher/netlink" "github.com/mdlayher/netlink"
"github.com/vishvananda/netns" "github.com/vishvananda/netns"
"tailscale.com/net/tsaddr" "tailscale.com/net/tsaddr"
"tailscale.com/tstest"
"tailscale.com/types/logger"
) )
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing // nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
@ -503,19 +505,6 @@ func cleanupSysConn(t *testing.T, ns netns.NsHandle) {
} }
} }
func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner {
nft4 := &nftable{Proto: nftables.TableFamilyIPv4}
nft6 := &nftable{Proto: nftables.TableFamilyIPv6}
return &nftablesRunner{
conn: conn,
nft4: nft4,
nft6: nft6,
v6Available: true,
v6NATAvailable: true,
}
}
func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) { func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) {
t.Helper() t.Helper()
got, err := conn.ListChainsOfTableFamily(fam) got, err := conn.ListChainsOfTableFamily(fam)
@ -526,42 +515,76 @@ func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wa
t.Fatalf("len(got) = %d, want %d", len(got), wantCount) t.Fatalf("len(got) = %d, want %d", len(got), wantCount)
} }
} }
func checkTables(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) {
func TestAddAndDelNetfilterChains(t *testing.T) { t.Helper()
conn := newSysConn(t) got, err := conn.ListTablesOfFamily(fam)
checkChains(t, conn, nftables.TableFamilyIPv4, 0)
checkChains(t, conn, nftables.TableFamilyIPv6, 0)
runner := newFakeNftablesRunner(t, conn)
if err := runner.AddChains(); err != nil {
t.Fatalf("runner.AddChains() failed: %v", err)
}
tables, err := conn.ListTables()
if err != nil { if err != nil {
t.Fatalf("conn.ListTables() failed: %v", err) t.Fatalf("conn.ListTablesOfFamily(%v) failed: %v", fam, err)
} }
if len(got) != wantCount {
if len(tables) != 4 { t.Fatalf("len(got) = %d, want %d", len(got), wantCount)
t.Fatalf("len(tables) = %d, want 4", len(tables))
} }
}
checkChains(t, conn, nftables.TableFamilyIPv4, 6) func TestAddAndDelNetfilterChains(t *testing.T) {
checkChains(t, conn, nftables.TableFamilyIPv6, 6) type test struct {
hostHasIPv6 bool
initIPv4ChainCount int
initIPv6ChainCount int
ipv4TableCount int
ipv6TableCount int
ipv4ChainCount int
ipv6ChainCount int
ipv4ChainCountPostDelete int
ipv6ChainCountPostDelete int
}
tests := []test{
{
hostHasIPv6: true,
initIPv4ChainCount: 0,
initIPv6ChainCount: 0,
ipv4TableCount: 2,
ipv6TableCount: 2,
ipv4ChainCount: 6,
ipv6ChainCount: 6,
ipv4ChainCountPostDelete: 3,
ipv6ChainCountPostDelete: 3,
},
{ // host without IPv6 support
ipv4TableCount: 2,
ipv4ChainCount: 6,
ipv4ChainCountPostDelete: 3,
}}
for _, tt := range tests {
t.Logf("running a test case for IPv6 support: %v", tt.hostHasIPv6)
conn := newSysConn(t)
runner := newFakeNftablesRunnerWithConn(t, conn, tt.hostHasIPv6)
// Check that we start off with no chains.
checkChains(t, conn, nftables.TableFamilyIPv4, tt.initIPv4ChainCount)
checkChains(t, conn, nftables.TableFamilyIPv6, tt.initIPv6ChainCount)
runner.DelChains() if err := runner.AddChains(); err != nil {
t.Fatalf("runner.AddChains() failed: %v", err)
}
// The default chains should still be present. // Check that the amount of tables for each IP family is as expected.
checkChains(t, conn, nftables.TableFamilyIPv4, 3) checkTables(t, conn, nftables.TableFamilyIPv4, tt.ipv4TableCount)
checkChains(t, conn, nftables.TableFamilyIPv6, 3) checkTables(t, conn, nftables.TableFamilyIPv6, tt.ipv6TableCount)
tables, err = conn.ListTables() // Check that the amount of chains for each IP family is as expected.
if err != nil { checkChains(t, conn, nftables.TableFamilyIPv4, tt.ipv4ChainCount)
t.Fatalf("conn.ListTables() failed: %v", err) checkChains(t, conn, nftables.TableFamilyIPv6, tt.ipv6ChainCount)
}
if err := runner.DelChains(); err != nil {
t.Fatalf("runner.DelChains() failed: %v", err)
}
if len(tables) != 4 { // Test that the tables as well as the default chains are still present.
t.Fatalf("len(tables) = %d, want 4", len(tables)) checkChains(t, conn, nftables.TableFamilyIPv4, tt.ipv4ChainCountPostDelete)
checkChains(t, conn, nftables.TableFamilyIPv6, tt.ipv6ChainCountPostDelete)
checkTables(t, conn, nftables.TableFamilyIPv4, tt.ipv4TableCount)
checkTables(t, conn, nftables.TableFamilyIPv6, tt.ipv6TableCount)
} }
} }
@ -665,7 +688,8 @@ func checkChainRules(t *testing.T, conn *nftables.Conn, chain *nftables.Chain, w
func TestNFTAddAndDelNetfilterBase(t *testing.T) { func TestNFTAddAndDelNetfilterBase(t *testing.T) {
conn := newSysConn(t) conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn) runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.AddChains(); err != nil { if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err) t.Fatalf("AddChains() failed: %v", err)
} }
@ -759,7 +783,7 @@ func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nf
func TestNFTAddAndDelLoopbackRule(t *testing.T) { func TestNFTAddAndDelLoopbackRule(t *testing.T) {
conn := newSysConn(t) conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn) runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.AddChains(); err != nil { if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err) t.Fatalf("AddChains() failed: %v", err)
} }
@ -817,7 +841,7 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {
func TestNFTAddAndDelHookRule(t *testing.T) { func TestNFTAddAndDelHookRule(t *testing.T) {
conn := newSysConn(t) conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn) runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.AddChains(); err != nil { if err := runner.AddChains(); err != nil {
t.Fatalf("AddChains() failed: %v", err) t.Fatalf("AddChains() failed: %v", err)
} }
@ -868,11 +892,11 @@ func (t *testFWDetector) nftDetect() (int, error) {
// postrouting chains are cleaned up. // postrouting chains are cleaned up.
func TestCreateDummyPostroutingChains(t *testing.T) { func TestCreateDummyPostroutingChains(t *testing.T) {
conn := newSysConn(t) conn := newSysConn(t)
runner := newFakeNftablesRunner(t, conn) runner := newFakeNftablesRunnerWithConn(t, conn, true)
if err := runner.createDummyPostroutingChains(); err != nil { if err := runner.createDummyPostroutingChains(); err != nil {
t.Fatalf("createDummyPostroutingChains() failed: %v", err) t.Fatalf("createDummyPostroutingChains() failed: %v", err)
} }
for _, table := range runner.getNATTables() { for _, table := range runner.getTables() {
nt, err := getTableIfExists(conn, table.Proto, tsDummyTableName) nt, err := getTableIfExists(conn, table.Proto, tsDummyTableName)
if err != nil { if err != nil {
t.Fatalf("getTableIfExists() failed: %v", err) t.Fatalf("getTableIfExists() failed: %v", err)
@ -929,3 +953,14 @@ func TestPickFirewallModeFromInstalledRules(t *testing.T) {
}) })
} }
} }
func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bool) *nftablesRunner {
t.Helper()
if !hasIPv6 {
tstest.Replace(t, &checkIPv6ForTest, func(logger.Logf) error {
return errors.New("test: no IPv6")
})
}
return newNfTablesRunnerWithConn(t.Logf, conn)
}

Loading…
Cancel
Save