From b47cf04624a108ecff059ea3c61b23123ce9c28e Mon Sep 17 00:00:00 2001 From: Maisem Ali Date: Tue, 17 Oct 2023 19:51:40 +0000 Subject: [PATCH] util/linuxfw: fix broken tests These tests were broken at HEAD. CI currently does not run these as root, will figure out how to do that in a followup. Updates #5621 Updates #8555 Updates #8762 Signed-off-by: Maisem Ali --- util/linuxfw/nftables_runner_test.go | 246 ++++++++------------------- 1 file changed, 68 insertions(+), 178 deletions(-) diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index 1a451238b..b6ff44e7f 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -474,6 +474,10 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) { func newSysConn(t *testing.T) *nftables.Conn { t.Helper() + if os.Geteuid() != 0 { + t.Skip(t.Name(), " requires privileges to create a namespace in order to run") + return nil + } runtime.LockOSThread() @@ -512,12 +516,21 @@ func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner { } } -func TestAddAndDelNetfilterChains(t *testing.T) { - if os.Geteuid() != 0 { - t.Skip(t.Name(), " requires privileges to create a namespace in order to run") - return +func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wantCount int) { + t.Helper() + got, err := conn.ListChainsOfTableFamily(fam) + if err != nil { + t.Fatalf("conn.ListChainsOfTableFamily(%v) failed: %v", fam, err) + } + if len(got) != wantCount { + t.Fatalf("len(got) = %d, want %d", len(got), wantCount) } +} + +func TestAddAndDelNetfilterChains(t *testing.T) { conn := newSysConn(t) + checkChains(t, conn, nftables.TableFamilyIPv4, 0) + checkChains(t, conn, nftables.TableFamilyIPv6, 0) runner := newFakeNftablesRunner(t, conn) runner.AddChains() @@ -531,33 +544,22 @@ func TestAddAndDelNetfilterChains(t *testing.T) { t.Fatalf("len(tables) = %d, want 4", len(tables)) } - chainsV4, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) - if err != nil { - t.Fatalf("list chains failed: %v", err) - } - - if len(chainsV4) != 6 { - t.Fatalf("len(chainsV4) = %d, want 6", len(chainsV4)) - } - - chainsV6, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv6) - if err != nil { - t.Fatalf("list chains failed: %v", err) - } - - if len(chainsV6) != 6 { - t.Fatalf("len(chainsV6) = %d, want 6", len(chainsV6)) - } + checkChains(t, conn, nftables.TableFamilyIPv4, 6) + checkChains(t, conn, nftables.TableFamilyIPv6, 6) runner.DelChains() + // The default chains should still be present. + checkChains(t, conn, nftables.TableFamilyIPv4, 3) + checkChains(t, conn, nftables.TableFamilyIPv6, 3) + tables, err = conn.ListTables() if err != nil { t.Fatalf("conn.ListTables() failed: %v", err) } - if len(tables) != 0 { - t.Fatalf("len(tables) = %d, want 0", len(tables)) + if len(tables) != 4 { + t.Fatalf("len(tables) = %d, want 4", len(tables)) } } @@ -646,12 +648,19 @@ func findCommonBaseRules( return get, nil } -func TestNFTAddAndDelNetfilterBase(t *testing.T) { - if os.Geteuid() != 0 { - t.Skip(t.Name(), " requires privileges to create a namespace in order to run") - return +// checkChainRules verifies that the chain has the expected number of rules. +func checkChainRules(t *testing.T, conn *nftables.Conn, chain *nftables.Chain, wantCount int) { + t.Helper() + got, err := conn.GetRules(chain.Table, chain) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + if len(got) != wantCount { + t.Fatalf("got = %d, want %d", len(got), wantCount) } +} +func TestNFTAddAndDelNetfilterBase(t *testing.T) { conn := newSysConn(t) runner := newFakeNftablesRunner(t, conn) @@ -664,30 +673,9 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) { if err != nil { t.Fatalf("getTsChains() failed: %v", err) } - - inputV4Rules, err := conn.GetRules(runner.nft4.Filter, inputV4) - if err != nil { - t.Fatalf("conn.GetRules() failed: %v", err) - } - if len(inputV4Rules) != 2 { - t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules)) - } - - forwardV4Rules, err := conn.GetRules(runner.nft4.Filter, forwardV4) - if err != nil { - t.Fatalf("conn.GetRules() failed: %v", err) - } - if len(forwardV4Rules) != 4 { - t.Fatalf("len(forwardV4Rules) = %d, want 4", len(forwardV4Rules)) - } - - postroutingV4Rules, err := conn.GetRules(runner.nft4.Nat, postroutingV4) - if err != nil { - t.Fatalf("conn.GetRules() failed: %v", err) - } - if len(postroutingV4Rules) != 0 { - t.Fatalf("len(postroutingV4Rules) = %d, want 0", len(postroutingV4Rules)) - } + checkChainRules(t, conn, inputV4, 3) + checkChainRules(t, conn, forwardV4, 4) + checkChainRules(t, conn, postroutingV4, 0) _, err = findV4BaseRules(conn, inputV4, forwardV4, "testTunn") if err != nil { @@ -703,30 +691,9 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) { if err != nil { t.Fatalf("getTsChains() failed: %v", err) } - - inputV6Rules, err := conn.GetRules(runner.nft6.Filter, inputV6) - if err != nil { - t.Fatalf("conn.GetRules() failed: %v", err) - } - if len(inputV6Rules) != 0 { - t.Fatalf("len(inputV6Rules) = %d, want 0", len(inputV4Rules)) - } - - forwardV6Rules, err := conn.GetRules(runner.nft6.Filter, forwardV6) - if err != nil { - t.Fatalf("conn.GetRules() failed: %v", err) - } - if len(forwardV6Rules) != 3 { - t.Fatalf("len(forwardV6Rules) = %d, want 3", len(forwardV4Rules)) - } - - postroutingV6Rules, err := conn.GetRules(runner.nft6.Nat, postroutingV6) - if err != nil { - t.Fatalf("conn.GetRules() failed: %v", err) - } - if len(postroutingV6Rules) != 0 { - t.Fatalf("len(postroutingV6Rules) = %d, want 0", len(postroutingV4Rules)) - } + checkChainRules(t, conn, inputV6, 3) + checkChainRules(t, conn, forwardV6, 4) + checkChainRules(t, conn, postroutingV6, 0) _, err = findCommonBaseRules(conn, forwardV6, "testTunn") if err != nil { @@ -740,13 +707,7 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) { t.Fatalf("conn.ListChains() failed: %v", err) } for _, chain := range chains { - chainRules, err := conn.GetRules(chain.Table, chain) - if err != nil { - t.Fatalf("conn.GetRules() failed: %v", err) - } - if len(chainRules) != 0 { - t.Fatalf("len(chainRules) = %d, want 0", len(chainRules)) - } + checkChainRules(t, conn, chain, 0) } } @@ -790,36 +751,36 @@ func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nf } func TestNFTAddAndDelLoopbackRule(t *testing.T) { - if os.Geteuid() != 0 { - t.Skip(t.Name(), " requires privileges to create a namespace in order to run") - return - } - conn := newSysConn(t) runner := newFakeNftablesRunner(t, conn) runner.AddChains() defer runner.DelChains() - runner.AddBase("testTunn") - defer runner.DelBase() - - addr := netip.MustParseAddr("192.168.0.2") - addrV6 := netip.MustParseAddr("2001:db8::2") - runner.AddLoopbackRule(addr) - runner.AddLoopbackRule(addrV6) inputV4, _, _, err := getTsChains(conn, nftables.TableFamilyIPv4) if err != nil { t.Fatalf("getTsChains() failed: %v", err) } - inputV4Rules, err := conn.GetRules(runner.nft4.Filter, inputV4) + inputV6, _, _, err := getTsChains(conn, nftables.TableFamilyIPv6) if err != nil { - t.Fatalf("conn.GetRules() failed: %v", err) - } - if len(inputV4Rules) != 3 { - t.Fatalf("len(inputV4Rules) = %d, want 3", len(inputV4Rules)) + t.Fatalf("getTsChains() failed: %v", err) } + checkChainRules(t, conn, inputV4, 0) + checkChainRules(t, conn, inputV6, 0) + + runner.AddBase("testTunn") + defer runner.DelBase() + checkChainRules(t, conn, inputV4, 3) + checkChainRules(t, conn, inputV6, 3) + + addr := netip.MustParseAddr("192.168.0.2") + addrV6 := netip.MustParseAddr("2001:db8::2") + runner.AddLoopbackRule(addr) + runner.AddLoopbackRule(addrV6) + + checkChainRules(t, conn, inputV4, 4) + checkChainRules(t, conn, inputV6, 4) existingLoopBackRule, err := findLoopBackRule(conn, nftables.TableFamilyIPv4, runner.nft4.Filter, inputV4, addr) if err != nil { @@ -830,19 +791,6 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) { t.Fatalf("existingLoopBackRule.Handle = %d, want 0", existingLoopBackRule.Handle) } - inputV6, _, _, err := getTsChains(conn, nftables.TableFamilyIPv6) - if err != nil { - t.Fatalf("getTsChains() failed: %v", err) - } - - inputV6Rules, err := conn.GetRules(runner.nft6.Filter, inputV4) - if err != nil { - t.Fatalf("conn.GetRules() failed: %v", err) - } - if len(inputV6Rules) != 1 { - t.Fatalf("len(inputV4Rules) = %d, want 1", len(inputV4Rules)) - } - existingLoopBackRuleV6, err := findLoopBackRule(conn, nftables.TableFamilyIPv6, runner.nft6.Filter, inputV6, addrV6) if err != nil { t.Fatalf("findLoopBackRule() failed: %v", err) @@ -855,21 +803,11 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) { runner.DelLoopbackRule(addr) runner.DelLoopbackRule(addrV6) - inputV4Rules, err = conn.GetRules(runner.nft4.Filter, inputV4) - if err != nil { - t.Fatalf("conn.GetRules() failed: %v", err) - } - if len(inputV4Rules) != 2 { - t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules)) - } + checkChainRules(t, conn, inputV4, 3) + checkChainRules(t, conn, inputV6, 3) } func TestNFTAddAndDelHookRule(t *testing.T) { - if os.Geteuid() != 0 { - t.Skip(t.Name(), " requires privileges to create a namespace in order to run") - return - } - conn := newSysConn(t) runner := newFakeNftablesRunner(t, conn) runner.AddChains() @@ -880,72 +818,24 @@ func TestNFTAddAndDelHookRule(t *testing.T) { if err != nil { t.Fatalf("failed to get forwardChain: %v", err) } - - forwardChainRules, err := conn.GetRules(forwardChain.Table, forwardChain) - if err != nil { - t.Fatalf("failed to get rules: %v", err) - } - - if len(forwardChainRules) != 1 { - t.Fatalf("expected 1 rule in FORWARD chain, got %v", len(forwardChainRules)) - } - inputChain, err := getChainFromTable(conn, runner.nft4.Filter, "INPUT") if err != nil { t.Fatalf("failed to get inputChain: %v", err) } - - inputChainRules, err := conn.GetRules(inputChain.Table, inputChain) - if err != nil { - t.Fatalf("failed to get rules: %v", err) - } - - if len(inputChainRules) != 1 { - t.Fatalf("expected 1 rule in INPUT chain, got %v", len(inputChainRules)) - } - postroutingChain, err := getChainFromTable(conn, runner.nft4.Nat, "POSTROUTING") if err != nil { t.Fatalf("failed to get postroutingChain: %v", err) } - postroutingChainRules, err := conn.GetRules(postroutingChain.Table, postroutingChain) - if err != nil { - t.Fatalf("failed to get rules: %v", err) - } - - if len(postroutingChainRules) != 1 { - t.Fatalf("expected 1 rule in POSTROUTING chain, got %v", len(postroutingChainRules)) - } + checkChainRules(t, conn, forwardChain, 1) + checkChainRules(t, conn, inputChain, 1) + checkChainRules(t, conn, postroutingChain, 1) runner.DelHooks(t.Logf) - forwardChainRules, err = conn.GetRules(forwardChain.Table, forwardChain) - if err != nil { - t.Fatalf("failed to get rules: %v", err) - } - - if len(forwardChainRules) != 0 { - t.Fatalf("expected 0 rule in FORWARD chain, got %v", len(forwardChainRules)) - } - - inputChainRules, err = conn.GetRules(inputChain.Table, inputChain) - if err != nil { - t.Fatalf("failed to get rules: %v", err) - } - - if len(inputChainRules) != 0 { - t.Fatalf("expected 0 rule in INPUT chain, got %v", len(inputChainRules)) - } - - postroutingChainRules, err = conn.GetRules(postroutingChain.Table, postroutingChain) - if err != nil { - t.Fatalf("failed to get rules: %v", err) - } - - if len(postroutingChainRules) != 0 { - t.Fatalf("expected 0 rule in POSTROUTING chain, got %v", len(postroutingChainRules)) - } + checkChainRules(t, conn, forwardChain, 0) + checkChainRules(t, conn, inputChain, 0) + checkChainRules(t, conn, postroutingChain, 0) } type testFWDetector struct {