You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
728 lines
22 KiB
Python
728 lines
22 KiB
Python
2 months ago
|
#!/usr/bin/env python3
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
from abc import (
|
||
|
ABC,
|
||
|
abstractmethod,
|
||
|
)
|
||
|
import argparse
|
||
|
from collections import defaultdict
|
||
|
from collections.abc import (
|
||
|
Mapping,
|
||
|
Sequence,
|
||
|
)
|
||
|
from enum import (
|
||
|
Enum,
|
||
|
Flag,
|
||
|
auto,
|
||
|
)
|
||
|
from functools import cached_property
|
||
|
import io
|
||
|
from ipaddress import (
|
||
|
IPv4Interface,
|
||
|
IPv6Interface,
|
||
|
IPv6Network,
|
||
|
)
|
||
|
from itertools import chain
|
||
|
import json
|
||
|
import logging
|
||
|
from logging.handlers import SysLogHandler
|
||
|
import os
|
||
|
from pathlib import Path
|
||
|
import re
|
||
|
import shlex
|
||
|
import subprocess
|
||
|
import threading
|
||
|
from threading import (
|
||
|
RLock,
|
||
|
Timer,
|
||
|
)
|
||
|
import traceback
|
||
|
from typing import (
|
||
|
Any,
|
||
|
Iterable,
|
||
|
NewType,
|
||
|
NoReturn,
|
||
|
Protocol,
|
||
|
TypeAlias,
|
||
|
TypeGuard,
|
||
|
TypeVar,
|
||
|
Union,
|
||
|
cast,
|
||
|
)
|
||
|
|
||
|
from attrs import (
|
||
|
define,
|
||
|
field,
|
||
|
)
|
||
|
from systemd import daemon # type: ignore[import-untyped]
|
||
|
from systemd.journal import JournalHandler # type: ignore[import-untyped]
|
||
|
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def raise_and_exit(args: Any) -> None:
|
||
|
Timer(0.01, os._exit, args=(1,)).start()
|
||
|
raise args[0]
|
||
|
|
||
|
|
||
|
# ensure exceptions in any thread brings the program down
|
||
|
# important for proper error detection via tests & in random cases in real world
|
||
|
|
||
|
threading.excepthook = raise_and_exit
|
||
|
|
||
|
|
||
|
JsonVal: TypeAlias = Union["JsonObj", "JsonList", str, int, bool]
|
||
|
JsonList: TypeAlias = Sequence[JsonVal]
|
||
|
JsonObj: TypeAlias = Mapping[str, JsonVal]
|
||
|
|
||
|
T = TypeVar("T", contravariant=True)
|
||
|
|
||
|
MACAddress = NewType("MACAddress", str)
|
||
|
# format: aabbccddeeff (lower-case, without separators)
|
||
|
NftProtocol = NewType("NftProtocol", str) # e.g. tcp, udp, …
|
||
|
Port = NewType("Port", int)
|
||
|
IfName = NewType("IfName", str)
|
||
|
NftTable = NewType("NftTable", str)
|
||
|
|
||
|
|
||
|
def to_mac(mac_str: str) -> MACAddress:
|
||
|
eui48 = re.sub(r"[.:_-]", "", mac_str.lower())
|
||
|
if not is_mac(eui48):
|
||
|
raise ValueError(f"invalid MAC address / EUI48: {mac_str}")
|
||
|
return MACAddress(eui48)
|
||
|
|
||
|
|
||
|
def is_mac(mac_str: str) -> TypeGuard[MACAddress]:
|
||
|
return re.match(r"^[0-9a-f]{12}$", mac_str) != None
|
||
|
|
||
|
|
||
|
def to_port(port_str: str | int) -> Port:
|
||
|
try:
|
||
|
port = int(port_str)
|
||
|
except ValueError as exc:
|
||
|
raise ValueError(f"invalid port number: {port_str}") from exc
|
||
|
if not is_port(port):
|
||
|
raise ValueError(f"invalid port number: {port_str}")
|
||
|
return Port(port)
|
||
|
|
||
|
|
||
|
def is_port(port: int) -> TypeGuard[Port]:
|
||
|
return 0 < port < 65536
|
||
|
|
||
|
|
||
|
def slaac_eui48(prefix: IPv6Network, eui48: MACAddress) -> IPv6Interface:
|
||
|
if prefix.prefixlen > 64:
|
||
|
raise ValueError(
|
||
|
f"a SLAAC IPv6 address requires a CIDR of at least /64, got {prefix}"
|
||
|
)
|
||
|
eui64 = eui48[0:6] + "fffe" + eui48[6:]
|
||
|
modified = hex(int(eui64[0:2], 16) ^ 2)[2:].zfill(2) + eui64[2:]
|
||
|
euil = int(modified, 16)
|
||
|
return IPv6Interface(f"{prefix[euil].compressed}/{prefix.prefixlen}")
|
||
|
|
||
|
|
||
|
class UpdateHandler(Protocol[T]):
|
||
|
def update(self, data: T) -> None:
|
||
|
...
|
||
|
|
||
|
def update_stack(self, data: Sequence[T]) -> None:
|
||
|
...
|
||
|
|
||
|
|
||
|
class UpdateStackHandler(UpdateHandler[T], ABC):
|
||
|
def update(self, data: T) -> None:
|
||
|
return self._update_stack((data,))
|
||
|
|
||
|
def update_stack(self, data: Sequence[T]) -> None:
|
||
|
if len(data) <= 0:
|
||
|
logger.warning(
|
||
|
f"[bug, please report upstream] received empty data in update_stack. Traceback:\n{''.join(traceback.format_stack())}"
|
||
|
)
|
||
|
return
|
||
|
return self._update_stack(data)
|
||
|
|
||
|
@abstractmethod
|
||
|
def _update_stack(self, data: Sequence[T]) -> None:
|
||
|
...
|
||
|
|
||
|
|
||
|
class IgnoreHandler(UpdateStackHandler[object]):
|
||
|
def _update_stack(self, data: Sequence[object]) -> None:
|
||
|
return
|
||
|
|
||
|
|
||
|
@define(
|
||
|
kw_only=True,
|
||
|
slots=False,
|
||
|
)
|
||
|
class UpdateBurstHandler(UpdateStackHandler[T]):
|
||
|
burst_interval: float
|
||
|
handler: Sequence[UpdateHandler[T]]
|
||
|
__lock: RLock = field(factory=RLock)
|
||
|
__updates: list[T] = field(factory=list)
|
||
|
__timer: Timer | None = None
|
||
|
|
||
|
def _update_stack(self, data: Sequence[T]) -> None:
|
||
|
with self.__lock:
|
||
|
self.__updates.extend(data)
|
||
|
self.__refresh_timer()
|
||
|
|
||
|
def __refresh_timer(self) -> None:
|
||
|
with self.__lock:
|
||
|
if self.__timer is not None:
|
||
|
# try to cancel
|
||
|
# not a problem if timer already elapsed but before processing really started
|
||
|
# because due to using locks when accessing updates
|
||
|
self.__timer.cancel()
|
||
|
self.__timer = Timer(
|
||
|
interval=self.burst_interval,
|
||
|
function=self.__process_updates,
|
||
|
)
|
||
|
self.__timer.start()
|
||
|
|
||
|
def __process_updates(self) -> None:
|
||
|
with self.__lock:
|
||
|
self.__timer = None
|
||
|
if not self.__updates:
|
||
|
return
|
||
|
updates = self.__updates
|
||
|
self.__updates = []
|
||
|
for handler in self.handler:
|
||
|
handler.update_stack(updates)
|
||
|
|
||
|
|
||
|
class IpFlag(Flag):
|
||
|
dynamic = auto()
|
||
|
mngtmpaddr = auto()
|
||
|
noprefixroute = auto()
|
||
|
temporary = auto()
|
||
|
tentiative = auto()
|
||
|
|
||
|
@staticmethod
|
||
|
def parse_str(flags_str: Sequence[str], ignore_unknown: bool = True) -> IpFlag:
|
||
|
flags = IpFlag(0)
|
||
|
for flag in flags_str:
|
||
|
flag = flag.lower()
|
||
|
member = IpFlag.__members__.get(flag)
|
||
|
if member is not None:
|
||
|
flags |= member
|
||
|
elif not ignore_unknown:
|
||
|
raise Exception(f"Unrecognized IpFlag: {flag}")
|
||
|
return flags
|
||
|
|
||
|
|
||
|
IP_MON_PATTERN = re.compile(
|
||
|
r"""(?x)^
|
||
|
(?P<deleted>[Dd]eleted\s+)?
|
||
|
(?P<ifindex>\d+):\s+
|
||
|
(?P<ifname>\S+)\s+
|
||
|
(?P<type>inet6?)\s+
|
||
|
(?P<ip>\S+)\s+
|
||
|
#(?:metric\s+\S+\s+)? # sometimes possible
|
||
|
#(?:brd\s+\S+\s+)? # broadcast IP on inet
|
||
|
(?:\S+\s+\S+\s+)* # abstracted irrelevant attributes
|
||
|
(?:scope\s+(?P<scope>\S+)\s+)
|
||
|
(?P<flags>(?:(\S+)\s)*) # (single spaces required for parser below to work correctly)
|
||
|
(?:\S+)? # random interface name repetition on inet
|
||
|
[\\]\s+
|
||
|
.* # lifetimes which are not interesting yet
|
||
|
$"""
|
||
|
)
|
||
|
|
||
|
|
||
|
@define(
|
||
|
frozen=True,
|
||
|
kw_only=True,
|
||
|
)
|
||
|
class IpAddressUpdate:
|
||
|
deleted: bool
|
||
|
ifindex: int
|
||
|
ifname: IfName
|
||
|
ip: IPv4Interface | IPv6Interface
|
||
|
scope: str
|
||
|
flags: IpFlag
|
||
|
|
||
|
@staticmethod
|
||
|
def parse_line(line: str) -> IpAddressUpdate:
|
||
|
m = IP_MON_PATTERN.search(line)
|
||
|
if not m:
|
||
|
raise Exception(f"Could not parse ip monitor output: {line!r}")
|
||
|
grp = m.groupdict()
|
||
|
ip_type: type[IPv4Interface | IPv6Interface] = (
|
||
|
IPv6Interface if grp["type"] == "inet6" else IPv4Interface
|
||
|
)
|
||
|
try:
|
||
|
ip = ip_type(grp["ip"])
|
||
|
except ValueError as e:
|
||
|
raise Exception(
|
||
|
f"Could not parse ip monitor output, invalid IP: {grp['ip']!r}"
|
||
|
) from e
|
||
|
flags = IpFlag.parse_str(grp["flags"].strip().split(" "))
|
||
|
return IpAddressUpdate(
|
||
|
deleted=grp["deleted"] != None,
|
||
|
ifindex=int(grp["ifindex"]),
|
||
|
ifname=IfName(grp["ifname"]),
|
||
|
ip=ip,
|
||
|
scope=grp["scope"],
|
||
|
flags=flags,
|
||
|
)
|
||
|
|
||
|
|
||
|
def monitor_ip(
|
||
|
ip_cmd: list[str],
|
||
|
handler: UpdateHandler[IpAddressUpdate],
|
||
|
) -> NoReturn:
|
||
|
proc = subprocess.Popen(
|
||
|
ip_cmd + ["-o", "monitor", "address"],
|
||
|
stdout=subprocess.PIPE,
|
||
|
text=True,
|
||
|
)
|
||
|
# initial kickoff (AFTER starting monitoring, to not miss any update)
|
||
|
logger.info("kickoff IP monitoring with current data")
|
||
|
res = subprocess.run(
|
||
|
ip_cmd + ["-o", "address", "show"],
|
||
|
check=True,
|
||
|
stdout=subprocess.PIPE,
|
||
|
text=True,
|
||
|
)
|
||
|
for line in res.stdout.splitlines(keepends=False):
|
||
|
line = line.rstrip()
|
||
|
if line == "":
|
||
|
continue
|
||
|
update = IpAddressUpdate.parse_line(line)
|
||
|
logger.debug(f"pass IP update: {update!r}")
|
||
|
handler.update(update)
|
||
|
logger.info("loading kickoff finished; start regular monitoring")
|
||
|
while True:
|
||
|
rc = proc.poll()
|
||
|
if rc != None:
|
||
|
# flush stdout for easier debugging
|
||
|
logger.error("Last stdout of monitor process:")
|
||
|
logger.error(proc.stdout.read()) # type: ignore[union-attr]
|
||
|
raise Exception(f"Monitor process crashed with returncode {rc}")
|
||
|
line = proc.stdout.readline().rstrip() # type: ignore[union-attr]
|
||
|
if not line:
|
||
|
continue
|
||
|
logger.info("IP change detected")
|
||
|
update = IpAddressUpdate.parse_line(line)
|
||
|
logger.debug(f"pass IP update: {update!r}")
|
||
|
handler.update(update)
|
||
|
|
||
|
|
||
|
class InterfaceUpdateHandler(UpdateStackHandler[IpAddressUpdate]):
|
||
|
# TODO regularly check (i.e. 1 hour) if stored lists are still correct
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
config: InterfaceConfig,
|
||
|
nft_handler: UpdateHandler[NftUpdate],
|
||
|
) -> None:
|
||
|
self.nft_handler = nft_handler
|
||
|
self.lock = RLock()
|
||
|
self.config = config
|
||
|
self.ipv4Addrs = list[IPv4Interface]()
|
||
|
self.ipv6Addrs = list[IPv6Interface]()
|
||
|
|
||
|
def _update_stack(self, data: Sequence[IpAddressUpdate]) -> None:
|
||
|
nft_updates = tuple(
|
||
|
chain.from_iterable(self.__gen_updates(single) for single in data)
|
||
|
)
|
||
|
if len(nft_updates) <= 0:
|
||
|
return
|
||
|
self.nft_handler.update_stack(nft_updates)
|
||
|
|
||
|
def __gen_updates(self, data: IpAddressUpdate) -> Iterable[NftUpdate]:
|
||
|
if data.ifname != self.config.ifname:
|
||
|
return
|
||
|
if data.ip.is_link_local:
|
||
|
logger.debug(
|
||
|
f"{self.config.ifname}: ignore change for IP {data.ip} because link-local"
|
||
|
)
|
||
|
return
|
||
|
if IpFlag.temporary in data.flags:
|
||
|
logger.debug(
|
||
|
f"{self.config.ifname}: ignore change for IP {data.ip} because temporary"
|
||
|
)
|
||
|
return # ignore IPv6 privacy extension addresses
|
||
|
if IpFlag.tentiative in data.flags:
|
||
|
logger.debug(
|
||
|
f"{self.config.ifname}: ignore change for IP {data.ip} because tentiative"
|
||
|
)
|
||
|
return # ignore (yet) tentiative addresses
|
||
|
logger.debug(f"{self.config.ifname}: process change of IP {data.ip}")
|
||
|
with self.lock:
|
||
|
ip_list: list[IPv4Interface] | list[IPv6Interface] = (
|
||
|
self.ipv6Addrs if isinstance(data.ip, IPv6Interface) else self.ipv4Addrs
|
||
|
)
|
||
|
if data.deleted != (data.ip in ip_list):
|
||
|
return # no change required
|
||
|
if data.deleted:
|
||
|
logger.info(f"{self.config.ifname}: deleted IP {data.ip}")
|
||
|
ip_list.remove(data.ip) # type: ignore[arg-type]
|
||
|
else:
|
||
|
logger.info(f"{self.config.ifname}: discovered IP {data.ip}")
|
||
|
ip_list.append(data.ip) # type: ignore[arg-type]
|
||
|
set_prefix = f"{self.config.ifname}v{data.ip.version}"
|
||
|
op = NftValueOperation.if_deleted(data.deleted)
|
||
|
yield NftUpdate(
|
||
|
obj_type="set",
|
||
|
obj_name=f"{set_prefix}net",
|
||
|
operation=op,
|
||
|
values=(data.ip.network.compressed,),
|
||
|
)
|
||
|
link_local_space = IPv6Network("fc00::/7") # because ip.is_private is wrong
|
||
|
if data.ip in link_local_space:
|
||
|
logger.debug(
|
||
|
f"{self.config.ifname}: only updated {set_prefix}net for changes in fc00::/7"
|
||
|
)
|
||
|
return
|
||
|
yield NftUpdate(
|
||
|
obj_type="set",
|
||
|
obj_name=f"{set_prefix}addr",
|
||
|
operation=op,
|
||
|
values=(data.ip.ip.compressed,),
|
||
|
)
|
||
|
if data.ip.version != 6:
|
||
|
return
|
||
|
slaacs = {mac: slaac_eui48(data.ip.network, mac) for mac in self.config.macs}
|
||
|
for mac in self.config.macs:
|
||
|
yield NftUpdate(
|
||
|
obj_type="set",
|
||
|
obj_name=f"{set_prefix}_{mac}",
|
||
|
operation=NftValueOperation.REPLACE,
|
||
|
values=(slaacs[mac].ip.compressed,),
|
||
|
)
|
||
|
for proto in self.config.protocols:
|
||
|
yield NftUpdate(
|
||
|
obj_type="set",
|
||
|
obj_name=f"{set_prefix}exp{proto.protocol}",
|
||
|
operation=NftValueOperation.REPLACE,
|
||
|
values=tuple(
|
||
|
f"{slaacs[mac].ip.compressed} . {port}"
|
||
|
for mac, portList in proto.exposed.items()
|
||
|
for port in portList
|
||
|
),
|
||
|
)
|
||
|
yield NftUpdate(
|
||
|
obj_type="map",
|
||
|
obj_name=f"{set_prefix}dnat{proto.protocol}",
|
||
|
operation=NftValueOperation.REPLACE,
|
||
|
values=tuple(
|
||
|
f"{wan} : {slaacs[mac].ip.compressed} . {lan}"
|
||
|
for mac, portMap in proto.forwarded.items()
|
||
|
for wan, lan in portMap.items()
|
||
|
),
|
||
|
)
|
||
|
|
||
|
def gen_set_definitions(self) -> str:
|
||
|
output = []
|
||
|
for ip_v in [4, 6]:
|
||
|
addr_type = f"ipv{ip_v}_addr"
|
||
|
set_prefix = f"{self.config.ifname}v{ip_v}"
|
||
|
output.append(gen_set_def("set", f"{set_prefix}addr", addr_type))
|
||
|
output.append(gen_set_def("set", f"{set_prefix}net", addr_type, "interval"))
|
||
|
if ip_v != 6:
|
||
|
continue
|
||
|
for mac in self.config.macs:
|
||
|
output.append(gen_set_def("set", f"{set_prefix}_{mac}", addr_type))
|
||
|
for proto in self.config.protocols:
|
||
|
output.append(
|
||
|
gen_set_def(
|
||
|
"set",
|
||
|
f"{set_prefix}exp{proto.protocol}",
|
||
|
f"{addr_type} . inet_service",
|
||
|
)
|
||
|
)
|
||
|
output.append(
|
||
|
gen_set_def(
|
||
|
"map",
|
||
|
f"{set_prefix}dnat{proto.protocol}",
|
||
|
f"inet_service : {addr_type} . inet_service",
|
||
|
)
|
||
|
)
|
||
|
return "\n".join(output)
|
||
|
|
||
|
|
||
|
def gen_set_def(
|
||
|
set_type: str,
|
||
|
name: str,
|
||
|
data_type: str,
|
||
|
flags: str | None = None,
|
||
|
elements: Sequence[str] = tuple(),
|
||
|
) -> str:
|
||
|
return "\n".join(
|
||
|
line
|
||
|
for line in (
|
||
|
f"{set_type} {name} " + "{",
|
||
|
f" type {data_type}",
|
||
|
f" flags {flags}" if flags is not None else None,
|
||
|
" elements = { " + ", ".join(elements) + " }"
|
||
|
if len(elements) > 0
|
||
|
else None,
|
||
|
"}",
|
||
|
)
|
||
|
if line is not None
|
||
|
)
|
||
|
pass
|
||
|
|
||
|
|
||
|
class NftValueOperation(Enum):
|
||
|
ADD = auto()
|
||
|
DELETE = auto()
|
||
|
REPLACE = auto()
|
||
|
|
||
|
@staticmethod
|
||
|
def if_deleted(b: bool) -> NftValueOperation:
|
||
|
return NftValueOperation.DELETE if b else NftValueOperation.ADD
|
||
|
|
||
|
|
||
|
@define(
|
||
|
frozen=True,
|
||
|
kw_only=True,
|
||
|
)
|
||
|
class NftUpdate:
|
||
|
obj_type: str
|
||
|
obj_name: str
|
||
|
operation: NftValueOperation
|
||
|
values: Sequence[str]
|
||
|
|
||
|
def to_script(self, table: NftTable) -> str:
|
||
|
lines = []
|
||
|
# inet family is the only which supports shared IPv4 & IPv6 entries
|
||
|
obj_id = f"inet {table} {self.obj_name}"
|
||
|
if self.operation == NftValueOperation.REPLACE:
|
||
|
lines.append(f"flush {self.obj_type} {obj_id}")
|
||
|
if len(self.values) > 0:
|
||
|
op_str = "destroy" if self.operation == NftValueOperation.DELETE else "add"
|
||
|
values_str = ", ".join(self.values)
|
||
|
lines.append(f"{op_str} element {obj_id} {{ {values_str} }}")
|
||
|
return "\n".join(lines)
|
||
|
|
||
|
|
||
|
class NftUpdateHandler(UpdateStackHandler[NftUpdate]):
|
||
|
def __init__(
|
||
|
self,
|
||
|
update_cmd: Sequence[str],
|
||
|
table: NftTable,
|
||
|
handler: UpdateHandler[None],
|
||
|
) -> None:
|
||
|
self.update_cmd = update_cmd
|
||
|
self.table = table
|
||
|
self.handler = handler
|
||
|
|
||
|
def _update_stack(self, data: Sequence[NftUpdate]) -> None:
|
||
|
logger.debug("compile stacked updates for nftables")
|
||
|
script = "\n".join(
|
||
|
map(
|
||
|
lambda u: u.to_script(table=self.table),
|
||
|
data,
|
||
|
)
|
||
|
)
|
||
|
logger.debug(f"pass updates to nftables:\n{script}")
|
||
|
subprocess.run(
|
||
|
list(self.update_cmd) + ["-f", "-"],
|
||
|
input=script,
|
||
|
check=True,
|
||
|
text=True,
|
||
|
)
|
||
|
self.handler.update(None)
|
||
|
|
||
|
|
||
|
class SystemdHandler(UpdateHandler[object]):
|
||
|
def update(self, data: object) -> None:
|
||
|
# TODO improve status updates
|
||
|
# daemon.notify("READY=1\nSTATUS=Updated successfully.\n")
|
||
|
daemon.notify("READY=1\n")
|
||
|
|
||
|
def update_stack(self, data: Sequence[object]) -> None:
|
||
|
self.update(None)
|
||
|
|
||
|
|
||
|
@define(
|
||
|
frozen=True,
|
||
|
kw_only=True,
|
||
|
)
|
||
|
class ProtocolConfig:
|
||
|
protocol: NftProtocol
|
||
|
exposed: Mapping[MACAddress, Sequence[Port]]
|
||
|
"only when direct public IPs are available"
|
||
|
forwarded: Mapping[MACAddress, Mapping[Port, Port]] # wan -> lan
|
||
|
"i.e. DNAT"
|
||
|
|
||
|
@staticmethod
|
||
|
def from_json(protocol: str, obj: JsonObj) -> ProtocolConfig:
|
||
|
assert set(obj.keys()) <= set(("exposed", "forwarded"))
|
||
|
exposed_raw = obj.get("exposed")
|
||
|
exposed = defaultdict[MACAddress, list[Port]](list)
|
||
|
if exposed_raw is not None:
|
||
|
assert isinstance(exposed_raw, Sequence)
|
||
|
for fwd in exposed_raw:
|
||
|
assert isinstance(fwd, Mapping)
|
||
|
dest = to_mac(fwd["dest"]) # type: ignore[arg-type]
|
||
|
port = to_port(fwd["port"]) # type: ignore[arg-type]
|
||
|
exposed[dest].append(port)
|
||
|
forwarded_raw = obj.get("forwarded")
|
||
|
forwarded = defaultdict[MACAddress, dict[Port, Port]](dict)
|
||
|
if forwarded_raw is not None:
|
||
|
assert isinstance(forwarded_raw, Sequence)
|
||
|
for smap in forwarded_raw:
|
||
|
assert isinstance(smap, Mapping)
|
||
|
dest = to_mac(smap["dest"]) # type: ignore[arg-type]
|
||
|
wanPort = to_port(smap["wanPort"]) # type: ignore[arg-type]
|
||
|
lanPort = to_port(smap["lanPort"]) # type: ignore[arg-type]
|
||
|
forwarded[dest][wanPort] = lanPort
|
||
|
return ProtocolConfig(
|
||
|
protocol=NftProtocol(protocol),
|
||
|
exposed=exposed,
|
||
|
forwarded=forwarded,
|
||
|
)
|
||
|
|
||
|
|
||
|
@define(
|
||
|
frozen=True,
|
||
|
kw_only=True,
|
||
|
)
|
||
|
class InterfaceConfig:
|
||
|
ifname: IfName
|
||
|
protocols: Sequence[ProtocolConfig]
|
||
|
|
||
|
@cached_property
|
||
|
def macs(self) -> Sequence[MACAddress]:
|
||
|
return tuple(
|
||
|
set(
|
||
|
chain(
|
||
|
(mac for proto in self.protocols for mac in proto.exposed.keys()),
|
||
|
(mac for proto in self.protocols for mac in proto.forwarded.keys()),
|
||
|
)
|
||
|
)
|
||
|
)
|
||
|
|
||
|
@staticmethod
|
||
|
def from_json(ifname: str, obj: JsonObj) -> InterfaceConfig:
|
||
|
assert set(obj.keys()) <= set(("ports",))
|
||
|
ports = obj.get("ports")
|
||
|
assert ports == None or isinstance(ports, Mapping)
|
||
|
return InterfaceConfig(
|
||
|
ifname=IfName(ifname),
|
||
|
protocols=tuple()
|
||
|
if ports == None
|
||
|
else tuple(
|
||
|
ProtocolConfig.from_json(proto, cast(JsonObj, proto_cfg))
|
||
|
for proto, proto_cfg in ports.items() # type: ignore[union-attr]
|
||
|
),
|
||
|
)
|
||
|
|
||
|
|
||
|
@define(
|
||
|
frozen=True,
|
||
|
kw_only=True,
|
||
|
)
|
||
|
class AppConfig:
|
||
|
nft_table: NftTable
|
||
|
interfaces: Sequence[InterfaceConfig]
|
||
|
|
||
|
@staticmethod
|
||
|
def from_json(obj: JsonObj) -> AppConfig:
|
||
|
assert set(obj.keys()) <= set(("interfaces", "nftTable"))
|
||
|
nft_table = obj["nftTable"]
|
||
|
assert isinstance(nft_table, str)
|
||
|
interfaces = obj["interfaces"]
|
||
|
assert isinstance(interfaces, Mapping)
|
||
|
return AppConfig(
|
||
|
nft_table=NftTable(nft_table),
|
||
|
interfaces=tuple(
|
||
|
InterfaceConfig.from_json(ifname, cast(JsonObj, if_cfg))
|
||
|
for ifname, if_cfg in interfaces.items()
|
||
|
),
|
||
|
)
|
||
|
|
||
|
|
||
|
def read_config_file(path: Path) -> AppConfig:
|
||
|
with path.open("r") as fh:
|
||
|
json_data = json.load(fh)
|
||
|
logger.debug(repr(json_data))
|
||
|
return AppConfig.from_json(json_data)
|
||
|
|
||
|
|
||
|
LOG_LEVEL_MAP = {
|
||
|
"critical": logging.CRITICAL,
|
||
|
"error": logging.ERROR,
|
||
|
"warning": logging.WARNING,
|
||
|
"info": logging.INFO,
|
||
|
"debug": logging.DEBUG,
|
||
|
}
|
||
|
|
||
|
|
||
|
def _gen_if_updater(
|
||
|
configs: Sequence[InterfaceConfig], nft_updater: UpdateHandler[NftUpdate]
|
||
|
) -> Sequence[InterfaceUpdateHandler]:
|
||
|
return tuple(
|
||
|
InterfaceUpdateHandler(
|
||
|
config=if_cfg,
|
||
|
nft_handler=nft_updater,
|
||
|
)
|
||
|
for if_cfg in configs
|
||
|
)
|
||
|
|
||
|
|
||
|
def static_part_generation(config: AppConfig) -> None:
|
||
|
dummy = IgnoreHandler()
|
||
|
if_updater = _gen_if_updater(config.interfaces, dummy)
|
||
|
for if_up in if_updater:
|
||
|
print(if_up.gen_set_definitions())
|
||
|
|
||
|
|
||
|
def service_execution(args: argparse.Namespace, config: AppConfig) -> NoReturn:
|
||
|
nft_updater = NftUpdateHandler(
|
||
|
table=config.nft_table,
|
||
|
update_cmd=shlex.split(args.nft_command),
|
||
|
handler=SystemdHandler(),
|
||
|
)
|
||
|
if_updater = _gen_if_updater(config.interfaces, nft_updater)
|
||
|
burst_handler = UpdateBurstHandler[IpAddressUpdate](
|
||
|
burst_interval=0.5,
|
||
|
handler=if_updater,
|
||
|
)
|
||
|
monitor_ip(shlex.split(args.ip_command), burst_handler)
|
||
|
|
||
|
|
||
|
def setup_logging(args: Any) -> None:
|
||
|
systemd_service = os.environ.get("INVOCATION_ID") and Path("/dev/log").exists()
|
||
|
if systemd_service:
|
||
|
logger.setLevel(logging.DEBUG)
|
||
|
logger.addHandler(JournalHandler(SYSLOG_IDENTIFIER="nft-update-addresses"))
|
||
|
else:
|
||
|
logging.basicConfig() # get output to stdout/stderr
|
||
|
logger.setLevel(LOG_LEVEL_MAP[args.log_level])
|
||
|
|
||
|
|
||
|
def main() -> None:
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("-c", "--config-file", required=True)
|
||
|
parser.add_argument("--check-config", action="store_true")
|
||
|
parser.add_argument("--output-set-definitions", action="store_true")
|
||
|
parser.add_argument("--ip-command", default="/usr/bin/env ip")
|
||
|
parser.add_argument("--nft-command", default="/usr/bin/env nft")
|
||
|
parser.add_argument(
|
||
|
"-l",
|
||
|
"--log-level",
|
||
|
default="error",
|
||
|
choices=LOG_LEVEL_MAP.keys(),
|
||
|
help="Log level for outputs to stdout/stderr (ignored when launched in a systemd service)",
|
||
|
)
|
||
|
args = parser.parse_args()
|
||
|
setup_logging(args)
|
||
|
config = read_config_file(Path(args.config_file))
|
||
|
if args.check_config:
|
||
|
return
|
||
|
if args.output_set_definitions:
|
||
|
return static_part_generation(config)
|
||
|
service_execution(args, config)
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|