@ -474,6 +474,10 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) {
func newSysConn ( t * testing . T ) * nftables . Conn {
func newSysConn ( t * testing . T ) * nftables . Conn {
t . Helper ( )
t . Helper ( )
if os . Geteuid ( ) != 0 {
t . Skip ( t . Name ( ) , " requires privileges to create a namespace in order to run" )
return nil
}
runtime . LockOSThread ( )
runtime . LockOSThread ( )
@ -512,12 +516,21 @@ func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner {
}
}
}
}
func TestAddAndDelNetfilterChains ( t * testing . T ) {
func checkChains ( t * testing . T , conn * nftables . Conn , fam nftables . TableFamily , wantCount int ) {
if os . Geteuid ( ) != 0 {
t . Helper ( )
t . Skip ( t . Name ( ) , " requires privileges to create a namespace in order to run" )
got , err := conn . ListChainsOfTableFamily ( fam )
return
if err != nil {
t . Fatalf ( "conn.ListChainsOfTableFamily(%v) failed: %v" , fam , err )
}
if len ( got ) != wantCount {
t . Fatalf ( "len(got) = %d, want %d" , len ( got ) , wantCount )
}
}
}
func TestAddAndDelNetfilterChains ( t * testing . T ) {
conn := newSysConn ( t )
conn := newSysConn ( t )
checkChains ( t , conn , nftables . TableFamilyIPv4 , 0 )
checkChains ( t , conn , nftables . TableFamilyIPv6 , 0 )
runner := newFakeNftablesRunner ( t , conn )
runner := newFakeNftablesRunner ( t , conn )
runner . AddChains ( )
runner . AddChains ( )
@ -531,33 +544,22 @@ func TestAddAndDelNetfilterChains(t *testing.T) {
t . Fatalf ( "len(tables) = %d, want 4" , len ( tables ) )
t . Fatalf ( "len(tables) = %d, want 4" , len ( tables ) )
}
}
chainsV4 , err := conn . ListChainsOfTableFamily ( nftables . TableFamilyIPv4 )
checkChains ( t , conn , nftables . TableFamilyIPv4 , 6 )
if err != nil {
checkChains ( t , conn , nftables . TableFamilyIPv6 , 6 )
t . Fatalf ( "list chains failed: %v" , err )
}
if len ( chainsV4 ) != 6 {
t . Fatalf ( "len(chainsV4) = %d, want 6" , len ( chainsV4 ) )
}
chainsV6 , err := conn . ListChainsOfTableFamily ( nftables . TableFamilyIPv6 )
if err != nil {
t . Fatalf ( "list chains failed: %v" , err )
}
if len ( chainsV6 ) != 6 {
t . Fatalf ( "len(chainsV6) = %d, want 6" , len ( chainsV6 ) )
}
runner . DelChains ( )
runner . DelChains ( )
// The default chains should still be present.
checkChains ( t , conn , nftables . TableFamilyIPv4 , 3 )
checkChains ( t , conn , nftables . TableFamilyIPv6 , 3 )
tables , err = conn . ListTables ( )
tables , err = conn . ListTables ( )
if err != nil {
if err != nil {
t . Fatalf ( "conn.ListTables() failed: %v" , err )
t . Fatalf ( "conn.ListTables() failed: %v" , err )
}
}
if len ( tables ) != 0 {
if len ( tables ) != 4 {
t . Fatalf ( "len(tables) = %d, want 0 ", len ( tables ) )
t . Fatalf ( "len(tables) = %d, want 4" , len ( tables ) )
}
}
}
}
@ -646,12 +648,19 @@ func findCommonBaseRules(
return get , nil
return get , nil
}
}
func TestNFTAddAndDelNetfilterBase ( t * testing . T ) {
// checkChainRules verifies that the chain has the expected number of rules.
if os . Geteuid ( ) != 0 {
func checkChainRules ( t * testing . T , conn * nftables . Conn , chain * nftables . Chain , wantCount int ) {
t . Skip ( t . Name ( ) , " requires privileges to create a namespace in order to run" )
t . Helper ( )
return
got , err := conn . GetRules ( chain . Table , chain )
if err != nil {
t . Fatalf ( "conn.GetRules() failed: %v" , err )
}
if len ( got ) != wantCount {
t . Fatalf ( "got = %d, want %d" , len ( got ) , wantCount )
}
}
}
func TestNFTAddAndDelNetfilterBase ( t * testing . T ) {
conn := newSysConn ( t )
conn := newSysConn ( t )
runner := newFakeNftablesRunner ( t , conn )
runner := newFakeNftablesRunner ( t , conn )
@ -664,30 +673,9 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) {
if err != nil {
if err != nil {
t . Fatalf ( "getTsChains() failed: %v" , err )
t . Fatalf ( "getTsChains() failed: %v" , err )
}
}
checkChainRules ( t , conn , inputV4 , 3 )
inputV4Rules , err := conn . GetRules ( runner . nft4 . Filter , inputV4 )
checkChainRules ( t , conn , forwardV4 , 4 )
if err != nil {
checkChainRules ( t , conn , postroutingV4 , 0 )
t . Fatalf ( "conn.GetRules() failed: %v" , err )
}
if len ( inputV4Rules ) != 2 {
t . Fatalf ( "len(inputV4Rules) = %d, want 2" , len ( inputV4Rules ) )
}
forwardV4Rules , err := conn . GetRules ( runner . nft4 . Filter , forwardV4 )
if err != nil {
t . Fatalf ( "conn.GetRules() failed: %v" , err )
}
if len ( forwardV4Rules ) != 4 {
t . Fatalf ( "len(forwardV4Rules) = %d, want 4" , len ( forwardV4Rules ) )
}
postroutingV4Rules , err := conn . GetRules ( runner . nft4 . Nat , postroutingV4 )
if err != nil {
t . Fatalf ( "conn.GetRules() failed: %v" , err )
}
if len ( postroutingV4Rules ) != 0 {
t . Fatalf ( "len(postroutingV4Rules) = %d, want 0" , len ( postroutingV4Rules ) )
}
_ , err = findV4BaseRules ( conn , inputV4 , forwardV4 , "testTunn" )
_ , err = findV4BaseRules ( conn , inputV4 , forwardV4 , "testTunn" )
if err != nil {
if err != nil {
@ -703,30 +691,9 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) {
if err != nil {
if err != nil {
t . Fatalf ( "getTsChains() failed: %v" , err )
t . Fatalf ( "getTsChains() failed: %v" , err )
}
}
checkChainRules ( t , conn , inputV6 , 3 )
inputV6Rules , err := conn . GetRules ( runner . nft6 . Filter , inputV6 )
checkChainRules ( t , conn , forwardV6 , 4 )
if err != nil {
checkChainRules ( t , conn , postroutingV6 , 0 )
t . Fatalf ( "conn.GetRules() failed: %v" , err )
}
if len ( inputV6Rules ) != 0 {
t . Fatalf ( "len(inputV6Rules) = %d, want 0" , len ( inputV4Rules ) )
}
forwardV6Rules , err := conn . GetRules ( runner . nft6 . Filter , forwardV6 )
if err != nil {
t . Fatalf ( "conn.GetRules() failed: %v" , err )
}
if len ( forwardV6Rules ) != 3 {
t . Fatalf ( "len(forwardV6Rules) = %d, want 3" , len ( forwardV4Rules ) )
}
postroutingV6Rules , err := conn . GetRules ( runner . nft6 . Nat , postroutingV6 )
if err != nil {
t . Fatalf ( "conn.GetRules() failed: %v" , err )
}
if len ( postroutingV6Rules ) != 0 {
t . Fatalf ( "len(postroutingV6Rules) = %d, want 0" , len ( postroutingV4Rules ) )
}
_ , err = findCommonBaseRules ( conn , forwardV6 , "testTunn" )
_ , err = findCommonBaseRules ( conn , forwardV6 , "testTunn" )
if err != nil {
if err != nil {
@ -740,13 +707,7 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) {
t . Fatalf ( "conn.ListChains() failed: %v" , err )
t . Fatalf ( "conn.ListChains() failed: %v" , err )
}
}
for _ , chain := range chains {
for _ , chain := range chains {
chainRules , err := conn . GetRules ( chain . Table , chain )
checkChainRules ( t , conn , chain , 0 )
if err != nil {
t . Fatalf ( "conn.GetRules() failed: %v" , err )
}
if len ( chainRules ) != 0 {
t . Fatalf ( "len(chainRules) = %d, want 0" , len ( chainRules ) )
}
}
}
}
}
@ -790,36 +751,36 @@ func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nf
}
}
func TestNFTAddAndDelLoopbackRule ( t * testing . T ) {
func TestNFTAddAndDelLoopbackRule ( t * testing . T ) {
if os . Geteuid ( ) != 0 {
t . Skip ( t . Name ( ) , " requires privileges to create a namespace in order to run" )
return
}
conn := newSysConn ( t )
conn := newSysConn ( t )
runner := newFakeNftablesRunner ( t , conn )
runner := newFakeNftablesRunner ( t , conn )
runner . AddChains ( )
runner . AddChains ( )
defer runner . DelChains ( )
defer runner . DelChains ( )
runner . AddBase ( "testTunn" )
defer runner . DelBase ( )
addr := netip . MustParseAddr ( "192.168.0.2" )
addrV6 := netip . MustParseAddr ( "2001:db8::2" )
runner . AddLoopbackRule ( addr )
runner . AddLoopbackRule ( addrV6 )
inputV4 , _ , _ , err := getTsChains ( conn , nftables . TableFamilyIPv4 )
inputV4 , _ , _ , err := getTsChains ( conn , nftables . TableFamilyIPv4 )
if err != nil {
if err != nil {
t . Fatalf ( "getTsChains() failed: %v" , err )
t . Fatalf ( "getTsChains() failed: %v" , err )
}
}
inputV 4Rules, err := conn . GetRules ( runner . nft4 . Filter , inputV4 )
inputV6 , _ , _ , err := getTsChains ( conn , nftables . TableFamilyIPv6 )
if err != nil {
if err != nil {
t . Fatalf ( "conn.GetRules() failed: %v" , err )
t . Fatalf ( "getTsChains() failed: %v" , err )
}
if len ( inputV4Rules ) != 3 {
t . Fatalf ( "len(inputV4Rules) = %d, want 3" , len ( inputV4Rules ) )
}
}
checkChainRules ( t , conn , inputV4 , 0 )
checkChainRules ( t , conn , inputV6 , 0 )
runner . AddBase ( "testTunn" )
defer runner . DelBase ( )
checkChainRules ( t , conn , inputV4 , 3 )
checkChainRules ( t , conn , inputV6 , 3 )
addr := netip . MustParseAddr ( "192.168.0.2" )
addrV6 := netip . MustParseAddr ( "2001:db8::2" )
runner . AddLoopbackRule ( addr )
runner . AddLoopbackRule ( addrV6 )
checkChainRules ( t , conn , inputV4 , 4 )
checkChainRules ( t , conn , inputV6 , 4 )
existingLoopBackRule , err := findLoopBackRule ( conn , nftables . TableFamilyIPv4 , runner . nft4 . Filter , inputV4 , addr )
existingLoopBackRule , err := findLoopBackRule ( conn , nftables . TableFamilyIPv4 , runner . nft4 . Filter , inputV4 , addr )
if err != nil {
if err != nil {
@ -830,19 +791,6 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {
t . Fatalf ( "existingLoopBackRule.Handle = %d, want 0" , existingLoopBackRule . Handle )
t . Fatalf ( "existingLoopBackRule.Handle = %d, want 0" , existingLoopBackRule . Handle )
}
}
inputV6 , _ , _ , err := getTsChains ( conn , nftables . TableFamilyIPv6 )
if err != nil {
t . Fatalf ( "getTsChains() failed: %v" , err )
}
inputV6Rules , err := conn . GetRules ( runner . nft6 . Filter , inputV4 )
if err != nil {
t . Fatalf ( "conn.GetRules() failed: %v" , err )
}
if len ( inputV6Rules ) != 1 {
t . Fatalf ( "len(inputV4Rules) = %d, want 1" , len ( inputV4Rules ) )
}
existingLoopBackRuleV6 , err := findLoopBackRule ( conn , nftables . TableFamilyIPv6 , runner . nft6 . Filter , inputV6 , addrV6 )
existingLoopBackRuleV6 , err := findLoopBackRule ( conn , nftables . TableFamilyIPv6 , runner . nft6 . Filter , inputV6 , addrV6 )
if err != nil {
if err != nil {
t . Fatalf ( "findLoopBackRule() failed: %v" , err )
t . Fatalf ( "findLoopBackRule() failed: %v" , err )
@ -855,21 +803,11 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {
runner . DelLoopbackRule ( addr )
runner . DelLoopbackRule ( addr )
runner . DelLoopbackRule ( addrV6 )
runner . DelLoopbackRule ( addrV6 )
inputV4Rules , err = conn . GetRules ( runner . nft4 . Filter , inputV4 )
checkChainRules ( t , conn , inputV4 , 3 )
if err != nil {
checkChainRules ( t , conn , inputV6 , 3 )
t . Fatalf ( "conn.GetRules() failed: %v" , err )
}
if len ( inputV4Rules ) != 2 {
t . Fatalf ( "len(inputV4Rules) = %d, want 2" , len ( inputV4Rules ) )
}
}
}
func TestNFTAddAndDelHookRule ( t * testing . T ) {
func TestNFTAddAndDelHookRule ( t * testing . T ) {
if os . Geteuid ( ) != 0 {
t . Skip ( t . Name ( ) , " requires privileges to create a namespace in order to run" )
return
}
conn := newSysConn ( t )
conn := newSysConn ( t )
runner := newFakeNftablesRunner ( t , conn )
runner := newFakeNftablesRunner ( t , conn )
runner . AddChains ( )
runner . AddChains ( )
@ -880,72 +818,24 @@ func TestNFTAddAndDelHookRule(t *testing.T) {
if err != nil {
if err != nil {
t . Fatalf ( "failed to get forwardChain: %v" , err )
t . Fatalf ( "failed to get forwardChain: %v" , err )
}
}
forwardChainRules , err := conn . GetRules ( forwardChain . Table , forwardChain )
if err != nil {
t . Fatalf ( "failed to get rules: %v" , err )
}
if len ( forwardChainRules ) != 1 {
t . Fatalf ( "expected 1 rule in FORWARD chain, got %v" , len ( forwardChainRules ) )
}
inputChain , err := getChainFromTable ( conn , runner . nft4 . Filter , "INPUT" )
inputChain , err := getChainFromTable ( conn , runner . nft4 . Filter , "INPUT" )
if err != nil {
if err != nil {
t . Fatalf ( "failed to get inputChain: %v" , err )
t . Fatalf ( "failed to get inputChain: %v" , err )
}
}
inputChainRules , err := conn . GetRules ( inputChain . Table , inputChain )
if err != nil {
t . Fatalf ( "failed to get rules: %v" , err )
}
if len ( inputChainRules ) != 1 {
t . Fatalf ( "expected 1 rule in INPUT chain, got %v" , len ( inputChainRules ) )
}
postroutingChain , err := getChainFromTable ( conn , runner . nft4 . Nat , "POSTROUTING" )
postroutingChain , err := getChainFromTable ( conn , runner . nft4 . Nat , "POSTROUTING" )
if err != nil {
if err != nil {
t . Fatalf ( "failed to get postroutingChain: %v" , err )
t . Fatalf ( "failed to get postroutingChain: %v" , err )
}
}
postroutingChainRules , err := conn . GetRules ( postroutingChain . Table , postroutingChain )
checkChainRules ( t , conn , forwardChain , 1 )
if err != nil {
checkChainRules ( t , conn , inputChain , 1 )
t . Fatalf ( "failed to get rules: %v" , err )
checkChainRules ( t , conn , postroutingChain , 1 )
}
if len ( postroutingChainRules ) != 1 {
t . Fatalf ( "expected 1 rule in POSTROUTING chain, got %v" , len ( postroutingChainRules ) )
}
runner . DelHooks ( t . Logf )
runner . DelHooks ( t . Logf )
forwardChainRules , err = conn . GetRules ( forwardChain . Table , forwardChain )
checkChainRules ( t , conn , forwardChain , 0 )
if err != nil {
checkChainRules ( t , conn , inputChain , 0 )
t . Fatalf ( "failed to get rules: %v" , err )
checkChainRules ( t , conn , postroutingChain , 0 )
}
if len ( forwardChainRules ) != 0 {
t . Fatalf ( "expected 0 rule in FORWARD chain, got %v" , len ( forwardChainRules ) )
}
inputChainRules , err = conn . GetRules ( inputChain . Table , inputChain )
if err != nil {
t . Fatalf ( "failed to get rules: %v" , err )
}
if len ( inputChainRules ) != 0 {
t . Fatalf ( "expected 0 rule in INPUT chain, got %v" , len ( inputChainRules ) )
}
postroutingChainRules , err = conn . GetRules ( postroutingChain . Table , postroutingChain )
if err != nil {
t . Fatalf ( "failed to get rules: %v" , err )
}
if len ( postroutingChainRules ) != 0 {
t . Fatalf ( "expected 0 rule in POSTROUTING chain, got %v" , len ( postroutingChainRules ) )
}
}
}
type testFWDetector struct {
type testFWDetector struct {