175 lines
6.9 KiB
Python
175 lines
6.9 KiB
Python
"""
|
|
DAGI Router Client
|
|
Sends requests to DAGI Router from Bot Gateway
|
|
"""
|
|
import logging
|
|
import os
|
|
import time
|
|
import httpx
|
|
from typing import Dict, Any
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Import metrics
|
|
try:
|
|
from metrics import ROUTER_CALLS_TOTAL, ROUTER_LATENCY, ERRORS_TOTAL
|
|
METRICS_AVAILABLE = True
|
|
except ImportError:
|
|
METRICS_AVAILABLE = False
|
|
|
|
# Router configuration from environment
|
|
ROUTER_BASE_URL = os.getenv("ROUTER_URL", "http://127.0.0.1:9102")
|
|
ROUTER_TIMEOUT = float(os.getenv("ROUTER_TIMEOUT", "180.0"))
|
|
GATEWAY_MAX_TOKENS_DEFAULT = int(os.getenv("GATEWAY_MAX_TOKENS_DEFAULT", "700"))
|
|
GATEWAY_MAX_TOKENS_CONCISE = int(os.getenv("GATEWAY_MAX_TOKENS_CONCISE", "220"))
|
|
GATEWAY_MAX_TOKENS_TRAINING = int(os.getenv("GATEWAY_MAX_TOKENS_TRAINING", "900"))
|
|
GATEWAY_TEMPERATURE_DEFAULT = float(os.getenv("GATEWAY_TEMPERATURE_DEFAULT", "0.4"))
|
|
GATEWAY_MAX_TOKENS_SENPAI_DEFAULT = int(os.getenv("GATEWAY_MAX_TOKENS_SENPAI_DEFAULT", "320"))
|
|
GATEWAY_MAX_TOKENS_HELION_DEFAULT = int(os.getenv("GATEWAY_MAX_TOKENS_HELION_DEFAULT", "240"))
|
|
GATEWAY_MAX_TOKENS_DETAILED = int(os.getenv("GATEWAY_MAX_TOKENS_DETAILED", "900"))
|
|
|
|
|
|
def _apply_runtime_communication_guardrails(system_prompt: str, metadata: Dict[str, Any]) -> str:
|
|
"""Apply global communication constraints for all agents in Telegram flows."""
|
|
if not system_prompt:
|
|
return system_prompt
|
|
|
|
lang_label = (metadata or {}).get("preferred_response_language_label") or "user language"
|
|
guardrail = (
|
|
"\n\n[GLOBAL COMMUNICATION POLICY]\n"
|
|
"1) Do not introduce yourself by name in every message.\n"
|
|
"2) Do not add repetitive generic closers like 'how can I help' unless user explicitly asks.\n"
|
|
"3) Continue the dialog naturally from context.\n"
|
|
f"4) Respond in {lang_label}, matching the user's latest language.\n"
|
|
)
|
|
return system_prompt + guardrail
|
|
|
|
|
|
def _apply_agent_style_guardrails(agent_id: str, system_prompt: str) -> str:
|
|
"""Apply lightweight runtime style constraints for specific agents."""
|
|
if not system_prompt:
|
|
return system_prompt
|
|
|
|
if agent_id == "nutra":
|
|
nutra_guardrail = (
|
|
"\n\n[STYLE LOCK - NUTRA]\n"
|
|
"Always write in first-person singular and feminine form.\n"
|
|
"Use feminine wording in Ukrainian/Russian (e.g., 'я підготувала', 'я готова', "
|
|
"'я зрозуміла').\n"
|
|
"Never switch to masculine forms (e.g., 'понял', 'готов').\n"
|
|
)
|
|
return system_prompt + nutra_guardrail
|
|
|
|
return system_prompt
|
|
|
|
|
|
async def send_to_router(body: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Send request to DAGI Router."""
|
|
_start = time.time()
|
|
|
|
agent_id = body.get("agent", "devtools")
|
|
message = body.get("message", "")
|
|
metadata = body.get("metadata", {})
|
|
context = body.get("context", {})
|
|
|
|
system_prompt = body.get("system_prompt") or context.get("system_prompt")
|
|
system_prompt = _apply_agent_style_guardrails(agent_id, system_prompt)
|
|
system_prompt = _apply_runtime_communication_guardrails(system_prompt, metadata)
|
|
request_id = str(metadata.get("request_id") or metadata.get("trace_id") or "").strip()
|
|
if request_id:
|
|
metadata["request_id"] = request_id
|
|
metadata["trace_id"] = request_id
|
|
|
|
if system_prompt:
|
|
logger.info(f"Using system prompt ({len(system_prompt)} chars) for agent {agent_id}")
|
|
|
|
infer_url = f"{ROUTER_BASE_URL}/v1/agents/{agent_id}/infer"
|
|
metadata["agent_id"] = agent_id
|
|
|
|
# Keep defaults moderate to avoid overly long replies while preserving flexibility.
|
|
max_tokens = GATEWAY_MAX_TOKENS_DEFAULT
|
|
|
|
# Senpai tends to over-verbose responses in Telegram; use lower default unless user asked details.
|
|
if agent_id == "senpai":
|
|
max_tokens = GATEWAY_MAX_TOKENS_SENPAI_DEFAULT
|
|
elif agent_id == "helion":
|
|
max_tokens = min(max_tokens, GATEWAY_MAX_TOKENS_HELION_DEFAULT)
|
|
|
|
if metadata.get("is_training_group"):
|
|
max_tokens = GATEWAY_MAX_TOKENS_TRAINING
|
|
|
|
if metadata.get("force_detailed"):
|
|
max_tokens = max(max_tokens, GATEWAY_MAX_TOKENS_DETAILED)
|
|
|
|
if metadata.get("force_concise"):
|
|
max_tokens = min(max_tokens, GATEWAY_MAX_TOKENS_CONCISE)
|
|
|
|
infer_body = {
|
|
"prompt": message,
|
|
"system_prompt": system_prompt,
|
|
"metadata": metadata,
|
|
"max_tokens": max_tokens,
|
|
"temperature": float(metadata.get("temperature_override", GATEWAY_TEMPERATURE_DEFAULT)),
|
|
}
|
|
|
|
images = context.get("images", [])
|
|
if images:
|
|
infer_body["images"] = images
|
|
logger.info(f"Including {len(images)} image(s) in request")
|
|
|
|
if metadata.get("provider"):
|
|
infer_body["provider_override"] = metadata["provider"]
|
|
|
|
prov = metadata.get("provider", "default")
|
|
logger.info(
|
|
f"Sending to Router ({infer_url}): agent={agent_id}, provider={prov}, "
|
|
f"has_images={bool(images)}, prompt_len={len(message)}, max_tokens={max_tokens}"
|
|
)
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=ROUTER_TIMEOUT) as client:
|
|
headers = {"X-Request-Id": request_id} if request_id else None
|
|
response = await client.post(infer_url, json=infer_body, headers=headers)
|
|
response.raise_for_status()
|
|
|
|
result = response.json()
|
|
|
|
latency = time.time() - _start
|
|
if METRICS_AVAILABLE:
|
|
ROUTER_CALLS_TOTAL.labels(status="success").inc()
|
|
ROUTER_LATENCY.observe(latency)
|
|
|
|
logger.info(f"Router response in {latency:.2f}s")
|
|
|
|
return {
|
|
"ok": True,
|
|
"data": {
|
|
"text": result.get("response", result.get("text", "")),
|
|
"image_base64": result.get("image_base64"),
|
|
"file_base64": result.get("file_base64"),
|
|
"file_name": result.get("file_name"),
|
|
"file_mime": result.get("file_mime"),
|
|
},
|
|
"response": result.get("response", result.get("text", "")),
|
|
"model": result.get("model"),
|
|
"backend": result.get("backend"),
|
|
"image_base64": result.get("image_base64"),
|
|
"file_base64": result.get("file_base64"),
|
|
"file_name": result.get("file_name"),
|
|
"file_mime": result.get("file_mime"),
|
|
}
|
|
|
|
except httpx.TimeoutException as e:
|
|
if METRICS_AVAILABLE:
|
|
ROUTER_CALLS_TOTAL.labels(status="timeout").inc()
|
|
ERRORS_TOTAL.labels(type="timeout", source="router").inc()
|
|
logger.error(f"Router request timeout after {time.time() - _start:.2f}s: {e}")
|
|
raise
|
|
|
|
except httpx.HTTPError as e:
|
|
if METRICS_AVAILABLE:
|
|
ROUTER_CALLS_TOTAL.labels(status="error").inc()
|
|
ERRORS_TOTAL.labels(type="http_error", source="router").inc()
|
|
logger.error(f"Router request failed: {e}")
|
|
raise
|