@ -20,6 +20,8 @@ import (
"github.com/mdlayher/netlink"
"github.com/vishvananda/netns"
"tailscale.com/net/tsaddr"
"tailscale.com/tstest"
"tailscale.com/types/logger"
)
// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing
@ -503,19 +505,6 @@ func cleanupSysConn(t *testing.T, ns netns.NsHandle) {
}
}
func newFakeNftablesRunner ( t * testing . T , conn * nftables . Conn ) * nftablesRunner {
nft4 := & nftable { Proto : nftables . TableFamilyIPv4 }
nft6 := & nftable { Proto : nftables . TableFamilyIPv6 }
return & nftablesRunner {
conn : conn ,
nft4 : nft4 ,
nft6 : nft6 ,
v6Available : true ,
v6NATAvailable : true ,
}
}
func checkChains ( t * testing . T , conn * nftables . Conn , fam nftables . TableFamily , wantCount int ) {
t . Helper ( )
got , err := conn . ListChainsOfTableFamily ( fam )
@ -526,42 +515,76 @@ func checkChains(t *testing.T, conn *nftables.Conn, fam nftables.TableFamily, wa
t . Fatalf ( "len(got) = %d, want %d" , len ( got ) , wantCount )
}
}
func checkTables ( t * testing . T , conn * nftables . Conn , fam nftables . TableFamily , wantCount int ) {
t . Helper ( )
got , err := conn . ListTablesOfFamily ( fam )
if err != nil {
t . Fatalf ( "conn.ListTablesOfFamily(%v) failed: %v" , fam , err )
}
if len ( got ) != wantCount {
t . Fatalf ( "len(got) = %d, want %d" , len ( got ) , wantCount )
}
}
func TestAddAndDelNetfilterChains ( t * testing . T ) {
type test struct {
hostHasIPv6 bool
initIPv4ChainCount int
initIPv6ChainCount int
ipv4TableCount int
ipv6TableCount int
ipv4ChainCount int
ipv6ChainCount int
ipv4ChainCountPostDelete int
ipv6ChainCountPostDelete int
}
tests := [ ] test {
{
hostHasIPv6 : true ,
initIPv4ChainCount : 0 ,
initIPv6ChainCount : 0 ,
ipv4TableCount : 2 ,
ipv6TableCount : 2 ,
ipv4ChainCount : 6 ,
ipv6ChainCount : 6 ,
ipv4ChainCountPostDelete : 3 ,
ipv6ChainCountPostDelete : 3 ,
} ,
{ // host without IPv6 support
ipv4TableCount : 2 ,
ipv4ChainCount : 6 ,
ipv4ChainCountPostDelete : 3 ,
} }
for _ , tt := range tests {
t . Logf ( "running a test case for IPv6 support: %v" , tt . hostHasIPv6 )
conn := newSysConn ( t )
checkChains ( t , conn , nftables . TableFamilyIPv4 , 0 )
checkChains ( t , conn , nftables . TableFamilyIPv6 , 0 )
runner := newFakeNftablesRunnerWithConn ( t , conn , tt . hostHasIPv6 )
// Check that we start off with no chains.
checkChains ( t , conn , nftables . TableFamilyIPv4 , tt . initIPv4ChainCount )
checkChains ( t , conn , nftables . TableFamilyIPv6 , tt . initIPv6ChainCount )
runner := newFakeNftablesRunner ( t , conn )
if err := runner . AddChains ( ) ; err != nil {
t . Fatalf ( "runner.AddChains() failed: %v" , err )
}
tables , err := conn . ListTables ( )
if err != nil {
t . Fatalf ( "conn.ListTables() failed: %v" , err )
}
if len ( tables ) != 4 {
t . Fatalf ( "len(tables) = %d, want 4" , len ( tables ) )
}
// Check that the amount of tables for each IP family is as expected.
checkTables ( t , conn , nftables . TableFamilyIPv4 , tt . ipv4TableCount )
checkTables ( t , conn , nftables . TableFamilyIPv6 , tt . ipv6TableCount )
checkChains ( t , conn , nftables . TableFamilyIPv4 , 6 )
checkChains ( t , conn , nftables . TableFamilyIPv6 , 6 )
// Check that the amount of chains for each IP family is as expected.
checkChains ( t , conn , nftables . TableFamilyIPv4 , tt . ipv4ChainCount )
checkChains ( t , conn , nftables . TableFamilyIPv6 , tt . ipv6ChainCount )
runner . DelChains ( )
// The default chains should still be present.
checkChains ( t , conn , nftables . TableFamilyIPv4 , 3 )
checkChains ( t , conn , nftables . TableFamilyIPv6 , 3 )
tables , err = conn . ListTables ( )
if err != nil {
t . Fatalf ( "conn.ListTables() failed: %v" , err )
if err := runner . DelChains ( ) ; err != nil {
t . Fatalf ( "runner.DelChains() failed: %v" , err )
}
if len ( tables ) != 4 {
t . Fatalf ( "len(tables) = %d, want 4" , len ( tables ) )
// Test that the tables as well as the default chains are still present.
checkChains ( t , conn , nftables . TableFamilyIPv4 , tt . ipv4ChainCountPostDelete )
checkChains ( t , conn , nftables . TableFamilyIPv6 , tt . ipv6ChainCountPostDelete )
checkTables ( t , conn , nftables . TableFamilyIPv4 , tt . ipv4TableCount )
checkTables ( t , conn , nftables . TableFamilyIPv6 , tt . ipv6TableCount )
}
}
@ -665,7 +688,8 @@ func checkChainRules(t *testing.T, conn *nftables.Conn, chain *nftables.Chain, w
func TestNFTAddAndDelNetfilterBase ( t * testing . T ) {
conn := newSysConn ( t )
runner := newFakeNftablesRunner ( t , conn )
runner := newFakeNftablesRunnerWithConn ( t , conn , true )
if err := runner . AddChains ( ) ; err != nil {
t . Fatalf ( "AddChains() failed: %v" , err )
}
@ -759,7 +783,7 @@ func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nf
func TestNFTAddAndDelLoopbackRule ( t * testing . T ) {
conn := newSysConn ( t )
runner := newFakeNftablesRunner ( t , conn )
runner := newFakeNftablesRunner WithConn ( t , conn , true )
if err := runner . AddChains ( ) ; err != nil {
t . Fatalf ( "AddChains() failed: %v" , err )
}
@ -817,7 +841,7 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) {
func TestNFTAddAndDelHookRule ( t * testing . T ) {
conn := newSysConn ( t )
runner := newFakeNftablesRunner ( t , conn )
runner := newFakeNftablesRunner WithConn ( t , conn , true )
if err := runner . AddChains ( ) ; err != nil {
t . Fatalf ( "AddChains() failed: %v" , err )
}
@ -868,11 +892,11 @@ func (t *testFWDetector) nftDetect() (int, error) {
// postrouting chains are cleaned up.
func TestCreateDummyPostroutingChains ( t * testing . T ) {
conn := newSysConn ( t )
runner := newFakeNftablesRunner ( t , conn )
runner := newFakeNftablesRunner WithConn ( t , conn , true )
if err := runner . createDummyPostroutingChains ( ) ; err != nil {
t . Fatalf ( "createDummyPostroutingChains() failed: %v" , err )
}
for _ , table := range runner . get NAT Tables( ) {
for _ , table := range runner . get Tables( ) {
nt , err := getTableIfExists ( conn , table . Proto , tsDummyTableName )
if err != nil {
t . Fatalf ( "getTableIfExists() failed: %v" , err )
@ -929,3 +953,14 @@ func TestPickFirewallModeFromInstalledRules(t *testing.T) {
} )
}
}
func newFakeNftablesRunnerWithConn ( t * testing . T , conn * nftables . Conn , hasIPv6 bool ) * nftablesRunner {
t . Helper ( )
if ! hasIPv6 {
tstest . Replace ( t , & checkIPv6ForTest , func ( logger . Logf ) error {
return errors . New ( "test: no IPv6" )
} )
}
return newNfTablesRunnerWithConn ( t . Logf , conn )
}