364 lines
11 KiB
Python
364 lines
11 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
dns-blocklists.py -- Download and merge DNS blocklists defined in config.json.
|
|
|
|
Reads the blocklists library from config.json, downloads every blocklist referenced
|
|
by at least one VLAN, and upserts normalized domains into a SQLite database
|
|
(blocklists/domains.db). Downloads are skipped when the content hash is unchanged.
|
|
Merged per-combo conf files are only rewritten when a constituent blocklist changed.
|
|
Sends SIGHUP to each running dnsmasq instance so it reloads without restarting.
|
|
|
|
Usage:
|
|
sudo python3 dns-blocklists.py
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import os
|
|
import sqlite3
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
import urllib.request
|
|
import urllib.error
|
|
from pathlib import Path
|
|
|
|
PRODUCT_NAME = "routlin"
|
|
SCRIPT_DIR = Path(__file__).parent
|
|
CONFIG_FILE = SCRIPT_DIR / "config.json"
|
|
BLOCKLIST_DIR = SCRIPT_DIR / "blocklists"
|
|
DB_FILE = BLOCKLIST_DIR / "domains.db"
|
|
LOG_FILE = SCRIPT_DIR / "dns-blocklists.log"
|
|
|
|
log = None
|
|
|
|
|
|
def _chown_to_script_dir_owner(path):
|
|
try:
|
|
stat = SCRIPT_DIR.stat()
|
|
os.chown(path, stat.st_uid, stat.st_gid)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def setup_logging(max_kb, errors_only):
|
|
global log
|
|
try:
|
|
if LOG_FILE.exists() and LOG_FILE.stat().st_size > max_kb * 1024:
|
|
LOG_FILE.write_text("")
|
|
if not LOG_FILE.exists():
|
|
LOG_FILE.touch()
|
|
_chown_to_script_dir_owner(LOG_FILE)
|
|
file_handler = logging.FileHandler(LOG_FILE)
|
|
except PermissionError:
|
|
print(f"WARNING: Cannot write to {LOG_FILE} (permission denied). "
|
|
f"Run with sudo or fix ownership: sudo chown $USER {LOG_FILE}")
|
|
file_handler = None
|
|
level = logging.ERROR if errors_only else logging.INFO
|
|
handlers = [logging.StreamHandler(sys.stdout)]
|
|
if file_handler:
|
|
handlers.insert(0, file_handler)
|
|
logging.basicConfig(
|
|
level=level,
|
|
format="%(asctime)s %(levelname)-8s %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
handlers=handlers,
|
|
)
|
|
log = logging.getLogger("dns-blocklists")
|
|
|
|
|
|
def die(msg):
|
|
print(f"ERROR: {msg}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
def check_root():
|
|
if os.geteuid() != 0:
|
|
die("This script must be run as root (sudo).")
|
|
|
|
|
|
def load_config():
|
|
if not CONFIG_FILE.exists():
|
|
die(f"Config file not found: {CONFIG_FILE}")
|
|
with open(CONFIG_FILE) as f:
|
|
data = json.load(f)
|
|
if not data.get("vlans"):
|
|
die("No vlans defined in config.json.")
|
|
return data
|
|
|
|
|
|
def combo_hash(names):
|
|
key = ",".join(sorted(names))
|
|
return hashlib.sha256(key.encode()).hexdigest()[:8]
|
|
|
|
|
|
def merged_path(h):
|
|
return BLOCKLIST_DIR / f"merged-{h}.conf"
|
|
|
|
|
|
# Parse / detect ======================================================
|
|
|
|
def parse_dnsmasq_format(content):
|
|
domains = set()
|
|
for ln in content.splitlines():
|
|
ln = ln.strip()
|
|
if not ln or ln.startswith("#"):
|
|
continue
|
|
if ln.startswith("local=/"):
|
|
domain = ln.removeprefix("local=/").rstrip("/")
|
|
if domain:
|
|
domains.add(domain)
|
|
elif ln.startswith("address=/"):
|
|
parts = ln.removeprefix("address=/").split("/")
|
|
if parts:
|
|
domains.add(parts[0])
|
|
return domains
|
|
|
|
|
|
def parse_hosts_format(content):
|
|
domains = set()
|
|
for ln in content.splitlines():
|
|
ln = ln.strip()
|
|
if not ln or ln.startswith("#"):
|
|
continue
|
|
parts = ln.split()
|
|
if len(parts) >= 2:
|
|
domains.add(parts[1])
|
|
return domains
|
|
|
|
|
|
def parse_local_format(content):
|
|
domains = set()
|
|
for ln in content.splitlines():
|
|
ln = ln.strip()
|
|
if ln and not ln.startswith("#"):
|
|
domains.add(ln)
|
|
return domains
|
|
|
|
|
|
def detect_format(content):
|
|
for ln in content.splitlines():
|
|
ln = ln.strip()
|
|
if not ln or ln.startswith("#"):
|
|
continue
|
|
if ln.startswith("local=/") or ln.startswith("address=/"):
|
|
return "dnsmasq"
|
|
if ln[0].isdigit():
|
|
return "hosts"
|
|
return "dnsmasq"
|
|
|
|
|
|
def parse_blocklist(content, is_local=False):
|
|
if is_local:
|
|
return parse_local_format(content)
|
|
fmt = detect_format(content)
|
|
if fmt == "dnsmasq":
|
|
return parse_dnsmasq_format(content)
|
|
return parse_hosts_format(content)
|
|
|
|
|
|
def content_hash(content):
|
|
return hashlib.sha256(content.encode()).hexdigest()
|
|
|
|
|
|
# SQLite ==============================================================
|
|
|
|
def open_db():
|
|
db = sqlite3.connect(DB_FILE)
|
|
db.execute("PRAGMA journal_mode=WAL")
|
|
db.execute("PRAGMA foreign_keys=ON")
|
|
db.executescript("""
|
|
CREATE TABLE IF NOT EXISTS blocklists (
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT UNIQUE NOT NULL,
|
|
content_hash TEXT,
|
|
fetched_at INTEGER,
|
|
domain_count INTEGER
|
|
);
|
|
CREATE TABLE IF NOT EXISTS domains (
|
|
domain TEXT NOT NULL,
|
|
blocklist_id INTEGER NOT NULL REFERENCES blocklists(id) ON DELETE CASCADE,
|
|
PRIMARY KEY (domain, blocklist_id)
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_domains_domain ON domains(domain);
|
|
""")
|
|
db.commit()
|
|
return db
|
|
|
|
|
|
def get_stored_hash(db, name):
|
|
row = db.execute("SELECT content_hash FROM blocklists WHERE name = ?", (name,)).fetchone()
|
|
return row[0] if row else None
|
|
|
|
|
|
def upsert_blocklist(db, name, domains, raw_hash):
|
|
now = int(time.time())
|
|
db.execute("""
|
|
INSERT INTO blocklists (name, content_hash, fetched_at, domain_count)
|
|
VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(name) DO UPDATE SET
|
|
content_hash = excluded.content_hash,
|
|
fetched_at = excluded.fetched_at,
|
|
domain_count = excluded.domain_count
|
|
""", (name, raw_hash, now, len(domains)))
|
|
bl_id = db.execute("SELECT id FROM blocklists WHERE name = ?", (name,)).fetchone()[0]
|
|
db.execute("DELETE FROM domains WHERE blocklist_id = ?", (bl_id,))
|
|
db.executemany("INSERT INTO domains (domain, blocklist_id) VALUES (?, ?)",
|
|
((d, bl_id) for d in domains))
|
|
db.commit()
|
|
|
|
|
|
def query_merged_domains(db, names):
|
|
placeholders = ",".join("?" * len(names))
|
|
rows = db.execute(f"""
|
|
SELECT DISTINCT d.domain
|
|
FROM domains d
|
|
JOIN blocklists b ON d.blocklist_id = b.id
|
|
WHERE b.name IN ({placeholders})
|
|
ORDER BY d.domain
|
|
""", list(names)).fetchall()
|
|
return [r[0] for r in rows]
|
|
|
|
|
|
# Conf file output ====================================================
|
|
|
|
def build_merged_conf(domains, bl_names):
|
|
lines = [
|
|
"# Generated by dns-blocklists.py -- do not edit manually.",
|
|
f"# Blocklist combination: {', '.join(sorted(bl_names))}",
|
|
f"# Merged: {len(domains):,} unique domains.",
|
|
"#",
|
|
"# Blocks domain and all subdomains via local=/domain/ syntax.",
|
|
"",
|
|
]
|
|
for domain in domains:
|
|
lines.append(f"local=/{domain}/")
|
|
return "\n".join(lines)
|
|
|
|
|
|
# Fetch ===============================================================
|
|
|
|
def fetch_community(entry):
|
|
url = entry["url"]
|
|
req = urllib.request.Request(url, headers={"User-Agent": "dns-blocklists.py/1.0"})
|
|
with urllib.request.urlopen(req, timeout=30) as r:
|
|
return r.read().decode("utf-8", errors="ignore")
|
|
|
|
|
|
def read_local(entry):
|
|
save_as = entry.get("save_as", "")
|
|
path = BLOCKLIST_DIR / save_as if save_as else None
|
|
if not path:
|
|
return ""
|
|
return path.read_text()
|
|
|
|
|
|
# Main update =========================================================
|
|
|
|
def update_blocklists(data):
|
|
BLOCKLIST_DIR.mkdir(exist_ok=True)
|
|
_chown_to_script_dir_owner(BLOCKLIST_DIR)
|
|
|
|
db = open_db()
|
|
|
|
bl_library = {bl["name"]: bl for bl in data.get("dns_blocking", {}).get("blocklists", [])}
|
|
needed = set()
|
|
for vlan in data["vlans"]:
|
|
needed.update(vlan.get("use_blocklists", []))
|
|
|
|
changed = set()
|
|
any_fail = False
|
|
|
|
for name in needed:
|
|
entry = bl_library[name]
|
|
is_local = entry.get("bl_type") == "local"
|
|
|
|
try:
|
|
raw = read_local(entry) if is_local else fetch_community(entry)
|
|
except Exception as e:
|
|
log.error(f"Failed to fetch '{name}': {e}")
|
|
any_fail = True
|
|
continue
|
|
|
|
h = content_hash(raw)
|
|
if h == get_stored_hash(db, name):
|
|
log.info(f"Unchanged: '{name}' -- skipping")
|
|
continue
|
|
|
|
domains = parse_blocklist(raw, is_local=is_local)
|
|
upsert_blocklist(db, name, domains, h)
|
|
log.info(f"Updated '{name}': {len(domains):,} domains")
|
|
changed.add(name)
|
|
|
|
active_hashes = set()
|
|
combos = {}
|
|
for vlan in data["vlans"]:
|
|
names = frozenset(vlan.get("use_blocklists", []))
|
|
if names:
|
|
h = combo_hash(names)
|
|
combos[h] = names
|
|
|
|
for h, names in combos.items():
|
|
active_hashes.add(h)
|
|
if not changed.intersection(names) and merged_path(h).exists():
|
|
log.info(f"Combo [{h}] unchanged -- skipping rewrite")
|
|
continue
|
|
domains = query_merged_domains(db, names)
|
|
merged_path(h).write_text(build_merged_conf(domains, names))
|
|
log.info(f"Merged [{h}] ({', '.join(sorted(names))}): {len(domains):,} unique domains")
|
|
|
|
for f in BLOCKLIST_DIR.glob("merged-*.conf"):
|
|
h = f.stem.removeprefix("merged-")
|
|
if h not in active_hashes:
|
|
f.unlink()
|
|
log.info(f"Removed stale merged file: {f.name}")
|
|
|
|
db.close()
|
|
return not any_fail
|
|
|
|
|
|
def reload_dnsmasq_instances():
|
|
"""Send SIGHUP to every active dnsmasq-routlin-* instance so it reloads
|
|
its conf-file inclusions without restarting. No DNS or DHCP interruption."""
|
|
result = subprocess.run(
|
|
["systemctl", "list-units", "--state=active", "--no-legend", "--plain",
|
|
f"dnsmasq-{PRODUCT_NAME}-*.service"],
|
|
capture_output=True, text=True,
|
|
)
|
|
units = [line.split()[0] for line in result.stdout.splitlines() if line.strip()]
|
|
if not units:
|
|
print(" No active dnsmasq instances found.")
|
|
return
|
|
for unit in units:
|
|
r = subprocess.run(["systemctl", "kill", "--signal=SIGHUP", unit],
|
|
capture_output=True, text=True)
|
|
if r.returncode == 0:
|
|
print(f" Reloaded: {unit}")
|
|
else:
|
|
print(f" WARNING: Failed to reload {unit}: {r.stderr.strip()}")
|
|
|
|
|
|
def main():
|
|
check_root()
|
|
data = load_config()
|
|
general = data.get("dns_blocking", {}).get("general", {})
|
|
setup_logging(
|
|
general.get("log_max_kb", 1024),
|
|
general.get("log_errors_only", False),
|
|
)
|
|
|
|
print("Updating blocklists =================================================")
|
|
success = update_blocklists(data)
|
|
print()
|
|
|
|
if success:
|
|
print("Reloading dnsmasq instances =========================================")
|
|
reload_dnsmasq_instances()
|
|
else:
|
|
print("WARNING: Blocklist update had errors -- skipping reload.")
|
|
print(" Existing merged files (if any) are unchanged.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|