"""Router experience event bus (Phase-1). Collects inference outcome events, applies sampling + dedup, then persists to JetStream and Postgres in async background worker. """ from __future__ import annotations import asyncio from datetime import datetime, timezone import json import logging import os import random import re import time import uuid from collections import OrderedDict from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple try: import asyncpg except ImportError: # pragma: no cover - runtime dependency in container asyncpg = None try: from agent_metrics import ( inc_experience_db_insert, inc_experience_dedup_dropped, inc_experience_published, inc_experience_sampled, ) except Exception: # pragma: no cover - keep router resilient def inc_experience_published(*args: Any, **kwargs: Any) -> None: # type: ignore[override] return None def inc_experience_db_insert(*args: Any, **kwargs: Any) -> None: # type: ignore[override] return None def inc_experience_dedup_dropped(*args: Any, **kwargs: Any) -> None: # type: ignore[override] return None def inc_experience_sampled(*args: Any, **kwargs: Any) -> None: # type: ignore[override] return None logger = logging.getLogger("experience_bus") @dataclass class ExperienceDecision: keep: bool reason: str class ExperienceBus: def __init__(self) -> None: self.enabled = os.getenv("EXPERIENCE_BUS_ENABLED", "true").lower() in {"1", "true", "yes"} self.node_id = os.getenv("NODE_ID", "NODA1") self.ok_sample_pct = float(os.getenv("EXPERIENCE_OK_SAMPLE_PCT", "10")) self.latency_spike_ms = int(os.getenv("EXPERIENCE_LATENCY_SPIKE_MS", "5000")) self.dedup_window_s = int(os.getenv("EXPERIENCE_DEDUP_WINDOW_SECONDS", "900")) self.dedup_max_keys = int(os.getenv("EXPERIENCE_DEDUP_MAX_KEYS", "20000")) self.queue_max = int(os.getenv("EXPERIENCE_QUEUE_MAX", "2000")) self.publish_timeout_s = float(os.getenv("EXPERIENCE_PUBLISH_TIMEOUT_MS", "800") or 800) / 1000.0 self.db_timeout_s = float(os.getenv("EXPERIENCE_DB_TIMEOUT_MS", "1200") or 1200) / 1000.0 self.subject_prefix = os.getenv("EXPERIENCE_SUBJECT_PREFIX", "agent.experience.v1") self.stream_name = os.getenv("EXPERIENCE_STREAM_NAME", "EXPERIENCE") self.enable_nats = os.getenv("EXPERIENCE_ENABLE_NATS", "true").lower() in {"1", "true", "yes"} self.enable_db = os.getenv("EXPERIENCE_ENABLE_DB", "true").lower() in {"1", "true", "yes"} self.db_dsn = os.getenv("EXPERIENCE_DATABASE_URL") or os.getenv("DATABASE_URL") self._queue: asyncio.Queue[Optional[Dict[str, Any]]] = asyncio.Queue(maxsize=self.queue_max) self._worker_task: Optional[asyncio.Task[Any]] = None self._running = False self._dedup_lock = asyncio.Lock() self._dedup: "OrderedDict[str, float]" = OrderedDict() self._pool: Optional[Any] = None self._nc: Any = None self._js: Any = None async def start(self, nats_client: Any = None) -> None: if not self.enabled: logger.info("ExperienceBus disabled by env") return if self._running: return if self.enable_db: await self._init_db() if self.enable_nats and nats_client is not None: await self.set_nats_client(nats_client) self._running = True self._worker_task = asyncio.create_task(self._worker(), name="experience-bus-worker") logger.info( "ExperienceBus started (db=%s nats=%s queue_max=%s sample_ok=%s%% dedup_window=%ss)", bool(self._pool), bool(self._js or self._nc), self.queue_max, self.ok_sample_pct, self.dedup_window_s, ) async def stop(self) -> None: if not self._running: return self._running = False try: self._queue.put_nowait(None) except asyncio.QueueFull: pass if self._worker_task is not None: try: await asyncio.wait_for(self._worker_task, timeout=5.0) except Exception: self._worker_task.cancel() self._worker_task = None if self._pool is not None: try: await self._pool.close() except Exception as e: # pragma: no cover logger.debug("ExperienceBus pool close error: %s", e) self._pool = None self._js = None self._nc = None logger.info("ExperienceBus stopped") async def set_nats_client(self, nats_client: Any) -> None: if not self.enabled or not self.enable_nats: return self._nc = nats_client if self._nc is None: self._js = None return try: self._js = self._nc.jetstream() await self._ensure_stream() except Exception as e: self._js = None logger.warning("ExperienceBus JetStream unavailable: %s", e) async def capture(self, event: Dict[str, Any]) -> None: """Apply sampling/dedup and enqueue for async persistence.""" if not self.enabled or not self._running: return decision = await self._decide(event) if not decision.keep: if decision.reason == "dedup": inc_experience_dedup_dropped(source="router") inc_experience_sampled(source="router", decision="out", reason=decision.reason) return inc_experience_sampled(source="router", decision="in", reason=decision.reason) try: self._queue.put_nowait(event) except asyncio.QueueFull: inc_experience_sampled(source="router", decision="out", reason="queue_full") logger.warning("ExperienceBus queue full; dropping event") async def _init_db(self) -> None: if asyncpg is None: logger.warning("ExperienceBus DB disabled: asyncpg not installed") return if not self.db_dsn: logger.warning("ExperienceBus DB disabled: DATABASE_URL missing") return try: self._pool = await asyncpg.create_pool(self.db_dsn, min_size=1, max_size=3) except Exception as e: self._pool = None logger.warning("ExperienceBus DB pool init failed: %s", e) async def _ensure_stream(self) -> None: if self._js is None: return subjects = [f"{self.subject_prefix}.>"] try: info = await self._js.stream_info(self.stream_name) stream_subjects = set(getattr(info.config, "subjects", []) or []) if not stream_subjects.intersection(subjects): logger.warning( "ExperienceBus stream '%s' exists without subject %s; keeping as-is", self.stream_name, subjects[0], ) return except Exception: pass try: await self._js.add_stream(name=self.stream_name, subjects=subjects) logger.info("ExperienceBus stream ensured: %s subjects=%s", self.stream_name, subjects) except Exception as e: logger.warning("ExperienceBus stream ensure failed: %s", e) async def _decide(self, event: Dict[str, Any]) -> ExperienceDecision: result = event.get("result") or {} llm = event.get("llm") or {} ok = bool(result.get("ok")) latency_ms = int(llm.get("latency_ms") or 0) if not ok: sample_reason = "error" elif latency_ms >= self.latency_spike_ms: sample_reason = "latency_spike" else: if random.random() * 100.0 >= self.ok_sample_pct: return ExperienceDecision(False, "ok_sample_out") sample_reason = "ok_sample_in" dedup_key = self._dedup_key(event) now = time.monotonic() async with self._dedup_lock: self._prune_dedup(now) seen_at = self._dedup.get(dedup_key) if seen_at is not None and (now - seen_at) < self.dedup_window_s: return ExperienceDecision(False, "dedup") self._dedup[dedup_key] = now self._dedup.move_to_end(dedup_key, last=True) while len(self._dedup) > self.dedup_max_keys: self._dedup.popitem(last=False) return ExperienceDecision(True, sample_reason) def _prune_dedup(self, now: float) -> None: if not self._dedup: return threshold = now - self.dedup_window_s while self._dedup: _, ts = next(iter(self._dedup.items())) if ts >= threshold: break self._dedup.popitem(last=False) def _dedup_key(self, event: Dict[str, Any]) -> str: result = event.get("result") or {} return "|".join( [ str(event.get("agent_id") or ""), str(event.get("task_type") or ""), str(event.get("inputs_hash") or ""), "1" if bool(result.get("ok")) else "0", str(result.get("error_class") or ""), ] ) async def _worker(self) -> None: while True: event = await self._queue.get() if event is None: self._queue.task_done() break try: await self._persist_event(event) except Exception as e: # pragma: no cover logger.warning("ExperienceBus persist error: %s", e) finally: self._queue.task_done() async def _persist_event(self, event: Dict[str, Any]) -> None: await self._publish_nats(event) await self._insert_db(event) async def _publish_nats(self, event: Dict[str, Any]) -> None: subject = f"{self.subject_prefix}.{event.get('agent_id', 'unknown')}" payload = json.dumps(event, ensure_ascii=False).encode("utf-8") msg_id = str(event.get("event_id") or "").strip() headers = {"Nats-Msg-Id": msg_id} if msg_id else None if self._js is not None: try: await asyncio.wait_for( self._js.publish(subject, payload, headers=headers), timeout=self.publish_timeout_s, ) inc_experience_published(source="router", transport="jetstream", status="ok") return except Exception as e: inc_experience_published(source="router", transport="jetstream", status="error") logger.debug("ExperienceBus JetStream publish failed: %s", e) if self._nc is not None: try: await asyncio.wait_for( self._nc.publish(subject, payload, headers=headers), timeout=self.publish_timeout_s, ) await asyncio.wait_for(self._nc.flush(), timeout=self.publish_timeout_s) inc_experience_published(source="router", transport="core", status="ok") return except Exception as e: inc_experience_published(source="router", transport="core", status="error") logger.debug("ExperienceBus core NATS publish failed: %s", e) inc_experience_published(source="router", transport="none", status="skipped") async def _insert_db(self, event: Dict[str, Any]) -> None: if self._pool is None: inc_experience_db_insert(source="router", status="skipped") return payload_json = json.dumps(event, ensure_ascii=False) llm = event.get("llm") or {} result = event.get("result") or {} event_uuid = _as_uuid(event.get("event_id")) event_ts = _as_timestamptz(event.get("ts")) query = """ INSERT INTO agent_experience_events ( event_id, ts, node_id, source, agent_id, task_type, request_id, channel, inputs_hash, provider, model, profile, latency_ms, tokens_in, tokens_out, ok, error_class, error_msg_redacted, http_status, raw ) VALUES ( $1::uuid, $2::timestamptz, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20::jsonb ) ON CONFLICT (event_id) DO NOTHING """ try: async with self._pool.acquire() as conn: await asyncio.wait_for( conn.execute( query, event_uuid, event_ts, event.get("node_id"), event.get("source"), event.get("agent_id"), event.get("task_type"), event.get("request_id"), event.get("channel", "unknown"), event.get("inputs_hash"), llm.get("provider", "unknown"), llm.get("model", "unknown"), llm.get("profile"), int(llm.get("latency_ms") or 0), _as_int_or_none(llm.get("tokens_in")), _as_int_or_none(llm.get("tokens_out")), bool(result.get("ok")), result.get("error_class"), result.get("error_msg_redacted"), int(result.get("http_status") or 0), payload_json, ), timeout=self.db_timeout_s, ) inc_experience_db_insert(source="router", status="ok") except Exception as e: inc_experience_db_insert(source="router", status="error") logger.debug("ExperienceBus DB insert failed: %s", e) def redact_error_message(value: Optional[str]) -> Optional[str]: if value is None: return None text = str(value) text = re.sub(r"(?i)(authorization\s*:\s*bearer)\s+[A-Za-z0-9._-]+", r"\1 [redacted]", text) text = re.sub(r"(?i)(api[_-]?key|token|password|secret)\s*[:=]\s*[^\s,;]+", r"\1=[redacted]", text) text = re.sub(r"\b[A-Za-z0-9_\-]{24,}\b", "[redacted]", text) text = re.sub(r"\s+", " ", text).strip() if len(text) > 300: return text[:300] return text def normalize_input_for_hash(text: str) -> str: value = re.sub(r"\s+", " ", (text or "").strip()).lower() return value[:4000] def _as_int_or_none(value: Any) -> Optional[int]: try: if value is None: return None return int(value) except Exception: return None def _as_uuid(value: Any) -> uuid.UUID: try: return uuid.UUID(str(value)) except Exception: return uuid.uuid4() def _as_timestamptz(value: Any) -> datetime: if isinstance(value, datetime): return value if value.tzinfo is not None else value.replace(tzinfo=timezone.utc) try: parsed = datetime.fromisoformat(str(value).replace("Z", "+00:00")) return parsed if parsed.tzinfo is not None else parsed.replace(tzinfo=timezone.utc) except Exception: return datetime.now(timezone.utc)