170 lines
5.3 KiB
Python
170 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
import math
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any, Dict, Optional, Protocol
|
|
|
|
from .logging import log_event
|
|
|
|
|
|
@dataclass
|
|
class RateLimitDecision:
|
|
allowed: bool
|
|
retry_after_s: int
|
|
remaining_tokens: float
|
|
|
|
|
|
class RateLimiter(Protocol):
|
|
def consume(self, key: str, rps: float, burst: int, cost: float = 1.0) -> RateLimitDecision:
|
|
...
|
|
|
|
|
|
class InMemoryRateLimiter:
|
|
def __init__(self, max_keys: int = 10000) -> None:
|
|
self._max_keys = max(1000, int(max_keys))
|
|
self._buckets: Dict[str, Dict[str, float]] = {}
|
|
|
|
def consume(self, key: str, rps: float, burst: int, cost: float = 1.0) -> RateLimitDecision:
|
|
rate = max(0.0, float(rps))
|
|
cap = max(1.0, float(burst))
|
|
req = max(0.1, float(cost))
|
|
now = time.monotonic()
|
|
|
|
hit = self._buckets.get(key)
|
|
if not hit:
|
|
tokens = cap
|
|
updated_at = now
|
|
else:
|
|
tokens = float(hit.get("tokens") or 0.0)
|
|
updated_at = float(hit.get("updated_at") or now)
|
|
elapsed = max(0.0, now - updated_at)
|
|
if rate > 0:
|
|
tokens = min(cap, tokens + elapsed * rate)
|
|
|
|
allowed = tokens >= req
|
|
retry_after = 0
|
|
if allowed:
|
|
tokens -= req
|
|
else:
|
|
needed = max(0.0, req - tokens)
|
|
retry_after = int(math.ceil(needed / rate)) if rate > 0 else 60
|
|
|
|
self._buckets[key] = {"tokens": max(0.0, tokens), "updated_at": now}
|
|
self._gc()
|
|
return RateLimitDecision(allowed=allowed, retry_after_s=retry_after, remaining_tokens=max(0.0, tokens))
|
|
|
|
def _gc(self) -> None:
|
|
if len(self._buckets) <= self._max_keys:
|
|
return
|
|
items = sorted(self._buckets.items(), key=lambda kv: float((kv[1] or {}).get("updated_at") or 0.0))
|
|
to_drop = max(1, len(items) // 2)
|
|
for k, _ in items[:to_drop]:
|
|
self._buckets.pop(k, None)
|
|
|
|
def reset(self) -> None:
|
|
self._buckets.clear()
|
|
|
|
|
|
class RedisRateLimiter:
|
|
def __init__(self, redis_client: Any, prefix: str = "sofiia:rl:", key_ttl_s: int = 86400) -> None:
|
|
self._redis = redis_client
|
|
self._prefix = str(prefix or "sofiia:rl:")
|
|
self._key_ttl_s = max(60, int(key_ttl_s))
|
|
self._lua = """
|
|
local key = KEYS[1]
|
|
local now = tonumber(ARGV[1])
|
|
local rps = tonumber(ARGV[2])
|
|
local burst = tonumber(ARGV[3])
|
|
local cost = tonumber(ARGV[4])
|
|
local ttl = tonumber(ARGV[5])
|
|
local data = redis.call('HMGET', key, 'tokens', 'ts')
|
|
local tokens = tonumber(data[1])
|
|
local ts = tonumber(data[2])
|
|
if tokens == nil then tokens = burst end
|
|
if ts == nil then ts = now end
|
|
if now > ts then
|
|
tokens = math.min(burst, tokens + ((now - ts) * rps))
|
|
end
|
|
local allowed = 0
|
|
local retry = 0
|
|
if tokens >= cost then
|
|
tokens = tokens - cost
|
|
allowed = 1
|
|
else
|
|
local needed = cost - tokens
|
|
if rps > 0 then
|
|
retry = math.ceil(needed / rps)
|
|
else
|
|
retry = 60
|
|
end
|
|
end
|
|
redis.call('HMSET', key, 'tokens', tokens, 'ts', now)
|
|
redis.call('EXPIRE', key, ttl)
|
|
return {allowed, retry, tokens}
|
|
"""
|
|
|
|
def _k(self, key: str) -> str:
|
|
return f"{self._prefix}{key}"
|
|
|
|
def consume(self, key: str, rps: float, burst: int, cost: float = 1.0) -> RateLimitDecision:
|
|
now = time.time()
|
|
rate = max(0.0, float(rps))
|
|
cap = max(1.0, float(burst))
|
|
req = max(0.1, float(cost))
|
|
out = self._redis.eval(
|
|
self._lua,
|
|
1,
|
|
self._k(key),
|
|
now,
|
|
rate,
|
|
cap,
|
|
req,
|
|
self._key_ttl_s,
|
|
)
|
|
allowed = bool((out or [0])[0])
|
|
retry_after = int((out or [0, 0])[1] or 0)
|
|
rem = float((out or [0, 0, 0.0])[2] or 0.0)
|
|
return RateLimitDecision(allowed=allowed, retry_after_s=retry_after, remaining_tokens=max(0.0, rem))
|
|
|
|
def reset(self) -> None:
|
|
keys = self._redis.keys(f"{self._prefix}*")
|
|
if keys:
|
|
self._redis.delete(*keys)
|
|
|
|
|
|
_RATE_LIMITER: Optional[RateLimiter] = None
|
|
|
|
|
|
def _make_redis_client(redis_url: str) -> Any:
|
|
import redis # type: ignore
|
|
|
|
return redis.Redis.from_url(redis_url, decode_responses=False)
|
|
|
|
|
|
def get_rate_limiter() -> RateLimiter:
|
|
global _RATE_LIMITER
|
|
if _RATE_LIMITER is None:
|
|
backend = os.getenv("SOFIIA_RATE_LIMIT_BACKEND", "inmemory").strip().lower() or "inmemory"
|
|
if backend == "redis":
|
|
redis_url = os.getenv("SOFIIA_REDIS_URL", "redis://localhost:6379/0").strip()
|
|
prefix = os.getenv("SOFIIA_RATE_LIMIT_PREFIX", "sofiia:rl:").strip() or "sofiia:rl:"
|
|
ttl = int(os.getenv("SOFIIA_RATE_LIMIT_KEY_TTL_S", "86400"))
|
|
try:
|
|
client = _make_redis_client(redis_url)
|
|
_RATE_LIMITER = RedisRateLimiter(client, prefix=prefix, key_ttl_s=ttl)
|
|
except Exception as exc:
|
|
_RATE_LIMITER = InMemoryRateLimiter()
|
|
log_event(
|
|
"rate_limit.backend.fallback",
|
|
backend="redis",
|
|
status="degraded",
|
|
error_code="redis_unavailable",
|
|
error=str(exc)[:180],
|
|
)
|
|
else:
|
|
max_keys = int(os.getenv("SOFIIA_RATE_LIMIT_MAX_KEYS", "10000"))
|
|
_RATE_LIMITER = InMemoryRateLimiter(max_keys=max_keys)
|
|
return _RATE_LIMITER
|