|
|
|
@ -74,21 +74,22 @@ type nftablesRunner struct {
|
|
|
|
nft6 *nftable // IPv6 tables or nil if the system does not support IPv6
|
|
|
|
nft6 *nftable // IPv6 tables or nil if the system does not support IPv6
|
|
|
|
|
|
|
|
|
|
|
|
v6Available bool // whether the host supports IPv6
|
|
|
|
v6Available bool // whether the host supports IPv6
|
|
|
|
|
|
|
|
af AddressFamilies
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
|
|
|
|
func (r *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table, *nftables.Chain, error) {
|
|
|
|
polAccept := nftables.ChainPolicyAccept
|
|
|
|
polAccept := nftables.ChainPolicyAccept
|
|
|
|
table, err := n.getNFTByAddr(dst)
|
|
|
|
table, err := r.getNFTByAddr(dst)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err)
|
|
|
|
return nil, nil, fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
|
|
|
nat, err := createTableIfNotExist(r.conn, table.Proto, "nat")
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
|
|
|
|
return nil, nil, fmt.Errorf("error ensuring nat table: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ensure prerouting chain exists
|
|
|
|
// ensure prerouting chain exists
|
|
|
|
preroutingCh, err := getOrCreateChain(n.conn, chainInfo{
|
|
|
|
preroutingCh, err := getOrCreateChain(r.conn, chainInfo{
|
|
|
|
table: nat,
|
|
|
|
table: nat,
|
|
|
|
name: "PREROUTING",
|
|
|
|
name: "PREROUTING",
|
|
|
|
chainType: nftables.ChainTypeNAT,
|
|
|
|
chainType: nftables.ChainTypeNAT,
|
|
|
|
@ -102,14 +103,14 @@ func (n *nftablesRunner) ensurePreroutingChain(dst netip.Addr) (*nftables.Table,
|
|
|
|
return nat, preroutingCh, nil
|
|
|
|
return nat, preroutingCh, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (n *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
|
|
|
|
func (r *nftablesRunner) AddDNATRule(origDst netip.Addr, dst netip.Addr) error {
|
|
|
|
nat, preroutingCh, err := n.ensurePreroutingChain(dst)
|
|
|
|
nat, preroutingCh, err := r.ensurePreroutingChain(dst)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rule := dnatRuleForChain(nat, preroutingCh, origDst, dst, nil)
|
|
|
|
rule := dnatRuleForChain(nat, preroutingCh, origDst, dst, nil)
|
|
|
|
n.conn.InsertRule(rule)
|
|
|
|
r.conn.InsertRule(rule)
|
|
|
|
return n.conn.Flush()
|
|
|
|
return r.conn.Flush()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.Addr, meta []byte) *nftables.Rule {
|
|
|
|
func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.Addr, meta []byte) *nftables.Rule {
|
|
|
|
@ -160,12 +161,12 @@ func dnatRuleForChain(t *nftables.Table, ch *nftables.Chain, origDst, dst netip.
|
|
|
|
// TODO (irbekrm): instead of doing this load balance traffic evenly to all
|
|
|
|
// TODO (irbekrm): instead of doing this load balance traffic evenly to all
|
|
|
|
// backend destinations.
|
|
|
|
// backend destinations.
|
|
|
|
// https://github.com/tailscale/tailscale/commit/d37f2f508509c6c35ad724fd75a27685b90b575b#diff-a3bcbcd1ca198799f4f768dc56fea913e1945a6b3ec9dbec89325a84a19a85e7R148-R232
|
|
|
|
// https://github.com/tailscale/tailscale/commit/d37f2f508509c6c35ad724fd75a27685b90b575b#diff-a3bcbcd1ca198799f4f768dc56fea913e1945a6b3ec9dbec89325a84a19a85e7R148-R232
|
|
|
|
func (n *nftablesRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error {
|
|
|
|
func (r *nftablesRunner) DNATWithLoadBalancer(origDst netip.Addr, dsts []netip.Addr) error {
|
|
|
|
return n.AddDNATRule(origDst, dsts[0])
|
|
|
|
return r.AddDNATRule(origDst, dsts[0])
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) error {
|
|
|
|
func (r *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr) error {
|
|
|
|
nat, preroutingCh, err := n.ensurePreroutingChain(dst)
|
|
|
|
nat, preroutingCh, err := r.ensurePreroutingChain(dst)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -197,23 +198,23 @@ func (n *nftablesRunner) DNATNonTailscaleTraffic(tunname string, dst netip.Addr)
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
n.conn.InsertRule(dnatRule)
|
|
|
|
r.conn.InsertRule(dnatRule)
|
|
|
|
return n.conn.Flush()
|
|
|
|
return r.conn.Flush()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
|
|
|
|
func (r *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
|
|
|
|
polAccept := nftables.ChainPolicyAccept
|
|
|
|
polAccept := nftables.ChainPolicyAccept
|
|
|
|
table, err := n.getNFTByAddr(dst)
|
|
|
|
table, err := r.getNFTByAddr(dst)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err)
|
|
|
|
return fmt.Errorf("error setting up nftables for IP family of %v: %w", dst, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
|
|
|
nat, err := createTableIfNotExist(r.conn, table.Proto, "nat")
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error ensuring nat table exists: %w", err)
|
|
|
|
return fmt.Errorf("error ensuring nat table exists: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ensure postrouting chain exists
|
|
|
|
// ensure postrouting chain exists
|
|
|
|
postRoutingCh, err := getOrCreateChain(n.conn, chainInfo{
|
|
|
|
postRoutingCh, err := getOrCreateChain(r.conn, chainInfo{
|
|
|
|
table: nat,
|
|
|
|
table: nat,
|
|
|
|
name: "POSTROUTING",
|
|
|
|
name: "POSTROUTING",
|
|
|
|
chainType: nftables.ChainTypeNAT,
|
|
|
|
chainType: nftables.ChainTypeNAT,
|
|
|
|
@ -225,7 +226,7 @@ func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
|
|
|
|
return fmt.Errorf("error ensuring postrouting chain: %w", err)
|
|
|
|
return fmt.Errorf("error ensuring postrouting chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
rules, err := n.conn.GetRules(nat, postRoutingCh)
|
|
|
|
rules, err := r.conn.GetRules(nat, postRoutingCh)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error listing rules: %w", err)
|
|
|
|
return fmt.Errorf("error listing rules: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -237,14 +238,14 @@ func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
|
|
|
|
if strings.EqualFold(current, snatRuleFullMatch) {
|
|
|
|
if strings.EqualFold(current, snatRuleFullMatch) {
|
|
|
|
return nil // already exists, do nothing
|
|
|
|
return nil // already exists, do nothing
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if err := n.conn.DelRule(rule); err != nil {
|
|
|
|
if err := r.conn.DelRule(rule); err != nil {
|
|
|
|
return fmt.Errorf("error deleting SNAT rule: %w", err)
|
|
|
|
return fmt.Errorf("error deleting SNAT rule: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
rule := snatRule(nat, postRoutingCh, src, dst, []byte(snatRuleFullMatch))
|
|
|
|
rule := snatRule(nat, postRoutingCh, src, dst, []byte(snatRuleFullMatch))
|
|
|
|
n.conn.AddRule(rule)
|
|
|
|
r.conn.AddRule(rule)
|
|
|
|
return n.conn.Flush()
|
|
|
|
return r.conn.Flush()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ClampMSSToPMTU ensures that all packets with TCP flags (SYN, ACK, RST) set
|
|
|
|
// ClampMSSToPMTU ensures that all packets with TCP flags (SYN, ACK, RST) set
|
|
|
|
@ -266,19 +267,19 @@ func (n *nftablesRunner) EnsureSNATForDst(src, dst netip.Addr) error {
|
|
|
|
// verdicts that would cause no further procesing within that chain. This
|
|
|
|
// verdicts that would cause no further procesing within that chain. This
|
|
|
|
// functionality is currently invoked from outside wgengine (containerboot), so
|
|
|
|
// functionality is currently invoked from outside wgengine (containerboot), so
|
|
|
|
// we don't want to race with wgengine for rule ordering within chains.
|
|
|
|
// we don't want to race with wgengine for rule ordering within chains.
|
|
|
|
func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
|
|
|
|
func (r *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
|
|
|
|
polAccept := nftables.ChainPolicyAccept
|
|
|
|
polAccept := nftables.ChainPolicyAccept
|
|
|
|
table, err := n.getNFTByAddr(addr)
|
|
|
|
table, err := r.getNFTByAddr(addr)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
|
|
|
|
return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
filterTable, err := createTableIfNotExist(n.conn, table.Proto, "filter")
|
|
|
|
filterTable, err := createTableIfNotExist(r.conn, table.Proto, "filter")
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error ensuring filter table: %w", err)
|
|
|
|
return fmt.Errorf("error ensuring filter table: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// ensure ts-clamp chain exists
|
|
|
|
// ensure ts-clamp chain exists
|
|
|
|
fwChain, err := getOrCreateChain(n.conn, chainInfo{
|
|
|
|
fwChain, err := getOrCreateChain(r.conn, chainInfo{
|
|
|
|
table: filterTable,
|
|
|
|
table: filterTable,
|
|
|
|
name: "ts-clamp",
|
|
|
|
name: "ts-clamp",
|
|
|
|
chainType: nftables.ChainTypeFilter,
|
|
|
|
chainType: nftables.ChainTypeFilter,
|
|
|
|
@ -344,8 +345,8 @@ func (n *nftablesRunner) ClampMSSToPMTU(tun string, addr netip.Addr) error {
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
n.conn.AddRule(clampRule)
|
|
|
|
r.conn.AddRule(clampRule)
|
|
|
|
return n.conn.Flush()
|
|
|
|
return r.conn.Flush()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// deleteTableIfExists deletes a nftables table via connection c if it exists
|
|
|
|
// deleteTableIfExists deletes a nftables table via connection c if it exists
|
|
|
|
@ -476,6 +477,14 @@ func getOrCreateChain(c *nftables.Conn, cinfo chainInfo) (*nftables.Chain, error
|
|
|
|
return chain, nil
|
|
|
|
return chain, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type AddressFamilies uint8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const (
|
|
|
|
|
|
|
|
FamilyBoth = FamilyIPv4 | FamilyIPv6
|
|
|
|
|
|
|
|
FamilyIPv4 = AddressFamilies(1 << iota)
|
|
|
|
|
|
|
|
FamilyIPv6
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
// NetfilterRunner abstracts helpers to run netfilter commands. It is
|
|
|
|
// NetfilterRunner abstracts helpers to run netfilter commands. It is
|
|
|
|
// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
|
|
|
|
// implemented by linuxfw.IPTablesRunner and linuxfw.NfTablesRunner.
|
|
|
|
type NetfilterRunner interface {
|
|
|
|
type NetfilterRunner interface {
|
|
|
|
@ -499,6 +508,10 @@ type NetfilterRunner interface {
|
|
|
|
// DelChains removes chains added by AddChains.
|
|
|
|
// DelChains removes chains added by AddChains.
|
|
|
|
DelChains() error
|
|
|
|
DelChains() error
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// SetAddressFamilies sets the address families (IPv4, IPv6, or both) that
|
|
|
|
|
|
|
|
// the runner should operate on.
|
|
|
|
|
|
|
|
SetAddressFamilies(AddressFamilies)
|
|
|
|
|
|
|
|
|
|
|
|
// AddBase adds rules reused by different other rules.
|
|
|
|
// AddBase adds rules reused by different other rules.
|
|
|
|
AddBase(tunname string) error
|
|
|
|
AddBase(tunname string) error
|
|
|
|
|
|
|
|
|
|
|
|
@ -521,16 +534,16 @@ type NetfilterRunner interface {
|
|
|
|
// using conntrack.
|
|
|
|
// using conntrack.
|
|
|
|
DelStatefulRule(tunname string) error
|
|
|
|
DelStatefulRule(tunname string) error
|
|
|
|
|
|
|
|
|
|
|
|
// HasIPV6 reports true if the system supports IPv6.
|
|
|
|
// HasIPV6 reports whether the system supports IPv6.
|
|
|
|
HasIPV6() bool
|
|
|
|
HasIPV6() bool
|
|
|
|
|
|
|
|
|
|
|
|
// HasIPV6NAT reports true if the system supports IPv6 NAT.
|
|
|
|
// HasIPV6NAT reports whether the system supports IPv6 NAT.
|
|
|
|
HasIPV6NAT() bool
|
|
|
|
HasIPV6NAT() bool
|
|
|
|
|
|
|
|
|
|
|
|
// HasIPV6Filter reports true if the system supports IPv6 filter tables
|
|
|
|
// HasIPV6Filter reports whether the system supports IPv6 filter tables.
|
|
|
|
// This is only meaningful for iptables implementation, where hosts have
|
|
|
|
// This is only meaningful for iptables implementation, where hosts have
|
|
|
|
// partial ipables support (i.e missing filter table). For nftables
|
|
|
|
// partial ipables support (i.e missing filter table). For the nftables
|
|
|
|
// implementation, this will default to the value of HasIPv6().
|
|
|
|
// implementation, this will default to the value of HasIPV6.
|
|
|
|
HasIPV6Filter() bool
|
|
|
|
HasIPV6Filter() bool
|
|
|
|
|
|
|
|
|
|
|
|
// AddDNATRule adds a rule to the nat/PREROUTING chain to DNAT traffic
|
|
|
|
// AddDNATRule adds a rule to the nat/PREROUTING chain to DNAT traffic
|
|
|
|
@ -645,6 +658,7 @@ func newNfTablesRunnerWithConn(logf logger.Logf, conn *nftables.Conn) *nftablesR
|
|
|
|
nft4: nft4,
|
|
|
|
nft4: nft4,
|
|
|
|
nft6: nft6,
|
|
|
|
nft6: nft6,
|
|
|
|
v6Available: supportsV6,
|
|
|
|
v6Available: supportsV6,
|
|
|
|
|
|
|
|
af: FamilyBoth,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@ -683,8 +697,12 @@ func newLoadDportExpr(destReg uint32) expr.Any {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// HasIPV6 reports true if the system supports IPv6.
|
|
|
|
// HasIPV6 reports true if the system supports IPv6.
|
|
|
|
func (n *nftablesRunner) HasIPV6() bool {
|
|
|
|
func (r *nftablesRunner) HasIPV6() bool {
|
|
|
|
return n.v6Available
|
|
|
|
return r.v6Available
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func (r *nftablesRunner) SetAddressFamilies(af AddressFamilies) {
|
|
|
|
|
|
|
|
r.af = af
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// HasIPV6NAT returns true if the system supports IPv6.
|
|
|
|
// HasIPV6NAT returns true if the system supports IPv6.
|
|
|
|
@ -692,15 +710,15 @@ func (n *nftablesRunner) HasIPV6() bool {
|
|
|
|
// NAT, so no need for a separate IPv6 NAT support check like we do for iptables.
|
|
|
|
// NAT, so no need for a separate IPv6 NAT support check like we do for iptables.
|
|
|
|
// https://tldp.org/HOWTO/Linux+IPv6-HOWTO/ch18s04.html
|
|
|
|
// https://tldp.org/HOWTO/Linux+IPv6-HOWTO/ch18s04.html
|
|
|
|
// https://wiki.nftables.org/wiki-nftables/index.php/Building_and_installing_nftables_from_sources
|
|
|
|
// https://wiki.nftables.org/wiki-nftables/index.php/Building_and_installing_nftables_from_sources
|
|
|
|
func (n *nftablesRunner) HasIPV6NAT() bool {
|
|
|
|
func (r *nftablesRunner) HasIPV6NAT() bool {
|
|
|
|
return n.v6Available
|
|
|
|
return r.v6Available
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// HasIPV6Filter returns true if system supports IPv6. There are no known edge
|
|
|
|
// HasIPV6Filter returns true if system supports IPv6. There are no known edge
|
|
|
|
// cases where nftables running on a host that supports IPv6 would not support
|
|
|
|
// cases where nftables running on a host that supports IPv6 would not support
|
|
|
|
// filter table.
|
|
|
|
// filter table.
|
|
|
|
func (n *nftablesRunner) HasIPV6Filter() bool {
|
|
|
|
func (r *nftablesRunner) HasIPV6Filter() bool {
|
|
|
|
return n.v6Available
|
|
|
|
return r.v6Available
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// findRule iterates through the rules to find the rule with matching expressions.
|
|
|
|
// findRule iterates through the rules to find the rule with matching expressions.
|
|
|
|
@ -808,30 +826,30 @@ func insertLoopbackRule(
|
|
|
|
|
|
|
|
|
|
|
|
// getNFTByAddr returns the nftables with correct IP family
|
|
|
|
// getNFTByAddr returns the nftables with correct IP family
|
|
|
|
// that we will be using for the given address.
|
|
|
|
// that we will be using for the given address.
|
|
|
|
func (n *nftablesRunner) getNFTByAddr(addr netip.Addr) (*nftable, error) {
|
|
|
|
func (r *nftablesRunner) getNFTByAddr(addr netip.Addr) (*nftable, error) {
|
|
|
|
if addr.Is6() && !n.v6Available {
|
|
|
|
if addr.Is6() && !r.v6Available {
|
|
|
|
return nil, fmt.Errorf("nftables for IPv6 are not available on this host")
|
|
|
|
return nil, fmt.Errorf("nftables for IPv6 are not available on this host")
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if addr.Is6() {
|
|
|
|
if addr.Is6() {
|
|
|
|
return n.nft6, nil
|
|
|
|
return r.nft6, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return n.nft4, nil
|
|
|
|
return r.nft4, nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// AddLoopbackRule adds an nftables rule to permit loopback traffic to
|
|
|
|
// AddLoopbackRule adds an nftables rule to permit loopback traffic to
|
|
|
|
// a local Tailscale IP. This rule is added only if it does not already exist.
|
|
|
|
// a local Tailscale IP. This rule is added only if it does not already exist.
|
|
|
|
func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
|
|
|
|
func (r *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
|
|
|
|
nf, err := n.getNFTByAddr(addr)
|
|
|
|
nf, err := r.getNFTByAddr(addr)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
|
|
|
|
return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
|
|
|
|
inputChain, err := getChainFromTable(r.conn, nf.Filter, chainNameInput)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get input chain: %w", err)
|
|
|
|
return fmt.Errorf("get input chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err := insertLoopbackRule(n.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil {
|
|
|
|
if err := insertLoopbackRule(r.conn, nf.Proto, nf.Filter, inputChain, addr); err != nil {
|
|
|
|
return fmt.Errorf("add loopback rule: %w", err)
|
|
|
|
return fmt.Errorf("add loopback rule: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@ -840,13 +858,13 @@ func (n *nftablesRunner) AddLoopbackRule(addr netip.Addr) error {
|
|
|
|
|
|
|
|
|
|
|
|
// DelLoopbackRule removes the nftables rule permitting loopback
|
|
|
|
// DelLoopbackRule removes the nftables rule permitting loopback
|
|
|
|
// traffic to a Tailscale IP.
|
|
|
|
// traffic to a Tailscale IP.
|
|
|
|
func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
|
|
|
|
func (r *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
|
|
|
|
nf, err := n.getNFTByAddr(addr)
|
|
|
|
nf, err := r.getNFTByAddr(addr)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
|
|
|
|
return fmt.Errorf("error setting up nftables for IP family of %v: %w", addr, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inputChain, err := getChainFromTable(n.conn, nf.Filter, chainNameInput)
|
|
|
|
inputChain, err := getChainFromTable(r.conn, nf.Filter, chainNameInput)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get input chain: %w", err)
|
|
|
|
return fmt.Errorf("get input chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -856,7 +874,7 @@ func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
|
|
|
|
return fmt.Errorf("create loopback rule: %w", err)
|
|
|
|
return fmt.Errorf("create loopback rule: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
existingLoopBackRule, err := findRule(n.conn, loopBackRule)
|
|
|
|
existingLoopBackRule, err := findRule(r.conn, loopBackRule)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("find loop back rule: %w", err)
|
|
|
|
return fmt.Errorf("find loop back rule: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -865,48 +883,48 @@ func (n *nftablesRunner) DelLoopbackRule(addr netip.Addr) error {
|
|
|
|
return nil
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err := n.conn.DelRule(existingLoopBackRule); err != nil {
|
|
|
|
if err := r.conn.DelRule(existingLoopBackRule); err != nil {
|
|
|
|
return fmt.Errorf("delete rule: %w", err)
|
|
|
|
return fmt.Errorf("delete rule: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return n.conn.Flush()
|
|
|
|
return r.conn.Flush()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// getTables returns tables for IP families that this host was determined to
|
|
|
|
// getTables returns tables for IP families that this host was determined to
|
|
|
|
// support (either IPv4 and IPv6 or just IPv4).
|
|
|
|
// support (either IPv4 and IPv6 or just IPv4).
|
|
|
|
func (n *nftablesRunner) getTables() []*nftable {
|
|
|
|
func (r *nftablesRunner) getTables() []*nftable {
|
|
|
|
if n.HasIPV6() {
|
|
|
|
if r.HasIPV6() {
|
|
|
|
return []*nftable{n.nft4, n.nft6}
|
|
|
|
return []*nftable{r.nft4, r.nft6}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return []*nftable{n.nft4}
|
|
|
|
return []*nftable{r.nft4}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// AddChains creates custom Tailscale chains in netfilter via nftables
|
|
|
|
// AddChains creates custom Tailscale chains in netfilter via nftables
|
|
|
|
// if the ts-chain doesn't already exist.
|
|
|
|
// if the ts-chain doesn't already exist.
|
|
|
|
func (n *nftablesRunner) AddChains() error {
|
|
|
|
func (r *nftablesRunner) AddChains() error {
|
|
|
|
polAccept := nftables.ChainPolicyAccept
|
|
|
|
polAccept := nftables.ChainPolicyAccept
|
|
|
|
for _, table := range n.getTables() {
|
|
|
|
for _, table := range r.getTables() {
|
|
|
|
// Create the filter table if it doesn't exist, this table name is the same
|
|
|
|
// Create the filter table if it doesn't exist, this table name is the same
|
|
|
|
// as the name used by iptables-nft and ufw. We install rules into the
|
|
|
|
// as the name used by iptables-nft and ufw. We install rules into the
|
|
|
|
// same conventional table so that `accept` verdicts from our jump
|
|
|
|
// same conventional table so that `accept` verdicts from our jump
|
|
|
|
// chains are conclusive.
|
|
|
|
// chains are conclusive.
|
|
|
|
filter, err := createTableIfNotExist(n.conn, table.Proto, "filter")
|
|
|
|
filter, err := createTableIfNotExist(r.conn, table.Proto, "filter")
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("create table: %w", err)
|
|
|
|
return fmt.Errorf("create table: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
table.Filter = filter
|
|
|
|
table.Filter = filter
|
|
|
|
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
|
|
|
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
|
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
|
|
|
if err = createChainIfNotExist(r.conn, chainInfo{filter, "FORWARD", nftables.ChainTypeFilter, nftables.ChainHookForward, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
|
|
|
return fmt.Errorf("create forward chain: %w", err)
|
|
|
|
return fmt.Errorf("create forward chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
|
|
|
if err = createChainIfNotExist(r.conn, chainInfo{filter, "INPUT", nftables.ChainTypeFilter, nftables.ChainHookInput, nftables.ChainPriorityFilter, &polAccept}); err != nil {
|
|
|
|
return fmt.Errorf("create input chain: %w", err)
|
|
|
|
return fmt.Errorf("create input chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Adding the tailscale chains that contain our rules.
|
|
|
|
// Adding the tailscale chains that contain our rules.
|
|
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
|
|
if err = createChainIfNotExist(r.conn, chainInfo{filter, chainNameForward, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
|
|
return fmt.Errorf("create forward chain: %w", err)
|
|
|
|
return fmt.Errorf("create forward chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if err = createChainIfNotExist(n.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
|
|
if err = createChainIfNotExist(r.conn, chainInfo{filter, chainNameInput, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
|
|
return fmt.Errorf("create input chain: %w", err)
|
|
|
|
return fmt.Errorf("create input chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@ -914,22 +932,22 @@ func (n *nftablesRunner) AddChains() error {
|
|
|
|
// as the name used by iptables-nft and ufw. We install rules into the
|
|
|
|
// as the name used by iptables-nft and ufw. We install rules into the
|
|
|
|
// same conventional table so that `accept` verdicts from our jump
|
|
|
|
// same conventional table so that `accept` verdicts from our jump
|
|
|
|
// chains are conclusive.
|
|
|
|
// chains are conclusive.
|
|
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, "nat")
|
|
|
|
nat, err := createTableIfNotExist(r.conn, table.Proto, "nat")
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("create table: %w", err)
|
|
|
|
return fmt.Errorf("create table: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
table.Nat = nat
|
|
|
|
table.Nat = nat
|
|
|
|
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
|
|
|
// Adding the "conventional chains" that are used by iptables-nft and ufw.
|
|
|
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
|
|
|
|
if err = createChainIfNotExist(r.conn, chainInfo{nat, "POSTROUTING", nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, &polAccept}); err != nil {
|
|
|
|
return fmt.Errorf("create postrouting chain: %w", err)
|
|
|
|
return fmt.Errorf("create postrouting chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Adding the tailscale chain that contains our rules.
|
|
|
|
// Adding the tailscale chain that contains our rules.
|
|
|
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
|
|
if err = createChainIfNotExist(r.conn, chainInfo{nat, chainNamePostrouting, chainTypeRegular, nil, nil, nil}); err != nil {
|
|
|
|
return fmt.Errorf("create postrouting chain: %w", err)
|
|
|
|
return fmt.Errorf("create postrouting chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
return n.conn.Flush()
|
|
|
|
return r.conn.Flush()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// These are dummy chains and tables we create to detect if nftables is
|
|
|
|
// These are dummy chains and tables we create to detect if nftables is
|
|
|
|
@ -945,24 +963,24 @@ const (
|
|
|
|
// createDummyPostroutingChains creates dummy postrouting chains in netfilter
|
|
|
|
// createDummyPostroutingChains creates dummy postrouting chains in netfilter
|
|
|
|
// via netfilter via nftables, as a last resort measure to detect that nftables
|
|
|
|
// via netfilter via nftables, as a last resort measure to detect that nftables
|
|
|
|
// can be used. It cleans up the dummy chains after creation.
|
|
|
|
// can be used. It cleans up the dummy chains after creation.
|
|
|
|
func (n *nftablesRunner) createDummyPostroutingChains() (retErr error) {
|
|
|
|
func (r *nftablesRunner) createDummyPostroutingChains() (retErr error) {
|
|
|
|
polAccept := ptr.To(nftables.ChainPolicyAccept)
|
|
|
|
polAccept := ptr.To(nftables.ChainPolicyAccept)
|
|
|
|
for _, table := range n.getTables() {
|
|
|
|
for _, table := range r.getTables() {
|
|
|
|
nat, err := createTableIfNotExist(n.conn, table.Proto, tsDummyTableName)
|
|
|
|
nat, err := createTableIfNotExist(r.conn, table.Proto, tsDummyTableName)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("create nat table: %w", err)
|
|
|
|
return fmt.Errorf("create nat table: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
defer func(fm nftables.TableFamily) {
|
|
|
|
defer func(fm nftables.TableFamily) {
|
|
|
|
if err := deleteTableIfExists(n.conn, fm, tsDummyTableName); err != nil && retErr == nil {
|
|
|
|
if err := deleteTableIfExists(r.conn, fm, tsDummyTableName); err != nil && retErr == nil {
|
|
|
|
retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err)
|
|
|
|
retErr = fmt.Errorf("delete %q table: %w", tsDummyTableName, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}(table.Proto)
|
|
|
|
}(table.Proto)
|
|
|
|
|
|
|
|
|
|
|
|
table.Nat = nat
|
|
|
|
table.Nat = nat
|
|
|
|
if err = createChainIfNotExist(n.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil {
|
|
|
|
if err = createChainIfNotExist(r.conn, chainInfo{nat, tsDummyChainName, nftables.ChainTypeNAT, nftables.ChainHookPostrouting, nftables.ChainPriorityNATSource, polAccept}); err != nil {
|
|
|
|
return fmt.Errorf("create %q chain: %w", tsDummyChainName, err)
|
|
|
|
return fmt.Errorf("create %q chain: %w", tsDummyChainName, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if err := deleteChainIfExists(n.conn, nat, tsDummyChainName); err != nil {
|
|
|
|
if err := deleteChainIfExists(r.conn, nat, tsDummyChainName); err != nil {
|
|
|
|
return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err)
|
|
|
|
return fmt.Errorf("delete %q chain: %w", tsDummyChainName, err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -990,27 +1008,27 @@ func deleteChainIfExists(c *nftables.Conn, table *nftables.Table, name string) e
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// DelChains removes the custom Tailscale chains from netfilter via nftables.
|
|
|
|
// DelChains removes the custom Tailscale chains from netfilter via nftables.
|
|
|
|
func (n *nftablesRunner) DelChains() error {
|
|
|
|
func (r *nftablesRunner) DelChains() error {
|
|
|
|
for _, table := range n.getTables() {
|
|
|
|
for _, table := range r.getTables() {
|
|
|
|
if err := deleteChainIfExists(n.conn, table.Filter, chainNameForward); err != nil {
|
|
|
|
if err := deleteChainIfExists(r.conn, table.Filter, chainNameForward); err != nil {
|
|
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if err := deleteChainIfExists(n.conn, table.Filter, chainNameInput); err != nil {
|
|
|
|
if err := deleteChainIfExists(r.conn, table.Filter, chainNameInput); err != nil {
|
|
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err := deleteChainIfExists(n.conn, n.nft4.Nat, chainNamePostrouting); err != nil {
|
|
|
|
if err := deleteChainIfExists(r.conn, r.nft4.Nat, chainNamePostrouting); err != nil {
|
|
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if n.HasIPV6NAT() {
|
|
|
|
if r.HasIPV6NAT() {
|
|
|
|
if err := deleteChainIfExists(n.conn, n.nft6.Nat, chainNamePostrouting); err != nil {
|
|
|
|
if err := deleteChainIfExists(r.conn, r.nft6.Nat, chainNamePostrouting); err != nil {
|
|
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
|
|
return fmt.Errorf("delete chain: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err := n.conn.Flush(); err != nil {
|
|
|
|
if err := r.conn.Flush(); err != nil {
|
|
|
|
return fmt.Errorf("flush: %w", err)
|
|
|
|
return fmt.Errorf("flush: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@ -1050,10 +1068,10 @@ func addHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables
|
|
|
|
|
|
|
|
|
|
|
|
// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
|
|
|
|
// AddHooks is adding rules to conventional chains like "FORWARD", "INPUT" and "POSTROUTING"
|
|
|
|
// in tables and jump from those chains to tailscale chains.
|
|
|
|
// in tables and jump from those chains to tailscale chains.
|
|
|
|
func (n *nftablesRunner) AddHooks() error {
|
|
|
|
func (r *nftablesRunner) AddHooks() error {
|
|
|
|
conn := n.conn
|
|
|
|
conn := r.conn
|
|
|
|
|
|
|
|
|
|
|
|
for _, table := range n.getTables() {
|
|
|
|
for _, table := range r.getTables() {
|
|
|
|
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
|
|
|
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
|
|
@ -1104,10 +1122,10 @@ func delHookRule(conn *nftables.Conn, table *nftables.Table, fromChain *nftables
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
|
|
|
|
// DelHooks is deleting the rules added to conventional chains to jump to tailscale chains.
|
|
|
|
func (n *nftablesRunner) DelHooks(logf logger.Logf) error {
|
|
|
|
func (r *nftablesRunner) DelHooks(logf logger.Logf) error {
|
|
|
|
conn := n.conn
|
|
|
|
conn := r.conn
|
|
|
|
|
|
|
|
|
|
|
|
for _, table := range n.getTables() {
|
|
|
|
for _, table := range r.getTables() {
|
|
|
|
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
|
|
|
inputChain, err := getChainFromTable(conn, table.Filter, "INPUT")
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
|
|
return fmt.Errorf("get INPUT chain: %w", err)
|
|
|
|
@ -1437,23 +1455,23 @@ func removeAcceptOnPortRule(conn *nftables.Conn, table *nftables.Table, chain *n
|
|
|
|
// the specified UDP port, so magicsock can accept incoming connections.
|
|
|
|
// the specified UDP port, so magicsock can accept incoming connections.
|
|
|
|
// network must be either "udp4" or "udp6" - this determines whether the rule
|
|
|
|
// network must be either "udp4" or "udp6" - this determines whether the rule
|
|
|
|
// is added for IPv4 or IPv6.
|
|
|
|
// is added for IPv4 or IPv6.
|
|
|
|
func (n *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error {
|
|
|
|
func (r *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error {
|
|
|
|
var filterTable *nftables.Table
|
|
|
|
var filterTable *nftables.Table
|
|
|
|
switch network {
|
|
|
|
switch network {
|
|
|
|
case "udp4":
|
|
|
|
case "udp4":
|
|
|
|
filterTable = n.nft4.Filter
|
|
|
|
filterTable = r.nft4.Filter
|
|
|
|
case "udp6":
|
|
|
|
case "udp6":
|
|
|
|
filterTable = n.nft6.Filter
|
|
|
|
filterTable = r.nft6.Filter
|
|
|
|
default:
|
|
|
|
default:
|
|
|
|
return fmt.Errorf("unsupported network %s", network)
|
|
|
|
return fmt.Errorf("unsupported network %s", network)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput)
|
|
|
|
inputChain, err := getChainFromTable(r.conn, filterTable, chainNameInput)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get input chain: %v", err)
|
|
|
|
return fmt.Errorf("get input chain: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
err = addAcceptOnPortRule(n.conn, filterTable, inputChain, port)
|
|
|
|
err = addAcceptOnPortRule(r.conn, filterTable, inputChain, port)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("add accept on port rule: %v", err)
|
|
|
|
return fmt.Errorf("add accept on port rule: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -1465,23 +1483,23 @@ func (n *nftablesRunner) AddMagicsockPortRule(port uint16, network string) error
|
|
|
|
// incoming traffic on a particular UDP port.
|
|
|
|
// incoming traffic on a particular UDP port.
|
|
|
|
// network must be either "udp4" or "udp6" - this determines whether the rule
|
|
|
|
// network must be either "udp4" or "udp6" - this determines whether the rule
|
|
|
|
// is removed for IPv4 or IPv6.
|
|
|
|
// is removed for IPv4 or IPv6.
|
|
|
|
func (n *nftablesRunner) DelMagicsockPortRule(port uint16, network string) error {
|
|
|
|
func (r *nftablesRunner) DelMagicsockPortRule(port uint16, network string) error {
|
|
|
|
var filterTable *nftables.Table
|
|
|
|
var filterTable *nftables.Table
|
|
|
|
switch network {
|
|
|
|
switch network {
|
|
|
|
case "udp4":
|
|
|
|
case "udp4":
|
|
|
|
filterTable = n.nft4.Filter
|
|
|
|
filterTable = r.nft4.Filter
|
|
|
|
case "udp6":
|
|
|
|
case "udp6":
|
|
|
|
filterTable = n.nft6.Filter
|
|
|
|
filterTable = r.nft6.Filter
|
|
|
|
default:
|
|
|
|
default:
|
|
|
|
return fmt.Errorf("unsupported network %s", network)
|
|
|
|
return fmt.Errorf("unsupported network %s", network)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
inputChain, err := getChainFromTable(n.conn, filterTable, chainNameInput)
|
|
|
|
inputChain, err := getChainFromTable(r.conn, filterTable, chainNameInput)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get input chain: %v", err)
|
|
|
|
return fmt.Errorf("get input chain: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
err = removeAcceptOnPortRule(n.conn, filterTable, inputChain, port)
|
|
|
|
err = removeAcceptOnPortRule(r.conn, filterTable, inputChain, port)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("add accept on port rule: %v", err)
|
|
|
|
return fmt.Errorf("add accept on port rule: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -1522,12 +1540,14 @@ func addAcceptIncomingPacketRule(conn *nftables.Conn, table *nftables.Table, cha
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// AddBase adds some basic processing rules.
|
|
|
|
// AddBase adds some basic processing rules.
|
|
|
|
func (n *nftablesRunner) AddBase(tunname string) error {
|
|
|
|
func (r *nftablesRunner) AddBase(tunname string) error {
|
|
|
|
if err := n.addBase4(tunname); err != nil {
|
|
|
|
if r.af&FamilyIPv4 != 0 {
|
|
|
|
return fmt.Errorf("add base v4: %w", err)
|
|
|
|
if err := r.addBase4(tunname); err != nil {
|
|
|
|
|
|
|
|
return fmt.Errorf("add base v4: %w", err)
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if n.HasIPV6() {
|
|
|
|
if r.af&FamilyIPv6 != 0 && r.HasIPV6() {
|
|
|
|
if err := n.addBase6(tunname); err != nil {
|
|
|
|
if err := r.addBase6(tunname); err != nil {
|
|
|
|
return fmt.Errorf("add base v6: %w", err)
|
|
|
|
return fmt.Errorf("add base v6: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -1535,41 +1555,41 @@ func (n *nftablesRunner) AddBase(tunname string) error {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// addBase4 adds some basic IPv4 processing rules.
|
|
|
|
// addBase4 adds some basic IPv4 processing rules.
|
|
|
|
func (n *nftablesRunner) addBase4(tunname string) error {
|
|
|
|
func (r *nftablesRunner) addBase4(tunname string) error {
|
|
|
|
conn := n.conn
|
|
|
|
conn := r.conn
|
|
|
|
|
|
|
|
|
|
|
|
inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput)
|
|
|
|
inputChain, err := getChainFromTable(conn, r.nft4.Filter, chainNameInput)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get input chain v4: %v", err)
|
|
|
|
return fmt.Errorf("get input chain v4: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
|
|
|
|
if err = addReturnChromeOSVMRangeRule(conn, r.nft4.Filter, inputChain, tunname); err != nil {
|
|
|
|
return fmt.Errorf("add return chromeos vm range rule v4: %w", err)
|
|
|
|
return fmt.Errorf("add return chromeos vm range rule v4: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
|
|
|
|
if err = addDropCGNATRangeRule(conn, r.nft4.Filter, inputChain, tunname); err != nil {
|
|
|
|
return fmt.Errorf("add drop cgnat range rule v4: %w", err)
|
|
|
|
return fmt.Errorf("add drop cgnat range rule v4: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if err = addAcceptIncomingPacketRule(conn, n.nft4.Filter, inputChain, tunname); err != nil {
|
|
|
|
if err = addAcceptIncomingPacketRule(conn, r.nft4.Filter, inputChain, tunname); err != nil {
|
|
|
|
return fmt.Errorf("add accept incoming packet rule v4: %w", err)
|
|
|
|
return fmt.Errorf("add accept incoming packet rule v4: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
forwardChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameForward)
|
|
|
|
forwardChain, err := getChainFromTable(conn, r.nft4.Filter, chainNameForward)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get forward chain v4: %v", err)
|
|
|
|
return fmt.Errorf("get forward chain v4: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err = addSetSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
|
|
|
if err = addSetSubnetRouteMarkRule(conn, r.nft4.Filter, forwardChain, tunname); err != nil {
|
|
|
|
return fmt.Errorf("add set subnet route mark rule v4: %w", err)
|
|
|
|
return fmt.Errorf("add set subnet route mark rule v4: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err = addMatchSubnetRouteMarkRule(conn, n.nft4.Filter, forwardChain, Accept); err != nil {
|
|
|
|
if err = addMatchSubnetRouteMarkRule(conn, r.nft4.Filter, forwardChain, Accept); err != nil {
|
|
|
|
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
|
|
|
|
return fmt.Errorf("add match subnet route mark rule v4: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
|
|
|
if err = addDropOutgoingPacketFromCGNATRangeRuleWithTunname(conn, r.nft4.Filter, forwardChain, tunname); err != nil {
|
|
|
|
return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err)
|
|
|
|
return fmt.Errorf("add drop outgoing packet from cgnat range rule v4: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err = addAcceptOutgoingPacketRule(conn, n.nft4.Filter, forwardChain, tunname); err != nil {
|
|
|
|
if err = addAcceptOutgoingPacketRule(conn, r.nft4.Filter, forwardChain, tunname); err != nil {
|
|
|
|
return fmt.Errorf("add accept outgoing packet rule v4: %w", err)
|
|
|
|
return fmt.Errorf("add accept outgoing packet rule v4: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@ -1581,31 +1601,31 @@ func (n *nftablesRunner) addBase4(tunname string) error {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// addBase6 adds some basic IPv6 processing rules.
|
|
|
|
// addBase6 adds some basic IPv6 processing rules.
|
|
|
|
func (n *nftablesRunner) addBase6(tunname string) error {
|
|
|
|
func (r *nftablesRunner) addBase6(tunname string) error {
|
|
|
|
conn := n.conn
|
|
|
|
conn := r.conn
|
|
|
|
|
|
|
|
|
|
|
|
inputChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameInput)
|
|
|
|
inputChain, err := getChainFromTable(conn, r.nft6.Filter, chainNameInput)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get input chain v4: %v", err)
|
|
|
|
return fmt.Errorf("get input chain v4: %v", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if err = addAcceptIncomingPacketRule(conn, n.nft6.Filter, inputChain, tunname); err != nil {
|
|
|
|
if err = addAcceptIncomingPacketRule(conn, r.nft6.Filter, inputChain, tunname); err != nil {
|
|
|
|
return fmt.Errorf("add accept incoming packet rule v6: %w", err)
|
|
|
|
return fmt.Errorf("add accept incoming packet rule v6: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
forwardChain, err := getChainFromTable(conn, n.nft6.Filter, chainNameForward)
|
|
|
|
forwardChain, err := getChainFromTable(conn, r.nft6.Filter, chainNameForward)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get forward chain v6: %w", err)
|
|
|
|
return fmt.Errorf("get forward chain v6: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err = addSetSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
|
|
|
|
if err = addSetSubnetRouteMarkRule(conn, r.nft6.Filter, forwardChain, tunname); err != nil {
|
|
|
|
return fmt.Errorf("add set subnet route mark rule v6: %w", err)
|
|
|
|
return fmt.Errorf("add set subnet route mark rule v6: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err = addMatchSubnetRouteMarkRule(conn, n.nft6.Filter, forwardChain, Accept); err != nil {
|
|
|
|
if err = addMatchSubnetRouteMarkRule(conn, r.nft6.Filter, forwardChain, Accept); err != nil {
|
|
|
|
return fmt.Errorf("add match subnet route mark rule v6: %w", err)
|
|
|
|
return fmt.Errorf("add match subnet route mark rule v6: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if err = addAcceptOutgoingPacketRule(conn, n.nft6.Filter, forwardChain, tunname); err != nil {
|
|
|
|
if err = addAcceptOutgoingPacketRule(conn, r.nft6.Filter, forwardChain, tunname); err != nil {
|
|
|
|
return fmt.Errorf("add accept outgoing packet rule v6: %w", err)
|
|
|
|
return fmt.Errorf("add accept outgoing packet rule v6: %w", err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@ -1618,10 +1638,10 @@ func (n *nftablesRunner) addBase6(tunname string) error {
|
|
|
|
|
|
|
|
|
|
|
|
// DelBase empties, but does not remove, custom Tailscale chains from
|
|
|
|
// DelBase empties, but does not remove, custom Tailscale chains from
|
|
|
|
// netfilter via iptables.
|
|
|
|
// netfilter via iptables.
|
|
|
|
func (n *nftablesRunner) DelBase() error {
|
|
|
|
func (r *nftablesRunner) DelBase() error {
|
|
|
|
conn := n.conn
|
|
|
|
conn := r.conn
|
|
|
|
|
|
|
|
|
|
|
|
for _, table := range n.getTables() {
|
|
|
|
for _, table := range r.getTables() {
|
|
|
|
inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
|
|
|
|
inputChain, err := getChainFromTable(conn, table.Filter, chainNameInput)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get input chain: %v", err)
|
|
|
|
return fmt.Errorf("get input chain: %v", err)
|
|
|
|
@ -1699,10 +1719,10 @@ func addMatchSubnetRouteMarkRule(conn *nftables.Conn, table *nftables.Table, cha
|
|
|
|
|
|
|
|
|
|
|
|
// AddSNATRule adds a netfilter rule to SNAT traffic destined for
|
|
|
|
// AddSNATRule adds a netfilter rule to SNAT traffic destined for
|
|
|
|
// local subnets.
|
|
|
|
// local subnets.
|
|
|
|
func (n *nftablesRunner) AddSNATRule() error {
|
|
|
|
func (r *nftablesRunner) AddSNATRule() error {
|
|
|
|
conn := n.conn
|
|
|
|
conn := r.conn
|
|
|
|
|
|
|
|
|
|
|
|
for _, table := range n.getTables() {
|
|
|
|
for _, table := range r.getTables() {
|
|
|
|
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
|
|
|
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get postrouting chain v4: %w", err)
|
|
|
|
return fmt.Errorf("get postrouting chain v4: %w", err)
|
|
|
|
@ -1745,10 +1765,10 @@ func delMatchSubnetRouteMarkMasqRule(conn *nftables.Conn, table *nftables.Table,
|
|
|
|
|
|
|
|
|
|
|
|
// DelSNATRule removes the netfilter rule to SNAT traffic destined for
|
|
|
|
// DelSNATRule removes the netfilter rule to SNAT traffic destined for
|
|
|
|
// local subnets. An error is returned if the rule does not exist.
|
|
|
|
// local subnets. An error is returned if the rule does not exist.
|
|
|
|
func (n *nftablesRunner) DelSNATRule() error {
|
|
|
|
func (r *nftablesRunner) DelSNATRule() error {
|
|
|
|
conn := n.conn
|
|
|
|
conn := r.conn
|
|
|
|
|
|
|
|
|
|
|
|
for _, table := range n.getTables() {
|
|
|
|
for _, table := range r.getTables() {
|
|
|
|
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
|
|
|
chain, err := getChainFromTable(conn, table.Nat, chainNamePostrouting)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get postrouting chain: %w", err)
|
|
|
|
return fmt.Errorf("get postrouting chain: %w", err)
|
|
|
|
@ -1886,11 +1906,11 @@ func makeStatefulRuleExprs(tunname string) []expr.Any {
|
|
|
|
|
|
|
|
|
|
|
|
// AddStatefulRule adds a netfilter rule for stateful packet filtering using
|
|
|
|
// AddStatefulRule adds a netfilter rule for stateful packet filtering using
|
|
|
|
// conntrack.
|
|
|
|
// conntrack.
|
|
|
|
func (n *nftablesRunner) AddStatefulRule(tunname string) error {
|
|
|
|
func (r *nftablesRunner) AddStatefulRule(tunname string) error {
|
|
|
|
conn := n.conn
|
|
|
|
conn := r.conn
|
|
|
|
|
|
|
|
|
|
|
|
exprs := makeStatefulRuleExprs(tunname)
|
|
|
|
exprs := makeStatefulRuleExprs(tunname)
|
|
|
|
for _, table := range n.getTables() {
|
|
|
|
for _, table := range r.getTables() {
|
|
|
|
chain, err := getChainFromTable(conn, table.Filter, chainNameForward)
|
|
|
|
chain, err := getChainFromTable(conn, table.Filter, chainNameForward)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get forward chain: %w", err)
|
|
|
|
return fmt.Errorf("get forward chain: %w", err)
|
|
|
|
@ -1922,11 +1942,11 @@ func (n *nftablesRunner) AddStatefulRule(tunname string) error {
|
|
|
|
|
|
|
|
|
|
|
|
// DelStatefulRule removes the netfilter rule for stateful packet filtering
|
|
|
|
// DelStatefulRule removes the netfilter rule for stateful packet filtering
|
|
|
|
// using conntrack.
|
|
|
|
// using conntrack.
|
|
|
|
func (n *nftablesRunner) DelStatefulRule(tunname string) error {
|
|
|
|
func (r *nftablesRunner) DelStatefulRule(tunname string) error {
|
|
|
|
conn := n.conn
|
|
|
|
conn := r.conn
|
|
|
|
|
|
|
|
|
|
|
|
exprs := makeStatefulRuleExprs(tunname)
|
|
|
|
exprs := makeStatefulRuleExprs(tunname)
|
|
|
|
for _, table := range n.getTables() {
|
|
|
|
for _, table := range r.getTables() {
|
|
|
|
chain, err := getChainFromTable(conn, table.Filter, chainNameForward)
|
|
|
|
chain, err := getChainFromTable(conn, table.Filter, chainNameForward)
|
|
|
|
if err != nil {
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("get forward chain: %w", err)
|
|
|
|
return fmt.Errorf("get forward chain: %w", err)
|
|
|
|
|