"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects.""" import asyncio import json import logging import os import time from typing import Any, Dict import config from models import JobRequest, JobResponse, JobError from idempotency import IdempotencyStore from providers import ollama, ollama_vision from providers import stt_mlx_whisper, tts_mlx_kokoro import fabric_metrics as fm logger = logging.getLogger("node-worker") _idem = IdempotencyStore() _semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY) _nats_client = None _inflight_count: int = 0 _latencies_llm: list = [] _latencies_vision: list = [] _LATENCY_BUFFER = 50 async def start(nats_client): global _nats_client _nats_client = nats_client nid = config.NODE_ID.lower() subjects = [ f"node.{nid}.llm.request", f"node.{nid}.vision.request", f"node.{nid}.stt.request", f"node.{nid}.tts.request", f"node.{nid}.image.request", f"node.{nid}.ocr.request", ] for subj in subjects: await nats_client.subscribe(subj, cb=_handle_request) logger.info(f"✅ Subscribed: {subj}") async def _handle_request(msg): t0 = time.time() try: raw = msg.data if len(raw) > config.MAX_PAYLOAD_BYTES: await _reply(msg, JobResponse( status="error", node_id=config.NODE_ID, error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"), )) return data = json.loads(raw) job = JobRequest(**data) job.trace_id = job.trace_id or job.job_id idem_key = job.effective_idem_key() cached = _idem.get(idem_key) if cached: logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}") await _reply(msg, cached) return remaining = job.remaining_ms() if remaining <= 0: resp = JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"), ) _idem.put(idem_key, resp) await _reply(msg, resp) return inflight = await _idem.acquire_inflight(idem_key) if inflight is not None: try: resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0) resp_copy = resp.model_copy() resp_copy.cached = True await _reply(msg, resp_copy) except asyncio.TimeoutError: await _reply(msg, JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"), )) return if _semaphore.locked() and _semaphore._value == 0: resp = JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="busy", error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"), ) _idem.complete_inflight(idem_key, resp) await _reply(msg, resp) return global _inflight_count _inflight_count += 1 fm.set_inflight(_inflight_count) try: async with _semaphore: resp = await _execute(job, remaining) finally: _inflight_count -= 1 fm.set_inflight(_inflight_count) _idem.put(idem_key, resp) _idem.complete_inflight(idem_key, resp) resp.latency_ms = int((time.time() - t0) * 1000) fm.inc_job(job.required_type, resp.status) if resp.status == "ok" and resp.latency_ms > 0: fm.observe_latency(job.required_type, resp.model or "?", resp.latency_ms) buf = _latencies_llm if job.required_type in ("llm", "code") else _latencies_vision buf.append(resp.latency_ms) if len(buf) > _LATENCY_BUFFER: del buf[:len(buf) - _LATENCY_BUFFER] _report_latency_async(job.required_type, resp.provider or "ollama", resp.latency_ms) await _reply(msg, resp) except Exception as e: logger.exception(f"Worker handler error: {e}") try: await _reply(msg, JobResponse( node_id=config.NODE_ID, status="error", error=JobError(code="INTERNAL", message=str(e)[:200]), )) except Exception: pass async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse: payload = job.payload hints = job.hints timeout_s = min(remaining_ms / 1000.0, 120.0) model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "") msg_count = len(payload.get("messages", [])) prompt_chars = len(payload.get("prompt", "")) logger.info( f"[job.start] job={job.job_id} trace={job.trace_id} " f"type={job.required_type} model={model or '?'} " f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms" ) try: if job.required_type == "llm": result = await asyncio.wait_for( ollama.infer( messages=payload.get("messages"), prompt=payload.get("prompt", ""), model=model, system=payload.get("system", ""), max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)), temperature=hints.get("temperature", payload.get("temperature", 0.2)), timeout_s=timeout_s, ), timeout=timeout_s, ) elif job.required_type == "vision": result = await asyncio.wait_for( ollama_vision.infer( images=payload.get("images"), prompt=payload.get("prompt", ""), model=model, system=payload.get("system", ""), max_tokens=hints.get("max_tokens", 1024), temperature=hints.get("temperature", 0.2), timeout_s=timeout_s, ), timeout=timeout_s, ) elif job.required_type == "stt": if config.STT_PROVIDER == "none": return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="NOT_AVAILABLE", message="STT not configured on this node"), ) result = await asyncio.wait_for( stt_mlx_whisper.transcribe(payload), timeout=timeout_s, ) elif job.required_type == "tts": if config.TTS_PROVIDER == "none": return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="NOT_AVAILABLE", message="TTS not configured on this node"), ) result = await asyncio.wait_for( tts_mlx_kokoro.synthesize(payload), timeout=timeout_s, ) elif job.required_type == "ocr": if config.OCR_PROVIDER == "none": return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="NOT_AVAILABLE", message="OCR not configured on this node"), ) ocr_prompt = payload.get("prompt", "Extract all text from this image. Return JSON: {\"text\": \"...\", \"language\": \"...\"}") result = await asyncio.wait_for( ollama_vision.infer( images=payload.get("images"), prompt=ocr_prompt, model=model or config.DEFAULT_VISION, system="You are an OCR engine. Extract text precisely. Return valid JSON only.", max_tokens=hints.get("max_tokens", 4096), temperature=0.05, timeout_s=timeout_s, ), timeout=timeout_s, ) result["provider"] = "vision_prompted_ocr" elif job.required_type == "image": return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="NOT_YET_IMPLEMENTED", message="Image adapter pending P3.7"), ) else: return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not supported"), ) logger.info( f"[job.done] job={job.job_id} status=ok " f"provider={result.get('provider')} model={result.get('model')}" ) return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="ok", provider=result.get("provider", ""), model=result.get("model", ""), result=result, ) except asyncio.TimeoutError: logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s") return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="PROVIDER_TIMEOUT"), ) except Exception as e: logger.warning(f"[job.error] job={job.job_id} {e}") return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]), ) def get_metrics() -> Dict[str, Any]: return { "inflight_jobs": _inflight_count, "concurrency_limit": config.MAX_CONCURRENCY, "queue_depth": 0, "last_latencies_llm": list(_latencies_llm[-_LATENCY_BUFFER:]), "last_latencies_vision": list(_latencies_vision[-_LATENCY_BUFFER:]), } def _report_latency_async(req_type: str, runtime: str, latency_ms: int): """Fire-and-forget latency report to local NCS.""" import httpx as _httpx ncs_url = os.getenv("NCS_REPORT_URL", "http://node-capabilities:8099") async def _do(): try: async with _httpx.AsyncClient(timeout=1) as c: await c.post(f"{ncs_url}/capabilities/report_latency", json={ "runtime": runtime, "type": req_type, "latency_ms": latency_ms, }) except Exception: pass try: asyncio.get_event_loop().create_task(_do()) except RuntimeError: pass async def _reply(msg, resp: JobResponse): if msg.reply: await _nats_client.publish(msg.reply, resp.model_dump_json().encode())