From a3c7b21cd18866cd8ceeee9b102c9e4171fcd719 Mon Sep 17 00:00:00 2001 From: KevinLiang10 Date: Mon, 19 Jun 2023 20:05:14 +0000 Subject: [PATCH] util/linuxfw: add nftables support This commit adds nftable rule injection for tailscaled. If tailscaled is started with envknob TS_DEBUG_USE_NETLINK_NFTABLES = true, the router will use nftables to manage firewall rules. Updates: #391 Signed-off-by: KevinLiang10 --- go.mod | 2 +- util/linuxfw/iptables_runner.go | 17 +- util/linuxfw/linuxfw.go | 33 +- util/linuxfw/nftables_runner.go | 977 +++++++++++++++++++++++++++ util/linuxfw/nftables_runner_test.go | 790 ++++++++++++++++++++++ wgengine/router/router_linux.go | 29 +- wgengine/router/router_linux_test.go | 8 +- 7 files changed, 1833 insertions(+), 23 deletions(-) create mode 100644 util/linuxfw/nftables_runner.go create mode 100644 util/linuxfw/nftables_runner_test.go diff --git a/go.mod b/go.mod index d07c44634..97314f32a 100644 --- a/go.mod +++ b/go.mod @@ -70,6 +70,7 @@ require ( github.com/toqueteos/webbrowser v1.2.0 github.com/u-root/u-root v0.11.0 github.com/vishvananda/netlink v1.2.1-beta.2 + github.com/vishvananda/netns v0.0.4 go.uber.org/zap v1.24.0 go4.org/mem v0.0.0-20220726221520-4f986261bf13 go4.org/netipx v0.0.0-20230303233057-f1b76eb4bb35 @@ -322,7 +323,6 @@ require ( github.com/ultraware/whitespace v0.0.5 // indirect github.com/uudashr/gocognit v1.0.6 // indirect github.com/vbatts/tar-split v0.11.2 // indirect - github.com/vishvananda/netns v0.0.4 // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xanzy/ssh-agent v0.3.3 // indirect github.com/yagipy/maintidx v1.0.0 // indirect diff --git a/util/linuxfw/iptables_runner.go b/util/linuxfw/iptables_runner.go index 754a22b22..14f2fa536 100644 --- a/util/linuxfw/iptables_runner.go +++ b/util/linuxfw/iptables_runner.go @@ -8,6 +8,7 @@ package linuxfw import ( "fmt" "net/netip" + "os/exec" "strings" "github.com/coreos/go-iptables/iptables" @@ -36,6 +37,14 @@ type iptablesRunner struct { v6NATAvailable bool } +func checkIP6TablesExists() error { + // Some distros ship ip6tables separately from iptables. + if _, err := exec.LookPath("ip6tables"); err != nil { + return fmt.Errorf("path not found: %w", err) + } + return nil +} + // NewIPTablesRunner constructs a NetfilterRunner that programs iptables rules. // If the underlying iptables library fails to initialize, that error is // returned. The runner probes for IPv6 support once at initialization time and @@ -48,9 +57,13 @@ func NewIPTablesRunner(logf logger.Logf) (*iptablesRunner, error) { supportsV6, supportsV6NAT := false, false v6err := checkIPv6(logf) - if v6err != nil { + ip6terr := checkIP6TablesExists() + switch { + case v6err != nil: logf("disabling tunneled IPv6 due to system IPv6 config: %v", v6err) - } else { + case ip6terr != nil: + logf("disabling tunneled IPv6 due to missing ip6tables: %v", ip6terr) + default: supportsV6 = true supportsV6NAT = supportsV6 && checkSupportsV6NAT() logf("v6nat = %v", supportsV6NAT) diff --git a/util/linuxfw/linuxfw.go b/util/linuxfw/linuxfw.go index dc50aa6cc..6ec152a4f 100644 --- a/util/linuxfw/linuxfw.go +++ b/util/linuxfw/linuxfw.go @@ -20,6 +20,15 @@ import ( "tailscale.com/types/logger" ) +// MatchDecision is the decision made by the firewall for a packet matched by a rule. +// It is used to decide whether to accept or masquerade a packet in addMatchSubnetRouteMarkRule. +type MatchDecision int + +const ( + Accept MatchDecision = iota + Masq +) + // The following bits are added to packet marks for Tailscale use. // // We tried to pick bits sufficiently out of the way that it's @@ -44,16 +53,12 @@ const ( // We claim bits 16:23 entirely. For now we only use the lower four // bits, leaving the higher 4 bits for future use. TailscaleFwmarkMask = "0xff0000" - TailscaleFwmarkMaskNeg = "0xff00ffff" TailscaleFwmarkMaskNum = 0xff0000 // Packet is from Tailscale and to a subnet route destination, so // is allowed to be routed through this machine. TailscaleSubnetRouteMark = "0x40000" TailscaleSubnetRouteMarkNum = 0x40000 - // This one is same value but padded to even number of digit, so - // hex decoding can work correctly. - TailscaleSubnetRouteMarkHexStr = "0x040000" // Packet was originated by tailscaled itself, and must not be // routed over the Tailscale network. @@ -61,6 +66,21 @@ const ( TailscaleBypassMarkNum = 0x80000 ) +// getTailscaleFwmarkMaskNeg returns the negation of TailscaleFwmarkMask in bytes. +func getTailscaleFwmarkMaskNeg() []byte { + return []byte{0xff, 0x00, 0xff, 0xff} +} + +// getTailscaleFwmarkMask returns the TailscaleFwmarkMask in bytes. +func getTailscaleFwmarkMask() []byte { + return []byte{0x00, 0xff, 0x00, 0x00} +} + +// getTailscaleSubnetRouteMark returns the TailscaleSubnetRouteMark in bytes. +func getTailscaleSubnetRouteMark() []byte { + return []byte{0x00, 0x04, 0x00, 0x00} +} + // errCode extracts and returns the process exit code from err, or // zero if err is nil. func errCode(err error) int { @@ -122,11 +142,6 @@ func checkIPv6(logf logger.Logf) error { return fmt.Errorf("kernel doesn't support IPv6 policy routing: %w", err) } - // Some distros ship ip6tables separately from iptables. - if _, err := exec.LookPath("ip6tables"); err != nil { - return err - } - return nil } diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go new file mode 100644 index 000000000..4d46ea104 --- /dev/null +++ b/util/linuxfw/nftables_runner.go @@ -0,0 +1,977 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "net" + "net/netip" + "reflect" + + "github.com/google/nftables" + "github.com/google/nftables/expr" + "tailscale.com/net/tsaddr" + "tailscale.com/types/logger" +) + +const ( + chainNameForward = "ts-forward" + chainNameInput = "ts-input" + chainNamePostrouting = "ts-postrouting" +) + +type chainInfo struct { + table *nftables.Table + name string + chainType nftables.ChainType + chainHook *nftables.ChainHook + chainPriority *nftables.ChainPriority +} + +type nftable struct { + Proto nftables.TableFamily + Filter *nftables.Table + Nat *nftables.Table +} + +type nftablesRunner struct { + conn *nftables.Conn + nft4 *nftable + nft6 *nftable + + v6Available bool + v6NATAvailable bool +} + +// createTableIfNotExist creates a nftables table via connection c if it does not exist within the given family. +func createTableIfNotExist(c *nftables.Conn, family nftables.TableFamily, name string) (*nftables.Table, error) { + tables, err := c.ListTables() + if err != nil { + return nil, fmt.Errorf("get tables: %w", err) + } + for _, table := range tables { + if table.Name == name && table.Family == family { + return table, nil + } + } + + t := c.AddTable(&nftables.Table{ + Family: family, + Name: name, + }) + if err := c.Flush(); err != nil { + return nil, fmt.Errorf("add table: %w", err) + } + return t, nil +} + +type errorChainNotFound struct { + chainName string + tableName string +} + +func (e errorChainNotFound) Error() string { + return fmt.Sprintf("chain %s not found in table %s", e.chainName, e.tableName) +} + +// getChainFromTable returns the chain with the given name from the given table. +// Note that a chain name is unique within a table. +func getChainFromTable(c *nftables.Conn, table *nftables.Table, name string) (*nftables.Chain, error) { + chains, err := c.ListChainsOfTableFamily(table.Family) + if err != nil { + return nil, fmt.Errorf("list chains: %w", err) + } + + for _, chain := range chains { + // Table family is already checked so table name is unique + if chain.Table.Name == table.Name && chain.Name == name { + return chain, nil + } + } + + return nil, errorChainNotFound{table.Name, name} +} + +// getChainsFromTable returns all chains from the given table. +func getChainsFromTable(c *nftables.Conn, table *nftables.Table) ([]*nftables.Chain, error) { + chains, err := c.ListChainsOfTableFamily(table.Family) + if err != nil { + return nil, fmt.Errorf("list chains: %w", err) + } + + var ret []*nftables.Chain + for _, chain := range chains { + // Table family is already checked so table name is unique + if chain.Table.Name == table.Name { + ret = append(ret, chain) + } + } + + return ret, nil +} + +// createChainIfNotExist creates a chain with the given name in the given table +// if it does not exist. +func createChainIfNotExist(c *nftables.Conn, cinfo chainInfo) error { + chain, err := getChainFromTable(c, cinfo.table, cinfo.name) + if err != nil && !errors.Is(err, errorChainNotFound{cinfo.table.Name, cinfo.name}) { + return fmt.Errorf("get chain: %w", err) + } else if err == nil { + // Chain already exists + if chain.Type != cinfo.chainType || chain.Hooknum != cinfo.chainHook || chain.Priority != cinfo.chainPriority { + return fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name) + } + return nil + } + + _ = c.AddChain(&nftables.Chain{ + Name: cinfo.name, + Table: cinfo.table, + Type: cinfo.chainType, + Hooknum: cinfo.chainHook, + Priority: cinfo.chainPriority, + }) + + if err := c.Flush(); err != nil { + return fmt.Errorf("add chain: %w", err) + } + + return nil +} + +// NewNfTablesRunner creates a new nftablesRunner without guaranteeing +// the existence of the tables and chains. +func NewNfTablesRunner(logf logger.Logf) (*nftablesRunner, error) { + conn, err := nftables.New() + if err != nil { + return nil, fmt.Errorf("nftables connection: %w", err) + } + nft4 := &nftable{Proto: nftables.TableFamilyIPv4} + + v6err := checkIPv6(logf) + if v6err != nil { + logf("disabling tunneled IPv6 due to system IPv6 config: %w", v6err) + } + supportsV6 := v6err == nil + supportsV6NAT := supportsV6 && checkSupportsV6NAT() + + var nft6 *nftable + if supportsV6 { + logf("v6nat availability: %v", supportsV6NAT) + nft6 = &nftable{Proto: nftables.TableFamilyIPv6} + } + + // TODO(KevinLiang10): convert iptables rule to nftable rules if they exist in the iptables + + return &nftablesRunner{ + conn: conn, + nft4: nft4, + nft6: nft6, + v6Available: supportsV6, + v6NATAvailable: supportsV6NAT, + }, nil +} + +// newLoadSaddrExpr creates a new nftables expression that loads the source +// address of the packet into the given register. +func newLoadSaddrExpr(proto nftables.TableFamily, destReg uint32) (expr.Any, error) { + switch proto { + case nftables.TableFamilyIPv4: + return &expr.Payload{ + DestRegister: destReg, + Base: expr.PayloadBaseNetworkHeader, + Offset: 12, + Len: 4, + }, nil + case nftables.TableFamilyIPv6: + return &expr.Payload{ + DestRegister: destReg, + Base: expr.PayloadBaseNetworkHeader, + Offset: 8, + Len: 16, + }, nil + default: + return nil, fmt.Errorf("table family %v is neither IPv4 nor IPv6", proto) + } +} + +// HasIPV6 returns true if the system supports IPv6. +func (n *nftablesRunner) HasIPV6() bool { + return n.v6Available +} + +// HasIPV6NAT returns true if the system supports IPv6 NAT. +func (n *nftablesRunner) HasIPV6NAT() bool { + return n.v6NATAvailable +} + +// findRule iterates through the rules to find the rule with matching expressions. +func findRule(conn *nftables.Conn, rule *nftables.Rule) (*nftables.Rule, error) { + rules, err := conn.GetRules(rule.Table, rule.Chain) + if err != nil { + return nil, fmt.Errorf("get nftables rules: %w", err) + } + if len(rules) == 0 { + return nil, nil + } + +ruleLoop: + for _, r := range rules { + if len(r.Exprs) != len(rule.Exprs) { + continue + } + + for i, e := range r.Exprs { + if !reflect.DeepEqual(e, rule.Exprs[i]) { + continue ruleLoop + } + } + return r, nil + } + + return nil, nil +} + +func createLoopbackRule( + proto nftables.TableFamily, + table *nftables.Table, + chain *nftables.Chain, + addr netip.Addr, +) (*nftables.Rule, error) { + saddrExpr, err := newLoadSaddrExpr(proto, 1) + if err != nil { + return nil, fmt.Errorf("newLoadSaddrExpr: %w", err) + } + loopBackRule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte("lo"), + }, + saddrExpr, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: addr.AsSlice(), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + } + return loopBackRule, nil +} + +// insertLoopbackRule inserts the TS loop back rule into +// the given chain as the first rule if it does not exist. +func insertLoopbackRule( + conn *nftables.Conn, proto nftables.TableFamily, + table *nftables.Table, chain *nftables.Chain, addr netip.Addr) error { + + loopBackRule, err := createLoopbackRule(proto, table, chain, addr) + if err != nil { + return fmt.Errorf("create loopback rule: %w", err) + } + + // If TestDial is set, we are running in test mode and we should not + // find rule because header will mismatch. + if conn.TestDial == nil { + // Check if the rule already exists. + rule, err := findRule(conn, loopBackRule) + if err != nil { + return fmt.Errorf("find rule: %w", err) + } + if rule != nil { + // Rule already exists, no need to insert. + return nil + } + } + + // This inserts the rule to the top of the chain + _ = conn.InsertRule(loopBackRule) + + if err = conn.Flush(); err != nil { + return fmt.Errorf("insert rule: %w", err) + } + return nil +} + +// getNFTByAddr returns the nftables with correct IP family +// that we will be using for the given address. +func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) *nftable { + if addr.Is6() { + return n.nft6 + } + return n.nft4 +} + +// AddLoopbackRule adds an nftables rule to permit loopback traffic to +// a local Tailscale IP. This rule is added only if it does not already exist. +func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error { + nf := n.getNFTByAddr(addr) + + inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain: %w", err) + } + + if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil { + return fmt.Errorf("add loopback rule: %w", err) + } + + return nil +} + +// DelLoopbackRule removes the nftables rule permitting loopback +// traffic to a Tailscale IP. +func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error { + nf := n.getNFTByAddr(addr) + + inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain: %w", err) + } + + loopBackRule, err := createLoopbackRule(nf.Proto, nf.Filter, inputChain, addr) + if err != nil { + return fmt.Errorf("create loopback rule: %w", err) + } + + existingLoopBackRule, err := findRule(n.conn, loopBackRule) + if err != nil { + return fmt.Errorf("find loop back rule: %w", err) + } + if existingLoopBackRule == nil { + // Rule does not exist, no need to delete. + return nil + } + + if err := n.conn.DelRule(existingLoopBackRule); err != nil { + return fmt.Errorf("delete rule: %w", err) + } + + return n.conn.Flush() +} + +// getTables gets the available nftable in nftables runner. +func (n *nftablesRunner) getTables() []*nftable { + if n.v6Available { + return []*nftable{n.nft4, n.nft6} + } + return []*nftable{n.nft4} +} + +// getNATTables gets the available nftable in nftables runner. +// If the system does not support IPv6 NAT, only the IPv4 nftable +// will be returned. +func (n *nftablesRunner) getNATTables() []*nftable { + if n.v6NATAvailable { + return n.getTables() + } + return []*nftable{n.nft4} +} + +// AddChains creates custom Tailscale chains in netfilter via nftables +// if the ts-chain doesn't already exist. +func (n *nftablesRunner) AddChains() error { + for _, table := range n.getTables() { + filter, err := createTableIfNotExist(n.conn, table.Proto, "ts-filter") + if err != nil { + return fmt.Errorf("create table: %w", err) + } + table.Filter = filter + if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityRef(-1)}); err != nil { + return fmt.Errorf("create forward chain: %w", err) + } + if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityRef(-1)}); err != nil { + return fmt.Errorf("create input chain: %w", err) + } + } + + for _, table := range n.getNATTables() { + nat, err := createTableIfNotExist(n.conn, table.Proto, "ts-nat") + if err != nil { + return fmt.Errorf("create table: %w", err) + } + table.Nat = nat + if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATDest}); err != nil { + return fmt.Errorf("create postrouting chain: %w", err) + } + } + + return n.conn.Flush() +} + +// deleteChainIfExists deletes a chain if it exists. +func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) error { + chain, err := getChainFromTable(c, table, name) + if err != nil && !errors.Is(err, errorChainNotFound{table.Name, name}) { + return fmt.Errorf("get chain: %w", err) + } else if err != nil { + // If the chain doesn't exist, we don't need to delete it. + return nil + } + + c.FlushChain(chain) + c.DelChain(chain) + + if err := c.Flush(); err != nil { + return fmt.Errorf("flush and delete chain: %w", err) + } + + return nil +} + +// DelChains removes the custom Tailscale chains from netfilter via nftables. +func (n *nftablesRunner) DelChains() error { + for _, table := range n.getTables() { + if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil { + return fmt.Errorf("delete chain: %w", err) + } + if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil { + return fmt.Errorf("delete chain: %w", err) + } + n.conn.DelTable(table.Filter) + } + + if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil { + return fmt.Errorf("delete chain: %w", err) + } + n.conn.DelTable(n.nft4.Nat) + + if n.v6NATAvailable { + if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil { + return fmt.Errorf("delete chain: %w", err) + } + n.conn.DelTable(n.nft6.Nat) + } + + if err := n.conn.Flush(); err != nil { + return fmt.Errorf("flush: %w", err) + } + + return nil +} + +// AddHooks is defined to satisfy the interface. NfTables does not require +// AddHooks, since we don't have any default tables or chains in nftables. +func (n *nftablesRunner) AddHooks() error { + return nil +} + +// DelHooks is defined to satisfy the interface. NfTables does not require +// DelHooks, since we don't have any default tables or chains in nftables. +func (n *nftablesRunner) DelHooks(logf logger.Logf) error { + return nil +} + +// maskof returns the mask of the given prefix in big endian bytes. +func maskof(pfx netip.Prefix) []byte { + mask := make([]byte, 4) + binary.BigEndian.PutUint32(mask, ^(uint32(0xffff_ffff) >> pfx.Bits())) + return mask +} + +// createRangeRule creates a rule that matches packets with source IP from the give +// range (like CGNAT range or ChromeOSVM range) and the interface is not the tunname, +// and makes the given decision. Only IPv4 is supported. +func createRangeRule( + table *nftables.Table, chain *nftables.Chain, + tunname string, rng netip.Prefix, decision expr.VerdictKind, +) (*nftables.Rule, error) { + if rng.Addr().Is6() { + return nil, errors.New("IPv6 is not supported") + } + saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1) + if err != nil { + return nil, fmt.Errorf("newLoadSaddrExpr: %w", err) + } + netip := rng.Addr().AsSlice() + mask := maskof(rng) + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte(tunname), + }, + saddrExpr, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: mask, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: netip, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: decision, + }, + }, + } + return rule, nil + +} + +// addReturnChromeOSVMRangeRule adds a rule to return if the source IP +// is in the ChromeOS VM range. +func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + _ = c.AddRule(rule) + if err = c.Flush(); err != nil { + return fmt.Errorf("add rule: %w", err) + } + return nil +} + +// addDropCGNATRangeRule adds a rule to drop if the source IP is in the +// CGNAT range. +func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + _ = c.AddRule(rule) + if err = c.Flush(); err != nil { + return fmt.Errorf("add rule: %w", err) + } + return nil +} + +// createSetSubnetRouteMarkRule creates a rule to set the subnet route +// mark if the packet is from the given interface. +func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) { + hexTsFwmarkMaskNeg := getTailscaleFwmarkMaskNeg() + hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() + + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(tunname), + }, + &expr.Counter{}, + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: hexTsFwmarkMaskNeg, + Xor: hexTSSubnetRouteMark, + }, + &expr.Meta{ + Key: expr.MetaKeyMARK, + SourceRegister: true, + Register: 1, + }, + }, + } + return rule, nil +} + +// addSetSubnetRouteMarkRule adds a rule to set the subnet route mark +// if the packet is from the given interface. +func addSetSubnetRouteMarkRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createSetSubnetRouteMarkRule(table, chain, tunname) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + _ = c.AddRule(rule) + + if err := c.Flush(); err != nil { + return fmt.Errorf("add rule: %w", err) + } + + return nil +} + +// createDropOutgoingPacketFromCGNATRangeRuleWithTunname creates a rule to drop +// outgoing packets from the CGNAT range. +func createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) { + _, ipNet, err := net.ParseCIDR(tsaddr.CGNATRange().String()) + if err != nil { + return nil, fmt.Errorf("parse cidr: %v", err) + } + mask, err := hex.DecodeString(ipNet.Mask.String()) + if err != nil { + return nil, fmt.Errorf("decode mask: %v", err) + } + netip := ipNet.IP.Mask(ipNet.Mask).To4() + saddrExpr, err := newLoadSaddrExpr(nftables.TableFamilyIPv4, 1) + if err != nil { + return nil, fmt.Errorf("newLoadSaddrExpr: %v", err) + } + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(tunname), + }, + saddrExpr, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: mask, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: netip, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictDrop, + }, + }, + } + return rule, nil +} + +// addDropOutgoingPacketFromCGNATRangeRuleWithTunname adds a rule to drop +// outgoing packets from the CGNAT range. +func addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(table, chain, tunname) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + _ = conn.AddRule(rule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("add rule: %w", err) + } + return nil +} + +// createAcceptOutgoingPacketRule creates a rule to accept outgoing packets +// from the given interface. +func createAcceptOutgoingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule { + return &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1}, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte(tunname), + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + } +} + +// addAcceptOutgoingPacketRule adds a rule to accept outgoing packets +// from the given interface. +func addAcceptOutgoingPacketRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule := createAcceptOutgoingPacketRule(table, chain, tunname) + _ = conn.AddRule(rule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush add rule: %w", err) + } + + return nil +} + +// AddBase adds some basic processing rules. +func (n *nftablesRunner) AddBase(tunname string) error { + if err := n.addBase4(tunname); err != nil { + return fmt.Errorf("add base v4: %w", err) + } + if n.HasIPV6() { + if err := n.addBase6(tunname); err != nil { + return fmt.Errorf("add base v6: %w", err) + } + } + return nil +} + +// addBase4 adds some basic IPv4 processing rules. +func (n *nftablesRunner) addBase4(tunname string) error { + conn := n.conn + + inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain v4: %v", err) + } + if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { + return fmt.Errorf("add return chromeos vm range rule v4: %w", err) + } + if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { + return fmt.Errorf("add drop cgnat range rule v4: %w", err) + } + + forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward) + if err != nil { + return fmt.Errorf("get forward chain v4: %v", err) + } + + if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil { + return fmt.Errorf("add set subnet route mark rule v4: %w", err) + } + + if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil { + return fmt.Errorf("add match subnet route mark rule v4: %w", err) + } + + if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil { + return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err) + } + + if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil { + return fmt.Errorf("add accept outgoing packet rule v4: %w", err) + } + + if err = conn.Flush(); err != nil { + return fmt.Errorf("flush base v4: %w", err) + } + + return nil +} + +// addBase6 adds some basic IPv6 processing rules. +func (n *nftablesRunner) addBase6(tunname string) error { + conn := n.conn + + forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward) + if err != nil { + return fmt.Errorf("get forward chain v6: %w", err) + } + + if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil { + return fmt.Errorf("add set subnet route mark rule v6: %w", err) + } + + if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil { + return fmt.Errorf("add match subnet route mark rule v6: %w", err) + } + + if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil { + return fmt.Errorf("add accept outgoing packet rule v6: %w", err) + } + + if err = conn.Flush(); err != nil { + return fmt.Errorf("flush base v6: %w", err) + } + + return nil +} + +// DelBase empties, but does not remove, custom Tailscale chains from +// netfilter via iptables. +func (n *nftablesRunner) DelBase() error { + conn := n.conn + + for _, table := range n.getTables() { + inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain: %v", err) + } + conn.FlushChain(inputChain) + forwardChain, err := getChainFromTable(conn, table.Filter, chainNameForward) + if err != nil { + return fmt.Errorf("get forward chain: %v", err) + } + conn.FlushChain(forwardChain) + } + + for _, table := range n.getNATTables() { + postrouteChain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) + if err != nil { + return fmt.Errorf("get postrouting chain v4: %v", err) + } + conn.FlushChain(postrouteChain) + } + + return conn.Flush() +} + +// createMatchSubnetRouteMarkRule creates a rule that matches packets +// with the subnet route mark and takes the specified action. +func createMatchSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, action MatchDecision) (*nftables.Rule, error) { + hexTSFwmarkMask := getTailscaleFwmarkMask() + hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() + + var endAction expr.Any + endAction = &expr.Verdict{Kind: expr.VerdictAccept} + if action == Masq { + endAction = &expr.Masq{} + } + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: hexTSFwmarkMask, + Xor: []byte{0x00, 0x00, 0x00, 0x00}, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: hexTSSubnetRouteMark, + }, + &expr.Counter{}, + endAction, + } + + rule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: exprs, + } + return rule, nil +} + +// addMatchSubnetRouteMarkRule adds a rule that matches packets with +// the subnet route mark and takes the specified action. +func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, chain *nftables.Chain, action MatchDecision) error { + rule, err := createMatchSubnetRouteMarkRule(table, chain, action) + if err != nil { + return fmt.Errorf("create match subnet route mark rule: %w", err) + } + _ = conn.AddRule(rule) + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush add rule: %w", err) + } + + return nil +} + +// AddSNATRule adds a netfilter rule to SNAT traffic destined for +// local subnets. +func (n *nftablesRunner) AddSNATRule() error { + conn := n.conn + + for _, table := range n.getNATTables() { + chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) + if err != nil { + return fmt.Errorf("get postrouting chain v4: %w", err) + } + + if err = addMatchSubnetRouteMarkRule(conn, table.Nat, chain, Masq); err != nil { + return fmt.Errorf("add match subnet route mark rule v4: %w", err) + } + } + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush add SNAT rule: %w", err) + } + + return nil +} + +// DelSNATRule removes the netfilter rule to SNAT traffic destined for +// local subnets. An error is returned if the rule does not exist. +func (n *nftablesRunner) DelSNATRule() error { + conn := n.conn + + hexTSFwmarkMask := getTailscaleFwmarkMask() + hexTSSubnetRouteMark := getTailscaleSubnetRouteMark() + + exprs := []expr.Any{ + &expr.Meta{Key: expr.MetaKeyMARK, Register: 1}, + &expr.Bitwise{ + SourceRegister: 1, + DestRegister: 1, + Len: 4, + Mask: hexTSFwmarkMask, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: hexTSSubnetRouteMark, + }, + &expr.Counter{}, + &expr.Masq{}, + } + + for _, table := range n.getNATTables() { + chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting) + if err != nil { + return fmt.Errorf("get postrouting chain v4: %w", err) + } + + rule := &nftables.Rule{ + Table: table.Nat, + Chain: chain, + Exprs: exprs, + } + + SNATRule, err := findRule(conn, rule) + if err != nil { + return fmt.Errorf("find SNAT rule v4: %w", err) + } + + _ = conn.DelRule(SNATRule) + } + + if err := conn.Flush(); err != nil { + return fmt.Errorf("flush del SNAT rule: %w", err) + } + + return nil +} + +// NfTablesCleanUp removes all Tailscale added nftables rules. +// Any errors that occur are logged to the provided logf. +func NfTablesCleanUp(logf logger.Logf) { + conn, err := nftables.New() + if err != nil { + logf("ERROR: nftables connection: %w", err) + } + + tables, err := conn.ListTables() // both v4 and v6 + if err != nil { + logf("ERROR: list tables: %w", err) + } + + for _, table := range tables { + if table.Name == "ts-filter" || table.Name == "ts-nat" { + conn.DelTable(table) + if err := conn.Flush(); err != nil { + logf("ERROR: flush table %s: %w", table.Name, err) + } + } + } +} diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go new file mode 100644 index 000000000..ab4543b2d --- /dev/null +++ b/util/linuxfw/nftables_runner_test.go @@ -0,0 +1,790 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package linuxfw + +import ( + "bytes" + "fmt" + "net/netip" + "os" + "runtime" + "strings" + "testing" + + "github.com/google/nftables" + "github.com/google/nftables/expr" + "github.com/mdlayher/netlink" + "github.com/vishvananda/netns" + "tailscale.com/net/tsaddr" +) + +// nfdump returns a hexdump of 4 bytes per line (like nft --debug=all), allowing +// users to make sense of large byte literals more easily. +func nfdump(b []byte) string { + var buf bytes.Buffer + i := 0 + for ; i < len(b); i += 4 { + // TODO: show printable characters as ASCII + fmt.Fprintf(&buf, "%02x %02x %02x %02x\n", + b[i], + b[i+1], + b[i+2], + b[i+3]) + } + for ; i < len(b); i++ { + fmt.Fprintf(&buf, "%02x ", b[i]) + } + return buf.String() +} + +func TestMaskof(t *testing.T) { + pfx, err := netip.ParsePrefix("192.168.1.0/24") + if err != nil { + t.Fatal(err) + } + want := []byte{0xff, 0xff, 0xff, 0x00} + if got := maskof(pfx); !bytes.Equal(got, want) { + t.Errorf("got %v; want %v", got, want) + } +} + +// linediff returns a side-by-side diff of two nfdump() return values, flagging +// lines which are not equal with an exclamation point prefix. +func linediff(a, b string) string { + var buf bytes.Buffer + fmt.Fprintf(&buf, "got -- want\n") + linesA := strings.Split(a, "\n") + linesB := strings.Split(b, "\n") + for idx, lineA := range linesA { + if idx >= len(linesB) { + break + } + lineB := linesB[idx] + prefix := "! " + if lineA == lineB { + prefix = " " + } + fmt.Fprintf(&buf, "%s%s -- %s\n", prefix, lineA, lineB) + } + return buf.String() +} + +func newTestConn(t *testing.T, want [][]byte) *nftables.Conn { + conn, err := nftables.New(nftables.WithTestDial( + func(req []netlink.Message) ([]netlink.Message, error) { + for idx, msg := range req { + b, err := msg.MarshalBinary() + if err != nil { + t.Fatal(err) + } + if len(b) < 16 { + continue + } + b = b[16:] + if len(want) == 0 { + t.Errorf("no want entry for message %d: %x", idx, b) + continue + } + if got, want := b, want[0]; !bytes.Equal(got, want) { + t.Errorf("message %d: %s", idx, linediff(nfdump(got), nfdump(want))) + } + want = want[1:] + } + return req, nil + })) + if err != nil { + t.Fatal(err) + } + return conn +} + +func TestInsertLoopbackRule(t *testing.T) { + proto := nftables.TableFamilyIPv4 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip ts-filter-test + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip ts-filter-test ts-input-test { type filter hook input priority 0 \; } + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x03\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), + // nft add rule ip ts-filter-test ts-input-test iifname "lo" ip saddr 192.168.0.2 counter accept + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x02\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x10\x01\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x06\x08\x00\x01\x00\x00\x00\x00\x01\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x06\x00\x01\x00\x6c\x6f\x00\x00\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x0c\x08\x00\x04\x00\x00\x00\x00\x04\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\xc0\xa8\x00\x02\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + testConn := newTestConn(t, want) + table := testConn.AddTable(&nftables.Table{ + Family: proto, + Name: "ts-filter-test", + }) + + chain := testConn.AddChain(&nftables.Chain{ + Name: "ts-input-test", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }) + + addr := netip.MustParseAddr("192.168.0.2") + + err := insertLoopbackRule(testConn, proto, table, chain, addr) + if err != nil { + t.Fatal(err) + } +} + +func TestInsertLoopbackRuleV6(t *testing.T) { + protoV6 := nftables.TableFamilyIPv6 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip6 ts-filter-test + []byte("\x0a\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip6 ts-filter-test ts-input-test { type filter hook input priority 0\; } + []byte("\x0a\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x03\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), + // nft add rule ip6 ts-filter-test ts-input-test iifname "lo" ip6 addr 2001:db8::1 counter accept + []byte("\x0a\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x02\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x1c\x01\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x06\x08\x00\x01\x00\x00\x00\x00\x01\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x06\x00\x01\x00\x6c\x6f\x00\x00\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x08\x08\x00\x04\x00\x00\x00\x00\x10\x38\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x2c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x18\x00\x03\x80\x14\x00\x01\x00\x20\x01\x0d\xb8\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + testConn := newTestConn(t, want) + tableV6 := testConn.AddTable(&nftables.Table{ + Family: protoV6, + Name: "ts-filter-test", + }) + + chainV6 := testConn.AddChain(&nftables.Chain{ + Name: "ts-input-test", + Table: tableV6, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }) + + addrV6 := netip.MustParseAddr("2001:db8::1") + + err := insertLoopbackRule(testConn, protoV6, tableV6, chainV6, addrV6) + if err != nil { + t.Fatal(err) + } +} + +func TestAddReturnChromeOSVMRangeRule(t *testing.T) { + proto := nftables.TableFamilyIPv4 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip ts-filter-test + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip ts-filter-test ts-input-test { type filter hook input priority 0\; } + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x03\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), + // nft add rule ip ts-filter-test ts-input-test iifname != "testTunn" ip saddr 100.115.92.0/23 counter return + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x02\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x58\x01\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x06\x08\x00\x01\x00\x00\x00\x00\x01\x30\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x10\x00\x03\x80\x0c\x00\x01\x00\x74\x65\x73\x74\x54\x75\x6e\x6e\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x0c\x08\x00\x04\x00\x00\x00\x00\x04\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\xff\xff\xfe\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x64\x73\x5c\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\xff\xff\xff\xfb"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + testConn := newTestConn(t, want) + table := testConn.AddTable(&nftables.Table{ + Family: proto, + Name: "ts-filter-test", + }) + chain := testConn.AddChain(&nftables.Chain{ + Name: "ts-input-test", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }) + err := addReturnChromeOSVMRangeRule(testConn, table, chain, "testTunn") + if err != nil { + t.Fatal(err) + } +} + +func TestAddDropCGNATRangeRule(t *testing.T) { + proto := nftables.TableFamilyIPv4 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip ts-filter-test + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip ts-filter-test ts-input-test { type filter hook input priority filter; } + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x03\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), + // nft add rule ip ts-filter-test ts-input-test iifname != "testTunn" ip saddr 100.64.0.0/10 counter drop + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x12\x00\x02\x00\x74\x73\x2d\x69\x6e\x70\x75\x74\x2d\x74\x65\x73\x74\x00\x00\x00\x58\x01\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x06\x08\x00\x01\x00\x00\x00\x00\x01\x30\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x10\x00\x03\x80\x0c\x00\x01\x00\x74\x65\x73\x74\x54\x75\x6e\x6e\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x0c\x08\x00\x04\x00\x00\x00\x00\x04\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\xff\xc0\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x64\x40\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + testConn := newTestConn(t, want) + table := testConn.AddTable(&nftables.Table{ + Family: proto, + Name: "ts-filter-test", + }) + chain := testConn.AddChain(&nftables.Chain{ + Name: "ts-input-test", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookInput, + Priority: nftables.ChainPriorityFilter, + }) + err := addDropCGNATRangeRule(testConn, table, chain, "testTunn") + if err != nil { + t.Fatal(err) + } +} + +func TestAddSetSubnetRouteMarkRule(t *testing.T) { + proto := nftables.TableFamilyIPv4 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip ts-filter-test + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip ts-filter-test ts-forward-test { type filter hook forward priority 0\; } + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x03\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), + // nft add rule ip ts-filter-test ts-forward-test iifname "testTunn" counter meta mark set mark and 0xff00ffff xor 0x40000 + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x02\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\x10\x01\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x06\x08\x00\x01\x00\x00\x00\x00\x01\x30\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x10\x00\x03\x80\x0c\x00\x01\x00\x74\x65\x73\x74\x54\x75\x6e\x6e\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\xff\x00\xff\xff\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x04\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x03\x00\x00\x00\x00\x01"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + testConn := newTestConn(t, want) + table := testConn.AddTable(&nftables.Table{ + Family: proto, + Name: "ts-filter-test", + }) + chain := testConn.AddChain(&nftables.Chain{ + Name: "ts-forward-test", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }) + err := addSetSubnetRouteMarkRule(testConn, table, chain, "testTunn") + if err != nil { + t.Fatal(err) + } +} + +func TestAddDropOutgoingPacketFromCGNATRangeRuleWithTunname(t *testing.T) { + proto := nftables.TableFamilyIPv4 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip ts-filter-test + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip ts-filter-test ts-forward-test { type filter hook forward priority 0\; } + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x03\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), + // nft add rule ip ts-filter-test ts-forward-test oifname "testTunn" ip saddr 100.64.0.0/10 counter drop + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x02\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\x58\x01\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x07\x08\x00\x01\x00\x00\x00\x00\x01\x30\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x10\x00\x03\x80\x0c\x00\x01\x00\x74\x65\x73\x74\x54\x75\x6e\x6e\x34\x00\x01\x80\x0c\x00\x01\x00\x70\x61\x79\x6c\x6f\x61\x64\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x0c\x08\x00\x04\x00\x00\x00\x00\x04\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\xff\xc0\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x64\x40\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + testConn := newTestConn(t, want) + table := testConn.AddTable(&nftables.Table{ + Family: proto, + Name: "ts-filter-test", + }) + chain := testConn.AddChain(&nftables.Chain{ + Name: "ts-forward-test", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }) + err := addDropOutgoingPacketFromCGNATRangeRuleWithTunname(testConn, table, chain, "testTunn") + if err != nil { + t.Fatal(err) + } +} + +func TestAddAcceptOutgoingPacketRule(t *testing.T) { + proto := nftables.TableFamilyIPv4 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip ts-filter-test + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip ts-filter-test ts-forward-test { type filter hook forward priority 0\; } + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x03\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), + // nft add rule ip ts-filter-test ts-forward-test oifname "testTunn" counter accept + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x02\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\xb4\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x07\x08\x00\x01\x00\x00\x00\x00\x01\x30\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x10\x00\x03\x80\x0c\x00\x01\x00\x74\x65\x73\x74\x54\x75\x6e\x6e\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + testConn := newTestConn(t, want) + table := testConn.AddTable(&nftables.Table{ + Family: proto, + Name: "ts-filter-test", + }) + chain := testConn.AddChain(&nftables.Chain{ + Name: "ts-forward-test", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }) + err := addAcceptOutgoingPacketRule(testConn, table, chain, "testTunn") + if err != nil { + t.Fatal(err) + } +} + +func TestAddMatchSubnetRouteMarkRuleMasq(t *testing.T) { + proto := nftables.TableFamilyIPv4 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip ts-nat-test + []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip ts-nat-test ts-postrouting-test { type nat hook postrouting priority 100; } + []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x03\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x04\x08\x00\x02\x00\x00\x00\x00\x64\x08\x00\x07\x00\x6e\x61\x74\x00"), + // nft add rule ip ts-nat-test ts-postrouting-test meta mark & 0x00ff0000 == 0x00040000 counter masquerade + []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\xf4\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x04\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + testConn := newTestConn(t, want) + table := testConn.AddTable(&nftables.Table{ + Family: proto, + Name: "ts-nat-test", + }) + chain := testConn.AddChain(&nftables.Chain{ + Name: "ts-postrouting-test", + Table: table, + Type: nftables.ChainTypeNAT, + Hooknum: nftables.ChainHookPostrouting, + Priority: nftables.ChainPriorityNATSource, + }) + err := addMatchSubnetRouteMarkRule(testConn, table, chain, Accept) + if err != nil { + t.Fatal(err) + } +} + +func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) { + proto := nftables.TableFamilyIPv4 + want := [][]byte{ + // batch begin + []byte("\x00\x00\x00\x0a"), + // nft add table ip ts-filter-test + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x08\x00\x02\x00\x00\x00\x00\x00"), + // nft add chain ip ts-filter-test ts-forward-test { type filter hook forward priority 0\; } + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x03\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), + // nft add rule ip ts-filter-test ts-forward-test meta mark and 0x00ff0000 eq 0x00040000 counter accept + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x02\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\xf4\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x04\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01"), + // batch end + []byte("\x00\x00\x00\x0a"), + } + testConn := newTestConn(t, want) + table := testConn.AddTable(&nftables.Table{ + Family: proto, + Name: "ts-filter-test", + }) + chain := testConn.AddChain(&nftables.Chain{ + Name: "ts-forward-test", + Table: table, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookForward, + Priority: nftables.ChainPriorityFilter, + }) + err := addMatchSubnetRouteMarkRule(testConn, table, chain, Accept) + if err != nil { + t.Fatal(err) + } +} + +func newSysConn(t *testing.T) *nftables.Conn { + t.Helper() + + runtime.LockOSThread() + + ns, err := netns.New() + if err != nil { + t.Fatalf("netns.New() failed: %v", err) + } + c, err := nftables.New(nftables.WithNetNSFd(int(ns))) + if err != nil { + t.Fatalf("nftables.New() failed: %v", err) + } + + t.Cleanup(func() { cleanupSysConn(t, ns) }) + + return c +} + +func cleanupSysConn(t *testing.T, ns netns.NsHandle) { + defer runtime.UnlockOSThread() + + if err := ns.Close(); err != nil { + t.Fatalf("newNS.Close() failed: %v", err) + } +} + +func newFakeNftablesRunner(t *testing.T, conn *nftables.Conn) *nftablesRunner { + nft4 := &nftable{Proto: nftables.TableFamilyIPv4} + nft6 := &nftable{Proto: nftables.TableFamilyIPv6} + + return &nftablesRunner{ + conn: conn, + nft4: nft4, + nft6: nft6, + v6Available: true, + v6NATAvailable: true, + } +} + +func TestAddAndDelNetfilterChains(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip(t.Name(), " requires privileges to create a namespace in order to run") + return + } + conn := newSysConn(t) + + runner := newFakeNftablesRunner(t, conn) + runner.AddChains() + + tables, err := conn.ListTables() + if err != nil { + t.Fatalf("conn.ListTables() failed: %v", err) + } + + if len(tables) != 4 { + t.Fatalf("len(tables) = %d, want 4", len(tables)) + } + + chainsV4, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + t.Fatalf("list chains failed: %v", err) + } + + if len(chainsV4) != 3 { + t.Fatalf("len(chainsV4) = %d, want 3", len(chainsV4)) + } + + chainsV6, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv6) + if err != nil { + t.Fatalf("list chains failed: %v", err) + } + + if len(chainsV6) != 3 { + t.Fatalf("len(chainsV6) = %d, want 3", len(chainsV6)) + } + + runner.DelChains() + + tables, err = conn.ListTables() + if err != nil { + t.Fatalf("conn.ListTables() failed: %v", err) + } + + if len(tables) != 0 { + t.Fatalf("len(tables) = %d, want 0", len(tables)) + } +} + +func getTsChains( + conn *nftables.Conn, + proto nftables.TableFamily) (*nftables.Chain, *nftables.Chain, *nftables.Chain, error) { + chains, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + if err != nil { + return nil, nil, nil, fmt.Errorf("list chains failed: %w", err) + } + var chainInput, chainForward, chainPostrouting *nftables.Chain + for _, chain := range chains { + switch chain.Name { + case "ts-input": + chainInput = chain + case "ts-forward": + chainForward = chain + case "ts-postrouting": + chainPostrouting = chain + } + } + return chainInput, chainForward, chainPostrouting, nil +} + +// findV4BaseRules verifies that the base rules are present in the input and forward chains. +func findV4BaseRules( + conn *nftables.Conn, + inpChain *nftables.Chain, + forwChain *nftables.Chain, + tunname string) ([]*nftables.Rule, error) { + want := []*nftables.Rule{} + rule, err := createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn) + if err != nil { + return nil, fmt.Errorf("create rule: %w", err) + } + want = append(want, rule) + rule, err = createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop) + if err != nil { + return nil, fmt.Errorf("create rule: %w", err) + } + want = append(want, rule) + rule, err = createDropOutgoingPacketFromCGNATRangeRuleWithTunname(forwChain.Table, forwChain, tunname) + if err != nil { + return nil, fmt.Errorf("create rule: %w", err) + } + want = append(want, rule) + + get := []*nftables.Rule{} + for _, rule := range want { + getRule, err := findRule(conn, rule) + if err != nil { + return nil, fmt.Errorf("find rule: %w", err) + } + get = append(get, getRule) + } + return get, nil +} + +func findCommonBaseRules( + conn *nftables.Conn, + forwChain *nftables.Chain, + tunname string) ([]*nftables.Rule, error) { + want := []*nftables.Rule{} + rule, err := createSetSubnetRouteMarkRule(forwChain.Table, forwChain, tunname) + if err != nil { + return nil, fmt.Errorf("create rule: %w", err) + } + want = append(want, rule) + rule, err = createMatchSubnetRouteMarkRule(forwChain.Table, forwChain, Accept) + if err != nil { + return nil, fmt.Errorf("create rule: %w", err) + } + want = append(want, rule) + rule = createAcceptOutgoingPacketRule(forwChain.Table, forwChain, tunname) + want = append(want, rule) + + get := []*nftables.Rule{} + for _, rule := range want { + getRule, err := findRule(conn, rule) + if err != nil { + return nil, fmt.Errorf("find rule: %w", err) + } + get = append(get, getRule) + } + + return get, nil +} + +func TestNFTAddAndDelNetfilterBase(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip(t.Name(), " requires privileges to create a namespace in order to run") + return + } + + conn := newSysConn(t) + + runner := newFakeNftablesRunner(t, conn) + runner.AddChains() + defer runner.DelChains() + runner.AddBase("testTunn") + + // check number of rules in each IPv4 TS chain + inputV4, forwardV4, postroutingV4, err := getTsChains(conn, nftables.TableFamilyIPv4) + if err != nil { + t.Fatalf("getTsChains() failed: %v", err) + } + + inputV4Rules, err := conn.GetRules(runner.nft4.Filter, inputV4) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + if len(inputV4Rules) != 2 { + t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules)) + } + + forwardV4Rules, err := conn.GetRules(runner.nft4.Filter, forwardV4) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + if len(forwardV4Rules) != 4 { + t.Fatalf("len(forwardV4Rules) = %d, want 4", len(forwardV4Rules)) + } + + postroutingV4Rules, err := conn.GetRules(runner.nft4.Nat, postroutingV4) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + if len(postroutingV4Rules) != 0 { + t.Fatalf("len(postroutingV4Rules) = %d, want 0", len(postroutingV4Rules)) + } + + _, err = findV4BaseRules(conn, inputV4, forwardV4, "testTunn") + if err != nil { + t.Fatalf("missing v4 base rule: %v", err) + } + _, err = findCommonBaseRules(conn, forwardV4, "testTunn") + if err != nil { + t.Fatalf("missing v4 base rule: %v", err) + } + + // Check number of rules in each IPv6 TS chain. + inputV6, forwardV6, postroutingV6, err := getTsChains(conn, nftables.TableFamilyIPv6) + if err != nil { + t.Fatalf("getTsChains() failed: %v", err) + } + + inputV6Rules, err := conn.GetRules(runner.nft6.Filter, inputV6) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + if len(inputV6Rules) != 0 { + t.Fatalf("len(inputV6Rules) = %d, want 0", len(inputV4Rules)) + } + + forwardV6Rules, err := conn.GetRules(runner.nft6.Filter, forwardV6) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + if len(forwardV6Rules) != 3 { + t.Fatalf("len(forwardV6Rules) = %d, want 3", len(forwardV4Rules)) + } + + postroutingV6Rules, err := conn.GetRules(runner.nft6.Nat, postroutingV6) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + if len(postroutingV6Rules) != 0 { + t.Fatalf("len(postroutingV6Rules) = %d, want 0", len(postroutingV4Rules)) + } + + _, err = findCommonBaseRules(conn, forwardV6, "testTunn") + if err != nil { + t.Fatalf("missing v6 base rule: %v", err) + } + + runner.DelBase() + + chains, err := conn.ListChains() + if err != nil { + t.Fatalf("conn.ListChains() failed: %v", err) + } + for _, chain := range chains { + chainRules, err := conn.GetRules(chain.Table, chain) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + if len(chainRules) != 0 { + t.Fatalf("len(chainRules) = %d, want 0", len(chainRules)) + } + } +} + +func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nftables.Table, chain *nftables.Chain, addr netip.Addr) (*nftables.Rule, error) { + matchingAddr := addr.AsSlice() + saddrExpr, err := newLoadSaddrExpr(proto, 1) + if err != nil { + return nil, fmt.Errorf("get expr: %w", err) + } + loopBackRule := &nftables.Rule{ + Table: table, + Chain: chain, + Exprs: []expr.Any{ + &expr.Meta{ + Key: expr.MetaKeyIIFNAME, + Register: 1, + }, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: []byte("lo"), + }, + saddrExpr, + &expr.Cmp{ + Op: expr.CmpOpEq, + Register: 1, + Data: matchingAddr, + }, + &expr.Counter{}, + &expr.Verdict{ + Kind: expr.VerdictAccept, + }, + }, + } + + existingLoopBackRule, err := findRule(conn, loopBackRule) + if err != nil { + return nil, fmt.Errorf("find loop back rule: %w", err) + } + return existingLoopBackRule, nil +} + +func TestNFTAddAndDelLoopbackRule(t *testing.T) { + if os.Geteuid() != 0 { + t.Skip(t.Name(), " requires privileges to create a namespace in order to run") + return + } + + conn := newSysConn(t) + + runner := newFakeNftablesRunner(t, conn) + runner.AddChains() + defer runner.DelChains() + runner.AddBase("testTunn") + defer runner.DelBase() + + addr := netip.MustParseAddr("192.168.0.2") + addrV6 := netip.MustParseAddr("2001:db8::2") + runner.AddLoopbackRule(addr) + runner.AddLoopbackRule(addrV6) + + inputV4, _, _, err := getTsChains(conn, nftables.TableFamilyIPv4) + if err != nil { + t.Fatalf("getTsChains() failed: %v", err) + } + + inputV4Rules, err := conn.GetRules(runner.nft4.Filter, inputV4) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + if len(inputV4Rules) != 3 { + t.Fatalf("len(inputV4Rules) = %d, want 3", len(inputV4Rules)) + } + + existingLoopBackRule, err := findLoopBackRule(conn, nftables.TableFamilyIPv4, runner.nft4.Filter, inputV4, addr) + if err != nil { + t.Fatalf("findLoopBackRule() failed: %v", err) + } + + if existingLoopBackRule.Position != 0 { + t.Fatalf("existingLoopBackRule.Handle = %d, want 0", existingLoopBackRule.Handle) + } + + inputV6, _, _, err := getTsChains(conn, nftables.TableFamilyIPv6) + if err != nil { + t.Fatalf("getTsChains() failed: %v", err) + } + + inputV6Rules, err := conn.GetRules(runner.nft6.Filter, inputV4) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + if len(inputV6Rules) != 1 { + t.Fatalf("len(inputV4Rules) = %d, want 1", len(inputV4Rules)) + } + + existingLoopBackRuleV6, err := findLoopBackRule(conn, nftables.TableFamilyIPv6, runner.nft6.Filter, inputV6, addrV6) + if err != nil { + t.Fatalf("findLoopBackRule() failed: %v", err) + } + + if existingLoopBackRuleV6.Position != 0 { + t.Fatalf("existingLoopBackRule.Handle = %d, want 0", existingLoopBackRule.Handle) + } + + runner.DelLoopbackRule(addr) + runner.DelLoopbackRule(addrV6) + + inputV4Rules, err = conn.GetRules(runner.nft4.Filter, inputV4) + if err != nil { + t.Fatalf("conn.GetRules() failed: %v", err) + } + if len(inputV4Rules) != 2 { + t.Fatalf("len(inputV4Rules) = %d, want 2", len(inputV4Rules)) + } +} diff --git a/wgengine/router/router_linux.go b/wgengine/router/router_linux.go index ee39849e6..6b723c845 100644 --- a/wgengine/router/router_linux.go +++ b/wgengine/router/router_linux.go @@ -36,9 +36,8 @@ const ( netfilterOn = preftype.NetfilterOn ) -// netfilterRunner abstracts helpers to run netfilter commands. It -// exists purely to swap out go-iptables for a fake implementation in -// tests. +// netfilterRunner abstracts helpers to run netfilter commands. It is +// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner. type netfilterRunner interface { AddLoopbackRule(addr netip.Addr) error DelLoopbackRule(addr netip.Addr) error @@ -55,14 +54,24 @@ type netfilterRunner interface { HasIPV6NAT() bool } +// newNetfilterRunner creates a netfilterRunner using either nftables or iptables. +// As nftables is still experimental, iptables will be used unless TS_DEBUG_USE_NETLINK_NFTABLES is set. func newNetfilterRunner(logf logger.Logf) (netfilterRunner, error) { var nfr netfilterRunner var err error - nfr, err = linuxfw.NewIPTablesRunner(logf) - if err != nil { - return nil, err + if envknob.Bool("TS_DEBUG_USE_NETLINK_NFTABLES") { + logf("router: using nftables") + nfr, err = linuxfw.NewNfTablesRunner(logf) + if err != nil { + return nil, err + } + } else { + logf("router: using iptables") + nfr, err = linuxfw.NewIPTablesRunner(logf) + if err != nil { + return nil, err + } } - return nfr, nil } @@ -489,9 +498,11 @@ func (r *linuxRouter) setNetfilterMode(mode preftype.NetfilterMode) error { if err := r.nfr.DelBase(); err != nil { return err } + // AddHooks adds the ts loopback rule. if err := r.nfr.AddHooks(); err != nil { return err } + // AddBase adds base ts rules if err := r.nfr.AddBase(r.tunname); err != nil { return err } @@ -1278,9 +1289,13 @@ func normalizeCIDR(cidr netip.Prefix) string { return cidr.Masked().String() } +// cleanup removes all the rules and routes that were added by the linux router. +// The function calls cleanup for both iptables and nftables since which ever +// netfilter runner is used, the cleanup function for the other one doesn't do anything. func cleanup(logf logger.Logf, interfaceName string) { if interfaceName != "userspace-networking" { linuxfw.IPTablesCleanup(logf) + linuxfw.NfTablesCleanUp(logf) } } diff --git a/wgengine/router/router_linux_test.go b/wgengine/router/router_linux_test.go index d5b3219ec..5d0263993 100644 --- a/wgengine/router/router_linux_test.go +++ b/wgengine/router/router_linux_test.go @@ -453,18 +453,18 @@ func (n *fakeIPTablesRunner) AddLoopbackRule(addr netip.Addr) error { } func (n *fakeIPTablesRunner) AddBase(tunname string) error { - if err := n.AddBase4(tunname); err != nil { + if err := n.addBase4(tunname); err != nil { return err } if n.HasIPV6() { - if err := n.AddBase6(tunname); err != nil { + if err := n.addBase6(tunname); err != nil { return err } } return nil } -func (n *fakeIPTablesRunner) AddBase4(tunname string) error { +func (n *fakeIPTablesRunner) addBase4(tunname string) error { curIPT := n.ipt4 newRules := []struct{ chain, rule string }{ {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.ChromeOSVMRange().String())}, @@ -482,7 +482,7 @@ func (n *fakeIPTablesRunner) AddBase4(tunname string) error { return nil } -func (n *fakeIPTablesRunner) AddBase6(tunname string) error { +func (n *fakeIPTablesRunner) addBase6(tunname string) error { curIPT := n.ipt6 newRules := []struct{ chain, rule string }{ {"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, linuxfw.TailscaleSubnetRouteMark, linuxfw.TailscaleFwmarkMask)},