#!/usr/bin/env python3
"""Adaptive, resource-governed scan scheduler.

Goals (per user spec): use the maximum resources available without contention,
always leaving the system enough to run, and watch/adjust dynamically.

- MEMORY: each scan runs in its own systemd scope with MemoryMax set adaptively to
  min(PER_SCAN_CAP, 0.8 x currently-available), never letting total committed memory
  cross (total - OS_RESERVE). A single decompression bomb is bounded by its scope.
- CONCURRENCY: the scheduler launches another scan only while free memory stays above
  the OS reserve; it scales up when there's room and backs off when tight.
- TIME: a progress-watchdog reads each scan's progress.json. A scan whose stage has not
  advanced for STALL_SECS is SIGKILLed (this catches MuPDF C-decoder hangs that an
  in-process deadline cannot interrupt). Killed/stalled scans are recorded, never block.
"""
import json, os, sys, time, subprocess, tempfile, shutil, threading, signal

ROOTS = sys.argv[1:-1]
OUT   = sys.argv[-1]
SC    = "/tmp/pdf_scanner.py"

OS_RESERVE_MB  = 1536          # never touch this much — leave it for the system
PER_SCAN_CAP_MB = 3000         # absolute per-scan ceiling (a bomb can't exceed this)
PER_SCAN_MIN_MB = 1024         # don't start a scan unless we can give it at least this
# A scan ADVANCING through engines is never killed regardless of total time (honour
# "allow max"). Only a scan STUCK on one engine for this long is a genuine hang.
# Generous so a legitimately slow engine (OCR on a large page) under memory/CPU share
# is not cut off; a true C-decoder hang never advances and is still caught.
STALL_SECS     = 120
HARD_SECS      = 900           # absolute per-scan wall ceiling (safety net)
POLL           = 1.0

def _meminfo():
    try:
        with open('/proc/meminfo') as f:
            return {l.split(':')[0]: int(l.split()[1]) // 1024 for l in f}
    except Exception:
        return {}
def avail_mb():
    return _meminfo().get('MemAvailable', 0)
def total_mb():
    return _meminfo().get('MemTotal', 8000)

def find_pdfs(roots):
    fs = []
    for r in roots:
        for dp, _, names in os.walk(r):
            for n in names:
                if n.lower().endswith('.pdf'):
                    fs.append(os.path.join(dp, n))
    return sorted(set(fs))

_lock = threading.Lock()
_committed_mb = 0      # sum of MemoryMax granted to in-flight scans

def run_one(path, results):
    global _committed_mb
    # adaptive per-scan memory grant
    grant = min(PER_SCAN_CAP_MB, max(PER_SCAN_MIN_MB, int(0.8 * avail_mb())))
    with _lock:
        _committed_mb += grant
    d = tempfile.mkdtemp(prefix="gov_")
    rj = os.path.join(d, "r.json"); pj = os.path.join(d, "p.json")
    rec = {"file": path, "status": None, "grant_mb": grant}
    try:
        rec["size"] = os.path.getsize(path)
    except Exception:
        pass
    unit = "govscan_" + str(abs(hash(path)) % 10**9)
    cmd = ["systemd-run", "--scope", "-q", "--unit=" + unit,
           "-p", f"MemoryMax={grant}M", "-p", "MemorySwapMax=0", "-p", "OOMPolicy=continue",
           "nice", "-n", "15", "python3", SC, path, rj, pj]
    try:
        p = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
                              preexec_fn=os.setsid)
    except Exception as e:
        rec["status"] = "launch_error"; rec["err"] = repr(e)[:160]
        with _lock: _committed_mb -= grant
        shutil.rmtree(d, ignore_errors=True); results.append(rec); return

    # progress-watchdog
    t0 = time.time(); last_stage = None; last_change = t0; killed = None
    while True:
        rc = p.poll()
        if rc is not None:
            break
        try:
            st = json.load(open(pj)).get("stage")
        except Exception:
            st = last_stage
        now = time.time()
        if st != last_stage:
            last_stage = st; last_change = now
        if now - last_change > STALL_SECS:
            killed = "stalled"; break
        if now - t0 > HARD_SECS:
            killed = "hard_timeout"; break
        time.sleep(POLL)
    if killed:
        try: os.killpg(os.getpgid(p.pid), signal.SIGKILL)
        except Exception: pass
        try: subprocess.run(["systemctl", "stop", unit + ".scope"], capture_output=True, timeout=10)
        except Exception: pass
        p.wait()

    with _lock: _committed_mb -= grant
    # collect result
    if os.path.exists(rj):
        try:
            x = json.load(open(rj))
            rec["status"] = "ok"
            for k in ("threat_score","risk_level","deception_score","structural_score","verdict_driver","has_exec_vector"):
                rec[k] = x.get(k)
            rec["incomplete"] = x.get("scan_incomplete", False)
            rec["last_stage"] = last_stage
            # per-domain capability flags for the at-scale study (encompassing the 5 articles)
            inds = x.get("indicators", []) or []
            keys = " | ".join(str(i.get("key","")) for i in inds if isinstance(i, dict))
            st = x.get("structure", {}) or {}
            ro  = st.get("reading_order_analysis", {}) or {}
            ocr = st.get("ocr_text_layer_integrity", {}) or {}
            acc = st.get("accessibility_tree_forensics", {}) or {}
            tu  = st.get("tounicode_analysis", {}) or {}
            mr  = st.get("metadata_reconciliation", {}) or {}
            sf  = x.get("signature_forensics", {}) or {}
            rec["vap"]            = ("Value/Appearance" in keys) or ("Appearance Stream Renders Blank" in keys)
            rec["parser_disagree"]= "Differential Parsing" in keys
            rec["e44_multicol"]   = bool(ro.get("multi_column_pages"))
            rec["e45_ocr_mismatch"]= bool(ocr.get("mismatch_pages"))
            rec["e46_has_struct"] = bool(acc.get("has_struct_tree"))
            rec["e46_alt"]        = (acc.get("alt_text_count",0) or 0) > 0
            rec["tounicode_remap"]= bool(tu.get("suspicious_mappings") or tu.get("remap_pages"))
            rec["meta_desync"]    = (mr.get("discrepancy_years",0) or 0) > 0 or bool(mr.get("info_xmp_conflict"))
            rec["signed"]         = (sf.get("signature_count",0) or 0) > 0
            rec["has_js"]         = ("/JavaScript" in keys) or ("/JS" in keys)
            rec["xfa"]            = bool((st.get("xfa_formcalc",{}) or {}).get("xfa_detected"))
        except Exception:
            rec["status"] = "bad_json"
    else:
        rec["status"] = killed or "no_output"
        rec["last_stage"] = last_stage
    shutil.rmtree(d, ignore_errors=True)
    results.append(rec)

def main():
    files = find_pdfs(ROOTS)
    # resume support
    done = set()
    if os.path.exists(OUT):
        for l in open(OUT):
            try: done.add(json.loads(l)["file"])
            except Exception: pass
    files = [f for f in files if f not in done]
    total = len(files); idx = 0
    print(f"[governor] {total} to scan (OS_RESERVE={OS_RESERVE_MB}M, stall={STALL_SECS}s)", flush=True)
    results = []; threads = []; written = 0
    fh = open(OUT, "a")
    t_start = time.time()
    while idx < total or threads:
        # reap finished
        for t in threads[:]:
            if not t.is_alive():
                t.join(); threads.remove(t)
        while results:
            rec = results.pop()
            fh.write(json.dumps(rec) + "\n"); fh.flush(); written += 1
            if written % 25 == 0:
                print(f"[{written}/{total}] inflight={len(threads)} avail={avail_mb()}M committed={_committed_mb}M {int(time.time()-t_start)}s", flush=True)
        # Launch another only if it fits. Gate on COMMITTED grants (immediate) so we
        # never over-subscribe before a freshly-launched scan's memory ramps, AND require
        # real available headroom above the OS reserve. Always allow at least one in flight.
        if idx < total:
            usable = total_mb() - OS_RESERVE_MB
            with _lock:
                committed = _committed_mb
            next_grant = min(PER_SCAN_CAP_MB, max(PER_SCAN_MIN_MB, int(0.8 * avail_mb())))
            room_committed = (committed + next_grant) <= usable
            room_real      = (avail_mb() - OS_RESERVE_MB) >= PER_SCAN_MIN_MB
            if len(threads) == 0 or (room_committed and room_real):
                th = threading.Thread(target=run_one, args=(files[idx], results), daemon=True)
                th.start(); threads.append(th); idx += 1
                time.sleep(2.5)   # let the new scan's memory ramp before the next decision
                continue
        time.sleep(POLL)
    # drain
    while results:
        rec = results.pop(); fh.write(json.dumps(rec) + "\n"); written += 1
    fh.flush(); fh.close()
    print(f"[governor] DONE {written} in {int(time.time()-t_start)}s", flush=True)

if __name__ == "__main__":
    main()
