diff --git a/cmd/containerboot/forwarding.go b/cmd/containerboot/forwarding.go index 050bf31c7..04d34836c 100644 --- a/cmd/containerboot/forwarding.go +++ b/cmd/containerboot/forwarding.go @@ -117,7 +117,7 @@ func installEgressForwardingRule(_ context.Context, dstStr string, tsIPs []netip if err := nfr.DNATNonTailscaleTraffic("tailscale0", dst); err != nil { return fmt.Errorf("installing egress proxy rules: %w", err) } - if err := nfr.AddSNATRuleForDst(local, dst); err != nil { + if err := nfr.EnsureSNATForDst(local, dst); err != nil { return fmt.Errorf("installing egress proxy rules: %w", err) } if err := nfr.ClampMSSToPMTU("tailscale0", dst); err != nil { diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index 6d2ad6cfc..86612d1a6 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -481,7 +481,11 @@ runLoop: egressAddrs = node.Addresses().AsSlice() newCurentEgressIPs = deephash.Hash(&egressAddrs) egressIPsHaveChanged = newCurentEgressIPs != currentEgressIPs - if egressIPsHaveChanged && len(egressAddrs) != 0 { + // The firewall rules get (re-)installed: + // - on startup + // - when the tailnet IPs of the tailnet target have changed + // - when the tailnet IPs of this node have changed + if (egressIPsHaveChanged || ipsHaveChanged) && len(egressAddrs) != 0 { var rulesInstalled bool for _, egressAddr := range egressAddrs { ea := egressAddr.Addr() diff --git a/cmd/containerboot/services.go b/cmd/containerboot/services.go index a3d7cdad2..41436fe53 100644 --- a/cmd/containerboot/services.go +++ b/cmd/containerboot/services.go @@ -196,8 +196,7 @@ func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *e if !local.IsValid() { return nil, fmt.Errorf("no valid local IP: %v", local) } - // TODO(irbekrm): only create the SNAT rule if it does not already exist. - if err := ep.nfr.AddSNATRuleForDst(local, t); err != nil { + if err := ep.nfr.EnsureSNATForDst(local, t); err != nil { return nil, fmt.Errorf("error setting up SNAT rule: %w", err) } } diff --git a/util/linuxfw/iptables_runner.go b/util/linuxfw/iptables_runner.go index e221ad596..9a6fc0224 100644 --- a/util/linuxfw/iptables_runner.go +++ b/util/linuxfw/iptables_runner.go @@ -9,6 +9,7 @@ import ( "bytes" "errors" "fmt" + "log" "net/netip" "os" "os/exec" @@ -371,9 +372,42 @@ func (i *iptablesRunner) AddDNATRule(origDst, dst netip.Addr) error { return table.Insert("nat", "PREROUTING", 1, "--destination", origDst.String(), "-j", "DNAT", "--to-destination", dst.String()) } -func (i *iptablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error { +// EnsureSNATForDst sets up firewall to ensure that all traffic aimed for dst, has its source ip set to src: +// - creates a SNAT rule if not already present +// - ensures that any no longer valid SNAT rules for the same dst are removed +func (i *iptablesRunner) EnsureSNATForDst(src, dst netip.Addr) error { table := i.getIPTByAddr(dst) - return table.Insert("nat", "POSTROUTING", 1, "--destination", dst.String(), "-j", "SNAT", "--to-source", src.String()) + rules, err := table.List("nat", "POSTROUTING") + if err != nil { + return fmt.Errorf("error listing rules: %v", err) + } + // iptables accept either address or a CIDR value for the --destination flag, but converts an address to /32 + // CIDR. Explicitly passing a /32 CIDR made it possible to test this rule. + dstPrefix, err := dst.Prefix(32) + if err != nil { + return fmt.Errorf("error calculating prefix of dst %v: %v", dst, err) + } + + // wantsArgsPrefix is the prefix of the SNAT rule for the provided destination. + // We should only have one POSTROUTING rule with this prefix. + wantsArgsPrefix := fmt.Sprintf("-d %s -j SNAT --to-source", dstPrefix.String()) + // wantsArgs is the actual SNAT rule that we want. + wantsArgs := fmt.Sprintf("%s %s", wantsArgsPrefix, src.String()) + for _, r := range rules { + args := argsFromPostRoutingRule(r) + if strings.HasPrefix(args, wantsArgsPrefix) { + if strings.HasPrefix(args, wantsArgs) { + return nil + } + // SNAT rule matching the destination, but for a different source - delete. + if err := table.Delete("nat", "POSTROUTING", strings.Split(args, " ")...); err != nil { + // If we failed to delete don't crash the node- the proxy should still be functioning. + log.Printf("[unexpected] error deleting rule %s: %v, please report it.", r, err) + } + break + } + } + return table.Insert("nat", "POSTROUTING", 1, "-d", dstPrefix.String(), "-j", "SNAT", "--to-source", src.String()) } func (i *iptablesRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error { @@ -731,3 +765,10 @@ func clearRules(proto iptables.Protocol, logf logger.Logf) error { return multierr.New(errs...) } + +// argsFromPostRoutingRule accepts a rule as returned by iptables.List and, if it is a rule from POSTROUTING chain, +// returns the args part, else returns the original rule. +func argsFromPostRoutingRule(r string) string { + args, _ := strings.CutPrefix(r, "-A POSTROUTING ") + return args +} diff --git a/util/linuxfw/iptables_runner_test.go b/util/linuxfw/iptables_runner_test.go index 2363e4ed3..56f13c78a 100644 --- a/util/linuxfw/iptables_runner_test.go +++ b/util/linuxfw/iptables_runner_test.go @@ -289,3 +289,77 @@ func TestAddAndDelSNATRule(t *testing.T) { t.Fatal(err) } } + +func TestEnsureSNATForDst_ipt(t *testing.T) { + ip1, ip2, ip3 := netip.MustParseAddr("100.99.99.99"), netip.MustParseAddr("100.88.88.88"), netip.MustParseAddr("100.77.77.77") + iptr := NewFakeIPTablesRunner() + + // 1. A new rule gets added + mustCreateSNATRule_ipt(t, iptr, ip1, ip2) + checkSNATRule_ipt(t, iptr, ip1, ip2) + checkSNATRuleCount(t, iptr, ip1, 1) + + // 2. Another call to EnsureSNATForDst with the same src and dst does not result in another rule being added. + mustCreateSNATRule_ipt(t, iptr, ip1, ip2) + checkSNATRule_ipt(t, iptr, ip1, ip2) + checkSNATRuleCount(t, iptr, ip1, 1) // still just 1 rule + + // 3. Another call to EnsureSNATForDst with a different src and the same dst results in the earlier rule being + // deleted. + mustCreateSNATRule_ipt(t, iptr, ip3, ip2) + checkSNATRule_ipt(t, iptr, ip3, ip2) + checkSNATRuleCount(t, iptr, ip1, 1) // still just 1 rule + + // 4. Another call to EnsureSNATForDst with a different dst should not get the earlier rule deleted. + mustCreateSNATRule_ipt(t, iptr, ip3, ip1) + checkSNATRule_ipt(t, iptr, ip3, ip1) + checkSNATRuleCount(t, iptr, ip1, 2) // now 2 rules + + // 5. A call to EnsureSNATForDst with a match dst and a match port should not get deleted by EnsureSNATForDst for the same dst. + args := []string{"--destination", ip1.String(), "-j", "SNAT", "--to-source", "10.0.0.1"} + if err := iptr.getIPTByAddr(ip1).Insert("nat", "POSTROUTING", 1, args...); err != nil { + t.Fatalf("error adding SNAT rule: %v", err) + } + exists, err := iptr.getIPTByAddr(ip1).Exists("nat", "POSTROUTING", args...) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if !exists { + t.Fatalf("SNAT rule for destination and port unexpectedly deleted") + } + mustCreateSNATRule_ipt(t, iptr, ip3, ip1) + checkSNATRuleCount(t, iptr, ip1, 3) // now 3 rules +} + +func mustCreateSNATRule_ipt(t *testing.T, iptr *iptablesRunner, src, dst netip.Addr) { + t.Helper() + if err := iptr.EnsureSNATForDst(src, dst); err != nil { + t.Fatalf("error ensuring SNAT rule: %v", err) + } +} + +func checkSNATRule_ipt(t *testing.T, iptr *iptablesRunner, src, dst netip.Addr) { + t.Helper() + dstPrefix, err := dst.Prefix(32) + if err != nil { + t.Fatalf("error converting addr to prefix: %v", err) + } + exists, err := iptr.getIPTByAddr(src).Exists("nat", "POSTROUTING", "-d", dstPrefix.String(), "-j", "SNAT", "--to-source", src.String()) + if err != nil { + t.Fatalf("error checking if rule exists: %v", err) + } + if !exists { + t.Fatalf("SNAT rule for src %s dst %s should exist, but it does not", src, dst) + } +} + +func checkSNATRuleCount(t *testing.T, iptr *iptablesRunner, ip netip.Addr, wantsRules int) { + t.Helper() + rules, err := iptr.getIPTByAddr(ip).List("nat", "POSTROUTING") + if err != nil { + t.Fatalf("error listing rules: %v", err) + } + if len(rules) != wantsRules { + t.Fatalf("wants %d rules, got %d", wantsRules, len(rules)) + } +} diff --git a/util/linuxfw/nftables_for_svcs_test.go b/util/linuxfw/nftables_for_svcs_test.go index 8a735d602..d2df6e4bd 100644 --- a/util/linuxfw/nftables_for_svcs_test.go +++ b/util/linuxfw/nftables_for_svcs_test.go @@ -27,32 +27,32 @@ func Test_nftablesRunner_EnsurePortMapRuleForSvc(t *testing.T) { runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP) svcChains(t, 1, conn) chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv4) - chainRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) + checkPortMapRule(t, "foo", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) // Create another rule for service 'foo' to forward TCP traffic to the // same IPv4 endpoint, but to a different port. runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv4, pmTCP1) svcChains(t, 1, conn) chainRuleCount(t, "foo", 2, conn, nftables.TableFamilyIPv4) - chainRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4) + checkPortMapRule(t, "foo", ipv4, pmTCP1, runner, nftables.TableFamilyIPv4) // Create a rule for service 'foo' to forward TCP traffic to an IPv6 endpoint runner.EnsurePortMapRuleForSvc("foo", "tailscale0", ipv6, pmTCP) svcChains(t, 2, conn) chainRuleCount(t, "foo", 1, conn, nftables.TableFamilyIPv6) - chainRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) + checkPortMapRule(t, "foo", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) // Create a rule for service 'bar' to forward TCP traffic to IPv4 endpoint runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv4, pmTCP) svcChains(t, 3, conn) chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv4) - chainRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) + checkPortMapRule(t, "bar", ipv4, pmTCP, runner, nftables.TableFamilyIPv4) // Create a rule for service 'bar' to forward TCP traffic to an IPv6 endpoint runner.EnsurePortMapRuleForSvc("bar", "tailscale0", ipv6, pmTCP) svcChains(t, 4, conn) chainRuleCount(t, "bar", 1, conn, nftables.TableFamilyIPv6) - chainRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) + checkPortMapRule(t, "bar", ipv6, pmTCP, runner, nftables.TableFamilyIPv6) // Delete service bar runner.DeleteSvc("bar", "tailscale0", []netip.Addr{ipv4, ipv6}, []PortMap{pmTCP}) @@ -95,36 +95,26 @@ func svcChains(t *testing.T, wantCount int, conn *nftables.Conn) { } } -// chainRuleCount returns number of rules in a chain identified by service name and IP family. -func chainRuleCount(t *testing.T, svc string, count int, conn *nftables.Conn, fam nftables.TableFamily) { +// chainRuleCount verifies that the named chain in the given table contains the provided number of rules. +func chainRuleCount(t *testing.T, name string, numOfRules int, conn *nftables.Conn, fam nftables.TableFamily) { t.Helper() chains, err := conn.ListChainsOfTableFamily(fam) if err != nil { t.Fatalf("error listing chains: %v", err) } - found := false for _, ch := range chains { - if ch.Name == svc { - found = true - rules, err := conn.GetRules(ch.Table, ch) - if err != nil { - t.Fatalf("error getting rules: %v", err) - } - if len(rules) != count { - t.Fatalf("unexpected number of rules, wants %d got %d", count, len(rules)) - } - break + if ch.Name == name { + checkChainRules(t, conn, ch, numOfRules) + return } } - if !found { - t.Fatalf("chain for service %s does not exist", svc) - } + t.Fatalf("chain %s does not exist", name) } -// chainRule verifies that rule for the provided target IP and PortMap exists in -// a chain identified by service name and IP family. -func chainRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner *nftablesRunner, fam nftables.TableFamily) { +// checkPortMapRule verifies that rule for the provided target IP and PortMap exists in a chain identified by service +// name and IP family. +func checkPortMapRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner *nftablesRunner, fam nftables.TableFamily) { t.Helper() chains, err := runner.conn.ListChainsOfTableFamily(fam) if err != nil { @@ -146,11 +136,17 @@ func chainRule(t *testing.T, svc string, targetIP netip.Addr, pm PortMap, runner t.Fatalf("error converting protocol: %v", err) } wantsRule := portMapRule(chain.Table, chain, "tailscale0", targetIP, pm.MatchPort, pm.TargetPort, p, meta) - gotRule, err := findRule(runner.conn, wantsRule) + checkRule(t, wantsRule, runner.conn) +} + +// checkRule checks that the provided rules exists. +func checkRule(t *testing.T, rule *nftables.Rule, conn *nftables.Conn) { + t.Helper() + gotRule, err := findRule(conn, rule) if err != nil { t.Fatalf("error looking up rule: %v", err) } if gotRule == nil { - t.Fatalf("rule not found") + t.Fatal("rule not found") } } diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index a7a407222..0f411521b 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -193,7 +193,7 @@ func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) return n.conn.Flush() } -func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error { +func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error { polAccept := nftables.ChainPolicyAccept table, err := n.getNFTByAddr(dst) if err != nil { @@ -216,44 +216,26 @@ func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error { if err != nil { return fmt.Errorf("error ensuring postrouting chain: %w", err) } - var daddrOffset, fam, daddrLen uint32 - if dst.Is4() { - daddrOffset = 16 - daddrLen = 4 - fam = unix.NFPROTO_IPV4 - } else { - daddrOffset = 24 - daddrLen = 16 - fam = unix.NFPROTO_IPV6 - } - snatRule := &nftables.Rule{ - Table: nat, - Chain: postRoutingCh, - Exprs: []expr.Any{ - &expr.Payload{ - DestRegister: 1, - Base: expr.PayloadBaseNetworkHeader, - Offset: daddrOffset, - Len: daddrLen, - }, - &expr.Cmp{ - Op: expr.CmpOpEq, - Register: 1, - Data: dst.AsSlice(), - }, - &expr.Immediate{ - Register: 1, - Data: src.AsSlice(), - }, - &expr.NAT{ - Type: expr.NATTypeSourceNAT, - Family: fam, - RegAddrMin: 1, - }, - }, + rules, err := n.conn.GetRules(nat, postRoutingCh) + if err != nil { + return fmt.Errorf("error listing rules: %w", err) + } + snatRulePrefixMatch := fmt.Sprintf("dst:%s,src:", dst.String()) + snatRuleFullMatch := fmt.Sprintf("%s%s", snatRulePrefixMatch, src.String()) + for _, rule := range rules { + current := string(rule.UserData) + if strings.HasPrefix(string(rule.UserData), snatRulePrefixMatch) { + if strings.EqualFold(current, snatRuleFullMatch) { + return nil // already exists, do nothing + } + if err := n.conn.DelRule(rule); err != nil { + return fmt.Errorf("error deleting SNAT rule: %w", err) + } + } } - n.conn.AddRule(snatRule) + rule := snatRule(nat, postRoutingCh, src, dst, []byte(snatRuleFullMatch)) + n.conn.AddRule(rule) return n.conn.Flush() } @@ -557,11 +539,12 @@ type NetfilterRunner interface { // in the Kubernetes ingress proxies. DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error - // AddSNATRuleForDst adds a rule to the nat/POSTROUTING chain to SNAT - // traffic destined for dst to src. + // EnsureSNATForDst sets up firewall to mask the source for traffic destined for dst to src: + // - creates a SNAT rule if it doesn't already exist + // - deletes any pre-existing rules matching the destination // This is used to forward traffic destined for the local machine over // the Tailscale interface, as used in the Kubernetes egress proxies. - AddSNATRuleForDst(src, dst netip.Addr) error + EnsureSNATForDst(src, dst netip.Addr) error // DNATNonTailscaleTraffic adds a rule to the nat/PREROUTING chain to DNAT // all traffic inbound from any interface except exemptInterface to dst. @@ -2028,3 +2011,45 @@ func NfTablesCleanUp(logf logger.Logf) { } } } + +func snatRule(t *nftables.Table, ch *nftables.Chain, src, dst netip.Addr, meta []byte) *nftables.Rule { + var daddrOffset, fam, daddrLen uint32 + if dst.Is4() { + daddrOffset = 16 + daddrLen = 4 + fam = unix.NFPROTO_IPV4 + } else { + daddrOffset = 24 + daddrLen = 16 + fam = unix.NFPROTO_IPV6 + } + + return &nftables.Rule{ + Table: t, + Chain: ch, + Exprs: []expr.Any{ + &expr.Payload{ + DestRegister: 1, + Base: expr.PayloadBaseNetworkHeader, + Offset: daddrOffset, + Len: daddrLen, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: dst.AsSlice(), + }, + &expr.Immediate{ + Register: 1, + Data: src.AsSlice(), + }, + &expr.NAT{ + Type: expr.NATTypeSourceNAT, + Family: fam, + RegAddrMin: 1, + RegAddrMax: 1, + }, + }, + UserData: meta, + } +} diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index ebf514c79..712a7b939 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -954,6 +954,37 @@ func TestPickFirewallModeFromInstalledRules(t *testing.T) { } } +// This test creates a temporary network namespace for the nftables rules being +// set up, so it needs to run in a privileged mode. Locally it needs to be run +// by root, else it will be silently skipped. In CI it runs in a privileged +// container. +func TestEnsureSNATForDst_nftables(t *testing.T) { + conn := newSysConn(t) + runner := newFakeNftablesRunnerWithConn(t, conn, true) + ip1, ip2, ip3 := netip.MustParseAddr("100.99.99.99"), netip.MustParseAddr("100.88.88.88"), netip.MustParseAddr("100.77.77.77") + + // 1. A new rule gets added + mustCreateSNATRule_nft(t, runner, ip1, ip2) + chainRuleCount(t, "POSTROUTING", 1, conn, nftables.TableFamilyIPv4) + checkSNATRule_nft(t, runner, runner.nft4.Proto, ip1, ip2) + + // 2. Another call to EnsureSNATForDst with the same src and dst does not result in another rule being added. + mustCreateSNATRule_nft(t, runner, ip1, ip2) + chainRuleCount(t, "POSTROUTING", 1, conn, nftables.TableFamilyIPv4) // still just one rule + checkSNATRule_nft(t, runner, runner.nft4.Proto, ip1, ip2) + + // 3. Another call to EnsureSNATForDst with a different src and the same dst results in the earlier rule being + // deleted. + mustCreateSNATRule_nft(t, runner, ip3, ip2) + chainRuleCount(t, "POSTROUTING", 1, conn, nftables.TableFamilyIPv4) // still just one rule + checkSNATRule_nft(t, runner, runner.nft4.Proto, ip3, ip2) + + // 4. Another call to EnsureSNATForDst with a different dst should not get the earlier rule deleted. + mustCreateSNATRule_nft(t, runner, ip3, ip1) + chainRuleCount(t, "POSTROUTING", 2, conn, nftables.TableFamilyIPv4) // now two rules + checkSNATRule_nft(t, runner, runner.nft4.Proto, ip3, ip1) +} + func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bool) *nftablesRunner { t.Helper() if !hasIPv6 { @@ -964,3 +995,32 @@ func newFakeNftablesRunnerWithConn(t *testing.T, conn *nftables.Conn, hasIPv6 bo } return newNfTablesRunnerWithConn(t.Logf, conn) } + +func mustCreateSNATRule_nft(t *testing.T, runner *nftablesRunner, src, dst netip.Addr) { + t.Helper() + if err := runner.EnsureSNATForDst(src, dst); err != nil { + t.Fatalf("error ensuring SNAT rule: %v", err) + } +} + +// checkSNATRule_nft verifies that a SNAT rule for the given destination and source exists. +func checkSNATRule_nft(t *testing.T, runner *nftablesRunner, fam nftables.TableFamily, src, dst netip.Addr) { + t.Helper() + chains, err := runner.conn.ListChainsOfTableFamily(fam) + if err != nil { + t.Fatalf("error listing chains: %v", err) + } + var chain *nftables.Chain + for _, ch := range chains { + if ch.Name == "POSTROUTING" { + chain = ch + break + } + } + if chain == nil { + t.Fatal("POSTROUTING chain does not exist") + } + meta := []byte(fmt.Sprintf("dst:%s,src:%s", dst.String(), src.String())) + wantsRule := snatRule(chain.Table, chain, src, dst, meta) + checkRule(t, wantsRule, runner.conn) +} diff --git a/wgengine/router/router_linux_test.go b/wgengine/router/router_linux_test.go index 893ff4a70..dce69550d 100644 --- a/wgengine/router/router_linux_test.go +++ b/wgengine/router/router_linux_test.go @@ -530,7 +530,7 @@ func (n *fakeIPTablesRunner) DNATWithLoadBalancer(netip.Addr, []netip.Addr) erro return errors.New("not implemented") } -func (n *fakeIPTablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error { +func (n *fakeIPTablesRunner) EnsureSNATForDst(src, dst netip.Addr) error { return errors.New("not implemented") }