414 lines
13 KiB
Python
414 lines
13 KiB
Python
"""Gateway experience event publisher/store (Phase-4).
|
|
|
|
Best-effort, fail-open telemetry for gateway webhook flow:
|
|
- publish to JetStream subject agent.experience.v1.<agent_id>
|
|
- optional DB append-only insert into agent_experience_events
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, Optional
|
|
|
|
try:
|
|
import asyncpg
|
|
except ImportError: # pragma: no cover
|
|
asyncpg = None # type: ignore[assignment]
|
|
|
|
try:
|
|
import nats
|
|
except ImportError: # pragma: no cover
|
|
nats = None # type: ignore[assignment]
|
|
|
|
try:
|
|
from metrics import GATEWAY_EXPERIENCE_PUBLISHED_TOTAL
|
|
METRICS_AVAILABLE = True
|
|
except Exception: # pragma: no cover
|
|
METRICS_AVAILABLE = False
|
|
GATEWAY_EXPERIENCE_PUBLISHED_TOTAL = None # type: ignore[assignment]
|
|
|
|
|
|
logger = logging.getLogger("gateway.experience_bus")
|
|
|
|
|
|
def _metric_publish(status: str) -> None:
|
|
if METRICS_AVAILABLE and GATEWAY_EXPERIENCE_PUBLISHED_TOTAL is not None:
|
|
GATEWAY_EXPERIENCE_PUBLISHED_TOTAL.labels(status=status).inc()
|
|
|
|
|
|
class GatewayExperienceBus:
|
|
def __init__(self) -> None:
|
|
self.enabled = os.getenv("EXPERIENCE_BUS_ENABLED", "true").lower() in {"1", "true", "yes"}
|
|
self.enable_nats = os.getenv("EXPERIENCE_ENABLE_NATS", "true").lower() in {"1", "true", "yes"}
|
|
self.enable_db = os.getenv("EXPERIENCE_ENABLE_DB", "true").lower() in {"1", "true", "yes"}
|
|
|
|
self.node_id = os.getenv("NODE_ID", "NODA1")
|
|
self.nats_url = os.getenv("NATS_URL", "nats://nats:4222")
|
|
self.stream_name = os.getenv("EXPERIENCE_STREAM_NAME", "EXPERIENCE")
|
|
self.subject_prefix = os.getenv("EXPERIENCE_SUBJECT_PREFIX", "agent.experience.v1")
|
|
self.publish_timeout_s = float(os.getenv("EXPERIENCE_PUBLISH_TIMEOUT_MS", "800") or 800) / 1000.0
|
|
self.db_timeout_s = float(os.getenv("EXPERIENCE_DB_TIMEOUT_MS", "1200") or 1200) / 1000.0
|
|
|
|
self.db_dsn = os.getenv("EXPERIENCE_DATABASE_URL") or os.getenv("DATABASE_URL")
|
|
|
|
self._lock = asyncio.Lock()
|
|
self._nc: Any = None
|
|
self._js: Any = None
|
|
self._pool: Any = None
|
|
self._stream_ensured = False
|
|
|
|
async def capture(self, event: Dict[str, Any]) -> None:
|
|
if not self.enabled:
|
|
return
|
|
|
|
try:
|
|
await self._ensure_clients()
|
|
except Exception as e: # pragma: no cover
|
|
logger.debug("gateway experience ensure clients failed: %s", e)
|
|
|
|
nats_ok = await self._publish_nats(event)
|
|
db_ok = await self._insert_db(event)
|
|
|
|
if nats_ok or db_ok:
|
|
_metric_publish("ok")
|
|
else:
|
|
_metric_publish("err")
|
|
|
|
async def get_anti_silent_tuning_lesson(
|
|
self,
|
|
*,
|
|
reason: str,
|
|
chat_type: str,
|
|
timeout_s: float = 0.04,
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Lookup active anti-silent tuning lesson for (reason, chat_type).
|
|
|
|
Returns lesson raw payload or None. Fail-open by design.
|
|
"""
|
|
if not self.enabled or not self.enable_db:
|
|
return None
|
|
try:
|
|
await self._ensure_clients()
|
|
except Exception:
|
|
return None
|
|
if self._pool is None:
|
|
return None
|
|
|
|
trigger = f"reason={reason};chat_type={chat_type}"
|
|
query = """
|
|
SELECT raw
|
|
FROM agent_lessons
|
|
WHERE COALESCE(raw->>'lesson_type', '') = 'anti_silent_tuning'
|
|
AND trigger = $1
|
|
AND (
|
|
NULLIF(COALESCE(raw->>'expires_at', ''), '') IS NULL
|
|
OR (raw->>'expires_at')::timestamptz > now()
|
|
)
|
|
ORDER BY ts DESC
|
|
LIMIT 1
|
|
"""
|
|
try:
|
|
async with self._pool.acquire() as conn:
|
|
row = await asyncio.wait_for(conn.fetchrow(query, trigger), timeout=timeout_s)
|
|
if row is None:
|
|
return None
|
|
raw = row.get("raw")
|
|
if isinstance(raw, dict):
|
|
return raw
|
|
if isinstance(raw, str):
|
|
return json.loads(raw)
|
|
return None
|
|
except Exception:
|
|
return None
|
|
|
|
async def get_agent_access_policy(
|
|
self,
|
|
*,
|
|
agent_id: str,
|
|
timeout_s: float = 0.04,
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Lookup access policy row for an agent. Returns None on miss/errors."""
|
|
if not self.enabled or not self.enable_db:
|
|
return None
|
|
try:
|
|
await self._ensure_clients()
|
|
except Exception:
|
|
return None
|
|
if self._pool is None:
|
|
return None
|
|
|
|
query = """
|
|
SELECT
|
|
agent_id,
|
|
enabled,
|
|
public_active,
|
|
requires_whitelist,
|
|
user_global_limit,
|
|
user_global_window_seconds,
|
|
user_agent_limit,
|
|
user_agent_window_seconds,
|
|
group_agent_limit,
|
|
group_agent_window_seconds
|
|
FROM agent_access_policies
|
|
WHERE agent_id = $1
|
|
LIMIT 1
|
|
"""
|
|
try:
|
|
async with self._pool.acquire() as conn:
|
|
row = await asyncio.wait_for(conn.fetchrow(query, agent_id), timeout=timeout_s)
|
|
if row is None:
|
|
return None
|
|
return {
|
|
"agent_id": row.get("agent_id"),
|
|
"enabled": bool(row.get("enabled")),
|
|
"public_active": bool(row.get("public_active")),
|
|
"requires_whitelist": bool(row.get("requires_whitelist")),
|
|
"user_global_limit": int(row.get("user_global_limit") or 0),
|
|
"user_global_window_seconds": int(row.get("user_global_window_seconds") or 0),
|
|
"user_agent_limit": int(row.get("user_agent_limit") or 0),
|
|
"user_agent_window_seconds": int(row.get("user_agent_window_seconds") or 0),
|
|
"group_agent_limit": int(row.get("group_agent_limit") or 0),
|
|
"group_agent_window_seconds": int(row.get("group_agent_window_seconds") or 0),
|
|
}
|
|
except Exception:
|
|
return None
|
|
|
|
async def is_allowlisted(
|
|
self,
|
|
*,
|
|
platform: str,
|
|
platform_user_id: str,
|
|
agent_id: str,
|
|
timeout_s: float = 0.04,
|
|
) -> bool:
|
|
"""Return True when (platform, user, agent) exists in allowlist."""
|
|
if not self.enabled or not self.enable_db:
|
|
return False
|
|
try:
|
|
await self._ensure_clients()
|
|
except Exception:
|
|
return False
|
|
if self._pool is None:
|
|
return False
|
|
|
|
query = """
|
|
SELECT 1
|
|
FROM agent_allowlist
|
|
WHERE platform = $1
|
|
AND platform_user_id = $2
|
|
AND agent_id = $3
|
|
LIMIT 1
|
|
"""
|
|
try:
|
|
async with self._pool.acquire() as conn:
|
|
row = await asyncio.wait_for(
|
|
conn.fetchrow(query, platform, platform_user_id, agent_id),
|
|
timeout=timeout_s,
|
|
)
|
|
return row is not None
|
|
except Exception:
|
|
return False
|
|
|
|
async def close(self) -> None:
|
|
if self._pool is not None:
|
|
try:
|
|
await self._pool.close()
|
|
except Exception:
|
|
pass
|
|
self._pool = None
|
|
|
|
if self._nc is not None:
|
|
try:
|
|
await self._nc.close()
|
|
except Exception:
|
|
pass
|
|
self._nc = None
|
|
self._js = None
|
|
|
|
async def _ensure_clients(self) -> None:
|
|
async with self._lock:
|
|
if self.enable_nats and self._nc is None and nats is not None:
|
|
try:
|
|
self._nc = await nats.connect(self.nats_url)
|
|
self._js = self._nc.jetstream()
|
|
except Exception as e:
|
|
logger.debug("gateway experience nats connect failed: %s", e)
|
|
self._nc = None
|
|
self._js = None
|
|
|
|
if self.enable_db and self._pool is None and asyncpg is not None and self.db_dsn:
|
|
try:
|
|
self._pool = await asyncpg.create_pool(self.db_dsn, min_size=1, max_size=2)
|
|
except Exception as e:
|
|
logger.debug("gateway experience db pool failed: %s", e)
|
|
self._pool = None
|
|
|
|
if self._js is not None and not self._stream_ensured:
|
|
await self._ensure_stream()
|
|
|
|
async def _ensure_stream(self) -> None:
|
|
if self._js is None:
|
|
return
|
|
subjects = [f"{self.subject_prefix}.>"]
|
|
try:
|
|
await self._js.stream_info(self.stream_name)
|
|
self._stream_ensured = True
|
|
return
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
await self._js.add_stream(name=self.stream_name, subjects=subjects)
|
|
self._stream_ensured = True
|
|
except Exception as e:
|
|
logger.debug("gateway experience ensure stream failed: %s", e)
|
|
|
|
async def _publish_nats(self, event: Dict[str, Any]) -> bool:
|
|
if not self.enable_nats:
|
|
return False
|
|
if self._js is None:
|
|
return False
|
|
|
|
subject = f"{self.subject_prefix}.{event.get('agent_id', 'unknown')}"
|
|
payload = json.dumps(event, ensure_ascii=False).encode("utf-8")
|
|
msg_id = str(event.get("event_id") or "").strip()
|
|
headers = {"Nats-Msg-Id": msg_id} if msg_id else None
|
|
|
|
try:
|
|
await asyncio.wait_for(self._js.publish(subject, payload, headers=headers), timeout=self.publish_timeout_s)
|
|
return True
|
|
except Exception as e:
|
|
logger.debug("gateway experience nats publish failed: %s", e)
|
|
return False
|
|
|
|
async def _insert_db(self, event: Dict[str, Any]) -> bool:
|
|
if not self.enable_db:
|
|
return False
|
|
if self._pool is None:
|
|
return False
|
|
|
|
llm = event.get("llm") or {}
|
|
result = event.get("result") or {}
|
|
|
|
query = """
|
|
INSERT INTO agent_experience_events (
|
|
event_id,
|
|
ts,
|
|
node_id,
|
|
source,
|
|
agent_id,
|
|
task_type,
|
|
request_id,
|
|
channel,
|
|
inputs_hash,
|
|
provider,
|
|
model,
|
|
profile,
|
|
latency_ms,
|
|
tokens_in,
|
|
tokens_out,
|
|
ok,
|
|
error_class,
|
|
error_msg_redacted,
|
|
http_status,
|
|
raw
|
|
) VALUES (
|
|
$1::uuid,
|
|
$2::timestamptz,
|
|
$3,
|
|
$4,
|
|
$5,
|
|
$6,
|
|
$7,
|
|
$8,
|
|
$9,
|
|
$10,
|
|
$11,
|
|
$12,
|
|
$13,
|
|
$14,
|
|
$15,
|
|
$16,
|
|
$17,
|
|
$18,
|
|
$19,
|
|
$20::jsonb
|
|
)
|
|
ON CONFLICT (event_id) DO NOTHING
|
|
"""
|
|
|
|
try:
|
|
payload_json = json.dumps(event, ensure_ascii=False)
|
|
async with self._pool.acquire() as conn:
|
|
await asyncio.wait_for(
|
|
conn.execute(
|
|
query,
|
|
_as_uuid(event.get("event_id")),
|
|
_as_timestamptz(event.get("ts")),
|
|
event.get("node_id", self.node_id),
|
|
event.get("source", "gateway"),
|
|
event.get("agent_id"),
|
|
event.get("task_type", "webhook"),
|
|
event.get("request_id"),
|
|
event.get("channel", "telegram"),
|
|
event.get("inputs_hash"),
|
|
llm.get("provider", "gateway"),
|
|
llm.get("model", "gateway"),
|
|
llm.get("profile"),
|
|
int(llm.get("latency_ms") or 0),
|
|
_as_int_or_none(llm.get("tokens_in")),
|
|
_as_int_or_none(llm.get("tokens_out")),
|
|
bool(result.get("ok")),
|
|
result.get("error_class"),
|
|
result.get("error_msg_redacted"),
|
|
int(result.get("http_status") or 0),
|
|
payload_json,
|
|
),
|
|
timeout=self.db_timeout_s,
|
|
)
|
|
return True
|
|
except Exception as e:
|
|
logger.debug("gateway experience db insert failed: %s", e)
|
|
return False
|
|
|
|
|
|
def _as_int_or_none(value: Any) -> Optional[int]:
|
|
try:
|
|
if value is None:
|
|
return None
|
|
return int(value)
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _as_uuid(value: Any) -> uuid.UUID:
|
|
try:
|
|
return uuid.UUID(str(value))
|
|
except Exception:
|
|
return uuid.uuid4()
|
|
|
|
|
|
def _as_timestamptz(value: Any) -> datetime:
|
|
if isinstance(value, datetime):
|
|
return value if value.tzinfo is not None else value.replace(tzinfo=timezone.utc)
|
|
try:
|
|
parsed = datetime.fromisoformat(str(value).replace("Z", "+00:00"))
|
|
return parsed if parsed.tzinfo is not None else parsed.replace(tzinfo=timezone.utc)
|
|
except Exception:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
_gateway_bus_singleton: Optional[GatewayExperienceBus] = None
|
|
|
|
|
|
def get_gateway_experience_bus() -> GatewayExperienceBus:
|
|
global _gateway_bus_singleton
|
|
if _gateway_bus_singleton is None:
|
|
_gateway_bus_singleton = GatewayExperienceBus()
|
|
return _gateway_bus_singleton
|