feat(runtime): sync experience bus and learner stack into main

This commit is contained in:
Apple
2026-03-05 11:30:17 -08:00
parent edd0427c61
commit ef6ebe3583
22 changed files with 2837 additions and 22 deletions

View File

@@ -0,0 +1,12 @@
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY main.py .
EXPOSE 9109
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "9109"]

View File

@@ -0,0 +1,839 @@
from __future__ import annotations
import asyncio
import contextlib
import hashlib
import json
import logging
import os
import random
import re
import time
import uuid
from collections import OrderedDict, deque
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any, Deque, Dict, Optional, Tuple
import asyncpg
import nats
from fastapi import FastAPI, Response
from nats.aio.msg import Msg
from nats.js.api import AckPolicy, ConsumerConfig, DeliverPolicy
from prometheus_client import CONTENT_TYPE_LATEST, Counter, Gauge, generate_latest
logging.basicConfig(
level=os.getenv("LOG_LEVEL", "INFO").upper(),
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger = logging.getLogger("experience_learner")
LESSONS_EXTRACTED = Counter(
"lessons_extracted_total",
"Total lessons extracted from experience events",
["status"],
)
LESSONS_INSERT = Counter(
"lessons_insert_total",
"Total lesson insert attempts",
["status"],
)
JS_MESSAGES_ACKED = Counter(
"js_messages_acked_total",
"Total JetStream messages acked by learner",
)
JS_MESSAGES_REDELIVERED = Counter(
"js_messages_redelivered_total",
"Total redelivered JetStream messages observed by learner",
)
EVENTS_SELECTED = Counter(
"experience_learner_events_selected_total",
"Events selected for learner processing",
["reason"],
)
EVENTS_DROPPED = Counter(
"experience_learner_events_dropped_total",
"Events dropped by learner filtering/dedup",
["reason"],
)
LESSON_PUBLISH = Counter(
"experience_learner_lessons_published_total",
"Lesson publish attempts to JetStream",
["status"],
)
ANTI_SILENT_TUNING_EVALUATED = Counter(
"experience_learner_anti_silent_tuning_evaluated_total",
"Anti-silent tuning lesson generation evaluations",
["status"],
)
CONSUMER_RUNNING = Gauge(
"experience_learner_consumer_running",
"1 when learner consumer loop is running",
)
@dataclass
class EventSample:
ts_mono: float
ok: bool
latency_ms: int
class ExperienceLearner:
def __init__(self) -> None:
self.node_id = os.getenv("NODE_ID", "NODA1")
self.nats_url = os.getenv("NATS_URL", "nats://nats:4222")
self.stream_name = os.getenv("EXPERIENCE_STREAM_NAME", "EXPERIENCE")
self.subject = os.getenv("EXPERIENCE_SUBJECT", "agent.experience.v1.>")
self.lesson_subject = os.getenv("LESSON_SUBJECT", "agent.lesson.v1")
self.durable = os.getenv("EXPERIENCE_DURABLE", "experience-learner-v1")
self.deliver_policy = os.getenv("EXPERIENCE_DELIVER_POLICY", "all").lower()
self.ack_wait_s = float(os.getenv("EXPERIENCE_ACK_WAIT_SECONDS", "30"))
self.max_deliver = int(os.getenv("EXPERIENCE_MAX_DELIVER", "20"))
self.fetch_batch = int(os.getenv("EXPERIENCE_FETCH_BATCH", "64"))
self.fetch_timeout_s = float(os.getenv("EXPERIENCE_FETCH_TIMEOUT_SECONDS", "2"))
self.window_s = int(os.getenv("EXPERIENCE_WINDOW_SECONDS", "1800"))
self.ok_sample_pct = float(os.getenv("EXPERIENCE_OK_SAMPLE_PCT", "10"))
self.latency_spike_ms = int(os.getenv("EXPERIENCE_LATENCY_SPIKE_MS", "5000"))
self.error_threshold = int(os.getenv("EXPERIENCE_ERROR_THRESHOLD", "3"))
self.silent_threshold = int(os.getenv("EXPERIENCE_SILENT_THRESHOLD", "5"))
self.latency_threshold = int(os.getenv("EXPERIENCE_LATENCY_THRESHOLD", "3"))
self.event_dedup_ttl_s = int(os.getenv("EXPERIENCE_EVENT_DEDUP_TTL_SECONDS", "3600"))
self.event_dedup_max = int(os.getenv("EXPERIENCE_EVENT_DEDUP_MAX", "100000"))
self.publish_lessons = os.getenv("LESSON_PUBLISH_ENABLED", "true").lower() in {"1", "true", "yes"}
self.anti_silent_tuning_enabled = os.getenv("ANTI_SILENT_TUNING_ENABLED", "true").lower() in {"1", "true", "yes"}
self.anti_silent_window_days = max(1, int(os.getenv("ANTI_SILENT_TUNING_WINDOW_DAYS", "7")))
self.anti_silent_min_evidence = max(1, int(os.getenv("ANTI_SILENT_TUNING_MIN_EVIDENCE", "20")))
self.anti_silent_min_score = max(0.0, min(1.0, float(os.getenv("ANTI_SILENT_TUNING_MIN_SCORE", "0.75"))))
self.anti_silent_weight_retry = max(0.0, min(1.0, float(os.getenv("ANTI_SILENT_TUNING_WEIGHT_RETRY", "0.6"))))
self.anti_silent_weight_negative = max(0.0, min(1.0, float(os.getenv("ANTI_SILENT_TUNING_WEIGHT_NEGATIVE", "0.3"))))
self.anti_silent_weight_suppressed = max(0.0, min(1.0, float(os.getenv("ANTI_SILENT_TUNING_WEIGHT_SUPPRESSED", "0.1"))))
self.anti_silent_ttl_days = max(1, int(os.getenv("ANTI_SILENT_TUNING_TTL_DAYS", "7")))
self.db_dsn = (
os.getenv("LEARNER_DATABASE_URL")
or os.getenv("EXPERIENCE_DATABASE_URL")
or os.getenv("DATABASE_URL")
)
if not self.db_dsn:
raise RuntimeError("LEARNER_DATABASE_URL (or EXPERIENCE_DATABASE_URL/DATABASE_URL) is required")
self._running = False
self._task: Optional[asyncio.Task[Any]] = None
self._nc = None
self._js = None
self._sub = None
self._pool: Optional[asyncpg.Pool] = None
self._seen_events: "OrderedDict[str, float]" = OrderedDict()
self._buckets: Dict[str, Deque[EventSample]] = {}
self._lock = asyncio.Lock()
async def start(self) -> None:
if self._running:
return
self._pool = await asyncpg.create_pool(self.db_dsn, min_size=1, max_size=4)
self._nc = await nats.connect(self.nats_url)
self._js = self._nc.jetstream()
await self._ensure_consumer()
self._sub = await self._js.pull_subscribe(
self.subject,
durable=self.durable,
stream=self.stream_name,
)
self._running = True
CONSUMER_RUNNING.set(1)
self._task = asyncio.create_task(self._consume_loop(), name="experience-learner")
logger.info(
"experience-learner started stream=%s subject=%s durable=%s",
self.stream_name,
self.subject,
self.durable,
)
async def stop(self) -> None:
self._running = False
CONSUMER_RUNNING.set(0)
if self._task:
self._task.cancel()
with contextlib.suppress(Exception):
await self._task
self._task = None
if self._nc:
await self._nc.close()
self._nc = None
self._js = None
self._sub = None
if self._pool:
await self._pool.close()
self._pool = None
async def _ensure_consumer(self) -> None:
if self._js is None:
return
deliver_policy = DeliverPolicy.ALL if self.deliver_policy == "all" else DeliverPolicy.NEW
cfg = ConsumerConfig(
durable_name=self.durable,
ack_policy=AckPolicy.EXPLICIT,
ack_wait=self.ack_wait_s,
max_deliver=self.max_deliver,
deliver_policy=deliver_policy,
filter_subject=self.subject,
)
try:
await self._js.add_consumer(self.stream_name, config=cfg)
logger.info("consumer created durable=%s stream=%s", self.durable, self.stream_name)
except Exception as exc:
msg = str(exc).lower()
if "consumer name already in use" in msg or "consumer already exists" in msg:
logger.info("consumer exists durable=%s stream=%s", self.durable, self.stream_name)
else:
raise
async def _consume_loop(self) -> None:
assert self._sub is not None
while self._running:
try:
msgs = await self._sub.fetch(self.fetch_batch, timeout=self.fetch_timeout_s)
except asyncio.TimeoutError:
continue
except Exception as exc:
logger.warning("fetch failed: %s", exc)
await asyncio.sleep(1.0)
continue
for msg in msgs:
await self._handle_msg(msg)
async def _handle_msg(self, msg: Msg) -> None:
try:
metadata = getattr(msg, "metadata", None)
if metadata is not None and getattr(metadata, "num_delivered", 1) > 1:
JS_MESSAGES_REDELIVERED.inc()
event = json.loads(msg.data.decode("utf-8", errors="replace"))
if not isinstance(event, dict):
EVENTS_DROPPED.labels(reason="invalid_payload").inc()
await msg.ack()
JS_MESSAGES_ACKED.inc()
return
event_id = str(event.get("event_id") or "").strip()
if event_id and await self._seen_event(event_id):
EVENTS_DROPPED.labels(reason="event_dedup").inc()
await msg.ack()
JS_MESSAGES_ACKED.inc()
return
keep, reason = self._should_keep(event)
if not keep:
EVENTS_DROPPED.labels(reason=reason).inc()
await msg.ack()
JS_MESSAGES_ACKED.inc()
return
EVENTS_SELECTED.labels(reason=reason).inc()
lessons = await self._extract_lessons(event)
if not lessons:
await msg.ack()
JS_MESSAGES_ACKED.inc()
return
for lesson in lessons:
LESSONS_EXTRACTED.labels(status="ok").inc()
insert_status = await self._insert_lesson(lesson)
LESSONS_INSERT.labels(status=insert_status).inc()
if insert_status == "ok" and self.publish_lessons:
await self._publish_lesson(lesson)
await msg.ack()
JS_MESSAGES_ACKED.inc()
except Exception as exc:
LESSONS_EXTRACTED.labels(status="err").inc()
logger.exception("message handling failed: %s", exc)
with contextlib.suppress(Exception):
await msg.nak()
async def _seen_event(self, event_id: str) -> bool:
now = time.monotonic()
async with self._lock:
self._prune_seen(now)
seen_ts = self._seen_events.get(event_id)
if seen_ts is not None and (now - seen_ts) < self.event_dedup_ttl_s:
return True
self._seen_events[event_id] = now
self._seen_events.move_to_end(event_id, last=True)
while len(self._seen_events) > self.event_dedup_max:
self._seen_events.popitem(last=False)
return False
def _prune_seen(self, now: float) -> None:
threshold = now - self.event_dedup_ttl_s
while self._seen_events:
_, ts = next(iter(self._seen_events.items()))
if ts >= threshold:
break
self._seen_events.popitem(last=False)
def _should_keep(self, event: Dict[str, Any]) -> Tuple[bool, str]:
result = event.get("result") or {}
llm = event.get("llm") or {}
policy = event.get("policy") or {}
ok = bool(result.get("ok"))
status = _as_int(result.get("http_status"), 0)
latency_ms = _as_int(llm.get("latency_ms"), 0)
sowa_decision = str(policy.get("sowa_decision") or "").upper()
if not ok:
return True, "error"
if self._is_anti_silent_gateway_event(event):
return True, "anti_silent_signal"
if sowa_decision == "SILENT":
return True, "silent"
if status >= 500:
return True, "http_5xx"
if latency_ms >= self.latency_spike_ms:
return True, "latency_spike"
if random.random() * 100.0 < self.ok_sample_pct:
return True, "ok_sample_in"
return False, "ok_sample_out"
def _is_anti_silent_gateway_event(self, event: Dict[str, Any]) -> bool:
if not self.anti_silent_tuning_enabled:
return False
if str(event.get("source") or "").strip().lower() != "gateway":
return False
action = str(event.get("anti_silent_action") or "").strip().upper()
if action not in {"ACK_EMITTED", "ACK_SUPPRESSED_COOLDOWN"}:
return False
reason = str((event.get("policy") or {}).get("reason") or "").strip()
template_id = str(event.get("anti_silent_template") or "").strip()
return bool(reason and template_id)
async def _extract_lessons(self, event: Dict[str, Any]) -> list[Dict[str, Any]]:
lessons: list[Dict[str, Any]] = []
tuning_lesson = await self._try_extract_anti_silent_tuning_lesson(event)
if tuning_lesson is not None:
lessons.append(tuning_lesson)
operational_lesson = await self._try_extract_lesson(event)
if operational_lesson is not None:
lessons.append(operational_lesson)
return lessons
async def _try_extract_lesson(self, event: Dict[str, Any]) -> Optional[Dict[str, Any]]:
categories = self._lesson_categories(event)
if not categories:
return None
for category in categories:
bucket_key = self._bucket_key(category, event)
count, ok_rate, p95_latency = self._update_bucket(bucket_key, event)
threshold = self._threshold_for(category)
if count < threshold:
continue
lesson = self._build_lesson(
category=category,
event=event,
count=count,
ok_rate=ok_rate,
p95_latency=p95_latency,
)
return lesson
return None
async def _try_extract_anti_silent_tuning_lesson(self, event: Dict[str, Any]) -> Optional[Dict[str, Any]]:
if not self._is_anti_silent_gateway_event(event):
return None
if self._pool is None:
ANTI_SILENT_TUNING_EVALUATED.labels(status="pool_missing").inc()
return None
policy = event.get("policy") or {}
reason = _safe_token(policy.get("reason"))
chat_type = _safe_token(event.get("chat_type"))
if not reason or not chat_type:
ANTI_SILENT_TUNING_EVALUATED.labels(status="missing_fields").inc()
return None
stats = await self._anti_silent_stats(reason=reason, chat_type=chat_type)
if not stats:
ANTI_SILENT_TUNING_EVALUATED.labels(status="no_data").inc()
return None
candidates = [item for item in stats if int(item.get("n") or 0) >= self.anti_silent_min_evidence]
if not candidates:
ANTI_SILENT_TUNING_EVALUATED.labels(status="insufficient_evidence").inc()
return None
best = max(candidates, key=lambda item: (float(item.get("score", 0.0)), int(item.get("n", 0))))
best_score = float(best.get("score", 0.0))
if best_score < self.anti_silent_min_score:
ANTI_SILENT_TUNING_EVALUATED.labels(status="below_score").inc()
return None
worst = min(candidates, key=lambda item: (float(item.get("score", 0.0)), -int(item.get("n", 0))))
best_template = str(best.get("template_id") or "").strip().upper()
if not best_template:
ANTI_SILENT_TUNING_EVALUATED.labels(status="bad_template").inc()
return None
trigger = f"reason={reason};chat_type={chat_type}"
action = f"prefer_template={best_template}"
avoid = ""
worst_template = str(worst.get("template_id") or "").strip().upper()
if worst_template and worst_template != best_template:
avoid = f"avoid_template={worst_template}"
if not avoid:
avoid = "avoid_template=none"
lesson_type = "anti_silent_tuning"
lesson_key_raw = "|".join([lesson_type, trigger, action])
lesson_key = hashlib.sha256(lesson_key_raw.encode("utf-8")).hexdigest()
now_dt = datetime.now(timezone.utc)
expires_at = (now_dt + timedelta(days=self.anti_silent_ttl_days)).isoformat().replace("+00:00", "Z")
evidence = {
"n_best": int(best.get("n") or 0),
"score_best": round(best_score, 6),
"retry_rate": round(float(best.get("retry_rate", 0.0)), 6),
"negative_rate": round(float(best.get("negative_rate", 0.0)), 6),
"suppressed_rate": round(float(best.get("suppressed_rate", 0.0)), 6),
"window_days": self.anti_silent_window_days,
"weights": {
"retry": self.anti_silent_weight_retry,
"negative": self.anti_silent_weight_negative,
"suppressed": self.anti_silent_weight_suppressed,
},
"candidates": stats,
}
signals = {
"policy_reason": reason,
"chat_type": chat_type,
"lesson_type": lesson_type,
"trigger_kind": "anti_silent_ack_template",
}
lesson: Dict[str, Any] = {
"lesson_id": str(uuid.uuid4()),
"lesson_key": lesson_key,
"lesson_type": lesson_type,
"ts": now_dt.isoformat().replace("+00:00", "Z"),
"expires_at": expires_at,
"scope": "global",
"agent_id": None,
"task_type": "webhook",
"trigger": trigger,
"action": action,
"avoid": avoid,
"signals": signals,
"evidence": evidence,
}
lesson["raw"] = dict(lesson)
ANTI_SILENT_TUNING_EVALUATED.labels(status="ok").inc()
return lesson
async def _anti_silent_stats(self, *, reason: str, chat_type: str) -> list[Dict[str, Any]]:
if self._pool is None:
return []
query = """
SELECT
COALESCE(raw->>'anti_silent_template', '') AS template_id,
COUNT(*)::int AS n,
AVG(
CASE
WHEN COALESCE(raw->'feedback'->>'user_signal', 'none') = 'retry' THEN 1.0
ELSE 0.0
END
)::float8 AS retry_rate,
AVG(
CASE
WHEN COALESCE(raw->'feedback'->>'user_signal', 'none') = 'negative' THEN 1.0
ELSE 0.0
END
)::float8 AS negative_rate,
AVG(
CASE
WHEN COALESCE(raw->>'anti_silent_action', '') = 'ACK_SUPPRESSED_COOLDOWN' THEN 1.0
ELSE 0.0
END
)::float8 AS suppressed_rate
FROM agent_experience_events
WHERE source = 'gateway'
AND ts >= (now() - ($1::int * interval '1 day'))
AND COALESCE(raw->'policy'->>'reason', '') = $2
AND COALESCE(raw->>'chat_type', 'unknown') = $3
AND COALESCE(raw->>'anti_silent_action', '') IN ('ACK_EMITTED', 'ACK_SUPPRESSED_COOLDOWN')
AND COALESCE(raw->>'anti_silent_template', '') <> ''
GROUP BY 1
HAVING COUNT(*) >= $4
"""
try:
async with self._pool.acquire() as conn:
rows = await conn.fetch(
query,
self.anti_silent_window_days,
reason,
chat_type,
self.anti_silent_min_evidence,
)
except Exception as exc:
logger.warning("anti-silent stats query failed: %s", exc)
return []
results: list[Dict[str, Any]] = []
for row in rows:
template_id = str(row.get("template_id") or "").strip().upper()
if not template_id:
continue
n = int(row.get("n") or 0)
retry_rate = float(row.get("retry_rate") or 0.0)
negative_rate = float(row.get("negative_rate") or 0.0)
suppressed_rate = float(row.get("suppressed_rate") or 0.0)
score = 1.0 - (
self.anti_silent_weight_retry * retry_rate
+ self.anti_silent_weight_negative * negative_rate
+ self.anti_silent_weight_suppressed * suppressed_rate
)
score = max(0.0, min(1.0, score))
results.append(
{
"template_id": template_id,
"n": n,
"retry_rate": retry_rate,
"negative_rate": negative_rate,
"suppressed_rate": suppressed_rate,
"score": score,
}
)
return results
def _lesson_categories(self, event: Dict[str, Any]) -> list[str]:
result = event.get("result") or {}
llm = event.get("llm") or {}
policy = event.get("policy") or {}
categories: list[str] = []
if not bool(result.get("ok")):
categories.append("error_repeat")
if str(policy.get("sowa_decision") or "").upper() == "SILENT":
categories.append("silent_repeat")
if _as_int(llm.get("latency_ms"), 0) >= self.latency_spike_ms:
categories.append("latency_spike")
return categories
def _bucket_key(self, category: str, event: Dict[str, Any]) -> str:
llm = event.get("llm") or {}
result = event.get("result") or {}
policy = event.get("policy") or {}
parts = [
category,
str(event.get("agent_id") or ""),
str(event.get("task_type") or "infer"),
str(result.get("error_class") or ""),
str(policy.get("reason") or ""),
str(llm.get("provider") or ""),
str(llm.get("model") or ""),
str(llm.get("profile") or ""),
]
return "|".join(parts)
def _update_bucket(self, bucket_key: str, event: Dict[str, Any]) -> Tuple[int, Optional[float], Optional[int]]:
now = time.monotonic()
llm = event.get("llm") or {}
result = event.get("result") or {}
sample = EventSample(
ts_mono=now,
ok=bool(result.get("ok")),
latency_ms=_as_int(llm.get("latency_ms"), 0),
)
bucket = self._buckets.get(bucket_key)
if bucket is None:
bucket = deque()
self._buckets[bucket_key] = bucket
bucket.append(sample)
cutoff = now - self.window_s
while bucket and bucket[0].ts_mono < cutoff:
bucket.popleft()
if not bucket:
return 0, None, None
count = len(bucket)
ok_count = sum(1 for item in bucket if item.ok)
ok_rate = round(ok_count / count, 4) if count > 0 else None
latencies = sorted(item.latency_ms for item in bucket)
p95_latency = _p95(latencies)
return count, ok_rate, p95_latency
def _threshold_for(self, category: str) -> int:
if category == "error_repeat":
return self.error_threshold
if category == "silent_repeat":
return self.silent_threshold
return self.latency_threshold
def _build_lesson(
self,
category: str,
event: Dict[str, Any],
count: int,
ok_rate: Optional[float],
p95_latency: Optional[int],
) -> Dict[str, Any]:
llm = event.get("llm") or {}
result = event.get("result") or {}
policy = event.get("policy") or {}
agent_id = str(event.get("agent_id") or "").strip() or None
task_type = str(event.get("task_type") or "infer")
error_class = _safe_token(result.get("error_class"))
policy_reason = _safe_token(policy.get("reason"))
sowa_decision = _safe_token(policy.get("sowa_decision"))
if category == "silent_repeat":
trigger = "Frequent SILENT policy outcomes on active conversation flow."
action = "Use short ACK/CHALLENGE clarification before silencing response."
avoid = "Avoid immediate SILENT when user intent might target the agent."
elif category == "latency_spike":
trigger = "Repeated latency spikes above configured SLA threshold."
action = "Prefer faster model/profile and reduce expensive tool rounds."
avoid = "Avoid routing to slow provider/profile for same task pattern."
else:
trigger = f"Repeated inference failures of class '{error_class or 'unknown_error'}'."
action = "Switch to stable provider/profile and constrain optional tool calls."
avoid = "Avoid blind retries on the same failing route."
scope = "agent" if agent_id else "global"
lesson_key_raw = "|".join(
[
scope,
str(agent_id or ""),
trigger,
action,
avoid,
str(error_class or ""),
str(policy_reason or ""),
]
)
lesson_key = hashlib.sha256(lesson_key_raw.encode("utf-8")).hexdigest()
lesson = {
"lesson_id": str(uuid.uuid4()),
"lesson_key": lesson_key,
"ts": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
"scope": scope,
"agent_id": agent_id,
"task_type": task_type,
"trigger": trigger,
"action": action,
"avoid": avoid,
"signals": {
"policy_reason": policy_reason,
"policy_decision": sowa_decision,
"error_class": error_class,
"provider": _safe_token(llm.get("provider")),
"model": _safe_token(llm.get("model")),
"profile": _safe_token(llm.get("profile")),
},
"evidence": {
"count": count,
"ok_rate": ok_rate,
"p95_latency_ms": p95_latency,
},
}
lesson["raw"] = dict(lesson)
return lesson
async def _insert_lesson(self, lesson: Dict[str, Any]) -> str:
if self._pool is None:
return "err"
query_insert = """
INSERT INTO agent_lessons (
lesson_id,
lesson_key,
ts,
scope,
agent_id,
task_type,
trigger,
action,
avoid,
signals,
evidence,
raw
) VALUES (
$1::uuid,
$2,
$3::timestamptz,
$4,
$5,
$6,
$7,
$8,
$9,
$10::jsonb,
$11::jsonb,
$12::jsonb
)
ON CONFLICT (lesson_key) DO NOTHING
RETURNING id
"""
query_tuning_upsert = """
INSERT INTO agent_lessons (
lesson_id,
lesson_key,
ts,
scope,
agent_id,
task_type,
trigger,
action,
avoid,
signals,
evidence,
raw
) VALUES (
$1::uuid,
$2,
$3::timestamptz,
$4,
$5,
$6,
$7,
$8,
$9,
$10::jsonb,
$11::jsonb,
$12::jsonb
)
ON CONFLICT (lesson_key) DO UPDATE SET
ts = EXCLUDED.ts,
scope = EXCLUDED.scope,
agent_id = EXCLUDED.agent_id,
task_type = EXCLUDED.task_type,
trigger = EXCLUDED.trigger,
action = EXCLUDED.action,
avoid = EXCLUDED.avoid,
signals = EXCLUDED.signals,
evidence = EXCLUDED.evidence,
raw = EXCLUDED.raw
RETURNING id
"""
try:
lesson_id = uuid.UUID(str(lesson["lesson_id"]))
ts_value = _as_timestamptz(lesson["ts"])
lesson_type = str(lesson.get("lesson_type") or "").strip().lower()
query = query_tuning_upsert if lesson_type == "anti_silent_tuning" else query_insert
async with self._pool.acquire() as conn:
row_id = await conn.fetchval(
query,
lesson_id,
lesson["lesson_key"],
ts_value,
lesson["scope"],
lesson.get("agent_id"),
lesson["task_type"],
lesson["trigger"],
lesson["action"],
lesson["avoid"],
json.dumps(lesson["signals"], ensure_ascii=False),
json.dumps(lesson["evidence"], ensure_ascii=False),
json.dumps(lesson, ensure_ascii=False),
)
if row_id is None:
return "conflict"
return "ok"
except Exception as exc:
logger.warning("insert lesson failed: %s", exc)
return "err"
async def _publish_lesson(self, lesson: Dict[str, Any]) -> None:
if self._js is None:
LESSON_PUBLISH.labels(status="skipped").inc()
return
payload = json.dumps(lesson, ensure_ascii=False).encode("utf-8")
headers = {"Nats-Msg-Id": str(lesson["lesson_id"])}
try:
await self._js.publish(self.lesson_subject, payload, headers=headers)
LESSON_PUBLISH.labels(status="ok").inc()
except Exception as exc:
LESSON_PUBLISH.labels(status="err").inc()
logger.warning("publish lesson failed: %s", exc)
async def health(self) -> Dict[str, Any]:
return {
"status": "ok" if self._running else "starting",
"node_id": self.node_id,
"stream": self.stream_name,
"subject": self.subject,
"durable": self.durable,
"nats_connected": self._nc is not None and self._nc.is_connected,
"db_connected": self._pool is not None,
"running": self._running,
}
def _safe_token(value: Any) -> Optional[str]:
if value is None:
return None
text = str(value)
text = re.sub(r"(?i)bearer\s+[A-Za-z0-9._-]+", "bearer [redacted]", text)
text = re.sub(r"(?i)(api[_-]?key|token|password|secret)\s*[:=]\s*[^\s,;]+", r"\1=[redacted]", text)
text = re.sub(r"[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}", "[redacted-email]", text)
text = re.sub(r"https?://[^\s]+", "[redacted-url]", text)
text = re.sub(r"\s+", " ", text).strip()
return text[:180] if text else None
def _as_int(value: Any, default: int) -> int:
try:
return int(value)
except Exception:
return default
def _as_timestamptz(value: Any) -> datetime:
if isinstance(value, datetime):
return value if value.tzinfo is not None else value.replace(tzinfo=timezone.utc)
try:
parsed = datetime.fromisoformat(str(value).replace("Z", "+00:00"))
return parsed if parsed.tzinfo is not None else parsed.replace(tzinfo=timezone.utc)
except Exception:
return datetime.now(timezone.utc)
def _p95(sorted_values: list[int]) -> Optional[int]:
if not sorted_values:
return None
idx = int(round(0.95 * (len(sorted_values) - 1)))
return sorted_values[min(max(idx, 0), len(sorted_values) - 1)]
app = FastAPI(title="Experience Learner")
learner = ExperienceLearner()
@app.on_event("startup")
async def startup() -> None:
await learner.start()
@app.on_event("shutdown")
async def shutdown() -> None:
await learner.stop()
@app.get("/health")
async def health() -> Dict[str, Any]:
return await learner.health()
@app.get("/metrics")
async def metrics() -> Response:
return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)

View File

@@ -0,0 +1,5 @@
fastapi==0.104.1
uvicorn[standard]==0.24.0
nats-py==2.6.0
asyncpg>=0.29.0
prometheus-client>=0.20.0

View File

@@ -443,6 +443,9 @@ class Database:
) -> Dict[str, Any]:
"""Create or update a user fact (isolated by agent_id)"""
import json
# Normalize NULL to empty string so ON CONFLICT matches (PostgreSQL: NULL != NULL in unique)
_agent_id = (agent_id or "").strip()
_team_id = team_id or ""
# Convert dict to JSON string for asyncpg JSONB
json_value = json.dumps(fact_value_json) if fact_value_json else None
@@ -457,7 +460,7 @@ class Database:
fact_value_json = EXCLUDED.fact_value_json,
updated_at = NOW()
RETURNING *
""", user_id, team_id, agent_id, fact_key, fact_value, json_value)
""", user_id, _team_id, _agent_id, fact_key, fact_value, json_value)
except asyncpg.exceptions.InvalidColumnReferenceError:
# Backward compatibility for DBs that only have UNIQUE(user_id, team_id, fact_key).
row = await conn.fetchrow("""
@@ -470,7 +473,7 @@ class Database:
fact_value_json = EXCLUDED.fact_value_json,
updated_at = NOW()
RETURNING *
""", user_id, team_id, agent_id, fact_key, fact_value, json_value)
""", user_id, _team_id, _agent_id, fact_key, fact_value, json_value)
return dict(row) if row else {}

View File

@@ -209,6 +209,57 @@ if PROMETHEUS_AVAILABLE:
registry=REGISTRY
)
# ==================== EXPERIENCE BUS METRICS ====================
EXPERIENCE_PUBLISHED = Counter(
'experience_published_total',
'Total experience events publish attempts',
['source', 'transport', 'status'], # transport: jetstream|core|none
registry=REGISTRY
)
EXPERIENCE_DB_INSERT = Counter(
'experience_db_insert_total',
'Total experience event DB insert attempts',
['source', 'status'], # status: ok|error|skipped
registry=REGISTRY
)
EXPERIENCE_DEDUP_DROPPED = Counter(
'experience_dedup_dropped_total',
'Total experience events dropped by dedup',
['source'],
registry=REGISTRY
)
EXPERIENCE_SAMPLED = Counter(
'experience_sampled_total',
'Total experience events sampled in/out',
['source', 'decision', 'reason'], # decision: in|out
registry=REGISTRY
)
LESSONS_RETRIEVED = Counter(
'lessons_retrieved_total',
'Total lessons retrieval attempts',
['status'], # status: ok|timeout|err
registry=REGISTRY
)
LESSONS_ATTACHED = Counter(
'lessons_attached_total',
'Total lessons attached buckets',
['count'], # count: 0|1-3|4-7
registry=REGISTRY
)
LESSONS_ATTACH_LATENCY = Histogram(
'lessons_attach_latency_ms',
'Lessons retrieval latency in milliseconds',
buckets=(1, 2, 5, 10, 25, 50, 100, 250, 500, 1000, 2500),
registry=REGISTRY
)
# ==================== METRIC HELPERS ====================
@@ -357,6 +408,61 @@ def track_agent_request(agent_id: str, operation: str):
return decorator
def inc_experience_published(source: str, transport: str, status: str) -> None:
if not PROMETHEUS_AVAILABLE:
return
EXPERIENCE_PUBLISHED.labels(source=source, transport=transport, status=status).inc()
def inc_experience_db_insert(source: str, status: str) -> None:
if not PROMETHEUS_AVAILABLE:
return
EXPERIENCE_DB_INSERT.labels(source=source, status=status).inc()
def inc_experience_dedup_dropped(source: str) -> None:
if not PROMETHEUS_AVAILABLE:
return
EXPERIENCE_DEDUP_DROPPED.labels(source=source).inc()
def inc_experience_sampled(source: str, decision: str, reason: str) -> None:
if not PROMETHEUS_AVAILABLE:
return
EXPERIENCE_SAMPLED.labels(source=source, decision=decision, reason=reason).inc()
def inc_lessons_retrieved(status: str) -> None:
if not PROMETHEUS_AVAILABLE:
return
LESSONS_RETRIEVED.labels(status=status).inc()
def inc_lessons_attached(count: int) -> None:
if not PROMETHEUS_AVAILABLE:
return
try:
n = int(count)
except Exception:
n = 0
if n <= 0:
bucket = "0"
elif n <= 3:
bucket = "1-3"
else:
bucket = "4-7"
LESSONS_ATTACHED.labels(count=bucket).inc()
def observe_lessons_attach_latency(latency_ms: float) -> None:
if not PROMETHEUS_AVAILABLE:
return
try:
LESSONS_ATTACH_LATENCY.observe(float(latency_ms))
except Exception:
return
# ==================== GPU METRICS COLLECTOR ====================
async def collect_gpu_metrics(node_id: str = "node1"):

View File

@@ -0,0 +1,446 @@
"""Router experience event bus (Phase-1).
Collects inference outcome events, applies sampling + dedup, then
persists to JetStream and Postgres in async background worker.
"""
from __future__ import annotations
import asyncio
from datetime import datetime, timezone
import json
import logging
import os
import random
import re
import time
import uuid
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
try:
import asyncpg
except ImportError: # pragma: no cover - runtime dependency in container
asyncpg = None
try:
from agent_metrics import (
inc_experience_db_insert,
inc_experience_dedup_dropped,
inc_experience_published,
inc_experience_sampled,
)
except Exception: # pragma: no cover - keep router resilient
def inc_experience_published(*args: Any, **kwargs: Any) -> None: # type: ignore[override]
return None
def inc_experience_db_insert(*args: Any, **kwargs: Any) -> None: # type: ignore[override]
return None
def inc_experience_dedup_dropped(*args: Any, **kwargs: Any) -> None: # type: ignore[override]
return None
def inc_experience_sampled(*args: Any, **kwargs: Any) -> None: # type: ignore[override]
return None
logger = logging.getLogger("experience_bus")
@dataclass
class ExperienceDecision:
keep: bool
reason: str
class ExperienceBus:
def __init__(self) -> None:
self.enabled = os.getenv("EXPERIENCE_BUS_ENABLED", "true").lower() in {"1", "true", "yes"}
self.node_id = os.getenv("NODE_ID", "NODA1")
self.ok_sample_pct = float(os.getenv("EXPERIENCE_OK_SAMPLE_PCT", "10"))
self.latency_spike_ms = int(os.getenv("EXPERIENCE_LATENCY_SPIKE_MS", "5000"))
self.dedup_window_s = int(os.getenv("EXPERIENCE_DEDUP_WINDOW_SECONDS", "900"))
self.dedup_max_keys = int(os.getenv("EXPERIENCE_DEDUP_MAX_KEYS", "20000"))
self.queue_max = int(os.getenv("EXPERIENCE_QUEUE_MAX", "2000"))
self.publish_timeout_s = float(os.getenv("EXPERIENCE_PUBLISH_TIMEOUT_MS", "800") or 800) / 1000.0
self.db_timeout_s = float(os.getenv("EXPERIENCE_DB_TIMEOUT_MS", "1200") or 1200) / 1000.0
self.subject_prefix = os.getenv("EXPERIENCE_SUBJECT_PREFIX", "agent.experience.v1")
self.stream_name = os.getenv("EXPERIENCE_STREAM_NAME", "EXPERIENCE")
self.enable_nats = os.getenv("EXPERIENCE_ENABLE_NATS", "true").lower() in {"1", "true", "yes"}
self.enable_db = os.getenv("EXPERIENCE_ENABLE_DB", "true").lower() in {"1", "true", "yes"}
self.db_dsn = os.getenv("EXPERIENCE_DATABASE_URL") or os.getenv("DATABASE_URL")
self._queue: asyncio.Queue[Optional[Dict[str, Any]]] = asyncio.Queue(maxsize=self.queue_max)
self._worker_task: Optional[asyncio.Task[Any]] = None
self._running = False
self._dedup_lock = asyncio.Lock()
self._dedup: "OrderedDict[str, float]" = OrderedDict()
self._pool: Optional[Any] = None
self._nc: Any = None
self._js: Any = None
async def start(self, nats_client: Any = None) -> None:
if not self.enabled:
logger.info("ExperienceBus disabled by env")
return
if self._running:
return
if self.enable_db:
await self._init_db()
if self.enable_nats and nats_client is not None:
await self.set_nats_client(nats_client)
self._running = True
self._worker_task = asyncio.create_task(self._worker(), name="experience-bus-worker")
logger.info(
"ExperienceBus started (db=%s nats=%s queue_max=%s sample_ok=%s%% dedup_window=%ss)",
bool(self._pool),
bool(self._js or self._nc),
self.queue_max,
self.ok_sample_pct,
self.dedup_window_s,
)
async def stop(self) -> None:
if not self._running:
return
self._running = False
try:
self._queue.put_nowait(None)
except asyncio.QueueFull:
pass
if self._worker_task is not None:
try:
await asyncio.wait_for(self._worker_task, timeout=5.0)
except Exception:
self._worker_task.cancel()
self._worker_task = None
if self._pool is not None:
try:
await self._pool.close()
except Exception as e: # pragma: no cover
logger.debug("ExperienceBus pool close error: %s", e)
self._pool = None
self._js = None
self._nc = None
logger.info("ExperienceBus stopped")
async def set_nats_client(self, nats_client: Any) -> None:
if not self.enabled or not self.enable_nats:
return
self._nc = nats_client
if self._nc is None:
self._js = None
return
try:
self._js = self._nc.jetstream()
await self._ensure_stream()
except Exception as e:
self._js = None
logger.warning("ExperienceBus JetStream unavailable: %s", e)
async def capture(self, event: Dict[str, Any]) -> None:
"""Apply sampling/dedup and enqueue for async persistence."""
if not self.enabled or not self._running:
return
decision = await self._decide(event)
if not decision.keep:
if decision.reason == "dedup":
inc_experience_dedup_dropped(source="router")
inc_experience_sampled(source="router", decision="out", reason=decision.reason)
return
inc_experience_sampled(source="router", decision="in", reason=decision.reason)
try:
self._queue.put_nowait(event)
except asyncio.QueueFull:
inc_experience_sampled(source="router", decision="out", reason="queue_full")
logger.warning("ExperienceBus queue full; dropping event")
async def _init_db(self) -> None:
if asyncpg is None:
logger.warning("ExperienceBus DB disabled: asyncpg not installed")
return
if not self.db_dsn:
logger.warning("ExperienceBus DB disabled: DATABASE_URL missing")
return
try:
self._pool = await asyncpg.create_pool(self.db_dsn, min_size=1, max_size=3)
except Exception as e:
self._pool = None
logger.warning("ExperienceBus DB pool init failed: %s", e)
async def _ensure_stream(self) -> None:
if self._js is None:
return
subjects = [f"{self.subject_prefix}.>"]
try:
info = await self._js.stream_info(self.stream_name)
stream_subjects = set(getattr(info.config, "subjects", []) or [])
if not stream_subjects.intersection(subjects):
logger.warning(
"ExperienceBus stream '%s' exists without subject %s; keeping as-is",
self.stream_name,
subjects[0],
)
return
except Exception:
pass
try:
await self._js.add_stream(name=self.stream_name, subjects=subjects)
logger.info("ExperienceBus stream ensured: %s subjects=%s", self.stream_name, subjects)
except Exception as e:
logger.warning("ExperienceBus stream ensure failed: %s", e)
async def _decide(self, event: Dict[str, Any]) -> ExperienceDecision:
result = event.get("result") or {}
llm = event.get("llm") or {}
ok = bool(result.get("ok"))
latency_ms = int(llm.get("latency_ms") or 0)
if not ok:
sample_reason = "error"
elif latency_ms >= self.latency_spike_ms:
sample_reason = "latency_spike"
else:
if random.random() * 100.0 >= self.ok_sample_pct:
return ExperienceDecision(False, "ok_sample_out")
sample_reason = "ok_sample_in"
dedup_key = self._dedup_key(event)
now = time.monotonic()
async with self._dedup_lock:
self._prune_dedup(now)
seen_at = self._dedup.get(dedup_key)
if seen_at is not None and (now - seen_at) < self.dedup_window_s:
return ExperienceDecision(False, "dedup")
self._dedup[dedup_key] = now
self._dedup.move_to_end(dedup_key, last=True)
while len(self._dedup) > self.dedup_max_keys:
self._dedup.popitem(last=False)
return ExperienceDecision(True, sample_reason)
def _prune_dedup(self, now: float) -> None:
if not self._dedup:
return
threshold = now - self.dedup_window_s
while self._dedup:
_, ts = next(iter(self._dedup.items()))
if ts >= threshold:
break
self._dedup.popitem(last=False)
def _dedup_key(self, event: Dict[str, Any]) -> str:
result = event.get("result") or {}
return "|".join(
[
str(event.get("agent_id") or ""),
str(event.get("task_type") or ""),
str(event.get("inputs_hash") or ""),
"1" if bool(result.get("ok")) else "0",
str(result.get("error_class") or ""),
]
)
async def _worker(self) -> None:
while True:
event = await self._queue.get()
if event is None:
self._queue.task_done()
break
try:
await self._persist_event(event)
except Exception as e: # pragma: no cover
logger.warning("ExperienceBus persist error: %s", e)
finally:
self._queue.task_done()
async def _persist_event(self, event: Dict[str, Any]) -> None:
await self._publish_nats(event)
await self._insert_db(event)
async def _publish_nats(self, event: Dict[str, Any]) -> None:
subject = f"{self.subject_prefix}.{event.get('agent_id', 'unknown')}"
payload = json.dumps(event, ensure_ascii=False).encode("utf-8")
msg_id = str(event.get("event_id") or "").strip()
headers = {"Nats-Msg-Id": msg_id} if msg_id else None
if self._js is not None:
try:
await asyncio.wait_for(
self._js.publish(subject, payload, headers=headers),
timeout=self.publish_timeout_s,
)
inc_experience_published(source="router", transport="jetstream", status="ok")
return
except Exception as e:
inc_experience_published(source="router", transport="jetstream", status="error")
logger.debug("ExperienceBus JetStream publish failed: %s", e)
if self._nc is not None:
try:
await asyncio.wait_for(
self._nc.publish(subject, payload, headers=headers),
timeout=self.publish_timeout_s,
)
await asyncio.wait_for(self._nc.flush(), timeout=self.publish_timeout_s)
inc_experience_published(source="router", transport="core", status="ok")
return
except Exception as e:
inc_experience_published(source="router", transport="core", status="error")
logger.debug("ExperienceBus core NATS publish failed: %s", e)
inc_experience_published(source="router", transport="none", status="skipped")
async def _insert_db(self, event: Dict[str, Any]) -> None:
if self._pool is None:
inc_experience_db_insert(source="router", status="skipped")
return
payload_json = json.dumps(event, ensure_ascii=False)
llm = event.get("llm") or {}
result = event.get("result") or {}
event_uuid = _as_uuid(event.get("event_id"))
event_ts = _as_timestamptz(event.get("ts"))
query = """
INSERT INTO agent_experience_events (
event_id,
ts,
node_id,
source,
agent_id,
task_type,
request_id,
channel,
inputs_hash,
provider,
model,
profile,
latency_ms,
tokens_in,
tokens_out,
ok,
error_class,
error_msg_redacted,
http_status,
raw
) VALUES (
$1::uuid,
$2::timestamptz,
$3,
$4,
$5,
$6,
$7,
$8,
$9,
$10,
$11,
$12,
$13,
$14,
$15,
$16,
$17,
$18,
$19,
$20::jsonb
)
ON CONFLICT (event_id) DO NOTHING
"""
try:
async with self._pool.acquire() as conn:
await asyncio.wait_for(
conn.execute(
query,
event_uuid,
event_ts,
event.get("node_id"),
event.get("source"),
event.get("agent_id"),
event.get("task_type"),
event.get("request_id"),
event.get("channel", "unknown"),
event.get("inputs_hash"),
llm.get("provider", "unknown"),
llm.get("model", "unknown"),
llm.get("profile"),
int(llm.get("latency_ms") or 0),
_as_int_or_none(llm.get("tokens_in")),
_as_int_or_none(llm.get("tokens_out")),
bool(result.get("ok")),
result.get("error_class"),
result.get("error_msg_redacted"),
int(result.get("http_status") or 0),
payload_json,
),
timeout=self.db_timeout_s,
)
inc_experience_db_insert(source="router", status="ok")
except Exception as e:
inc_experience_db_insert(source="router", status="error")
logger.debug("ExperienceBus DB insert failed: %s", e)
def redact_error_message(value: Optional[str]) -> Optional[str]:
if value is None:
return None
text = str(value)
text = re.sub(r"(?i)(authorization\s*:\s*bearer)\s+[A-Za-z0-9._-]+", r"\1 [redacted]", text)
text = re.sub(r"(?i)(api[_-]?key|token|password|secret)\s*[:=]\s*[^\s,;]+", r"\1=[redacted]", text)
text = re.sub(r"\b[A-Za-z0-9_\-]{24,}\b", "[redacted]", text)
text = re.sub(r"\s+", " ", text).strip()
if len(text) > 300:
return text[:300]
return text
def normalize_input_for_hash(text: str) -> str:
value = re.sub(r"\s+", " ", (text or "").strip()).lower()
return value[:4000]
def _as_int_or_none(value: Any) -> Optional[int]:
try:
if value is None:
return None
return int(value)
except Exception:
return None
def _as_uuid(value: Any) -> uuid.UUID:
try:
return uuid.UUID(str(value))
except Exception:
return uuid.uuid4()
def _as_timestamptz(value: Any) -> datetime:
if isinstance(value, datetime):
return value if value.tzinfo is not None else value.replace(tzinfo=timezone.utc)
try:
parsed = datetime.fromisoformat(str(value).replace("Z", "+00:00"))
return parsed if parsed.tzinfo is not None else parsed.replace(tzinfo=timezone.utc)
except Exception:
return datetime.now(timezone.utc)

View File

@@ -1,10 +1,12 @@
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import Literal, Optional, Dict, Any, List
from typing import Literal, Optional, Dict, Any, List, Tuple
import asyncio
from collections import OrderedDict
import json
import os
import random as random_module
import re
import yaml
import httpx
@@ -12,6 +14,8 @@ import logging
import hashlib
import hmac
import time # For latency metrics
import uuid
from datetime import datetime, timezone, timedelta
from difflib import SequenceMatcher
# CrewAI Integration
@@ -62,6 +66,34 @@ except ImportError:
global_capabilities_client = None # type: ignore[assignment]
offload_client = None # type: ignore[assignment]
try:
from experience_bus import ExperienceBus, normalize_input_for_hash, redact_error_message
EXPERIENCE_BUS_AVAILABLE = True
except ImportError:
EXPERIENCE_BUS_AVAILABLE = False
ExperienceBus = None # type: ignore[assignment]
try:
import asyncpg
except ImportError:
asyncpg = None # type: ignore[assignment]
try:
from agent_metrics import (
inc_lessons_retrieved,
inc_lessons_attached,
observe_lessons_attach_latency,
)
except Exception:
def inc_lessons_retrieved(*args: Any, **kwargs: Any) -> None:
return None
def inc_lessons_attached(*args: Any, **kwargs: Any) -> None:
return None
def observe_lessons_attach_latency(*args: Any, **kwargs: Any) -> None:
return None
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
NEO4J_NOTIFICATIONS_LOG_LEVEL = os.getenv("NEO4J_NOTIFICATIONS_LOG_LEVEL", "ERROR").strip().upper()
@@ -71,6 +103,29 @@ logging.getLogger("neo4j.notifications").setLevel(_neo4j_notifications_level)
# If auto-router module is unavailable (or loaded later), inference must still work.
SOFIIA_AUTO_ROUTER_AVAILABLE = False
def _parse_agent_id_set(raw_value: Optional[str], default_csv: str = "") -> set[str]:
source = raw_value if (raw_value is not None and str(raw_value).strip() != "") else default_csv
out: set[str] = set()
for part in str(source or "").split(","):
token = part.strip().lower()
if token:
out.add(token)
return out
PLANNED_AGENT_IDS = _parse_agent_id_set(os.getenv("PLANNED_AGENT_IDS"), "aistalk")
DISABLED_AGENT_IDS = _parse_agent_id_set(os.getenv("DISABLED_AGENT_IDS"), "devtools")
def _inactive_agent_state(agent_id: str) -> Optional[str]:
aid = str(agent_id or "").strip().lower()
if aid in PLANNED_AGENT_IDS:
return "planned"
if aid in DISABLED_AGENT_IDS:
return "disabled"
return None
TRUSTED_DOMAINS_CONFIG_PATH = os.getenv("TRUSTED_DOMAINS_CONFIG_PATH", "./trusted_domains.yml")
_trusted_domains_cache: Dict[str, Any] = {"mtime": None, "data": {}}
@@ -894,6 +949,287 @@ def _select_default_llm(agent_id: str, metadata: Dict[str, Any], base_llm: str,
return use_llm
return base_llm
def _safe_json_from_bytes(payload: bytes) -> Dict[str, Any]:
if not payload:
return {}
try:
decoded = payload.decode("utf-8", errors="ignore").strip()
if not decoded:
return {}
value = json.loads(decoded)
if isinstance(value, dict):
return value
except Exception:
return {}
return {}
def _extract_infer_agent_id(path: str) -> Optional[str]:
match = _INFER_PATH_RE.match(path or "")
if not match:
return None
return (match.group(1) or "").strip().lower() or None
def _infer_channel_from_metadata(metadata: Dict[str, Any]) -> str:
channel = str(
metadata.get("channel")
or metadata.get("channel_type")
or metadata.get("source")
or metadata.get("entrypoint")
or "unknown"
).strip().lower()
if channel in {"telegram", "web", "api"}:
return channel
return "unknown"
def _derive_provider_from_backend_model(backend: str, model: str, profile: Optional[str]) -> str:
profiles = (router_config or {}).get("llm_profiles", {}) if isinstance(router_config, dict) else {}
if profile and isinstance(profiles, dict):
p = profiles.get(profile, {})
if isinstance(p, dict) and p.get("provider"):
return str(p.get("provider"))
b = str(backend or "").lower()
m = str(model or "").lower()
if "mistral" in b:
return "mistral"
if "deepseek" in b:
return "deepseek"
if "grok" in b:
return "grok"
if "anthropic" in b or "claude" in b:
return "anthropic"
if "openai" in b:
return "openai"
if "glm" in b:
return "glm"
if "nats-offload" in b:
return "remote"
if "ollama" in b or "local" in b:
return "local"
if any(m.startswith(prefix) for prefix in ("qwen", "gemma", "mistral", "deepseek", "glm")):
return "local"
return "other"
def _resolve_profile_for_event(agent_id: str, req_payload: Dict[str, Any]) -> Optional[str]:
if not isinstance(router_config, dict):
return None
metadata = req_payload.get("metadata")
if not isinstance(metadata, dict):
metadata = {}
agent_cfg = (router_config.get("agents") or {}).get(agent_id, {})
if not isinstance(agent_cfg, dict):
return None
base_llm = str(agent_cfg.get("default_llm") or "").strip()
if not base_llm:
return None
rules = router_config.get("routing") or []
if isinstance(rules, list):
return _select_default_llm(agent_id, metadata, base_llm, rules)
return base_llm
def _lesson_guarded_text(value: Any, max_len: int = 220) -> str:
text = re.sub(r"\s+", " ", str(value or "")).strip()
if not text:
return ""
lower = text.lower()
if any(marker in lower for marker in LESSONS_INJECTION_GUARDS):
return ""
if len(text) > max_len:
text = text[:max_len].rstrip()
return text
def _decode_lesson_signals(raw: Any) -> Dict[str, Any]:
if isinstance(raw, dict):
return dict(raw)
if isinstance(raw, str):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict):
return parsed
except Exception:
return {}
return {}
def _score_lesson_record(
row: Dict[str, Any],
*,
agent_id: str,
provider: str,
model: str,
profile: str,
last_error_class: Optional[str],
) -> float:
score = 0.0
row_agent_id = str(row.get("agent_id") or "").strip().lower()
if row_agent_id and row_agent_id == agent_id:
score += 3.0
signals = _decode_lesson_signals(row.get("signals"))
signal_error = str(signals.get("error_class") or "").strip().lower()
if last_error_class and signal_error and signal_error == last_error_class.lower():
score += 2.0
signal_provider = str(signals.get("provider") or "").strip().lower()
signal_model = str(signals.get("model") or "").strip().lower()
signal_profile = str(signals.get("profile") or "").strip().lower()
if provider and signal_provider and signal_provider == provider:
score += 1.0
if model and signal_model and signal_model == model:
score += 1.0
if profile and signal_profile and signal_profile == profile:
score += 1.0
row_ts = row.get("ts")
if isinstance(row_ts, datetime):
dt = row_ts if row_ts.tzinfo else row_ts.replace(tzinfo=timezone.utc)
age_hours = max(0.0, (datetime.now(timezone.utc) - dt).total_seconds() / 3600.0)
score -= min(2.0, age_hours / 168.0) # down-rank lessons older than ~7 days
return score
def _render_operational_lessons(lessons: List[Dict[str, Any]], max_chars: int) -> str:
if not lessons:
return ""
lines = ["Operational Lessons (apply if relevant):"]
for idx, lesson in enumerate(lessons, start=1):
trigger = _lesson_guarded_text(lesson.get("trigger"), max_len=220)
action = _lesson_guarded_text(lesson.get("action"), max_len=220)
avoid = _lesson_guarded_text(lesson.get("avoid"), max_len=220)
if not trigger or not action or not avoid:
continue
chunk = f"{idx}) Trigger: {trigger}\n Do: {action}\n Avoid: {avoid}"
candidate = "\n".join(lines + [chunk])
if len(candidate) > max_chars:
break
lines.append(chunk)
if len(lines) <= 1:
return ""
return "\n".join(lines)
async def _update_last_infer_signal(agent_id: str, *, ok: bool, error_class: Optional[str], latency_ms: int) -> None:
key = str(agent_id or "").strip().lower()
if not key:
return
now = time.monotonic()
async with _lessons_signal_lock:
_lessons_signal_cache[key] = {
"ok": bool(ok),
"error_class": str(error_class or "").strip() or None,
"latency_ms": int(max(0, latency_ms)),
"seen_at": now,
}
_lessons_signal_cache.move_to_end(key, last=True)
threshold = now - max(30, LESSONS_SIGNAL_CACHE_TTL_SECONDS)
stale_keys = [k for k, v in _lessons_signal_cache.items() if float(v.get("seen_at", 0.0)) < threshold]
for stale_key in stale_keys:
_lessons_signal_cache.pop(stale_key, None)
while len(_lessons_signal_cache) > 4000:
_lessons_signal_cache.popitem(last=False)
async def _get_last_infer_signal(agent_id: str) -> Optional[Dict[str, Any]]:
key = str(agent_id or "").strip().lower()
if not key:
return None
now = time.monotonic()
async with _lessons_signal_lock:
value = _lessons_signal_cache.get(key)
if not value:
return None
age = now - float(value.get("seen_at", 0.0))
if age > LESSONS_SIGNAL_CACHE_TTL_SECONDS:
_lessons_signal_cache.pop(key, None)
return None
return dict(value)
async def _fetch_ranked_lessons(
*,
agent_id: str,
provider: str,
model: str,
profile: str,
last_error_class: Optional[str],
limit: int,
) -> Tuple[List[Dict[str, Any]], str, int]:
if lessons_db_pool is None:
return [], "err", 0
query = """
SELECT lesson_key, ts, scope, agent_id, task_type, trigger, action, avoid, signals
FROM agent_lessons
WHERE (agent_id = $1 OR agent_id IS NULL)
AND task_type = 'infer'
ORDER BY (agent_id = $1) DESC, ts DESC
LIMIT 50
"""
started = time.time()
try:
async with lessons_db_pool.acquire() as conn:
rows = await asyncio.wait_for(
conn.fetch(query, str(agent_id).strip().lower()),
timeout=LESSONS_ATTACH_TIMEOUT_MS / 1000.0,
)
except asyncio.TimeoutError:
elapsed = max(0, int((time.time() - started) * 1000))
return [], "timeout", elapsed
except Exception as e:
logger.debug("Lessons retrieval failed: %s", e)
elapsed = max(0, int((time.time() - started) * 1000))
return [], "err", elapsed
ranked: List[Tuple[float, datetime, Dict[str, Any]]] = []
for row in rows:
row_data = dict(row)
lesson = {
"lesson_key": row_data.get("lesson_key"),
"ts": row_data.get("ts"),
"scope": row_data.get("scope"),
"agent_id": row_data.get("agent_id"),
"task_type": row_data.get("task_type"),
"trigger": row_data.get("trigger"),
"action": row_data.get("action"),
"avoid": row_data.get("avoid"),
"signals": _decode_lesson_signals(row_data.get("signals")),
}
if not (
_lesson_guarded_text(lesson.get("trigger"))
and _lesson_guarded_text(lesson.get("action"))
and _lesson_guarded_text(lesson.get("avoid"))
):
continue
score = _score_lesson_record(
lesson,
agent_id=agent_id,
provider=(provider or "").strip().lower(),
model=(model or "").strip().lower(),
profile=(profile or "").strip().lower(),
last_error_class=last_error_class,
)
ts = lesson.get("ts")
if not isinstance(ts, datetime):
ts = datetime.now(timezone.utc) - timedelta(days=365)
ranked.append((score, ts, lesson))
ranked.sort(key=lambda item: (item[0], item[1]), reverse=True)
selected = [item[2] for item in ranked[: max(1, limit)]]
elapsed = max(0, int((time.time() - started) * 1000))
return selected, "ok", elapsed
app = FastAPI(title="DAARION Router", version="2.0.0")
# Configuration
@@ -907,6 +1243,27 @@ VISION_URL = os.getenv("VISION_URL", "http://host.docker.internal:11434")
OCR_URL = os.getenv("OCR_URL", "http://swapper-service:8890")
DOCUMENT_URL = os.getenv("DOCUMENT_URL", "http://swapper-service:8890")
CITY_SERVICE_URL = os.getenv("CITY_SERVICE_URL", "http://daarion-city-service:7001")
LESSONS_ATTACH_ENABLED = os.getenv("LESSONS_ATTACH_ENABLED", "true").lower() in {"1", "true", "yes"}
LESSONS_ATTACH_MIN = max(1, int(os.getenv("LESSONS_ATTACH_MIN", "3")))
LESSONS_ATTACH_MAX = max(LESSONS_ATTACH_MIN, int(os.getenv("LESSONS_ATTACH_MAX", "7")))
LESSONS_ATTACH_TIMEOUT_MS = max(5, int(os.getenv("LESSONS_ATTACH_TIMEOUT_MS", "25")))
LESSONS_ATTACH_SAMPLE_PCT = max(0.0, min(100.0, float(os.getenv("LESSONS_ATTACH_SAMPLE_PCT", "10"))))
LESSONS_ATTACH_MAX_CHARS = max(400, int(os.getenv("LESSONS_ATTACH_MAX_CHARS", "1200")))
LESSONS_SIGNAL_CACHE_TTL_SECONDS = max(30, int(os.getenv("LESSONS_SIGNAL_CACHE_TTL_SECONDS", "300")))
LESSONS_LATENCY_SPIKE_MS = max(250, int(os.getenv("EXPERIENCE_LATENCY_SPIKE_MS", "5000")))
LESSONS_DATABASE_URL = (
os.getenv("LESSONS_DATABASE_URL")
or os.getenv("EXPERIENCE_DATABASE_URL")
or os.getenv("DATABASE_URL")
)
LESSONS_INJECTION_GUARDS = (
"ignore previous",
"ignore all previous",
"system:",
"developer:",
"```",
)
# CrewAI Routing Configuration
CREWAI_ROUTING_ENABLED = os.getenv("CREWAI_ROUTING_ENABLED", "true").lower() == "true"
@@ -947,6 +1304,12 @@ nats_available = False
# Tool Manager
tool_manager = None
runtime_guard_engine = None
experience_bus = None
lessons_db_pool = None
_lessons_signal_cache: "OrderedDict[str, Dict[str, Any]]" = OrderedDict()
_lessons_signal_lock = asyncio.Lock()
_INFER_PATH_RE = re.compile(r"^/v1/agents/([^/]+)/infer/?$")
# Models
class FilterDecision(BaseModel):
@@ -999,10 +1362,146 @@ def load_router_config():
config = load_config()
router_config = load_router_config()
@app.middleware("http")
async def experience_capture_middleware(request: Request, call_next):
"""Capture /infer outcomes and emit ExperienceEvent asynchronously."""
infer_agent_id = _extract_infer_agent_id(request.url.path)
if (
not infer_agent_id
or request.method.upper() != "POST"
or not EXPERIENCE_BUS_AVAILABLE
or experience_bus is None
):
return await call_next(request)
started_at = time.time()
req_body = await request.body()
async def _receive() -> Dict[str, Any]:
return {"type": "http.request", "body": req_body, "more_body": False}
wrapped_request = Request(request.scope, _receive)
response = None
response_body = b""
status_code = 500
caught_exc: Optional[Exception] = None
try:
response = await call_next(wrapped_request)
status_code = int(response.status_code)
chunks: List[bytes] = []
async for chunk in response.body_iterator:
chunks.append(chunk)
response_body = b"".join(chunks)
except Exception as exc: # pragma: no cover - defensive capture path
caught_exc = exc
status_code = 500
latency_ms = max(0, int((time.time() - started_at) * 1000))
try:
req_payload = _safe_json_from_bytes(req_body)
resp_payload = _safe_json_from_bytes(response_body)
metadata = req_payload.get("metadata")
if not isinstance(metadata, dict):
metadata = {}
prompt = str(req_payload.get("prompt") or "")
normalized_input = normalize_input_for_hash(prompt)
inputs_hash = hashlib.sha256(normalized_input.encode("utf-8")).hexdigest()
profile = _resolve_profile_for_event(infer_agent_id, req_payload)
profile_cfg = {}
if profile and isinstance(router_config, dict):
profile_cfg = (router_config.get("llm_profiles") or {}).get(profile, {}) or {}
if not isinstance(profile_cfg, dict):
profile_cfg = {}
model = str(resp_payload.get("model") or profile_cfg.get("model") or "unknown")
backend = str(resp_payload.get("backend") or "")
provider = _derive_provider_from_backend_model(backend, model, profile)
tokens_total = resp_payload.get("tokens_used")
tokens_out = int(tokens_total) if isinstance(tokens_total, int) else None
request_id = str(
metadata.get("request_id")
or metadata.get("trace_id")
or request.headers.get("x-request-id")
or ""
).strip() or None
err_class: Optional[str] = None
err_msg: Optional[str] = None
detail_obj = resp_payload.get("detail")
if caught_exc is not None:
err_class = type(caught_exc).__name__
err_msg = str(caught_exc)
elif status_code >= 400:
if isinstance(detail_obj, dict):
err_class = str(detail_obj.get("code") or detail_obj.get("error_class") or f"http_{status_code}")
err_msg = str(detail_obj.get("message") or detail_obj.get("detail") or json.dumps(detail_obj))
elif isinstance(detail_obj, str):
err_class = f"http_{status_code}"
err_msg = detail_obj
else:
err_class = f"http_{status_code}"
err_msg = f"http_status={status_code}"
await _update_last_infer_signal(
infer_agent_id,
ok=status_code < 400,
error_class=err_class,
latency_ms=latency_ms,
)
event = {
"event_id": str(uuid.uuid4()),
"ts": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"),
"node_id": os.getenv("NODE_ID", "NODA1"),
"source": "router",
"agent_id": infer_agent_id,
"request_id": request_id,
"channel": _infer_channel_from_metadata(metadata),
"task_type": "infer",
"inputs_hash": inputs_hash,
"llm": {
"provider": provider,
"model": model,
"profile": profile,
"latency_ms": latency_ms,
"tokens_in": None,
"tokens_out": tokens_out,
},
"result": {
"ok": status_code < 400,
"error_class": err_class,
"error_msg_redacted": redact_error_message(err_msg),
"http_status": status_code,
},
}
await experience_bus.capture(event)
except Exception as exp_err:
logger.debug("Experience capture skipped: %s", exp_err)
if caught_exc is not None:
raise caught_exc
headers = dict(response.headers) if response is not None else {}
headers.pop("content-length", None)
return Response(
content=response_body,
status_code=status_code,
headers=headers,
media_type=response.media_type if response is not None else "application/json",
background=response.background if response is not None else None,
)
@app.on_event("startup")
async def startup_event():
"""Initialize NATS connection and subscriptions"""
global nc, nats_available, http_client, neo4j_driver, neo4j_available, runtime_guard_engine
global nc, nats_available, http_client, neo4j_driver, neo4j_available, runtime_guard_engine, experience_bus, lessons_db_pool
logger.info("🚀 DAGI Router v2.0.0 starting up...")
# Initialize HTTP client
@@ -1041,6 +1540,34 @@ async def startup_event():
logger.warning(f"⚠️ NATS not available: {e}")
logger.warning("⚠️ Running in test mode (HTTP only)")
nats_available = False
# Initialize Experience Bus (Phase-1)
if EXPERIENCE_BUS_AVAILABLE and ExperienceBus is not None:
try:
experience_bus = ExperienceBus()
await experience_bus.start(nats_client=nc if nats_available else None)
logger.info("✅ Experience Bus initialized")
except Exception as e:
experience_bus = None
logger.warning(f"⚠️ Experience Bus init failed: {e}")
# Initialize lessons retrieval pool (Phase-3 read path)
if LESSONS_ATTACH_ENABLED:
if asyncpg is None:
logger.warning("⚠️ Lessons attach enabled but asyncpg is unavailable")
elif not LESSONS_DATABASE_URL:
logger.warning("⚠️ Lessons attach enabled but LESSONS_DATABASE_URL is missing")
else:
try:
lessons_db_pool = await asyncpg.create_pool(
LESSONS_DATABASE_URL,
min_size=1,
max_size=3,
)
logger.info("✅ Lessons DB pool initialized")
except Exception as e:
lessons_db_pool = None
logger.warning(f"⚠️ Lessons DB pool init failed: {e}")
# Initialize Memory Retrieval Pipeline
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval:
@@ -1765,6 +2292,24 @@ async def agent_infer(agent_id: str, request: InferRequest):
"""
logger.info(f"🔀 Inference request for agent: {agent_id}")
logger.info(f"📝 Prompt: {request.prompt[:100]}...")
inactive_state = _inactive_agent_state(agent_id)
if inactive_state is not None:
status_code = 410 if inactive_state == "planned" else 404
logger.info(
"⛔ Agent unavailable by lifecycle state: agent=%s state=%s",
agent_id,
inactive_state,
)
raise HTTPException(
status_code=status_code,
detail={
"code": f"agent_{inactive_state}",
"agent_id": str(agent_id).strip().lower(),
"state": inactive_state,
"message": "Agent is not active in this environment",
},
)
# =========================================================================
# MEMORY RETRIEVAL (v4.0 - Universal for all agents)
@@ -2682,23 +3227,77 @@ async def agent_infer(agent_id: str, request: InferRequest):
# SMART LLM ROUTER WITH AUTO-FALLBACK
# Priority: DeepSeek → Mistral → Grok → Local Ollama
# =========================================================================
lessons_block = ""
lessons_attached_count = 0
if LESSONS_ATTACH_ENABLED and not request.images:
retrieval_always_on = False
retrieval_limit = LESSONS_ATTACH_MIN
last_signal = await _get_last_infer_signal(request_agent_id)
last_error_class = None
if last_signal:
last_error_class = last_signal.get("error_class")
if (not bool(last_signal.get("ok", True))) or int(last_signal.get("latency_ms", 0) or 0) >= LESSONS_LATENCY_SPIKE_MS:
retrieval_always_on = True
retrieval_limit = LESSONS_ATTACH_MAX
should_retrieve = retrieval_always_on or (random_module.random() * 100.0 < LESSONS_ATTACH_SAMPLE_PCT)
if should_retrieve:
lessons_rows, retrieval_status, retrieval_latency_ms = await _fetch_ranked_lessons(
agent_id=request_agent_id,
provider=str(provider or "").strip().lower(),
model=str(model or "").strip().lower(),
profile=str(default_llm or "").strip().lower(),
last_error_class=str(last_error_class or "").strip() or None,
limit=retrieval_limit,
)
inc_lessons_retrieved(status=retrieval_status)
observe_lessons_attach_latency(latency_ms=float(retrieval_latency_ms))
if retrieval_status == "ok" and lessons_rows:
selected_lessons = lessons_rows[:retrieval_limit]
lessons_block = _render_operational_lessons(selected_lessons, LESSONS_ATTACH_MAX_CHARS)
if lessons_block:
lessons_attached_count = len(selected_lessons)
logger.info(
"🧠 lessons_attached=%s agent=%s mode=%s",
lessons_attached_count,
request_agent_id,
"always_on" if retrieval_always_on else "sampled",
)
inc_lessons_attached(count=lessons_attached_count)
# Build messages array once for all providers
messages = []
if system_prompt:
combined_parts: List[str] = [system_prompt]
if memory_brief_text:
enhanced_prompt = f"{system_prompt}\n\n[INTERNAL MEMORY - do NOT repeat to user]\n{memory_brief_text}"
messages.append({"role": "system", "content": enhanced_prompt})
logger.info(f"📝 Added system message with prompt ({len(system_prompt)} chars) + memory ({len(memory_brief_text)} chars)")
else:
messages.append({"role": "system", "content": system_prompt})
logger.info(f"📝 Added system message with prompt ({len(system_prompt)} chars)")
elif memory_brief_text:
messages.append({"role": "system", "content": f"[INTERNAL MEMORY - do NOT repeat to user]\n{memory_brief_text}"})
logger.warning(f"⚠️ No system_prompt! Using only memory brief ({len(memory_brief_text)} chars)")
combined_parts.append(f"[INTERNAL MEMORY - do NOT repeat to user]\n{memory_brief_text}")
if lessons_block:
combined_parts.append(f"[OPERATIONAL LESSONS - INTERNAL]\n{lessons_block}")
enhanced_prompt = "\n\n".join(combined_parts)
messages.append({"role": "system", "content": enhanced_prompt})
logger.info(
"📝 Added system message prompt=%s memory=%s lessons=%s",
len(system_prompt),
len(memory_brief_text or ""),
lessons_attached_count,
)
elif memory_brief_text or lessons_block:
fallback_parts: List[str] = []
if memory_brief_text:
fallback_parts.append(f"[INTERNAL MEMORY - do NOT repeat to user]\n{memory_brief_text}")
if lessons_block:
fallback_parts.append(f"[OPERATIONAL LESSONS - INTERNAL]\n{lessons_block}")
messages.append({"role": "system", "content": "\n\n".join(fallback_parts)})
logger.warning(
"⚠️ No system_prompt! Using fallback context memory=%s lessons=%s",
len(memory_brief_text or ""),
lessons_attached_count,
)
else:
logger.error(f"❌ No system_prompt AND no memory_brief! LLM will have no context!")
logger.error("❌ No system_prompt, memory_brief, or lessons; LLM will have no context")
messages.append({"role": "user", "content": request.prompt})
logger.debug(f"📨 Messages array: {len(messages)} messages, system={len(messages[0].get('content', '')) if messages else 0} chars")
@@ -4555,7 +5154,7 @@ async def sofiia_model_catalog(refresh_ollama: bool = False):
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup connections on shutdown"""
global neo4j_driver, http_client, nc
global neo4j_driver, http_client, nc, experience_bus, lessons_db_pool
# Close Memory Retrieval
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval:
@@ -4576,3 +5175,17 @@ async def shutdown_event():
if nc:
await nc.close()
logger.info("🔌 NATS connection closed")
if EXPERIENCE_BUS_AVAILABLE and experience_bus:
try:
await experience_bus.stop()
logger.info("🔌 Experience Bus closed")
except Exception as e:
logger.warning(f"⚠️ Experience Bus close error: {e}")
if lessons_db_pool is not None:
try:
await lessons_db_pool.close()
logger.info("🔌 Lessons DB pool closed")
except Exception as e:
logger.warning(f"⚠️ Lessons DB pool close error: {e}")