Development

This commit is contained in:
Matthew Grotke 2026-06-09 09:54:47 -04:00
parent 49dd4a2cf8
commit e33133df1e
9 changed files with 412 additions and 414 deletions

View file

@ -6,9 +6,11 @@ generation, applying/reloading instances, system service conflict resolution
(systemd-resolved, dnsmasq, chrony, ufw), and DHCP lease display.
"""
import hashlib
import json
import logging
import sqlite3
import subprocess
import time
from datetime import datetime
from pathlib import Path
@ -17,29 +19,303 @@ import mod_wireguard as wireguard
import mod_validation as validation
BLOCKLIST_DIR = shared.SCRIPT_DIR / "blocklists"
DB_FILE = BLOCKLIST_DIR / "domains.db"
LOG_FILE = shared.SCRIPT_DIR / "blocklists.log"
RESOLV_CONF = Path("/etc/resolv.conf")
_log = logging.getLogger("blocklists")
# ===================================================================
# Blocklist management
# ===================================================================
def combo_hash(names):
"""Return a stable 8-char hex hash for a list/set of blocklist names."""
key = ",".join(sorted(names))
return hashlib.sha256(key.encode()).hexdigest()[:8]
def vlan_hosts_file(vlan):
"""Stable per-VLAN hosts file path (always the same regardless of blocklist combo)."""
return BLOCKLIST_DIR / f"for-{vlan['name']}.hosts"
def merged_path(h):
return BLOCKLIST_DIR / f"merged-{h}.conf"
def blocklists_available(data):
"""Return True if at least one merged blocklist file exists on disk."""
combos = set()
"""Return True if at least one per-VLAN hosts file is non-empty."""
for vlan in data.get("vlans", []):
names = vlan.get("use_blocklists", [])
if names:
combos.add(combo_hash(names))
return any(merged_path(h).exists() for h in combos)
if vlan.get("use_blocklists"):
f = vlan_hosts_file(vlan)
if f.exists() and f.stat().st_size > 0:
return True
return False
# ===================================================================
# Blocklist parse / detect
# ===================================================================
def _parse_dnsmasq_format(content):
domains = set()
for ln in content.splitlines():
ln = ln.strip()
if not ln or ln.startswith("#"):
continue
if ln.startswith("local=/"):
domain = ln.removeprefix("local=/").rstrip("/")
if domain:
domains.add(domain)
elif ln.startswith("address=/"):
parts = ln.removeprefix("address=/").split("/")
if parts:
domains.add(parts[0])
return domains
def _parse_hosts_format(content):
domains = set()
for ln in content.splitlines():
ln = ln.strip()
if not ln or ln.startswith("#"):
continue
parts = ln.split()
if len(parts) >= 2:
domains.add(parts[1])
return domains
def _parse_local_format(content):
domains = set()
for ln in content.splitlines():
ln = ln.strip()
if ln and not ln.startswith("#"):
domains.add(ln)
return domains
def _detect_format(content):
for ln in content.splitlines():
ln = ln.strip()
if not ln or ln.startswith("#"):
continue
if ln.startswith("local=/") or ln.startswith("address=/"):
return "dnsmasq"
if ln[0].isdigit():
return "hosts"
return "dnsmasq"
def _parse_blocklist(content, is_local=False):
if is_local:
return _parse_local_format(content)
fmt = _detect_format(content)
if fmt == "dnsmasq":
return _parse_dnsmasq_format(content)
return _parse_hosts_format(content)
# ===================================================================
# Blocklist SQLite
# ===================================================================
def _open_db():
db = sqlite3.connect(DB_FILE)
db.execute("PRAGMA journal_mode=WAL")
db.execute("PRAGMA foreign_keys=ON")
db.executescript("""
CREATE TABLE IF NOT EXISTS blocklists (
id INTEGER PRIMARY KEY,
name TEXT UNIQUE NOT NULL,
mtime REAL,
fetched_at INTEGER,
domain_count INTEGER
);
CREATE TABLE IF NOT EXISTS domains (
domain TEXT NOT NULL,
blocklist_id INTEGER NOT NULL REFERENCES blocklists(id) ON DELETE CASCADE,
PRIMARY KEY (domain, blocklist_id)
);
CREATE INDEX IF NOT EXISTS idx_domains_domain ON domains(domain);
""")
db.commit()
return db
def _get_stored_mtime(db, name):
row = db.execute("SELECT mtime FROM blocklists WHERE name = ?", (name,)).fetchone()
return row[0] if row else None
def _upsert_blocklist(db, name, domains, mtime):
now = int(time.time())
db.execute("""
INSERT INTO blocklists (name, mtime, fetched_at, domain_count)
VALUES (?, ?, ?, ?)
ON CONFLICT(name) DO UPDATE SET
mtime = excluded.mtime,
fetched_at = excluded.fetched_at,
domain_count = excluded.domain_count
""", (name, mtime, now, len(domains)))
bl_id = db.execute("SELECT id FROM blocklists WHERE name = ?", (name,)).fetchone()[0]
db.execute("DELETE FROM domains WHERE blocklist_id = ?", (bl_id,))
db.executemany("INSERT INTO domains (domain, blocklist_id) VALUES (?, ?)",
((d, bl_id) for d in domains))
db.commit()
def _query_merged_domains(db, names):
placeholders = ",".join("?" * len(names))
rows = db.execute(f"""
SELECT DISTINCT d.domain
FROM domains d
JOIN blocklists b ON d.blocklist_id = b.id
WHERE b.name IN ({placeholders})
ORDER BY d.domain
""", list(names)).fetchall()
return [r[0] for r in rows]
# ===================================================================
# Blocklist merge and SIGHUP
# ===================================================================
def _build_merged_hosts(domains, bl_names):
lines = [
"# Generated by core.py -- do not edit manually.",
f"# Blocklists: {', '.join(sorted(bl_names))}",
f"# Domains: {len(domains):,}",
"",
]
for domain in domains:
lines.append(f"0.0.0.0 {domain}")
return "\n".join(lines) + "\n"
def setup_blocklist_logging(general):
"""Configure file + stdout logging for blocklists logger."""
max_kb = general.get("log_max_kb", 1024)
errors_only = general.get("log_errors_only", False)
try:
if LOG_FILE.exists() and LOG_FILE.stat().st_size > max_kb * 1024:
LOG_FILE.write_text("")
if not LOG_FILE.exists():
LOG_FILE.touch()
file_handler = logging.FileHandler(LOG_FILE)
except PermissionError:
print(f"WARNING: Cannot write to {LOG_FILE} -- run with sudo.")
file_handler = None
level = logging.ERROR if errors_only else logging.INFO
handlers = [logging.StreamHandler()]
if file_handler:
handlers.insert(0, file_handler)
logging.basicConfig(
level=level,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
handlers=handlers,
force=True,
)
def update_blocklist_hosts(data):
"""Parse downloaded/local blocklist files, upsert into SQLite, and write
per-VLAN hosts files (0.0.0.0 format). Always writes every VLAN's file
(empty if no blocklists assigned) so addn-hosts= in the dnsmasq conf is
always valid and SIGHUP can update it without a restart.
Returns True on full success, False if any fetch/parse failed.
"""
BLOCKLIST_DIR.mkdir(exist_ok=True)
db = _open_db()
bl_library = {bl["name"]: bl for bl in data.get("dns_blocking", {}).get("blocklists", [])}
needed = set()
for vlan in data.get("vlans", []):
needed.update(vlan.get("use_blocklists", []))
changed = set()
any_fail = False
for name in needed:
if name not in bl_library:
_log.warning(f"Blocklist '{name}' referenced by a VLAN but not defined -- skipping")
continue
entry = bl_library[name]
is_local = entry.get("bl_type") == "local"
save_as = entry.get("save_as", "")
try:
path = BLOCKLIST_DIR / save_as if save_as else None
if not path or not path.exists():
_log.warning(f"'{name}': file not found ({path}) -- skipping")
any_fail = True
continue
current_mtime = path.stat().st_mtime
except Exception as e:
_log.error(f"Failed to stat '{name}': {e}")
any_fail = True
continue
if current_mtime == _get_stored_mtime(db, name):
_log.info(f"Unchanged: '{name}' -- skipping")
continue
try:
raw = path.read_text("utf-8", errors="ignore")
except Exception as e:
_log.error(f"Failed to read '{name}': {e}")
any_fail = True
continue
domains = _parse_blocklist(raw, is_local=is_local)
_upsert_blocklist(db, name, domains, current_mtime)
_log.info(f"Updated '{name}': {len(domains):,} domains")
changed.add(name)
active_vlan_names = set()
for vlan in data.get("vlans", []):
vlan_name = vlan["name"]
active_vlan_names.add(vlan_name)
bl_names = [n for n in vlan.get("use_blocklists", []) if n in bl_library]
hosts_file = vlan_hosts_file(vlan)
if not bl_names:
if not hosts_file.exists():
hosts_file.write_text("")
continue
if not changed.intersection(bl_names) and hosts_file.exists():
_log.info(f"VLAN '{vlan_name}' blocklists unchanged -- skipping rewrite")
continue
domains = _query_merged_domains(db, bl_names)
hosts_file.write_text(_build_merged_hosts(domains, bl_names))
_log.info(f"VLAN '{vlan_name}': wrote {len(domains):,} domains from [{', '.join(sorted(bl_names))}]")
for f in BLOCKLIST_DIR.glob("for-*.hosts"):
vlan_name = f.stem.removeprefix("for-")
if vlan_name not in active_vlan_names:
f.unlink()
_log.info(f"Removed stale hosts file: {f.name}")
db.close()
return not any_fail
def sighup_all_instances():
"""Send SIGHUP to every active dnsmasq-routlin-* instance to reload addn-hosts
files without restarting. No DNS or DHCP interruption."""
result = subprocess.run(
["systemctl", "list-units", "--state=active", "--no-legend", "--plain",
f"dnsmasq-{shared.PRODUCT_NAME}-*.service"],
capture_output=True, text=True,
)
units = [line.split()[0] for line in result.stdout.splitlines() if line.strip()]
if not units:
print(" No active dnsmasq instances found.")
return
for unit in units:
r = subprocess.run(["systemctl", "kill", "--signal=SIGHUP", unit],
capture_output=True, text=True)
if r.returncode == 0:
print(f" Reloaded: {unit}")
else:
print(f" WARNING: Failed to SIGHUP {unit}: {r.stderr.strip()}")
# ===================================================================
@ -95,12 +371,7 @@ def build_vlan_dnsmasq_conf(vlan, data, iface):
opts = shared.resolve_vlan_options(vlan)
gateway = opts["gateway"]
bl_names = vlan.get("use_blocklists", [])
bl_file = None
if bl_names:
p = merged_path(combo_hash(bl_names))
if p.exists():
bl_file = p
hosts_file = vlan_hosts_file(vlan)
L = [
"# Generated by core.py -- do not edit manually.",
@ -216,14 +487,12 @@ def build_vlan_dnsmasq_conf(vlan, data, iface):
for o in overrides:
L += [f"# {o['description']}", f"address=/{o['host']}/{o['ip']}", ""]
if bl_file:
if hosts_file.exists():
L += [
"# -- Blocklist ------------------------------------------------------",
f"conf-file={bl_file}",
f"addn-hosts={hosts_file}",
"",
]
elif bl_names:
L += ["# Blocklist not yet downloaded -- run: sudo python3 dns-blocklists.py", ""]
return "\n".join(L)
@ -419,6 +688,9 @@ def apply_dnsmasq_instances(data, dry_run=False, start_if_needed=True):
if not dry_run:
shared.DNSMASQ_CONF_DIR.mkdir(exist_ok=True)
print("Updating blocklist hosts files ======================================")
update_blocklist_hosts(data)
print()
disable_system_dnsmasq(data)
print()