#!/usr/bin/env python3 from __future__ import annotations from abc import ( ABC, abstractmethod, ) import argparse from collections import defaultdict from collections.abc import ( Mapping, Sequence, ) from enum import ( Enum, Flag, auto, ) from functools import cached_property import io from ipaddress import ( IPv4Interface, IPv6Interface, IPv6Network, ) from itertools import chain import json import logging from logging.handlers import SysLogHandler import os from pathlib import Path import re import shlex import subprocess import threading from threading import ( RLock, Timer, ) import traceback from typing import ( Any, Iterable, NewType, NoReturn, Protocol, TypeAlias, TypeGuard, TypeVar, Union, cast, ) from attrs import ( define, field, ) from systemd import daemon # type: ignore[import-untyped] from systemd.journal import JournalHandler # type: ignore[import-untyped] logger = logging.getLogger(__name__) def raise_and_exit(args: Any) -> None: Timer(0.01, os._exit, args=(1,)).start() raise args[0] # ensure exceptions in any thread brings the program down # important for proper error detection via tests & in random cases in real world threading.excepthook = raise_and_exit JsonVal: TypeAlias = Union["JsonObj", "JsonList", str, int, bool] JsonList: TypeAlias = Sequence[JsonVal] JsonObj: TypeAlias = Mapping[str, JsonVal] T = TypeVar("T", contravariant=True) MACAddress = NewType("MACAddress", str) # format: aabbccddeeff (lower-case, without separators) NftProtocol = NewType("NftProtocol", str) # e.g. tcp, udp, … Port = NewType("Port", int) IfName = NewType("IfName", str) NftTable = NewType("NftTable", str) def to_mac(mac_str: str) -> MACAddress: eui48 = re.sub(r"[.:_-]", "", mac_str.lower()) if not is_mac(eui48): raise ValueError(f"invalid MAC address / EUI48: {mac_str}") return MACAddress(eui48) def is_mac(mac_str: str) -> TypeGuard[MACAddress]: return re.match(r"^[0-9a-f]{12}$", mac_str) != None def to_port(port_str: str | int) -> Port: try: port = int(port_str) except ValueError as exc: raise ValueError(f"invalid port number: {port_str}") from exc if not is_port(port): raise ValueError(f"invalid port number: {port_str}") return Port(port) def is_port(port: int) -> TypeGuard[Port]: return 0 < port < 65536 def slaac_eui48(prefix: IPv6Network, eui48: MACAddress) -> IPv6Interface: if prefix.prefixlen > 64: raise ValueError( f"a SLAAC IPv6 address requires a CIDR of at least /64, got {prefix}" ) eui64 = eui48[0:6] + "fffe" + eui48[6:] modified = hex(int(eui64[0:2], 16) ^ 2)[2:].zfill(2) + eui64[2:] euil = int(modified, 16) return IPv6Interface(f"{prefix[euil].compressed}/{prefix.prefixlen}") class UpdateHandler(Protocol[T]): def update(self, data: T) -> None: ... def update_stack(self, data: Sequence[T]) -> None: ... class UpdateStackHandler(UpdateHandler[T], ABC): def update(self, data: T) -> None: return self._update_stack((data,)) def update_stack(self, data: Sequence[T]) -> None: if len(data) <= 0: logger.warning( f"[bug, please report upstream] received empty data in update_stack. Traceback:\n{''.join(traceback.format_stack())}" ) return return self._update_stack(data) @abstractmethod def _update_stack(self, data: Sequence[T]) -> None: ... class IgnoreHandler(UpdateStackHandler[object]): def _update_stack(self, data: Sequence[object]) -> None: return @define( kw_only=True, slots=False, ) class UpdateBurstHandler(UpdateStackHandler[T]): burst_interval: float handler: Sequence[UpdateHandler[T]] __lock: RLock = field(factory=RLock) __updates: list[T] = field(factory=list) __timer: Timer | None = None def _update_stack(self, data: Sequence[T]) -> None: with self.__lock: self.__updates.extend(data) self.__refresh_timer() def __refresh_timer(self) -> None: with self.__lock: if self.__timer is not None: # try to cancel # not a problem if timer already elapsed but before processing really started # because due to using locks when accessing updates self.__timer.cancel() self.__timer = Timer( interval=self.burst_interval, function=self.__process_updates, ) self.__timer.start() def __process_updates(self) -> None: with self.__lock: self.__timer = None if not self.__updates: return updates = self.__updates self.__updates = [] for handler in self.handler: handler.update_stack(updates) class IpFlag(Flag): dynamic = auto() mngtmpaddr = auto() noprefixroute = auto() temporary = auto() tentiative = auto() @staticmethod def parse_str(flags_str: Sequence[str], ignore_unknown: bool = True) -> IpFlag: flags = IpFlag(0) for flag in flags_str: flag = flag.lower() member = IpFlag.__members__.get(flag) if member is not None: flags |= member elif not ignore_unknown: raise Exception(f"Unrecognized IpFlag: {flag}") return flags IP_MON_PATTERN = re.compile( r"""(?x)^ (?P[Dd]eleted\s+)? (?P\d+):\s+ (?P\S+)\s+ (?Pinet6?)\s+ (?P\S+)\s+ #(?:metric\s+\S+\s+)? # sometimes possible #(?:brd\s+\S+\s+)? # broadcast IP on inet (?:\S+\s+\S+\s+)* # abstracted irrelevant attributes (?:scope\s+(?P\S+)\s+) (?P(?:(\S+)\s)*) # (single spaces required for parser below to work correctly) (?:\S+)? # random interface name repetition on inet [\\]\s+ .* # lifetimes which are not interesting yet $""" ) @define( frozen=True, kw_only=True, ) class IpAddressUpdate: deleted: bool ifindex: int ifname: IfName ip: IPv4Interface | IPv6Interface scope: str flags: IpFlag @staticmethod def parse_line(line: str) -> IpAddressUpdate: m = IP_MON_PATTERN.search(line) if not m: raise Exception(f"Could not parse ip monitor output: {line!r}") grp = m.groupdict() ip_type: type[IPv4Interface | IPv6Interface] = ( IPv6Interface if grp["type"] == "inet6" else IPv4Interface ) try: ip = ip_type(grp["ip"]) except ValueError as e: raise Exception( f"Could not parse ip monitor output, invalid IP: {grp['ip']!r}" ) from e flags = IpFlag.parse_str(grp["flags"].strip().split(" ")) return IpAddressUpdate( deleted=grp["deleted"] != None, ifindex=int(grp["ifindex"]), ifname=IfName(grp["ifname"]), ip=ip, scope=grp["scope"], flags=flags, ) def monitor_ip( ip_cmd: list[str], handler: UpdateHandler[IpAddressUpdate], ) -> NoReturn: proc = subprocess.Popen( ip_cmd + ["-o", "monitor", "address"], stdout=subprocess.PIPE, text=True, ) # initial kickoff (AFTER starting monitoring, to not miss any update) logger.info("kickoff IP monitoring with current data") res = subprocess.run( ip_cmd + ["-o", "address", "show"], check=True, stdout=subprocess.PIPE, text=True, ) for line in res.stdout.splitlines(keepends=False): line = line.rstrip() if line == "": continue update = IpAddressUpdate.parse_line(line) logger.debug(f"pass IP update: {update!r}") handler.update(update) logger.info("loading kickoff finished; start regular monitoring") while True: rc = proc.poll() if rc != None: # flush stdout for easier debugging logger.error("Last stdout of monitor process:") logger.error(proc.stdout.read()) # type: ignore[union-attr] raise Exception(f"Monitor process crashed with returncode {rc}") line = proc.stdout.readline().rstrip() # type: ignore[union-attr] if not line: continue logger.info("IP change detected") update = IpAddressUpdate.parse_line(line) logger.debug(f"pass IP update: {update!r}") handler.update(update) class InterfaceUpdateHandler(UpdateStackHandler[IpAddressUpdate]): # TODO regularly check (i.e. 1 hour) if stored lists are still correct def __init__( self, config: InterfaceConfig, nft_handler: UpdateHandler[NftUpdate], ) -> None: self.nft_handler = nft_handler self.lock = RLock() self.config = config self.ipv4Addrs = list[IPv4Interface]() self.ipv6Addrs = list[IPv6Interface]() def _update_stack(self, data: Sequence[IpAddressUpdate]) -> None: nft_updates = tuple( chain.from_iterable(self.__gen_updates(single) for single in data) ) if len(nft_updates) <= 0: return self.nft_handler.update_stack(nft_updates) def __gen_updates(self, data: IpAddressUpdate) -> Iterable[NftUpdate]: if data.ifname != self.config.ifname: return if data.ip.is_link_local: logger.debug( f"{self.config.ifname}: ignore change for IP {data.ip} because link-local" ) return if IpFlag.temporary in data.flags: logger.debug( f"{self.config.ifname}: ignore change for IP {data.ip} because temporary" ) return # ignore IPv6 privacy extension addresses if IpFlag.tentiative in data.flags: logger.debug( f"{self.config.ifname}: ignore change for IP {data.ip} because tentiative" ) return # ignore (yet) tentiative addresses logger.debug(f"{self.config.ifname}: process change of IP {data.ip}") with self.lock: ip_list: list[IPv4Interface] | list[IPv6Interface] = ( self.ipv6Addrs if isinstance(data.ip, IPv6Interface) else self.ipv4Addrs ) if data.deleted != (data.ip in ip_list): return # no change required if data.deleted: logger.info(f"{self.config.ifname}: deleted IP {data.ip}") ip_list.remove(data.ip) # type: ignore[arg-type] else: logger.info(f"{self.config.ifname}: discovered IP {data.ip}") ip_list.append(data.ip) # type: ignore[arg-type] set_prefix = f"{self.config.ifname}v{data.ip.version}" op = NftValueOperation.if_deleted(data.deleted) yield NftUpdate( obj_type="set", obj_name=f"{set_prefix}net", operation=op, values=(data.ip.network.compressed,), ) link_local_space = IPv6Network("fc00::/7") # because ip.is_private is wrong if data.ip in link_local_space: logger.debug( f"{self.config.ifname}: only updated {set_prefix}net for changes in fc00::/7" ) return yield NftUpdate( obj_type="set", obj_name=f"{set_prefix}addr", operation=op, values=(data.ip.ip.compressed,), ) if data.ip.version != 6: return slaacs = {mac: slaac_eui48(data.ip.network, mac) for mac in self.config.macs} for mac in self.config.macs: yield NftUpdate( obj_type="set", obj_name=f"{set_prefix}_{mac}", operation=NftValueOperation.REPLACE, values=(slaacs[mac].ip.compressed,), ) for proto in self.config.protocols: yield NftUpdate( obj_type="set", obj_name=f"{set_prefix}exp{proto.protocol}", operation=NftValueOperation.REPLACE, values=tuple( f"{slaacs[mac].ip.compressed} . {port}" for mac, portList in proto.exposed.items() for port in portList ), ) yield NftUpdate( obj_type="map", obj_name=f"{set_prefix}dnat{proto.protocol}", operation=NftValueOperation.REPLACE, values=tuple( f"{wan} : {slaacs[mac].ip.compressed} . {lan}" for mac, portMap in proto.forwarded.items() for wan, lan in portMap.items() ), ) def gen_set_definitions(self) -> str: output = [] for ip_v in [4, 6]: addr_type = f"ipv{ip_v}_addr" set_prefix = f"{self.config.ifname}v{ip_v}" output.append(gen_set_def("set", f"{set_prefix}addr", addr_type)) output.append(gen_set_def("set", f"{set_prefix}net", addr_type, "interval")) if ip_v != 6: continue for mac in self.config.macs: output.append(gen_set_def("set", f"{set_prefix}_{mac}", addr_type)) for proto in self.config.protocols: output.append( gen_set_def( "set", f"{set_prefix}exp{proto.protocol}", f"{addr_type} . inet_service", ) ) output.append( gen_set_def( "map", f"{set_prefix}dnat{proto.protocol}", f"inet_service : {addr_type} . inet_service", ) ) return "\n".join(output) def gen_set_def( set_type: str, name: str, data_type: str, flags: str | None = None, elements: Sequence[str] = tuple(), ) -> str: return "\n".join( line for line in ( f"{set_type} {name} " + "{", f" type {data_type}", f" flags {flags}" if flags is not None else None, " elements = { " + ", ".join(elements) + " }" if len(elements) > 0 else None, "}", ) if line is not None ) pass class NftValueOperation(Enum): ADD = auto() DELETE = auto() REPLACE = auto() @staticmethod def if_deleted(b: bool) -> NftValueOperation: return NftValueOperation.DELETE if b else NftValueOperation.ADD @define( frozen=True, kw_only=True, ) class NftUpdate: obj_type: str obj_name: str operation: NftValueOperation values: Sequence[str] def to_script(self, table: NftTable) -> str: lines = [] # inet family is the only which supports shared IPv4 & IPv6 entries obj_id = f"inet {table} {self.obj_name}" if self.operation == NftValueOperation.REPLACE: lines.append(f"flush {self.obj_type} {obj_id}") if len(self.values) > 0: op_str = "destroy" if self.operation == NftValueOperation.DELETE else "add" values_str = ", ".join(self.values) lines.append(f"{op_str} element {obj_id} {{ {values_str} }}") return "\n".join(lines) class NftUpdateHandler(UpdateStackHandler[NftUpdate]): def __init__( self, update_cmd: Sequence[str], table: NftTable, handler: UpdateHandler[None], ) -> None: self.update_cmd = update_cmd self.table = table self.handler = handler def _update_stack(self, data: Sequence[NftUpdate]) -> None: logger.debug("compile stacked updates for nftables") script = "\n".join( map( lambda u: u.to_script(table=self.table), data, ) ) logger.debug(f"pass updates to nftables:\n{script}") subprocess.run( list(self.update_cmd) + ["-f", "-"], input=script, check=True, text=True, ) self.handler.update(None) class SystemdHandler(UpdateHandler[object]): def update(self, data: object) -> None: # TODO improve status updates # daemon.notify("READY=1\nSTATUS=Updated successfully.\n") daemon.notify("READY=1\n") def update_stack(self, data: Sequence[object]) -> None: self.update(None) @define( frozen=True, kw_only=True, ) class ProtocolConfig: protocol: NftProtocol exposed: Mapping[MACAddress, Sequence[Port]] "only when direct public IPs are available" forwarded: Mapping[MACAddress, Mapping[Port, Port]] # wan -> lan "i.e. DNAT" @staticmethod def from_json(protocol: str, obj: JsonObj) -> ProtocolConfig: assert set(obj.keys()) <= set(("exposed", "forwarded")) exposed_raw = obj.get("exposed") exposed = defaultdict[MACAddress, list[Port]](list) if exposed_raw is not None: assert isinstance(exposed_raw, Sequence) for fwd in exposed_raw: assert isinstance(fwd, Mapping) dest = to_mac(fwd["dest"]) # type: ignore[arg-type] port = to_port(fwd["port"]) # type: ignore[arg-type] exposed[dest].append(port) forwarded_raw = obj.get("forwarded") forwarded = defaultdict[MACAddress, dict[Port, Port]](dict) if forwarded_raw is not None: assert isinstance(forwarded_raw, Sequence) for smap in forwarded_raw: assert isinstance(smap, Mapping) dest = to_mac(smap["dest"]) # type: ignore[arg-type] wanPort = to_port(smap["wanPort"]) # type: ignore[arg-type] lanPort = to_port(smap["lanPort"]) # type: ignore[arg-type] forwarded[dest][wanPort] = lanPort return ProtocolConfig( protocol=NftProtocol(protocol), exposed=exposed, forwarded=forwarded, ) @define( frozen=True, kw_only=True, ) class InterfaceConfig: ifname: IfName protocols: Sequence[ProtocolConfig] @cached_property def macs(self) -> Sequence[MACAddress]: return tuple( set( chain( (mac for proto in self.protocols for mac in proto.exposed.keys()), (mac for proto in self.protocols for mac in proto.forwarded.keys()), ) ) ) @staticmethod def from_json(ifname: str, obj: JsonObj) -> InterfaceConfig: assert set(obj.keys()) <= set(("ports",)) ports = obj.get("ports") assert ports == None or isinstance(ports, Mapping) return InterfaceConfig( ifname=IfName(ifname), protocols=tuple() if ports == None else tuple( ProtocolConfig.from_json(proto, cast(JsonObj, proto_cfg)) for proto, proto_cfg in ports.items() # type: ignore[union-attr] ), ) @define( frozen=True, kw_only=True, ) class AppConfig: nft_table: NftTable interfaces: Sequence[InterfaceConfig] @staticmethod def from_json(obj: JsonObj) -> AppConfig: assert set(obj.keys()) <= set(("interfaces", "nftTable")) nft_table = obj["nftTable"] assert isinstance(nft_table, str) interfaces = obj["interfaces"] assert isinstance(interfaces, Mapping) return AppConfig( nft_table=NftTable(nft_table), interfaces=tuple( InterfaceConfig.from_json(ifname, cast(JsonObj, if_cfg)) for ifname, if_cfg in interfaces.items() ), ) def read_config_file(path: Path) -> AppConfig: with path.open("r") as fh: json_data = json.load(fh) logger.debug(repr(json_data)) return AppConfig.from_json(json_data) LOG_LEVEL_MAP = { "critical": logging.CRITICAL, "error": logging.ERROR, "warning": logging.WARNING, "info": logging.INFO, "debug": logging.DEBUG, } def _gen_if_updater( configs: Sequence[InterfaceConfig], nft_updater: UpdateHandler[NftUpdate] ) -> Sequence[InterfaceUpdateHandler]: return tuple( InterfaceUpdateHandler( config=if_cfg, nft_handler=nft_updater, ) for if_cfg in configs ) def static_part_generation(config: AppConfig) -> None: dummy = IgnoreHandler() if_updater = _gen_if_updater(config.interfaces, dummy) for if_up in if_updater: print(if_up.gen_set_definitions()) def service_execution(args: argparse.Namespace, config: AppConfig) -> NoReturn: nft_updater = NftUpdateHandler( table=config.nft_table, update_cmd=shlex.split(args.nft_command), handler=SystemdHandler(), ) if_updater = _gen_if_updater(config.interfaces, nft_updater) burst_handler = UpdateBurstHandler[IpAddressUpdate]( burst_interval=0.5, handler=if_updater, ) monitor_ip(shlex.split(args.ip_command), burst_handler) def setup_logging(args: Any) -> None: systemd_service = os.environ.get("INVOCATION_ID") and Path("/dev/log").exists() if systemd_service: logger.setLevel(logging.DEBUG) logger.addHandler(JournalHandler(SYSLOG_IDENTIFIER="nft-update-addresses")) else: logging.basicConfig() # get output to stdout/stderr logger.setLevel(LOG_LEVEL_MAP[args.log_level]) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("-c", "--config-file", required=True) parser.add_argument("--check-config", action="store_true") parser.add_argument("--output-set-definitions", action="store_true") parser.add_argument("--ip-command", default="/usr/bin/env ip") parser.add_argument("--nft-command", default="/usr/bin/env nft") parser.add_argument( "-l", "--log-level", default="error", choices=LOG_LEVEL_MAP.keys(), help="Log level for outputs to stdout/stderr (ignored when launched in a systemd service)", ) args = parser.parse_args() setup_logging(args) config = read_config_file(Path(args.config_file)) if args.check_config: return if args.output_set_definitions: return static_part_generation(config) service_execution(args, config) if __name__ == "__main__": main()