pull/18034/merge
Brad Fitzpatrick 2 days ago committed by GitHub
commit 1185facf2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -137,6 +137,12 @@ func NewFakeIPTablesRunner() NetfilterRunner {
v6Available = true v6Available = true
} }
iptr := &iptablesRunner{ipt4, ipt6, v6Available, v6Available, v6Available} return &iptablesRunner{
return iptr af: FamilyBoth,
ipt4: ipt4,
ipt6: ipt6,
v6Available: v6Available,
v6NATAvailable: v6Available,
v6FilterAvailable: v6Available,
}
} }

@ -116,7 +116,9 @@ func newIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) {
ipt6: ipt6, ipt6: ipt6,
v6Available: supportsV6, v6Available: supportsV6,
v6NATAvailable: supportsV6NAT, v6NATAvailable: supportsV6NAT,
v6FilterAvailable: supportsV6Filter}, nil v6FilterAvailable: supportsV6Filter,
af: FamilyBoth,
}, nil
} }
// checkSupportsV6Filter returns whether the system has a "filter" table in the // checkSupportsV6Filter returns whether the system has a "filter" table in the

@ -18,8 +18,8 @@ import (
// EnsurePortMapRuleForSvc adds a prerouting rule that forwards traffic received // EnsurePortMapRuleForSvc adds a prerouting rule that forwards traffic received
// on match port and NOT on the provided interface to target IP and target port. // on match port and NOT on the provided interface to target IP and target port.
// Rule will only be added if it does not already exists. // Rule will only be added if it does not already exists.
func (i *iptablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { func (r *iptablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
table := i.getIPTByAddr(targetIP) table := r.getIPTByAddr(targetIP)
args := argsForPortMapRule(svc, tun, targetIP, pm) args := argsForPortMapRule(svc, tun, targetIP, pm)
exists, err := table.Exists("nat", "PREROUTING", args...) exists, err := table.Exists("nat", "PREROUTING", args...)
if err != nil { if err != nil {
@ -34,8 +34,8 @@ func (i *iptablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip
// DeleteMapRuleForSvc constructs a prerouting rule as would be created by // DeleteMapRuleForSvc constructs a prerouting rule as would be created by
// EnsurePortMapRuleForSvc with the provided args and, if such a rule exists, // EnsurePortMapRuleForSvc with the provided args and, if such a rule exists,
// deletes it. // deletes it.
func (i *iptablesRunner) DeletePortMapRuleForSvc(svc, excludeI string, targetIP netip.Addr, pm PortMap) error { func (r *iptablesRunner) DeletePortMapRuleForSvc(svc, excludeI string, targetIP netip.Addr, pm PortMap) error {
table := i.getIPTByAddr(targetIP) table := r.getIPTByAddr(targetIP)
args := argsForPortMapRule(svc, excludeI, targetIP, pm) args := argsForPortMapRule(svc, excludeI, targetIP, pm)
exists, err := table.Exists("nat", "PREROUTING", args...) exists, err := table.Exists("nat", "PREROUTING", args...)
if err != nil { if err != nil {
@ -51,8 +51,8 @@ func (i *iptablesRunner) DeletePortMapRuleForSvc(svc, excludeI string, targetIP
// VIPService IP address to a local address. This is used by the Kubernetes // VIPService IP address to a local address. This is used by the Kubernetes
// operator's network layer proxies to forward tailnet traffic for VIPServices // operator's network layer proxies to forward tailnet traffic for VIPServices
// to Kubernetes Services. // to Kubernetes Services.
func (i *iptablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { func (r *iptablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
table := i.getIPTByAddr(dst) table := r.getIPTByAddr(dst)
args := argsForIngressRule(svcName, origDst, dst) args := argsForIngressRule(svcName, origDst, dst)
exists, err := table.Exists("nat", "PREROUTING", args...) exists, err := table.Exists("nat", "PREROUTING", args...)
if err != nil { if err != nil {
@ -65,8 +65,8 @@ func (i *iptablesRunner) EnsureDNATRuleForSvc(svcName string, origDst, dst netip
} }
// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc. // DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc.
func (i *iptablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { func (r *iptablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
table := i.getIPTByAddr(dst) table := r.getIPTByAddr(dst)
args := argsForIngressRule(svcName, origDst, dst) args := argsForIngressRule(svcName, origDst, dst)
exists, err := table.Exists("nat", "PREROUTING", args...) exists, err := table.Exists("nat", "PREROUTING", args...)
if err != nil { if err != nil {
@ -81,10 +81,10 @@ func (i *iptablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip
// DeleteSvc constructs all possible rules that would have been created by // DeleteSvc constructs all possible rules that would have been created by
// EnsurePortMapRuleForSvc from the provided args and ensures that each one that // EnsurePortMapRuleForSvc from the provided args and ensures that each one that
// exists is deleted. // exists is deleted.
func (i *iptablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pms []PortMap) error { func (r *iptablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pms []PortMap) error {
for _, tip := range targetIPs { for _, tip := range targetIPs {
for _, pm := range pms { for _, pm := range pms {
if err := i.DeletePortMapRuleForSvc(svc, tun, tip, pm); err != nil { if err := r.DeletePortMapRuleForSvc(svc, tun, tip, pm); err != nil {
return fmt.Errorf("error deleting rule: %w", err) return fmt.Errorf("error deleting rule: %w", err)
} }
} }

@ -37,6 +37,7 @@ type iptablesInterface interface {
} }
type iptablesRunner struct { type iptablesRunner struct {
af AddressFamilies
ipt4 iptablesInterface ipt4 iptablesInterface
ipt6 iptablesInterface ipt6 iptablesInterface
@ -53,35 +54,39 @@ func checkIP6TablesExists() error {
return nil return nil
} }
func (r *iptablesRunner) SetAddressFamilies(af AddressFamilies) {
r.af = af
}
// HasIPV6 reports true if the system supports IPv6. // HasIPV6 reports true if the system supports IPv6.
func (i *iptablesRunner) HasIPV6() bool { func (r *iptablesRunner) HasIPV6() bool {
return i.v6Available return r.v6Available
} }
// HasIPV6Filter reports true if the system supports ip6tables filter table. // HasIPV6Filter reports true if the system supports ip6tables filter table.
func (i *iptablesRunner) HasIPV6Filter() bool { func (r *iptablesRunner) HasIPV6Filter() bool {
return i.v6FilterAvailable return r.v6FilterAvailable
} }
// HasIPV6NAT reports true if the system supports IPv6 NAT. // HasIPV6NAT reports true if the system supports IPv6 NAT.
func (i *iptablesRunner) HasIPV6NAT() bool { func (r *iptablesRunner) HasIPV6NAT() bool {
return i.v6NATAvailable return r.v6NATAvailable
} }
// getIPTByAddr returns the iptablesInterface with correct IP family // getIPTByAddr returns the iptablesInterface with correct IP family
// that we will be using for the given address. // that we will be using for the given address.
func (i *iptablesRunner) getIPTByAddr(addr netip.Addr) iptablesInterface { func (r *iptablesRunner) getIPTByAddr(addr netip.Addr) iptablesInterface {
nf := i.ipt4 nf := r.ipt4
if addr.Is6() { if addr.Is6() {
nf = i.ipt6 nf = r.ipt6
} }
return nf return nf
} }
// AddLoopbackRule adds an iptables rule to permit loopback traffic to // AddLoopbackRule adds an iptables rule to permit loopback traffic to
// a local Tailscale IP. // a local Tailscale IP.
func (i *iptablesRunner) AddLoopbackRule(addr netip.Addr) error { func (r *iptablesRunner) AddLoopbackRule(addr netip.Addr) error {
if err := i.getIPTByAddr(addr).Insert("filter", "ts-input", 1, "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil { if err := r.getIPTByAddr(addr).Insert("filter", "ts-input", 1, "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil {
return fmt.Errorf("adding loopback allow rule for %q: %w", addr, err) return fmt.Errorf("adding loopback allow rule for %q: %w", addr, err)
} }
@ -96,8 +101,8 @@ func tsChain(chain string) string {
// DelLoopbackRule removes the iptables rule permitting loopback // DelLoopbackRule removes the iptables rule permitting loopback
// traffic to a Tailscale IP. // traffic to a Tailscale IP.
func (i *iptablesRunner) DelLoopbackRule(addr netip.Addr) error { func (r *iptablesRunner) DelLoopbackRule(addr netip.Addr) error {
if err := i.getIPTByAddr(addr).Delete("filter", "ts-input", "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil { if err := r.getIPTByAddr(addr).Delete("filter", "ts-input", "-i", "lo", "-s", addr.String(), "-j", "ACCEPT"); err != nil {
return fmt.Errorf("deleting loopback allow rule for %q: %w", addr, err) return fmt.Errorf("deleting loopback allow rule for %q: %w", addr, err)
} }
@ -105,27 +110,27 @@ func (i *iptablesRunner) DelLoopbackRule(addr netip.Addr) error {
} }
// getTables gets the available iptablesInterface in iptables runner. // getTables gets the available iptablesInterface in iptables runner.
func (i *iptablesRunner) getTables() []iptablesInterface { func (r *iptablesRunner) getTables() []iptablesInterface {
if i.HasIPV6Filter() { if r.HasIPV6Filter() {
return []iptablesInterface{i.ipt4, i.ipt6} return []iptablesInterface{r.ipt4, r.ipt6}
} }
return []iptablesInterface{i.ipt4} return []iptablesInterface{r.ipt4}
} }
// getNATTables gets the available iptablesInterface in iptables runner. // getNATTables gets the available iptablesInterface in iptables runner.
// If the system does not support IPv6 NAT, only the IPv4 iptablesInterface // If the system does not support IPv6 NAT, only the IPv4 iptablesInterface
// is returned. // is returned.
func (i *iptablesRunner) getNATTables() []iptablesInterface { func (r *iptablesRunner) getNATTables() []iptablesInterface {
if i.HasIPV6NAT() { if r.HasIPV6NAT() {
return i.getTables() return r.getTables()
} }
return []iptablesInterface{i.ipt4} return []iptablesInterface{r.ipt4}
} }
// AddHooks inserts calls to tailscale's netfilter chains in // AddHooks inserts calls to tailscale's netfilter chains in
// the relevant main netfilter chains. The tailscale chains must // the relevant main netfilter chains. The tailscale chains must
// already exist. If they do not, an error is returned. // already exist. If they do not, an error is returned.
func (i *iptablesRunner) AddHooks() error { func (r *iptablesRunner) AddHooks() error {
// divert inserts a jump to the tailscale chain in the given table/chain. // divert inserts a jump to the tailscale chain in the given table/chain.
// If the jump already exists, it is a no-op. // If the jump already exists, it is a no-op.
divert := func(ipt iptablesInterface, table, chain string) error { divert := func(ipt iptablesInterface, table, chain string) error {
@ -145,7 +150,7 @@ func (i *iptablesRunner) AddHooks() error {
return nil return nil
} }
for _, ipt := range i.getTables() { for _, ipt := range r.getTables() {
if err := divert(ipt, "filter", "INPUT"); err != nil { if err := divert(ipt, "filter", "INPUT"); err != nil {
return err return err
} }
@ -154,7 +159,7 @@ func (i *iptablesRunner) AddHooks() error {
} }
} }
for _, ipt := range i.getNATTables() { for _, ipt := range r.getNATTables() {
if err := divert(ipt, "nat", "POSTROUTING"); err != nil { if err := divert(ipt, "nat", "POSTROUTING"); err != nil {
return err return err
} }
@ -164,7 +169,7 @@ func (i *iptablesRunner) AddHooks() error {
// AddChains creates custom Tailscale chains in netfilter via iptables // AddChains creates custom Tailscale chains in netfilter via iptables
// if the ts-chain doesn't already exist. // if the ts-chain doesn't already exist.
func (i *iptablesRunner) AddChains() error { func (r *iptablesRunner) AddChains() error {
// create creates a chain in the given table if it doesn't already exist. // create creates a chain in the given table if it doesn't already exist.
// If the chain already exists, it is a no-op. // If the chain already exists, it is a no-op.
create := func(ipt iptablesInterface, table, chain string) error { create := func(ipt iptablesInterface, table, chain string) error {
@ -179,7 +184,7 @@ func (i *iptablesRunner) AddChains() error {
return nil return nil
} }
for _, ipt := range i.getTables() { for _, ipt := range r.getTables() {
if err := create(ipt, "filter", "ts-input"); err != nil { if err := create(ipt, "filter", "ts-input"); err != nil {
return err return err
} }
@ -188,7 +193,7 @@ func (i *iptablesRunner) AddChains() error {
} }
} }
for _, ipt := range i.getNATTables() { for _, ipt := range r.getNATTables() {
if err := create(ipt, "nat", "ts-postrouting"); err != nil { if err := create(ipt, "nat", "ts-postrouting"); err != nil {
return err return err
} }
@ -199,12 +204,14 @@ func (i *iptablesRunner) AddChains() error {
// AddBase adds some basic processing rules to be supplemented by // AddBase adds some basic processing rules to be supplemented by
// later calls to other helpers. // later calls to other helpers.
func (i *iptablesRunner) AddBase(tunname string) error { func (r *iptablesRunner) AddBase(tunname string) error {
if err := i.addBase4(tunname); err != nil { if r.af&FamilyIPv4 != 0 {
if err := r.addBase4(tunname); err != nil {
return err return err
} }
if i.HasIPV6Filter() { }
if err := i.addBase6(tunname); err != nil { if r.af&FamilyIPv6 != 0 && r.HasIPV6Filter() {
if err := r.addBase6(tunname); err != nil {
return err return err
} }
} }
@ -213,7 +220,7 @@ func (i *iptablesRunner) AddBase(tunname string) error {
// addBase4 adds some basic IPv4 processing rules to be // addBase4 adds some basic IPv4 processing rules to be
// supplemented by later calls to other helpers. // supplemented by later calls to other helpers.
func (i *iptablesRunner) addBase4(tunname string) error { func (r *iptablesRunner) addBase4(tunname string) error {
// Only allow CGNAT range traffic to come from tailscale0. There // Only allow CGNAT range traffic to come from tailscale0. There
// is an exception carved out for ranges used by ChromeOS, for // is an exception carved out for ranges used by ChromeOS, for
// which we fall out of the Tailscale chain. // which we fall out of the Tailscale chain.
@ -221,17 +228,17 @@ func (i *iptablesRunner) addBase4(tunname string) error {
// Note, this will definitely break nodes that end up using the // Note, this will definitely break nodes that end up using the
// CGNAT range for other purposes :(. // CGNAT range for other purposes :(.
args := []string{"!", "-i", tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"} args := []string{"!", "-i", tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"}
if err := i.ipt4.Append("filter", "ts-input", args...); err != nil { if err := r.ipt4.Append("filter", "ts-input", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err)
} }
args = []string{"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"} args = []string{"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}
if err := i.ipt4.Append("filter", "ts-input", args...); err != nil { if err := r.ipt4.Append("filter", "ts-input", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err)
} }
// Explicitly allow all other inbound traffic to the tun interface // Explicitly allow all other inbound traffic to the tun interface
args = []string{"-i", tunname, "-j", "ACCEPT"} args = []string{"-i", tunname, "-j", "ACCEPT"}
if err := i.ipt4.Append("filter", "ts-input", args...); err != nil { if err := r.ipt4.Append("filter", "ts-input", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err)
} }
@ -247,35 +254,35 @@ func (i *iptablesRunner) addBase4(tunname string) error {
// filter/FORWARD, and set a packet mark that nat/POSTROUTING can // filter/FORWARD, and set a packet mark that nat/POSTROUTING can
// use to effectively run that same test again. // use to effectively run that same test again.
args = []string{"-i", tunname, "-j", "MARK", "--set-mark", subnetRouteMark + "/" + fwmarkMask} args = []string{"-i", tunname, "-j", "MARK", "--set-mark", subnetRouteMark + "/" + fwmarkMask}
if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
} }
args = []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "ACCEPT"} args = []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "ACCEPT"}
if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
} }
args = []string{"-o", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"} args = []string{"-o", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}
if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
} }
args = []string{"-o", tunname, "-j", "ACCEPT"} args = []string{"-o", tunname, "-j", "ACCEPT"}
if err := i.ipt4.Append("filter", "ts-forward", args...); err != nil { if err := r.ipt4.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err) return fmt.Errorf("adding %v in v4/filter/ts-forward: %w", args, err)
} }
return nil return nil
} }
func (i *iptablesRunner) AddDNATRule(origDst, dst netip.Addr) error { func (r *iptablesRunner) AddDNATRule(origDst, dst netip.Addr) error {
table := i.getIPTByAddr(dst) table := r.getIPTByAddr(dst)
return table.Insert("nat", "PREROUTING", 1, "--destination", origDst.String(), "-j", "DNAT", "--to-destination", dst.String()) return table.Insert("nat", "PREROUTING", 1, "--destination", origDst.String(), "-j", "DNAT", "--to-destination", dst.String())
} }
// EnsureSNATForDst sets up firewall to ensure that all traffic aimed for dst, has its source ip set to src: // EnsureSNATForDst sets up firewall to ensure that all traffic aimed for dst, has its source ip set to src:
// - creates a SNAT rule if not already present // - creates a SNAT rule if not already present
// - ensures that any no longer valid SNAT rules for the same dst are removed // - ensures that any no longer valid SNAT rules for the same dst are removed
func (i *iptablesRunner) EnsureSNATForDst(src, dst netip.Addr) error { func (r *iptablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
table := i.getIPTByAddr(dst) table := r.getIPTByAddr(dst)
rules, err := table.List("nat", "POSTROUTING") rules, err := table.List("nat", "POSTROUTING")
if err != nil { if err != nil {
return fmt.Errorf("error listing rules: %v", err) return fmt.Errorf("error listing rules: %v", err)
@ -309,15 +316,15 @@ func (i *iptablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
return table.Insert("nat", "POSTROUTING", 1, "-d", dstPrefix.String(), "-j", "SNAT", "--to-source", src.String()) return table.Insert("nat", "POSTROUTING", 1, "-d", dstPrefix.String(), "-j", "SNAT", "--to-source", src.String())
} }
func (i *iptablesRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error { func (r *iptablesRunner) DNATNonTailscaleTraffic(tun string, dst netip.Addr) error {
table := i.getIPTByAddr(dst) table := r.getIPTByAddr(dst)
return table.Insert("nat", "PREROUTING", 1, "!", "-i", tun, "-j", "DNAT", "--to-destination", dst.String()) return table.Insert("nat", "PREROUTING", 1, "!", "-i", tun, "-j", "DNAT", "--to-destination", dst.String())
} }
// DNATWithLoadBalancer adds iptables rules to forward all traffic received for // DNATWithLoadBalancer adds iptables rules to forward all traffic received for
// originDst to the backend dsts. Traffic will be load balanced using round robin. // originDst to the backend dsts. Traffic will be load balanced using round robin.
func (i *iptablesRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error { func (r *iptablesRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error {
table := i.getIPTByAddr(dsts[0]) table := r.getIPTByAddr(dsts[0])
if err := table.ClearChain("nat", "PREROUTING"); err != nil && !isNotExistError(err) { if err := table.ClearChain("nat", "PREROUTING"); err != nil && !isNotExistError(err) {
// If clearing the PREROUTING chain fails, fail the whole operation. This // If clearing the PREROUTING chain fails, fail the whole operation. This
// rule is currently only used in Kubernetes containers where a // rule is currently only used in Kubernetes containers where a
@ -335,35 +342,35 @@ func (i *iptablesRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.A
return table.Append("nat", "PREROUTING", "--destination", origDst.String(), "-j", "DNAT", "--to-destination", dsts[0].String()) return table.Append("nat", "PREROUTING", "--destination", origDst.String(), "-j", "DNAT", "--to-destination", dsts[0].String())
} }
func (i *iptablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { func (r *iptablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
table := i.getIPTByAddr(addr) table := r.getIPTByAddr(addr)
return table.Append("mangle", "FORWARD", "-o", tun, "-p", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", "--clamp-mss-to-pmtu") return table.Append("mangle", "FORWARD", "-o", tun, "-p", "tcp", "--tcp-flags", "SYN,RST", "SYN", "-j", "TCPMSS", "--clamp-mss-to-pmtu")
} }
// addBase6 adds some basic IPv6 processing rules to be // addBase6 adds some basic IPv6 processing rules to be
// supplemented by later calls to other helpers. // supplemented by later calls to other helpers.
func (i *iptablesRunner) addBase6(tunname string) error { func (r *iptablesRunner) addBase6(tunname string) error {
// TODO: only allow traffic from Tailscale's ULA range to come // TODO: only allow traffic from Tailscale's ULA range to come
// from tailscale0. // from tailscale0.
// Explicitly allow all other inbound traffic to the tun interface // Explicitly allow all other inbound traffic to the tun interface
args := []string{"-i", tunname, "-j", "ACCEPT"} args := []string{"-i", tunname, "-j", "ACCEPT"}
if err := i.ipt6.Append("filter", "ts-input", args...); err != nil { if err := r.ipt6.Append("filter", "ts-input", args...); err != nil {
return fmt.Errorf("adding %v in v6/filter/ts-input: %w", args, err) return fmt.Errorf("adding %v in v6/filter/ts-input: %w", args, err)
} }
args = []string{"-i", tunname, "-j", "MARK", "--set-mark", subnetRouteMark + "/" + fwmarkMask} args = []string{"-i", tunname, "-j", "MARK", "--set-mark", subnetRouteMark + "/" + fwmarkMask}
if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil { if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
} }
args = []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "ACCEPT"} args = []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "ACCEPT"}
if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil { if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
} }
// TODO: drop forwarded traffic to tailscale0 from tailscale's ULA // TODO: drop forwarded traffic to tailscale0 from tailscale's ULA
// (see corresponding IPv4 CGNAT rule). // (see corresponding IPv4 CGNAT rule).
args = []string{"-o", tunname, "-j", "ACCEPT"} args = []string{"-o", tunname, "-j", "ACCEPT"}
if err := i.ipt6.Append("filter", "ts-forward", args...); err != nil { if err := r.ipt6.Append("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err) return fmt.Errorf("adding %v in v6/filter/ts-forward: %w", args, err)
} }
@ -371,8 +378,8 @@ func (i *iptablesRunner) addBase6(tunname string) error {
} }
// DelChains removes the custom Tailscale chains from netfilter via iptables. // DelChains removes the custom Tailscale chains from netfilter via iptables.
func (i *iptablesRunner) DelChains() error { func (r *iptablesRunner) DelChains() error {
for _, ipt := range i.getTables() { for _, ipt := range r.getTables() {
if err := delChain(ipt, "filter", "ts-input"); err != nil { if err := delChain(ipt, "filter", "ts-input"); err != nil {
return err return err
} }
@ -381,7 +388,7 @@ func (i *iptablesRunner) DelChains() error {
} }
} }
for _, ipt := range i.getNATTables() { for _, ipt := range r.getNATTables() {
if err := delChain(ipt, "nat", "ts-postrouting"); err != nil { if err := delChain(ipt, "nat", "ts-postrouting"); err != nil {
return err return err
} }
@ -392,7 +399,7 @@ func (i *iptablesRunner) DelChains() error {
// DelBase empties but does not remove custom Tailscale chains from // DelBase empties but does not remove custom Tailscale chains from
// netfilter via iptables. // netfilter via iptables.
func (i *iptablesRunner) DelBase() error { func (r *iptablesRunner) DelBase() error {
del := func(ipt iptablesInterface, table, chain string) error { del := func(ipt iptablesInterface, table, chain string) error {
if err := ipt.ClearChain(table, chain); err != nil { if err := ipt.ClearChain(table, chain); err != nil {
if isNotExistError(err) { if isNotExistError(err) {
@ -405,7 +412,7 @@ func (i *iptablesRunner) DelBase() error {
return nil return nil
} }
for _, ipt := range i.getTables() { for _, ipt := range r.getTables() {
if err := del(ipt, "filter", "ts-input"); err != nil { if err := del(ipt, "filter", "ts-input"); err != nil {
return err return err
} }
@ -413,7 +420,7 @@ func (i *iptablesRunner) DelBase() error {
return err return err
} }
} }
for _, ipt := range i.getNATTables() { for _, ipt := range r.getNATTables() {
if err := del(ipt, "nat", "ts-postrouting"); err != nil { if err := del(ipt, "nat", "ts-postrouting"); err != nil {
return err return err
} }
@ -424,8 +431,8 @@ func (i *iptablesRunner) DelBase() error {
// DelHooks deletes the calls to tailscale's netfilter chains // DelHooks deletes the calls to tailscale's netfilter chains
// in the relevant main netfilter chains. // in the relevant main netfilter chains.
func (i *iptablesRunner) DelHooks(logf logger.Logf) error { func (r *iptablesRunner) DelHooks(logf logger.Logf) error {
for _, ipt := range i.getTables() { for _, ipt := range r.getTables() {
if err := delTSHook(ipt, "filter", "INPUT", logf); err != nil { if err := delTSHook(ipt, "filter", "INPUT", logf); err != nil {
return err return err
} }
@ -433,7 +440,7 @@ func (i *iptablesRunner) DelHooks(logf logger.Logf) error {
return err return err
} }
} }
for _, ipt := range i.getNATTables() { for _, ipt := range r.getNATTables() {
if err := delTSHook(ipt, "nat", "POSTROUTING", logf); err != nil { if err := delTSHook(ipt, "nat", "POSTROUTING", logf); err != nil {
return err return err
} }
@ -444,9 +451,9 @@ func (i *iptablesRunner) DelHooks(logf logger.Logf) error {
// AddSNATRule adds a netfilter rule to SNAT traffic destined for // AddSNATRule adds a netfilter rule to SNAT traffic destined for
// local subnets. // local subnets.
func (i *iptablesRunner) AddSNATRule() error { func (r *iptablesRunner) AddSNATRule() error {
args := []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "MASQUERADE"} args := []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "MASQUERADE"}
for _, ipt := range i.getNATTables() { for _, ipt := range r.getNATTables() {
if err := ipt.Append("nat", "ts-postrouting", args...); err != nil { if err := ipt.Append("nat", "ts-postrouting", args...); err != nil {
return fmt.Errorf("adding %v in nat/ts-postrouting: %w", args, err) return fmt.Errorf("adding %v in nat/ts-postrouting: %w", args, err)
} }
@ -456,9 +463,9 @@ func (i *iptablesRunner) AddSNATRule() error {
// DelSNATRule removes the netfilter rule to SNAT traffic destined for // DelSNATRule removes the netfilter rule to SNAT traffic destined for
// local subnets. An error is returned if the rule does not exist. // local subnets. An error is returned if the rule does not exist.
func (i *iptablesRunner) DelSNATRule() error { func (r *iptablesRunner) DelSNATRule() error {
args := []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "MASQUERADE"} args := []string{"-m", "mark", "--mark", subnetRouteMark + "/" + fwmarkMask, "-j", "MASQUERADE"}
for _, ipt := range i.getNATTables() { for _, ipt := range r.getNATTables() {
if err := ipt.Delete("nat", "ts-postrouting", args...); err != nil { if err := ipt.Delete("nat", "ts-postrouting", args...); err != nil {
return fmt.Errorf("deleting %v in nat/ts-postrouting: %w", args, err) return fmt.Errorf("deleting %v in nat/ts-postrouting: %w", args, err)
} }
@ -472,7 +479,7 @@ func statefulRuleArgs(tunname string) []string {
// AddStatefulRule adds a netfilter rule for stateful packet filtering using // AddStatefulRule adds a netfilter rule for stateful packet filtering using
// conntrack. // conntrack.
func (i *iptablesRunner) AddStatefulRule(tunname string) error { func (r *iptablesRunner) AddStatefulRule(tunname string) error {
// Drop packets that are destined for the tailscale interface if // Drop packets that are destined for the tailscale interface if
// they're a new connection, per conntrack, to prevent hosts on the // they're a new connection, per conntrack, to prevent hosts on the
// same subnet from being able to use this device as a way to forward // same subnet from being able to use this device as a way to forward
@ -495,7 +502,7 @@ func (i *iptablesRunner) AddStatefulRule(tunname string) error {
// Tailscale from other hosts on the same network segment; we drop // Tailscale from other hosts on the same network segment; we drop
// INVALID packets as well. // INVALID packets as well.
args := statefulRuleArgs(tunname) args := statefulRuleArgs(tunname)
for _, ipt := range i.getTables() { for _, ipt := range r.getTables() {
// First, find the final "accept" rule. // First, find the final "accept" rule.
rules, err := ipt.List("filter", "ts-forward") rules, err := ipt.List("filter", "ts-forward")
if err != nil { if err != nil {
@ -517,9 +524,9 @@ func (i *iptablesRunner) AddStatefulRule(tunname string) error {
// DelStatefulRule removes the netfilter rule for stateful packet filtering // DelStatefulRule removes the netfilter rule for stateful packet filtering
// using conntrack. // using conntrack.
func (i *iptablesRunner) DelStatefulRule(tunname string) error { func (r *iptablesRunner) DelStatefulRule(tunname string) error {
args := statefulRuleArgs(tunname) args := statefulRuleArgs(tunname)
for _, ipt := range i.getTables() { for _, ipt := range r.getTables() {
if err := ipt.Delete("filter", "ts-forward", args...); err != nil { if err := ipt.Delete("filter", "ts-forward", args...); err != nil {
return fmt.Errorf("deleting %v in filter/ts-forward: %w", args, err) return fmt.Errorf("deleting %v in filter/ts-forward: %w", args, err)
} }
@ -540,13 +547,13 @@ func buildMagicsockPortRule(port uint16) []string {
// the specified UDP port, so magicsock can accept incoming connections. // the specified UDP port, so magicsock can accept incoming connections.
// network must be either "udp4" or "udp6" - this determines whether the rule // network must be either "udp4" or "udp6" - this determines whether the rule
// is added for IPv4 or IPv6. // is added for IPv4 or IPv6.
func (i *iptablesRunner) AddMagicsockPortRule(port uint16, network string) error { func (r *iptablesRunner) AddMagicsockPortRule(port uint16, network string) error {
var ipt iptablesInterface var ipt iptablesInterface
switch network { switch network {
case "udp4": case "udp4":
ipt = i.ipt4 ipt = r.ipt4
case "udp6": case "udp6":
ipt = i.ipt6 ipt = r.ipt6
default: default:
return fmt.Errorf("unsupported network %s", network) return fmt.Errorf("unsupported network %s", network)
} }
@ -564,13 +571,13 @@ func (i *iptablesRunner) AddMagicsockPortRule(port uint16, network string) error
// incoming traffic on a particular UDP port. // incoming traffic on a particular UDP port.
// network must be either "udp4" or "udp6" - this determines whether the rule // network must be either "udp4" or "udp6" - this determines whether the rule
// is removed for IPv4 or IPv6. // is removed for IPv4 or IPv6.
func (i *iptablesRunner) DelMagicsockPortRule(port uint16, network string) error { func (r *iptablesRunner) DelMagicsockPortRule(port uint16, network string) error {
var ipt iptablesInterface var ipt iptablesInterface
switch network { switch network {
case "udp4": case "udp4":
ipt = i.ipt4 ipt = r.ipt4
case "udp6": case "udp6":
ipt = i.ipt6 ipt = r.ipt6
default: default:
return fmt.Errorf("unsupported network %s", network) return fmt.Errorf("unsupported network %s", network)
} }

@ -32,13 +32,13 @@ import (
// - ensures that nat table exists // - ensures that nat table exists
// - ensures that there is a prerouting chain for the given service and IP family of the target address in the nat table // - ensures that there is a prerouting chain for the given service and IP family of the target address in the nat table
// - ensures that there is a portmapping rule mathcing the given portmap (only creates the rule if it does not already exist) // - ensures that there is a portmapping rule mathcing the given portmap (only creates the rule if it does not already exist)
func (n *nftablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { func (r *nftablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
t, ch, err := n.ensureChainForSvc(svc, targetIP) t, ch, err := r.ensureChainForSvc(svc, targetIP)
if err != nil { if err != nil {
return fmt.Errorf("error ensuring chain for %s: %w", svc, err) return fmt.Errorf("error ensuring chain for %s: %w", svc, err)
} }
meta := svcPortMapRuleMeta(svc, targetIP, pm) meta := svcPortMapRuleMeta(svc, targetIP, pm)
rule, err := n.findRuleByMetadata(t, ch, meta) rule, err := r.findRuleByMetadata(t, ch, meta)
if err != nil { if err != nil {
return fmt.Errorf("error looking up rule: %w", err) return fmt.Errorf("error looking up rule: %w", err)
} }
@ -51,8 +51,8 @@ func (n *nftablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip
} }
rule = portMapRule(t, ch, tun, targetIP, pm.MatchPort, pm.TargetPort, p, meta) rule = portMapRule(t, ch, tun, targetIP, pm.MatchPort, pm.TargetPort, p, meta)
n.conn.InsertRule(rule) r.conn.InsertRule(rule)
return n.conn.Flush() return r.conn.Flush()
} }
// DeletePortMapRuleForSvc deletes a portmapping rule in the given service/IP family chain. // DeletePortMapRuleForSvc deletes a portmapping rule in the given service/IP family chain.
@ -60,19 +60,19 @@ func (n *nftablesRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip
// The caller is expected to call DeleteSvc if the whole service (the chain) // The caller is expected to call DeleteSvc if the whole service (the chain)
// needs to be deleted, so we don't deal with the case where this is the only // needs to be deleted, so we don't deal with the case where this is the only
// rule in the chain here. // rule in the chain here.
func (n *nftablesRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { func (r *nftablesRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error {
table, err := n.getNFTByAddr(targetIP) table, err := r.getNFTByAddr(targetIP)
if err != nil { if err != nil {
return fmt.Errorf("error setting up nftables for IP family of %s: %w", targetIP, err) return fmt.Errorf("error setting up nftables for IP family of %s: %w", targetIP, err)
} }
t, err := getTableIfExists(n.conn, table.Proto, "nat") t, err := getTableIfExists(r.conn, table.Proto, "nat")
if err != nil { if err != nil {
return fmt.Errorf("error checking if nat table exists: %w", err) return fmt.Errorf("error checking if nat table exists: %w", err)
} }
if t == nil { if t == nil {
return nil return nil
} }
ch, err := getChainFromTable(n.conn, t, svc) ch, err := getChainFromTable(r.conn, t, svc)
if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) { if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) {
return fmt.Errorf("error checking if chain %s exists: %w", svc, err) return fmt.Errorf("error checking if chain %s exists: %w", svc, err)
} }
@ -80,56 +80,56 @@ func (n *nftablesRunner) DeletePortMapRuleForSvc(svc, tun string, targetIP netip
return nil // service chain does not exist, so neither does the portmapping rule return nil // service chain does not exist, so neither does the portmapping rule
} }
meta := svcPortMapRuleMeta(svc, targetIP, pm) meta := svcPortMapRuleMeta(svc, targetIP, pm)
rule, err := n.findRuleByMetadata(t, ch, meta) rule, err := r.findRuleByMetadata(t, ch, meta)
if err != nil { if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err) return fmt.Errorf("error checking if rule exists: %w", err)
} }
if rule == nil { if rule == nil {
return nil return nil
} }
if err := n.conn.DelRule(rule); err != nil { if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("error deleting rule: %w", err) return fmt.Errorf("error deleting rule: %w", err)
} }
return n.conn.Flush() return r.conn.Flush()
} }
// DeleteSvc deletes the chains for the given service if any exist. // DeleteSvc deletes the chains for the given service if any exist.
func (n *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error { func (r *nftablesRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, pm []PortMap) error {
for _, tip := range targetIPs { for _, tip := range targetIPs {
table, err := n.getNFTByAddr(tip) table, err := r.getNFTByAddr(tip)
if err != nil { if err != nil {
return fmt.Errorf("error setting up nftables for IP family of %s: %w", tip, err) return fmt.Errorf("error setting up nftables for IP family of %s: %w", tip, err)
} }
t, err := getTableIfExists(n.conn, table.Proto, "nat") t, err := getTableIfExists(r.conn, table.Proto, "nat")
if err != nil { if err != nil {
return fmt.Errorf("error checking if nat table exists: %w", err) return fmt.Errorf("error checking if nat table exists: %w", err)
} }
if t == nil { if t == nil {
return nil return nil
} }
ch, err := getChainFromTable(n.conn, t, svc) ch, err := getChainFromTable(r.conn, t, svc)
if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) { if err != nil && !errors.Is(err, errorChainNotFound{t.Name, svc}) {
return fmt.Errorf("error checking if chain %s exists: %w", svc, err) return fmt.Errorf("error checking if chain %s exists: %w", svc, err)
} }
if errors.Is(err, errorChainNotFound{t.Name, svc}) { if errors.Is(err, errorChainNotFound{t.Name, svc}) {
return nil return nil
} }
n.conn.DelChain(ch) r.conn.DelChain(ch)
} }
return n.conn.Flush() return r.conn.Flush()
} }
// EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the // EnsureDNATRuleForSvc adds a DNAT rule that forwards traffic from the
// VIPService IP address to a local address. This is used by the Kubernetes // VIPService IP address to a local address. This is used by the Kubernetes
// operator's network layer proxies to forward tailnet traffic for VIPServices // operator's network layer proxies to forward tailnet traffic for VIPServices
// to Kubernetes Services. // to Kubernetes Services.
func (n *nftablesRunner) EnsureDNATRuleForSvc(svc string, origDst, dst netip.Addr) error { func (r *nftablesRunner) EnsureDNATRuleForSvc(svc string, origDst, dst netip.Addr) error {
t, ch, err := n.ensurePreroutingChain(origDst) t, ch, err := r.ensurePreroutingChain(origDst)
if err != nil { if err != nil {
return fmt.Errorf("error ensuring chain for %s: %w", svc, err) return fmt.Errorf("error ensuring chain for %s: %w", svc, err)
} }
meta := svcRuleMeta(svc, origDst, dst) meta := svcRuleMeta(svc, origDst, dst)
rule, err := n.findRuleByMetadata(t, ch, meta) rule, err := r.findRuleByMetadata(t, ch, meta)
if err != nil { if err != nil {
return fmt.Errorf("error looking up rule: %w", err) return fmt.Errorf("error looking up rule: %w", err)
} }
@ -137,25 +137,25 @@ func (n *nftablesRunner) EnsureDNATRuleForSvc(svc string, origDst, dst netip.Add
return nil return nil
} }
rule = dnatRuleForChain(t, ch, origDst, dst, meta) rule = dnatRuleForChain(t, ch, origDst, dst, meta)
n.conn.InsertRule(rule) r.conn.InsertRule(rule)
return n.conn.Flush() return r.conn.Flush()
} }
// DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc. // DeleteDNATRuleForSvc deletes a DNAT rule created by EnsureDNATRuleForSvc.
// We use the metadata attached to the rule to look it up. // We use the metadata attached to the rule to look it up.
func (n *nftablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error { func (r *nftablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip.Addr) error {
table, err := n.getNFTByAddr(origDst) table, err := r.getNFTByAddr(origDst)
if err != nil { if err != nil {
return fmt.Errorf("error setting up nftables for IP family of %s: %w", origDst, err) return fmt.Errorf("error setting up nftables for IP family of %s: %w", origDst, err)
} }
t, err := getTableIfExists(n.conn, table.Proto, "nat") t, err := getTableIfExists(r.conn, table.Proto, "nat")
if err != nil { if err != nil {
return fmt.Errorf("error checking if nat table exists: %w", err) return fmt.Errorf("error checking if nat table exists: %w", err)
} }
if t == nil { if t == nil {
return nil return nil
} }
ch, err := getChainFromTable(n.conn, t, "PREROUTING") ch, err := getChainFromTable(r.conn, t, "PREROUTING")
if errors.Is(err, errorChainNotFound{tableName: "nat", chainName: "PREROUTING"}) { if errors.Is(err, errorChainNotFound{tableName: "nat", chainName: "PREROUTING"}) {
return nil return nil
} }
@ -163,17 +163,17 @@ func (n *nftablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst netip
return fmt.Errorf("error checking if chain PREROUTING exists: %w", err) return fmt.Errorf("error checking if chain PREROUTING exists: %w", err)
} }
meta := svcRuleMeta(svcName, origDst, dst) meta := svcRuleMeta(svcName, origDst, dst)
rule, err := n.findRuleByMetadata(t, ch, meta) rule, err := r.findRuleByMetadata(t, ch, meta)
if err != nil { if err != nil {
return fmt.Errorf("error checking if rule exists: %w", err) return fmt.Errorf("error checking if rule exists: %w", err)
} }
if rule == nil { if rule == nil {
return nil return nil
} }
if err := n.conn.DelRule(rule); err != nil { if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("error deleting rule: %w", err) return fmt.Errorf("error deleting rule: %w", err)
} }
return n.conn.Flush() return r.conn.Flush()
} }
func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule { func portMapRule(t *nftables.Table, ch *nftables.Chain, tun string, targetIP netip.Addr, matchPort, targetPort uint16, proto uint8, meta []byte) *nftables.Rule {
@ -239,11 +239,11 @@ func svcPortMapRuleMeta(svcName string, targetIP netip.Addr, pm PortMap) []byte
return []byte(fmt.Sprintf("svc:%s,targetIP:%s:matchPort:%v,targetPort:%v,proto:%v", svcName, targetIP.String(), pm.MatchPort, pm.TargetPort, pm.Protocol)) return []byte(fmt.Sprintf("svc:%s,targetIP:%s:matchPort:%v,targetPort:%v,proto:%v", svcName, targetIP.String(), pm.MatchPort, pm.TargetPort, pm.Protocol))
} }
func (n *nftablesRunner) findRuleByMetadata(t *nftables.Table, ch *nftables.Chain, meta []byte) (*nftables.Rule, error) { func (r *nftablesRunner) findRuleByMetadata(t *nftables.Table, ch *nftables.Chain, meta []byte) (*nftables.Rule, error) {
if n.conn == nil || t == nil || ch == nil || len(meta) == 0 { if r.conn == nil || t == nil || ch == nil || len(meta) == 0 {
return nil, nil return nil, nil
} }
rules, err := n.conn.GetRules(t, ch) rules, err := r.conn.GetRules(t, ch)
if err != nil { if err != nil {
return nil, fmt.Errorf("error listing rules: %w", err) return nil, fmt.Errorf("error listing rules: %w", err)
} }
@ -255,17 +255,17 @@ func (n *nftablesRunner) findRuleByMetadata(t *nftables.Table, ch *nftables.Chai
return nil, nil return nil, nil
} }
func (n *nftablesRunner) ensureChainForSvc(svc string, targetIP netip.Addr) (*nftables.Table, *nftables.Chain, error) { func (r *nftablesRunner) ensureChainForSvc(svc string, targetIP netip.Addr) (*nftables.Table, *nftables.Chain, error) {
polAccept := nftables.ChainPolicyAccept polAccept := nftables.ChainPolicyAccept
table, err := n.getNFTByAddr(targetIP) table, err := r.getNFTByAddr(targetIP)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", targetIP, err) return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", targetIP, err)
} }
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") nat, err := createTableIfNotExist(r.conn, table.Proto, "nat")
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error ensuring nat table: %w", err) return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
} }
svcCh, err := getOrCreateChain(n.conn, chainInfo{ svcCh, err := getOrCreateChain(r.conn, chainInfo{
table: nat, table: nat,
name: svc, name: svc,
chainType: nftables.ChainTypeNAT, chainType: nftables.ChainTypeNAT,

@ -74,21 +74,22 @@ type nftablesRunner struct {
nft6 *nftable // IPv6 tables or nil if the system does not support IPv6 nft6 *nftable // IPv6 tables or nil if the system does not support IPv6
v6Available bool // whether the host supports IPv6 v6Available bool // whether the host supports IPv6
af AddressFamilies
} }
func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) { func (r *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
polAccept := nftables.ChainPolicyAccept polAccept := nftables.ChainPolicyAccept
table, err := n.getNFTByAddr(dst) table, err := r.getNFTByAddr(dst)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err) return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err)
} }
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") nat, err := createTableIfNotExist(r.conn, table.Proto, "nat")
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error ensuring nat table: %w", err) return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
} }
// ensure prerouting chain exists // ensure prerouting chain exists
preroutingCh, err := getOrCreateChain(n.conn, chainInfo{ preroutingCh, err := getOrCreateChain(r.conn, chainInfo{
table: nat, table: nat,
name: "PREROUTING", name: "PREROUTING",
chainType: nftables.ChainTypeNAT, chainType: nftables.ChainTypeNAT,
@ -102,14 +103,14 @@ func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table,
return nat, preroutingCh, nil return nat, preroutingCh, nil
} }
func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { func (r *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
nat, preroutingCh, err := n.ensurePreroutingChain(dst) nat, preroutingCh, err := r.ensurePreroutingChain(dst)
if err != nil { if err != nil {
return err return err
} }
rule := dnatRuleForChain(nat, preroutingCh, origDst, dst, nil) rule := dnatRuleForChain(nat, preroutingCh, origDst, dst, nil)
n.conn.InsertRule(rule) r.conn.InsertRule(rule)
return n.conn.Flush() return r.conn.Flush()
} }
func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.Addr, meta []byte) *nftables.Rule { func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.Addr, meta []byte) *nftables.Rule {
@ -160,12 +161,12 @@ func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.
// TODO (irbekrm): instead of doing this load balance traffic evenly to all // TODO (irbekrm): instead of doing this load balance traffic evenly to all
// backend destinations. // backend destinations.
// https://github.com/tailscale/tailscale/commit/d37f2f508509c6c35ad724fd75a27685b90b575b#diff-a3bcbcd1ca198799f4f768dc56fea913e1945a6b3ec9dbec89325a84a19a85e7R148-R232 // https://github.com/tailscale/tailscale/commit/d37f2f508509c6c35ad724fd75a27685b90b575b#diff-a3bcbcd1ca198799f4f768dc56fea913e1945a6b3ec9dbec89325a84a19a85e7R148-R232
func (n *nftablesRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error { func (r *nftablesRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error {
return n.AddDNATRule(origDst, dsts[0]) return r.AddDNATRule(origDst, dsts[0])
} }
func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) error { func (r *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) error {
nat, preroutingCh, err := n.ensurePreroutingChain(dst) nat, preroutingCh, err := r.ensurePreroutingChain(dst)
if err != nil { if err != nil {
return err return err
} }
@ -197,23 +198,23 @@ func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr)
}, },
}, },
} }
n.conn.InsertRule(dnatRule) r.conn.InsertRule(dnatRule)
return n.conn.Flush() return r.conn.Flush()
} }
func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error { func (r *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
polAccept := nftables.ChainPolicyAccept polAccept := nftables.ChainPolicyAccept
table, err := n.getNFTByAddr(dst) table, err := r.getNFTByAddr(dst)
if err != nil { if err != nil {
return fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err) return fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err)
} }
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") nat, err := createTableIfNotExist(r.conn, table.Proto, "nat")
if err != nil { if err != nil {
return fmt.Errorf("error ensuring nat table exists: %w", err) return fmt.Errorf("error ensuring nat table exists: %w", err)
} }
// ensure postrouting chain exists // ensure postrouting chain exists
postRoutingCh, err := getOrCreateChain(n.conn, chainInfo{ postRoutingCh, err := getOrCreateChain(r.conn, chainInfo{
table: nat, table: nat,
name: "POSTROUTING", name: "POSTROUTING",
chainType: nftables.ChainTypeNAT, chainType: nftables.ChainTypeNAT,
@ -225,7 +226,7 @@ func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
return fmt.Errorf("error ensuring postrouting chain: %w", err) return fmt.Errorf("error ensuring postrouting chain: %w", err)
} }
rules, err := n.conn.GetRules(nat, postRoutingCh) rules, err := r.conn.GetRules(nat, postRoutingCh)
if err != nil { if err != nil {
return fmt.Errorf("error listing rules: %w", err) return fmt.Errorf("error listing rules: %w", err)
} }
@ -237,14 +238,14 @@ func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
if strings.EqualFold(current, snatRuleFullMatch) { if strings.EqualFold(current, snatRuleFullMatch) {
return nil // already exists, do nothing return nil // already exists, do nothing
} }
if err := n.conn.DelRule(rule); err != nil { if err := r.conn.DelRule(rule); err != nil {
return fmt.Errorf("error deleting SNAT rule: %w", err) return fmt.Errorf("error deleting SNAT rule: %w", err)
} }
} }
} }
rule := snatRule(nat, postRoutingCh, src, dst, []byte(snatRuleFullMatch)) rule := snatRule(nat, postRoutingCh, src, dst, []byte(snatRuleFullMatch))
n.conn.AddRule(rule) r.conn.AddRule(rule)
return n.conn.Flush() return r.conn.Flush()
} }
// ClampMSSToPMTU ensures that all packets with TCP flags (SYN, ACK, RST) set // ClampMSSToPMTU ensures that all packets with TCP flags (SYN, ACK, RST) set
@ -266,19 +267,19 @@ func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
// verdicts that would cause no further procesing within that chain. This // verdicts that would cause no further procesing within that chain. This
// functionality is currently invoked from outside wgengine (containerboot), so // functionality is currently invoked from outside wgengine (containerboot), so
// we don't want to race with wgengine for rule ordering within chains. // we don't want to race with wgengine for rule ordering within chains.
func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { func (r *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
polAccept := nftables.ChainPolicyAccept polAccept := nftables.ChainPolicyAccept
table, err := n.getNFTByAddr(addr) table, err := r.getNFTByAddr(addr)
if err != nil { if err != nil {
return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err) return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
} }
filterTable, err := createTableIfNotExist(n.conn, table.Proto, "filter") filterTable, err := createTableIfNotExist(r.conn, table.Proto, "filter")
if err != nil { if err != nil {
return fmt.Errorf("error ensuring filter table: %w", err) return fmt.Errorf("error ensuring filter table: %w", err)
} }
// ensure ts-clamp chain exists // ensure ts-clamp chain exists
fwChain, err := getOrCreateChain(n.conn, chainInfo{ fwChain, err := getOrCreateChain(r.conn, chainInfo{
table: filterTable, table: filterTable,
name: "ts-clamp", name: "ts-clamp",
chainType: nftables.ChainTypeFilter, chainType: nftables.ChainTypeFilter,
@ -344,8 +345,8 @@ func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
}, },
}, },
} }
n.conn.AddRule(clampRule) r.conn.AddRule(clampRule)
return n.conn.Flush() return r.conn.Flush()
} }
// deleteTableIfExists deletes a nftables table via connection c if it exists // deleteTableIfExists deletes a nftables table via connection c if it exists
@ -476,6 +477,14 @@ func getOrCreateChain(c *nftables.Conn, cinfo chainInfo) (*nftables.Chain, error
return chain, nil return chain, nil
} }
type AddressFamilies uint8
const (
FamilyBoth = FamilyIPv4 | FamilyIPv6
FamilyIPv4 = AddressFamilies(1 << iota)
FamilyIPv6
)
// NetfilterRunner abstracts helpers to run netfilter commands. It is // NetfilterRunner abstracts helpers to run netfilter commands. It is
// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner. // implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
type NetfilterRunner interface { type NetfilterRunner interface {
@ -499,6 +508,10 @@ type NetfilterRunner interface {
// DelChains removes chains added by AddChains. // DelChains removes chains added by AddChains.
DelChains() error DelChains() error
// SetAddressFamilies sets the address families (IPv4, IPv6, or both) that
// the runner should operate on.
SetAddressFamilies(AddressFamilies)
// AddBase adds rules reused by different other rules. // AddBase adds rules reused by different other rules.
AddBase(tunname string) error AddBase(tunname string) error
@ -521,16 +534,16 @@ type NetfilterRunner interface {
// using conntrack. // using conntrack.
DelStatefulRule(tunname string) error DelStatefulRule(tunname string) error
// HasIPV6 reports true if the system supports IPv6. // HasIPV6 reports whether the system supports IPv6.
HasIPV6() bool HasIPV6() bool
// HasIPV6NAT reports true if the system supports IPv6 NAT. // HasIPV6NAT reports whether the system supports IPv6 NAT.
HasIPV6NAT() bool HasIPV6NAT() bool
// HasIPV6Filter reports true if the system supports IPv6 filter tables // HasIPV6Filter reports whether the system supports IPv6 filter tables.
// This is only meaningful for iptables implementation, where hosts have // This is only meaningful for iptables implementation, where hosts have
// partial ipables support (i.e missing filter table). For nftables // partial ipables support (i.e missing filter table). For the nftables
// implementation, this will default to the value of HasIPv6(). // implementation, this will default to the value of HasIPV6.
HasIPV6Filter() bool HasIPV6Filter() bool
// AddDNATRule adds a rule to the nat/PREROUTING chain to DNAT traffic // AddDNATRule adds a rule to the nat/PREROUTING chain to DNAT traffic
@ -645,6 +658,7 @@ func newNfTablesRunnerWithConn(logf logger.Logf, conn *nftables.Conn) *nftablesR
nft4: nft4, nft4: nft4,
nft6: nft6, nft6: nft6,
v6Available: supportsV6, v6Available: supportsV6,
af: FamilyBoth,
} }
} }
@ -683,8 +697,12 @@ func newLoadDportExpr(destReg uint32) expr.Any {
} }
// HasIPV6 reports true if the system supports IPv6. // HasIPV6 reports true if the system supports IPv6.
func (n *nftablesRunner) HasIPV6() bool { func (r *nftablesRunner) HasIPV6() bool {
return n.v6Available return r.v6Available
}
func (r *nftablesRunner) SetAddressFamilies(af AddressFamilies) {
r.af = af
} }
// HasIPV6NAT returns true if the system supports IPv6. // HasIPV6NAT returns true if the system supports IPv6.
@ -692,15 +710,15 @@ func (n *nftablesRunner) HasIPV6() bool {
// NAT, so no need for a separate IPv6 NAT support check like we do for iptables. // NAT, so no need for a separate IPv6 NAT support check like we do for iptables.
// https://tldp.org/HOWTO/Linux+IPv6-HOWTO/ch18s04.html // https://tldp.org/HOWTO/Linux+IPv6-HOWTO/ch18s04.html
// https://wiki.nftables.org/wiki-nftables/index.php/Building_and_installing_nftables_from_sources // https://wiki.nftables.org/wiki-nftables/index.php/Building_and_installing_nftables_from_sources
func (n *nftablesRunner) HasIPV6NAT() bool { func (r *nftablesRunner) HasIPV6NAT() bool {
return n.v6Available return r.v6Available
} }
// HasIPV6Filter returns true if system supports IPv6. There are no known edge // HasIPV6Filter returns true if system supports IPv6. There are no known edge
// cases where nftables running on a host that supports IPv6 would not support // cases where nftables running on a host that supports IPv6 would not support
// filter table. // filter table.
func (n *nftablesRunner) HasIPV6Filter() bool { func (r *nftablesRunner) HasIPV6Filter() bool {
return n.v6Available return r.v6Available
} }
// findRule iterates through the rules to find the rule with matching expressions. // findRule iterates through the rules to find the rule with matching expressions.
@ -808,30 +826,30 @@ func insertLoopbackRule(
// getNFTByAddr returns the nftables with correct IP family // getNFTByAddr returns the nftables with correct IP family
// that we will be using for the given address. // that we will be using for the given address.
func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) (*nftable, error) { func (r *nftablesRunner) getNFTByAddr(addr netip.Addr) (*nftable, error) {
if addr.Is6() && !n.v6Available { if addr.Is6() && !r.v6Available {
return nil, fmt.Errorf("nftables for IPv6 are not available on this host") return nil, fmt.Errorf("nftables for IPv6 are not available on this host")
} }
if addr.Is6() { if addr.Is6() {
return n.nft6, nil return r.nft6, nil
} }
return n.nft4, nil return r.nft4, nil
} }
// AddLoopbackRule adds an nftables rule to permit loopback traffic to // AddLoopbackRule adds an nftables rule to permit loopback traffic to
// a local Tailscale IP. This rule is added only if it does not already exist. // a local Tailscale IP. This rule is added only if it does not already exist.
func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error { func (r *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
nf, err := n.getNFTByAddr(addr) nf, err := r.getNFTByAddr(addr)
if err != nil { if err != nil {
return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err) return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
} }
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) inputChain, err := getChainFromTable(r.conn, nf.Filter, chainNameInput)
if err != nil { if err != nil {
return fmt.Errorf("get input chain: %w", err) return fmt.Errorf("get input chain: %w", err)
} }
if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil { if err := insertLoopbackRule(r.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil {
return fmt.Errorf("add loopback rule: %w", err) return fmt.Errorf("add loopback rule: %w", err)
} }
@ -840,13 +858,13 @@ func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
// DelLoopbackRule removes the nftables rule permitting loopback // DelLoopbackRule removes the nftables rule permitting loopback
// traffic to a Tailscale IP. // traffic to a Tailscale IP.
func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error { func (r *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
nf, err := n.getNFTByAddr(addr) nf, err := r.getNFTByAddr(addr)
if err != nil { if err != nil {
return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err) return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
} }
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) inputChain, err := getChainFromTable(r.conn, nf.Filter, chainNameInput)
if err != nil { if err != nil {
return fmt.Errorf("get input chain: %w", err) return fmt.Errorf("get input chain: %w", err)
} }
@ -856,7 +874,7 @@ func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
return fmt.Errorf("create loopback rule: %w", err) return fmt.Errorf("create loopback rule: %w", err)
} }
existingLoopBackRule, err := findRule(n.conn, loopBackRule) existingLoopBackRule, err := findRule(r.conn, loopBackRule)
if err != nil { if err != nil {
return fmt.Errorf("find loop back rule: %w", err) return fmt.Errorf("find loop back rule: %w", err)
} }
@ -865,48 +883,48 @@ func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
return nil return nil
} }
if err := n.conn.DelRule(existingLoopBackRule); err != nil { if err := r.conn.DelRule(existingLoopBackRule); err != nil {
return fmt.Errorf("delete rule: %w", err) return fmt.Errorf("delete rule: %w", err)
} }
return n.conn.Flush() return r.conn.Flush()
} }
// getTables returns tables for IP families that this host was determined to // getTables returns tables for IP families that this host was determined to
// support (either IPv4 and IPv6 or just IPv4). // support (either IPv4 and IPv6 or just IPv4).
func (n *nftablesRunner) getTables() []*nftable { func (r *nftablesRunner) getTables() []*nftable {
if n.HasIPV6() { if r.HasIPV6() {
return []*nftable{n.nft4, n.nft6} return []*nftable{r.nft4, r.nft6}
} }
return []*nftable{n.nft4} return []*nftable{r.nft4}
} }
// AddChains creates custom Tailscale chains in netfilter via nftables // AddChains creates custom Tailscale chains in netfilter via nftables
// if the ts-chain doesn't already exist. // if the ts-chain doesn't already exist.
func (n *nftablesRunner) AddChains() error { func (r *nftablesRunner) AddChains() error {
polAccept := nftables.ChainPolicyAccept polAccept := nftables.ChainPolicyAccept
for _, table := range n.getTables() { for _, table := range r.getTables() {
// Create the filter table if it doesn't exist, this table name is the same // Create the filter table if it doesn't exist, this table name is the same
// as the name used by iptables-nft and ufw. We install rules into the // as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump // same conventional table so that `accept` verdicts from our jump
// chains are conclusive. // chains are conclusive.
filter, err := createTableIfNotExist(n.conn, table.Proto, "filter") filter, err := createTableIfNotExist(r.conn, table.Proto, "filter")
if err != nil { if err != nil {
return fmt.Errorf("create table: %w", err) return fmt.Errorf("create table: %w", err)
} }
table.Filter = filter table.Filter = filter
// Adding the "conventional chains" that are used by iptables-nft and ufw. // Adding the "conventional chains" that are used by iptables-nft and ufw.
if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil { if err = createChainIfNotExist(r.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
return fmt.Errorf("create forward chain: %w", err) return fmt.Errorf("create forward chain: %w", err)
} }
if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil { if err = createChainIfNotExist(r.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
return fmt.Errorf("create input chain: %w", err) return fmt.Errorf("create input chain: %w", err)
} }
// Adding the tailscale chains that contain our rules. // Adding the tailscale chains that contain our rules.
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil { if err = createChainIfNotExist(r.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create forward chain: %w", err) return fmt.Errorf("create forward chain: %w", err)
} }
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil { if err = createChainIfNotExist(r.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create input chain: %w", err) return fmt.Errorf("create input chain: %w", err)
} }
@ -914,22 +932,22 @@ func (n *nftablesRunner) AddChains() error {
// as the name used by iptables-nft and ufw. We install rules into the // as the name used by iptables-nft and ufw. We install rules into the
// same conventional table so that `accept` verdicts from our jump // same conventional table so that `accept` verdicts from our jump
// chains are conclusive. // chains are conclusive.
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") nat, err := createTableIfNotExist(r.conn, table.Proto, "nat")
if err != nil { if err != nil {
return fmt.Errorf("create table: %w", err) return fmt.Errorf("create table: %w", err)
} }
table.Nat = nat table.Nat = nat
// Adding the "conventional chains" that are used by iptables-nft and ufw. // Adding the "conventional chains" that are used by iptables-nft and ufw.
if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil { if err = createChainIfNotExist(r.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
return fmt.Errorf("create postrouting chain: %w", err) return fmt.Errorf("create postrouting chain: %w", err)
} }
// Adding the tailscale chain that contains our rules. // Adding the tailscale chain that contains our rules.
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil { if err = createChainIfNotExist(r.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
return fmt.Errorf("create postrouting chain: %w", err) return fmt.Errorf("create postrouting chain: %w", err)
} }
} }
return n.conn.Flush() return r.conn.Flush()
} }
// These are dummy chains and tables we create to detect if nftables is // These are dummy chains and tables we create to detect if nftables is
@ -945,24 +963,24 @@ const (
// createDummyPostroutingChains creates dummy postrouting chains in netfilter // createDummyPostroutingChains creates dummy postrouting chains in netfilter
// via netfilter via nftables, as a last resort measure to detect that nftables // via netfilter via nftables, as a last resort measure to detect that nftables
// can be used. It cleans up the dummy chains after creation. // can be used. It cleans up the dummy chains after creation.
func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) { func (r *nftablesRunner) createDummyPostroutingChains() (retErr error) {
polAccept := ptr.To(nftables.ChainPolicyAccept) polAccept := ptr.To(nftables.ChainPolicyAccept)
for _, table := range n.getTables() { for _, table := range r.getTables() {
nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName) nat, err := createTableIfNotExist(r.conn, table.Proto, tsDummyTableName)
if err != nil { if err != nil {
return fmt.Errorf("create nat table: %w", err) return fmt.Errorf("create nat table: %w", err)
} }
defer func(fm nftables.TableFamily) { defer func(fm nftables.TableFamily) {
if err := deleteTableIfExists(n.conn, fm, tsDummyTableName); err != nil && retErr == nil { if err := deleteTableIfExists(r.conn, fm, tsDummyTableName); err != nil && retErr == nil {
retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err) retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err)
} }
}(table.Proto) }(table.Proto)
table.Nat = nat table.Nat = nat
if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil { if err = createChainIfNotExist(r.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil {
return fmt.Errorf("create %q chain: %w", tsDummyChainName, err) return fmt.Errorf("create %q chain: %w", tsDummyChainName, err)
} }
if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil { if err := deleteChainIfExists(r.conn, nat, tsDummyChainName); err != nil {
return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err) return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err)
} }
} }
@ -990,27 +1008,27 @@ func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) e
} }
// DelChains removes the custom Tailscale chains from netfilter via nftables. // DelChains removes the custom Tailscale chains from netfilter via nftables.
func (n *nftablesRunner) DelChains() error { func (r *nftablesRunner) DelChains() error {
for _, table := range n.getTables() { for _, table := range r.getTables() {
if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil { if err := deleteChainIfExists(r.conn, table.Filter, chainNameForward); err != nil {
return fmt.Errorf("delete chain: %w", err) return fmt.Errorf("delete chain: %w", err)
} }
if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil { if err := deleteChainIfExists(r.conn, table.Filter, chainNameInput); err != nil {
return fmt.Errorf("delete chain: %w", err) return fmt.Errorf("delete chain: %w", err)
} }
} }
if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil { if err := deleteChainIfExists(r.conn, r.nft4.Nat, chainNamePostrouting); err != nil {
return fmt.Errorf("delete chain: %w", err) return fmt.Errorf("delete chain: %w", err)
} }
if n.HasIPV6NAT() { if r.HasIPV6NAT() {
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil { if err := deleteChainIfExists(r.conn, r.nft6.Nat, chainNamePostrouting); err != nil {
return fmt.Errorf("delete chain: %w", err) return fmt.Errorf("delete chain: %w", err)
} }
} }
if err := n.conn.Flush(); err != nil { if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush: %w", err) return fmt.Errorf("flush: %w", err)
} }
@ -1050,10 +1068,10 @@ func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables
// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING" // AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
// in tables and jump from those chains to tailscale chains. // in tables and jump from those chains to tailscale chains.
func (n *nftablesRunner) AddHooks() error { func (r *nftablesRunner) AddHooks() error {
conn := n.conn conn := r.conn
for _, table := range n.getTables() { for _, table := range r.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
if err != nil { if err != nil {
return fmt.Errorf("get INPUT chain: %w", err) return fmt.Errorf("get INPUT chain: %w", err)
@ -1104,10 +1122,10 @@ func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables
} }
// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains. // DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
func (n *nftablesRunner) DelHooks(logf logger.Logf) error { func (r *nftablesRunner) DelHooks(logf logger.Logf) error {
conn := n.conn conn := r.conn
for _, table := range n.getTables() { for _, table := range r.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
if err != nil { if err != nil {
return fmt.Errorf("get INPUT chain: %w", err) return fmt.Errorf("get INPUT chain: %w", err)
@ -1437,23 +1455,23 @@ func removeAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *n
// the specified UDP port, so magicsock can accept incoming connections. // the specified UDP port, so magicsock can accept incoming connections.
// network must be either "udp4" or "udp6" - this determines whether the rule // network must be either "udp4" or "udp6" - this determines whether the rule
// is added for IPv4 or IPv6. // is added for IPv4 or IPv6.
func (n *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error { func (r *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error {
var filterTable *nftables.Table var filterTable *nftables.Table
switch network { switch network {
case "udp4": case "udp4":
filterTable = n.nft4.Filter filterTable = r.nft4.Filter
case "udp6": case "udp6":
filterTable = n.nft6.Filter filterTable = r.nft6.Filter
default: default:
return fmt.Errorf("unsupported network %s", network) return fmt.Errorf("unsupported network %s", network)
} }
inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput) inputChain, err := getChainFromTable(r.conn, filterTable, chainNameInput)
if err != nil { if err != nil {
return fmt.Errorf("get input chain: %v", err) return fmt.Errorf("get input chain: %v", err)
} }
err = addAcceptOnPortRule(n.conn, filterTable, inputChain, port) err = addAcceptOnPortRule(r.conn, filterTable, inputChain, port)
if err != nil { if err != nil {
return fmt.Errorf("add accept on port rule: %v", err) return fmt.Errorf("add accept on port rule: %v", err)
} }
@ -1465,23 +1483,23 @@ func (n *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error
// incoming traffic on a particular UDP port. // incoming traffic on a particular UDP port.
// network must be either "udp4" or "udp6" - this determines whether the rule // network must be either "udp4" or "udp6" - this determines whether the rule
// is removed for IPv4 or IPv6. // is removed for IPv4 or IPv6.
func (n *nftablesRunner) DelMagicsockPortRule(port uint16, network string) error { func (r *nftablesRunner) DelMagicsockPortRule(port uint16, network string) error {
var filterTable *nftables.Table var filterTable *nftables.Table
switch network { switch network {
case "udp4": case "udp4":
filterTable = n.nft4.Filter filterTable = r.nft4.Filter
case "udp6": case "udp6":
filterTable = n.nft6.Filter filterTable = r.nft6.Filter
default: default:
return fmt.Errorf("unsupported network %s", network) return fmt.Errorf("unsupported network %s", network)
} }
inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput) inputChain, err := getChainFromTable(r.conn, filterTable, chainNameInput)
if err != nil { if err != nil {
return fmt.Errorf("get input chain: %v", err) return fmt.Errorf("get input chain: %v", err)
} }
err = removeAcceptOnPortRule(n.conn, filterTable, inputChain, port) err = removeAcceptOnPortRule(r.conn, filterTable, inputChain, port)
if err != nil { if err != nil {
return fmt.Errorf("add accept on port rule: %v", err) return fmt.Errorf("add accept on port rule: %v", err)
} }
@ -1522,12 +1540,14 @@ func addAcceptIncomingPacketRule(conn *nftables.Conn, table *nftables.Table, cha
} }
// AddBase adds some basic processing rules. // AddBase adds some basic processing rules.
func (n *nftablesRunner) AddBase(tunname string) error { func (r *nftablesRunner) AddBase(tunname string) error {
if err := n.addBase4(tunname); err != nil { if r.af&FamilyIPv4 != 0 {
if err := r.addBase4(tunname); err != nil {
return fmt.Errorf("add base v4: %w", err) return fmt.Errorf("add base v4: %w", err)
} }
if n.HasIPV6() { }
if err := n.addBase6(tunname); err != nil { if r.af&FamilyIPv6 != 0 && r.HasIPV6() {
if err := r.addBase6(tunname); err != nil {
return fmt.Errorf("add base v6: %w", err) return fmt.Errorf("add base v6: %w", err)
} }
} }
@ -1535,41 +1555,41 @@ func (n *nftablesRunner) AddBase(tunname string) error {
} }
// addBase4 adds some basic IPv4 processing rules. // addBase4 adds some basic IPv4 processing rules.
func (n *nftablesRunner) addBase4(tunname string) error { func (r *nftablesRunner) addBase4(tunname string) error {
conn := n.conn conn := r.conn
inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput) inputChain, err := getChainFromTable(conn, r.nft4.Filter, chainNameInput)
if err != nil { if err != nil {
return fmt.Errorf("get input chain v4: %v", err) return fmt.Errorf("get input chain v4: %v", err)
} }
if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { if err = addReturnChromeOSVMRangeRule(conn, r.nft4.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add return chromeos vm range rule v4: %w", err) return fmt.Errorf("add return chromeos vm range rule v4: %w", err)
} }
if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { if err = addDropCGNATRangeRule(conn, r.nft4.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add drop cgnat range rule v4: %w", err) return fmt.Errorf("add drop cgnat range rule v4: %w", err)
} }
if err = addAcceptIncomingPacketRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { if err = addAcceptIncomingPacketRule(conn, r.nft4.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add accept incoming packet rule v4: %w", err) return fmt.Errorf("add accept incoming packet rule v4: %w", err)
} }
forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward) forwardChain, err := getChainFromTable(conn, r.nft4.Filter, chainNameForward)
if err != nil { if err != nil {
return fmt.Errorf("get forward chain v4: %v", err) return fmt.Errorf("get forward chain v4: %v", err)
} }
if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil { if err = addSetSubnetRouteMarkRule(conn, r.nft4.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add set subnet route mark rule v4: %w", err) return fmt.Errorf("add set subnet route mark rule v4: %w", err)
} }
if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil { if err = addMatchSubnetRouteMarkRule(conn, r.nft4.Filter, forwardChain, Accept); err != nil {
return fmt.Errorf("add match subnet route mark rule v4: %w", err) return fmt.Errorf("add match subnet route mark rule v4: %w", err)
} }
if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil { if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, r.nft4.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err) return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err)
} }
if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil { if err = addAcceptOutgoingPacketRule(conn, r.nft4.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add accept outgoing packet rule v4: %w", err) return fmt.Errorf("add accept outgoing packet rule v4: %w", err)
} }
@ -1581,31 +1601,31 @@ func (n *nftablesRunner) addBase4(tunname string) error {
} }
// addBase6 adds some basic IPv6 processing rules. // addBase6 adds some basic IPv6 processing rules.
func (n *nftablesRunner) addBase6(tunname string) error { func (r *nftablesRunner) addBase6(tunname string) error {
conn := n.conn conn := r.conn
inputChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameInput) inputChain, err := getChainFromTable(conn, r.nft6.Filter, chainNameInput)
if err != nil { if err != nil {
return fmt.Errorf("get input chain v4: %v", err) return fmt.Errorf("get input chain v4: %v", err)
} }
if err = addAcceptIncomingPacketRule(conn, n.nft6.Filter, inputChain, tunname); err != nil { if err = addAcceptIncomingPacketRule(conn, r.nft6.Filter, inputChain, tunname); err != nil {
return fmt.Errorf("add accept incoming packet rule v6: %w", err) return fmt.Errorf("add accept incoming packet rule v6: %w", err)
} }
forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward) forwardChain, err := getChainFromTable(conn, r.nft6.Filter, chainNameForward)
if err != nil { if err != nil {
return fmt.Errorf("get forward chain v6: %w", err) return fmt.Errorf("get forward chain v6: %w", err)
} }
if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil { if err = addSetSubnetRouteMarkRule(conn, r.nft6.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add set subnet route mark rule v6: %w", err) return fmt.Errorf("add set subnet route mark rule v6: %w", err)
} }
if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil { if err = addMatchSubnetRouteMarkRule(conn, r.nft6.Filter, forwardChain, Accept); err != nil {
return fmt.Errorf("add match subnet route mark rule v6: %w", err) return fmt.Errorf("add match subnet route mark rule v6: %w", err)
} }
if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil { if err = addAcceptOutgoingPacketRule(conn, r.nft6.Filter, forwardChain, tunname); err != nil {
return fmt.Errorf("add accept outgoing packet rule v6: %w", err) return fmt.Errorf("add accept outgoing packet rule v6: %w", err)
} }
@ -1618,10 +1638,10 @@ func (n *nftablesRunner) addBase6(tunname string) error {
// DelBase empties, but does not remove, custom Tailscale chains from // DelBase empties, but does not remove, custom Tailscale chains from
// netfilter via iptables. // netfilter via iptables.
func (n *nftablesRunner) DelBase() error { func (r *nftablesRunner) DelBase() error {
conn := n.conn conn := r.conn
for _, table := range n.getTables() { for _, table := range r.getTables() {
inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput) inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
if err != nil { if err != nil {
return fmt.Errorf("get input chain: %v", err) return fmt.Errorf("get input chain: %v", err)
@ -1699,10 +1719,10 @@ func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, cha
// AddSNATRule adds a netfilter rule to SNAT traffic destined for // AddSNATRule adds a netfilter rule to SNAT traffic destined for
// local subnets. // local subnets.
func (n *nftablesRunner) AddSNATRule() error { func (r *nftablesRunner) AddSNATRule() error {
conn := n.conn conn := r.conn
for _, table := range n.getTables() { for _, table := range r.getTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil { if err != nil {
return fmt.Errorf("get postrouting chain v4: %w", err) return fmt.Errorf("get postrouting chain v4: %w", err)
@ -1745,10 +1765,10 @@ func delMatchSubnetRouteMarkMasqRule(conn *nftables.Conn, table *nftables.Table,
// DelSNATRule removes the netfilter rule to SNAT traffic destined for // DelSNATRule removes the netfilter rule to SNAT traffic destined for
// local subnets. An error is returned if the rule does not exist. // local subnets. An error is returned if the rule does not exist.
func (n *nftablesRunner) DelSNATRule() error { func (r *nftablesRunner) DelSNATRule() error {
conn := n.conn conn := r.conn
for _, table := range n.getTables() { for _, table := range r.getTables() {
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
if err != nil { if err != nil {
return fmt.Errorf("get postrouting chain: %w", err) return fmt.Errorf("get postrouting chain: %w", err)
@ -1886,11 +1906,11 @@ func makeStatefulRuleExprs(tunname string) []expr.Any {
// AddStatefulRule adds a netfilter rule for stateful packet filtering using // AddStatefulRule adds a netfilter rule for stateful packet filtering using
// conntrack. // conntrack.
func (n *nftablesRunner) AddStatefulRule(tunname string) error { func (r *nftablesRunner) AddStatefulRule(tunname string) error {
conn := n.conn conn := r.conn
exprs := makeStatefulRuleExprs(tunname) exprs := makeStatefulRuleExprs(tunname)
for _, table := range n.getTables() { for _, table := range r.getTables() {
chain, err := getChainFromTable(conn, table.Filter, chainNameForward) chain, err := getChainFromTable(conn, table.Filter, chainNameForward)
if err != nil { if err != nil {
return fmt.Errorf("get forward chain: %w", err) return fmt.Errorf("get forward chain: %w", err)
@ -1922,11 +1942,11 @@ func (n *nftablesRunner) AddStatefulRule(tunname string) error {
// DelStatefulRule removes the netfilter rule for stateful packet filtering // DelStatefulRule removes the netfilter rule for stateful packet filtering
// using conntrack. // using conntrack.
func (n *nftablesRunner) DelStatefulRule(tunname string) error { func (r *nftablesRunner) DelStatefulRule(tunname string) error {
conn := n.conn conn := r.conn
exprs := makeStatefulRuleExprs(tunname) exprs := makeStatefulRuleExprs(tunname)
for _, table := range n.getTables() { for _, table := range r.getTables() {
chain, err := getChainFromTable(conn, table.Filter, chainNameForward) chain, err := getChainFromTable(conn, table.Filter, chainNameForward)
if err != nil { if err != nil {
return fmt.Errorf("get forward chain: %w", err) return fmt.Errorf("get forward chain: %w", err)

@ -12,6 +12,7 @@ import (
"net/netip" "net/netip"
"os" "os"
"os/exec" "os/exec"
"slices"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
@ -81,6 +82,7 @@ type linuxRouter struct {
nfr linuxfw.NetfilterRunner nfr linuxfw.NetfilterRunner
mu sync.Mutex mu sync.Mutex
af linuxfw.AddressFamilies
addrs map[netip.Prefix]bool addrs map[netip.Prefix]bool
routes map[netip.Prefix]bool routes map[netip.Prefix]bool
localRoutes map[netip.Prefix]bool localRoutes map[netip.Prefix]bool
@ -112,6 +114,7 @@ func newUserspaceRouterAdvanced(logf logger.Logf, tunname string, netMon *netmon
netfilterMode: netfilterOff, netfilterMode: netfilterOff,
netMon: netMon, netMon: netMon,
health: health, health: health,
af: linuxfw.FamilyBoth,
cmd: cmd, cmd: cmd,
@ -401,6 +404,7 @@ func (r *linuxRouter) setupNetfilterLocked(kind string) error {
if err != nil { if err != nil {
return fmt.Errorf("could not create new netfilter: %w", err) return fmt.Errorf("could not create new netfilter: %w", err)
} }
r.nfr.SetAddressFamilies(r.af)
return nil return nil
} }
@ -414,6 +418,12 @@ func (r *linuxRouter) Set(cfg *router.Config) error {
cfg = &shutdownConfig cfg = &shutdownConfig
} }
r.af = linuxfw.FamilyBoth
if !slices.ContainsFunc(cfg.LocalAddrs, func(p netip.Prefix) bool { return p.Addr().Is4() }) {
r.af = linuxfw.FamilyIPv6
}
r.nfr.SetAddressFamilies(r.af)
if cfg.NetfilterKind != r.netfilterKind { if cfg.NetfilterKind != r.netfilterKind {
if err := r.setNetfilterModeLocked(netfilterOff); err != nil { if err := r.setNetfilterModeLocked(netfilterOff); err != nil {
err = fmt.Errorf("could not disable existing netfilter: %w", err) err = fmt.Errorf("could not disable existing netfilter: %w", err)

@ -420,6 +420,11 @@ type fakeIPTablesRunner struct {
ipt4 map[string][]string ipt4 map[string][]string
ipt6 map[string][]string ipt6 map[string][]string
// we always assume ipv6 and ipv6 nat are enabled when testing // we always assume ipv6 and ipv6 nat are enabled when testing
af linuxfw.AddressFamilies
}
func (r *fakeIPTablesRunner) SetAddressFamilies(af linuxfw.AddressFamilies) {
r.af = af
} }
func newIPTablesRunner(t *testing.T) linuxfw.NetfilterRunner { func newIPTablesRunner(t *testing.T) linuxfw.NetfilterRunner {

Loading…
Cancel
Save