// Copyright (c) Tailscale Inc & AUTHORS // SPDX-License-Identifier: BSD-3-Clause //go:build linux package linuxfw import ( "encoding/binary" "encoding/hex" "errors" "fmt" "net" "net/netip" "reflect" "strings" "github.com/google/nftables" "github.com/google/nftables/expr" "golang.org/x/sys/unix" "tailscale.com/net/tsaddr" "tailscale.com/types/logger" "tailscale.com/types/ptr" ) const ( chainNameForward = "ts-forward" chainNameInput = "ts-input" chainNamePostrouting = "ts-postrouting" ) // chainTypeRegular is an nftables chain that does not apply to a hook. const chainTypeRegular = "" type chainInfo struct { table *nftables.Table name string chainType nftables.ChainType chainHook *nftables.ChainHook chainPriority *nftables.ChainPriority chainPolicy *nftables.ChainPolicy } type nftable struct { Proto nftables.TableFamily Filter *nftables.Table Nat *nftables.Table } // nftablesRunner implements a netfilterRunner using the netlink based nftables // library. As nftables allows for arbitrary tables and chains, there is a need // to follow conventions in order to integrate well with a surrounding // ecosystem. The rules installed by nftablesRunner have the following // properties: // - Install rules that intend to take precedence over rules installed by // other software. Tailscale provides packet filtering for tailnet traffic // inside the daemon based on the tailnet ACL rules. // - As nftables "accept" is not final, rules from high priority tables (low // numbers) will fall through to lower priority tables (high numbers). In // order to effectively be 'final', we install "jump" rules into conventional // tables and chains that will reach an accept verdict inside those tables. // - The table and chain conventions followed here are those used by // `iptables-nft` and `ufw`, so that those tools co-exist and do not // negatively affect Tailscale function. // - Be mindful that 1) all chains attached to a given hook (i.e the forward hook) // will be processed in priority order till either a rule in one of the chains issues a drop verdict // or there are no more chains for that hook // 2) processing of individual rules within a chain will stop once one of them issues a final verdict (accept, drop). // https://wiki.nftables.org/wiki-nftables/index.php/Configuring_chains type nftablesRunner struct { conn *nftables.Conn nft4 *nftable nft6 *nftable v6Available bool v6NATAvailable bool } func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) { polAccept := nftables.ChainPolicyAccept table := n.getNFTByAddr(dst) nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") if err != nil { return nil, nil, fmt.Errorf("error ensuring nat table: %w", err) } // ensure prerouting chain exists preroutingCh, err := getOrCreateChain(n.conn, chainInfo{ table: nat, name: "PREROUTING", chainType: nftables.ChainTypeNAT, chainHook: nftables.ChainHookPrerouting, chainPriority: nftables.ChainPriorityNATDest, chainPolicy: &polAccept, }) if err != nil { return nil, nil, fmt.Errorf("error ensuring prerouting chain: %w", err) } return nat, preroutingCh, nil } func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error { nat, preroutingCh, err := n.ensurePreroutingChain(dst) if err != nil { return err } var daddrOffset, fam, dadderLen uint32 if origDst.Is4() { daddrOffset = 16 dadderLen = 4 fam = unix.NFPROTO_IPV4 } else { daddrOffset = 24 dadderLen = 16 fam = unix.NFPROTO_IPV6 } dnatRule := &nftables.Rule{ Table: nat, Chain: preroutingCh, Exprs: []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: daddrOffset, Len: dadderLen, }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: origDst.AsSlice(), }, &expr.Immediate{ Register: 1, Data: dst.AsSlice(), }, &expr.NAT{ Type: expr.NATTypeDestNAT, Family: fam, RegAddrMin: 1, }, }, } n.conn.InsertRule(dnatRule) return n.conn.Flush() } // DNATWithLoadBalancer currently just forwards all traffic destined for origDst // to the first IP address from the backend targets. // TODO (irbekrm): instead of doing this load balance traffic evenly to all // backend destinations. // https://github.com/tailscale/tailscale/commit/d37f2f508509c6c35ad724fd75a27685b90b575b#diff-a3bcbcd1ca198799f4f768dc56fea913e1945a6b3ec9dbec89325a84a19a85e7R148-R232 func (n *nftablesRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error { return n.AddDNATRule(origDst, dsts[0]) } func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) error { nat, preroutingCh, err := n.ensurePreroutingChain(dst) if err != nil { return err } var famConst uint32 if dst.Is4() { famConst = unix.NFPROTO_IPV4 } else { famConst = unix.NFPROTO_IPV6 } dnatRule := &nftables.Rule{ Table: nat, Chain: preroutingCh, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, &expr.Cmp{ Op: expr.CmpOpNeq, Register: 1, Data: []byte(tunname), }, &expr.Immediate{ Register: 1, Data: dst.AsSlice(), }, &expr.NAT{ Type: expr.NATTypeDestNAT, Family: famConst, RegAddrMin: 1, }, }, } n.conn.InsertRule(dnatRule) return n.conn.Flush() } func (n *nftablesRunner) AddSNATRuleForDst(src, dst netip.Addr) error { polAccept := nftables.ChainPolicyAccept table := n.getNFTByAddr(dst) nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") if err != nil { return fmt.Errorf("error ensuring nat table exists: %w", err) } // ensure postrouting chain exists postRoutingCh, err := getOrCreateChain(n.conn, chainInfo{ table: nat, name: "POSTROUTING", chainType: nftables.ChainTypeNAT, chainHook: nftables.ChainHookPostrouting, chainPriority: nftables.ChainPriorityNATSource, chainPolicy: &polAccept, }) if err != nil { return fmt.Errorf("error ensuring postrouting chain: %w", err) } var daddrOffset, fam, daddrLen uint32 if dst.Is4() { daddrOffset = 16 daddrLen = 4 fam = unix.NFPROTO_IPV4 } else { daddrOffset = 24 daddrLen = 16 fam = unix.NFPROTO_IPV6 } snatRule := &nftables.Rule{ Table: nat, Chain: postRoutingCh, Exprs: []expr.Any{ &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseNetworkHeader, Offset: daddrOffset, Len: daddrLen, }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: dst.AsSlice(), }, &expr.Immediate{ Register: 1, Data: src.AsSlice(), }, &expr.NAT{ Type: expr.NATTypeSourceNAT, Family: fam, RegAddrMin: 1, }, }, } n.conn.AddRule(snatRule) return n.conn.Flush() } // ClampMSSToPMTU ensures that all packets with TCP flags (SYN, ACK, RST) set // being forwarded via the given interface (tun) have MSS set to - 40 (IP and TCP headers). This can be useful if this tailscale // instance is expected to run as a forwarding proxy, forwarding packets from an // endpoint with higher MTU in an environment where path MTU discovery is // expected to not work (such as the proxies created by the Tailscale Kubernetes // operator). ClamMSSToPMTU creates a new base-chain ts-clamp in the filter // table with accept policy and priority -150. In practice, this means that for // SYN packets the clamp rule in this chain will likely run first and accept the // packet. This is fine because 1) nftables run ALL chains with the same hook // type unless a rule in one of them drops the packet and 2) this chain does not // have functionality to drop the packet- so in practice a matching clamp rule // will always be followed by the custom tailscale filtering rules in the other // chains attached to the filter hook (FORWARD, ts-forward). // We do not want to place the clamping rule into FORWARD/ts-forward chains // because wgengine populates those chains with rules that contain accept // verdicts that would cause no further procesing within that chain. This // functionality is currently invoked from outside wgengine (containerboot), so // we don't want to race with wgengine for rule ordering within chains. func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error { polAccept := nftables.ChainPolicyAccept table := n.getNFTByAddr(addr) filterTable, err := createTableIfNotExist(n.conn, table.Proto, "filter") if err != nil { return fmt.Errorf("error ensuring filter table: %w", err) } // ensure ts-clamp chain exists fwChain, err := getOrCreateChain(n.conn, chainInfo{ table: filterTable, name: "ts-clamp", chainType: nftables.ChainTypeFilter, chainHook: nftables.ChainHookForward, chainPriority: nftables.ChainPriorityMangle, chainPolicy: &polAccept, }) if err != nil { return fmt.Errorf("error ensuring forward chain: %w", err) } clampRule := &nftables.Rule{ Table: filterTable, Chain: fwChain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: []byte(tun), }, &expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1}, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_TCP}, }, &expr.Payload{ DestRegister: 1, Base: expr.PayloadBaseTransportHeader, Offset: 13, Len: 1, }, &expr.Bitwise{ DestRegister: 1, SourceRegister: 1, Len: 1, Mask: []byte{0x02}, Xor: []byte{0x00}, }, &expr.Cmp{ Op: expr.CmpOpNeq, // match any packet with a TCP flag set (SYN, ACK, RST) Register: 1, Data: []byte{0x00}, }, &expr.Rt{ Register: 1, Key: expr.RtTCPMSS, }, &expr.Byteorder{ DestRegister: 1, SourceRegister: 1, Op: expr.ByteorderHton, Len: 2, Size: 2, }, &expr.Exthdr{ SourceRegister: 1, Type: 2, Offset: 2, Len: 2, Op: expr.ExthdrOpTcpopt, }, }, } n.conn.AddRule(clampRule) return n.conn.Flush() } // deleteTableIfExists deletes a nftables table via connection c if it exists // within the given family. func deleteTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) error { t, err := getTableIfExists(c, family, name) if err != nil { return fmt.Errorf("get table: %w", err) } if t == nil { // Table does not exist, so nothing to delete. return nil } c.DelTable(t) if err := c.Flush(); err != nil { if t, err = getTableIfExists(c, family, name); t == nil && err == nil { // Check if the table still exists. If it does not, then the error // is due to the table not existing, so we can ignore it. Maybe a // concurrent process deleted the table. return nil } return fmt.Errorf("del table: %w", err) } return nil } // getTableIfExists returns the table with the given name from the given family // if it exists. If none match, it returns (nil, nil). func getTableIfExists(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { tables, err := c.ListTables() if err != nil { return nil, fmt.Errorf("get tables: %w", err) } for _, table := range tables { if table.Name == name && table.Family == family { return table, nil } } return nil, nil } // createTableIfNotExist creates a nftables table via connection c if it does // not exist within the given family. func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { if t, err := getTableIfExists(c, family, name); err != nil { return nil, fmt.Errorf("get table: %w", err) } else if t != nil { return t, nil } t := c.AddTable(&nftables.Table{ Family: family, Name: name, }) if err := c.Flush(); err != nil { return nil, fmt.Errorf("add table: %w", err) } return t, nil } type errorChainNotFound struct { chainName string tableName string } func (e errorChainNotFound) Error() string { return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName) } // getChainFromTable returns the chain with the given name from the given table. // Note that a chain name is unique within a table. func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) { chains, err := c.ListChainsOfTableFamily(table.Family) if err != nil { return nil, fmt.Errorf("list chains: %w", err) } for _, chain := range chains { // Table family is already checked so table name is unique if chain.Table.Name == table.Name && chain.Name == name { return chain, nil } } return nil, errorChainNotFound{table.Name, name} } // isTSChain reports whether `name` begins with "ts-" (and is thus a // Tailscale-managed chain). func isTSChain(name string) bool { return strings.HasPrefix(name, "ts-") } // createChainIfNotExist creates a chain with the given name in the given table // if it does not exist. func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { _, err := getOrCreateChain(c, cinfo) return err } func getOrCreateChain(c *nftables.Conn, cinfo chainInfo) (*nftables.Chain, error) { chain, err := getChainFromTable(c, cinfo.table, cinfo.name) if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) { return nil, fmt.Errorf("get chain: %w", err) } else if err == nil { // The chain already exists. If it is a TS chain, check the // type/hook/priority, but for "conventional chains" assume they're what // we expect (in case iptables-nft/ufw make minor behavior changes in // the future). if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || *chain.Hooknum != *cinfo.chainHook || *chain.Priority != *cinfo.chainPriority) { return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name) } return chain, nil } chain = c.AddChain(&nftables.Chain{ Name: cinfo.name, Table: cinfo.table, Type: cinfo.chainType, Hooknum: cinfo.chainHook, Priority: cinfo.chainPriority, Policy: cinfo.chainPolicy, }) if err := c.Flush(); err != nil { return nil, fmt.Errorf("add chain: %w", err) } return chain, nil } // NetfilterRunner abstracts helpers to run netfilter commands. It is // implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner. type NetfilterRunner interface { // AddLoopbackRule adds a rule to permit loopback traffic to addr. This rule // is added only if it does not already exist. AddLoopbackRule(addr netip.Addr) error // DelLoopbackRule removes the rule added by AddLoopbackRule. DelLoopbackRule(addr netip.Addr) error // AddHooks adds rules to conventional chains like "FORWARD", "INPUT" and // "POSTROUTING" to jump from those chains to tailscale chains. AddHooks() error // DelHooks deletes rules added by AddHooks. DelHooks(logf logger.Logf) error // AddChains creates custom Tailscale chains. AddChains() error // DelChains removes chains added by AddChains. DelChains() error // AddBase adds rules reused by different other rules. AddBase(tunname string) error // DelBase removes rules added by AddBase. DelBase() error // AddSNATRule adds the netfilter rule to SNAT incoming traffic over // the Tailscale interface destined for local subnets. An error is // returned if the rule already exists. AddSNATRule() error // DelSNATRule removes the rule added by AddSNATRule. DelSNATRule() error // AddStatefulRule adds a netfilter rule for stateful packet filtering // using conntrack. AddStatefulRule(tunname string) error // DelStatefulRule removes a netfilter rule for stateful packet filtering // using conntrack. DelStatefulRule(tunname string) error // HasIPV6 reports true if the system supports IPv6. HasIPV6() bool // HasIPV6NAT reports true if the system supports IPv6 NAT. HasIPV6NAT() bool // HasIPV6Filter reports true if the system supports IPv6 filter tables // This is only meaningful for iptables implementation, where hosts have // partial ipables support (i.e missing filter table). For nftables // implementation, this will default to the value of HasIPv6(). HasIPV6Filter() bool // AddDNATRule adds a rule to the nat/PREROUTING chain to DNAT traffic // destined for the given original destination to the given new destination. // This is used to forward all traffic destined for the Tailscale interface // to the provided destination, as used in the Kubernetes ingress proxies. AddDNATRule(origDst, dst netip.Addr) error // DNATWithLoadBalancer adds a rule to the nat/PREROUTING chain to DNAT // traffic destined for the given original destination to the given new // destination(s) using round robin to load balance if more than one // destination is provided. This is used to forward all traffic destined // for the Tailscale interface to the provided destination(s), as used // in the Kubernetes ingress proxies. DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error // AddSNATRuleForDst adds a rule to the nat/POSTROUTING chain to SNAT // traffic destined for dst to src. // This is used to forward traffic destined for the local machine over // the Tailscale interface, as used in the Kubernetes egress proxies. AddSNATRuleForDst(src, dst netip.Addr) error // DNATNonTailscaleTraffic adds a rule to the nat/PREROUTING chain to DNAT // all traffic inbound from any interface except exemptInterface to dst. // This is used to forward traffic destined for the local machine over // the Tailscale interface, as used in the Kubernetes egress proxies. DNATNonTailscaleTraffic(exemptInterface string, dst netip.Addr) error // ClampMSSToPMTU adds a rule to the mangle/FORWARD chain to clamp MSS for // traffic destined for the provided tun interface. ClampMSSToPMTU(tun string, addr netip.Addr) error // AddMagicsockPortRule adds a rule to the ts-input chain to accept // incoming traffic on the specified port, to allow magicsock to // communicate. AddMagicsockPortRule(port uint16, network string) error // DelMagicsockPortRule removes the rule created by AddMagicsockPortRule, // if it exists. DelMagicsockPortRule(port uint16, network string) error } // New creates a NetfilterRunner, auto-detecting whether to use // nftables or iptables. // As nftables is still experimental, iptables will be used unless // either the TS_DEBUG_FIREWALL_MODE environment variable, or the prefHint // parameter, is set to one of "nftables" or "auto". func New(logf logger.Logf, prefHint string) (NetfilterRunner, error) { mode := detectFirewallMode(logf, prefHint) switch mode { case FirewallModeIPTables: return newIPTablesRunner(logf) case FirewallModeNfTables: return newNfTablesRunner(logf) default: return nil, fmt.Errorf("unknown firewall mode %v", mode) } } // newNfTablesRunner creates a new nftablesRunner without guaranteeing // the existence of the tables and chains. func newNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { conn, err := nftables.New() if err != nil { return nil, fmt.Errorf("nftables connection: %w", err) } nft4 := &nftable{Proto: nftables.TableFamilyIPv4} v6err := CheckIPv6(logf) if v6err != nil { logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) } supportsV6 := v6err == nil var nft6 *nftable if supportsV6 { nft6 = &nftable{Proto: nftables.TableFamilyIPv6} logf("v6nat availability: true") } // TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables return &nftablesRunner{ conn: conn, nft4: nft4, nft6: nft6, v6Available: supportsV6, }, nil } // newLoadSaddrExpr creates a new nftables expression that loads the source // address of the packet into the given register. func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) { switch proto { case nftables.TableFamilyIPv4: return &expr.Payload{ DestRegister: destReg, Base: expr.PayloadBaseNetworkHeader, Offset: 12, Len: 4, }, nil case nftables.TableFamilyIPv6: return &expr.Payload{ DestRegister: destReg, Base: expr.PayloadBaseNetworkHeader, Offset: 8, Len: 16, }, nil default: return nil, fmt.Errorf("table family %v is neither IPv4 nor IPv6", proto) } } // newLoadDportExpr creates a new nftables express that loads the desination port // of a TCP/UDP packet into the given register. func newLoadDportExpr(destReg uint32) expr.Any { return &expr.Payload{ DestRegister: destReg, Base: expr.PayloadBaseTransportHeader, Offset: 2, Len: 2, } } // HasIPV6 reports true if the system supports IPv6. func (n *nftablesRunner) HasIPV6() bool { return n.v6Available } // HasIPV6NAT returns true if the system supports IPv6. // Kernel support for nftables was added after support for IPv6 // NAT, so no need for a separate IPv6 NAT support check like we do for iptables. // https://tldp.org/HOWTO/Linux+IPv6-HOWTO/ch18s04.html // https://wiki.nftables.org/wiki-nftables/index.php/Building_and_installing_nftables_from_sources func (n *nftablesRunner) HasIPV6NAT() bool { return n.v6Available } // HasIPV6Filter returns true if system supports IPv6. There are no known edge // cases where nftables running on a host that supports IPv6 would not support // filter table. func (n *nftablesRunner) HasIPV6Filter() bool { return n.v6Available } // findRule iterates through the rules to find the rule with matching expressions. func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) { rules, err := conn.GetRules(rule.Table, rule.Chain) if err != nil { return nil, fmt.Errorf("get nftables rules: %w", err) } if len(rules) == 0 { return nil, nil } ruleLoop: for _, r := range rules { if len(r.Exprs) != len(rule.Exprs) { continue } for i, e := range r.Exprs { // Skip counter expressions, as they will not match. if _, ok := e.(*expr.Counter); ok { continue } if !reflect.DeepEqual(e, rule.Exprs[i]) { continue ruleLoop } } return r, nil } return nil, nil } func createLoopbackRule( proto nftables.TableFamily, table *nftables.Table, chain *nftables.Chain, addr netip.Addr, ) (*nftables.Rule, error) { saddrExpr, err := newLoadSaddrExpr(proto, 1) if err != nil { return nil, fmt.Errorf("newLoadSaddrExpr: %w", err) } loopBackRule := &nftables.Rule{ Table: table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{ Key: expr.MetaKeyIIFNAME, Register: 1, }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: []byte("lo"), }, saddrExpr, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: addr.AsSlice(), }, &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictAccept, }, }, } return loopBackRule, nil } // insertLoopbackRule inserts the TS loop back rule into // the given chain as the first rule if it does not exist. func insertLoopbackRule( conn *nftables.Conn, proto nftables.TableFamily, table *nftables.Table, chain *nftables.Chain, addr netip.Addr) error { loopBackRule, err := createLoopbackRule(proto, table, chain, addr) if err != nil { return fmt.Errorf("create loopback rule: %w", err) } // If TestDial is set, we are running in test mode and we should not // find rule because header will mismatch. if conn.TestDial == nil { // Check if the rule already exists. rule, err := findRule(conn, loopBackRule) if err != nil { return fmt.Errorf("find rule: %w", err) } if rule != nil { // Rule already exists, no need to insert. return nil } } // This inserts the rule to the top of the chain _ = conn.InsertRule(loopBackRule) if err = conn.Flush(); err != nil { return fmt.Errorf("insert rule: %w", err) } return nil } // getNFTByAddr returns the nftables with correct IP family // that we will be using for the given address. func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable { if addr.Is6() { return n.nft6 } return n.nft4 } // AddLoopbackRule adds an nftables rule to permit loopback traffic to // a local Tailscale IP. This rule is added only if it does not already exist. func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error { nf := n.getNFTByAddr(addr) inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) if err != nil { return fmt.Errorf("get input chain: %w", err) } if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil { return fmt.Errorf("add loopback rule: %w", err) } return nil } // DelLoopbackRule removes the nftables rule permitting loopback // traffic to a Tailscale IP. func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error { nf := n.getNFTByAddr(addr) inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) if err != nil { return fmt.Errorf("get input chain: %w", err) } loopBackRule, err := createLoopbackRule(nf.Proto, nf.Filter, inputChain, addr) if err != nil { return fmt.Errorf("create loopback rule: %w", err) } existingLoopBackRule, err := findRule(n.conn, loopBackRule) if err != nil { return fmt.Errorf("find loop back rule: %w", err) } if existingLoopBackRule == nil { // Rule does not exist, no need to delete. return nil } if err := n.conn.DelRule(existingLoopBackRule); err != nil { return fmt.Errorf("delete rule: %w", err) } return n.conn.Flush() } // getTables gets the available nftable in nftables runner. func (n *nftablesRunner) getTables() []*nftable { if n.v6Available { return []*nftable{n.nft4, n.nft6} } return []*nftable{n.nft4} } // getNATTables gets the available nftable in nftables runner. // If the system does not support IPv6 NAT, only the IPv4 nftable // will be returned. func (n *nftablesRunner) getNATTables() []*nftable { if n.v6NATAvailable { return n.getTables() } return []*nftable{n.nft4} } // AddChains creates custom Tailscale chains in netfilter via nftables // if the ts-chain doesn't already exist. func (n *nftablesRunner) AddChains() error { polAccept := nftables.ChainPolicyAccept for _, table := range n.getTables() { // Create the filter table if it doesn't exist, this table name is the same // as the name used by iptables-nft and ufw. We install rules into the // same conventional table so that `accept` verdicts from our jump // chains are conclusive. filter, err := createTableIfNotExist(n.conn, table.Proto, "filter") if err != nil { return fmt.Errorf("create table: %w", err) } table.Filter = filter // Adding the "conventional chains" that are used by iptables-nft and ufw. if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil { return fmt.Errorf("create forward chain: %w", err) } if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil { return fmt.Errorf("create input chain: %w", err) } // Adding the tailscale chains that contain our rules. if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create forward chain: %w", err) } if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create input chain: %w", err) } } for _, table := range n.getNATTables() { // Create the nat table if it doesn't exist, this table name is the same // as the name used by iptables-nft and ufw. We install rules into the // same conventional table so that `accept` verdicts from our jump // chains are conclusive. nat, err := createTableIfNotExist(n.conn, table.Proto, "nat") if err != nil { return fmt.Errorf("create table: %w", err) } table.Nat = nat // Adding the "conventional chains" that are used by iptables-nft and ufw. if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil { return fmt.Errorf("create postrouting chain: %w", err) } // Adding the tailscale chain that contains our rules. if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil { return fmt.Errorf("create postrouting chain: %w", err) } } return n.conn.Flush() } // These are dummy chains and tables we create to detect if nftables is // available. We create them, then delete them. If we can create and delete // them, then we can use nftables. If we can't, then we assume that we're // running on a system that doesn't support nftables. See // createDummyPostroutingChains. const ( tsDummyChainName = "ts-test-postrouting" tsDummyTableName = "ts-test-nat" ) // createDummyPostroutingChains creates dummy postrouting chains in netfilter // via netfilter via nftables, as a last resort measure to detect that nftables // can be used. It cleans up the dummy chains after creation. func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) { polAccept := ptr.To(nftables.ChainPolicyAccept) for _, table := range n.getNATTables() { nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName) if err != nil { return fmt.Errorf("create nat table: %w", err) } defer func(fm nftables.TableFamily) { if err := deleteTableIfExists(n.conn, fm, tsDummyTableName); err != nil && retErr == nil { retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err) } }(table.Proto) table.Nat = nat if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil { return fmt.Errorf("create %q chain: %w", tsDummyChainName, err) } if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil { return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err) } } return nil } // deleteChainIfExists deletes a chain if it exists. func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error { chain, err := getChainFromTable(c, table, name) if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) { return fmt.Errorf("get chain: %w", err) } else if err != nil { // If the chain doesn't exist, we don't need to delete it. return nil } c.FlushChain(chain) c.DelChain(chain) if err := c.Flush(); err != nil { return fmt.Errorf("flush and delete chain: %w", err) } return nil } // DelChains removes the custom Tailscale chains from netfilter via nftables. func (n *nftablesRunner) DelChains() error { for _, table := range n.getTables() { if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil { return fmt.Errorf("delete chain: %w", err) } if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil { return fmt.Errorf("delete chain: %w", err) } } if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil { return fmt.Errorf("delete chain: %w", err) } if n.v6NATAvailable { if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil { return fmt.Errorf("delete chain: %w", err) } } if err := n.conn.Flush(); err != nil { return fmt.Errorf("flush: %w", err) } return nil } // createHookRule creates a rule to jump from a hooked chain to a regular chain. func createHookRule(table *nftables.Table, fromChain *nftables.Chain, toChainName string) *nftables.Rule { exprs := []expr.Any{ &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictJump, Chain: toChainName, }, } rule := &nftables.Rule{ Table: table, Chain: fromChain, Exprs: exprs, } return rule } // addHookRule adds a rule to jump from a hooked chain to a regular chain at top of the hooked chain. func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { rule := createHookRule(table, fromChain, toChainName) _ = conn.InsertRule(rule) if err := conn.Flush(); err != nil { return fmt.Errorf("flush add rule: %w", err) } return nil } // AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING" // in tables and jump from those chains to tailscale chains. func (n *nftablesRunner) AddHooks() error { conn := n.conn for _, table := range n.getTables() { inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") if err != nil { return fmt.Errorf("get INPUT chain: %w", err) } err = addHookRule(conn, table.Filter, inputChain, chainNameInput) if err != nil { return fmt.Errorf("Addhook: %w", err) } forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") if err != nil { return fmt.Errorf("get FORWARD chain: %w", err) } err = addHookRule(conn, table.Filter, forwardChain, chainNameForward) if err != nil { return fmt.Errorf("Addhook: %w", err) } } for _, table := range n.getNATTables() { postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") if err != nil { return fmt.Errorf("get INPUT chain: %w", err) } err = addHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting) if err != nil { return fmt.Errorf("Addhook: %w", err) } } return nil } // delHookRule deletes a rule that jumps from a hooked chain to a regular chain. func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables.Chain, toChainName string) error { rule := createHookRule(table, fromChain, toChainName) existingRule, err := findRule(conn, rule) if err != nil { return fmt.Errorf("Failed to find hook rule: %w", err) } if existingRule == nil { return nil } _ = conn.DelRule(existingRule) if err := conn.Flush(); err != nil { return fmt.Errorf("flush del hook rule: %w", err) } return nil } // DelHooks is deleting the rules added to conventional chains to jump to tailscale chains. func (n *nftablesRunner) DelHooks(logf logger.Logf) error { conn := n.conn for _, table := range n.getTables() { inputChain, err := getChainFromTable(conn, table.Filter, "INPUT") if err != nil { return fmt.Errorf("get INPUT chain: %w", err) } err = delHookRule(conn, table.Filter, inputChain, chainNameInput) if err != nil { return fmt.Errorf("delhook: %w", err) } forwardChain, err := getChainFromTable(conn, table.Filter, "FORWARD") if err != nil { return fmt.Errorf("get FORWARD chain: %w", err) } err = delHookRule(conn, table.Filter, forwardChain, chainNameForward) if err != nil { return fmt.Errorf("delhook: %w", err) } } for _, table := range n.getNATTables() { postroutingChain, err := getChainFromTable(conn, table.Nat, "POSTROUTING") if err != nil { return fmt.Errorf("get INPUT chain: %w", err) } err = delHookRule(conn, table.Nat, postroutingChain, chainNamePostrouting) if err != nil { return fmt.Errorf("delhook: %w", err) } } return nil } // maskof returns the mask of the given prefix in big endian bytes. func maskof(pfx netip.Prefix) []byte { mask := make([]byte, 4) binary.BigEndian.PutUint32(mask, ^(uint32(0xffff_ffff) >> pfx.Bits())) return mask } // createRangeRule creates a rule that matches packets with source IP from the give // range (like CGNAT range or ChromeOSVM range) and the interface is not the tunname, // and makes the given decision. Only IPv4 is supported. func createRangeRule( table *nftables.Table, chain *nftables.Chain, tunname string, rng netip.Prefix, decision expr.VerdictKind, ) (*nftables.Rule, error) { if rng.Addr().Is6() { return nil, errors.New("IPv6 is not supported") } saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1) if err != nil { return nil, fmt.Errorf("newLoadSaddrExpr: %w", err) } netip := rng.Addr().AsSlice() mask := maskof(rng) rule := &nftables.Rule{ Table: table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ Op: expr.CmpOpNeq, Register: 1, Data: []byte(tunname), }, saddrExpr, &expr.Bitwise{ SourceRegister: 1, DestRegister: 1, Len: 4, Mask: mask, Xor: []byte{0x00, 0x00, 0x00, 0x00}, }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: netip, }, &expr.Counter{}, &expr.Verdict{ Kind: decision, }, }, } return rule, nil } // addReturnChromeOSVMRangeRule adds a rule to return if the source IP // is in the ChromeOS VM range. func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn) if err != nil { return fmt.Errorf("create rule: %w", err) } _ = c.AddRule(rule) if err = c.Flush(); err != nil { return fmt.Errorf("add rule: %w", err) } return nil } // addDropCGNATRangeRule adds a rule to drop if the source IP is in the // CGNAT range. func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop) if err != nil { return fmt.Errorf("create rule: %w", err) } _ = c.AddRule(rule) if err = c.Flush(); err != nil { return fmt.Errorf("add rule: %w", err) } return nil } // createSetSubnetRouteMarkRule creates a rule to set the subnet route // mark if the packet is from the given interface. func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) { hexTsFwmarkMaskNeg := getTailscaleFwmarkMaskNeg() hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() rule := &nftables.Rule{ Table: table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: []byte(tunname), }, &expr.Counter{}, &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, &expr.Bitwise{ SourceRegister: 1, DestRegister: 1, Len: 4, Mask: hexTsFwmarkMaskNeg, Xor: hexTSSubnetRouteMark, }, &expr.Meta{ Key: expr.MetaKeyMARK, SourceRegister: true, Register: 1, }, }, } return rule, nil } // addSetSubnetRouteMarkRule adds a rule to set the subnet route mark // if the packet is from the given interface. func addSetSubnetRouteMarkRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { rule, err := createSetSubnetRouteMarkRule(table, chain, tunname) if err != nil { return fmt.Errorf("create rule: %w", err) } _ = c.AddRule(rule) if err := c.Flush(); err != nil { return fmt.Errorf("add rule: %w", err) } return nil } // createDropOutgoingPacketFromCGNATRangeRuleWithTunname creates a rule to drop // outgoing packets from the CGNAT range. func createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) { _, ipNet, err := net.ParseCIDR(tsaddr.CGNATRange().String()) if err != nil { return nil, fmt.Errorf("parse cidr: %v", err) } mask, err := hex.DecodeString(ipNet.Mask.String()) if err != nil { return nil, fmt.Errorf("decode mask: %v", err) } netip := ipNet.IP.Mask(ipNet.Mask).To4() saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1) if err != nil { return nil, fmt.Errorf("newLoadSaddrExpr: %v", err) } rule := &nftables.Rule{ Table: table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: []byte(tunname), }, saddrExpr, &expr.Bitwise{ SourceRegister: 1, DestRegister: 1, Len: 4, Mask: mask, Xor: []byte{0x00, 0x00, 0x00, 0x00}, }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: netip, }, &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictDrop, }, }, } return rule, nil } // addDropOutgoingPacketFromCGNATRangeRuleWithTunname adds a rule to drop // outgoing packets from the CGNAT range. func addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table, chain, tunname) if err != nil { return fmt.Errorf("create rule: %w", err) } _ = conn.AddRule(rule) if err := conn.Flush(); err != nil { return fmt.Errorf("add rule: %w", err) } return nil } // createAcceptOutgoingPacketRule creates a rule to accept outgoing packets // from the given interface. func createAcceptOutgoingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule { return &nftables.Rule{ Table: table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: []byte(tunname), }, &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictAccept, }, }, } } // addAcceptOutgoingPacketRule adds a rule to accept outgoing packets // from the given interface. func addAcceptOutgoingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { rule := createAcceptOutgoingPacketRule(table, chain, tunname) _ = conn.AddRule(rule) if err := conn.Flush(); err != nil { return fmt.Errorf("flush add rule: %w", err) } return nil } // createAcceptOnPortRule creates a rule to accept incoming packets to // a given destination UDP port. func createAcceptOnPortRule(table *nftables.Table, chain *nftables.Chain, port uint16) *nftables.Rule { portBytes := make([]byte, 2) binary.BigEndian.PutUint16(portBytes, port) return &nftables.Rule{ Table: table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{ Key: expr.MetaKeyL4PROTO, Register: 1, }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: []byte{unix.IPPROTO_UDP}, }, newLoadDportExpr(1), &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: portBytes, }, &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictAccept, }, }, } } // addAcceptOnPortRule adds a rule to accept incoming packets to // a given destination UDP port. func addAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error { rule := createAcceptOnPortRule(table, chain, port) _ = conn.AddRule(rule) if err := conn.Flush(); err != nil { return fmt.Errorf("flush add rule: %w", err) } return nil } // addAcceptOnPortRule removes a rule to accept incoming packets to // a given destination UDP port. func removeAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, port uint16) error { rule := createAcceptOnPortRule(table, chain, port) rule, err := findRule(conn, rule) if err != nil { return fmt.Errorf("find rule: %v", err) } _ = conn.DelRule(rule) if err := conn.Flush(); err != nil { return fmt.Errorf("flush del rule: %w", err) } return nil } // AddMagicsockPortRule adds a rule to nftables to allow incoming traffic on // the specified UDP port, so magicsock can accept incoming connections. // network must be either "udp4" or "udp6" - this determines whether the rule // is added for IPv4 or IPv6. func (n *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error { var filterTable *nftables.Table switch network { case "udp4": filterTable = n.nft4.Filter case "udp6": filterTable = n.nft6.Filter default: return fmt.Errorf("unsupported network %s", network) } inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput) if err != nil { return fmt.Errorf("get input chain: %v", err) } err = addAcceptOnPortRule(n.conn, filterTable, inputChain, port) if err != nil { return fmt.Errorf("add accept on port rule: %v", err) } return nil } // DelMagicsockPortRule removes a rule added by AddMagicsockPortRule to accept // incoming traffic on a particular UDP port. // network must be either "udp4" or "udp6" - this determines whether the rule // is removed for IPv4 or IPv6. func (n *nftablesRunner) DelMagicsockPortRule(port uint16, network string) error { var filterTable *nftables.Table switch network { case "udp4": filterTable = n.nft4.Filter case "udp6": filterTable = n.nft6.Filter default: return fmt.Errorf("unsupported network %s", network) } inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput) if err != nil { return fmt.Errorf("get input chain: %v", err) } err = removeAcceptOnPortRule(n.conn, filterTable, inputChain, port) if err != nil { return fmt.Errorf("add accept on port rule: %v", err) } return nil } // createAcceptIncomingPacketRule creates a rule to accept incoming packets to // the given interface. func createAcceptIncomingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule { return &nftables.Rule{ Table: table, Chain: chain, Exprs: []expr.Any{ &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: []byte(tunname), }, &expr.Counter{}, &expr.Verdict{ Kind: expr.VerdictAccept, }, }, } } func addAcceptIncomingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { rule := createAcceptIncomingPacketRule(table, chain, tunname) _ = conn.AddRule(rule) if err := conn.Flush(); err != nil { return fmt.Errorf("flush add rule: %w", err) } return nil } // AddBase adds some basic processing rules. func (n *nftablesRunner) AddBase(tunname string) error { if err := n.addBase4(tunname); err != nil { return fmt.Errorf("add base v4: %w", err) } if n.HasIPV6() { if err := n.addBase6(tunname); err != nil { return fmt.Errorf("add base v6: %w", err) } } return nil } // addBase4 adds some basic IPv4 processing rules. func (n *nftablesRunner) addBase4(tunname string) error { conn := n.conn inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput) if err != nil { return fmt.Errorf("get input chain v4: %v", err) } if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { return fmt.Errorf("add return chromeos vm range rule v4: %w", err) } if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { return fmt.Errorf("add drop cgnat range rule v4: %w", err) } if err = addAcceptIncomingPacketRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { return fmt.Errorf("add accept incoming packet rule v4: %w", err) } forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward) if err != nil { return fmt.Errorf("get forward chain v4: %v", err) } if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil { return fmt.Errorf("add set subnet route mark rule v4: %w", err) } if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil { return fmt.Errorf("add match subnet route mark rule v4: %w", err) } if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil { return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err) } if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil { return fmt.Errorf("add accept outgoing packet rule v4: %w", err) } if err = conn.Flush(); err != nil { return fmt.Errorf("flush base v4: %w", err) } return nil } // addBase6 adds some basic IPv6 processing rules. func (n *nftablesRunner) addBase6(tunname string) error { conn := n.conn inputChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameInput) if err != nil { return fmt.Errorf("get input chain v4: %v", err) } if err = addAcceptIncomingPacketRule(conn, n.nft6.Filter, inputChain, tunname); err != nil { return fmt.Errorf("add accept incoming packet rule v6: %w", err) } forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward) if err != nil { return fmt.Errorf("get forward chain v6: %w", err) } if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil { return fmt.Errorf("add set subnet route mark rule v6: %w", err) } if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil { return fmt.Errorf("add match subnet route mark rule v6: %w", err) } if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil { return fmt.Errorf("add accept outgoing packet rule v6: %w", err) } if err = conn.Flush(); err != nil { return fmt.Errorf("flush base v6: %w", err) } return nil } // DelBase empties, but does not remove, custom Tailscale chains from // netfilter via iptables. func (n *nftablesRunner) DelBase() error { conn := n.conn for _, table := range n.getTables() { inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput) if err != nil { return fmt.Errorf("get input chain: %v", err) } conn.FlushChain(inputChain) forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward) if err != nil { return fmt.Errorf("get forward chain: %v", err) } conn.FlushChain(forwardChain) } for _, table := range n.getNATTables() { postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { return fmt.Errorf("get postrouting chain v4: %v", err) } conn.FlushChain(postrouteChain) } return conn.Flush() } // createMatchSubnetRouteMarkRule creates a rule that matches packets // with the subnet route mark and takes the specified action. func createMatchSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, action MatchDecision) (*nftables.Rule, error) { hexTSFwmarkMask := getTailscaleFwmarkMask() hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() var endAction expr.Any endAction = &expr.Verdict{Kind: expr.VerdictAccept} if action == Masq { endAction = &expr.Masq{} } exprs := []expr.Any{ &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, &expr.Bitwise{ SourceRegister: 1, DestRegister: 1, Len: 4, Mask: hexTSFwmarkMask, Xor: []byte{0x00, 0x00, 0x00, 0x00}, }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: hexTSSubnetRouteMark, }, &expr.Counter{}, endAction, } rule := &nftables.Rule{ Table: table, Chain: chain, Exprs: exprs, } return rule, nil } // addMatchSubnetRouteMarkRule adds a rule that matches packets with // the subnet route mark and takes the specified action. func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, action MatchDecision) error { rule, err := createMatchSubnetRouteMarkRule(table, chain, action) if err != nil { return fmt.Errorf("create match subnet route mark rule: %w", err) } _ = conn.AddRule(rule) if err := conn.Flush(); err != nil { return fmt.Errorf("flush add rule: %w", err) } return nil } // AddSNATRule adds a netfilter rule to SNAT traffic destined for // local subnets. func (n *nftablesRunner) AddSNATRule() error { conn := n.conn for _, table := range n.getNATTables() { chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { return fmt.Errorf("get postrouting chain v4: %w", err) } if err = addMatchSubnetRouteMarkRule(conn, table.Nat, chain, Masq); err != nil { return fmt.Errorf("add match subnet route mark rule v4: %w", err) } } if err := conn.Flush(); err != nil { return fmt.Errorf("flush add SNAT rule: %w", err) } return nil } // DelSNATRule removes the netfilter rule to SNAT traffic destined for // local subnets. An error is returned if the rule does not exist. func (n *nftablesRunner) DelSNATRule() error { conn := n.conn hexTSFwmarkMask := getTailscaleFwmarkMask() hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() exprs := []expr.Any{ &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, &expr.Bitwise{ SourceRegister: 1, DestRegister: 1, Len: 4, Mask: hexTSFwmarkMask, }, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: hexTSSubnetRouteMark, }, &expr.Counter{}, &expr.Masq{}, } for _, table := range n.getNATTables() { chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) if err != nil { return fmt.Errorf("get postrouting chain v4: %w", err) } rule := &nftables.Rule{ Table: table.Nat, Chain: chain, Exprs: exprs, } SNATRule, err := findRule(conn, rule) if err != nil { return fmt.Errorf("find SNAT rule v4: %w", err) } if SNATRule != nil { _ = conn.DelRule(SNATRule) } } if err := conn.Flush(); err != nil { return fmt.Errorf("flush del SNAT rule: %w", err) } return nil } func nativeUint32(v uint32) []byte { b := make([]byte, 4) binary.NativeEndian.PutUint32(b, v) return b } func makeStatefulRuleExprs(tunname string) []expr.Any { return []expr.Any{ // Check if the output interface is the Tailscale interface by // first loding the OIFNAME into register 1 and comparing it // against our tunname. // // 'cmp' implicitly breaks from a rule if a comparison fails, // so if we continue past this rule we know that the packet is // going to our TUN. &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, &expr.Cmp{ Op: expr.CmpOpEq, Register: 1, Data: []byte(tunname), }, // Store the conntrack state in register 1 &expr.Ct{ Register: 1, Key: expr.CtKeySTATE, }, // Mask the state in register 1 to "hide" the ESTABLISHED and // RELATED bits (which are expected and fine); if there are any // other bits, we want them to remain. // // This operation is, in the kernel: // dst[i] = (src[i] & mask[i]) ^ xor[i] // // So, we can mask by setting the inverse of the bits we want // to remove; i.e. ESTABLISHED = 0b00000010, RELATED = // 0b00000100, so, if we assume an 8-bit state (in reality, // it's 32-bit), we can mask with 0b11111001 to clear those // bits and keep everything else (e.g. the INVALID bit which is // 0b00000001). // // TODO(andrew-d): for now, let's also allow // CtStateBitUNTRACKED, which is a state for packets that are not // tracked (marked so explicitly with an iptables rule using // --notrack); we should figure out if we want to allow this or not. &expr.Bitwise{ SourceRegister: 1, DestRegister: 1, Len: 4, Mask: nativeUint32(^(0 | expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED | expr.CtStateBitUNTRACKED)), // Xor is unused but must be specified Xor: nativeUint32(0), }, // Compare against the expected state (0, i.e. no bits set // other than maybe ESTABLISHED and RELATED). We want this // comparison to fail if there are no bits set, so that this // rule's evaluation stops and we don't fall through to the // "Drop" verdict. // // For example, if the state is ESTABLISHED (and we want to // break from this rule/accept this packet): // state = ESTABLISHED // register1 = 0b0 (since the bitwise operation cleared the ESTABLISHED bit) // // compare register1 (0b0) != 0: false // -> comparison implicitly breaks // -> continue to the next rule // // For example, if the state is NEW (and we want to continue to // the next expression and thus drop this packet): // state = NEW // register1 = 0b1000 // // compare register1 (0b1000) != 0: true // -> comparison continues to next expr &expr.Cmp{ Op: expr.CmpOpNeq, Register: 1, Data: []byte{0, 0, 0, 0}, }, // If we get here, we know that this packet is going to our TUN // device, and has a conntrack state set other than ESTABLISHED // or RELATED. We thus count and drop the packet. &expr.Counter{}, &expr.Verdict{Kind: expr.VerdictDrop}, } // TODO(andrew-d): iptables-nft writes a rule that dumps as: // // match name conntrack rev 3 // // I think this is using expr.Match against the following struct // (xt_conntrack_mtinfo3): // // https://github.com/torvalds/linux/blob/master/include/uapi/linux/netfilter/xt_conntrack.h#L64-L77 // // We could probably do something similar here, but I'm not sure if // there's any advantage. Below is an example Match statement if we // decide to do that, based on dumping the rule that iptables-nft // generates: // // _ = expr.Match{ // Name: "conntrack", // Rev: 3, // Info: &xt.ConntrackMtinfo3{ // ConntrackMtinfo2: xt.ConntrackMtinfo2{ // ConntrackMtinfoBase: xt.ConntrackMtinfoBase{ // MatchFlags: xt.ConntrackState, // InvertFlags: xt.ConntrackState, // }, // // Mask the state to remove ESTABLISHED and // // RELATED before comparing. // StateMask: expr.CtStateBitESTABLISHED | expr.CtStateBitRELATED, // }, // }, // } } // AddStatefulRule adds a netfilter rule for stateful packet filtering using // conntrack. func (n *nftablesRunner) AddStatefulRule(tunname string) error { conn := n.conn exprs := makeStatefulRuleExprs(tunname) for _, table := range n.getTables() { chain, err := getChainFromTable(conn, table.Filter, chainNameForward) if err != nil { return fmt.Errorf("get forward chain: %w", err) } // First, find the 'accept' rule that we want to insert our rule before. acceptRule := createAcceptOutgoingPacketRule(table.Filter, chain, tunname) rule, err := findRule(conn, acceptRule) if err != nil { return fmt.Errorf("find accept rule: %w", err) } conn.InsertRule(&nftables.Rule{ Table: table.Filter, Chain: chain, Exprs: exprs, // Specifying Position in an Insert operation means to // insert this rule before the specified rule. Position: rule.Handle, }) } if err := conn.Flush(); err != nil { return fmt.Errorf("flush add stateful rule: %w", err) } return nil } // DelStatefulRule removes the netfilter rule for stateful packet filtering // using conntrack. func (n *nftablesRunner) DelStatefulRule(tunname string) error { conn := n.conn exprs := makeStatefulRuleExprs(tunname) for _, table := range n.getTables() { chain, err := getChainFromTable(conn, table.Filter, chainNameForward) if err != nil { return fmt.Errorf("get forward chain: %w", err) } rule, err := findRule(conn, &nftables.Rule{ Table: table.Filter, Chain: chain, Exprs: exprs, }) if err != nil { return fmt.Errorf("find stateful rule: %w", err) } if rule != nil { conn.DelRule(rule) } } if err := conn.Flush(); err != nil { return fmt.Errorf("flush del stateful rule: %w", err) } return nil } // cleanupChain removes a jump rule from hookChainName to tsChainName, and then // the entire chain tsChainName. Errors are logged, but attempts to remove both // the jump rule and chain continue even if one errors. func cleanupChain(logf logger.Logf, conn *nftables.Conn, table *nftables.Table, hookChainName, tsChainName string) { // remove the jump first, before removing the jump destination. defaultChain, err := getChainFromTable(conn, table, hookChainName) if err != nil && !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) { logf("cleanup: did not find default chain: %s", err) } if !errors.Is(err, errorChainNotFound{table.Name, hookChainName}) { // delete hook in convention chain _ = delHookRule(conn, table, defaultChain, tsChainName) } tsChain, err := getChainFromTable(conn, table, tsChainName) if err != nil && !errors.Is(err, errorChainNotFound{table.Name, tsChainName}) { logf("cleanup: did not find ts-chain: %s", err) } if tsChain != nil { // flush and delete ts-chain conn.FlushChain(tsChain) conn.DelChain(tsChain) err = conn.Flush() logf("cleanup: delete and flush chain %s: %s", tsChainName, err) } } // NfTablesCleanUp removes all Tailscale added nftables rules. // Any errors that occur are logged to the provided logf. func NfTablesCleanUp(logf logger.Logf) { conn, err := nftables.New() if err != nil { logf("cleanup: nftables connection: %s", err) } tables, err := conn.ListTables() // both v4 and v6 if err != nil { logf("cleanup: list tables: %s", err) } for _, table := range tables { // These table names were used briefly in 1.48.0. if table.Name == "ts-filter" || table.Name == "ts-nat" { conn.DelTable(table) if err := conn.Flush(); err != nil { logf("cleanup: flush delete table %s: %s", table.Name, err) } } if table.Name == "filter" { cleanupChain(logf, conn, table, "INPUT", chainNameInput) cleanupChain(logf, conn, table, "FORWARD", chainNameForward) } if table.Name == "nat" { cleanupChain(logf, conn, table, "POSTROUTING", chainNamePostrouting) } } }