Files
microdao-daarion/gateway-bot/gateway_experience_bus.py

414 lines
13 KiB
Python

"""Gateway experience event publisher/store (Phase-4).
Best-effort, fail-open telemetry for gateway webhook flow:
- publish to JetStream subject agent.experience.v1.<agent_id>
- optional DB append-only insert into agent_experience_events
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, Optional
try:
import asyncpg
except ImportError: # pragma: no cover
asyncpg = None # type: ignore[assignment]
try:
import nats
except ImportError: # pragma: no cover
nats = None # type: ignore[assignment]
try:
from metrics import GATEWAY_EXPERIENCE_PUBLISHED_TOTAL
METRICS_AVAILABLE = True
except Exception: # pragma: no cover
METRICS_AVAILABLE = False
GATEWAY_EXPERIENCE_PUBLISHED_TOTAL = None # type: ignore[assignment]
logger = logging.getLogger("gateway.experience_bus")
def _metric_publish(status: str) -> None:
if METRICS_AVAILABLE and GATEWAY_EXPERIENCE_PUBLISHED_TOTAL is not None:
GATEWAY_EXPERIENCE_PUBLISHED_TOTAL.labels(status=status).inc()
class GatewayExperienceBus:
def __init__(self) -> None:
self.enabled = os.getenv("EXPERIENCE_BUS_ENABLED", "true").lower() in {"1", "true", "yes"}
self.enable_nats = os.getenv("EXPERIENCE_ENABLE_NATS", "true").lower() in {"1", "true", "yes"}
self.enable_db = os.getenv("EXPERIENCE_ENABLE_DB", "true").lower() in {"1", "true", "yes"}
self.node_id = os.getenv("NODE_ID", "NODA1")
self.nats_url = os.getenv("NATS_URL", "nats://nats:4222")
self.stream_name = os.getenv("EXPERIENCE_STREAM_NAME", "EXPERIENCE")
self.subject_prefix = os.getenv("EXPERIENCE_SUBJECT_PREFIX", "agent.experience.v1")
self.publish_timeout_s = float(os.getenv("EXPERIENCE_PUBLISH_TIMEOUT_MS", "800") or 800) / 1000.0
self.db_timeout_s = float(os.getenv("EXPERIENCE_DB_TIMEOUT_MS", "1200") or 1200) / 1000.0
self.db_dsn = os.getenv("EXPERIENCE_DATABASE_URL") or os.getenv("DATABASE_URL")
self._lock = asyncio.Lock()
self._nc: Any = None
self._js: Any = None
self._pool: Any = None
self._stream_ensured = False
async def capture(self, event: Dict[str, Any]) -> None:
if not self.enabled:
return
try:
await self._ensure_clients()
except Exception as e: # pragma: no cover
logger.debug("gateway experience ensure clients failed: %s", e)
nats_ok = await self._publish_nats(event)
db_ok = await self._insert_db(event)
if nats_ok or db_ok:
_metric_publish("ok")
else:
_metric_publish("err")
async def get_anti_silent_tuning_lesson(
self,
*,
reason: str,
chat_type: str,
timeout_s: float = 0.04,
) -> Optional[Dict[str, Any]]:
"""Lookup active anti-silent tuning lesson for (reason, chat_type).
Returns lesson raw payload or None. Fail-open by design.
"""
if not self.enabled or not self.enable_db:
return None
try:
await self._ensure_clients()
except Exception:
return None
if self._pool is None:
return None
trigger = f"reason={reason};chat_type={chat_type}"
query = """
SELECT raw
FROM agent_lessons
WHERE COALESCE(raw->>'lesson_type', '') = 'anti_silent_tuning'
AND trigger = $1
AND (
NULLIF(COALESCE(raw->>'expires_at', ''), '') IS NULL
OR (raw->>'expires_at')::timestamptz > now()
)
ORDER BY ts DESC
LIMIT 1
"""
try:
async with self._pool.acquire() as conn:
row = await asyncio.wait_for(conn.fetchrow(query, trigger), timeout=timeout_s)
if row is None:
return None
raw = row.get("raw")
if isinstance(raw, dict):
return raw
if isinstance(raw, str):
return json.loads(raw)
return None
except Exception:
return None
async def get_agent_access_policy(
self,
*,
agent_id: str,
timeout_s: float = 0.04,
) -> Optional[Dict[str, Any]]:
"""Lookup access policy row for an agent. Returns None on miss/errors."""
if not self.enabled or not self.enable_db:
return None
try:
await self._ensure_clients()
except Exception:
return None
if self._pool is None:
return None
query = """
SELECT
agent_id,
enabled,
public_active,
requires_whitelist,
user_global_limit,
user_global_window_seconds,
user_agent_limit,
user_agent_window_seconds,
group_agent_limit,
group_agent_window_seconds
FROM agent_access_policies
WHERE agent_id = $1
LIMIT 1
"""
try:
async with self._pool.acquire() as conn:
row = await asyncio.wait_for(conn.fetchrow(query, agent_id), timeout=timeout_s)
if row is None:
return None
return {
"agent_id": row.get("agent_id"),
"enabled": bool(row.get("enabled")),
"public_active": bool(row.get("public_active")),
"requires_whitelist": bool(row.get("requires_whitelist")),
"user_global_limit": int(row.get("user_global_limit") or 0),
"user_global_window_seconds": int(row.get("user_global_window_seconds") or 0),
"user_agent_limit": int(row.get("user_agent_limit") or 0),
"user_agent_window_seconds": int(row.get("user_agent_window_seconds") or 0),
"group_agent_limit": int(row.get("group_agent_limit") or 0),
"group_agent_window_seconds": int(row.get("group_agent_window_seconds") or 0),
}
except Exception:
return None
async def is_allowlisted(
self,
*,
platform: str,
platform_user_id: str,
agent_id: str,
timeout_s: float = 0.04,
) -> bool:
"""Return True when (platform, user, agent) exists in allowlist."""
if not self.enabled or not self.enable_db:
return False
try:
await self._ensure_clients()
except Exception:
return False
if self._pool is None:
return False
query = """
SELECT 1
FROM agent_allowlist
WHERE platform = $1
AND platform_user_id = $2
AND agent_id = $3
LIMIT 1
"""
try:
async with self._pool.acquire() as conn:
row = await asyncio.wait_for(
conn.fetchrow(query, platform, platform_user_id, agent_id),
timeout=timeout_s,
)
return row is not None
except Exception:
return False
async def close(self) -> None:
if self._pool is not None:
try:
await self._pool.close()
except Exception:
pass
self._pool = None
if self._nc is not None:
try:
await self._nc.close()
except Exception:
pass
self._nc = None
self._js = None
async def _ensure_clients(self) -> None:
async with self._lock:
if self.enable_nats and self._nc is None and nats is not None:
try:
self._nc = await nats.connect(self.nats_url)
self._js = self._nc.jetstream()
except Exception as e:
logger.debug("gateway experience nats connect failed: %s", e)
self._nc = None
self._js = None
if self.enable_db and self._pool is None and asyncpg is not None and self.db_dsn:
try:
self._pool = await asyncpg.create_pool(self.db_dsn, min_size=1, max_size=2)
except Exception as e:
logger.debug("gateway experience db pool failed: %s", e)
self._pool = None
if self._js is not None and not self._stream_ensured:
await self._ensure_stream()
async def _ensure_stream(self) -> None:
if self._js is None:
return
subjects = [f"{self.subject_prefix}.>"]
try:
await self._js.stream_info(self.stream_name)
self._stream_ensured = True
return
except Exception:
pass
try:
await self._js.add_stream(name=self.stream_name, subjects=subjects)
self._stream_ensured = True
except Exception as e:
logger.debug("gateway experience ensure stream failed: %s", e)
async def _publish_nats(self, event: Dict[str, Any]) -> bool:
if not self.enable_nats:
return False
if self._js is None:
return False
subject = f"{self.subject_prefix}.{event.get('agent_id', 'unknown')}"
payload = json.dumps(event, ensure_ascii=False).encode("utf-8")
msg_id = str(event.get("event_id") or "").strip()
headers = {"Nats-Msg-Id": msg_id} if msg_id else None
try:
await asyncio.wait_for(self._js.publish(subject, payload, headers=headers), timeout=self.publish_timeout_s)
return True
except Exception as e:
logger.debug("gateway experience nats publish failed: %s", e)
return False
async def _insert_db(self, event: Dict[str, Any]) -> bool:
if not self.enable_db:
return False
if self._pool is None:
return False
llm = event.get("llm") or {}
result = event.get("result") or {}
query = """
INSERT INTO agent_experience_events (
event_id,
ts,
node_id,
source,
agent_id,
task_type,
request_id,
channel,
inputs_hash,
provider,
model,
profile,
latency_ms,
tokens_in,
tokens_out,
ok,
error_class,
error_msg_redacted,
http_status,
raw
) VALUES (
$1::uuid,
$2::timestamptz,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15,
$16,
$17,
$18,
$19,
$20::jsonb
)
ON CONFLICT (event_id) DO NOTHING
"""
try:
payload_json = json.dumps(event, ensure_ascii=False)
async with self._pool.acquire() as conn:
await asyncio.wait_for(
conn.execute(
query,
_as_uuid(event.get("event_id")),
_as_timestamptz(event.get("ts")),
event.get("node_id", self.node_id),
event.get("source", "gateway"),
event.get("agent_id"),
event.get("task_type", "webhook"),
event.get("request_id"),
event.get("channel", "telegram"),
event.get("inputs_hash"),
llm.get("provider", "gateway"),
llm.get("model", "gateway"),
llm.get("profile"),
int(llm.get("latency_ms") or 0),
_as_int_or_none(llm.get("tokens_in")),
_as_int_or_none(llm.get("tokens_out")),
bool(result.get("ok")),
result.get("error_class"),
result.get("error_msg_redacted"),
int(result.get("http_status") or 0),
payload_json,
),
timeout=self.db_timeout_s,
)
return True
except Exception as e:
logger.debug("gateway experience db insert failed: %s", e)
return False
def _as_int_or_none(value: Any) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except Exception:
return None
def _as_uuid(value: Any) -> uuid.UUID:
try:
return uuid.UUID(str(value))
except Exception:
return uuid.uuid4()
def _as_timestamptz(value: Any) -> datetime:
if isinstance(value, datetime):
return value if value.tzinfo is not None else value.replace(tzinfo=timezone.utc)
try:
parsed = datetime.fromisoformat(str(value).replace("Z", "+00:00"))
return parsed if parsed.tzinfo is not None else parsed.replace(tzinfo=timezone.utc)
except Exception:
return datetime.now(timezone.utc)
_gateway_bus_singleton: Optional[GatewayExperienceBus] = None
def get_gateway_experience_bus() -> GatewayExperienceBus:
global _gateway_bus_singleton
if _gateway_bus_singleton is None:
_gateway_bus_singleton = GatewayExperienceBus()
return _gateway_bus_singleton