"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects.""" import asyncio import json import logging import time from typing import Any, Dict import config from models import JobRequest, JobResponse, JobError from idempotency import IdempotencyStore from providers import ollama, swapper_vision logger = logging.getLogger("node-worker") _idem = IdempotencyStore() _semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY) _nats_client = None async def start(nats_client): global _nats_client _nats_client = nats_client subjects = [ f"node.{config.NODE_ID.lower()}.llm.request", f"node.{config.NODE_ID.lower()}.vision.request", ] for subj in subjects: await nats_client.subscribe(subj, cb=_handle_request) logger.info(f"✅ Subscribed: {subj}") async def _handle_request(msg): t0 = time.time() try: raw = msg.data if len(raw) > config.MAX_PAYLOAD_BYTES: await _reply(msg, JobResponse( status="error", node_id=config.NODE_ID, error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"), )) return data = json.loads(raw) job = JobRequest(**data) job.trace_id = job.trace_id or job.job_id idem_key = job.effective_idem_key() cached = _idem.get(idem_key) if cached: logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}") await _reply(msg, cached) return remaining = job.remaining_ms() if remaining <= 0: resp = JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"), ) _idem.put(idem_key, resp) await _reply(msg, resp) return inflight = await _idem.acquire_inflight(idem_key) if inflight is not None: try: resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0) resp_copy = resp.model_copy() resp_copy.cached = True await _reply(msg, resp_copy) except asyncio.TimeoutError: await _reply(msg, JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"), )) return if _semaphore.locked() and _semaphore._value == 0: resp = JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="busy", error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"), ) _idem.complete_inflight(idem_key, resp) await _reply(msg, resp) return async with _semaphore: resp = await _execute(job, remaining) _idem.put(idem_key, resp) _idem.complete_inflight(idem_key, resp) resp.latency_ms = int((time.time() - t0) * 1000) await _reply(msg, resp) except Exception as e: logger.exception(f"Worker handler error: {e}") try: await _reply(msg, JobResponse( node_id=config.NODE_ID, status="error", error=JobError(code="INTERNAL", message=str(e)[:200]), )) except Exception: pass async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse: payload = job.payload hints = job.hints timeout_s = min(remaining_ms / 1000.0, 120.0) model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "") msg_count = len(payload.get("messages", [])) prompt_chars = len(payload.get("prompt", "")) logger.info( f"[job.start] job={job.job_id} trace={job.trace_id} " f"type={job.required_type} model={model or '?'} " f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms" ) try: if job.required_type == "llm": result = await asyncio.wait_for( ollama.infer( messages=payload.get("messages"), prompt=payload.get("prompt", ""), model=model, system=payload.get("system", ""), max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)), temperature=hints.get("temperature", payload.get("temperature", 0.2)), timeout_s=timeout_s, ), timeout=timeout_s, ) elif job.required_type == "vision": result = await asyncio.wait_for( swapper_vision.infer( images=payload.get("images"), prompt=payload.get("prompt", ""), model=model, system=payload.get("system", ""), max_tokens=hints.get("max_tokens", 1024), temperature=hints.get("temperature", 0.2), timeout_s=timeout_s, ), timeout=timeout_s, ) else: return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not implemented"), ) logger.info( f"[job.done] job={job.job_id} status=ok " f"provider={result.get('provider')} model={result.get('model')}" ) return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="ok", provider=result.get("provider", ""), model=result.get("model", ""), result=result, ) except asyncio.TimeoutError: logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s") return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="PROVIDER_TIMEOUT"), ) except Exception as e: logger.warning(f"[job.error] job={job.job_id} {e}") return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]), ) async def _reply(msg, resp: JobResponse): if msg.reply: await _nats_client.publish(msg.reply, resp.model_dump_json().encode())