""" Sofiia Supervisor — State Backend Supports: - redis: production (requires redis-py) - memory: in-process dict (testing / single-instance dev) Redis schema: run:{run_id} → JSON (RunRecord without events) run:{run_id}:events → Redis list of JSON RunEvent TTL: RUN_TTL_SEC (default 24h) """ from __future__ import annotations import json import logging from abc import ABC, abstractmethod from typing import List, Optional from .config import settings from .models import RunEvent, RunRecord, RunStatus logger = logging.getLogger(__name__) class StateBackend(ABC): @abstractmethod async def save_run(self, run: RunRecord) -> None: ... @abstractmethod async def get_run(self, run_id: str) -> Optional[RunRecord]: ... @abstractmethod async def append_event(self, run_id: str, event: RunEvent) -> None: ... @abstractmethod async def get_events(self, run_id: str) -> List[RunEvent]: ... @abstractmethod async def cancel_run(self, run_id: str) -> bool: ... # ─── In-memory backend (testing/dev) ───────────────────────────────────────── class MemoryStateBackend(StateBackend): def __init__(self): self._runs: dict[str, RunRecord] = {} self._events: dict[str, list[RunEvent]] = {} async def save_run(self, run: RunRecord) -> None: self._runs[run.run_id] = run async def get_run(self, run_id: str) -> Optional[RunRecord]: return self._runs.get(run_id) async def append_event(self, run_id: str, event: RunEvent) -> None: self._events.setdefault(run_id, []).append(event) async def get_events(self, run_id: str) -> List[RunEvent]: return list(self._events.get(run_id, [])) async def cancel_run(self, run_id: str) -> bool: run = self._runs.get(run_id) if not run: return False if run.status in (RunStatus.SUCCEEDED, RunStatus.FAILED, RunStatus.CANCELLED): return False run.status = RunStatus.CANCELLED return True # ─── Redis backend (production) ────────────────────────────────────────────── class RedisStateBackend(StateBackend): def __init__(self): self._redis = None async def _client(self): if self._redis is None: try: import redis.asyncio as aioredis self._redis = await aioredis.from_url( settings.REDIS_URL, decode_responses=True, ) except Exception as e: logger.error(f"Redis connection error: {e}") raise return self._redis def _run_key(self, run_id: str) -> str: return f"run:{run_id}" def _events_key(self, run_id: str) -> str: return f"run:{run_id}:events" async def save_run(self, run: RunRecord) -> None: r = await self._client() # Store run without events (events stored separately in list) data = run.model_dump(exclude={"events"}) await r.setex( self._run_key(run.run_id), settings.RUN_TTL_SEC, json.dumps(data, default=str), ) async def get_run(self, run_id: str) -> Optional[RunRecord]: r = await self._client() raw = await r.get(self._run_key(run_id)) if not raw: return None try: data = json.loads(raw) events = await self.get_events(run_id) data["events"] = [e.model_dump() for e in events] return RunRecord(**data) except Exception as e: logger.error(f"Deserialise run {run_id}: {e}") return None async def append_event(self, run_id: str, event: RunEvent) -> None: r = await self._client() key = self._events_key(run_id) await r.rpush(key, json.dumps(event.model_dump(), default=str)) await r.expire(key, settings.RUN_TTL_SEC) async def get_events(self, run_id: str) -> List[RunEvent]: r = await self._client() raw_list = await r.lrange(self._events_key(run_id), 0, -1) events = [] for raw in raw_list: try: events.append(RunEvent(**json.loads(raw))) except Exception: pass return events async def cancel_run(self, run_id: str) -> bool: run = await self.get_run(run_id) if not run: return False if run.status in (RunStatus.SUCCEEDED, RunStatus.FAILED, RunStatus.CANCELLED): return False run.status = RunStatus.CANCELLED await self.save_run(run) return True # ─── Factory ───────────────────────────────────────────────────────────────── def create_state_backend() -> StateBackend: if settings.STATE_BACKEND == "redis": logger.info("Using Redis state backend") return RedisStateBackend() logger.info("Using in-memory state backend") return MemoryStateBackend()