add nft-update-addresses as package & with a NixOS module
parent
0b0e008fdb
commit
142e970bdc
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
imports = [
|
||||||
|
# files
|
||||||
|
./nft-update-addresses.nix
|
||||||
|
];
|
||||||
|
}
|
@ -0,0 +1,186 @@
|
|||||||
|
{
|
||||||
|
config,
|
||||||
|
lib,
|
||||||
|
pkgs,
|
||||||
|
...
|
||||||
|
}:
|
||||||
|
let
|
||||||
|
servName = "nft-update-addresses";
|
||||||
|
cfg = config.services.${servName};
|
||||||
|
settingsFormat = pkgs.formats.json { };
|
||||||
|
mkDisableOption = desc: lib.mkEnableOption desc // { default = true; };
|
||||||
|
# output options values
|
||||||
|
configFile = pkgs.writeTextFile {
|
||||||
|
name = "${servName}.json";
|
||||||
|
text = builtins.toJSON cfg.settings; # TODO can otherwise not easily check the file for errors
|
||||||
|
checkPhase = ''
|
||||||
|
${lib.getExe cfg.package} --check-config --config-file "$out"
|
||||||
|
'';
|
||||||
|
};
|
||||||
|
staticDefs = builtins.readFile (
|
||||||
|
pkgs.runCommandLocal "${servName}.nftables" { } ''
|
||||||
|
${lib.getExe cfg.package} --output-set-definitions --config-file ${configFile} > $out
|
||||||
|
''
|
||||||
|
);
|
||||||
|
in
|
||||||
|
{
|
||||||
|
|
||||||
|
options.services.${servName} = {
|
||||||
|
|
||||||
|
enable = lib.mkEnableOption "${servName} service";
|
||||||
|
|
||||||
|
package = lib.mkPackageOption pkgs (lib.singleton servName) { };
|
||||||
|
|
||||||
|
settings = lib.mkOption {
|
||||||
|
# TODO link to docu
|
||||||
|
description = "Configuration for ${servName}";
|
||||||
|
type = settingsFormat.type;
|
||||||
|
default = {
|
||||||
|
nftTable = "nixos-fw";
|
||||||
|
};
|
||||||
|
example.interfaces = {
|
||||||
|
wan0 = { };
|
||||||
|
lan0.ports.tcp = {
|
||||||
|
exposed = [
|
||||||
|
{
|
||||||
|
dest = "aa:bb:cc:dd:ee:ff";
|
||||||
|
port = 80;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
dest = "aa:bb:cc:00:11:22";
|
||||||
|
port = 80;
|
||||||
|
}
|
||||||
|
];
|
||||||
|
forwarded = [
|
||||||
|
{
|
||||||
|
dest = "aabb-ccdd-eeff";
|
||||||
|
lanPort = 80;
|
||||||
|
wanPort = 80;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
dest = "aa.bbcc.0011.22";
|
||||||
|
lanPort = 80;
|
||||||
|
wanPort = 8080;
|
||||||
|
}
|
||||||
|
];
|
||||||
|
};
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
includeStaticDefinitions = mkDisableOption ''inclusion of static definitions from {option}`services.${servName}.nftablesStaticDefinitions` into the nftables config'';
|
||||||
|
|
||||||
|
configurationFile = lib.mkOption {
|
||||||
|
description = "Path to configuration file used by ${servName}.";
|
||||||
|
type = lib.types.path; # needs to be available at build time
|
||||||
|
readOnly = true;
|
||||||
|
default = configFile;
|
||||||
|
defaultText = lib.literalExpression "# content as generated from config.services.${servName}.settings";
|
||||||
|
};
|
||||||
|
|
||||||
|
nftablesStaticDefinitions = lib.mkOption {
|
||||||
|
description = ''
|
||||||
|
Static definitions provided by ${servName} when called with given configuration.
|
||||||
|
|
||||||
|
When {option}`services.${servName}.includeStaticDefinitions (which is default),
|
||||||
|
these will be already included in your nftables setup.
|
||||||
|
Otherwise, you can use the value of this output option as you prefer.
|
||||||
|
'';
|
||||||
|
readOnly = true;
|
||||||
|
default = staticDefs;
|
||||||
|
defaultText = lib.literalExpression "# as provided by ${servName}";
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
config = lib.mkIf cfg.enable {
|
||||||
|
|
||||||
|
assertions = [
|
||||||
|
{
|
||||||
|
assertion = cfg.enable -> config.networking.nftables.enable;
|
||||||
|
message = "${servName} requires nftables to be configured";
|
||||||
|
}
|
||||||
|
# TODO assert for port duplications
|
||||||
|
];
|
||||||
|
|
||||||
|
networking.nftables.tables.${cfg.settings.nftTable}.content = lib.mkIf cfg.includeStaticDefinitions staticDefs;
|
||||||
|
|
||||||
|
systemd.services.${servName} = {
|
||||||
|
description = "IPv6 prefix updater for subnet & NAT rules for nftables router setup";
|
||||||
|
after = [
|
||||||
|
"nftables.service"
|
||||||
|
"network.target"
|
||||||
|
];
|
||||||
|
partOf = lib.singleton "nftables.service";
|
||||||
|
requisite = lib.singleton "nftables.service";
|
||||||
|
wantedBy = lib.singleton "multi-user.target";
|
||||||
|
upheldBy = lib.singleton "systemd-networkd.service";
|
||||||
|
restartIfChanged = true;
|
||||||
|
restartTriggers = config.systemd.services.nftables.restartTriggers;
|
||||||
|
serviceConfig = {
|
||||||
|
# Service
|
||||||
|
Type = "notify-reload";
|
||||||
|
ExecStart = lib.singleton "${lib.getExe cfg.package} ${
|
||||||
|
lib.cli.toGNUCommandLineShell { } {
|
||||||
|
config-file = configFile;
|
||||||
|
ip-command = "${pkgs.iproute2}/bin/ip";
|
||||||
|
nft-command = lib.getExe pkgs.nftables;
|
||||||
|
}
|
||||||
|
}";
|
||||||
|
RestartSec = "250ms";
|
||||||
|
RestartSteps = 3;
|
||||||
|
RestartMaxDelaySec = "3s";
|
||||||
|
TimeoutSec = "10s";
|
||||||
|
Restart = "always";
|
||||||
|
NotifyAccess = "all"; # bash script opens subprocesses in pipes
|
||||||
|
# Paths
|
||||||
|
ProtectProc = "noaccess";
|
||||||
|
ProcSubset = "pid";
|
||||||
|
CapabilityBoundingSet = [
|
||||||
|
"CAP_BPF" # nft is compiled to bpf
|
||||||
|
"CAP_IPC_LOCK" # ?
|
||||||
|
"CAP_KILL" # ?
|
||||||
|
"CAP_NET_ADMIN"
|
||||||
|
];
|
||||||
|
# Security
|
||||||
|
NoNewPrivileges = true;
|
||||||
|
# Process
|
||||||
|
KeyringMode = "private";
|
||||||
|
OOMScoreAdjust = 10;
|
||||||
|
# Scheduling
|
||||||
|
Nice = -2;
|
||||||
|
CPUSchedulingPolicy = "fifo";
|
||||||
|
# Sandboxing
|
||||||
|
ProtectSystem = "strict";
|
||||||
|
ProtectHome = true;
|
||||||
|
PrivateTmp = true;
|
||||||
|
PrivateDevices = true;
|
||||||
|
PrivateNetwork = false; # breaks nftables
|
||||||
|
PrivateIPC = true;
|
||||||
|
PrivateUsers = false; # breaks nftables
|
||||||
|
ProtectClock = true;
|
||||||
|
ProtectKernelTunables = true;
|
||||||
|
ProtectKernelModules = true; # are already loaded
|
||||||
|
ProtectKernelLogs = true;
|
||||||
|
ProtectControlGroups = true;
|
||||||
|
#RestrictAddressFamilies = [
|
||||||
|
# # ?
|
||||||
|
# "AF_INET"
|
||||||
|
# "AF_INET6"
|
||||||
|
# #"AF_NETLINK"
|
||||||
|
#];
|
||||||
|
RestrictNamespaces = true;
|
||||||
|
RestrictSUIDSGID = true;
|
||||||
|
#SystemCallFilter = "@basic-io @ipc @network-io @signal @timer" # definitly will break that
|
||||||
|
#SystemCallLog = "~"; # for debugging; should lock all system calls made
|
||||||
|
# Resource Control
|
||||||
|
CPUQuota = "50%";
|
||||||
|
# TODO test to gather real values
|
||||||
|
MemoryLow = "8M";
|
||||||
|
MemoryHigh = "32M";
|
||||||
|
MemoryMax = "128M";
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
@ -0,0 +1,727 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import (
|
||||||
|
ABC,
|
||||||
|
abstractmethod,
|
||||||
|
)
|
||||||
|
import argparse
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import (
|
||||||
|
Mapping,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
|
from enum import (
|
||||||
|
Enum,
|
||||||
|
Flag,
|
||||||
|
auto,
|
||||||
|
)
|
||||||
|
from functools import cached_property
|
||||||
|
import io
|
||||||
|
from ipaddress import (
|
||||||
|
IPv4Interface,
|
||||||
|
IPv6Interface,
|
||||||
|
IPv6Network,
|
||||||
|
)
|
||||||
|
from itertools import chain
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from logging.handlers import SysLogHandler
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
from threading import (
|
||||||
|
RLock,
|
||||||
|
Timer,
|
||||||
|
)
|
||||||
|
import traceback
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Iterable,
|
||||||
|
NewType,
|
||||||
|
NoReturn,
|
||||||
|
Protocol,
|
||||||
|
TypeAlias,
|
||||||
|
TypeGuard,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
|
from attrs import (
|
||||||
|
define,
|
||||||
|
field,
|
||||||
|
)
|
||||||
|
from systemd import daemon # type: ignore[import-untyped]
|
||||||
|
from systemd.journal import JournalHandler # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def raise_and_exit(args: Any) -> None:
|
||||||
|
Timer(0.01, os._exit, args=(1,)).start()
|
||||||
|
raise args[0]
|
||||||
|
|
||||||
|
|
||||||
|
# ensure exceptions in any thread brings the program down
|
||||||
|
# important for proper error detection via tests & in random cases in real world
|
||||||
|
|
||||||
|
threading.excepthook = raise_and_exit
|
||||||
|
|
||||||
|
|
||||||
|
JsonVal: TypeAlias = Union["JsonObj", "JsonList", str, int, bool]
|
||||||
|
JsonList: TypeAlias = Sequence[JsonVal]
|
||||||
|
JsonObj: TypeAlias = Mapping[str, JsonVal]
|
||||||
|
|
||||||
|
T = TypeVar("T", contravariant=True)
|
||||||
|
|
||||||
|
MACAddress = NewType("MACAddress", str)
|
||||||
|
# format: aabbccddeeff (lower-case, without separators)
|
||||||
|
NftProtocol = NewType("NftProtocol", str) # e.g. tcp, udp, …
|
||||||
|
Port = NewType("Port", int)
|
||||||
|
IfName = NewType("IfName", str)
|
||||||
|
NftTable = NewType("NftTable", str)
|
||||||
|
|
||||||
|
|
||||||
|
def to_mac(mac_str: str) -> MACAddress:
|
||||||
|
eui48 = re.sub(r"[.:_-]", "", mac_str.lower())
|
||||||
|
if not is_mac(eui48):
|
||||||
|
raise ValueError(f"invalid MAC address / EUI48: {mac_str}")
|
||||||
|
return MACAddress(eui48)
|
||||||
|
|
||||||
|
|
||||||
|
def is_mac(mac_str: str) -> TypeGuard[MACAddress]:
|
||||||
|
return re.match(r"^[0-9a-f]{12}$", mac_str) != None
|
||||||
|
|
||||||
|
|
||||||
|
def to_port(port_str: str | int) -> Port:
|
||||||
|
try:
|
||||||
|
port = int(port_str)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise ValueError(f"invalid port number: {port_str}") from exc
|
||||||
|
if not is_port(port):
|
||||||
|
raise ValueError(f"invalid port number: {port_str}")
|
||||||
|
return Port(port)
|
||||||
|
|
||||||
|
|
||||||
|
def is_port(port: int) -> TypeGuard[Port]:
|
||||||
|
return 0 < port < 65536
|
||||||
|
|
||||||
|
|
||||||
|
def slaac_eui48(prefix: IPv6Network, eui48: MACAddress) -> IPv6Interface:
|
||||||
|
if prefix.prefixlen > 64:
|
||||||
|
raise ValueError(
|
||||||
|
f"a SLAAC IPv6 address requires a CIDR of at least /64, got {prefix}"
|
||||||
|
)
|
||||||
|
eui64 = eui48[0:6] + "fffe" + eui48[6:]
|
||||||
|
modified = hex(int(eui64[0:2], 16) ^ 2)[2:].zfill(2) + eui64[2:]
|
||||||
|
euil = int(modified, 16)
|
||||||
|
return IPv6Interface(f"{prefix[euil].compressed}/{prefix.prefixlen}")
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateHandler(Protocol[T]):
|
||||||
|
def update(self, data: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def update_stack(self, data: Sequence[T]) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateStackHandler(UpdateHandler[T], ABC):
|
||||||
|
def update(self, data: T) -> None:
|
||||||
|
return self._update_stack((data,))
|
||||||
|
|
||||||
|
def update_stack(self, data: Sequence[T]) -> None:
|
||||||
|
if len(data) <= 0:
|
||||||
|
logger.warning(
|
||||||
|
f"[bug, please report upstream] received empty data in update_stack. Traceback:\n{''.join(traceback.format_stack())}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
return self._update_stack(data)
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _update_stack(self, data: Sequence[T]) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class IgnoreHandler(UpdateStackHandler[object]):
|
||||||
|
def _update_stack(self, data: Sequence[object]) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@define(
|
||||||
|
kw_only=True,
|
||||||
|
slots=False,
|
||||||
|
)
|
||||||
|
class UpdateBurstHandler(UpdateStackHandler[T]):
|
||||||
|
burst_interval: float
|
||||||
|
handler: Sequence[UpdateHandler[T]]
|
||||||
|
__lock: RLock = field(factory=RLock)
|
||||||
|
__updates: list[T] = field(factory=list)
|
||||||
|
__timer: Timer | None = None
|
||||||
|
|
||||||
|
def _update_stack(self, data: Sequence[T]) -> None:
|
||||||
|
with self.__lock:
|
||||||
|
self.__updates.extend(data)
|
||||||
|
self.__refresh_timer()
|
||||||
|
|
||||||
|
def __refresh_timer(self) -> None:
|
||||||
|
with self.__lock:
|
||||||
|
if self.__timer is not None:
|
||||||
|
# try to cancel
|
||||||
|
# not a problem if timer already elapsed but before processing really started
|
||||||
|
# because due to using locks when accessing updates
|
||||||
|
self.__timer.cancel()
|
||||||
|
self.__timer = Timer(
|
||||||
|
interval=self.burst_interval,
|
||||||
|
function=self.__process_updates,
|
||||||
|
)
|
||||||
|
self.__timer.start()
|
||||||
|
|
||||||
|
def __process_updates(self) -> None:
|
||||||
|
with self.__lock:
|
||||||
|
self.__timer = None
|
||||||
|
if not self.__updates:
|
||||||
|
return
|
||||||
|
updates = self.__updates
|
||||||
|
self.__updates = []
|
||||||
|
for handler in self.handler:
|
||||||
|
handler.update_stack(updates)
|
||||||
|
|
||||||
|
|
||||||
|
class IpFlag(Flag):
|
||||||
|
dynamic = auto()
|
||||||
|
mngtmpaddr = auto()
|
||||||
|
noprefixroute = auto()
|
||||||
|
temporary = auto()
|
||||||
|
tentiative = auto()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_str(flags_str: Sequence[str], ignore_unknown: bool = True) -> IpFlag:
|
||||||
|
flags = IpFlag(0)
|
||||||
|
for flag in flags_str:
|
||||||
|
flag = flag.lower()
|
||||||
|
member = IpFlag.__members__.get(flag)
|
||||||
|
if member is not None:
|
||||||
|
flags |= member
|
||||||
|
elif not ignore_unknown:
|
||||||
|
raise Exception(f"Unrecognized IpFlag: {flag}")
|
||||||
|
return flags
|
||||||
|
|
||||||
|
|
||||||
|
IP_MON_PATTERN = re.compile(
|
||||||
|
r"""(?x)^
|
||||||
|
(?P<deleted>[Dd]eleted\s+)?
|
||||||
|
(?P<ifindex>\d+):\s+
|
||||||
|
(?P<ifname>\S+)\s+
|
||||||
|
(?P<type>inet6?)\s+
|
||||||
|
(?P<ip>\S+)\s+
|
||||||
|
#(?:metric\s+\S+\s+)? # sometimes possible
|
||||||
|
#(?:brd\s+\S+\s+)? # broadcast IP on inet
|
||||||
|
(?:\S+\s+\S+\s+)* # abstracted irrelevant attributes
|
||||||
|
(?:scope\s+(?P<scope>\S+)\s+)
|
||||||
|
(?P<flags>(?:(\S+)\s)*) # (single spaces required for parser below to work correctly)
|
||||||
|
(?:\S+)? # random interface name repetition on inet
|
||||||
|
[\\]\s+
|
||||||
|
.* # lifetimes which are not interesting yet
|
||||||
|
$"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@define(
|
||||||
|
frozen=True,
|
||||||
|
kw_only=True,
|
||||||
|
)
|
||||||
|
class IpAddressUpdate:
|
||||||
|
deleted: bool
|
||||||
|
ifindex: int
|
||||||
|
ifname: IfName
|
||||||
|
ip: IPv4Interface | IPv6Interface
|
||||||
|
scope: str
|
||||||
|
flags: IpFlag
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_line(line: str) -> IpAddressUpdate:
|
||||||
|
m = IP_MON_PATTERN.search(line)
|
||||||
|
if not m:
|
||||||
|
raise Exception(f"Could not parse ip monitor output: {line!r}")
|
||||||
|
grp = m.groupdict()
|
||||||
|
ip_type: type[IPv4Interface | IPv6Interface] = (
|
||||||
|
IPv6Interface if grp["type"] == "inet6" else IPv4Interface
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
ip = ip_type(grp["ip"])
|
||||||
|
except ValueError as e:
|
||||||
|
raise Exception(
|
||||||
|
f"Could not parse ip monitor output, invalid IP: {grp['ip']!r}"
|
||||||
|
) from e
|
||||||
|
flags = IpFlag.parse_str(grp["flags"].strip().split(" "))
|
||||||
|
return IpAddressUpdate(
|
||||||
|
deleted=grp["deleted"] != None,
|
||||||
|
ifindex=int(grp["ifindex"]),
|
||||||
|
ifname=IfName(grp["ifname"]),
|
||||||
|
ip=ip,
|
||||||
|
scope=grp["scope"],
|
||||||
|
flags=flags,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def monitor_ip(
|
||||||
|
ip_cmd: list[str],
|
||||||
|
handler: UpdateHandler[IpAddressUpdate],
|
||||||
|
) -> NoReturn:
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
ip_cmd + ["-o", "monitor", "address"],
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
# initial kickoff (AFTER starting monitoring, to not miss any update)
|
||||||
|
logger.info("kickoff IP monitoring with current data")
|
||||||
|
res = subprocess.run(
|
||||||
|
ip_cmd + ["-o", "address", "show"],
|
||||||
|
check=True,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
for line in res.stdout.splitlines(keepends=False):
|
||||||
|
line = line.rstrip()
|
||||||
|
if line == "":
|
||||||
|
continue
|
||||||
|
update = IpAddressUpdate.parse_line(line)
|
||||||
|
logger.debug(f"pass IP update: {update!r}")
|
||||||
|
handler.update(update)
|
||||||
|
logger.info("loading kickoff finished; start regular monitoring")
|
||||||
|
while True:
|
||||||
|
rc = proc.poll()
|
||||||
|
if rc != None:
|
||||||
|
# flush stdout for easier debugging
|
||||||
|
logger.error("Last stdout of monitor process:")
|
||||||
|
logger.error(proc.stdout.read()) # type: ignore[union-attr]
|
||||||
|
raise Exception(f"Monitor process crashed with returncode {rc}")
|
||||||
|
line = proc.stdout.readline().rstrip() # type: ignore[union-attr]
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
logger.info("IP change detected")
|
||||||
|
update = IpAddressUpdate.parse_line(line)
|
||||||
|
logger.debug(f"pass IP update: {update!r}")
|
||||||
|
handler.update(update)
|
||||||
|
|
||||||
|
|
||||||
|
class InterfaceUpdateHandler(UpdateStackHandler[IpAddressUpdate]):
|
||||||
|
# TODO regularly check (i.e. 1 hour) if stored lists are still correct
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: InterfaceConfig,
|
||||||
|
nft_handler: UpdateHandler[NftUpdate],
|
||||||
|
) -> None:
|
||||||
|
self.nft_handler = nft_handler
|
||||||
|
self.lock = RLock()
|
||||||
|
self.config = config
|
||||||
|
self.ipv4Addrs = list[IPv4Interface]()
|
||||||
|
self.ipv6Addrs = list[IPv6Interface]()
|
||||||
|
|
||||||
|
def _update_stack(self, data: Sequence[IpAddressUpdate]) -> None:
|
||||||
|
nft_updates = tuple(
|
||||||
|
chain.from_iterable(self.__gen_updates(single) for single in data)
|
||||||
|
)
|
||||||
|
if len(nft_updates) <= 0:
|
||||||
|
return
|
||||||
|
self.nft_handler.update_stack(nft_updates)
|
||||||
|
|
||||||
|
def __gen_updates(self, data: IpAddressUpdate) -> Iterable[NftUpdate]:
|
||||||
|
if data.ifname != self.config.ifname:
|
||||||
|
return
|
||||||
|
if data.ip.is_link_local:
|
||||||
|
logger.debug(
|
||||||
|
f"{self.config.ifname}: ignore change for IP {data.ip} because link-local"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if IpFlag.temporary in data.flags:
|
||||||
|
logger.debug(
|
||||||
|
f"{self.config.ifname}: ignore change for IP {data.ip} because temporary"
|
||||||
|
)
|
||||||
|
return # ignore IPv6 privacy extension addresses
|
||||||
|
if IpFlag.tentiative in data.flags:
|
||||||
|
logger.debug(
|
||||||
|
f"{self.config.ifname}: ignore change for IP {data.ip} because tentiative"
|
||||||
|
)
|
||||||
|
return # ignore (yet) tentiative addresses
|
||||||
|
logger.debug(f"{self.config.ifname}: process change of IP {data.ip}")
|
||||||
|
with self.lock:
|
||||||
|
ip_list: list[IPv4Interface] | list[IPv6Interface] = (
|
||||||
|
self.ipv6Addrs if isinstance(data.ip, IPv6Interface) else self.ipv4Addrs
|
||||||
|
)
|
||||||
|
if data.deleted != (data.ip in ip_list):
|
||||||
|
return # no change required
|
||||||
|
if data.deleted:
|
||||||
|
logger.info(f"{self.config.ifname}: deleted IP {data.ip}")
|
||||||
|
ip_list.remove(data.ip) # type: ignore[arg-type]
|
||||||
|
else:
|
||||||
|
logger.info(f"{self.config.ifname}: discovered IP {data.ip}")
|
||||||
|
ip_list.append(data.ip) # type: ignore[arg-type]
|
||||||
|
set_prefix = f"{self.config.ifname}v{data.ip.version}"
|
||||||
|
op = NftValueOperation.if_deleted(data.deleted)
|
||||||
|
yield NftUpdate(
|
||||||
|
obj_type="set",
|
||||||
|
obj_name=f"{set_prefix}net",
|
||||||
|
operation=op,
|
||||||
|
values=(data.ip.network.compressed,),
|
||||||
|
)
|
||||||
|
link_local_space = IPv6Network("fc00::/7") # because ip.is_private is wrong
|
||||||
|
if data.ip in link_local_space:
|
||||||
|
logger.debug(
|
||||||
|
f"{self.config.ifname}: only updated {set_prefix}net for changes in fc00::/7"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
yield NftUpdate(
|
||||||
|
obj_type="set",
|
||||||
|
obj_name=f"{set_prefix}addr",
|
||||||
|
operation=op,
|
||||||
|
values=(data.ip.ip.compressed,),
|
||||||
|
)
|
||||||
|
if data.ip.version != 6:
|
||||||
|
return
|
||||||
|
slaacs = {mac: slaac_eui48(data.ip.network, mac) for mac in self.config.macs}
|
||||||
|
for mac in self.config.macs:
|
||||||
|
yield NftUpdate(
|
||||||
|
obj_type="set",
|
||||||
|
obj_name=f"{set_prefix}_{mac}",
|
||||||
|
operation=NftValueOperation.REPLACE,
|
||||||
|
values=(slaacs[mac].ip.compressed,),
|
||||||
|
)
|
||||||
|
for proto in self.config.protocols:
|
||||||
|
yield NftUpdate(
|
||||||
|
obj_type="set",
|
||||||
|
obj_name=f"{set_prefix}exp{proto.protocol}",
|
||||||
|
operation=NftValueOperation.REPLACE,
|
||||||
|
values=tuple(
|
||||||
|
f"{slaacs[mac].ip.compressed} . {port}"
|
||||||
|
for mac, portList in proto.exposed.items()
|
||||||
|
for port in portList
|
||||||
|
),
|
||||||
|
)
|
||||||
|
yield NftUpdate(
|
||||||
|
obj_type="map",
|
||||||
|
obj_name=f"{set_prefix}dnat{proto.protocol}",
|
||||||
|
operation=NftValueOperation.REPLACE,
|
||||||
|
values=tuple(
|
||||||
|
f"{wan} : {slaacs[mac].ip.compressed} . {lan}"
|
||||||
|
for mac, portMap in proto.forwarded.items()
|
||||||
|
for wan, lan in portMap.items()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen_set_definitions(self) -> str:
|
||||||
|
output = []
|
||||||
|
for ip_v in [4, 6]:
|
||||||
|
addr_type = f"ipv{ip_v}_addr"
|
||||||
|
set_prefix = f"{self.config.ifname}v{ip_v}"
|
||||||
|
output.append(gen_set_def("set", f"{set_prefix}addr", addr_type))
|
||||||
|
output.append(gen_set_def("set", f"{set_prefix}net", addr_type, "interval"))
|
||||||
|
if ip_v != 6:
|
||||||
|
continue
|
||||||
|
for mac in self.config.macs:
|
||||||
|
output.append(gen_set_def("set", f"{set_prefix}_{mac}", addr_type))
|
||||||
|
for proto in self.config.protocols:
|
||||||
|
output.append(
|
||||||
|
gen_set_def(
|
||||||
|
"set",
|
||||||
|
f"{set_prefix}exp{proto.protocol}",
|
||||||
|
f"{addr_type} . inet_service",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
output.append(
|
||||||
|
gen_set_def(
|
||||||
|
"map",
|
||||||
|
f"{set_prefix}dnat{proto.protocol}",
|
||||||
|
f"inet_service : {addr_type} . inet_service",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return "\n".join(output)
|
||||||
|
|
||||||
|
|
||||||
|
def gen_set_def(
|
||||||
|
set_type: str,
|
||||||
|
name: str,
|
||||||
|
data_type: str,
|
||||||
|
flags: str | None = None,
|
||||||
|
elements: Sequence[str] = tuple(),
|
||||||
|
) -> str:
|
||||||
|
return "\n".join(
|
||||||
|
line
|
||||||
|
for line in (
|
||||||
|
f"{set_type} {name} " + "{",
|
||||||
|
f" type {data_type}",
|
||||||
|
f" flags {flags}" if flags is not None else None,
|
||||||
|
" elements = { " + ", ".join(elements) + " }"
|
||||||
|
if len(elements) > 0
|
||||||
|
else None,
|
||||||
|
"}",
|
||||||
|
)
|
||||||
|
if line is not None
|
||||||
|
)
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NftValueOperation(Enum):
|
||||||
|
ADD = auto()
|
||||||
|
DELETE = auto()
|
||||||
|
REPLACE = auto()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def if_deleted(b: bool) -> NftValueOperation:
|
||||||
|
return NftValueOperation.DELETE if b else NftValueOperation.ADD
|
||||||
|
|
||||||
|
|
||||||
|
@define(
|
||||||
|
frozen=True,
|
||||||
|
kw_only=True,
|
||||||
|
)
|
||||||
|
class NftUpdate:
|
||||||
|
obj_type: str
|
||||||
|
obj_name: str
|
||||||
|
operation: NftValueOperation
|
||||||
|
values: Sequence[str]
|
||||||
|
|
||||||
|
def to_script(self, table: NftTable) -> str:
|
||||||
|
lines = []
|
||||||
|
# inet family is the only which supports shared IPv4 & IPv6 entries
|
||||||
|
obj_id = f"inet {table} {self.obj_name}"
|
||||||
|
if self.operation == NftValueOperation.REPLACE:
|
||||||
|
lines.append(f"flush {self.obj_type} {obj_id}")
|
||||||
|
if len(self.values) > 0:
|
||||||
|
op_str = "destroy" if self.operation == NftValueOperation.DELETE else "add"
|
||||||
|
values_str = ", ".join(self.values)
|
||||||
|
lines.append(f"{op_str} element {obj_id} {{ {values_str} }}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
class NftUpdateHandler(UpdateStackHandler[NftUpdate]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
update_cmd: Sequence[str],
|
||||||
|
table: NftTable,
|
||||||
|
handler: UpdateHandler[None],
|
||||||
|
) -> None:
|
||||||
|
self.update_cmd = update_cmd
|
||||||
|
self.table = table
|
||||||
|
self.handler = handler
|
||||||
|
|
||||||
|
def _update_stack(self, data: Sequence[NftUpdate]) -> None:
|
||||||
|
logger.debug("compile stacked updates for nftables")
|
||||||
|
script = "\n".join(
|
||||||
|
map(
|
||||||
|
lambda u: u.to_script(table=self.table),
|
||||||
|
data,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
logger.debug(f"pass updates to nftables:\n{script}")
|
||||||
|
subprocess.run(
|
||||||
|
list(self.update_cmd) + ["-f", "-"],
|
||||||
|
input=script,
|
||||||
|
check=True,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
self.handler.update(None)
|
||||||
|
|
||||||
|
|
||||||
|
class SystemdHandler(UpdateHandler[object]):
|
||||||
|
def update(self, data: object) -> None:
|
||||||
|
# TODO improve status updates
|
||||||
|
# daemon.notify("READY=1\nSTATUS=Updated successfully.\n")
|
||||||
|
daemon.notify("READY=1\n")
|
||||||
|
|
||||||
|
def update_stack(self, data: Sequence[object]) -> None:
|
||||||
|
self.update(None)
|
||||||
|
|
||||||
|
|
||||||
|
@define(
|
||||||
|
frozen=True,
|
||||||
|
kw_only=True,
|
||||||
|
)
|
||||||
|
class ProtocolConfig:
|
||||||
|
protocol: NftProtocol
|
||||||
|
exposed: Mapping[MACAddress, Sequence[Port]]
|
||||||
|
"only when direct public IPs are available"
|
||||||
|
forwarded: Mapping[MACAddress, Mapping[Port, Port]] # wan -> lan
|
||||||
|
"i.e. DNAT"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(protocol: str, obj: JsonObj) -> ProtocolConfig:
|
||||||
|
assert set(obj.keys()) <= set(("exposed", "forwarded"))
|
||||||
|
exposed_raw = obj.get("exposed")
|
||||||
|
exposed = defaultdict[MACAddress, list[Port]](list)
|
||||||
|
if exposed_raw is not None:
|
||||||
|
assert isinstance(exposed_raw, Sequence)
|
||||||
|
for fwd in exposed_raw:
|
||||||
|
assert isinstance(fwd, Mapping)
|
||||||
|
dest = to_mac(fwd["dest"]) # type: ignore[arg-type]
|
||||||
|
port = to_port(fwd["port"]) # type: ignore[arg-type]
|
||||||
|
exposed[dest].append(port)
|
||||||
|
forwarded_raw = obj.get("forwarded")
|
||||||
|
forwarded = defaultdict[MACAddress, dict[Port, Port]](dict)
|
||||||
|
if forwarded_raw is not None:
|
||||||
|
assert isinstance(forwarded_raw, Sequence)
|
||||||
|
for smap in forwarded_raw:
|
||||||
|
assert isinstance(smap, Mapping)
|
||||||
|
dest = to_mac(smap["dest"]) # type: ignore[arg-type]
|
||||||
|
wanPort = to_port(smap["wanPort"]) # type: ignore[arg-type]
|
||||||
|
lanPort = to_port(smap["lanPort"]) # type: ignore[arg-type]
|
||||||
|
forwarded[dest][wanPort] = lanPort
|
||||||
|
return ProtocolConfig(
|
||||||
|
protocol=NftProtocol(protocol),
|
||||||
|
exposed=exposed,
|
||||||
|
forwarded=forwarded,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@define(
|
||||||
|
frozen=True,
|
||||||
|
kw_only=True,
|
||||||
|
)
|
||||||
|
class InterfaceConfig:
|
||||||
|
ifname: IfName
|
||||||
|
protocols: Sequence[ProtocolConfig]
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def macs(self) -> Sequence[MACAddress]:
|
||||||
|
return tuple(
|
||||||
|
set(
|
||||||
|
chain(
|
||||||
|
(mac for proto in self.protocols for mac in proto.exposed.keys()),
|
||||||
|
(mac for proto in self.protocols for mac in proto.forwarded.keys()),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(ifname: str, obj: JsonObj) -> InterfaceConfig:
|
||||||
|
assert set(obj.keys()) <= set(("ports",))
|
||||||
|
ports = obj.get("ports")
|
||||||
|
assert ports == None or isinstance(ports, Mapping)
|
||||||
|
return InterfaceConfig(
|
||||||
|
ifname=IfName(ifname),
|
||||||
|
protocols=tuple()
|
||||||
|
if ports == None
|
||||||
|
else tuple(
|
||||||
|
ProtocolConfig.from_json(proto, cast(JsonObj, proto_cfg))
|
||||||
|
for proto, proto_cfg in ports.items() # type: ignore[union-attr]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@define(
|
||||||
|
frozen=True,
|
||||||
|
kw_only=True,
|
||||||
|
)
|
||||||
|
class AppConfig:
|
||||||
|
nft_table: NftTable
|
||||||
|
interfaces: Sequence[InterfaceConfig]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_json(obj: JsonObj) -> AppConfig:
|
||||||
|
assert set(obj.keys()) <= set(("interfaces", "nftTable"))
|
||||||
|
nft_table = obj["nftTable"]
|
||||||
|
assert isinstance(nft_table, str)
|
||||||
|
interfaces = obj["interfaces"]
|
||||||
|
assert isinstance(interfaces, Mapping)
|
||||||
|
return AppConfig(
|
||||||
|
nft_table=NftTable(nft_table),
|
||||||
|
interfaces=tuple(
|
||||||
|
InterfaceConfig.from_json(ifname, cast(JsonObj, if_cfg))
|
||||||
|
for ifname, if_cfg in interfaces.items()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def read_config_file(path: Path) -> AppConfig:
|
||||||
|
with path.open("r") as fh:
|
||||||
|
json_data = json.load(fh)
|
||||||
|
logger.debug(repr(json_data))
|
||||||
|
return AppConfig.from_json(json_data)
|
||||||
|
|
||||||
|
|
||||||
|
LOG_LEVEL_MAP = {
|
||||||
|
"critical": logging.CRITICAL,
|
||||||
|
"error": logging.ERROR,
|
||||||
|
"warning": logging.WARNING,
|
||||||
|
"info": logging.INFO,
|
||||||
|
"debug": logging.DEBUG,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _gen_if_updater(
|
||||||
|
configs: Sequence[InterfaceConfig], nft_updater: UpdateHandler[NftUpdate]
|
||||||
|
) -> Sequence[InterfaceUpdateHandler]:
|
||||||
|
return tuple(
|
||||||
|
InterfaceUpdateHandler(
|
||||||
|
config=if_cfg,
|
||||||
|
nft_handler=nft_updater,
|
||||||
|
)
|
||||||
|
for if_cfg in configs
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def static_part_generation(config: AppConfig) -> None:
|
||||||
|
dummy = IgnoreHandler()
|
||||||
|
if_updater = _gen_if_updater(config.interfaces, dummy)
|
||||||
|
for if_up in if_updater:
|
||||||
|
print(if_up.gen_set_definitions())
|
||||||
|
|
||||||
|
|
||||||
|
def service_execution(args: argparse.Namespace, config: AppConfig) -> NoReturn:
|
||||||
|
nft_updater = NftUpdateHandler(
|
||||||
|
table=config.nft_table,
|
||||||
|
update_cmd=shlex.split(args.nft_command),
|
||||||
|
handler=SystemdHandler(),
|
||||||
|
)
|
||||||
|
if_updater = _gen_if_updater(config.interfaces, nft_updater)
|
||||||
|
burst_handler = UpdateBurstHandler[IpAddressUpdate](
|
||||||
|
burst_interval=0.5,
|
||||||
|
handler=if_updater,
|
||||||
|
)
|
||||||
|
monitor_ip(shlex.split(args.ip_command), burst_handler)
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(args: Any) -> None:
|
||||||
|
systemd_service = os.environ.get("INVOCATION_ID") and Path("/dev/log").exists()
|
||||||
|
if systemd_service:
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
logger.addHandler(JournalHandler(SYSLOG_IDENTIFIER="nft-update-addresses"))
|
||||||
|
else:
|
||||||
|
logging.basicConfig() # get output to stdout/stderr
|
||||||
|
logger.setLevel(LOG_LEVEL_MAP[args.log_level])
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("-c", "--config-file", required=True)
|
||||||
|
parser.add_argument("--check-config", action="store_true")
|
||||||
|
parser.add_argument("--output-set-definitions", action="store_true")
|
||||||
|
parser.add_argument("--ip-command", default="/usr/bin/env ip")
|
||||||
|
parser.add_argument("--nft-command", default="/usr/bin/env nft")
|
||||||
|
parser.add_argument(
|
||||||
|
"-l",
|
||||||
|
"--log-level",
|
||||||
|
default="error",
|
||||||
|
choices=LOG_LEVEL_MAP.keys(),
|
||||||
|
help="Log level for outputs to stdout/stderr (ignored when launched in a systemd service)",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
setup_logging(args)
|
||||||
|
config = read_config_file(Path(args.config_file))
|
||||||
|
if args.check_config:
|
||||||
|
return
|
||||||
|
if args.output_set_definitions:
|
||||||
|
return static_part_generation(config)
|
||||||
|
service_execution(args, config)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
Reference in New Issue