""" policy_store — M6.0: Persistent room-node override store. SQLite-backed store that allows operators to dynamically set a preferred node (NODA1, NODA2, …) for any Matrix room without redeploying the bridge. Resolution layer (in NodePolicy.resolve): 1. explicit node=X kwarg (highest priority) 2. dynamic store override ← this module 3. static BRIDGE_ROOM_NODE_MAP env 4. BRIDGE_DEFAULT_NODE (lowest priority) All DB operations are synchronous/blocking. Call via asyncio.to_thread in async contexts to avoid blocking the event loop. Security: - operator identity is stored as SHA-256[:16] (no PII verbatim) - room_id values validated against basic Matrix ID format by callers - SQLite WAL mode, PRAGMA synchronous=NORMAL for durability+speed """ from __future__ import annotations import datetime import glob as _glob import hashlib import json as _json import logging import os as _os import sqlite3 import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple POLICY_SNAPSHOT_VERSION = 1 POLICY_IMPORT_MODE_MERGE = "merge" POLICY_IMPORT_MODE_REPLACE = "replace" logger = logging.getLogger(__name__) _DDL = """ CREATE TABLE IF NOT EXISTS room_node_overrides ( room_id TEXT PRIMARY KEY, node_id TEXT NOT NULL, updated_at INTEGER NOT NULL, updated_by_hash TEXT NOT NULL ); """ _IDX_TS = """ CREATE INDEX IF NOT EXISTS idx_rno_updated_at ON room_node_overrides (updated_at DESC); """ # M6.1: Dynamic mixed room agent overrides _DDL_AGENT = """ CREATE TABLE IF NOT EXISTS room_agent_overrides ( room_id TEXT PRIMARY KEY, agents_csv TEXT NOT NULL, default_agent TEXT, updated_at INTEGER NOT NULL, updated_by_hash TEXT NOT NULL ); """ _IDX_AGENT_TS = """ CREATE INDEX IF NOT EXISTS idx_rao_updated_at ON room_agent_overrides (updated_at DESC); """ # M8.2: HA persistence tables _DDL_STICKY = """ CREATE TABLE IF NOT EXISTS sticky_node_cache ( key TEXT PRIMARY KEY, node_id TEXT NOT NULL, expires_at INTEGER NOT NULL, updated_at INTEGER NOT NULL ); """ _DDL_NODE_HEALTH = """ CREATE TABLE IF NOT EXISTS node_health_state ( node_id TEXT PRIMARY KEY, ewma_latency_s REAL, consecutive_failures INTEGER NOT NULL DEFAULT 0, updated_at INTEGER NOT NULL ); """ # M10.2: Policy change history table _DDL_POLICY_CHANGES = """ CREATE TABLE IF NOT EXISTS policy_changes ( id INTEGER PRIMARY KEY AUTOINCREMENT, applied_at INTEGER NOT NULL, verb TEXT NOT NULL DEFAULT '', mode TEXT NOT NULL DEFAULT '', source_file TEXT NOT NULL DEFAULT '', sender_hash TEXT NOT NULL DEFAULT '', diff_summary TEXT NOT NULL DEFAULT '', is_destructive INTEGER NOT NULL DEFAULT 0, node_added INTEGER NOT NULL DEFAULT 0, node_updated INTEGER NOT NULL DEFAULT 0, node_deleted INTEGER NOT NULL DEFAULT 0, agent_added INTEGER NOT NULL DEFAULT 0, agent_updated INTEGER NOT NULL DEFAULT 0, agent_deleted INTEGER NOT NULL DEFAULT 0 ); """ _IDX_POLICY_CHANGES_TS = """ CREATE INDEX IF NOT EXISTS idx_pc_applied_at ON policy_changes (applied_at DESC); """ _POLICY_HISTORY_DEFAULT_LIMIT = 100 # Maximum number of entries returned by list_* (safety cap) _LIST_HARD_LIMIT = 100 # M9.1: Import diff result dataclass _SAMPLE_KEYS_MAX = 5 @dataclass class ImportDiff: """ Result of compute_import_diff — what would change if a snapshot were imported. Used to build a preview reply and confirm binding hash (M9.1). """ node_added: int = 0 node_updated: int = 0 node_deleted: int = 0 agent_added: int = 0 agent_updated: int = 0 agent_deleted: int = 0 sample_keys: List[str] = field(default_factory=list) # up to _SAMPLE_KEYS_MAX is_replace: bool = False def total_changes(self) -> int: return ( self.node_added + self.node_updated + self.node_deleted + self.agent_added + self.agent_updated + self.agent_deleted ) def is_destructive(self) -> bool: """True if any existing data would be deleted.""" return self.node_deleted > 0 or self.agent_deleted > 0 # M10.2: Policy change history entry @dataclass class PolicyChange: """A single recorded policy apply event (import or restore).""" id: int applied_at: int # unix timestamp verb: str # e.g. "policy.import", "policy.restore" mode: str # "merge" or "replace" source_file: str # snapshot filename (basename only) sender_hash: str # truncated hash of operator sender_id diff_summary: str # human-readable change summary string is_destructive: bool # True if any deletions occurred node_added: int node_updated: int node_deleted: int agent_added: int agent_updated: int agent_deleted: int def when_str(self) -> str: """Human-readable UTC timestamp.""" return datetime.datetime.fromtimestamp( self.applied_at, datetime.timezone.utc ).strftime("%Y-%m-%d %H:%M") def changes_short(self) -> str: """Compact change summary, e.g. '+2n -1n +1a'.""" parts = [] if self.node_added: parts.append(f"+{self.node_added}n") if self.node_updated: parts.append(f"~{self.node_updated}n") if self.node_deleted: parts.append(f"-{self.node_deleted}n") if self.agent_added: parts.append(f"+{self.agent_added}a") if self.agent_updated: parts.append(f"~{self.agent_updated}a") if self.agent_deleted: parts.append(f"-{self.agent_deleted}a") return " ".join(parts) or "no changes" # M10.0: Auto-backup + prune result _AUTOBACKUP_PREFIX = "policy-autobackup-" _EXPORT_GLOB = "policy-*.json" _PRUNE_SAMPLE_MAX = 5 @dataclass class PruneResult: """Result of prune_exports — what was (or would be) pruned (M10.0).""" files_to_delete: List[str] # basenames of matching expired files total_bytes: int # approximate bytes freed (or to be freed) oldest_mtime: Optional[float] = None # oldest mtime among files to delete @property def count(self) -> int: return len(self.files_to_delete) def sample_filenames(self, n: int = _PRUNE_SAMPLE_MAX) -> List[str]: return sorted(self.files_to_delete)[:n] def _hash_sender(sender: str) -> str: """Partial SHA-256 of sender Matrix ID (non-reversible, no PII stored raw).""" return hashlib.sha256(sender.encode("utf-8")).hexdigest()[:16] class PolicyStore: """ Lightweight synchronous SQLite wrapper for room→node overrides. Usage pattern (async callers): override = await asyncio.to_thread(store.get_override, room_id) await asyncio.to_thread(store.set_override, room_id, "NODA2", sender) """ def __init__(self, db_path: str) -> None: self._db_path = db_path self._conn: Optional[sqlite3.Connection] = None # ── Lifecycle ────────────────────────────────────────────────────────────── def open(self) -> None: """Open (or create) the SQLite DB and apply schema.""" Path(self._db_path).parent.mkdir(parents=True, exist_ok=True) self._conn = sqlite3.connect( self._db_path, check_same_thread=False, isolation_level=None, # autocommit ) self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA synchronous=NORMAL") self._conn.execute(_DDL) self._conn.execute(_IDX_TS) self._conn.execute(_DDL_AGENT) self._conn.execute(_IDX_AGENT_TS) # M8.2: HA persistence tables self._conn.execute(_DDL_STICKY) self._conn.execute(_DDL_NODE_HEALTH) # M10.2: Policy change history self._conn.execute(_DDL_POLICY_CHANGES) self._conn.execute(_IDX_POLICY_CHANGES_TS) logger.info("PolicyStore opened: %s", self._db_path) def close(self) -> None: """Close the SQLite connection.""" if self._conn: try: self._conn.close() except Exception: # noqa: BLE001 pass finally: self._conn = None # ── CRUD ─────────────────────────────────────────────────────────────────── def get_override(self, room_id: str) -> Optional[str]: """Return the stored node_id for room_id, or None if not set.""" self._require_open() row = self._conn.execute( # type: ignore[union-attr] "SELECT node_id FROM room_node_overrides WHERE room_id = ?", (room_id,), ).fetchone() return row[0] if row else None def set_override(self, room_id: str, node_id: str, updated_by: str) -> None: """Upsert a room→node override.""" self._require_open() self._conn.execute( # type: ignore[union-attr] """ INSERT INTO room_node_overrides (room_id, node_id, updated_at, updated_by_hash) VALUES (?, ?, ?, ?) ON CONFLICT(room_id) DO UPDATE SET node_id = excluded.node_id, updated_at = excluded.updated_at, updated_by_hash = excluded.updated_by_hash """, (room_id, node_id, int(time.time()), _hash_sender(updated_by)), ) def delete_override(self, room_id: str) -> bool: """Remove override for room_id. Returns True if a row was deleted.""" self._require_open() cursor = self._conn.execute( # type: ignore[union-attr] "DELETE FROM room_node_overrides WHERE room_id = ?", (room_id,), ) return cursor.rowcount > 0 def list_overrides(self, limit: int = 10) -> List[Tuple[str, str, int]]: """ Return [(room_id, node_id, updated_at), …] ordered by updated_at DESC. Hard-capped at _LIST_HARD_LIMIT regardless of caller's limit. """ self._require_open() cap = min(max(1, limit), _LIST_HARD_LIMIT) rows = self._conn.execute( # type: ignore[union-attr] """ SELECT room_id, node_id, updated_at FROM room_node_overrides ORDER BY updated_at DESC LIMIT ? """, (cap,), ).fetchall() return [(r[0], r[1], r[2]) for r in rows] def count_overrides(self) -> int: """Return total number of override rows in the DB.""" self._require_open() row = self._conn.execute( "SELECT COUNT(*) FROM room_node_overrides" ).fetchone() return int(row[0]) if row else 0 # ── Properties ───────────────────────────────────────────────────────────── @property def db_path(self) -> str: return self._db_path @property def is_open(self) -> bool: return self._conn is not None # ── M6.1: Room agent overrides ───────────────────────────────────────────── def get_agent_override( self, room_id: str ) -> Optional[Tuple[List[str], Optional[str]]]: """ Return (agents_list, default_agent_or_None) for room_id, or None if no override exists. """ self._require_open() row = self._conn.execute( # type: ignore[union-attr] "SELECT agents_csv, default_agent FROM room_agent_overrides WHERE room_id = ?", (room_id,), ).fetchone() if row is None: return None agents = [a.strip() for a in row[0].split(",") if a.strip()] return agents, (row[1] or None) def set_agent_override( self, room_id: str, agents: List[str], default_agent: Optional[str], updated_by: str, ) -> None: """Upsert a room agent override (sorted, deduplicated agents_csv).""" self._require_open() agents_csv = ",".join(sorted(set(agents))) self._conn.execute( # type: ignore[union-attr] """ INSERT INTO room_agent_overrides (room_id, agents_csv, default_agent, updated_at, updated_by_hash) VALUES (?, ?, ?, ?, ?) ON CONFLICT(room_id) DO UPDATE SET agents_csv = excluded.agents_csv, default_agent = excluded.default_agent, updated_at = excluded.updated_at, updated_by_hash = excluded.updated_by_hash """, (room_id, agents_csv, default_agent, int(time.time()), _hash_sender(updated_by)), ) def delete_agent_override(self, room_id: str) -> bool: """Remove agent override for room_id. Returns True if deleted.""" self._require_open() cursor = self._conn.execute( # type: ignore[union-attr] "DELETE FROM room_agent_overrides WHERE room_id = ?", (room_id,), ) return cursor.rowcount > 0 def add_agent_to_room( self, room_id: str, agent: str, updated_by: str ) -> Tuple[List[str], Optional[str]]: """ Add agent to room override, creating it if it doesn't exist. Returns the new (agents, default_agent) state. """ self._require_open() existing = self.get_agent_override(room_id) if existing: agents, default = existing if agent not in agents: agents = sorted(set(agents) | {agent}) self.set_agent_override(room_id, agents, default, updated_by) return agents, default else: self.set_agent_override(room_id, [agent], agent, updated_by) return [agent], agent def remove_agent_from_room( self, room_id: str, agent: str, updated_by: str ) -> Tuple[bool, Optional[str]]: """ Remove agent from room override. Returns (removed: bool, error_message_or_None). If the last agent is removed, the entire override is deleted. """ self._require_open() existing = self.get_agent_override(room_id) if not existing: return False, "No agent override set for this room" agents, default = existing if agent not in agents: return False, f"Agent `{agent}` not in override list" agents = [a for a in agents if a != agent] if not agents: self.delete_agent_override(room_id) return True, None new_default = default if default != agent else agents[0] self.set_agent_override(room_id, agents, new_default, updated_by) return True, None def list_agent_overrides( self, limit: int = 10 ) -> List[Tuple[str, List[str], Optional[str], int]]: """ Return [(room_id, agents_list, default_agent, updated_at), …] ordered by updated_at DESC. """ self._require_open() cap = min(max(1, limit), _LIST_HARD_LIMIT) rows = self._conn.execute( # type: ignore[union-attr] """ SELECT room_id, agents_csv, default_agent, updated_at FROM room_agent_overrides ORDER BY updated_at DESC LIMIT ? """, (cap,), ).fetchall() return [ (r[0], [a.strip() for a in r[1].split(",") if a.strip()], r[2] or None, r[3]) for r in rows ] def count_agent_overrides(self) -> int: """Return total number of agent override rows.""" self._require_open() row = self._conn.execute( "SELECT COUNT(*) FROM room_agent_overrides" ).fetchone() return int(row[0]) if row else 0 # ── M8.2: HA persistence — sticky node cache ────────────────────────────── def upsert_sticky(self, key: str, node_id: str, expires_at_unix: int) -> None: """Persist a sticky routing entry. Idempotent (upsert by key).""" assert self._conn, "Store not open" now = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) self._conn.execute( """INSERT INTO sticky_node_cache (key, node_id, expires_at, updated_at) VALUES (?, ?, ?, ?) ON CONFLICT(key) DO UPDATE SET node_id=excluded.node_id, expires_at=excluded.expires_at, updated_at=excluded.updated_at""", (key, node_id, expires_at_unix, now), ) def delete_sticky(self, key: str) -> bool: """Remove a sticky entry. Returns True if it existed.""" assert self._conn, "Store not open" cur = self._conn.execute( "DELETE FROM sticky_node_cache WHERE key=?", (key,) ) return cur.rowcount > 0 def load_sticky_entries(self) -> List[Tuple[str, str, int]]: """ Return all non-expired sticky entries as (key, node_id, expires_at_unix). Callers filter by monotonic time; here we compare against unix now. """ assert self._conn, "Store not open" now = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) rows = self._conn.execute( "SELECT key, node_id, expires_at FROM sticky_node_cache WHERE expires_at > ?", (now,), ).fetchall() return [(r[0], r[1], int(r[2])) for r in rows] def prune_sticky_expired(self) -> int: """Remove all expired sticky entries. Returns count removed.""" assert self._conn, "Store not open" now = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) cur = self._conn.execute( "DELETE FROM sticky_node_cache WHERE expires_at <= ?", (now,) ) return cur.rowcount # ── M8.2: HA persistence — node health state ────────────────────────────── def upsert_node_health( self, node_id: str, ewma_latency_s: Optional[float], consecutive_failures: int, ) -> None: """Persist node health snapshot. Idempotent (upsert by node_id).""" assert self._conn, "Store not open" now = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) self._conn.execute( """INSERT INTO node_health_state (node_id, ewma_latency_s, consecutive_failures, updated_at) VALUES (?, ?, ?, ?) ON CONFLICT(node_id) DO UPDATE SET ewma_latency_s=excluded.ewma_latency_s, consecutive_failures=excluded.consecutive_failures, updated_at=excluded.updated_at""", (node_id, ewma_latency_s, consecutive_failures, now), ) def load_node_health(self, max_age_s: int = 600) -> Optional[Dict[str, Any]]: """ Load node health snapshot if all rows are fresh enough (updated_at >= now - max_age_s). Returns None if no rows or snapshot is stale. Returns dict: {node_id: {ewma_latency_s, consecutive_failures, updated_at}} """ assert self._conn, "Store not open" now = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) cutoff = now - max_age_s rows = self._conn.execute( """SELECT node_id, ewma_latency_s, consecutive_failures, updated_at FROM node_health_state""", ).fetchall() if not rows: return None result: Dict[str, Any] = {} for node_id, ewma, consec, updated_at in rows: if int(updated_at) < cutoff: logger.debug( "HA: node health snapshot for %s is stale (age=%ds > max=%ds) — ignoring", node_id, now - int(updated_at), max_age_s, ) return None # Any stale node → discard whole snapshot result[node_id] = { "ewma_latency_s": ewma, "consecutive_failures": int(consec), "updated_at": int(updated_at), } return result if result else None # ── M6.2: Snapshot export / import ──────────────────────────────────────── # ── M10.2: Policy change history ────────────────────────────────────────── def record_policy_change( self, verb: str, mode: str, source_file: str, sender_hash: str, diff_summary: str, is_destructive: bool, node_added: int, node_updated: int, node_deleted: int, agent_added: int, agent_updated: int, agent_deleted: int, history_limit: int = _POLICY_HISTORY_DEFAULT_LIMIT, ) -> int: """ Insert a policy apply event into the history table and prune old rows. history_limit=0 means keep all rows (no pruning). Returns the id of the inserted row. """ self._require_open() cur = self._conn.execute( # type: ignore[union-attr] """INSERT INTO policy_changes (applied_at, verb, mode, source_file, sender_hash, diff_summary, is_destructive, node_added, node_updated, node_deleted, agent_added, agent_updated, agent_deleted) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""", ( int(time.time()), verb, mode, source_file, sender_hash, diff_summary, int(is_destructive), node_added, node_updated, node_deleted, agent_added, agent_updated, agent_deleted, ), ) row_id: int = cur.lastrowid # type: ignore[assignment] # Prune oldest rows beyond limit if history_limit > 0: self._conn.execute( # type: ignore[union-attr] """DELETE FROM policy_changes WHERE id NOT IN ( SELECT id FROM policy_changes ORDER BY id DESC LIMIT ? )""", (history_limit,), ) logger.debug( "Recorded policy change id=%d verb=%s mode=%s file=%s destr=%s", row_id, verb, mode, source_file, is_destructive, ) return row_id def list_policy_changes(self, limit: int = 10) -> List[PolicyChange]: """ Return the most-recent `limit` policy change records, newest first. Hard cap: min(limit, _LIST_HARD_LIMIT). """ self._require_open() safe_limit = min(max(1, limit), _LIST_HARD_LIMIT) rows = self._conn.execute( # type: ignore[union-attr] """SELECT id, applied_at, verb, mode, source_file, sender_hash, diff_summary, is_destructive, node_added, node_updated, node_deleted, agent_added, agent_updated, agent_deleted FROM policy_changes ORDER BY id DESC LIMIT ?""", (safe_limit,), ).fetchall() return [ PolicyChange( id=r[0], applied_at=r[1], verb=r[2], mode=r[3], source_file=r[4], sender_hash=r[5], diff_summary=r[6], is_destructive=bool(r[7]), node_added=r[8], node_updated=r[9], node_deleted=r[10], agent_added=r[11], agent_updated=r[12], agent_deleted=r[13], ) for r in rows ] def get_policy_changes_count(self) -> int: """Return the total number of recorded policy changes.""" self._require_open() row = self._conn.execute( # type: ignore[union-attr] "SELECT COUNT(*) FROM policy_changes" ).fetchone() return row[0] if row else 0 def get_policy_change_by_id(self, change_id: int) -> Optional["PolicyChange"]: """Return a single PolicyChange by its DB auto-increment id, or None.""" self._require_open() row = self._conn.execute( # type: ignore[union-attr] """SELECT id, applied_at, verb, mode, source_file, sender_hash, diff_summary, is_destructive, node_added, node_updated, node_deleted, agent_added, agent_updated, agent_deleted FROM policy_changes WHERE id = ?""", (change_id,), ).fetchone() if row is None: return None return PolicyChange( id=row[0], applied_at=row[1], verb=row[2], mode=row[3], source_file=row[4], sender_hash=row[5], diff_summary=row[6], is_destructive=bool(row[7]), node_added=row[8], node_updated=row[9], node_deleted=row[10], agent_added=row[11], agent_updated=row[12], agent_deleted=row[13], ) # ── M10.0: Auto-backup + retention prune ────────────────────────────────── def write_autobackup( self, exports_dir: str, sender_hash8: str, nonce: str, ) -> tuple[str, str]: """ Export all current policy to a timestamped autobackup file. Filename: policy-autobackup---.json Returns (file_path, content_hash_prefix[:8]). Non-atomic write is acceptable: file is complete before we return. """ self._require_open() ts = datetime.datetime.now(datetime.timezone.utc).strftime("%Y%m%dT%H%M%SZ") filename = f"{_AUTOBACKUP_PREFIX}{ts}-{sender_hash8[:8]}-{nonce}.json" file_path = _os.path.join(exports_dir, filename) snapshot = self.export_all() content = _json.dumps(snapshot, sort_keys=True, ensure_ascii=True) with open(file_path, "w", encoding="utf-8") as fh: fh.write(content) content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()[:8] logger.debug("Auto-backup written: %s hash=%s", filename, content_hash) return file_path, content_hash def prune_exports( self, exports_dir: str, retention_days: int, dry_run: bool = True, ) -> PruneResult: """ Remove policy export files older than retention_days. Only files matching 'policy-*.json' in exports_dir are considered — never recursing into subdirectories. dry_run=True: compute stats without deleting. dry_run=False: actually delete matching files. Returns PruneResult with filenames, total_bytes, oldest_mtime. """ if retention_days <= 0: return PruneResult(files_to_delete=[], total_bytes=0) cutoff = time.time() - retention_days * 86400 pattern = _os.path.join(exports_dir, _EXPORT_GLOB) to_delete: List[str] = [] total_bytes = 0 oldest_mtime: Optional[float] = None for fpath in sorted(_glob.glob(pattern)): # Safety: only process files directly in exports_dir (no subdirs) if _os.path.dirname(fpath) != _os.path.abspath(exports_dir): continue try: stat = _os.stat(fpath) except OSError: continue if stat.st_mtime < cutoff: basename = _os.path.basename(fpath) to_delete.append(basename) total_bytes += stat.st_size if oldest_mtime is None or stat.st_mtime < oldest_mtime: oldest_mtime = stat.st_mtime if not dry_run: for basename in to_delete: fpath = _os.path.join(exports_dir, basename) try: _os.remove(fpath) logger.info("Pruned policy export: %s", basename) except OSError as exc: logger.warning("Could not prune %s: %s", basename, exc) return PruneResult( files_to_delete=to_delete, total_bytes=total_bytes, oldest_mtime=oldest_mtime, ) def export_all(self) -> Dict[str, Any]: """ Export all overrides as a JSON-serializable snapshot dict. Format (version 1): { "version": 1, "created_at": "Z", "room_node_overrides": [{room_id, node_id, updated_at, updated_by}, ...], "room_agent_overrides": [{room_id, agents, default_agent, updated_at, updated_by}, ...] } """ self._require_open() node_rows = self._conn.execute( # type: ignore[union-attr] "SELECT room_id, node_id, updated_at, updated_by_hash FROM room_node_overrides ORDER BY room_id" ).fetchall() agent_rows = self._conn.execute( # type: ignore[union-attr] """SELECT room_id, agents_csv, default_agent, updated_at, updated_by_hash FROM room_agent_overrides ORDER BY room_id""" ).fetchall() return { "version": POLICY_SNAPSHOT_VERSION, "created_at": datetime.datetime.now(datetime.timezone.utc).isoformat().replace("+00:00", "Z"), "room_node_overrides": [ {"room_id": r[0], "node_id": r[1], "updated_at": r[2], "updated_by": r[3]} for r in node_rows ], "room_agent_overrides": [ { "room_id": r[0], "agents": [a.strip() for a in r[1].split(",") if a.strip()], "default_agent": r[2] or None, "updated_at": r[3], "updated_by": r[4], } for r in agent_rows ], } def compute_import_diff( self, data: Dict[str, Any], mode: str = POLICY_IMPORT_MODE_MERGE, ) -> ImportDiff: """ Compute what would change if data were imported (dry-run, M9.1). Returns an ImportDiff with counts and up to _SAMPLE_KEYS_MAX changed rooms. Non-destructive — never modifies the database. """ if data.get("version") != POLICY_SNAPSHOT_VERSION: raise ValueError(f"Unsupported snapshot version: {data.get('version')!r}") self._require_open() existing_nodes: Dict[str, str] = { r[0]: r[1] for r in self._conn.execute( # type: ignore[union-attr] "SELECT room_id, node_id FROM room_node_overrides" ).fetchall() } existing_agents: Dict[str, str] = { r[0]: r[1] for r in self._conn.execute( # type: ignore[union-attr] "SELECT room_id, agents_csv FROM room_agent_overrides" ).fetchall() } file_nodes: Dict[str, str] = { e["room_id"]: e["node_id"] for e in (data.get("room_node_overrides") or []) if "room_id" in e and "node_id" in e } file_agents: Dict[str, Any] = { e["room_id"]: e for e in (data.get("room_agent_overrides") or []) if "room_id" in e and "agents" in e } node_added = sum(1 for r in file_nodes if r not in existing_nodes) node_updated = sum(1 for r in file_nodes if r in existing_nodes) agent_added = sum(1 for r in file_agents if r not in existing_agents) agent_updated = sum(1 for r in file_agents if r in existing_agents) node_deleted = 0 agent_deleted = 0 if mode == POLICY_IMPORT_MODE_REPLACE: node_deleted = sum(1 for r in existing_nodes if r not in file_nodes) agent_deleted = sum(1 for r in existing_agents if r not in file_agents) # Collect up to _SAMPLE_KEYS_MAX affected rooms (deterministic: sorted) affected: List[str] = [] seen: set[str] = set() for rid in list(file_nodes) + list(file_agents): if rid not in seen: affected.append(rid) seen.add(rid) if mode == POLICY_IMPORT_MODE_REPLACE: for rid in list(existing_nodes) + list(existing_agents): if rid not in seen and (rid not in file_nodes or rid not in file_agents): affected.append(rid) seen.add(rid) sample_keys = sorted(affected)[:_SAMPLE_KEYS_MAX] return ImportDiff( node_added=node_added, node_updated=node_updated, node_deleted=node_deleted, agent_added=agent_added, agent_updated=agent_updated, agent_deleted=agent_deleted, sample_keys=sample_keys, is_replace=(mode == POLICY_IMPORT_MODE_REPLACE), ) def import_snapshot( self, data: Dict[str, Any], mode: str = POLICY_IMPORT_MODE_MERGE, dry_run: bool = True, imported_by: str = "import", ) -> Dict[str, int]: """ Import a policy snapshot. mode=merge: upsert entries from file; never delete existing entries not in file. mode=replace: upsert entries from file AND delete entries in DB not present in file. dry_run=True: compute stats without modifying DB. Returns: { "node_added": N, "node_updated": N, "node_deleted": N, "agent_added": N, "agent_updated": N, "agent_deleted": N, } """ if data.get("version") != POLICY_SNAPSHOT_VERSION: raise ValueError(f"Unsupported snapshot version: {data.get('version')!r}") self._require_open() # ── Current DB state ────────────────────────────────────────────────── existing_nodes: Dict[str, str] = { r[0]: r[1] for r in self._conn.execute( # type: ignore[union-attr] "SELECT room_id, node_id FROM room_node_overrides" ).fetchall() } existing_agents: Dict[str, str] = { r[0]: r[1] for r in self._conn.execute( # type: ignore[union-attr] "SELECT room_id, agents_csv FROM room_agent_overrides" ).fetchall() } # ── Compute deltas ──────────────────────────────────────────────────── file_nodes = { e["room_id"]: e["node_id"] for e in (data.get("room_node_overrides") or []) if "room_id" in e and "node_id" in e } file_agents = { e["room_id"]: e for e in (data.get("room_agent_overrides") or []) if "room_id" in e and "agents" in e } node_added = sum(1 for r in file_nodes if r not in existing_nodes) node_updated = sum(1 for r in file_nodes if r in existing_nodes) agent_added = sum(1 for r in file_agents if r not in existing_agents) agent_updated = sum(1 for r in file_agents if r in existing_agents) node_deleted = 0 agent_deleted = 0 if mode == POLICY_IMPORT_MODE_REPLACE: node_deleted = sum(1 for r in existing_nodes if r not in file_nodes) agent_deleted = sum(1 for r in existing_agents if r not in file_agents) stats = { "node_added": node_added, "node_updated": node_updated, "node_deleted": node_deleted, "agent_added": agent_added, "agent_updated": agent_updated, "agent_deleted": agent_deleted, } if dry_run: return stats # ── Apply changes ───────────────────────────────────────────────────── now = int(time.time()) by_hash = _hash_sender(imported_by) for entry in (data.get("room_node_overrides") or []): rid = entry.get("room_id") nid = entry.get("node_id") if rid and nid: self._conn.execute( # type: ignore[union-attr] """ INSERT INTO room_node_overrides (room_id, node_id, updated_at, updated_by_hash) VALUES (?, ?, ?, ?) ON CONFLICT(room_id) DO UPDATE SET node_id = excluded.node_id, updated_at = excluded.updated_at, updated_by_hash = excluded.updated_by_hash """, (rid, nid, now, by_hash), ) for entry in (data.get("room_agent_overrides") or []): rid = entry.get("room_id") agents = entry.get("agents") or [] def_agent = entry.get("default_agent") or (agents[0] if agents else None) if rid and agents: agents_csv = ",".join(sorted(set(agents))) self._conn.execute( # type: ignore[union-attr] """ INSERT INTO room_agent_overrides (room_id, agents_csv, default_agent, updated_at, updated_by_hash) VALUES (?, ?, ?, ?, ?) ON CONFLICT(room_id) DO UPDATE SET agents_csv = excluded.agents_csv, default_agent = excluded.default_agent, updated_at = excluded.updated_at, updated_by_hash = excluded.updated_by_hash """, (rid, agents_csv, def_agent, now, by_hash), ) if mode == POLICY_IMPORT_MODE_REPLACE: file_node_rooms = set(file_nodes.keys()) file_agent_rooms = set(file_agents.keys()) for room_id in existing_nodes: if room_id not in file_node_rooms: self._conn.execute( # type: ignore[union-attr] "DELETE FROM room_node_overrides WHERE room_id = ?", (room_id,) ) for room_id in existing_agents: if room_id not in file_agent_rooms: self._conn.execute( # type: ignore[union-attr] "DELETE FROM room_agent_overrides WHERE room_id = ?", (room_id,) ) return stats # ── Internal ─────────────────────────────────────────────────────────────── def _require_open(self) -> None: if self._conn is None: raise RuntimeError("PolicyStore is not open — call open() first")