feat(matrix-bridge-dagi): add rate limiting (H1) and metrics (H3)
H1 — InMemoryRateLimiter (sliding window, no Redis):
- Per-room: RATE_LIMIT_ROOM_RPM (default 20/min)
- Per-sender: RATE_LIMIT_SENDER_RPM (default 10/min)
- Room checked before sender — sender quota not charged on room block
- Blocked messages: audit matrix.rate_limited + on_rate_limited callback
- reset() for ops/test, stats() exposed in /health
H3 — Extended Prometheus metrics:
- matrix_bridge_rate_limited_total{room_id,agent_id,limit_type}
- matrix_bridge_send_duration_seconds histogram (invoke was already there)
- matrix_bridge_invoke_duration_seconds buckets tuned for LLM latency
- matrix_bridge_rate_limiter_active_rooms/senders gauges
- on_invoke_latency + on_send_latency callbacks wired in ingress loop
16 new tests: rate limiter unit (13) + ingress integration (3)
Total: 65 passed
Made-with: Cursor
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""
|
||||
Matrix Ingress + Egress Loop — Phase M1.4
|
||||
Matrix Ingress + Egress Loop — Phase M1.4 + H1/H3
|
||||
|
||||
Polls Matrix /sync for new messages, invokes DAGI Router for mapped rooms,
|
||||
sends agent replies back to Matrix, writes audit events to sofiia-console.
|
||||
@@ -7,11 +7,12 @@ sends agent replies back to Matrix, writes audit events to sofiia-console.
|
||||
Pipeline:
|
||||
sync_poll() → extract_room_messages()
|
||||
→ for each message:
|
||||
1. dedupe (mark_seen)
|
||||
2. audit: matrix.message.received
|
||||
3. invoke DAGI Router (/v1/agents/{agent_id}/infer)
|
||||
4. send_text() reply to Matrix room
|
||||
5. audit: matrix.agent.replied | matrix.error
|
||||
1. rate_limit check (room + sender) ← H1
|
||||
2. dedupe (mark_seen)
|
||||
3. audit: matrix.message.received
|
||||
4. invoke DAGI Router (timed → on_invoke_latency) ← H3
|
||||
5. send_text() reply (timed → on_send_latency) ← H3
|
||||
6. audit: matrix.agent.replied | matrix.error
|
||||
|
||||
Graceful shutdown via asyncio.Event.
|
||||
"""
|
||||
@@ -24,6 +25,7 @@ from typing import Any, Callable, Dict, List, Optional
|
||||
import httpx
|
||||
|
||||
from .matrix_client import MatrixClient
|
||||
from .rate_limit import InMemoryRateLimiter
|
||||
from .room_mapping import RoomMappingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -32,10 +34,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_RETRY_BACKOFF = 60.0
|
||||
_INIT_RETRY_BACKOFF = 2.0
|
||||
_ROUTER_TIMEOUT_S = 45.0 # Router may call DeepSeek/Mistral
|
||||
_ROUTER_TIMEOUT_S = 45.0
|
||||
_AUDIT_TIMEOUT_S = 5.0
|
||||
_REPLY_TEXT_MAX = 4000 # Matrix message cap (chars)
|
||||
_ERROR_REPLY_TEXT = "⚠️ Тимчасова помилка. Спробуйте ще раз."
|
||||
_REPLY_TEXT_MAX = 4000
|
||||
|
||||
|
||||
# ── Router invoke ──────────────────────────────────────────────────────────────
|
||||
@@ -50,7 +51,7 @@ async def _invoke_router(
|
||||
) -> str:
|
||||
"""
|
||||
POST /v1/agents/{agent_id}/infer — returns response text string.
|
||||
Field: response['response'] (confirmed from NODA1 test).
|
||||
Field confirmed as 'response' on NODA1.
|
||||
Raises httpx.HTTPError on failure.
|
||||
"""
|
||||
url = f"{router_url.rstrip('/')}/v1/agents/{agent_id}/infer"
|
||||
@@ -66,7 +67,6 @@ async def _invoke_router(
|
||||
resp = await http_client.post(url, json=payload, timeout=_ROUTER_TIMEOUT_S)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
# Extract text — field confirmed as 'response'
|
||||
text = (
|
||||
data.get("response")
|
||||
or data.get("text")
|
||||
@@ -95,10 +95,7 @@ async def _write_audit(
|
||||
duration_ms: Optional[int] = None,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Fire-and-forget audit write to sofiia-console internal endpoint.
|
||||
Never raises — logs warning on failure.
|
||||
"""
|
||||
"""Fire-and-forget audit write. Never raises."""
|
||||
if not console_url or not internal_token:
|
||||
return
|
||||
try:
|
||||
@@ -131,12 +128,15 @@ async def _write_audit(
|
||||
|
||||
class MatrixIngressLoop:
|
||||
"""
|
||||
Drives Matrix sync-poll → router-invoke → Matrix send_text pipeline.
|
||||
Drives Matrix sync-poll → rate-check → router-invoke → Matrix send_text.
|
||||
|
||||
Usage:
|
||||
loop = MatrixIngressLoop(...)
|
||||
stop_event = asyncio.Event()
|
||||
await loop.run(stop_event)
|
||||
Metric callbacks (all optional, called synchronously):
|
||||
on_message_received(room_id, agent_id)
|
||||
on_message_replied(room_id, agent_id, status)
|
||||
on_gateway_error(error_type)
|
||||
on_rate_limited(room_id, agent_id, limit_type) ← H1
|
||||
on_invoke_latency(agent_id, duration_seconds) ← H3
|
||||
on_send_latency(agent_id, duration_seconds) ← H3
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -149,9 +149,13 @@ class MatrixIngressLoop:
|
||||
room_map: RoomMappingConfig,
|
||||
sofiia_console_url: str = "",
|
||||
sofiia_internal_token: str = "",
|
||||
rate_limiter: Optional[InMemoryRateLimiter] = None,
|
||||
on_message_received: Optional[Callable[[str, str], None]] = None,
|
||||
on_message_replied: Optional[Callable[[str, str, str], None]] = None,
|
||||
on_gateway_error: Optional[Callable[[str], None]] = None,
|
||||
on_rate_limited: Optional[Callable[[str, str, str], None]] = None,
|
||||
on_invoke_latency: Optional[Callable[[str, float], None]] = None,
|
||||
on_send_latency: Optional[Callable[[str, float], None]] = None,
|
||||
) -> None:
|
||||
self._hs_url = matrix_homeserver_url
|
||||
self._token = matrix_access_token
|
||||
@@ -161,9 +165,13 @@ class MatrixIngressLoop:
|
||||
self._room_map = room_map
|
||||
self._console_url = sofiia_console_url
|
||||
self._internal_token = sofiia_internal_token
|
||||
self._rate_limiter = rate_limiter
|
||||
self._on_message_received = on_message_received
|
||||
self._on_message_replied = on_message_replied
|
||||
self._on_gateway_error = on_gateway_error
|
||||
self._on_rate_limited = on_rate_limited
|
||||
self._on_invoke_latency = on_invoke_latency
|
||||
self._on_send_latency = on_send_latency
|
||||
self._next_batch: Optional[str] = None
|
||||
|
||||
@property
|
||||
@@ -171,7 +179,6 @@ class MatrixIngressLoop:
|
||||
return self._next_batch
|
||||
|
||||
async def run(self, stop_event: asyncio.Event) -> None:
|
||||
"""Main loop until stop_event is set."""
|
||||
backoff = _INIT_RETRY_BACKOFF
|
||||
logger.info(
|
||||
"Matrix ingress/egress loop started | hs=%s node=%s mappings=%d",
|
||||
@@ -239,7 +246,27 @@ class MatrixIngressLoop:
|
||||
if not text:
|
||||
return
|
||||
|
||||
# Dedupe — mark seen before any IO (prevents double-process on retry)
|
||||
# ── H1: Rate limit check ───────────────────────────────────────────────
|
||||
if self._rate_limiter is not None:
|
||||
allowed, limit_type = self._rate_limiter.check(room_id=room_id, sender=sender)
|
||||
if not allowed:
|
||||
logger.warning(
|
||||
"Rate limited: room=%s sender=%s limit_type=%s event=%s",
|
||||
room_id, sender, limit_type, event_id,
|
||||
)
|
||||
if self._on_rate_limited:
|
||||
self._on_rate_limited(room_id, agent_id, limit_type or "unknown")
|
||||
await _write_audit(
|
||||
http_client, self._console_url, self._internal_token,
|
||||
event="matrix.rate_limited",
|
||||
agent_id=agent_id, node_id=self._node_id,
|
||||
room_id=room_id, event_id=event_id,
|
||||
status="error", error_code=f"rate_limit_{limit_type}",
|
||||
data={"sender": sender, "limit_type": limit_type},
|
||||
)
|
||||
return
|
||||
|
||||
# Dedupe — mark seen before any IO
|
||||
client.mark_seen(event_id)
|
||||
|
||||
logger.info(
|
||||
@@ -250,7 +277,6 @@ class MatrixIngressLoop:
|
||||
if self._on_message_received:
|
||||
self._on_message_received(room_id, agent_id)
|
||||
|
||||
# Audit: received
|
||||
await _write_audit(
|
||||
http_client, self._console_url, self._internal_token,
|
||||
event="matrix.message.received",
|
||||
@@ -260,12 +286,13 @@ class MatrixIngressLoop:
|
||||
data={"sender": sender, "text_len": len(text)},
|
||||
)
|
||||
|
||||
# Session ID: stable per room (allows memory context across messages)
|
||||
session_id = f"matrix:{room_id.replace('!', '').replace(':', '_')}"
|
||||
|
||||
# ── H3: Invoke with latency measurement ───────────────────────────────
|
||||
t0 = time.monotonic()
|
||||
reply_text: Optional[str] = None
|
||||
invoke_ok = False
|
||||
invoke_duration_s: float = 0.0
|
||||
|
||||
try:
|
||||
reply_text = await _invoke_router(
|
||||
@@ -277,14 +304,20 @@ class MatrixIngressLoop:
|
||||
session_id=session_id,
|
||||
)
|
||||
invoke_ok = True
|
||||
duration_ms = int((time.monotonic() - t0) * 1000)
|
||||
invoke_duration_s = time.monotonic() - t0
|
||||
duration_ms = int(invoke_duration_s * 1000)
|
||||
|
||||
if self._on_invoke_latency:
|
||||
self._on_invoke_latency(agent_id, invoke_duration_s)
|
||||
|
||||
logger.info(
|
||||
"Router invoke ok: agent=%s event=%s reply_len=%d duration=%dms",
|
||||
agent_id, event_id, len(reply_text or ""), duration_ms,
|
||||
)
|
||||
|
||||
except httpx.HTTPStatusError as exc:
|
||||
duration_ms = int((time.monotonic() - t0) * 1000)
|
||||
invoke_duration_s = time.monotonic() - t0
|
||||
duration_ms = int(invoke_duration_s * 1000)
|
||||
logger.error(
|
||||
"Router HTTP %d for agent=%s event=%s duration=%dms",
|
||||
exc.response.status_code, agent_id, event_id, duration_ms,
|
||||
@@ -301,7 +334,8 @@ class MatrixIngressLoop:
|
||||
)
|
||||
|
||||
except (httpx.ConnectError, httpx.TimeoutException) as exc:
|
||||
duration_ms = int((time.monotonic() - t0) * 1000)
|
||||
invoke_duration_s = time.monotonic() - t0
|
||||
duration_ms = int(invoke_duration_s * 1000)
|
||||
logger.error(
|
||||
"Router network error agent=%s event=%s: %s duration=%dms",
|
||||
agent_id, event_id, exc, duration_ms,
|
||||
@@ -318,7 +352,8 @@ class MatrixIngressLoop:
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
duration_ms = int((time.monotonic() - t0) * 1000)
|
||||
invoke_duration_s = time.monotonic() - t0
|
||||
duration_ms = int(invoke_duration_s * 1000)
|
||||
logger.error(
|
||||
"Unexpected router error agent=%s event=%s: %s",
|
||||
agent_id, event_id, exc,
|
||||
@@ -334,24 +369,25 @@ class MatrixIngressLoop:
|
||||
duration_ms=duration_ms,
|
||||
)
|
||||
|
||||
# ── Egress: send reply back to Matrix ──────────────────────────────────
|
||||
if not invoke_ok:
|
||||
# No reply on error in M1 — just audit (avoids spam in room)
|
||||
return
|
||||
|
||||
if not reply_text:
|
||||
logger.warning("Empty reply from router for agent=%s event=%s", agent_id, event_id)
|
||||
return
|
||||
|
||||
# Truncate if needed
|
||||
# ── H3: Send with latency measurement ─────────────────────────────────
|
||||
send_text = reply_text[:_REPLY_TEXT_MAX]
|
||||
txn_id = MatrixClient.make_txn_id(room_id, event_id)
|
||||
|
||||
send_t0 = time.monotonic()
|
||||
try:
|
||||
await client.send_text(room_id, send_text, txn_id)
|
||||
send_duration_ms = int((time.monotonic() - send_t0) * 1000)
|
||||
send_duration_s = time.monotonic() - send_t0
|
||||
send_duration_ms = int(send_duration_s * 1000)
|
||||
|
||||
if self._on_send_latency:
|
||||
self._on_send_latency(agent_id, send_duration_s)
|
||||
if self._on_message_replied:
|
||||
self._on_message_replied(room_id, agent_id, "ok")
|
||||
|
||||
@@ -365,7 +401,7 @@ class MatrixIngressLoop:
|
||||
data={
|
||||
"reply_len": len(send_text),
|
||||
"truncated": len(reply_text) > _REPLY_TEXT_MAX,
|
||||
"router_duration_ms": duration_ms,
|
||||
"router_duration_ms": int(invoke_duration_s * 1000),
|
||||
},
|
||||
)
|
||||
logger.info(
|
||||
@@ -374,7 +410,8 @@ class MatrixIngressLoop:
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
send_duration_ms = int((time.monotonic() - send_t0) * 1000)
|
||||
send_duration_s = time.monotonic() - send_t0
|
||||
send_duration_ms = int(send_duration_s * 1000)
|
||||
logger.error(
|
||||
"Failed to send Matrix reply agent=%s event=%s: %s",
|
||||
agent_id, event_id, exc,
|
||||
|
||||
@@ -33,6 +33,7 @@ except ImportError: # pragma: no cover
|
||||
|
||||
from .config import BridgeConfig, load_config
|
||||
from .ingress import MatrixIngressLoop
|
||||
from .rate_limit import InMemoryRateLimiter
|
||||
from .room_mapping import RoomMappingConfig, parse_room_map
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -41,7 +42,7 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger("matrix-bridge-dagi")
|
||||
|
||||
# ── Prometheus metrics ────────────────────────────────────────────────────────
|
||||
# ── Prometheus metrics (H3) ───────────────────────────────────────────────────
|
||||
if _PROM_OK:
|
||||
_messages_received = Counter(
|
||||
"matrix_bridge_messages_received_total",
|
||||
@@ -53,28 +54,49 @@ if _PROM_OK:
|
||||
"Total agent replies sent to Matrix",
|
||||
["room_id", "agent_id", "status"],
|
||||
)
|
||||
_messages_rate_limited = Counter(
|
||||
"matrix_bridge_rate_limited_total",
|
||||
"Messages dropped by rate limiter",
|
||||
["room_id", "agent_id", "limit_type"],
|
||||
)
|
||||
_gateway_errors = Counter(
|
||||
"matrix_bridge_gateway_errors_total",
|
||||
"Errors calling DAGI gateway",
|
||||
"Errors by stage (sync, invoke, send, audit)",
|
||||
["error_type"],
|
||||
)
|
||||
_invoke_latency = Histogram(
|
||||
"matrix_bridge_invoke_duration_seconds",
|
||||
"Duration of DAGI invoke call",
|
||||
"Latency of DAGI Router infer call",
|
||||
["agent_id"],
|
||||
buckets=[0.5, 1.0, 2.0, 5.0, 10.0, 20.0, 45.0],
|
||||
)
|
||||
_send_latency = Histogram(
|
||||
"matrix_bridge_send_duration_seconds",
|
||||
"Latency of Matrix send_text call",
|
||||
["agent_id"],
|
||||
buckets=[0.05, 0.1, 0.25, 0.5, 1.0, 2.0, 5.0],
|
||||
)
|
||||
_bridge_up = Gauge(
|
||||
"matrix_bridge_up",
|
||||
"1 if bridge started successfully",
|
||||
)
|
||||
_rate_limiter_active_rooms = Gauge(
|
||||
"matrix_bridge_rate_limiter_active_rooms",
|
||||
"Rooms with activity in the current rate-limit window",
|
||||
)
|
||||
_rate_limiter_active_senders = Gauge(
|
||||
"matrix_bridge_rate_limiter_active_senders",
|
||||
"Senders with activity in the current rate-limit window",
|
||||
)
|
||||
|
||||
# ── Startup state ─────────────────────────────────────────────────────────────
|
||||
_START_TIME = time.monotonic()
|
||||
_cfg: Optional[BridgeConfig] = None
|
||||
_config_error: Optional[str] = None
|
||||
_matrix_reachable: Optional[bool] = None # probed at startup
|
||||
_gateway_reachable: Optional[bool] = None # probed at startup
|
||||
_matrix_reachable: Optional[bool] = None
|
||||
_gateway_reachable: Optional[bool] = None
|
||||
_room_map: Optional[RoomMappingConfig] = None
|
||||
_rate_limiter: Optional[InMemoryRateLimiter] = None
|
||||
_ingress_task: Optional[asyncio.Task] = None
|
||||
_ingress_stop: Optional[asyncio.Event] = None
|
||||
|
||||
@@ -93,7 +115,8 @@ async def _probe_url(url: str, timeout: float = 5.0) -> bool:
|
||||
# ── Lifespan ──────────────────────────────────────────────────────────────────
|
||||
@asynccontextmanager
|
||||
async def lifespan(app_: Any):
|
||||
global _cfg, _config_error, _matrix_reachable, _gateway_reachable, _room_map
|
||||
global _cfg, _config_error, _matrix_reachable, _gateway_reachable
|
||||
global _room_map, _rate_limiter
|
||||
try:
|
||||
_cfg = load_config()
|
||||
|
||||
@@ -103,6 +126,16 @@ async def lifespan(app_: Any):
|
||||
_cfg.bridge_allowed_agents,
|
||||
)
|
||||
|
||||
# H1: Rate limiter (inmemory, per config)
|
||||
_rate_limiter = InMemoryRateLimiter(
|
||||
room_rpm=_cfg.rate_limit_room_rpm,
|
||||
sender_rpm=_cfg.rate_limit_sender_rpm,
|
||||
)
|
||||
logger.info(
|
||||
"✅ Rate limiter: room_rpm=%d sender_rpm=%d",
|
||||
_cfg.rate_limit_room_rpm, _cfg.rate_limit_sender_rpm,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"✅ matrix-bridge-dagi started | node=%s build=%s homeserver=%s "
|
||||
"room=%s agents=%s mappings=%d",
|
||||
@@ -147,6 +180,25 @@ async def lifespan(app_: Any):
|
||||
room_id=room_id, agent_id=agent_id, status=status
|
||||
).inc()
|
||||
|
||||
def _on_rate_limited(room_id: str, agent_id: str, limit_type: str) -> None:
|
||||
if _PROM_OK:
|
||||
_messages_rate_limited.labels(
|
||||
room_id=room_id, agent_id=agent_id, limit_type=limit_type
|
||||
).inc()
|
||||
# Update active room/sender gauges from limiter stats
|
||||
if _rate_limiter is not None:
|
||||
stats = _rate_limiter.stats()
|
||||
_rate_limiter_active_rooms.set(stats["active_rooms"])
|
||||
_rate_limiter_active_senders.set(stats["active_senders"])
|
||||
|
||||
def _on_invoke_latency(agent_id: str, duration_s: float) -> None:
|
||||
if _PROM_OK:
|
||||
_invoke_latency.labels(agent_id=agent_id).observe(duration_s)
|
||||
|
||||
def _on_send_latency(agent_id: str, duration_s: float) -> None:
|
||||
if _PROM_OK:
|
||||
_send_latency.labels(agent_id=agent_id).observe(duration_s)
|
||||
|
||||
ingress = MatrixIngressLoop(
|
||||
matrix_homeserver_url=_cfg.matrix_homeserver_url,
|
||||
matrix_access_token=_cfg.matrix_access_token,
|
||||
@@ -156,9 +208,13 @@ async def lifespan(app_: Any):
|
||||
room_map=_room_map,
|
||||
sofiia_console_url=_cfg.sofiia_console_url,
|
||||
sofiia_internal_token=_cfg.sofiia_internal_token,
|
||||
rate_limiter=_rate_limiter,
|
||||
on_message_received=_on_msg,
|
||||
on_message_replied=_on_replied,
|
||||
on_gateway_error=_on_gw_error,
|
||||
on_rate_limited=_on_rate_limited,
|
||||
on_invoke_latency=_on_invoke_latency,
|
||||
on_send_latency=_on_send_latency,
|
||||
)
|
||||
_ingress_task = asyncio.create_task(
|
||||
ingress.run(_ingress_stop),
|
||||
@@ -233,6 +289,7 @@ async def health() -> Dict[str, Any]:
|
||||
"gateway_reachable": _gateway_reachable,
|
||||
"mappings_count": _room_map.total_mappings if _room_map else 0,
|
||||
"config_ok": True,
|
||||
"rate_limiter": _rate_limiter.stats() if _rate_limiter else None,
|
||||
}
|
||||
|
||||
|
||||
|
||||
111
services/matrix-bridge-dagi/app/rate_limit.py
Normal file
111
services/matrix-bridge-dagi/app/rate_limit.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
In-memory rate limiter for matrix-bridge-dagi (H1)
|
||||
|
||||
Sliding window algorithm — no external dependencies (no Redis needed for M1).
|
||||
Thread-safe within a single asyncio event loop.
|
||||
|
||||
Two independent limiters per message:
|
||||
- room limiter: max N messages per room per minute
|
||||
- sender limiter: max N messages per sender per minute
|
||||
|
||||
Usage:
|
||||
rl = InMemoryRateLimiter(room_rpm=20, sender_rpm=10)
|
||||
allowed, limit_type = rl.check(room_id="!abc:server", sender="@user:server")
|
||||
if not allowed:
|
||||
# reject, audit matrix.rate_limited
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict, deque
|
||||
from typing import Deque, Dict, Optional, Tuple
|
||||
|
||||
|
||||
class InMemoryRateLimiter:
|
||||
"""
|
||||
Sliding-window rate limiter, independent per room and per sender.
|
||||
|
||||
Windows are pruned lazily on each check — no background task needed.
|
||||
Each bucket stores timestamps (float) of accepted events within the window.
|
||||
"""
|
||||
|
||||
_WINDOW_S: float = 60.0 # 1-minute window
|
||||
|
||||
def __init__(self, room_rpm: int = 20, sender_rpm: int = 10) -> None:
|
||||
if room_rpm < 1 or sender_rpm < 1:
|
||||
raise ValueError("RPM limits must be >= 1")
|
||||
self._room_rpm = room_rpm
|
||||
self._sender_rpm = sender_rpm
|
||||
|
||||
# buckets: key → deque of accepted timestamps
|
||||
self._room_buckets: Dict[str, Deque[float]] = defaultdict(deque)
|
||||
self._sender_buckets: Dict[str, Deque[float]] = defaultdict(deque)
|
||||
|
||||
@property
|
||||
def room_rpm(self) -> int:
|
||||
return self._room_rpm
|
||||
|
||||
@property
|
||||
def sender_rpm(self) -> int:
|
||||
return self._sender_rpm
|
||||
|
||||
def check(self, room_id: str, sender: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Check if message is allowed.
|
||||
|
||||
Returns:
|
||||
(True, None) — allowed
|
||||
(False, "room") — room limit exceeded
|
||||
(False, "sender") — sender limit exceeded
|
||||
|
||||
Room is checked first; if room limit hit, sender bucket is NOT updated
|
||||
to avoid penalising user's quota for messages already blocked.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
cutoff = now - self._WINDOW_S
|
||||
|
||||
# Check room
|
||||
room_bucket = self._room_buckets[room_id]
|
||||
self._prune(room_bucket, cutoff)
|
||||
if len(room_bucket) >= self._room_rpm:
|
||||
return False, "room"
|
||||
|
||||
# Check sender
|
||||
sender_bucket = self._sender_buckets[sender]
|
||||
self._prune(sender_bucket, cutoff)
|
||||
if len(sender_bucket) >= self._sender_rpm:
|
||||
return False, "sender"
|
||||
|
||||
# Both allowed — record
|
||||
room_bucket.append(now)
|
||||
sender_bucket.append(now)
|
||||
return True, None
|
||||
|
||||
def reset(self, room_id: Optional[str] = None, sender: Optional[str] = None) -> None:
|
||||
"""Clear buckets — useful for tests or manual ops reset."""
|
||||
if room_id is not None:
|
||||
self._room_buckets.pop(room_id, None)
|
||||
if sender is not None:
|
||||
self._sender_buckets.pop(sender, None)
|
||||
|
||||
def stats(self) -> dict:
|
||||
"""Read-only snapshot of current bucket sizes (for /health or ops)."""
|
||||
now = time.monotonic()
|
||||
cutoff = now - self._WINDOW_S
|
||||
return {
|
||||
"room_rpm_limit": self._room_rpm,
|
||||
"sender_rpm_limit": self._sender_rpm,
|
||||
"active_rooms": sum(
|
||||
1 for b in self._room_buckets.values()
|
||||
if any(t > cutoff for t in b)
|
||||
),
|
||||
"active_senders": sum(
|
||||
1 for b in self._sender_buckets.values()
|
||||
if any(t > cutoff for t in b)
|
||||
),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _prune(bucket: Deque[float], cutoff: float) -> None:
|
||||
"""Remove timestamps older than the window."""
|
||||
while bucket and bucket[0] <= cutoff:
|
||||
bucket.popleft()
|
||||
@@ -50,6 +50,14 @@ MSG_EVENT = {
|
||||
}
|
||||
|
||||
|
||||
import sys
|
||||
_BRIDGE = Path(__file__).parent.parent / "services" / "matrix-bridge-dagi"
|
||||
if str(_BRIDGE) not in sys.path:
|
||||
sys.path.insert(0, str(_BRIDGE))
|
||||
|
||||
from app.rate_limit import InMemoryRateLimiter # noqa: E402
|
||||
|
||||
|
||||
def _make_loop(**kwargs) -> MatrixIngressLoop:
|
||||
room_map = parse_room_map(ROOM_MAP_STR, ALLOWED)
|
||||
defaults = dict(
|
||||
@@ -446,3 +454,188 @@ def test_loop_metric_callbacks_fire():
|
||||
assert replied[0][2] == "ok"
|
||||
|
||||
run(_inner())
|
||||
|
||||
|
||||
# ── H1: Rate limit integration ────────────────────────────────────────────────
|
||||
|
||||
def test_rate_limiter_blocks_invoke():
|
||||
"""When room rate limit exceeded, router must NOT be invoked."""
|
||||
async def _inner():
|
||||
router_calls = [0]
|
||||
rate_limited = []
|
||||
|
||||
rl = InMemoryRateLimiter(room_rpm=1, sender_rpm=100)
|
||||
|
||||
stop = asyncio.Event()
|
||||
loop = _make_loop(
|
||||
rate_limiter=rl,
|
||||
on_rate_limited=lambda r, a, lt: rate_limited.append(lt),
|
||||
)
|
||||
|
||||
# Two events from same room
|
||||
event2 = {**MSG_EVENT, "event_id": "$event2:server"}
|
||||
call_count = [0]
|
||||
|
||||
async def fake_sync_poll(**kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] > 1:
|
||||
stop.set()
|
||||
return {"next_batch": "end", "rooms": {}}
|
||||
return _fake_sync([MSG_EVENT, event2])
|
||||
|
||||
def fake_extract(sync_resp, room_id):
|
||||
events = sync_resp.get("rooms", {}).get("join", {}).get(room_id, {}).get("timeline", {}).get("events", [])
|
||||
return [e for e in events if e.get("type") == "m.room.message" and e.get("sender") != BOT_USER]
|
||||
|
||||
async def fake_http_post(url, *, json=None, headers=None, timeout=None):
|
||||
if "/infer" in url:
|
||||
router_calls[0] += 1
|
||||
return _ok_response("reply")
|
||||
return _audit_response()
|
||||
|
||||
with patch("app.ingress.MatrixClient") as MockClient:
|
||||
mock_mc = AsyncMock()
|
||||
mock_mc.__aenter__ = AsyncMock(return_value=mock_mc)
|
||||
mock_mc.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_mc.sync_poll = fake_sync_poll
|
||||
mock_mc.join_room = AsyncMock()
|
||||
mock_mc.mark_seen = MagicMock()
|
||||
mock_mc.extract_room_messages = fake_extract
|
||||
mock_mc.send_text = AsyncMock(return_value={"event_id": "$r"})
|
||||
MockClient.return_value = mock_mc
|
||||
MockClient.make_txn_id = lambda r, e: f"txn_{e}"
|
||||
|
||||
with patch("app.ingress.httpx.AsyncClient") as MockHTTP:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.__aenter__ = AsyncMock(return_value=mock_http)
|
||||
mock_http.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_http.post = fake_http_post
|
||||
MockHTTP.return_value = mock_http
|
||||
|
||||
await asyncio.wait_for(loop.run(stop), timeout=3.0)
|
||||
|
||||
# First message passes, second blocked
|
||||
assert router_calls[0] == 1
|
||||
assert len(rate_limited) == 1
|
||||
assert rate_limited[0] == "room"
|
||||
|
||||
run(_inner())
|
||||
|
||||
|
||||
def test_rate_limiter_audit_event_on_block():
|
||||
"""Blocked message must produce matrix.rate_limited audit event."""
|
||||
async def _inner():
|
||||
audit_events = []
|
||||
|
||||
rl = InMemoryRateLimiter(room_rpm=1, sender_rpm=100)
|
||||
stop = asyncio.Event()
|
||||
loop = _make_loop(rate_limiter=rl)
|
||||
|
||||
event2 = {**MSG_EVENT, "event_id": "$event2:server"}
|
||||
call_count = [0]
|
||||
|
||||
async def fake_sync_poll(**kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] > 1:
|
||||
stop.set()
|
||||
return {"next_batch": "end", "rooms": {}}
|
||||
return _fake_sync([MSG_EVENT, event2])
|
||||
|
||||
def fake_extract(sync_resp, room_id):
|
||||
events = sync_resp.get("rooms", {}).get("join", {}).get(room_id, {}).get("timeline", {}).get("events", [])
|
||||
return [e for e in events if e.get("type") == "m.room.message" and e.get("sender") != BOT_USER]
|
||||
|
||||
async def fake_http_post(url, *, json=None, headers=None, timeout=None):
|
||||
if "/audit/internal" in url:
|
||||
audit_events.append(json.get("event") if json else "unknown")
|
||||
if "/infer" in url:
|
||||
return _ok_response("reply")
|
||||
return _audit_response()
|
||||
|
||||
with patch("app.ingress.MatrixClient") as MockClient:
|
||||
mock_mc = AsyncMock()
|
||||
mock_mc.__aenter__ = AsyncMock(return_value=mock_mc)
|
||||
mock_mc.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_mc.sync_poll = fake_sync_poll
|
||||
mock_mc.join_room = AsyncMock()
|
||||
mock_mc.mark_seen = MagicMock()
|
||||
mock_mc.extract_room_messages = fake_extract
|
||||
mock_mc.send_text = AsyncMock(return_value={"event_id": "$r"})
|
||||
MockClient.return_value = mock_mc
|
||||
MockClient.make_txn_id = lambda r, e: f"txn_{e}"
|
||||
|
||||
with patch("app.ingress.httpx.AsyncClient") as MockHTTP:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.__aenter__ = AsyncMock(return_value=mock_http)
|
||||
mock_http.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_http.post = fake_http_post
|
||||
MockHTTP.return_value = mock_http
|
||||
|
||||
await asyncio.wait_for(loop.run(stop), timeout=3.0)
|
||||
|
||||
assert "matrix.rate_limited" in audit_events
|
||||
|
||||
run(_inner())
|
||||
|
||||
|
||||
def test_latency_callbacks_fire():
|
||||
"""on_invoke_latency and on_send_latency must be called with agent_id and float."""
|
||||
async def _inner():
|
||||
invoke_latencies = []
|
||||
send_latencies = []
|
||||
|
||||
stop = asyncio.Event()
|
||||
loop = _make_loop(
|
||||
on_invoke_latency=lambda a, d: invoke_latencies.append((a, d)),
|
||||
on_send_latency=lambda a, d: send_latencies.append((a, d)),
|
||||
)
|
||||
|
||||
call_count = [0]
|
||||
|
||||
async def fake_sync_poll(**kwargs):
|
||||
call_count[0] += 1
|
||||
if call_count[0] > 1:
|
||||
stop.set()
|
||||
return {"next_batch": "end", "rooms": {}}
|
||||
return _fake_sync([MSG_EVENT])
|
||||
|
||||
def fake_extract(sync_resp, room_id):
|
||||
events = sync_resp.get("rooms", {}).get("join", {}).get(room_id, {}).get("timeline", {}).get("events", [])
|
||||
return [e for e in events if e.get("type") == "m.room.message" and e.get("sender") != BOT_USER]
|
||||
|
||||
async def fake_http_post(url, *, json=None, headers=None, timeout=None):
|
||||
if "/infer" in url:
|
||||
return _ok_response("test reply")
|
||||
return _audit_response()
|
||||
|
||||
with patch("app.ingress.MatrixClient") as MockClient:
|
||||
mock_mc = AsyncMock()
|
||||
mock_mc.__aenter__ = AsyncMock(return_value=mock_mc)
|
||||
mock_mc.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_mc.sync_poll = fake_sync_poll
|
||||
mock_mc.join_room = AsyncMock()
|
||||
mock_mc.mark_seen = MagicMock()
|
||||
mock_mc.extract_room_messages = fake_extract
|
||||
mock_mc.send_text = AsyncMock(return_value={"event_id": "$r"})
|
||||
MockClient.return_value = mock_mc
|
||||
MockClient.make_txn_id = lambda r, e: f"txn_{e}"
|
||||
|
||||
with patch("app.ingress.httpx.AsyncClient") as MockHTTP:
|
||||
mock_http = AsyncMock()
|
||||
mock_http.__aenter__ = AsyncMock(return_value=mock_http)
|
||||
mock_http.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_http.post = fake_http_post
|
||||
MockHTTP.return_value = mock_http
|
||||
|
||||
await asyncio.wait_for(loop.run(stop), timeout=3.0)
|
||||
|
||||
assert len(invoke_latencies) == 1
|
||||
assert invoke_latencies[0][0] == "sofiia"
|
||||
assert isinstance(invoke_latencies[0][1], float)
|
||||
assert invoke_latencies[0][1] >= 0
|
||||
|
||||
assert len(send_latencies) == 1
|
||||
assert send_latencies[0][0] == "sofiia"
|
||||
assert isinstance(send_latencies[0][1], float)
|
||||
|
||||
run(_inner())
|
||||
|
||||
169
tests/test_matrix_bridge_rate_limit.py
Normal file
169
tests/test_matrix_bridge_rate_limit.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Tests for services/matrix-bridge-dagi/app/rate_limit.py (H1)
|
||||
|
||||
Coverage:
|
||||
- basic allow / room limit / sender limit
|
||||
- independent room and sender counters
|
||||
- sliding window prune (old events don't block)
|
||||
- reset() clears buckets
|
||||
- stats() reflects live state
|
||||
- constructor validation
|
||||
"""
|
||||
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
_BRIDGE = Path(__file__).parent.parent / "services" / "matrix-bridge-dagi"
|
||||
if str(_BRIDGE) not in sys.path:
|
||||
sys.path.insert(0, str(_BRIDGE))
|
||||
|
||||
from app.rate_limit import InMemoryRateLimiter # noqa: E402
|
||||
|
||||
ROOM = "!room1:server"
|
||||
ROOM2 = "!room2:server"
|
||||
SENDER = "@alice:server"
|
||||
SENDER2 = "@bob:server"
|
||||
|
||||
|
||||
def test_allows_first_message():
|
||||
rl = InMemoryRateLimiter(room_rpm=5, sender_rpm=5)
|
||||
allowed, limit_type = rl.check(ROOM, SENDER)
|
||||
assert allowed is True
|
||||
assert limit_type is None
|
||||
|
||||
|
||||
def test_room_limit_blocks_at_threshold():
|
||||
rl = InMemoryRateLimiter(room_rpm=3, sender_rpm=100)
|
||||
for _ in range(3):
|
||||
allowed, _ = rl.check(ROOM, SENDER)
|
||||
assert allowed
|
||||
# 4th from same room (different sender) should be blocked
|
||||
allowed, limit_type = rl.check(ROOM, SENDER2)
|
||||
assert allowed is False
|
||||
assert limit_type == "room"
|
||||
|
||||
|
||||
def test_sender_limit_blocks_at_threshold():
|
||||
rl = InMemoryRateLimiter(room_rpm=100, sender_rpm=2)
|
||||
allowed, _ = rl.check(ROOM, SENDER)
|
||||
assert allowed
|
||||
allowed, _ = rl.check(ROOM2, SENDER)
|
||||
assert allowed
|
||||
# 3rd from same sender (different room)
|
||||
allowed, limit_type = rl.check("!room3:server", SENDER)
|
||||
assert allowed is False
|
||||
assert limit_type == "sender"
|
||||
|
||||
|
||||
def test_room_checked_before_sender():
|
||||
"""When both would exceed, 'room' is reported first."""
|
||||
rl = InMemoryRateLimiter(room_rpm=1, sender_rpm=1)
|
||||
rl.check(ROOM, SENDER) # fills both
|
||||
allowed, limit_type = rl.check(ROOM, SENDER)
|
||||
assert not allowed
|
||||
assert limit_type == "room"
|
||||
|
||||
|
||||
def test_independent_rooms_dont_interfere():
|
||||
rl = InMemoryRateLimiter(room_rpm=2, sender_rpm=100)
|
||||
rl.check(ROOM, SENDER)
|
||||
rl.check(ROOM, SENDER)
|
||||
# room1 full — room2 still ok
|
||||
allowed, limit_type = rl.check(ROOM2, SENDER)
|
||||
assert allowed is True
|
||||
|
||||
|
||||
def test_independent_senders_dont_interfere():
|
||||
rl = InMemoryRateLimiter(room_rpm=100, sender_rpm=1)
|
||||
rl.check(ROOM, SENDER)
|
||||
# alice full — bob still ok
|
||||
allowed, _ = rl.check(ROOM, SENDER2)
|
||||
assert allowed is True
|
||||
|
||||
|
||||
def test_window_prune_allows_after_expiry(monkeypatch):
|
||||
"""Events older than 60s should not count against the limit."""
|
||||
rl = InMemoryRateLimiter(room_rpm=2, sender_rpm=100)
|
||||
# Fill the room bucket
|
||||
rl.check(ROOM, SENDER)
|
||||
rl.check(ROOM, SENDER)
|
||||
# Verify blocked
|
||||
ok, lt = rl.check(ROOM, SENDER2)
|
||||
assert not ok and lt == "room"
|
||||
|
||||
# Fast-forward time by 61 seconds
|
||||
original_time = time.monotonic
|
||||
start = original_time()
|
||||
monkeypatch.setattr(time, "monotonic", lambda: start + 61.0)
|
||||
|
||||
# Should be allowed again
|
||||
allowed, _ = rl.check(ROOM, SENDER2)
|
||||
assert allowed is True
|
||||
|
||||
|
||||
def test_reset_room_clears_bucket():
|
||||
rl = InMemoryRateLimiter(room_rpm=1, sender_rpm=100)
|
||||
rl.check(ROOM, SENDER)
|
||||
ok, lt = rl.check(ROOM, SENDER2)
|
||||
assert not ok and lt == "room"
|
||||
|
||||
rl.reset(room_id=ROOM)
|
||||
ok, _ = rl.check(ROOM, SENDER2)
|
||||
assert ok is True
|
||||
|
||||
|
||||
def test_reset_sender_clears_bucket():
|
||||
rl = InMemoryRateLimiter(room_rpm=100, sender_rpm=1)
|
||||
rl.check(ROOM, SENDER)
|
||||
ok, lt = rl.check(ROOM2, SENDER)
|
||||
assert not ok and lt == "sender"
|
||||
|
||||
rl.reset(sender=SENDER)
|
||||
ok, _ = rl.check(ROOM2, SENDER)
|
||||
assert ok is True
|
||||
|
||||
|
||||
def test_stats_reflects_active_buckets():
|
||||
rl = InMemoryRateLimiter(room_rpm=10, sender_rpm=10)
|
||||
rl.check(ROOM, SENDER)
|
||||
rl.check(ROOM2, SENDER2)
|
||||
s = rl.stats()
|
||||
assert s["active_rooms"] == 2
|
||||
assert s["active_senders"] == 2
|
||||
assert s["room_rpm_limit"] == 10
|
||||
assert s["sender_rpm_limit"] == 10
|
||||
|
||||
|
||||
def test_stats_stale_buckets_not_counted(monkeypatch):
|
||||
rl = InMemoryRateLimiter(room_rpm=10, sender_rpm=10)
|
||||
rl.check(ROOM, SENDER)
|
||||
|
||||
original_time = time.monotonic
|
||||
start = original_time()
|
||||
monkeypatch.setattr(time, "monotonic", lambda: start + 61.0)
|
||||
|
||||
s = rl.stats()
|
||||
assert s["active_rooms"] == 0
|
||||
assert s["active_senders"] == 0
|
||||
|
||||
|
||||
def test_constructor_validates_limits():
|
||||
import pytest
|
||||
with pytest.raises(ValueError):
|
||||
InMemoryRateLimiter(room_rpm=0, sender_rpm=5)
|
||||
with pytest.raises(ValueError):
|
||||
InMemoryRateLimiter(room_rpm=5, sender_rpm=-1)
|
||||
|
||||
|
||||
def test_sender_bucket_not_charged_when_room_blocked():
|
||||
"""When room blocks, sender quota must not decrease."""
|
||||
rl = InMemoryRateLimiter(room_rpm=1, sender_rpm=2)
|
||||
rl.check(ROOM, SENDER) # fills room (1/1), sender (1/2)
|
||||
# room blocked — sender should NOT be decremented
|
||||
rl.check(ROOM, SENDER) # blocked by room
|
||||
rl.check(ROOM, SENDER) # blocked by room
|
||||
|
||||
# Sender still has 1 slot left in a fresh room
|
||||
ok, lt = rl.check(ROOM2, SENDER)
|
||||
assert ok is True # sender only used 1/2 of its quota
|
||||
Reference in New Issue
Block a user