From 9b89ace2fcd8c78f7a2da25204b15353693ed6af Mon Sep 17 00:00:00 2001 From: Apple Date: Mon, 2 Mar 2026 09:24:21 -0800 Subject: [PATCH] feat(sofiia-console): add rate limiting for chat send (per-chat and per-operator) Made-with: Cursor --- services/sofiia-console/app/main.py | 79 +++++++++- services/sofiia-console/app/metrics.py | 11 ++ services/sofiia-console/app/rate_limit.py | 169 ++++++++++++++++++++++ tests/test_sofiia_rate_limit.py | 74 ++++++++++ 4 files changed, 331 insertions(+), 2 deletions(-) create mode 100644 services/sofiia-console/app/rate_limit.py create mode 100644 tests/test_sofiia_rate_limit.py diff --git a/services/sofiia-console/app/main.py b/services/sofiia-console/app/main.py index cf1632bf..e8ea500b 100644 --- a/services/sofiia-console/app/main.py +++ b/services/sofiia-console/app/main.py @@ -61,9 +61,11 @@ from .metrics import ( SOFIIA_SEND_REQUESTS_TOTAL, SOFIIA_IDEMPOTENCY_REPLAYS_TOTAL, SOFIIA_CURSOR_REQUESTS_TOTAL, + SOFIIA_RATE_LIMITED_TOTAL, render_metrics, ) from .idempotency import get_idempotency_store, ReplayEntry +from .rate_limit import get_rate_limiter from .logging import ( configure_sofiia_logger, get_request_id, @@ -86,6 +88,11 @@ _NODE_ID = os.getenv("NODE_ID", os.getenv("HOSTNAME", "noda2")) _rate_buckets: Dict[str, collections.deque] = {} _idempotency_store = get_idempotency_store() +_rate_limiter = get_rate_limiter() +_RL_CHAT_RPS = float(os.getenv("SOFIIA_RL_CHAT_RPS", "1.0")) +_RL_CHAT_BURST = int(os.getenv("SOFIIA_RL_CHAT_BURST", "8")) +_RL_OP_RPS = float(os.getenv("SOFIIA_RL_OP_RPS", "3.0")) +_RL_OP_BURST = int(os.getenv("SOFIIA_RL_OP_BURST", "20")) def _check_rate(key: str, max_calls: int, window_sec: int = 60) -> bool: now = time.monotonic() @@ -98,6 +105,32 @@ def _check_rate(key: str, max_calls: int, window_sec: int = 60) -> bool: return True +def _resolve_operator_id(request: Request, body: "ChatMessageSendBody", request_id: str) -> Tuple[str, bool]: + client_meta = body.client or {} + operator_id = ( + str(client_meta.get("operator_id") or "").strip() + or str(body.user_id or "").strip() + or str(request.headers.get("X-Operator-Id") or "").strip() + ) + if operator_id: + return operator_id[:128], False + client_ip = request.client.host if request.client else "unknown" + fallback = f"ip:{client_ip}" if client_ip else f"req:{request_id}" + return fallback[:128], True + + +def _rate_limited_http(scope: str, retry_after_s: int) -> HTTPException: + retry_s = max(1, int(retry_after_s or 1)) + return HTTPException( + status_code=429, + detail={ + "error": {"code": "rate_limited", "scope": scope}, + "retry_after_s": retry_s, + }, + headers={"Retry-After": str(retry_s)}, + ) + + # ── Voice error rings (repro pack for incident diagnosis) ───────────────────── # Circular buffers: last 5 TTS errors and last 5 LLM errors. # Populated by all voice endpoints. Read by /api/voice/degradation_status. @@ -3311,12 +3344,53 @@ async def api_chat_send_v2(chat_id: str, body: ChatMessageSendBody, request: Req info = _parse_chat_id(chat_id) target_node = ((body.routing or {}).get("force_node_id") or info["node_id"] or "NODA2").upper() target_agent = info["agent_id"] or "sofiia" + operator_id, operator_id_missing = _resolve_operator_id(request, body, request_id) + chat_rl = _rate_limiter.consume(f"rl:chat:{chat_id}", rps=_RL_CHAT_RPS, burst=_RL_CHAT_BURST) + if not chat_rl.allowed: + SOFIIA_RATE_LIMITED_TOTAL.labels(scope="chat").inc() + log_event( + "chat.send.rate_limited", + request_id=request_id, + scope="chat", + chat_id=chat_id, + node_id=target_node, + agent_id=target_agent, + operator_id=operator_id, + operator_id_missing=operator_id_missing, + limit_rps=_RL_CHAT_RPS, + burst=_RL_CHAT_BURST, + retry_after_s=chat_rl.retry_after_s, + status="error", + error_code="rate_limited", + ) + raise _rate_limited_http("chat", chat_rl.retry_after_s) + op_rl = _rate_limiter.consume(f"rl:op:{operator_id}", rps=_RL_OP_RPS, burst=_RL_OP_BURST) + if not op_rl.allowed: + SOFIIA_RATE_LIMITED_TOTAL.labels(scope="operator").inc() + log_event( + "chat.send.rate_limited", + request_id=request_id, + scope="operator", + chat_id=chat_id, + node_id=target_node, + agent_id=target_agent, + operator_id=operator_id, + operator_id_missing=operator_id_missing, + limit_rps=_RL_OP_RPS, + burst=_RL_OP_BURST, + retry_after_s=op_rl.retry_after_s, + status="error", + error_code="rate_limited", + ) + raise _rate_limited_http("operator", op_rl.retry_after_s) log_event( "chat.send", request_id=request_id, chat_id=chat_id, node_id=target_node, agent_id=target_agent, + operator_id=operator_id, + operator_id_missing=operator_id_missing, idempotency_key_hash=(idem_hash or None), replayed=False, status="ok", @@ -3345,7 +3419,7 @@ async def api_chat_send_v2(chat_id: str, body: ChatMessageSendBody, request: Req SOFIIA_SEND_REQUESTS_TOTAL.labels(node_id=target_node).inc() project_id = body.project_id or CHAT_PROJECT_ID session_id = body.session_id or chat_id - user_id = body.user_id or "console_user" + user_id = operator_id title = f"{target_agent} • {target_node} • {info['source']}" await _app_db.upsert_session(chat_id, project_id=CHAT_PROJECT_ID, title=title) @@ -3353,7 +3427,8 @@ async def api_chat_send_v2(chat_id: str, body: ChatMessageSendBody, request: Req metadata: Dict[str, Any] = { "project_id": project_id, "session_id": session_id, - "user_id": user_id, + "user_id": operator_id, + "operator_id": operator_id, "client": "sofiia-console", "chat_id": chat_id, "node_id": target_node, diff --git a/services/sofiia-console/app/metrics.py b/services/sofiia-console/app/metrics.py index 96c7b512..1e697de3 100644 --- a/services/sofiia-console/app/metrics.py +++ b/services/sofiia-console/app/metrics.py @@ -31,6 +31,11 @@ if _PROM_OK: "Total number of cursor pagination requests", ["resource"], ) + SOFIIA_RATE_LIMITED_TOTAL = _PromCounter( + "sofiia_rate_limited_total", + "Total number of requests rejected by rate limiting", + ["scope"], + ) def render_metrics() -> Tuple[bytes, str]: return _prom_generate_latest(), _PROM_CONTENT_TYPE @@ -93,11 +98,17 @@ else: "Total number of cursor pagination requests", ["resource"], ) + SOFIIA_RATE_LIMITED_TOTAL = _FallbackCounter( + "sofiia_rate_limited_total", + "Total number of requests rejected by rate limiting", + ["scope"], + ) _ALL = [ SOFIIA_SEND_REQUESTS_TOTAL, SOFIIA_IDEMPOTENCY_REPLAYS_TOTAL, SOFIIA_CURSOR_REQUESTS_TOTAL, + SOFIIA_RATE_LIMITED_TOTAL, ] def render_metrics() -> Tuple[bytes, str]: diff --git a/services/sofiia-console/app/rate_limit.py b/services/sofiia-console/app/rate_limit.py new file mode 100644 index 00000000..2a65c952 --- /dev/null +++ b/services/sofiia-console/app/rate_limit.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import math +import os +import time +from dataclasses import dataclass +from typing import Any, Dict, Optional, Protocol + +from .logging import log_event + + +@dataclass +class RateLimitDecision: + allowed: bool + retry_after_s: int + remaining_tokens: float + + +class RateLimiter(Protocol): + def consume(self, key: str, rps: float, burst: int, cost: float = 1.0) -> RateLimitDecision: + ... + + +class InMemoryRateLimiter: + def __init__(self, max_keys: int = 10000) -> None: + self._max_keys = max(1000, int(max_keys)) + self._buckets: Dict[str, Dict[str, float]] = {} + + def consume(self, key: str, rps: float, burst: int, cost: float = 1.0) -> RateLimitDecision: + rate = max(0.0, float(rps)) + cap = max(1.0, float(burst)) + req = max(0.1, float(cost)) + now = time.monotonic() + + hit = self._buckets.get(key) + if not hit: + tokens = cap + updated_at = now + else: + tokens = float(hit.get("tokens") or 0.0) + updated_at = float(hit.get("updated_at") or now) + elapsed = max(0.0, now - updated_at) + if rate > 0: + tokens = min(cap, tokens + elapsed * rate) + + allowed = tokens >= req + retry_after = 0 + if allowed: + tokens -= req + else: + needed = max(0.0, req - tokens) + retry_after = int(math.ceil(needed / rate)) if rate > 0 else 60 + + self._buckets[key] = {"tokens": max(0.0, tokens), "updated_at": now} + self._gc() + return RateLimitDecision(allowed=allowed, retry_after_s=retry_after, remaining_tokens=max(0.0, tokens)) + + def _gc(self) -> None: + if len(self._buckets) <= self._max_keys: + return + items = sorted(self._buckets.items(), key=lambda kv: float((kv[1] or {}).get("updated_at") or 0.0)) + to_drop = max(1, len(items) // 2) + for k, _ in items[:to_drop]: + self._buckets.pop(k, None) + + def reset(self) -> None: + self._buckets.clear() + + +class RedisRateLimiter: + def __init__(self, redis_client: Any, prefix: str = "sofiia:rl:", key_ttl_s: int = 86400) -> None: + self._redis = redis_client + self._prefix = str(prefix or "sofiia:rl:") + self._key_ttl_s = max(60, int(key_ttl_s)) + self._lua = """ +local key = KEYS[1] +local now = tonumber(ARGV[1]) +local rps = tonumber(ARGV[2]) +local burst = tonumber(ARGV[3]) +local cost = tonumber(ARGV[4]) +local ttl = tonumber(ARGV[5]) +local data = redis.call('HMGET', key, 'tokens', 'ts') +local tokens = tonumber(data[1]) +local ts = tonumber(data[2]) +if tokens == nil then tokens = burst end +if ts == nil then ts = now end +if now > ts then + tokens = math.min(burst, tokens + ((now - ts) * rps)) +end +local allowed = 0 +local retry = 0 +if tokens >= cost then + tokens = tokens - cost + allowed = 1 +else + local needed = cost - tokens + if rps > 0 then + retry = math.ceil(needed / rps) + else + retry = 60 + end +end +redis.call('HMSET', key, 'tokens', tokens, 'ts', now) +redis.call('EXPIRE', key, ttl) +return {allowed, retry, tokens} +""" + + def _k(self, key: str) -> str: + return f"{self._prefix}{key}" + + def consume(self, key: str, rps: float, burst: int, cost: float = 1.0) -> RateLimitDecision: + now = time.time() + rate = max(0.0, float(rps)) + cap = max(1.0, float(burst)) + req = max(0.1, float(cost)) + out = self._redis.eval( + self._lua, + 1, + self._k(key), + now, + rate, + cap, + req, + self._key_ttl_s, + ) + allowed = bool((out or [0])[0]) + retry_after = int((out or [0, 0])[1] or 0) + rem = float((out or [0, 0, 0.0])[2] or 0.0) + return RateLimitDecision(allowed=allowed, retry_after_s=retry_after, remaining_tokens=max(0.0, rem)) + + def reset(self) -> None: + keys = self._redis.keys(f"{self._prefix}*") + if keys: + self._redis.delete(*keys) + + +_RATE_LIMITER: Optional[RateLimiter] = None + + +def _make_redis_client(redis_url: str) -> Any: + import redis # type: ignore + + return redis.Redis.from_url(redis_url, decode_responses=False) + + +def get_rate_limiter() -> RateLimiter: + global _RATE_LIMITER + if _RATE_LIMITER is None: + backend = os.getenv("SOFIIA_RATE_LIMIT_BACKEND", "inmemory").strip().lower() or "inmemory" + if backend == "redis": + redis_url = os.getenv("SOFIIA_REDIS_URL", "redis://localhost:6379/0").strip() + prefix = os.getenv("SOFIIA_RATE_LIMIT_PREFIX", "sofiia:rl:").strip() or "sofiia:rl:" + ttl = int(os.getenv("SOFIIA_RATE_LIMIT_KEY_TTL_S", "86400")) + try: + client = _make_redis_client(redis_url) + _RATE_LIMITER = RedisRateLimiter(client, prefix=prefix, key_ttl_s=ttl) + except Exception as exc: + _RATE_LIMITER = InMemoryRateLimiter() + log_event( + "rate_limit.backend.fallback", + backend="redis", + status="degraded", + error_code="redis_unavailable", + error=str(exc)[:180], + ) + else: + max_keys = int(os.getenv("SOFIIA_RATE_LIMIT_MAX_KEYS", "10000")) + _RATE_LIMITER = InMemoryRateLimiter(max_keys=max_keys) + return _RATE_LIMITER diff --git a/tests/test_sofiia_rate_limit.py b/tests/test_sofiia_rate_limit.py new file mode 100644 index 00000000..4196c0c3 --- /dev/null +++ b/tests/test_sofiia_rate_limit.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from app.rate_limit import InMemoryRateLimiter + + +def _create_chat(client, agent_id: str, node_id: str, ref: str) -> str: + r = client.post( + "/api/chats", + json={ + "agent_id": agent_id, + "node_id": node_id, + "source": "web", + "external_chat_ref": ref, + }, + ) + assert r.status_code == 200, r.text + return r.json()["chat"]["chat_id"] + + +def test_inmemory_rate_limiter_blocks_burst_exceed(): + rl = InMemoryRateLimiter() + first = rl.consume("rl:test:key", rps=0.001, burst=1) + second = rl.consume("rl:test:key", rps=0.001, burst=1) + + assert first.allowed is True + assert second.allowed is False + assert second.retry_after_s > 0 + + +def test_send_rate_limit_per_chat_returns_429(sofiia_client, sofiia_module, monkeypatch): + async def _fake_infer(base_url, agent_id, text, **kwargs): + return {"response": f"ok:{agent_id}:{text}", "backend": "fake", "model": "fake-model"} + + monkeypatch.setattr(sofiia_module, "infer", _fake_infer) + monkeypatch.setattr(sofiia_module, "_rate_limiter", InMemoryRateLimiter()) + monkeypatch.setattr(sofiia_module, "_RL_CHAT_RPS", 0.001) + monkeypatch.setattr(sofiia_module, "_RL_CHAT_BURST", 1) + monkeypatch.setattr(sofiia_module, "_RL_OP_RPS", 100.0) + monkeypatch.setattr(sofiia_module, "_RL_OP_BURST", 100) + + chat_id = _create_chat(sofiia_client, "sofiia", "NODA2", "rl-chat") + r1 = sofiia_client.post(f"/api/chats/{chat_id}/send", json={"text": "ping-1", "user_id": "op-1"}) + r2 = sofiia_client.post(f"/api/chats/{chat_id}/send", json={"text": "ping-2", "user_id": "op-1"}) + + assert r1.status_code == 200, r1.text + assert r2.status_code == 429, r2.text + body = r2.json() + assert body["detail"]["error"]["code"] == "rate_limited" + assert body["detail"]["error"]["scope"] == "chat" + assert int(r2.headers.get("Retry-After", "0")) >= 1 + + +def test_send_rate_limit_per_operator_returns_429(sofiia_client, sofiia_module, monkeypatch): + async def _fake_infer(base_url, agent_id, text, **kwargs): + return {"response": f"ok:{agent_id}:{text}", "backend": "fake", "model": "fake-model"} + + monkeypatch.setattr(sofiia_module, "infer", _fake_infer) + monkeypatch.setattr(sofiia_module, "_rate_limiter", InMemoryRateLimiter()) + monkeypatch.setattr(sofiia_module, "_RL_CHAT_RPS", 100.0) + monkeypatch.setattr(sofiia_module, "_RL_CHAT_BURST", 100) + monkeypatch.setattr(sofiia_module, "_RL_OP_RPS", 0.001) + monkeypatch.setattr(sofiia_module, "_RL_OP_BURST", 1) + + chat_1 = _create_chat(sofiia_client, "sofiia", "NODA2", "rl-op-1") + chat_2 = _create_chat(sofiia_client, "sofiia", "NODA2", "rl-op-2") + r1 = sofiia_client.post(f"/api/chats/{chat_1}/send", json={"text": "ping-1", "user_id": "operator-1"}) + r2 = sofiia_client.post(f"/api/chats/{chat_2}/send", json={"text": "ping-2", "user_id": "operator-1"}) + + assert r1.status_code == 200, r1.text + assert r2.status_code == 429, r2.text + body = r2.json() + assert body["detail"]["error"]["code"] == "rate_limited" + assert body["detail"]["error"]["scope"] == "operator" + assert int(r2.headers.get("Retry-After", "0")) >= 1