"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects.""" import asyncio import json import logging import os import time from typing import Any, Dict import config from models import JobRequest, JobResponse, JobError from idempotency import IdempotencyStore from providers import ollama, swapper_vision import fabric_metrics as fm logger = logging.getLogger("node-worker") _idem = IdempotencyStore() _semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY) _nats_client = None _inflight_count: int = 0 _latencies_llm: list = [] _latencies_vision: list = [] _LATENCY_BUFFER = 50 async def start(nats_client): global _nats_client _nats_client = nats_client subjects = [ f"node.{config.NODE_ID.lower()}.llm.request", f"node.{config.NODE_ID.lower()}.vision.request", ] for subj in subjects: await nats_client.subscribe(subj, cb=_handle_request) logger.info(f"✅ Subscribed: {subj}") async def _handle_request(msg): t0 = time.time() try: raw = msg.data if len(raw) > config.MAX_PAYLOAD_BYTES: await _reply(msg, JobResponse( status="error", node_id=config.NODE_ID, error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"), )) return data = json.loads(raw) job = JobRequest(**data) job.trace_id = job.trace_id or job.job_id idem_key = job.effective_idem_key() cached = _idem.get(idem_key) if cached: logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}") await _reply(msg, cached) return remaining = job.remaining_ms() if remaining <= 0: resp = JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"), ) _idem.put(idem_key, resp) await _reply(msg, resp) return inflight = await _idem.acquire_inflight(idem_key) if inflight is not None: try: resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0) resp_copy = resp.model_copy() resp_copy.cached = True await _reply(msg, resp_copy) except asyncio.TimeoutError: await _reply(msg, JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"), )) return if _semaphore.locked() and _semaphore._value == 0: resp = JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="busy", error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"), ) _idem.complete_inflight(idem_key, resp) await _reply(msg, resp) return global _inflight_count _inflight_count += 1 fm.set_inflight(_inflight_count) try: async with _semaphore: resp = await _execute(job, remaining) finally: _inflight_count -= 1 fm.set_inflight(_inflight_count) _idem.put(idem_key, resp) _idem.complete_inflight(idem_key, resp) resp.latency_ms = int((time.time() - t0) * 1000) fm.inc_job(job.required_type, resp.status) if resp.status == "ok" and resp.latency_ms > 0: fm.observe_latency(job.required_type, resp.model or "?", resp.latency_ms) buf = _latencies_llm if job.required_type in ("llm", "code") else _latencies_vision buf.append(resp.latency_ms) if len(buf) > _LATENCY_BUFFER: del buf[:len(buf) - _LATENCY_BUFFER] _report_latency_async(job.required_type, resp.provider or "ollama", resp.latency_ms) await _reply(msg, resp) except Exception as e: logger.exception(f"Worker handler error: {e}") try: await _reply(msg, JobResponse( node_id=config.NODE_ID, status="error", error=JobError(code="INTERNAL", message=str(e)[:200]), )) except Exception: pass async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse: payload = job.payload hints = job.hints timeout_s = min(remaining_ms / 1000.0, 120.0) model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "") msg_count = len(payload.get("messages", [])) prompt_chars = len(payload.get("prompt", "")) logger.info( f"[job.start] job={job.job_id} trace={job.trace_id} " f"type={job.required_type} model={model or '?'} " f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms" ) try: if job.required_type == "llm": result = await asyncio.wait_for( ollama.infer( messages=payload.get("messages"), prompt=payload.get("prompt", ""), model=model, system=payload.get("system", ""), max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)), temperature=hints.get("temperature", payload.get("temperature", 0.2)), timeout_s=timeout_s, ), timeout=timeout_s, ) elif job.required_type == "vision": result = await asyncio.wait_for( swapper_vision.infer( images=payload.get("images"), prompt=payload.get("prompt", ""), model=model, system=payload.get("system", ""), max_tokens=hints.get("max_tokens", 1024), temperature=hints.get("temperature", 0.2), timeout_s=timeout_s, ), timeout=timeout_s, ) else: return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not implemented"), ) logger.info( f"[job.done] job={job.job_id} status=ok " f"provider={result.get('provider')} model={result.get('model')}" ) return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="ok", provider=result.get("provider", ""), model=result.get("model", ""), result=result, ) except asyncio.TimeoutError: logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s") return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="PROVIDER_TIMEOUT"), ) except Exception as e: logger.warning(f"[job.error] job={job.job_id} {e}") return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]), ) def get_metrics() -> Dict[str, Any]: return { "inflight_jobs": _inflight_count, "concurrency_limit": config.MAX_CONCURRENCY, "queue_depth": 0, "last_latencies_llm": list(_latencies_llm[-_LATENCY_BUFFER:]), "last_latencies_vision": list(_latencies_vision[-_LATENCY_BUFFER:]), } def _report_latency_async(req_type: str, runtime: str, latency_ms: int): """Fire-and-forget latency report to local NCS.""" import httpx as _httpx ncs_url = os.getenv("NCS_REPORT_URL", "http://node-capabilities:8099") async def _do(): try: async with _httpx.AsyncClient(timeout=1) as c: await c.post(f"{ncs_url}/capabilities/report_latency", json={ "runtime": runtime, "type": req_type, "latency_ms": latency_ms, }) except Exception: pass try: asyncio.get_event_loop().create_task(_do()) except RuntimeError: pass async def _reply(msg, resp: JobResponse): if msg.reply: await _nats_client.publish(msg.reply, resp.model_dump_json().encode())