87 lines
2.4 KiB
Python
Executable file
87 lines
2.4 KiB
Python
Executable file
#!/usr/bin/env python3
|
|
"""
|
|
check_captive_users.py -- Expire captive portal sessions.
|
|
|
|
Runs every 5 minutes (systemd timer installed by core.py --apply).
|
|
Queries .client-credentials for sessions past their expiry time,
|
|
deletes them, and appends disallow commands to .captive-queue so
|
|
do_captive_queue.sh removes the corresponding nftables entries.
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import sqlite3
|
|
from pathlib import Path
|
|
|
|
SCRIPT_DIR = Path(__file__).parent
|
|
DB_FILE = SCRIPT_DIR / ".client-credentials"
|
|
QUEUE_FILE = SCRIPT_DIR / ".captive-queue"
|
|
|
|
|
|
def main():
|
|
if not DB_FILE.exists():
|
|
return
|
|
|
|
try:
|
|
conn = sqlite3.connect(DB_FILE)
|
|
conn.row_factory = sqlite3.Row
|
|
except Exception as e:
|
|
print(f"check_captive_users: cannot open {DB_FILE}: {e}", file=sys.stderr)
|
|
return
|
|
|
|
now = int(time.time())
|
|
|
|
try:
|
|
expired_ips = [
|
|
row["ip"]
|
|
for row in conn.execute(
|
|
"SELECT ip FROM sessions WHERE expires_at IS NOT NULL AND expires_at <= ?",
|
|
(now,),
|
|
)
|
|
]
|
|
account_expired_ips = [
|
|
row["ip"]
|
|
for row in conn.execute(
|
|
"""SELECT s.ip FROM sessions s
|
|
JOIN credentials c ON s.credential_id = c.id
|
|
WHERE c.expires_seconds > 0
|
|
AND (c.date_set + c.expires_seconds) <= ?""",
|
|
(now,),
|
|
)
|
|
]
|
|
except sqlite3.OperationalError:
|
|
conn.close()
|
|
return
|
|
|
|
all_ips = list(set(expired_ips + account_expired_ips))
|
|
|
|
if not all_ips:
|
|
conn.close()
|
|
return
|
|
|
|
conn.execute(
|
|
"DELETE FROM sessions WHERE expires_at IS NOT NULL AND expires_at <= ?",
|
|
(now,),
|
|
)
|
|
if account_expired_ips:
|
|
conn.execute(
|
|
"""DELETE FROM sessions WHERE id IN (
|
|
SELECT s.id FROM sessions s
|
|
JOIN credentials c ON s.credential_id = c.id
|
|
WHERE c.expires_seconds > 0
|
|
AND (c.date_set + c.expires_seconds) <= ?)""",
|
|
(now,),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
lines = "".join(f"disallow {ip}\n" for ip in all_ips)
|
|
with open(QUEUE_FILE, "a") as f:
|
|
f.write(lines)
|
|
|
|
print(f"check_captive_users: queued disallow for {len(all_ips)} expired session(s) "
|
|
f"({len(expired_ips)} session timeout, {len(account_expired_ips)} account expired).")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|