cmd/containerboot,util/linuxfw: create a SNAT rule for dst/src only once, clean up if needed (#13658)

The AddSNATRuleForDst rule was adding a new rule each time it was called including:
- if a rule already existed
- if a rule matching the destination, but with different desired source already existed

This was causing issues especially for the in-progress egress HA proxies work,
where the rules are now refreshed more frequently, so more redundant rules
were being created.

This change:
- only creates the rule if it doesn't already exist
- if a rule for the same dst, but different source is found, delete it
- also ensures that egress proxies refresh firewall rules
if the node's tailnet IP changes

Updates tailscale/tailscale#13406

Signed-off-by: Irbe Krumina <irbe@tailscale.com>
angott/doh-clients-sleep-mode
Irbe Krumina 3 weeks ago committed by GitHub
parent a3c6a3a34f
commit 9bd158cc09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -117,7 +117,7 @@ func installEgressForwardingRule(_ context.Context, dstStr string, tsIPs []netip
if err := nfr.DNATNonTailscaleTraffic("tailscale0", dst); err != nil {
return fmt.Errorf("installing egress proxy rules: %w", err)
}
if err := nfr.AddSNATRuleForDst(local, dst); err != nil {
if err := nfr.EnsureSNATForDst(local, dst); err != nil {
return fmt.Errorf("installing egress proxy rules: %w", err)
}
if err := nfr.ClampMSSToPMTU("tailscale0", dst); err != nil {

@ -481,7 +481,11 @@ runLoop:
egressAddrs = node.Addresses().AsSlice()
newCurentEgressIPs = deephash.Hash(&egressAddrs)
egressIPsHaveChanged = newCurentEgressIPs != currentEgressIPs
if egressIPsHaveChanged && len(egressAddrs) != 0 {
// The firewall rules get (re-)installed:
// - on startup
// - when the tailnet IPs of the tailnet target have changed
// - when the tailnet IPs of this node have changed
if (egressIPsHaveChanged || ipsHaveChanged) && len(egressAddrs) != 0 {
var rulesInstalled bool
for _, egressAddr := range egressAddrs {
ea := egressAddr.Addr()

@ -196,8 +196,7 @@ func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *e
if !local.IsValid() {
return nil, fmt.Errorf("no valid local IP: %v", local)
}
// TODO(irbekrm): only create the SNAT rule if it does not already exist.
if err := ep.nfr.AddSNATRuleForDst(local, t); err != nil {
if err := ep.nfr.EnsureSNATForDst(local, t); err != nil {
return nil, fmt.Errorf("error setting up SNAT rule: %w", err)
}
}

@ -9,6 +9,7 @@ import (
"bytes"
"errors"
"fmt"
"log"
"net/netip"
"os"
"os/exec"
@ -371,9 +372,42 @@ func (i *iptablesRunner) AddDNATRule(origDst, dst netip.Addr) error {
return table.Insert("nat", "PREROUTING", 1, "--destination", origDst.String(), "-j", "DNAT", "--to-destination", dst.String())
}
func (i *iptablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
// EnsureSNATForDst sets up firewall to ensure that all traffic aimed for dst, has its source ip set to src:
// - creates a SNAT rule if not already present
// - ensures that any no longer valid SNAT rules for the same dst are removed
func (i *iptablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
table := i.getIPTByAddr(dst)
return table.Insert("nat", "POSTROUTING", 1, "--destination", dst.String(), "-j", "SNAT", "--to-source", src.String())
rules, err := table.List("nat", "POSTROUTING")
if err != nil {
return fmt.Errorf("error listing rules: %v", err)
}
// iptables accept either address or a CIDR value for the --destination flag, but converts an address to /32
// CIDR. Explicitly passing a /32 CIDR made it possible to test this rule.
dstPrefix, err := dst.Prefix(32)
if err != nil {
return fmt.Errorf("error calculating prefix of dst %v: %v", dst, err)
}
// wantsArgsPrefix is the prefix of the SNAT rule for the provided destination.
// We should only have one POSTROUTING rule with this prefix.
wantsArgsPrefix := fmt.Sprintf("-d %s -j SNAT --to-source", dstPrefix.String())
// wantsArgs is the actual SNAT rule that we want.
wantsArgs := fmt.Sprintf("%s %s", wantsArgsPrefix, src.String())
for _, r := range rules {
args := argsFromPostRoutingRule(r)
if strings.HasPrefix(args, wantsArgsPrefix) {
if strings.HasPrefix(args, wantsArgs) {
return nil
}
// SNAT rule matching the destination, but for a different source - delete.
if err := table.Delete("nat", "POSTROUTING", strings.Split(args, " ")...); err != nil {
// If we failed to delete don't crash the node- the proxy should still be functioning.
log.Printf("[unexpected] error deleting rule %s: %v, please report it.", r, err)
}
break
}
}
return table.Insert("nat", "POSTROUTING", 1, "-d", dstPrefix.String(), "-j", "SNAT", "--to-source", src.String())
}
func (i *iptablesRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error {
@ -731,3 +765,10 @@ func clearRules(proto iptables.Protocol, logf logger.Logf) error {
return multierr.New(errs...)
}
// argsFromPostRoutingRule accepts a rule as returned by iptables.List and, if it is a rule from POSTROUTING chain,
// returns the args part, else returns the original rule.
func argsFromPostRoutingRule(r string) string {
args, _ := strings.CutPrefix(r, "-A POSTROUTING ")
return args
}

@ -289,3 +289,77 @@ func TestAddAndDelSNATRule(t *testing.T) {
t.Fatal(err)
}
}
func TestEnsureSNATForDst_ipt(t *testing.T) {
ip1, ip2, ip3 := netip.MustParseAddr("100.99.99.99"), netip.MustParseAddr("100.88.88.88"), netip.MustParseAddr("100.77.77.77")
iptr := NewFakeIPTablesRunner()
// 1. A new rule gets added
mustCreateSNATRule_ipt(t, iptr, ip1, ip2)
checkSNATRule_ipt(t, iptr, ip1, ip2)
checkSNATRuleCount(t, iptr, ip1, 1)
// 2. Another call to EnsureSNATForDst with the same src and dst does not result in another rule being added.
mustCreateSNATRule_ipt(t, iptr, ip1, ip2)
checkSNATRule_ipt(t, iptr, ip1, ip2)
checkSNATRuleCount(t, iptr, ip1, 1) // still just 1 rule
// 3. Another call to EnsureSNATForDst with a different src and the same dst results in the earlier rule being
// deleted.
mustCreateSNATRule_ipt(t, iptr, ip3, ip2)
checkSNATRule_ipt(t, iptr, ip3, ip2)
checkSNATRuleCount(t, iptr, ip1, 1) // still just 1 rule
// 4. Another call to EnsureSNATForDst with a different dst should not get the earlier rule deleted.
mustCreateSNATRule_ipt(t, iptr, ip3, ip1)
checkSNATRule_ipt(t, iptr, ip3, ip1)
checkSNATRuleCount(t, iptr, ip1, 2) // now 2 rules
// 5. A call to EnsureSNATForDst with a match dst and a match port should not get deleted by EnsureSNATForDst for the same dst.
args := []string{"--destination", ip1.String(), "-j", "SNAT", "--to-source", "10.0.0.1"}
if err := iptr.getIPTByAddr(ip1).Insert("nat", "POSTROUTING", 1, args...); err != nil {
t.Fatalf("error adding SNAT rule: %v", err)
}
exists, err := iptr.getIPTByAddr(ip1).Exists("nat", "POSTROUTING", args...)
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if !exists {
t.Fatalf("SNAT rule for destination and port unexpectedly deleted")
}
mustCreateSNATRule_ipt(t, iptr, ip3, ip1)
checkSNATRuleCount(t, iptr, ip1, 3) // now 3 rules
}
func mustCreateSNATRule_ipt(t *testing.T, iptr *iptablesRunner, src, dst netip.Addr) {
t.Helper()
if err := iptr.EnsureSNATForDst(src, dst); err != nil {
t.Fatalf("error ensuring SNAT rule: %v", err)
}
}
func checkSNATRule_ipt(t *testing.T, iptr *iptablesRunner, src, dst netip.Addr) {
t.Helper()
dstPrefix, err := dst.Prefix(32)
if err != nil {
t.Fatalf("error converting addr to prefix: %v", err)
}
exists, err := iptr.getIPTByAddr(src).Exists("nat", "POSTROUTING", "-d", dstPrefix.String(), "-j", "SNAT", "--to-source", src.String())
if err != nil {
t.Fatalf("error checking if rule exists: %v", err)
}
if !exists {
t.Fatalf("SNAT rule for src %s dst %s should exist, but it does not", src, dst)
}
}
func checkSNATRuleCount(t *testing.T, iptr *iptablesRunner, ip netip.Addr, wantsRules int) {
t.Helper()
rules, err := iptr.getIPTByAddr(ip).List("nat", "POSTROUTING")
if err != nil {
t.Fatalf("error listing rules: %v", err)
}
if len(rules) != wantsRules {
t.Fatalf("wants %d rules, got %d", wantsRules, len(rules))
}
}

@ -27,32 +27,32 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) {
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP)
svcChains(t, 1, conn)
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4)
chainRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
checkPortMapRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
// Create another rule for service 'foo' to forward TCP traffic to the
// same IPv4 endpoint, but to a different port.
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1)
svcChains(t, 1, conn)
chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4)
chainRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
checkPortMapRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4)
// Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint
runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP)
svcChains(t, 2, conn)
chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6)
chainRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
checkPortMapRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
// Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP)
svcChains(t, 3, conn)
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4)
chainRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
checkPortMapRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4)
// Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint
runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP)
svcChains(t, 4, conn)
chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6)
chainRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
checkPortMapRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6)
// Delete service bar
runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP})
@ -95,36 +95,26 @@ func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) {
}
}
// chainRuleCount returns number of rules in a chain identified by service name and IP family.
func chainRuleCount(t *testing.T, svc string, count int, conn *nftables.Conn, fam nftables.TableFamily) {
// chainRuleCount verifies that the named chain in the given table contains the provided number of rules.
func chainRuleCount(t *testing.T, name string, numOfRules int, conn *nftables.Conn, fam nftables.TableFamily) {
t.Helper()
chains, err := conn.ListChainsOfTableFamily(fam)
if err != nil {
t.Fatalf("error listing chains: %v", err)
}
found := false
for _, ch := range chains {
if ch.Name == svc {
found = true
rules, err := conn.GetRules(ch.Table, ch)
if err != nil {
t.Fatalf("error getting rules: %v", err)
}
if len(rules) != count {
t.Fatalf("unexpected number of rules, wants %d got %d", count, len(rules))
}
break
if ch.Name == name {
checkChainRules(t, conn, ch, numOfRules)
return
}
}
if !found {
t.Fatalf("chain for service %s does not exist", svc)
}
t.Fatalf("chain %s does not exist", name)
}
// chainRule verifies that rule for the provided target IP and PortMap exists in
// a chain identified by service name and IP family.
func chainRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner *nftablesRunner, fam nftables.TableFamily) {
// checkPortMapRule verifies that rule for the provided target IP and PortMap exists in a chain identified by service
// name and IP family.
func checkPortMapRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner *nftablesRunner, fam nftables.TableFamily) {
t.Helper()
chains, err := runner.conn.ListChainsOfTableFamily(fam)
if err != nil {
@ -146,11 +136,17 @@ func chainRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner
t.Fatalf("error converting protocol: %v", err)
}
wantsRule := portMapRule(chain.Table, chain, "tailscale0", targetIP, pm.MatchPort, pm.TargetPort, p, meta)
gotRule, err := findRule(runner.conn, wantsRule)
checkRule(t, wantsRule, runner.conn)
}
// checkRule checks that the provided rules exists.
func checkRule(t *testing.T, rule *nftables.Rule, conn *nftables.Conn) {
t.Helper()
gotRule, err := findRule(conn, rule)
if err != nil {
t.Fatalf("error looking up rule: %v", err)
}
if gotRule == nil {
t.Fatalf("rule not found")
t.Fatal("rule not found")
}
}

@ -193,7 +193,7 @@ func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr)
return n.conn.Flush()
}
func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
polAccept := nftables.ChainPolicyAccept
table, err := n.getNFTByAddr(dst)
if err != nil {
@ -216,44 +216,26 @@ func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
if err != nil {
return fmt.Errorf("error ensuring postrouting chain: %w", err)
}
var daddrOffset, fam, daddrLen uint32
if dst.Is4() {
daddrOffset = 16
daddrLen = 4
fam = unix.NFPROTO_IPV4
} else {
daddrOffset = 24
daddrLen = 16
fam = unix.NFPROTO_IPV6
}
snatRule := &nftables.Rule{
Table: nat,
Chain: postRoutingCh,
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: daddrOffset,
Len: daddrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: dst.AsSlice(),
},
&expr.Immediate{
Register: 1,
Data: src.AsSlice(),
},
&expr.NAT{
Type: expr.NATTypeSourceNAT,
Family: fam,
RegAddrMin: 1,
},
},
rules, err := n.conn.GetRules(nat, postRoutingCh)
if err != nil {
return fmt.Errorf("error listing rules: %w", err)
}
snatRulePrefixMatch := fmt.Sprintf("dst:%s,src:", dst.String())
snatRuleFullMatch := fmt.Sprintf("%s%s", snatRulePrefixMatch, src.String())
for _, rule := range rules {
current := string(rule.UserData)
if strings.HasPrefix(string(rule.UserData), snatRulePrefixMatch) {
if strings.EqualFold(current, snatRuleFullMatch) {
return nil // already exists, do nothing
}
if err := n.conn.DelRule(rule); err != nil {
return fmt.Errorf("error deleting SNAT rule: %w", err)
}
}
}
n.conn.AddRule(snatRule)
rule := snatRule(nat, postRoutingCh, src, dst, []byte(snatRuleFullMatch))
n.conn.AddRule(rule)
return n.conn.Flush()
}
@ -557,11 +539,12 @@ type NetfilterRunner interface {
// in the Kubernetes ingress proxies.
DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error
// AddSNATRuleForDst adds a rule to the nat/POSTROUTING chain to SNAT
// traffic destined for dst to src.
// EnsureSNATForDst sets up firewall to mask the source for traffic destined for dst to src:
// - creates a SNAT rule if it doesn't already exist
// - deletes any pre-existing rules matching the destination
// This is used to forward traffic destined for the local machine over
// the Tailscale interface, as used in the Kubernetes egress proxies.
AddSNATRuleForDst(src, dst netip.Addr) error
EnsureSNATForDst(src, dst netip.Addr) error
// DNATNonTailscaleTraffic adds a rule to the nat/PREROUTING chain to DNAT
// all traffic inbound from any interface except exemptInterface to dst.
@ -2028,3 +2011,45 @@ func NfTablesCleanUp(logf logger.Logf) {
}
}
}
func snatRule(t *nftables.Table, ch *nftables.Chain, src, dst netip.Addr, meta []byte) *nftables.Rule {
var daddrOffset, fam, daddrLen uint32
if dst.Is4() {
daddrOffset = 16
daddrLen = 4
fam = unix.NFPROTO_IPV4
} else {
daddrOffset = 24
daddrLen = 16
fam = unix.NFPROTO_IPV6
}
return &nftables.Rule{
Table: t,
Chain: ch,
Exprs: []expr.Any{
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseNetworkHeader,
Offset: daddrOffset,
Len: daddrLen,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: dst.AsSlice(),
},
&expr.Immediate{
Register: 1,
Data: src.AsSlice(),
},
&expr.NAT{
Type: expr.NATTypeSourceNAT,
Family: fam,
RegAddrMin: 1,
RegAddrMax: 1,
},
},
UserData: meta,
}
}

@ -954,6 +954,37 @@ func TestPickFirewallModeFromInstalledRules(t *testing.T) {
}
}
// This test creates a temporary network namespace for the nftables rules being
// set up, so it needs to run in a privileged mode. Locally it needs to be run
// by root, else it will be silently skipped. In CI it runs in a privileged
// container.
func TestEnsureSNATForDst_nftables(t *testing.T) {
conn := newSysConn(t)
runner := newFakeNftablesRunnerWithConn(t, conn, true)
ip1, ip2, ip3 := netip.MustParseAddr("100.99.99.99"), netip.MustParseAddr("100.88.88.88"), netip.MustParseAddr("100.77.77.77")
// 1. A new rule gets added
mustCreateSNATRule_nft(t, runner, ip1, ip2)
chainRuleCount(t, "POSTROUTING", 1, conn, nftables.TableFamilyIPv4)
checkSNATRule_nft(t, runner, runner.nft4.Proto, ip1, ip2)
// 2. Another call to EnsureSNATForDst with the same src and dst does not result in another rule being added.
mustCreateSNATRule_nft(t, runner, ip1, ip2)
chainRuleCount(t, "POSTROUTING", 1, conn, nftables.TableFamilyIPv4) // still just one rule
checkSNATRule_nft(t, runner, runner.nft4.Proto, ip1, ip2)
// 3. Another call to EnsureSNATForDst with a different src and the same dst results in the earlier rule being
// deleted.
mustCreateSNATRule_nft(t, runner, ip3, ip2)
chainRuleCount(t, "POSTROUTING", 1, conn, nftables.TableFamilyIPv4) // still just one rule
checkSNATRule_nft(t, runner, runner.nft4.Proto, ip3, ip2)
// 4. Another call to EnsureSNATForDst with a different dst should not get the earlier rule deleted.
mustCreateSNATRule_nft(t, runner, ip3, ip1)
chainRuleCount(t, "POSTROUTING", 2, conn, nftables.TableFamilyIPv4) // now two rules
checkSNATRule_nft(t, runner, runner.nft4.Proto, ip3, ip1)
}
func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bool) *nftablesRunner {
t.Helper()
if !hasIPv6 {
@ -964,3 +995,32 @@ func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bo
}
return newNfTablesRunnerWithConn(t.Logf, conn)
}
func mustCreateSNATRule_nft(t *testing.T, runner *nftablesRunner, src, dst netip.Addr) {
t.Helper()
if err := runner.EnsureSNATForDst(src, dst); err != nil {
t.Fatalf("error ensuring SNAT rule: %v", err)
}
}
// checkSNATRule_nft verifies that a SNAT rule for the given destination and source exists.
func checkSNATRule_nft(t *testing.T, runner *nftablesRunner, fam nftables.TableFamily, src, dst netip.Addr) {
t.Helper()
chains, err := runner.conn.ListChainsOfTableFamily(fam)
if err != nil {
t.Fatalf("error listing chains: %v", err)
}
var chain *nftables.Chain
for _, ch := range chains {
if ch.Name == "POSTROUTING" {
chain = ch
break
}
}
if chain == nil {
t.Fatal("POSTROUTING chain does not exist")
}
meta := []byte(fmt.Sprintf("dst:%s,src:%s", dst.String(), src.String()))
wantsRule := snatRule(chain.Table, chain, src, dst, meta)
checkRule(t, wantsRule, runner.conn)
}

@ -530,7 +530,7 @@ func (n *fakeIPTablesRunner) DNATWithLoadBalancer(netip.Addr, []netip.Addr) erro
return errors.New("not implemented")
}
func (n *fakeIPTablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error {
func (n *fakeIPTablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
return errors.New("not implemented")
}

Loading…
Cancel
Save