""" mod_nftables.py -- nftables config generation and management. Generates and applies the routlin-nat and routlin-filter tables, manages the NAT boot service, and handles the banned_ips IP expansion logic. """ import ipaddress import json import subprocess import sys from pathlib import Path import mod_avahi as avahi import mod_radius as radius import mod_shared as shared import mod_wireguard as wireguard import mod_validation as validation NAT_SERVICE_NAME = f"{shared.PRODUCT_NAME}-nat" NAT_SERVICE_FILE = shared.SYSTEMD_DIR / f"{NAT_SERVICE_NAME}.service" # =================================================================== # Rule list helpers # =================================================================== def rule_enabled(rules): return [r for r in rules if r.get("enabled") is True] def rule_disabled(rules): return [r for r in rules if r.get("enabled") is not True] def expand_protocols(rule): """Return list of (protocol, rule, comment_suffix) tuples. When protocol is 'both', expands into tcp and udp with suffixes ' (tcp)' and ' (udp)' so generated comments are unambiguous. """ proto = rule["protocol"] if proto == "both": return [("tcp", rule, " (tcp)"), ("udp", rule, " (udp)")] return [(proto, rule, "")] # =================================================================== # Container bridge detection # =================================================================== def get_container_bridges(): """Return all active bridge interfaces. Works universally for Docker, Podman, LXC, libvirt, etc. """ try: result = subprocess.run( ["ip", "-j", "link", "show", "type", "bridge"], capture_output=True, text=True, timeout=5 ) if result.returncode != 0: return [] links = json.loads(result.stdout) return [l["ifname"] for l in links if l.get("operstate") == "UP"] except Exception: return [] # =================================================================== # banned_ips expansion # =================================================================== def _expand_banned_ipv4(ip_str): """Convert an IPv4 pattern (CIDR, wildcard, range) to nftables set elements.""" if '/' in ip_str: ipaddress.IPv4Network(ip_str, strict=False) # validate return [ip_str] parts = ip_str.split('.') if len(parts) != 4: raise ValueError(f"Invalid IPv4 pattern: {ip_str!r} - expected 4 octets") def parse_octet(s, pos): if s == '*': return (0, 255) if '-' in s: a, b = s.split('-', 1) lo, hi = int(a), int(b) if not (0 <= lo <= hi <= 255): raise ValueError(f"Invalid octet range {s!r} in {ip_str!r}") return (lo, hi) v = int(s) if not 0 <= v <= 255: raise ValueError(f"Octet value {v} out of range in {ip_str!r}") return (v, v) ranges = [parse_octet(p, i) for i, p in enumerate(parts)] # Count trailing full-wildcard octets to determine CIDR suffix length trailing = 0 for lo, hi in reversed(ranges): if lo == 0 and hi == 255: trailing += 1 else: break prefix_len = 32 - 8 * trailing prefix_ranges = ranges[:4 - trailing] # Guard against combinatorial explosion total = 1 for lo, hi in prefix_ranges: total *= (hi - lo + 1) if total > 1024: raise ValueError( f"Pattern {ip_str!r} would expand to {total} entries (limit 1024). " f"Use CIDR notation instead." ) results = [] if trailing > 0: def _enum_cidr(idx, chosen): if idx == len(prefix_ranges): base = '.'.join(str(v) for v in chosen) + '.0' * trailing if prefix_len == 32: results.append(base) else: results.append(f"{base}/{prefix_len}") return lo, hi = prefix_ranges[idx] for v in range(lo, hi + 1): _enum_cidr(idx + 1, chosen + [v]) _enum_cidr(0, []) else: outer_ranges = ranges[:3] lo4, hi4 = ranges[3] def _enum_range(idx, chosen): if idx == 3: base = '.'.join(str(v) for v in chosen) if lo4 == hi4: results.append(f"{base}.{lo4}") else: results.append(f"{base}.{lo4}-{base}.{hi4}") return lo, hi = outer_ranges[idx] for v in range(lo, hi + 1): _enum_range(idx + 1, chosen + [v]) _enum_range(0, []) return results def _expand_banned_ipv6(ip_str): """Convert an IPv6 pattern (CIDR, single IP, or trailing-wildcard) to nftables set elements. Supported formats: Single address : "2a01:4f8:c17:b0f::2" -- passed through as-is CIDR : "2a01:4f8::/32" -- passed through as-is Wildcard : "2a01:4f8:c17:*" -- prefix:* expands to a CIDR "2a01:4f8:c17:b00::*" -- :: compression is supported Range notation (e.g. "b00-bff") is not supported for IPv6. Use CIDR instead. """ if '/' in ip_str: ipaddress.IPv6Network(ip_str, strict=False) # validate return [ip_str] if '*' not in ip_str: ipaddress.IPv6Address(ip_str) # validate single address return [ip_str] if not ip_str.endswith(':*'): raise ValueError( f"Unsupported IPv6 wildcard pattern {ip_str!r}. " f"Use 'prefix:*' (e.g. '2a01:4f8:c17:*') or CIDR notation. " f"Range notation (e.g. 'b00-bff') is not supported for IPv6." ) prefix_part = ip_str[:-2] # strip trailing ':*' if '::' in prefix_part: left, right = prefix_part.split('::', 1) left_groups = [g for g in left.split(':') if g] if left else [] right_groups = [g for g in right.split(':') if g] if right else [] zero_count = 8 - len(left_groups) - len(right_groups) - 1 if zero_count < 0: raise ValueError(f"IPv6 wildcard pattern {ip_str!r} has too many groups.") groups = left_groups + ['0000'] * zero_count + right_groups else: groups = [g for g in prefix_part.split(':') if g] num_groups = len(groups) prefix_bits = num_groups * 16 if num_groups < 1 or num_groups > 7: raise ValueError( f"IPv6 wildcard pattern {ip_str!r} must have between 1 and 7 " f"prefix groups before the wildcard." ) base = ':'.join(groups) + ':' + ':'.join(['0000'] * (8 - num_groups)) addr = ipaddress.IPv6Address(base) return [f"{addr}/{prefix_bits}"] def expand_banned_ip(ip_str): """Return (family, [nftables_elements]) for a banned_ips entry. family is 'ipv4' or 'ipv6'.""" if ':' in ip_str: return ('ipv6', _expand_banned_ipv6(ip_str)) return ('ipv4', _expand_banned_ipv4(ip_str)) def banned_ip_sets(data): """Return (v4_elements, v6_elements) as flat lists of nftables set element strings.""" v4, v6 = [], [] for entry in rule_enabled(data.get("banned_ips", [])): family, elements = expand_banned_ip(entry["ip"]) if family == 'ipv4': v4.extend(elements) else: v6.extend(elements) return v4, v6 # =================================================================== # nftables config generation # =================================================================== def build_nft_config(data, dry_run=False): wan = data["network_interfaces"]["wan_interface"] vlans = [v for v in data["vlans"] if not validation.is_wg(v) or dry_run or wireguard.wg_interface_up(validation.derive_interface(v, data))] all_fwd = list(rule_enabled(data.get("port_forwarding", []))) wrngl_vlan_by_name = {v["name"]: v for v in vlans} all_wrngl = [(wrngl_vlan_by_name[r["vlan"]], r) for r in rule_enabled(data.get("port_wrangling", [])) if r.get("vlan") in wrngl_vlan_by_name] active_ifaces = {validation.derive_interface(v, data) for v in vlans} vlan_networks = {} for v in vlans: try: net = shared.network_for(v) vlan_networks[validation.derive_interface(v, data)] = net except (KeyError, ValueError): pass all_except = rule_enabled(data.get("inter_vlan_exceptions", [])) banned_v4, banned_v6 = banned_ip_sets(data) container_bridges = get_container_bridges() L = [ "# Generated by core.py -- do not edit manually.", "# Edit config.json and re-run: sudo python3 core.py --apply", "", ] # ========================================================================== # routlin-nat table # ========================================================================== L += [ f"table ip {shared.PRODUCT_NAME}-nat {{", "", " chain prerouting {", " type nat hook prerouting priority dstnat - 10; policy accept;", "", ] if all_fwd: L += [" # -- Port forwarding (inbound WAN -> LAN host) ---------------", ""] for rule in all_fwd: for proto, r, suffix in expand_protocols(rule): L += [ f" # {r['description']}{suffix}", f" iif \"{wan}\" {proto} dport {r['dest_port']} dnat to {r['nat_ip']}:{r['nat_port']}", "", ] if all_wrngl: L += [" # -- Port wrangling (redirect VLAN traffic to local host) ----", ""] for vlan, rule in all_wrngl: iface = validation.derive_interface(vlan, data) for proto, r, suffix in expand_protocols(rule): L += [ f" # {r['description']}{suffix}", f" iif \"{iface}\" {proto} dport {r['dest_port']} ip daddr != {r['redirect_to']} dnat to {r['redirect_to']}", "", ] L += [ " }", "", " chain postrouting {", " type nat hook postrouting priority srcnat; policy accept;", "", " # Masquerade all outbound traffic through WAN", f" oif \"{wan}\" masquerade", "", " }", "", "}", "", ] # ========================================================================== # routlin-filter table # ========================================================================== L += [f"table ip {shared.PRODUCT_NAME}-filter {{", ""] if banned_v4: elements = ", ".join(banned_v4) L += [ " set banned_ipv4 {", " type ipv4_addr", " flags interval", f" elements = {{ {elements} }}", " }", "", ] # INPUT chain L += [ " # INPUT -- traffic destined for this machine itself", " chain input {", " type filter hook input priority filter; policy drop;", "", ] if banned_v4: L += [ " # Drop banned IPs on WAN inbound", f" iif \"{wan}\" ip saddr @banned_ipv4 drop", "", ] L += [ " # Allow loopback", " iif \"lo\" accept", "", " # Allow established/related return traffic", " ct state established,related accept", "", " # Allow ICMP (ping) from anywhere", " ip protocol icmp accept", "", ] if avahi.avahi_enabled(data): mdns_ifaces = avahi.avahi_interfaces(data) if mdns_ifaces: iface_set = ", ".join(f'"{i}"' for i in mdns_ifaces) L += [ " # mDNS (port 5353) -- allow on reflection interfaces for avahi", f" iif {{ {iface_set} }} udp dport 5353 accept", "", ] # RADIUS -- must come BEFORE the broad VLAN accept rules below r_clients = radius.radius_clients(data) if r_clients: allowed_ips = ", ".join(r["ip"] for r, _ in r_clients) L += [ " # RADIUS (port 1812) -- allow only designated authenticators", f" ip saddr {{ {allowed_ips} }} udp dport 1812 accept", " udp dport 1812 drop", "", ] if container_bridges: iface_set = ", ".join(f'"{b}"' for b in container_bridges) L += [ " # Allow DNS from container bridge networks (Docker, Podman, etc.)", f" iif {{ {iface_set} }} meta l4proto {{ tcp, udp }} th dport 53 accept", "", ] L.append(" # Allow all traffic inbound from any VLAN interface") for vlan in vlans: L.append(f" iif \"{validation.derive_interface(vlan, data)}\" accept # {vlan['name']}") L.append("") if all_fwd: L += [" # Allow inbound WAN access for port-forwarded services", ""] for rule in all_fwd: for proto, r, suffix in expand_protocols(rule): L += [ f" # {r['description']}{suffix}", f" iif \"{wan}\" {proto} dport {r['dest_port']} accept", "", ] L += [" # Drop all other inbound WAN traffic", " }", ""] # FORWARD chain L += [ " # FORWARD -- traffic being routed through this machine", " chain forward {", " type filter hook forward priority filter; policy drop;", "", ] if banned_v4: L += [ " # Drop banned IPs on WAN inbound", f" iif \"{wan}\" ip saddr @banned_ipv4 drop", "", ] L += [ " # Allow established/related return traffic", " ct state established,related accept", "", ] L.append(" # Allow each VLAN -> WAN (outbound internet)") for vlan in vlans: L.append(f" iif \"{validation.derive_interface(vlan, data)}\" oif \"{wan}\" accept # {vlan['name']} -> WAN") L.append("") if container_bridges: L.append(" # Allow VLAN -> Docker bridge forwarding") for vlan in vlans: for bridge in container_bridges: L.append(f" iif \"{validation.derive_interface(vlan, data)}\" oif \"{bridge}\" ct state new accept" f" # {vlan['name']} -> {bridge}") L.append("") L += [ " # Allow Docker containers -> WAN (outbound internet access)", f" iif != \"{wan}\" oif \"{wan}\" ct state new accept", "", ] if avahi.avahi_enabled(data): mdns_ifaces = avahi.avahi_interfaces(data) if len(mdns_ifaces) > 1: iface_set = ", ".join(f'"{i}"' for i in mdns_ifaces) L += [ " # mDNS forwarding between reflection interfaces for avahi", f" iif {{ {iface_set} }} oif {{ {iface_set} }} udp dport 5353 accept", "", ] all_except = rule_enabled(data.get("inter_vlan_exceptions", [])) if all_except: L += [" # -- Inter-VLAN exceptions ------------------------------------------", ""] for r in all_except: src = r["src_ip_or_subnet"] dst = r.get("dst_ip_or_subnet") or r.get("dst_ip", "") min_p = r.get("dest_port_start") or r.get("dst_port") max_p = r.get("dest_port_end") if min_p and max_p and str(min_p) != str(max_p): port_spec = f"{min_p}-{max_p}" elif min_p: port_spec = str(min_p) else: port_spec = None for proto, _, suffix in expand_protocols(r): L.append(f" # {r['description']}{suffix}") if port_spec is not None: L.append(f" ip saddr {src} ip daddr {dst} {proto} dport {port_spec} ct state new accept") else: L.append(f" ip saddr {src} ip daddr {dst} ip protocol {proto} ct state new accept") L.append("") if all_fwd: L += [" # Allow inbound WAN -> VLAN for active port forwarding rules", ""] for rule in all_fwd: try: nat_addr = ipaddress.IPv4Address(rule["nat_ip"]) iface = wan # fallback for iface_key, net in vlan_networks.items(): if nat_addr in net: iface = iface_key break except ValueError: iface = wan for proto, r, suffix in expand_protocols(rule): L += [ f" # {r['description']}{suffix}", f" iif \"{wan}\" oif \"{iface}\" {proto} dport {r['nat_port']} ip daddr {r['nat_ip']} ct state new accept", "", ] L += [ " }", "", " chain output {", " type filter hook output priority filter; policy accept;", " }", "", "}", ] if banned_v6: elements = ", ".join(banned_v6) L += [ "", f"table ip6 {shared.PRODUCT_NAME}-ban {{", "", " set banned_ipv6 {", " type ipv6_addr", " flags interval", f" elements = {{ {elements} }}", " }", "", " chain input {", " type filter hook input priority filter; policy accept;", f" iif \"{wan}\" ip6 saddr @banned_ipv6 drop", " }", "", " chain forward {", " type filter hook forward priority filter; policy accept;", f" iif \"{wan}\" ip6 saddr @banned_ipv6 drop", " }", "", "}", ] return "\n".join(L) # =================================================================== # nftables apply / disable / status # =================================================================== def table_exists(family, name): result = subprocess.run( ["nft", "list", "table", family, name], capture_output=True, text=True ) return result.returncode == 0 def delete_our_tables(): """Delete all routlin-owned nftables tables. Returns error string on failure, None on success.""" for family, table in [("ip", f"{shared.PRODUCT_NAME}-nat"), ("ip", f"{shared.PRODUCT_NAME}-filter"), ("ip6", f"{shared.PRODUCT_NAME}-ban")]: if table_exists(family, table): result = subprocess.run( ["nft", "delete", "table", family, table], capture_output=True, text=True ) if result.returncode != 0: return f"Failed to delete table {family} {table}: {result.stderr.strip()}" print(f"Removed existing table: {family} {table}") else: print(f"Table not present, skipping delete: {family} {table}") return None def apply_nft_config(config_text): result = subprocess.run( ["nft", "-f", "-"], input=config_text, capture_output=True, text=True ) if result.returncode != 0: print("ERROR: nft rejected the ruleset:", file=sys.stderr) print(result.stderr, file=sys.stderr) sys.exit(1) def apply_nftables(data, dry_run=False): config = build_nft_config(data, dry_run=dry_run) if dry_run: print(config) return active_ifaces = {validation.derive_interface(v, data) for v in data["vlans"] if not validation.is_wg(v) or wireguard.wg_interface_up(validation.derive_interface(v, data))} active_vlans = [v for v in data["vlans"] if validation.derive_interface(v, data) in active_ifaces] all_fwd = list(rule_enabled(data.get("port_forwarding", []))) all_dis_fwd = list(rule_disabled(data.get("port_forwarding", []))) active_vlan_by_name = {v["name"]: v for v in active_vlans} all_wrngl = [(active_vlan_by_name[r["vlan"]], r) for r in rule_enabled(data.get("port_wrangling", [])) if r.get("vlan") in active_vlan_by_name] all_dis_wrngl = rule_disabled(data.get("port_wrangling", [])) all_except = rule_enabled(data.get("inter_vlan_exceptions", [])) print(f"Applying {len(all_fwd)} port forwarding rule(s), {len(all_dis_fwd)} skipped.") print(f"Applying {len(all_wrngl)} port wrangling rule(s), {len(all_dis_wrngl)} skipped.") print(f"Applying {len(all_except)} inter-VLAN exception(s).") container_bridges = get_container_bridges() if container_bridges: print(f"Container bridges: {', '.join(container_bridges)}") print() delete_our_tables() apply_nft_config(config) print("nftables rules applied successfully.") active_subnets = [] for v in data["vlans"]: if validation.is_wg(v) and not wireguard.wg_interface_up(validation.derive_interface(v, data)): continue try: active_subnets.append(shared.network_for(v)) except (KeyError, ValueError): pass def dst_is_active(r): dst = r.get("dst_ip_or_subnet") or r.get("dst_ip", "") try: addr = ipaddress.IPv4Address(dst) return any(addr in net for net in active_subnets) except ValueError: try: net = ipaddress.IPv4Network(dst, strict=False) return any(net.overlaps(s) for s in active_subnets) except ValueError: return True if all_fwd: print() print("Active port forwarding:") for r in all_fwd: print(f" [{r['protocol'].upper():<4}] :{r['dest_port']} -> {r['nat_ip']}:{r['nat_port']} ({r['description']})") if all_dis_fwd: print() print("Skipped port forwarding (disabled):") for r in all_dis_fwd: print(f" [{r['protocol'].upper():<4}] :{r['dest_port']} -> {r['nat_ip']}:{r['nat_port']} ({r['description']})") if all_wrngl: print() print("Active port wrangling:") for vlan, r in all_wrngl: print(f" [{r['protocol'].upper():<4}] :{r['dest_port']} -> {r['redirect_to']} ({r['description']}) [{vlan['name']}]") active_except = [r for r in all_except if dst_is_active(r)] if active_except: print() print("Active inter-VLAN exceptions:") for r in active_except: src = r["src_ip_or_subnet"] dst = r.get("dst_ip_or_subnet") or r.get("dst_ip", "") min_p = r.get("dest_port_start") or r.get("dst_port") max_p = r.get("dest_port_end") if min_p and max_p and str(min_p) != str(max_p): port_str = f":{min_p}-{max_p}" elif min_p: port_str = f":{min_p}" else: port_str = "" dst_str = f"{dst}{port_str}" print(f" [{r['protocol'].upper():<4}] {src} -> {dst_str} ({r['description']})") def show_rules(): for table in (f"{shared.PRODUCT_NAME}-nat", f"{shared.PRODUCT_NAME}-filter"): result = subprocess.run( ["nft", "list", "table", "ip", table], capture_output=True, text=True ) if result.returncode != 0: print(f"[{table}] not found (not yet applied)") else: print(result.stdout) # =================================================================== # NAT boot service # =================================================================== def install_nat_service(): script_path = shared.SCRIPT_DIR / "core.py" service_content = f"""[Unit] Description=Apply {shared.PRODUCT_NAME} NAT and firewall rules After=network-online.target docker.service Wants=network-online.target docker.service [Service] Type=oneshot ExecStart=/usr/bin/python3 {script_path} --apply RemainAfterExit=yes Restart=on-failure RestartSec=5s [Install] WantedBy=multi-user.target """ existing = NAT_SERVICE_FILE.read_text() if NAT_SERVICE_FILE.exists() else None if existing == service_content: print(f"Boot service already up to date: {NAT_SERVICE_FILE}") return NAT_SERVICE_FILE.write_text(service_content) subprocess.run(["systemctl", "daemon-reload"], check=True) subprocess.run(["systemctl", "enable", NAT_SERVICE_NAME], check=True) if existing is None: print(f"Boot service installed and enabled: {NAT_SERVICE_FILE}") else: print(f"Boot service updated: {NAT_SERVICE_FILE}") def remove_nat_service(): if NAT_SERVICE_FILE.exists(): subprocess.run(["systemctl", "disable", "--now", NAT_SERVICE_NAME], capture_output=True, text=True) NAT_SERVICE_FILE.unlink() subprocess.run(["systemctl", "daemon-reload"], capture_output=True, text=True) print(f"Removed boot service: {NAT_SERVICE_NAME}.service") else: print(f"Boot service not found, skipping: {NAT_SERVICE_NAME}.service")