add nft-update-addresses as package & with a NixOS module
parent
0b0e008fdb
commit
142e970bdc
@ -0,0 +1,6 @@
|
||||
{
|
||||
imports = [
|
||||
# files
|
||||
./nft-update-addresses.nix
|
||||
];
|
||||
}
|
@ -0,0 +1,186 @@
|
||||
{
|
||||
config,
|
||||
lib,
|
||||
pkgs,
|
||||
...
|
||||
}:
|
||||
let
|
||||
servName = "nft-update-addresses";
|
||||
cfg = config.services.${servName};
|
||||
settingsFormat = pkgs.formats.json { };
|
||||
mkDisableOption = desc: lib.mkEnableOption desc // { default = true; };
|
||||
# output options values
|
||||
configFile = pkgs.writeTextFile {
|
||||
name = "${servName}.json";
|
||||
text = builtins.toJSON cfg.settings; # TODO can otherwise not easily check the file for errors
|
||||
checkPhase = ''
|
||||
${lib.getExe cfg.package} --check-config --config-file "$out"
|
||||
'';
|
||||
};
|
||||
staticDefs = builtins.readFile (
|
||||
pkgs.runCommandLocal "${servName}.nftables" { } ''
|
||||
${lib.getExe cfg.package} --output-set-definitions --config-file ${configFile} > $out
|
||||
''
|
||||
);
|
||||
in
|
||||
{
|
||||
|
||||
options.services.${servName} = {
|
||||
|
||||
enable = lib.mkEnableOption "${servName} service";
|
||||
|
||||
package = lib.mkPackageOption pkgs (lib.singleton servName) { };
|
||||
|
||||
settings = lib.mkOption {
|
||||
# TODO link to docu
|
||||
description = "Configuration for ${servName}";
|
||||
type = settingsFormat.type;
|
||||
default = {
|
||||
nftTable = "nixos-fw";
|
||||
};
|
||||
example.interfaces = {
|
||||
wan0 = { };
|
||||
lan0.ports.tcp = {
|
||||
exposed = [
|
||||
{
|
||||
dest = "aa:bb:cc:dd:ee:ff";
|
||||
port = 80;
|
||||
}
|
||||
{
|
||||
dest = "aa:bb:cc:00:11:22";
|
||||
port = 80;
|
||||
}
|
||||
];
|
||||
forwarded = [
|
||||
{
|
||||
dest = "aabb-ccdd-eeff";
|
||||
lanPort = 80;
|
||||
wanPort = 80;
|
||||
}
|
||||
{
|
||||
dest = "aa.bbcc.0011.22";
|
||||
lanPort = 80;
|
||||
wanPort = 8080;
|
||||
}
|
||||
];
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
includeStaticDefinitions = mkDisableOption ''inclusion of static definitions from {option}`services.${servName}.nftablesStaticDefinitions` into the nftables config'';
|
||||
|
||||
configurationFile = lib.mkOption {
|
||||
description = "Path to configuration file used by ${servName}.";
|
||||
type = lib.types.path; # needs to be available at build time
|
||||
readOnly = true;
|
||||
default = configFile;
|
||||
defaultText = lib.literalExpression "# content as generated from config.services.${servName}.settings";
|
||||
};
|
||||
|
||||
nftablesStaticDefinitions = lib.mkOption {
|
||||
description = ''
|
||||
Static definitions provided by ${servName} when called with given configuration.
|
||||
|
||||
When {option}`services.${servName}.includeStaticDefinitions (which is default),
|
||||
these will be already included in your nftables setup.
|
||||
Otherwise, you can use the value of this output option as you prefer.
|
||||
'';
|
||||
readOnly = true;
|
||||
default = staticDefs;
|
||||
defaultText = lib.literalExpression "# as provided by ${servName}";
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
config = lib.mkIf cfg.enable {
|
||||
|
||||
assertions = [
|
||||
{
|
||||
assertion = cfg.enable -> config.networking.nftables.enable;
|
||||
message = "${servName} requires nftables to be configured";
|
||||
}
|
||||
# TODO assert for port duplications
|
||||
];
|
||||
|
||||
networking.nftables.tables.${cfg.settings.nftTable}.content = lib.mkIf cfg.includeStaticDefinitions staticDefs;
|
||||
|
||||
systemd.services.${servName} = {
|
||||
description = "IPv6 prefix updater for subnet & NAT rules for nftables router setup";
|
||||
after = [
|
||||
"nftables.service"
|
||||
"network.target"
|
||||
];
|
||||
partOf = lib.singleton "nftables.service";
|
||||
requisite = lib.singleton "nftables.service";
|
||||
wantedBy = lib.singleton "multi-user.target";
|
||||
upheldBy = lib.singleton "systemd-networkd.service";
|
||||
restartIfChanged = true;
|
||||
restartTriggers = config.systemd.services.nftables.restartTriggers;
|
||||
serviceConfig = {
|
||||
# Service
|
||||
Type = "notify-reload";
|
||||
ExecStart = lib.singleton "${lib.getExe cfg.package} ${
|
||||
lib.cli.toGNUCommandLineShell { } {
|
||||
config-file = configFile;
|
||||
ip-command = "${pkgs.iproute2}/bin/ip";
|
||||
nft-command = lib.getExe pkgs.nftables;
|
||||
}
|
||||
}";
|
||||
RestartSec = "250ms";
|
||||
RestartSteps = 3;
|
||||
RestartMaxDelaySec = "3s";
|
||||
TimeoutSec = "10s";
|
||||
Restart = "always";
|
||||
NotifyAccess = "all"; # bash script opens subprocesses in pipes
|
||||
# Paths
|
||||
ProtectProc = "noaccess";
|
||||
ProcSubset = "pid";
|
||||
CapabilityBoundingSet = [
|
||||
"CAP_BPF" # nft is compiled to bpf
|
||||
"CAP_IPC_LOCK" # ?
|
||||
"CAP_KILL" # ?
|
||||
"CAP_NET_ADMIN"
|
||||
];
|
||||
# Security
|
||||
NoNewPrivileges = true;
|
||||
# Process
|
||||
KeyringMode = "private";
|
||||
OOMScoreAdjust = 10;
|
||||
# Scheduling
|
||||
Nice = -2;
|
||||
CPUSchedulingPolicy = "fifo";
|
||||
# Sandboxing
|
||||
ProtectSystem = "strict";
|
||||
ProtectHome = true;
|
||||
PrivateTmp = true;
|
||||
PrivateDevices = true;
|
||||
PrivateNetwork = false; # breaks nftables
|
||||
PrivateIPC = true;
|
||||
PrivateUsers = false; # breaks nftables
|
||||
ProtectClock = true;
|
||||
ProtectKernelTunables = true;
|
||||
ProtectKernelModules = true; # are already loaded
|
||||
ProtectKernelLogs = true;
|
||||
ProtectControlGroups = true;
|
||||
#RestrictAddressFamilies = [
|
||||
# # ?
|
||||
# "AF_INET"
|
||||
# "AF_INET6"
|
||||
# #"AF_NETLINK"
|
||||
#];
|
||||
RestrictNamespaces = true;
|
||||
RestrictSUIDSGID = true;
|
||||
#SystemCallFilter = "@basic-io @ipc @network-io @signal @timer" # definitly will break that
|
||||
#SystemCallLog = "~"; # for debugging; should lock all system calls made
|
||||
# Resource Control
|
||||
CPUQuota = "50%";
|
||||
# TODO test to gather real values
|
||||
MemoryLow = "8M";
|
||||
MemoryHigh = "32M";
|
||||
MemoryMax = "128M";
|
||||
};
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
}
|
@ -0,0 +1,727 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import (
|
||||
ABC,
|
||||
abstractmethod,
|
||||
)
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from collections.abc import (
|
||||
Mapping,
|
||||
Sequence,
|
||||
)
|
||||
from enum import (
|
||||
Enum,
|
||||
Flag,
|
||||
auto,
|
||||
)
|
||||
from functools import cached_property
|
||||
import io
|
||||
from ipaddress import (
|
||||
IPv4Interface,
|
||||
IPv6Interface,
|
||||
IPv6Network,
|
||||
)
|
||||
from itertools import chain
|
||||
import json
|
||||
import logging
|
||||
from logging.handlers import SysLogHandler
|
||||
import os
|
||||
from pathlib import Path
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
import threading
|
||||
from threading import (
|
||||
RLock,
|
||||
Timer,
|
||||
)
|
||||
import traceback
|
||||
from typing import (
|
||||
Any,
|
||||
Iterable,
|
||||
NewType,
|
||||
NoReturn,
|
||||
Protocol,
|
||||
TypeAlias,
|
||||
TypeGuard,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from attrs import (
|
||||
define,
|
||||
field,
|
||||
)
|
||||
from systemd import daemon # type: ignore[import-untyped]
|
||||
from systemd.journal import JournalHandler # type: ignore[import-untyped]
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def raise_and_exit(args: Any) -> None:
|
||||
Timer(0.01, os._exit, args=(1,)).start()
|
||||
raise args[0]
|
||||
|
||||
|
||||
# ensure exceptions in any thread brings the program down
|
||||
# important for proper error detection via tests & in random cases in real world
|
||||
|
||||
threading.excepthook = raise_and_exit
|
||||
|
||||
|
||||
JsonVal: TypeAlias = Union["JsonObj", "JsonList", str, int, bool]
|
||||
JsonList: TypeAlias = Sequence[JsonVal]
|
||||
JsonObj: TypeAlias = Mapping[str, JsonVal]
|
||||
|
||||
T = TypeVar("T", contravariant=True)
|
||||
|
||||
MACAddress = NewType("MACAddress", str)
|
||||
# format: aabbccddeeff (lower-case, without separators)
|
||||
NftProtocol = NewType("NftProtocol", str) # e.g. tcp, udp, …
|
||||
Port = NewType("Port", int)
|
||||
IfName = NewType("IfName", str)
|
||||
NftTable = NewType("NftTable", str)
|
||||
|
||||
|
||||
def to_mac(mac_str: str) -> MACAddress:
|
||||
eui48 = re.sub(r"[.:_-]", "", mac_str.lower())
|
||||
if not is_mac(eui48):
|
||||
raise ValueError(f"invalid MAC address / EUI48: {mac_str}")
|
||||
return MACAddress(eui48)
|
||||
|
||||
|
||||
def is_mac(mac_str: str) -> TypeGuard[MACAddress]:
|
||||
return re.match(r"^[0-9a-f]{12}$", mac_str) != None
|
||||
|
||||
|
||||
def to_port(port_str: str | int) -> Port:
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"invalid port number: {port_str}") from exc
|
||||
if not is_port(port):
|
||||
raise ValueError(f"invalid port number: {port_str}")
|
||||
return Port(port)
|
||||
|
||||
|
||||
def is_port(port: int) -> TypeGuard[Port]:
|
||||
return 0 < port < 65536
|
||||
|
||||
|
||||
def slaac_eui48(prefix: IPv6Network, eui48: MACAddress) -> IPv6Interface:
|
||||
if prefix.prefixlen > 64:
|
||||
raise ValueError(
|
||||
f"a SLAAC IPv6 address requires a CIDR of at least /64, got {prefix}"
|
||||
)
|
||||
eui64 = eui48[0:6] + "fffe" + eui48[6:]
|
||||
modified = hex(int(eui64[0:2], 16) ^ 2)[2:].zfill(2) + eui64[2:]
|
||||
euil = int(modified, 16)
|
||||
return IPv6Interface(f"{prefix[euil].compressed}/{prefix.prefixlen}")
|
||||
|
||||
|
||||
class UpdateHandler(Protocol[T]):
|
||||
def update(self, data: T) -> None:
|
||||
...
|
||||
|
||||
def update_stack(self, data: Sequence[T]) -> None:
|
||||
...
|
||||
|
||||
|
||||
class UpdateStackHandler(UpdateHandler[T], ABC):
|
||||
def update(self, data: T) -> None:
|
||||
return self._update_stack((data,))
|
||||
|
||||
def update_stack(self, data: Sequence[T]) -> None:
|
||||
if len(data) <= 0:
|
||||
logger.warning(
|
||||
f"[bug, please report upstream] received empty data in update_stack. Traceback:\n{''.join(traceback.format_stack())}"
|
||||
)
|
||||
return
|
||||
return self._update_stack(data)
|
||||
|
||||
@abstractmethod
|
||||
def _update_stack(self, data: Sequence[T]) -> None:
|
||||
...
|
||||
|
||||
|
||||
class IgnoreHandler(UpdateStackHandler[object]):
|
||||
def _update_stack(self, data: Sequence[object]) -> None:
|
||||
return
|
||||
|
||||
|
||||
@define(
|
||||
kw_only=True,
|
||||
slots=False,
|
||||
)
|
||||
class UpdateBurstHandler(UpdateStackHandler[T]):
|
||||
burst_interval: float
|
||||
handler: Sequence[UpdateHandler[T]]
|
||||
__lock: RLock = field(factory=RLock)
|
||||
__updates: list[T] = field(factory=list)
|
||||
__timer: Timer | None = None
|
||||
|
||||
def _update_stack(self, data: Sequence[T]) -> None:
|
||||
with self.__lock:
|
||||
self.__updates.extend(data)
|
||||
self.__refresh_timer()
|
||||
|
||||
def __refresh_timer(self) -> None:
|
||||
with self.__lock:
|
||||
if self.__timer is not None:
|
||||
# try to cancel
|
||||
# not a problem if timer already elapsed but before processing really started
|
||||
# because due to using locks when accessing updates
|
||||
self.__timer.cancel()
|
||||
self.__timer = Timer(
|
||||
interval=self.burst_interval,
|
||||
function=self.__process_updates,
|
||||
)
|
||||
self.__timer.start()
|
||||
|
||||
def __process_updates(self) -> None:
|
||||
with self.__lock:
|
||||
self.__timer = None
|
||||
if not self.__updates:
|
||||
return
|
||||
updates = self.__updates
|
||||
self.__updates = []
|
||||
for handler in self.handler:
|
||||
handler.update_stack(updates)
|
||||
|
||||
|
||||
class IpFlag(Flag):
|
||||
dynamic = auto()
|
||||
mngtmpaddr = auto()
|
||||
noprefixroute = auto()
|
||||
temporary = auto()
|
||||
tentiative = auto()
|
||||
|
||||
@staticmethod
|
||||
def parse_str(flags_str: Sequence[str], ignore_unknown: bool = True) -> IpFlag:
|
||||
flags = IpFlag(0)
|
||||
for flag in flags_str:
|
||||
flag = flag.lower()
|
||||
member = IpFlag.__members__.get(flag)
|
||||
if member is not None:
|
||||
flags |= member
|
||||
elif not ignore_unknown:
|
||||
raise Exception(f"Unrecognized IpFlag: {flag}")
|
||||
return flags
|
||||
|
||||
|
||||
IP_MON_PATTERN = re.compile(
|
||||
r"""(?x)^
|
||||
(?P<deleted>[Dd]eleted\s+)?
|
||||
(?P<ifindex>\d+):\s+
|
||||
(?P<ifname>\S+)\s+
|
||||
(?P<type>inet6?)\s+
|
||||
(?P<ip>\S+)\s+
|
||||
#(?:metric\s+\S+\s+)? # sometimes possible
|
||||
#(?:brd\s+\S+\s+)? # broadcast IP on inet
|
||||
(?:\S+\s+\S+\s+)* # abstracted irrelevant attributes
|
||||
(?:scope\s+(?P<scope>\S+)\s+)
|
||||
(?P<flags>(?:(\S+)\s)*) # (single spaces required for parser below to work correctly)
|
||||
(?:\S+)? # random interface name repetition on inet
|
||||
[\\]\s+
|
||||
.* # lifetimes which are not interesting yet
|
||||
$"""
|
||||
)
|
||||
|
||||
|
||||
@define(
|
||||
frozen=True,
|
||||
kw_only=True,
|
||||
)
|
||||
class IpAddressUpdate:
|
||||
deleted: bool
|
||||
ifindex: int
|
||||
ifname: IfName
|
||||
ip: IPv4Interface | IPv6Interface
|
||||
scope: str
|
||||
flags: IpFlag
|
||||
|
||||
@staticmethod
|
||||
def parse_line(line: str) -> IpAddressUpdate:
|
||||
m = IP_MON_PATTERN.search(line)
|
||||
if not m:
|
||||
raise Exception(f"Could not parse ip monitor output: {line!r}")
|
||||
grp = m.groupdict()
|
||||
ip_type: type[IPv4Interface | IPv6Interface] = (
|
||||
IPv6Interface if grp["type"] == "inet6" else IPv4Interface
|
||||
)
|
||||
try:
|
||||
ip = ip_type(grp["ip"])
|
||||
except ValueError as e:
|
||||
raise Exception(
|
||||
f"Could not parse ip monitor output, invalid IP: {grp['ip']!r}"
|
||||
) from e
|
||||
flags = IpFlag.parse_str(grp["flags"].strip().split(" "))
|
||||
return IpAddressUpdate(
|
||||
deleted=grp["deleted"] != None,
|
||||
ifindex=int(grp["ifindex"]),
|
||||
ifname=IfName(grp["ifname"]),
|
||||
ip=ip,
|
||||
scope=grp["scope"],
|
||||
flags=flags,
|
||||
)
|
||||
|
||||
|
||||
def monitor_ip(
|
||||
ip_cmd: list[str],
|
||||
handler: UpdateHandler[IpAddressUpdate],
|
||||
) -> NoReturn:
|
||||
proc = subprocess.Popen(
|
||||
ip_cmd + ["-o", "monitor", "address"],
|
||||
stdout=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
# initial kickoff (AFTER starting monitoring, to not miss any update)
|
||||
logger.info("kickoff IP monitoring with current data")
|
||||
res = subprocess.run(
|
||||
ip_cmd + ["-o", "address", "show"],
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
for line in res.stdout.splitlines(keepends=False):
|
||||
line = line.rstrip()
|
||||
if line == "":
|
||||
continue
|
||||
update = IpAddressUpdate.parse_line(line)
|
||||
logger.debug(f"pass IP update: {update!r}")
|
||||
handler.update(update)
|
||||
logger.info("loading kickoff finished; start regular monitoring")
|
||||
while True:
|
||||
rc = proc.poll()
|
||||
if rc != None:
|
||||
# flush stdout for easier debugging
|
||||
logger.error("Last stdout of monitor process:")
|
||||
logger.error(proc.stdout.read()) # type: ignore[union-attr]
|
||||
raise Exception(f"Monitor process crashed with returncode {rc}")
|
||||
line = proc.stdout.readline().rstrip() # type: ignore[union-attr]
|
||||
if not line:
|
||||
continue
|
||||
logger.info("IP change detected")
|
||||
update = IpAddressUpdate.parse_line(line)
|
||||
logger.debug(f"pass IP update: {update!r}")
|
||||
handler.update(update)
|
||||
|
||||
|
||||
class InterfaceUpdateHandler(UpdateStackHandler[IpAddressUpdate]):
|
||||
# TODO regularly check (i.e. 1 hour) if stored lists are still correct
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: InterfaceConfig,
|
||||
nft_handler: UpdateHandler[NftUpdate],
|
||||
) -> None:
|
||||
self.nft_handler = nft_handler
|
||||
self.lock = RLock()
|
||||
self.config = config
|
||||
self.ipv4Addrs = list[IPv4Interface]()
|
||||
self.ipv6Addrs = list[IPv6Interface]()
|
||||
|
||||
def _update_stack(self, data: Sequence[IpAddressUpdate]) -> None:
|
||||
nft_updates = tuple(
|
||||
chain.from_iterable(self.__gen_updates(single) for single in data)
|
||||
)
|
||||
if len(nft_updates) <= 0:
|
||||
return
|
||||
self.nft_handler.update_stack(nft_updates)
|
||||
|
||||
def __gen_updates(self, data: IpAddressUpdate) -> Iterable[NftUpdate]:
|
||||
if data.ifname != self.config.ifname:
|
||||
return
|
||||
if data.ip.is_link_local:
|
||||
logger.debug(
|
||||
f"{self.config.ifname}: ignore change for IP {data.ip} because link-local"
|
||||
)
|
||||
return
|
||||
if IpFlag.temporary in data.flags:
|
||||
logger.debug(
|
||||
f"{self.config.ifname}: ignore change for IP {data.ip} because temporary"
|
||||
)
|
||||
return # ignore IPv6 privacy extension addresses
|
||||
if IpFlag.tentiative in data.flags:
|
||||
logger.debug(
|
||||
f"{self.config.ifname}: ignore change for IP {data.ip} because tentiative"
|
||||
)
|
||||
return # ignore (yet) tentiative addresses
|
||||
logger.debug(f"{self.config.ifname}: process change of IP {data.ip}")
|
||||
with self.lock:
|
||||
ip_list: list[IPv4Interface] | list[IPv6Interface] = (
|
||||
self.ipv6Addrs if isinstance(data.ip, IPv6Interface) else self.ipv4Addrs
|
||||
)
|
||||
if data.deleted != (data.ip in ip_list):
|
||||
return # no change required
|
||||
if data.deleted:
|
||||
logger.info(f"{self.config.ifname}: deleted IP {data.ip}")
|
||||
ip_list.remove(data.ip) # type: ignore[arg-type]
|
||||
else:
|
||||
logger.info(f"{self.config.ifname}: discovered IP {data.ip}")
|
||||
ip_list.append(data.ip) # type: ignore[arg-type]
|
||||
set_prefix = f"{self.config.ifname}v{data.ip.version}"
|
||||
op = NftValueOperation.if_deleted(data.deleted)
|
||||
yield NftUpdate(
|
||||
obj_type="set",
|
||||
obj_name=f"{set_prefix}net",
|
||||
operation=op,
|
||||
values=(data.ip.network.compressed,),
|
||||
)
|
||||
link_local_space = IPv6Network("fc00::/7") # because ip.is_private is wrong
|
||||
if data.ip in link_local_space:
|
||||
logger.debug(
|
||||
f"{self.config.ifname}: only updated {set_prefix}net for changes in fc00::/7"
|
||||
)
|
||||
return
|
||||
yield NftUpdate(
|
||||
obj_type="set",
|
||||
obj_name=f"{set_prefix}addr",
|
||||
operation=op,
|
||||
values=(data.ip.ip.compressed,),
|
||||
)
|
||||
if data.ip.version != 6:
|
||||
return
|
||||
slaacs = {mac: slaac_eui48(data.ip.network, mac) for mac in self.config.macs}
|
||||
for mac in self.config.macs:
|
||||
yield NftUpdate(
|
||||
obj_type="set",
|
||||
obj_name=f"{set_prefix}_{mac}",
|
||||
operation=NftValueOperation.REPLACE,
|
||||
values=(slaacs[mac].ip.compressed,),
|
||||
)
|
||||
for proto in self.config.protocols:
|
||||
yield NftUpdate(
|
||||
obj_type="set",
|
||||
obj_name=f"{set_prefix}exp{proto.protocol}",
|
||||
operation=NftValueOperation.REPLACE,
|
||||
values=tuple(
|
||||
f"{slaacs[mac].ip.compressed} . {port}"
|
||||
for mac, portList in proto.exposed.items()
|
||||
for port in portList
|
||||
),
|
||||
)
|
||||
yield NftUpdate(
|
||||
obj_type="map",
|
||||
obj_name=f"{set_prefix}dnat{proto.protocol}",
|
||||
operation=NftValueOperation.REPLACE,
|
||||
values=tuple(
|
||||
f"{wan} : {slaacs[mac].ip.compressed} . {lan}"
|
||||
for mac, portMap in proto.forwarded.items()
|
||||
for wan, lan in portMap.items()
|
||||
),
|
||||
)
|
||||
|
||||
def gen_set_definitions(self) -> str:
|
||||
output = []
|
||||
for ip_v in [4, 6]:
|
||||
addr_type = f"ipv{ip_v}_addr"
|
||||
set_prefix = f"{self.config.ifname}v{ip_v}"
|
||||
output.append(gen_set_def("set", f"{set_prefix}addr", addr_type))
|
||||
output.append(gen_set_def("set", f"{set_prefix}net", addr_type, "interval"))
|
||||
if ip_v != 6:
|
||||
continue
|
||||
for mac in self.config.macs:
|
||||
output.append(gen_set_def("set", f"{set_prefix}_{mac}", addr_type))
|
||||
for proto in self.config.protocols:
|
||||
output.append(
|
||||
gen_set_def(
|
||||
"set",
|
||||
f"{set_prefix}exp{proto.protocol}",
|
||||
f"{addr_type} . inet_service",
|
||||
)
|
||||
)
|
||||
output.append(
|
||||
gen_set_def(
|
||||
"map",
|
||||
f"{set_prefix}dnat{proto.protocol}",
|
||||
f"inet_service : {addr_type} . inet_service",
|
||||
)
|
||||
)
|
||||
return "\n".join(output)
|
||||
|
||||
|
||||
def gen_set_def(
|
||||
set_type: str,
|
||||
name: str,
|
||||
data_type: str,
|
||||
flags: str | None = None,
|
||||
elements: Sequence[str] = tuple(),
|
||||
) -> str:
|
||||
return "\n".join(
|
||||
line
|
||||
for line in (
|
||||
f"{set_type} {name} " + "{",
|
||||
f" type {data_type}",
|
||||
f" flags {flags}" if flags is not None else None,
|
||||
" elements = { " + ", ".join(elements) + " }"
|
||||
if len(elements) > 0
|
||||
else None,
|
||||
"}",
|
||||
)
|
||||
if line is not None
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
class NftValueOperation(Enum):
|
||||
ADD = auto()
|
||||
DELETE = auto()
|
||||
REPLACE = auto()
|
||||
|
||||
@staticmethod
|
||||
def if_deleted(b: bool) -> NftValueOperation:
|
||||
return NftValueOperation.DELETE if b else NftValueOperation.ADD
|
||||
|
||||
|
||||
@define(
|
||||
frozen=True,
|
||||
kw_only=True,
|
||||
)
|
||||
class NftUpdate:
|
||||
obj_type: str
|
||||
obj_name: str
|
||||
operation: NftValueOperation
|
||||
values: Sequence[str]
|
||||
|
||||
def to_script(self, table: NftTable) -> str:
|
||||
lines = []
|
||||
# inet family is the only which supports shared IPv4 & IPv6 entries
|
||||
obj_id = f"inet {table} {self.obj_name}"
|
||||
if self.operation == NftValueOperation.REPLACE:
|
||||
lines.append(f"flush {self.obj_type} {obj_id}")
|
||||
if len(self.values) > 0:
|
||||
op_str = "destroy" if self.operation == NftValueOperation.DELETE else "add"
|
||||
values_str = ", ".join(self.values)
|
||||
lines.append(f"{op_str} element {obj_id} {{ {values_str} }}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
class NftUpdateHandler(UpdateStackHandler[NftUpdate]):
|
||||
def __init__(
|
||||
self,
|
||||
update_cmd: Sequence[str],
|
||||
table: NftTable,
|
||||
handler: UpdateHandler[None],
|
||||
) -> None:
|
||||
self.update_cmd = update_cmd
|
||||
self.table = table
|
||||
self.handler = handler
|
||||
|
||||
def _update_stack(self, data: Sequence[NftUpdate]) -> None:
|
||||
logger.debug("compile stacked updates for nftables")
|
||||
script = "\n".join(
|
||||
map(
|
||||
lambda u: u.to_script(table=self.table),
|
||||
data,
|
||||
)
|
||||
)
|
||||
logger.debug(f"pass updates to nftables:\n{script}")
|
||||
subprocess.run(
|
||||
list(self.update_cmd) + ["-f", "-"],
|
||||
input=script,
|
||||
check=True,
|
||||
text=True,
|
||||
)
|
||||
self.handler.update(None)
|
||||
|
||||
|
||||
class SystemdHandler(UpdateHandler[object]):
|
||||
def update(self, data: object) -> None:
|
||||
# TODO improve status updates
|
||||
# daemon.notify("READY=1\nSTATUS=Updated successfully.\n")
|
||||
daemon.notify("READY=1\n")
|
||||
|
||||
def update_stack(self, data: Sequence[object]) -> None:
|
||||
self.update(None)
|
||||
|
||||
|
||||
@define(
|
||||
frozen=True,
|
||||
kw_only=True,
|
||||
)
|
||||
class ProtocolConfig:
|
||||
protocol: NftProtocol
|
||||
exposed: Mapping[MACAddress, Sequence[Port]]
|
||||
"only when direct public IPs are available"
|
||||
forwarded: Mapping[MACAddress, Mapping[Port, Port]] # wan -> lan
|
||||
"i.e. DNAT"
|
||||
|
||||
@staticmethod
|
||||
def from_json(protocol: str, obj: JsonObj) -> ProtocolConfig:
|
||||
assert set(obj.keys()) <= set(("exposed", "forwarded"))
|
||||
exposed_raw = obj.get("exposed")
|
||||
exposed = defaultdict[MACAddress, list[Port]](list)
|
||||
if exposed_raw is not None:
|
||||
assert isinstance(exposed_raw, Sequence)
|
||||
for fwd in exposed_raw:
|
||||
assert isinstance(fwd, Mapping)
|
||||
dest = to_mac(fwd["dest"]) # type: ignore[arg-type]
|
||||
port = to_port(fwd["port"]) # type: ignore[arg-type]
|
||||
exposed[dest].append(port)
|
||||
forwarded_raw = obj.get("forwarded")
|
||||
forwarded = defaultdict[MACAddress, dict[Port, Port]](dict)
|
||||
if forwarded_raw is not None:
|
||||
assert isinstance(forwarded_raw, Sequence)
|
||||
for smap in forwarded_raw:
|
||||
assert isinstance(smap, Mapping)
|
||||
dest = to_mac(smap["dest"]) # type: ignore[arg-type]
|
||||
wanPort = to_port(smap["wanPort"]) # type: ignore[arg-type]
|
||||
lanPort = to_port(smap["lanPort"]) # type: ignore[arg-type]
|
||||
forwarded[dest][wanPort] = lanPort
|
||||
return ProtocolConfig(
|
||||
protocol=NftProtocol(protocol),
|
||||
exposed=exposed,
|
||||
forwarded=forwarded,
|
||||
)
|
||||
|
||||
|
||||
@define(
|
||||
frozen=True,
|
||||
kw_only=True,
|
||||
)
|
||||
class InterfaceConfig:
|
||||
ifname: IfName
|
||||
protocols: Sequence[ProtocolConfig]
|
||||
|
||||
@cached_property
|
||||
def macs(self) -> Sequence[MACAddress]:
|
||||
return tuple(
|
||||
set(
|
||||
chain(
|
||||
(mac for proto in self.protocols for mac in proto.exposed.keys()),
|
||||
(mac for proto in self.protocols for mac in proto.forwarded.keys()),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_json(ifname: str, obj: JsonObj) -> InterfaceConfig:
|
||||
assert set(obj.keys()) <= set(("ports",))
|
||||
ports = obj.get("ports")
|
||||
assert ports == None or isinstance(ports, Mapping)
|
||||
return InterfaceConfig(
|
||||
ifname=IfName(ifname),
|
||||
protocols=tuple()
|
||||
if ports == None
|
||||
else tuple(
|
||||
ProtocolConfig.from_json(proto, cast(JsonObj, proto_cfg))
|
||||
for proto, proto_cfg in ports.items() # type: ignore[union-attr]
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@define(
|
||||
frozen=True,
|
||||
kw_only=True,
|
||||
)
|
||||
class AppConfig:
|
||||
nft_table: NftTable
|
||||
interfaces: Sequence[InterfaceConfig]
|
||||
|
||||
@staticmethod
|
||||
def from_json(obj: JsonObj) -> AppConfig:
|
||||
assert set(obj.keys()) <= set(("interfaces", "nftTable"))
|
||||
nft_table = obj["nftTable"]
|
||||
assert isinstance(nft_table, str)
|
||||
interfaces = obj["interfaces"]
|
||||
assert isinstance(interfaces, Mapping)
|
||||
return AppConfig(
|
||||
nft_table=NftTable(nft_table),
|
||||
interfaces=tuple(
|
||||
InterfaceConfig.from_json(ifname, cast(JsonObj, if_cfg))
|
||||
for ifname, if_cfg in interfaces.items()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def read_config_file(path: Path) -> AppConfig:
|
||||
with path.open("r") as fh:
|
||||
json_data = json.load(fh)
|
||||
logger.debug(repr(json_data))
|
||||
return AppConfig.from_json(json_data)
|
||||
|
||||
|
||||
LOG_LEVEL_MAP = {
|
||||
"critical": logging.CRITICAL,
|
||||
"error": logging.ERROR,
|
||||
"warning": logging.WARNING,
|
||||
"info": logging.INFO,
|
||||
"debug": logging.DEBUG,
|
||||
}
|
||||
|
||||
|
||||
def _gen_if_updater(
|
||||
configs: Sequence[InterfaceConfig], nft_updater: UpdateHandler[NftUpdate]
|
||||
) -> Sequence[InterfaceUpdateHandler]:
|
||||
return tuple(
|
||||
InterfaceUpdateHandler(
|
||||
config=if_cfg,
|
||||
nft_handler=nft_updater,
|
||||
)
|
||||
for if_cfg in configs
|
||||
)
|
||||
|
||||
|
||||
def static_part_generation(config: AppConfig) -> None:
|
||||
dummy = IgnoreHandler()
|
||||
if_updater = _gen_if_updater(config.interfaces, dummy)
|
||||
for if_up in if_updater:
|
||||
print(if_up.gen_set_definitions())
|
||||
|
||||
|
||||
def service_execution(args: argparse.Namespace, config: AppConfig) -> NoReturn:
|
||||
nft_updater = NftUpdateHandler(
|
||||
table=config.nft_table,
|
||||
update_cmd=shlex.split(args.nft_command),
|
||||
handler=SystemdHandler(),
|
||||
)
|
||||
if_updater = _gen_if_updater(config.interfaces, nft_updater)
|
||||
burst_handler = UpdateBurstHandler[IpAddressUpdate](
|
||||
burst_interval=0.5,
|
||||
handler=if_updater,
|
||||
)
|
||||
monitor_ip(shlex.split(args.ip_command), burst_handler)
|
||||
|
||||
|
||||
def setup_logging(args: Any) -> None:
|
||||
systemd_service = os.environ.get("INVOCATION_ID") and Path("/dev/log").exists()
|
||||
if systemd_service:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(JournalHandler(SYSLOG_IDENTIFIER="nft-update-addresses"))
|
||||
else:
|
||||
logging.basicConfig() # get output to stdout/stderr
|
||||
logger.setLevel(LOG_LEVEL_MAP[args.log_level])
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-c", "--config-file", required=True)
|
||||
parser.add_argument("--check-config", action="store_true")
|
||||
parser.add_argument("--output-set-definitions", action="store_true")
|
||||
parser.add_argument("--ip-command", default="/usr/bin/env ip")
|
||||
parser.add_argument("--nft-command", default="/usr/bin/env nft")
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--log-level",
|
||||
default="error",
|
||||
choices=LOG_LEVEL_MAP.keys(),
|
||||
help="Log level for outputs to stdout/stderr (ignored when launched in a systemd service)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
setup_logging(args)
|
||||
config = read_config_file(Path(args.config_file))
|
||||
if args.check_config:
|
||||
return
|
||||
if args.output_set_definitions:
|
||||
return static_part_generation(config)
|
||||
service_execution(args, config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in New Issue