164 lines
4.6 KiB
Python
164 lines
4.6 KiB
Python
"""
|
|
validation.py -- Shared structural validators for core.json fields.
|
|
|
|
Lives alongside core.py in ~/router/ and is volume-mounted into the
|
|
router-dash container at /configs/validation.py. Importable by both
|
|
core.py (router host) and the Flask app (via validate.py which adds
|
|
/configs to sys.path).
|
|
|
|
Convention: each function accepts a raw string and returns the
|
|
normalised valid value, or '' if the input is invalid.
|
|
"""
|
|
import ipaddress
|
|
import re
|
|
|
|
VALID_PROTOCOLS = {'tcp', 'udp', 'both'}
|
|
VALID_BLOCKLIST_FORMATS = {'dnsmasq', 'hosts'}
|
|
|
|
|
|
# ===================================================================
|
|
# IP / CIDR
|
|
# ===================================================================
|
|
|
|
def ip(value):
|
|
"""Return value if it is a valid IPv4 or IPv6 address, else ''."""
|
|
if not value:
|
|
return ''
|
|
v = str(value).strip()
|
|
try:
|
|
ipaddress.ip_address(v)
|
|
return v
|
|
except ValueError:
|
|
return ''
|
|
|
|
|
|
def ip_or_cidr(value):
|
|
"""Return value if it is a valid IPv4/IPv6 address or CIDR network, else ''."""
|
|
if not value:
|
|
return ''
|
|
v = str(value).strip()
|
|
try:
|
|
ipaddress.ip_address(v)
|
|
return v
|
|
except ValueError:
|
|
pass
|
|
try:
|
|
ipaddress.ip_network(v, strict=False)
|
|
return v
|
|
except ValueError:
|
|
return ''
|
|
|
|
|
|
# ===================================================================
|
|
# Port
|
|
# ===================================================================
|
|
|
|
def port(value):
|
|
"""Return port as string if valid 1-65535, else ''."""
|
|
try:
|
|
p = int(re.sub(r'[^0-9]', '', str(value)))
|
|
if 1 <= p <= 65535:
|
|
return str(p)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
return ''
|
|
|
|
|
|
# ===================================================================
|
|
# Banned-IP pattern
|
|
# ===================================================================
|
|
|
|
def banned_ip(value):
|
|
"""
|
|
Return value if it is a valid banned_ip pattern, else ''.
|
|
|
|
Accepted formats (mirrors core.py expand_banned_ip):
|
|
IPv4:
|
|
Single address 192.0.2.1
|
|
CIDR 192.0.2.0/24
|
|
Wildcard octet 192.0.2.*
|
|
Octet range 192.0.2.10-20
|
|
(combinations that expand to <=1024 entries are accepted)
|
|
IPv6:
|
|
Single address 2001:db8::1
|
|
CIDR 2001:db8::/32
|
|
Trailing wildcard 2001:db8:c17:*
|
|
"""
|
|
if not value:
|
|
return ''
|
|
v = str(value).strip()
|
|
try:
|
|
_check_banned_ip(v)
|
|
return v
|
|
except (ValueError, TypeError):
|
|
return ''
|
|
|
|
|
|
def _check_banned_ip(ip_str):
|
|
if ':' in ip_str:
|
|
_check_banned_ipv6(ip_str)
|
|
else:
|
|
_check_banned_ipv4(ip_str)
|
|
|
|
|
|
def _check_banned_ipv4(ip_str):
|
|
if '/' in ip_str:
|
|
ipaddress.IPv4Network(ip_str, strict=False)
|
|
return
|
|
|
|
parts = ip_str.split('.')
|
|
if len(parts) != 4:
|
|
raise ValueError(f"Expected 4 octets: {ip_str!r}")
|
|
|
|
def parse_octet(s):
|
|
if s == '*':
|
|
return (0, 255)
|
|
if '-' in s:
|
|
a, b = s.split('-', 1)
|
|
lo, hi = int(a), int(b)
|
|
if not (0 <= lo <= hi <= 255):
|
|
raise ValueError(f"Invalid octet range {s!r}")
|
|
return (lo, hi)
|
|
v = int(s)
|
|
if not 0 <= v <= 255:
|
|
raise ValueError(f"Octet {v} out of 0-255")
|
|
return (v, v)
|
|
|
|
ranges = [parse_octet(p) for p in parts]
|
|
|
|
trailing = 0
|
|
for lo, hi in reversed(ranges):
|
|
if lo == 0 and hi == 255:
|
|
trailing += 1
|
|
else:
|
|
break
|
|
|
|
total = 1
|
|
for lo, hi in ranges[:4 - trailing]:
|
|
total *= (hi - lo + 1)
|
|
if total > 1024:
|
|
raise ValueError(f"Pattern expands to {total} entries (limit 1024); use CIDR")
|
|
|
|
|
|
def _check_banned_ipv6(ip_str):
|
|
if '/' in ip_str:
|
|
ipaddress.IPv6Network(ip_str, strict=False)
|
|
return
|
|
if '*' not in ip_str:
|
|
ipaddress.IPv6Address(ip_str)
|
|
return
|
|
if not ip_str.endswith(':*'):
|
|
raise ValueError(f"Unsupported IPv6 wildcard: {ip_str!r}; use 'prefix:*' or CIDR")
|
|
prefix_part = ip_str[:-2]
|
|
if '::' in prefix_part:
|
|
left, right = prefix_part.split('::', 1)
|
|
lg = [g for g in left.split(':') if g] if left else []
|
|
rg = [g for g in right.split(':') if g] if right else []
|
|
zeros = 8 - len(lg) - len(rg) - 1
|
|
if zeros < 0:
|
|
raise ValueError(f"Too many groups in {ip_str!r}")
|
|
groups = lg + ['0000'] * zeros + rg
|
|
else:
|
|
groups = [g for g in prefix_part.split(':') if g]
|
|
if not (1 <= len(groups) <= 7):
|
|
raise ValueError(f"IPv6 wildcard must have 1-7 prefix groups: {ip_str!r}")
|