108 lines
3.1 KiB
Python
108 lines
3.1 KiB
Python
import asyncio
|
|
from datetime import datetime, timezone
|
|
import logging
|
|
import os
|
|
from typing import Any, Dict
|
|
|
|
import httpx
|
|
|
|
from .redis_jobs import close_redis, dequeue_job, get_job, update_job, wait_for_redis
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
|
|
logger = logging.getLogger("daarion-gateway-worker")
|
|
|
|
ROUTER_BASE_URL = os.getenv("ROUTER_BASE_URL", os.getenv("ROUTER_URL", "http://router:8000"))
|
|
ROUTER_TIMEOUT_SECONDS = float(os.getenv("ROUTER_WORKER_TIMEOUT", "60"))
|
|
|
|
|
|
def _now() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
async def _call_router(agent_id: str, input_payload: Dict[str, Any], metadata: Dict[str, Any]) -> Dict[str, Any]:
|
|
body: Dict[str, Any] = {
|
|
"prompt": input_payload.get("prompt", ""),
|
|
"metadata": metadata or {},
|
|
}
|
|
images = input_payload.get("images") or []
|
|
if images:
|
|
body["images"] = images
|
|
|
|
url = f"{ROUTER_BASE_URL}/v1/agents/{agent_id}/infer"
|
|
async with httpx.AsyncClient(timeout=ROUTER_TIMEOUT_SECONDS) as client:
|
|
resp = await client.post(url, json=body)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
return {
|
|
"response": data.get("response", ""),
|
|
"model": data.get("model"),
|
|
"backend": data.get("backend"),
|
|
"tokens_used": data.get("tokens_used"),
|
|
}
|
|
|
|
|
|
async def run_once(job_id: str) -> None:
|
|
job = await get_job(job_id)
|
|
if not job:
|
|
logger.warning("job_missing: %s", job_id)
|
|
return
|
|
|
|
await update_job(job_id, {"status": "running", "started_at": _now(), "updated_at": _now()})
|
|
|
|
agent_id = job.get("agent_id")
|
|
input_payload = job.get("input") or {}
|
|
metadata = job.get("metadata") or {}
|
|
|
|
try:
|
|
result = await _call_router(agent_id, input_payload, metadata)
|
|
await update_job(
|
|
job_id,
|
|
{
|
|
"status": "done",
|
|
"result": result,
|
|
"error": None,
|
|
"finished_at": _now(),
|
|
"updated_at": _now(),
|
|
},
|
|
)
|
|
logger.info("job_done: %s agent=%s", job_id, agent_id)
|
|
except Exception as e:
|
|
await update_job(
|
|
job_id,
|
|
{
|
|
"status": "failed",
|
|
"error": {"type": e.__class__.__name__, "message": str(e)},
|
|
"finished_at": _now(),
|
|
"updated_at": _now(),
|
|
},
|
|
)
|
|
logger.exception("job_failed: %s agent=%s", job_id, agent_id)
|
|
|
|
|
|
async def worker_loop() -> None:
|
|
await wait_for_redis(60)
|
|
logger.info("worker_started router=%s", ROUTER_BASE_URL)
|
|
|
|
while True:
|
|
try:
|
|
job_id = await dequeue_job(block_seconds=10)
|
|
if not job_id:
|
|
continue
|
|
await run_once(job_id)
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception:
|
|
logger.exception("worker_loop_error")
|
|
await asyncio.sleep(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(worker_loop())
|
|
finally:
|
|
try:
|
|
asyncio.run(close_redis())
|
|
except Exception:
|
|
pass
|