@ -27,32 +27,32 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) {
runner . EnsurePortMapRuleForSvc ( "foo" , "tailscale0" , ipv4 , pmTCP )
runner . EnsurePortMapRuleForSvc ( "foo" , "tailscale0" , ipv4 , pmTCP )
svcChains ( t , 1 , conn )
svcChains ( t , 1 , conn )
chainRuleCount ( t , "foo" , 1 , conn , nftables . TableFamilyIPv4 )
chainRuleCount ( t , "foo" , 1 , conn , nftables . TableFamilyIPv4 )
ch ain Rule( t , "foo" , ipv4 , pmTCP , runner , nftables . TableFamilyIPv4 )
ch eckPortMap Rule( t , "foo" , ipv4 , pmTCP , runner , nftables . TableFamilyIPv4 )
// Create another rule for service 'foo' to forward TCP traffic to the
// Create another rule for service 'foo' to forward TCP traffic to the
// same IPv4 endpoint, but to a different port.
// same IPv4 endpoint, but to a different port.
runner . EnsurePortMapRuleForSvc ( "foo" , "tailscale0" , ipv4 , pmTCP1 )
runner . EnsurePortMapRuleForSvc ( "foo" , "tailscale0" , ipv4 , pmTCP1 )
svcChains ( t , 1 , conn )
svcChains ( t , 1 , conn )
chainRuleCount ( t , "foo" , 2 , conn , nftables . TableFamilyIPv4 )
chainRuleCount ( t , "foo" , 2 , conn , nftables . TableFamilyIPv4 )
ch ain Rule( t , "foo" , ipv4 , pmTCP1 , runner , nftables . TableFamilyIPv4 )
ch eckPortMap Rule( t , "foo" , ipv4 , pmTCP1 , runner , nftables . TableFamilyIPv4 )
// Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint
// Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint
runner . EnsurePortMapRuleForSvc ( "foo" , "tailscale0" , ipv6 , pmTCP )
runner . EnsurePortMapRuleForSvc ( "foo" , "tailscale0" , ipv6 , pmTCP )
svcChains ( t , 2 , conn )
svcChains ( t , 2 , conn )
chainRuleCount ( t , "foo" , 1 , conn , nftables . TableFamilyIPv6 )
chainRuleCount ( t , "foo" , 1 , conn , nftables . TableFamilyIPv6 )
ch ain Rule( t , "foo" , ipv6 , pmTCP , runner , nftables . TableFamilyIPv6 )
ch eckPortMap Rule( t , "foo" , ipv6 , pmTCP , runner , nftables . TableFamilyIPv6 )
// Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint
// Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint
runner . EnsurePortMapRuleForSvc ( "bar" , "tailscale0" , ipv4 , pmTCP )
runner . EnsurePortMapRuleForSvc ( "bar" , "tailscale0" , ipv4 , pmTCP )
svcChains ( t , 3 , conn )
svcChains ( t , 3 , conn )
chainRuleCount ( t , "bar" , 1 , conn , nftables . TableFamilyIPv4 )
chainRuleCount ( t , "bar" , 1 , conn , nftables . TableFamilyIPv4 )
ch ain Rule( t , "bar" , ipv4 , pmTCP , runner , nftables . TableFamilyIPv4 )
ch eckPortMap Rule( t , "bar" , ipv4 , pmTCP , runner , nftables . TableFamilyIPv4 )
// Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint
// Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint
runner . EnsurePortMapRuleForSvc ( "bar" , "tailscale0" , ipv6 , pmTCP )
runner . EnsurePortMapRuleForSvc ( "bar" , "tailscale0" , ipv6 , pmTCP )
svcChains ( t , 4 , conn )
svcChains ( t , 4 , conn )
chainRuleCount ( t , "bar" , 1 , conn , nftables . TableFamilyIPv6 )
chainRuleCount ( t , "bar" , 1 , conn , nftables . TableFamilyIPv6 )
ch ain Rule( t , "bar" , ipv6 , pmTCP , runner , nftables . TableFamilyIPv6 )
ch eckPortMap Rule( t , "bar" , ipv6 , pmTCP , runner , nftables . TableFamilyIPv6 )
// Delete service bar
// Delete service bar
runner . DeleteSvc ( "bar" , "tailscale0" , [ ] netip . Addr { ipv4 , ipv6 } , [ ] PortMap { pmTCP } )
runner . DeleteSvc ( "bar" , "tailscale0" , [ ] netip . Addr { ipv4 , ipv6 } , [ ] PortMap { pmTCP } )
@ -95,36 +95,26 @@ func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) {
}
}
}
}
// chainRuleCount returns number of rules in a chain identified by service name and IP family .
// chainRuleCount verifies that the named chain in the given table contains the provided number of rules .
func chainRuleCount ( t * testing . T , svc string , count int , conn * nftables . Conn , fam nftables . TableFamily ) {
func chainRuleCount ( t * testing . T , name string , numOfRules int , conn * nftables . Conn , fam nftables . TableFamily ) {
t . Helper ( )
t . Helper ( )
chains , err := conn . ListChainsOfTableFamily ( fam )
chains , err := conn . ListChainsOfTableFamily ( fam )
if err != nil {
if err != nil {
t . Fatalf ( "error listing chains: %v" , err )
t . Fatalf ( "error listing chains: %v" , err )
}
}
found := false
for _ , ch := range chains {
for _ , ch := range chains {
if ch . Name == svc {
if ch . Name == name {
found = true
checkChainRules ( t , conn , ch , numOfRules )
rules , err := conn . GetRules ( ch . Table , ch )
return
if err != nil {
t . Fatalf ( "error getting rules: %v" , err )
}
if len ( rules ) != count {
t . Fatalf ( "unexpected number of rules, wants %d got %d" , count , len ( rules ) )
}
break
}
}
}
}
if ! found {
t . Fatalf ( "chain %s does not exist" , name )
t . Fatalf ( "chain for service %s does not exist" , svc )
}
}
}
// ch ain Rule verifies that rule for the provided target IP and PortMap exists in
// checkPortMapRule verifies that rule for the provided target IP and PortMap exists in a chain identified by service
// a chain identified by service name and IP family.
// name and IP family.
func ch ain Rule( t * testing . T , svc string , targetIP netip . Addr , pm PortMap , runner * nftablesRunner , fam nftables . TableFamily ) {
func ch eckPortMap Rule( t * testing . T , svc string , targetIP netip . Addr , pm PortMap , runner * nftablesRunner , fam nftables . TableFamily ) {
t . Helper ( )
t . Helper ( )
chains , err := runner . conn . ListChainsOfTableFamily ( fam )
chains , err := runner . conn . ListChainsOfTableFamily ( fam )
if err != nil {
if err != nil {
@ -146,11 +136,17 @@ func chainRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner
t . Fatalf ( "error converting protocol: %v" , err )
t . Fatalf ( "error converting protocol: %v" , err )
}
}
wantsRule := portMapRule ( chain . Table , chain , "tailscale0" , targetIP , pm . MatchPort , pm . TargetPort , p , meta )
wantsRule := portMapRule ( chain . Table , chain , "tailscale0" , targetIP , pm . MatchPort , pm . TargetPort , p , meta )
gotRule , err := findRule ( runner . conn , wantsRule )
checkRule ( t , wantsRule , runner . conn )
}
// checkRule checks that the provided rules exists.
func checkRule ( t * testing . T , rule * nftables . Rule , conn * nftables . Conn ) {
t . Helper ( )
gotRule , err := findRule ( conn , rule )
if err != nil {
if err != nil {
t . Fatalf ( "error looking up rule: %v" , err )
t . Fatalf ( "error looking up rule: %v" , err )
}
}
if gotRule == nil {
if gotRule == nil {
t . Fatalf ( "rule not found" )
t . Fatal ( "rule not found" )
}
}
}
}