Development
This commit is contained in:
parent
8303eb5397
commit
a4eb431f22
11 changed files with 744 additions and 1 deletions
65
routlin/check_captive_users.py
Executable file
65
routlin/check_captive_users.py
Executable file
|
|
@ -0,0 +1,65 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
check_captive_users.py -- Expire captive portal sessions.
|
||||
|
||||
Runs every 5 minutes (systemd timer installed by core.py --apply).
|
||||
Queries .client-credentials for sessions past their expiry time,
|
||||
deletes them, and appends disallow commands to .captive-queue so
|
||||
do_captive_queue.sh removes the corresponding nftables entries.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
DB_FILE = SCRIPT_DIR / ".client-credentials"
|
||||
QUEUE_FILE = SCRIPT_DIR / ".captive-queue"
|
||||
|
||||
|
||||
def main():
|
||||
if not DB_FILE.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(DB_FILE)
|
||||
conn.row_factory = sqlite3.Row
|
||||
except Exception as e:
|
||||
print(f"check_captive_users: cannot open {DB_FILE}: {e}", file=sys.stderr)
|
||||
return
|
||||
|
||||
now = int(time.time())
|
||||
|
||||
try:
|
||||
expired_ips = [
|
||||
row["ip"]
|
||||
for row in conn.execute(
|
||||
"SELECT ip FROM sessions WHERE expires_at IS NOT NULL AND expires_at <= ?",
|
||||
(now,),
|
||||
)
|
||||
]
|
||||
except sqlite3.OperationalError:
|
||||
conn.close()
|
||||
return
|
||||
|
||||
if not expired_ips:
|
||||
conn.close()
|
||||
return
|
||||
|
||||
conn.execute(
|
||||
"DELETE FROM sessions WHERE expires_at IS NOT NULL AND expires_at <= ?",
|
||||
(now,),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
lines = "".join(f"disallow {ip}\n" for ip in expired_ips)
|
||||
with open(QUEUE_FILE, "a") as f:
|
||||
f.write(lines)
|
||||
|
||||
print(f"check_captive_users: queued disallow for {len(expired_ips)} expired session(s).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -97,6 +97,7 @@ from pathlib import Path
|
|||
|
||||
import health as health
|
||||
import mod_avahi as avahi
|
||||
import mod_captive as captive
|
||||
import mod_dnsmasq as dnsmasq
|
||||
import mod_metrics as metrics
|
||||
import mod_networkd as networkd
|
||||
|
|
@ -816,6 +817,15 @@ def cmd_apply(data, dry_run=False):
|
|||
avahi.disable_avahi()
|
||||
print()
|
||||
|
||||
print("Captive portal ==============================================")
|
||||
if captive.captive_portal_enabled(data):
|
||||
timers.install_captive_timers()
|
||||
print("Captive portal enabled - timers installed.")
|
||||
else:
|
||||
timers.remove_captive_timers()
|
||||
print("No captive portal VLANs - timers removed.")
|
||||
print()
|
||||
|
||||
print("Done.")
|
||||
|
||||
healthy, status = health.run_and_write(data)
|
||||
|
|
|
|||
17
routlin/mod_captive.py
Normal file
17
routlin/mod_captive.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""
|
||||
mod_captive.py -- Captive portal state and path constants.
|
||||
"""
|
||||
|
||||
import mod_shared as shared
|
||||
|
||||
CAPTIVE_QUEUE_FILE = shared.SCRIPT_DIR / ".captive-queue"
|
||||
CAPTIVE_DB_FILE = shared.SCRIPT_DIR / ".client-credentials"
|
||||
|
||||
# nftables table and set that hold authenticated client IPs
|
||||
CAPTIVE_NFT_FAMILY = "inet"
|
||||
CAPTIVE_NFT_TABLE = "filter"
|
||||
CAPTIVE_NFT_SET = "captive_allowed"
|
||||
|
||||
|
||||
def captive_portal_enabled(data):
|
||||
return any(v.get("restricted_vlan") == "c" for v in data.get("vlans", []))
|
||||
|
|
@ -22,6 +22,16 @@ HEALTH_TIMER_FILE = shared.SYSTEMD_DIR / f"{HEALTH_TIMER_NAME}.timer"
|
|||
HEALTH_TIMER_SVC_FILE = shared.SYSTEMD_DIR / f"{HEALTH_TIMER_NAME}.service"
|
||||
HEALTH_TIMER_INTERVAL_SEC = 300
|
||||
|
||||
CAPTIVE_QUEUE_TIMER_NAME = f"{shared.PRODUCT_NAME}-captive-queue"
|
||||
CAPTIVE_QUEUE_TIMER_FILE = shared.SYSTEMD_DIR / f"{CAPTIVE_QUEUE_TIMER_NAME}.timer"
|
||||
CAPTIVE_QUEUE_TIMER_SVC_FILE = shared.SYSTEMD_DIR / f"{CAPTIVE_QUEUE_TIMER_NAME}.service"
|
||||
CAPTIVE_QUEUE_TIMER_INTERVAL = 10
|
||||
|
||||
CAPTIVE_CHECK_TIMER_NAME = f"{shared.PRODUCT_NAME}-captive-check"
|
||||
CAPTIVE_CHECK_TIMER_FILE = shared.SYSTEMD_DIR / f"{CAPTIVE_CHECK_TIMER_NAME}.timer"
|
||||
CAPTIVE_CHECK_TIMER_SVC_FILE = shared.SYSTEMD_DIR / f"{CAPTIVE_CHECK_TIMER_NAME}.service"
|
||||
CAPTIVE_CHECK_TIMER_INTERVAL = 300
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Blocklist timer
|
||||
|
|
@ -212,3 +222,30 @@ def install_maint_timer(data):
|
|||
subprocess.run(["systemctl"] + verb.split() + [f"{MAINT_TIMER_NAME}.timer"],
|
||||
capture_output=True, text=True)
|
||||
print(f"Timer {MAINT_TIMER_NAME}.timer enabled (runs every {interval}).")
|
||||
|
||||
|
||||
# ===================================================================
|
||||
# Captive portal timers
|
||||
# ===================================================================
|
||||
|
||||
def install_captive_timers():
|
||||
install_interval_timers(
|
||||
names=[CAPTIVE_QUEUE_TIMER_NAME, CAPTIVE_CHECK_TIMER_NAME],
|
||||
timer_files=[CAPTIVE_QUEUE_TIMER_FILE, CAPTIVE_CHECK_TIMER_FILE],
|
||||
svc_files=[CAPTIVE_QUEUE_TIMER_SVC_FILE, CAPTIVE_CHECK_TIMER_SVC_FILE],
|
||||
descriptions=["Captive portal queue processor", "Captive portal session expiry checker"],
|
||||
exec_starts=[
|
||||
f"/bin/bash {shared.SCRIPT_DIR / 'do_captive_queue.sh'}",
|
||||
f"/usr/bin/python3 {shared.SCRIPT_DIR / 'check_captive_users.py'}",
|
||||
],
|
||||
interval_secs=[CAPTIVE_QUEUE_TIMER_INTERVAL, CAPTIVE_CHECK_TIMER_INTERVAL],
|
||||
)
|
||||
|
||||
|
||||
def remove_captive_timers():
|
||||
remove_timers(
|
||||
names=[CAPTIVE_QUEUE_TIMER_NAME, CAPTIVE_CHECK_TIMER_NAME],
|
||||
timer_files=[CAPTIVE_QUEUE_TIMER_FILE, CAPTIVE_CHECK_TIMER_FILE],
|
||||
svc_files=[CAPTIVE_QUEUE_TIMER_SVC_FILE, CAPTIVE_CHECK_TIMER_SVC_FILE],
|
||||
daemon_reload=True,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue