164 lines
5.3 KiB
Python
164 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
import time
|
|
import json
|
|
from collections import OrderedDict
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Optional, Protocol
|
|
|
|
from .logging import log_event
|
|
|
|
|
|
@dataclass
|
|
class ReplayEntry:
|
|
message_id: str
|
|
response_body: Dict[str, Any]
|
|
created_at: float
|
|
node_id: str
|
|
|
|
|
|
class IdempotencyStore(Protocol):
|
|
def get(self, key: str) -> Optional[ReplayEntry]:
|
|
...
|
|
|
|
def set(self, key: str, entry: ReplayEntry) -> None:
|
|
...
|
|
|
|
|
|
class InMemoryIdempotencyStore:
|
|
def __init__(self, ttl_seconds: int = 900, max_size: int = 5000) -> None:
|
|
self._ttl_seconds = max(60, int(ttl_seconds))
|
|
self._max_size = max(100, int(max_size))
|
|
self._values: "OrderedDict[str, Dict[str, Any]]" = OrderedDict()
|
|
|
|
def _cleanup(self, now: Optional[float] = None) -> None:
|
|
ts = now if now is not None else time.monotonic()
|
|
while self._values:
|
|
first_key = next(iter(self._values))
|
|
exp = float((self._values[first_key] or {}).get("expires_at", 0.0))
|
|
if exp > ts:
|
|
break
|
|
self._values.popitem(last=False)
|
|
|
|
def get(self, key: str) -> Optional[ReplayEntry]:
|
|
self._cleanup()
|
|
hit = self._values.get(key)
|
|
if not hit:
|
|
return None
|
|
# Touch key to preserve LRU behavior.
|
|
self._values.move_to_end(key, last=True)
|
|
entry = hit.get("entry")
|
|
return entry if isinstance(entry, ReplayEntry) else None
|
|
|
|
def set(self, key: str, entry: ReplayEntry) -> None:
|
|
now = time.monotonic()
|
|
self._cleanup(now)
|
|
self._values[key] = {
|
|
"expires_at": now + self._ttl_seconds,
|
|
"entry": entry,
|
|
}
|
|
self._values.move_to_end(key, last=True)
|
|
while len(self._values) > self._max_size:
|
|
self._values.popitem(last=False)
|
|
|
|
# Debug/test helpers
|
|
def size(self) -> int:
|
|
self._cleanup()
|
|
return len(self._values)
|
|
|
|
def delete(self, key: str) -> None:
|
|
self._values.pop(key, None)
|
|
|
|
def reset(self) -> None:
|
|
self._values.clear()
|
|
|
|
|
|
class RedisIdempotencyStore:
|
|
def __init__(self, redis_client: Any, ttl_seconds: int = 900, prefix: str = "sofiia:idem:") -> None:
|
|
self._redis = redis_client
|
|
self._ttl_seconds = max(60, int(ttl_seconds))
|
|
self._prefix = str(prefix or "sofiia:idem:")
|
|
|
|
def _k(self, key: str) -> str:
|
|
return f"{self._prefix}{key}"
|
|
|
|
def get(self, key: str) -> Optional[ReplayEntry]:
|
|
raw = self._redis.get(self._k(key))
|
|
if raw is None:
|
|
return None
|
|
if isinstance(raw, bytes):
|
|
raw = raw.decode("utf-8", errors="ignore")
|
|
try:
|
|
payload = json.loads(str(raw))
|
|
except Exception:
|
|
return None
|
|
if not isinstance(payload, dict):
|
|
return None
|
|
return ReplayEntry(
|
|
message_id=str(payload.get("message_id") or ""),
|
|
response_body=dict(payload.get("response_body") or {}),
|
|
created_at=float(payload.get("created_at") or 0.0),
|
|
node_id=str(payload.get("node_id") or ""),
|
|
)
|
|
|
|
def set(self, key: str, entry: ReplayEntry) -> None:
|
|
payload = {
|
|
"message_id": entry.message_id,
|
|
"response_body": entry.response_body,
|
|
"created_at": float(entry.created_at),
|
|
"node_id": entry.node_id,
|
|
}
|
|
self._redis.set(self._k(key), json.dumps(payload, ensure_ascii=True), ex=self._ttl_seconds)
|
|
|
|
# Debug/test helpers
|
|
def delete(self, key: str) -> None:
|
|
self._redis.delete(self._k(key))
|
|
|
|
def reset(self) -> None:
|
|
keys = self._redis.keys(f"{self._prefix}*")
|
|
if keys:
|
|
self._redis.delete(*keys)
|
|
|
|
|
|
_STORE: Optional[IdempotencyStore] = None
|
|
|
|
|
|
def _make_redis_client(redis_url: str) -> Any:
|
|
import redis # type: ignore
|
|
|
|
return redis.Redis.from_url(redis_url, decode_responses=False)
|
|
|
|
|
|
def get_idempotency_store() -> IdempotencyStore:
|
|
global _STORE
|
|
if _STORE is None:
|
|
ttl = int(
|
|
os.getenv(
|
|
"SOFIIA_IDEMPOTENCY_TTL_S",
|
|
os.getenv("CHAT_IDEMPOTENCY_TTL_SEC", "900"),
|
|
)
|
|
)
|
|
backend = os.getenv("SOFIIA_IDEMPOTENCY_BACKEND", "inmemory").strip().lower() or "inmemory"
|
|
if backend == "redis":
|
|
redis_url = os.getenv("SOFIIA_REDIS_URL", "redis://localhost:6379/0").strip()
|
|
prefix = os.getenv("SOFIIA_REDIS_PREFIX", "sofiia:idem:").strip() or "sofiia:idem:"
|
|
try:
|
|
client = _make_redis_client(redis_url)
|
|
_STORE = RedisIdempotencyStore(client, ttl_seconds=ttl, prefix=prefix)
|
|
except Exception as exc:
|
|
max_size = int(os.getenv("SOFIIA_IDEMPOTENCY_MAX", "5000"))
|
|
_STORE = InMemoryIdempotencyStore(ttl_seconds=ttl, max_size=max_size)
|
|
log_event(
|
|
"idempotency.backend.fallback",
|
|
backend="redis",
|
|
status="degraded",
|
|
error_code="redis_unavailable",
|
|
error=str(exc)[:180],
|
|
)
|
|
else:
|
|
max_size = int(os.getenv("SOFIIA_IDEMPOTENCY_MAX", "5000"))
|
|
_STORE = InMemoryIdempotencyStore(ttl_seconds=ttl, max_size=max_size)
|
|
return _STORE
|
|
|