from __future__ import annotations import os import time import json from collections import OrderedDict from dataclasses import dataclass from typing import Any, Dict, Optional, Protocol from .logging import log_event @dataclass class ReplayEntry: message_id: str response_body: Dict[str, Any] created_at: float node_id: str class IdempotencyStore(Protocol): def get(self, key: str) -> Optional[ReplayEntry]: ... def set(self, key: str, entry: ReplayEntry) -> None: ... class InMemoryIdempotencyStore: def __init__(self, ttl_seconds: int = 900, max_size: int = 5000) -> None: self._ttl_seconds = max(60, int(ttl_seconds)) self._max_size = max(100, int(max_size)) self._values: "OrderedDict[str, Dict[str, Any]]" = OrderedDict() def _cleanup(self, now: Optional[float] = None) -> None: ts = now if now is not None else time.monotonic() while self._values: first_key = next(iter(self._values)) exp = float((self._values[first_key] or {}).get("expires_at", 0.0)) if exp > ts: break self._values.popitem(last=False) def get(self, key: str) -> Optional[ReplayEntry]: self._cleanup() hit = self._values.get(key) if not hit: return None # Touch key to preserve LRU behavior. self._values.move_to_end(key, last=True) entry = hit.get("entry") return entry if isinstance(entry, ReplayEntry) else None def set(self, key: str, entry: ReplayEntry) -> None: now = time.monotonic() self._cleanup(now) self._values[key] = { "expires_at": now + self._ttl_seconds, "entry": entry, } self._values.move_to_end(key, last=True) while len(self._values) > self._max_size: self._values.popitem(last=False) # Debug/test helpers def size(self) -> int: self._cleanup() return len(self._values) def delete(self, key: str) -> None: self._values.pop(key, None) def reset(self) -> None: self._values.clear() class RedisIdempotencyStore: def __init__(self, redis_client: Any, ttl_seconds: int = 900, prefix: str = "sofiia:idem:") -> None: self._redis = redis_client self._ttl_seconds = max(60, int(ttl_seconds)) self._prefix = str(prefix or "sofiia:idem:") def _k(self, key: str) -> str: return f"{self._prefix}{key}" def get(self, key: str) -> Optional[ReplayEntry]: raw = self._redis.get(self._k(key)) if raw is None: return None if isinstance(raw, bytes): raw = raw.decode("utf-8", errors="ignore") try: payload = json.loads(str(raw)) except Exception: return None if not isinstance(payload, dict): return None return ReplayEntry( message_id=str(payload.get("message_id") or ""), response_body=dict(payload.get("response_body") or {}), created_at=float(payload.get("created_at") or 0.0), node_id=str(payload.get("node_id") or ""), ) def set(self, key: str, entry: ReplayEntry) -> None: payload = { "message_id": entry.message_id, "response_body": entry.response_body, "created_at": float(entry.created_at), "node_id": entry.node_id, } self._redis.set(self._k(key), json.dumps(payload, ensure_ascii=True), ex=self._ttl_seconds) # Debug/test helpers def delete(self, key: str) -> None: self._redis.delete(self._k(key)) def reset(self) -> None: keys = self._redis.keys(f"{self._prefix}*") if keys: self._redis.delete(*keys) _STORE: Optional[IdempotencyStore] = None def _make_redis_client(redis_url: str) -> Any: import redis # type: ignore return redis.Redis.from_url(redis_url, decode_responses=False) def get_idempotency_store() -> IdempotencyStore: global _STORE if _STORE is None: ttl = int( os.getenv( "SOFIIA_IDEMPOTENCY_TTL_S", os.getenv("CHAT_IDEMPOTENCY_TTL_SEC", "900"), ) ) backend = os.getenv("SOFIIA_IDEMPOTENCY_BACKEND", "inmemory").strip().lower() or "inmemory" if backend == "redis": redis_url = os.getenv("SOFIIA_REDIS_URL", "redis://localhost:6379/0").strip() prefix = os.getenv("SOFIIA_REDIS_PREFIX", "sofiia:idem:").strip() or "sofiia:idem:" try: client = _make_redis_client(redis_url) _STORE = RedisIdempotencyStore(client, ttl_seconds=ttl, prefix=prefix) except Exception as exc: max_size = int(os.getenv("SOFIIA_IDEMPOTENCY_MAX", "5000")) _STORE = InMemoryIdempotencyStore(ttl_seconds=ttl, max_size=max_size) log_event( "idempotency.backend.fallback", backend="redis", status="degraded", error_code="redis_unavailable", error=str(exc)[:180], ) else: max_size = int(os.getenv("SOFIIA_IDEMPOTENCY_MAX", "5000")) _STORE = InMemoryIdempotencyStore(ttl_seconds=ttl, max_size=max_size) return _STORE