Files
microdao-daarion/services/router/experience_bus.py

447 lines
16 KiB
Python

"""Router experience event bus (Phase-1).
Collects inference outcome events, applies sampling + dedup, then
persists to JetStream and Postgres in async background worker.
"""
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
import json
import logging
import os
import random
import re
import time
import uuid
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
try:
import asyncpg
except ImportError: # pragma: no cover - runtime dependency in container
asyncpg = None
try:
from agent_metrics import (
inc_experience_db_insert,
inc_experience_dedup_dropped,
inc_experience_published,
inc_experience_sampled,
)
except Exception: # pragma: no cover - keep router resilient
def inc_experience_published(*args: Any, **kwargs: Any) -> None: # type: ignore[override]
return None
def inc_experience_db_insert(*args: Any, **kwargs: Any) -> None: # type: ignore[override]
return None
def inc_experience_dedup_dropped(*args: Any, **kwargs: Any) -> None: # type: ignore[override]
return None
def inc_experience_sampled(*args: Any, **kwargs: Any) -> None: # type: ignore[override]
return None
logger = logging.getLogger("experience_bus")
@dataclass
class ExperienceDecision:
keep: bool
reason: str
class ExperienceBus:
def __init__(self) -> None:
self.enabled = os.getenv("EXPERIENCE_BUS_ENABLED", "true").lower() in {"1", "true", "yes"}
self.node_id = os.getenv("NODE_ID", "NODA1")
self.ok_sample_pct = float(os.getenv("EXPERIENCE_OK_SAMPLE_PCT", "10"))
self.latency_spike_ms = int(os.getenv("EXPERIENCE_LATENCY_SPIKE_MS", "5000"))
self.dedup_window_s = int(os.getenv("EXPERIENCE_DEDUP_WINDOW_SECONDS", "900"))
self.dedup_max_keys = int(os.getenv("EXPERIENCE_DEDUP_MAX_KEYS", "20000"))
self.queue_max = int(os.getenv("EXPERIENCE_QUEUE_MAX", "2000"))
self.publish_timeout_s = float(os.getenv("EXPERIENCE_PUBLISH_TIMEOUT_MS", "800") or 800) / 1000.0
self.db_timeout_s = float(os.getenv("EXPERIENCE_DB_TIMEOUT_MS", "1200") or 1200) / 1000.0
self.subject_prefix = os.getenv("EXPERIENCE_SUBJECT_PREFIX", "agent.experience.v1")
self.stream_name = os.getenv("EXPERIENCE_STREAM_NAME", "EXPERIENCE")
self.enable_nats = os.getenv("EXPERIENCE_ENABLE_NATS", "true").lower() in {"1", "true", "yes"}
self.enable_db = os.getenv("EXPERIENCE_ENABLE_DB", "true").lower() in {"1", "true", "yes"}
self.db_dsn = os.getenv("EXPERIENCE_DATABASE_URL") or os.getenv("DATABASE_URL")
self._queue: asyncio.Queue[Optional[Dict[str, Any]]] = asyncio.Queue(maxsize=self.queue_max)
self._worker_task: Optional[asyncio.Task[Any]] = None
self._running = False
self._dedup_lock = asyncio.Lock()
self._dedup: "OrderedDict[str, float]" = OrderedDict()
self._pool: Optional[Any] = None
self._nc: Any = None
self._js: Any = None
async def start(self, nats_client: Any = None) -> None:
if not self.enabled:
logger.info("ExperienceBus disabled by env")
return
if self._running:
return
if self.enable_db:
await self._init_db()
if self.enable_nats and nats_client is not None:
await self.set_nats_client(nats_client)
self._running = True
self._worker_task = asyncio.create_task(self._worker(), name="experience-bus-worker")
logger.info(
"ExperienceBus started (db=%s nats=%s queue_max=%s sample_ok=%s%% dedup_window=%ss)",
bool(self._pool),
bool(self._js or self._nc),
self.queue_max,
self.ok_sample_pct,
self.dedup_window_s,
)
async def stop(self) -> None:
if not self._running:
return
self._running = False
try:
self._queue.put_nowait(None)
except asyncio.QueueFull:
pass
if self._worker_task is not None:
try:
await asyncio.wait_for(self._worker_task, timeout=5.0)
except Exception:
self._worker_task.cancel()
self._worker_task = None
if self._pool is not None:
try:
await self._pool.close()
except Exception as e: # pragma: no cover
logger.debug("ExperienceBus pool close error: %s", e)
self._pool = None
self._js = None
self._nc = None
logger.info("ExperienceBus stopped")
async def set_nats_client(self, nats_client: Any) -> None:
if not self.enabled or not self.enable_nats:
return
self._nc = nats_client
if self._nc is None:
self._js = None
return
try:
self._js = self._nc.jetstream()
await self._ensure_stream()
except Exception as e:
self._js = None
logger.warning("ExperienceBus JetStream unavailable: %s", e)
async def capture(self, event: Dict[str, Any]) -> None:
"""Apply sampling/dedup and enqueue for async persistence."""
if not self.enabled or not self._running:
return
decision = await self._decide(event)
if not decision.keep:
if decision.reason == "dedup":
inc_experience_dedup_dropped(source="router")
inc_experience_sampled(source="router", decision="out", reason=decision.reason)
return
inc_experience_sampled(source="router", decision="in", reason=decision.reason)
try:
self._queue.put_nowait(event)
except asyncio.QueueFull:
inc_experience_sampled(source="router", decision="out", reason="queue_full")
logger.warning("ExperienceBus queue full; dropping event")
async def _init_db(self) -> None:
if asyncpg is None:
logger.warning("ExperienceBus DB disabled: asyncpg not installed")
return
if not self.db_dsn:
logger.warning("ExperienceBus DB disabled: DATABASE_URL missing")
return
try:
self._pool = await asyncpg.create_pool(self.db_dsn, min_size=1, max_size=3)
except Exception as e:
self._pool = None
logger.warning("ExperienceBus DB pool init failed: %s", e)
async def _ensure_stream(self) -> None:
if self._js is None:
return
subjects = [f"{self.subject_prefix}.>"]
try:
info = await self._js.stream_info(self.stream_name)
stream_subjects = set(getattr(info.config, "subjects", []) or [])
if not stream_subjects.intersection(subjects):
logger.warning(
"ExperienceBus stream '%s' exists without subject %s; keeping as-is",
self.stream_name,
subjects[0],
)
return
except Exception:
pass
try:
await self._js.add_stream(name=self.stream_name, subjects=subjects)
logger.info("ExperienceBus stream ensured: %s subjects=%s", self.stream_name, subjects)
except Exception as e:
logger.warning("ExperienceBus stream ensure failed: %s", e)
async def _decide(self, event: Dict[str, Any]) -> ExperienceDecision:
result = event.get("result") or {}
llm = event.get("llm") or {}
ok = bool(result.get("ok"))
latency_ms = int(llm.get("latency_ms") or 0)
if not ok:
sample_reason = "error"
elif latency_ms >= self.latency_spike_ms:
sample_reason = "latency_spike"
else:
if random.random() * 100.0 >= self.ok_sample_pct:
return ExperienceDecision(False, "ok_sample_out")
sample_reason = "ok_sample_in"
dedup_key = self._dedup_key(event)
now = time.monotonic()
async with self._dedup_lock:
self._prune_dedup(now)
seen_at = self._dedup.get(dedup_key)
if seen_at is not None and (now - seen_at) < self.dedup_window_s:
return ExperienceDecision(False, "dedup")
self._dedup[dedup_key] = now
self._dedup.move_to_end(dedup_key, last=True)
while len(self._dedup) > self.dedup_max_keys:
self._dedup.popitem(last=False)
return ExperienceDecision(True, sample_reason)
def _prune_dedup(self, now: float) -> None:
if not self._dedup:
return
threshold = now - self.dedup_window_s
while self._dedup:
_, ts = next(iter(self._dedup.items()))
if ts >= threshold:
break
self._dedup.popitem(last=False)
def _dedup_key(self, event: Dict[str, Any]) -> str:
result = event.get("result") or {}
return "|".join(
[
str(event.get("agent_id") or ""),
str(event.get("task_type") or ""),
str(event.get("inputs_hash") or ""),
"1" if bool(result.get("ok")) else "0",
str(result.get("error_class") or ""),
]
)
async def _worker(self) -> None:
while True:
event = await self._queue.get()
if event is None:
self._queue.task_done()
break
try:
await self._persist_event(event)
except Exception as e: # pragma: no cover
logger.warning("ExperienceBus persist error: %s", e)
finally:
self._queue.task_done()
async def _persist_event(self, event: Dict[str, Any]) -> None:
await self._publish_nats(event)
await self._insert_db(event)
async def _publish_nats(self, event: Dict[str, Any]) -> None:
subject = f"{self.subject_prefix}.{event.get('agent_id', 'unknown')}"
payload = json.dumps(event, ensure_ascii=False).encode("utf-8")
msg_id = str(event.get("event_id") or "").strip()
headers = {"Nats-Msg-Id": msg_id} if msg_id else None
if self._js is not None:
try:
await asyncio.wait_for(
self._js.publish(subject, payload, headers=headers),
timeout=self.publish_timeout_s,
)
inc_experience_published(source="router", transport="jetstream", status="ok")
return
except Exception as e:
inc_experience_published(source="router", transport="jetstream", status="error")
logger.debug("ExperienceBus JetStream publish failed: %s", e)
if self._nc is not None:
try:
await asyncio.wait_for(
self._nc.publish(subject, payload, headers=headers),
timeout=self.publish_timeout_s,
)
await asyncio.wait_for(self._nc.flush(), timeout=self.publish_timeout_s)
inc_experience_published(source="router", transport="core", status="ok")
return
except Exception as e:
inc_experience_published(source="router", transport="core", status="error")
logger.debug("ExperienceBus core NATS publish failed: %s", e)
inc_experience_published(source="router", transport="none", status="skipped")
async def _insert_db(self, event: Dict[str, Any]) -> None:
if self._pool is None:
inc_experience_db_insert(source="router", status="skipped")
return
payload_json = json.dumps(event, ensure_ascii=False)
llm = event.get("llm") or {}
result = event.get("result") or {}
event_uuid = _as_uuid(event.get("event_id"))
event_ts = _as_timestamptz(event.get("ts"))
query = """
INSERT INTO agent_experience_events (
event_id,
ts,
node_id,
source,
agent_id,
task_type,
request_id,
channel,
inputs_hash,
provider,
model,
profile,
latency_ms,
tokens_in,
tokens_out,
ok,
error_class,
error_msg_redacted,
http_status,
raw
) VALUES (
$1::uuid,
$2::timestamptz,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15,
$16,
$17,
$18,
$19,
$20::jsonb
)
ON CONFLICT (event_id) DO NOTHING
"""
try:
async with self._pool.acquire() as conn:
await asyncio.wait_for(
conn.execute(
query,
event_uuid,
event_ts,
event.get("node_id"),
event.get("source"),
event.get("agent_id"),
event.get("task_type"),
event.get("request_id"),
event.get("channel", "unknown"),
event.get("inputs_hash"),
llm.get("provider", "unknown"),
llm.get("model", "unknown"),
llm.get("profile"),
int(llm.get("latency_ms") or 0),
_as_int_or_none(llm.get("tokens_in")),
_as_int_or_none(llm.get("tokens_out")),
bool(result.get("ok")),
result.get("error_class"),
result.get("error_msg_redacted"),
int(result.get("http_status") or 0),
payload_json,
),
timeout=self.db_timeout_s,
)
inc_experience_db_insert(source="router", status="ok")
except Exception as e:
inc_experience_db_insert(source="router", status="error")
logger.debug("ExperienceBus DB insert failed: %s", e)
def redact_error_message(value: Optional[str]) -> Optional[str]:
if value is None:
return None
text = str(value)
text = re.sub(r"(?i)(authorization\s*:\s*bearer)\s+[A-Za-z0-9._-]+", r"\1 [redacted]", text)
text = re.sub(r"(?i)(api[_-]?key|token|password|secret)\s*[:=]\s*[^\s,;]+", r"\1=[redacted]", text)
text = re.sub(r"\b[A-Za-z0-9_\-]{24,}\b", "[redacted]", text)
text = re.sub(r"\s+", " ", text).strip()
if len(text) > 300:
return text[:300]
return text
def normalize_input_for_hash(text: str) -> str:
value = re.sub(r"\s+", " ", (text or "").strip()).lower()
return value[:4000]
def _as_int_or_none(value: Any) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except Exception:
return None
def _as_uuid(value: Any) -> uuid.UUID:
try:
return uuid.UUID(str(value))
except Exception:
return uuid.uuid4()
def _as_timestamptz(value: Any) -> datetime:
if isinstance(value, datetime):
return value if value.tzinfo is not None else value.replace(tzinfo=timezone.utc)
try:
parsed = datetime.fromisoformat(str(value).replace("Z", "+00:00"))
return parsed if parsed.tzinfo is not None else parsed.replace(tzinfo=timezone.utc)
except Exception:
return datetime.now(timezone.utc)