Files
microdao-daarion/services/sofiia-console/app/rate_limit.py

170 lines
5.3 KiB
Python

from __future__ import annotations
import math
import os
import time
from dataclasses import dataclass
from typing import Any, Dict, Optional, Protocol
from .logging import log_event
@dataclass
class RateLimitDecision:
allowed: bool
retry_after_s: int
remaining_tokens: float
class RateLimiter(Protocol):
def consume(self, key: str, rps: float, burst: int, cost: float = 1.0) -> RateLimitDecision:
...
class InMemoryRateLimiter:
def __init__(self, max_keys: int = 10000) -> None:
self._max_keys = max(1000, int(max_keys))
self._buckets: Dict[str, Dict[str, float]] = {}
def consume(self, key: str, rps: float, burst: int, cost: float = 1.0) -> RateLimitDecision:
rate = max(0.0, float(rps))
cap = max(1.0, float(burst))
req = max(0.1, float(cost))
now = time.monotonic()
hit = self._buckets.get(key)
if not hit:
tokens = cap
updated_at = now
else:
tokens = float(hit.get("tokens") or 0.0)
updated_at = float(hit.get("updated_at") or now)
elapsed = max(0.0, now - updated_at)
if rate > 0:
tokens = min(cap, tokens + elapsed * rate)
allowed = tokens >= req
retry_after = 0
if allowed:
tokens -= req
else:
needed = max(0.0, req - tokens)
retry_after = int(math.ceil(needed / rate)) if rate > 0 else 60
self._buckets[key] = {"tokens": max(0.0, tokens), "updated_at": now}
self._gc()
return RateLimitDecision(allowed=allowed, retry_after_s=retry_after, remaining_tokens=max(0.0, tokens))
def _gc(self) -> None:
if len(self._buckets) <= self._max_keys:
return
items = sorted(self._buckets.items(), key=lambda kv: float((kv[1] or {}).get("updated_at") or 0.0))
to_drop = max(1, len(items) // 2)
for k, _ in items[:to_drop]:
self._buckets.pop(k, None)
def reset(self) -> None:
self._buckets.clear()
class RedisRateLimiter:
def __init__(self, redis_client: Any, prefix: str = "sofiia:rl:", key_ttl_s: int = 86400) -> None:
self._redis = redis_client
self._prefix = str(prefix or "sofiia:rl:")
self._key_ttl_s = max(60, int(key_ttl_s))
self._lua = """
local key = KEYS[1]
local now = tonumber(ARGV[1])
local rps = tonumber(ARGV[2])
local burst = tonumber(ARGV[3])
local cost = tonumber(ARGV[4])
local ttl = tonumber(ARGV[5])
local data = redis.call('HMGET', key, 'tokens', 'ts')
local tokens = tonumber(data[1])
local ts = tonumber(data[2])
if tokens == nil then tokens = burst end
if ts == nil then ts = now end
if now > ts then
tokens = math.min(burst, tokens + ((now - ts) * rps))
end
local allowed = 0
local retry = 0
if tokens >= cost then
tokens = tokens - cost
allowed = 1
else
local needed = cost - tokens
if rps > 0 then
retry = math.ceil(needed / rps)
else
retry = 60
end
end
redis.call('HMSET', key, 'tokens', tokens, 'ts', now)
redis.call('EXPIRE', key, ttl)
return {allowed, retry, tokens}
"""
def _k(self, key: str) -> str:
return f"{self._prefix}{key}"
def consume(self, key: str, rps: float, burst: int, cost: float = 1.0) -> RateLimitDecision:
now = time.time()
rate = max(0.0, float(rps))
cap = max(1.0, float(burst))
req = max(0.1, float(cost))
out = self._redis.eval(
self._lua,
1,
self._k(key),
now,
rate,
cap,
req,
self._key_ttl_s,
)
allowed = bool((out or [0])[0])
retry_after = int((out or [0, 0])[1] or 0)
rem = float((out or [0, 0, 0.0])[2] or 0.0)
return RateLimitDecision(allowed=allowed, retry_after_s=retry_after, remaining_tokens=max(0.0, rem))
def reset(self) -> None:
keys = self._redis.keys(f"{self._prefix}*")
if keys:
self._redis.delete(*keys)
_RATE_LIMITER: Optional[RateLimiter] = None
def _make_redis_client(redis_url: str) -> Any:
import redis # type: ignore
return redis.Redis.from_url(redis_url, decode_responses=False)
def get_rate_limiter() -> RateLimiter:
global _RATE_LIMITER
if _RATE_LIMITER is None:
backend = os.getenv("SOFIIA_RATE_LIMIT_BACKEND", "inmemory").strip().lower() or "inmemory"
if backend == "redis":
redis_url = os.getenv("SOFIIA_REDIS_URL", "redis://localhost:6379/0").strip()
prefix = os.getenv("SOFIIA_RATE_LIMIT_PREFIX", "sofiia:rl:").strip() or "sofiia:rl:"
ttl = int(os.getenv("SOFIIA_RATE_LIMIT_KEY_TTL_S", "86400"))
try:
client = _make_redis_client(redis_url)
_RATE_LIMITER = RedisRateLimiter(client, prefix=prefix, key_ttl_s=ttl)
except Exception as exc:
_RATE_LIMITER = InMemoryRateLimiter()
log_event(
"rate_limit.backend.fallback",
backend="redis",
status="degraded",
error_code="redis_unavailable",
error=str(exc)[:180],
)
else:
max_keys = int(os.getenv("SOFIIA_RATE_LIMIT_MAX_KEYS", "10000"))
_RATE_LIMITER = InMemoryRateLimiter(max_keys=max_keys)
return _RATE_LIMITER