"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects.""" import asyncio import json import logging import os import time from typing import Any, Dict import config from models import JobRequest, JobResponse, JobError from idempotency import IdempotencyStore from providers import ollama, ollama_vision from providers import stt_mlx_whisper, tts_mlx_kokoro from providers import stt_memory_service, tts_memory_service import fabric_metrics as fm logger = logging.getLogger("node-worker") _idem = IdempotencyStore() _semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY) # Voice-dedicated semaphores — independent from generic MAX_CONCURRENCY. # Prevents voice requests from starving generic inference and vice versa. _voice_sem_tts: asyncio.Semaphore = asyncio.Semaphore(config.VOICE_MAX_CONCURRENT_TTS) _voice_sem_llm: asyncio.Semaphore = asyncio.Semaphore(config.VOICE_MAX_CONCURRENT_LLM) _voice_sem_stt: asyncio.Semaphore = asyncio.Semaphore(config.VOICE_MAX_CONCURRENT_STT) _VOICE_SEMAPHORES = { "voice.tts": _voice_sem_tts, "voice.llm": _voice_sem_llm, "voice.stt": _voice_sem_stt, } _nats_client = None _inflight_count: int = 0 _voice_inflight: Dict[str, int] = {"voice.tts": 0, "voice.llm": 0, "voice.stt": 0} _latencies_llm: list = [] _latencies_vision: list = [] _LATENCY_BUFFER = 50 # Set of subjects that use the voice handler path _VOICE_SUBJECTS: set = set() async def start(nats_client): global _nats_client _nats_client = nats_client nid = config.NODE_ID.lower() # Generic subjects (unchanged — backward compatible) subjects = [ f"node.{nid}.llm.request", f"node.{nid}.vision.request", f"node.{nid}.stt.request", f"node.{nid}.tts.request", f"node.{nid}.image.request", f"node.{nid}.ocr.request", ] for subj in subjects: await nats_client.subscribe(subj, cb=_handle_request) logger.info(f"✅ Subscribed: {subj}") # Voice HA subjects — separate semaphores, own metrics, own deadlines # Only subscribe if the relevant provider is configured (preflight-first) voice_subjects_to_caps = { f"node.{nid}.voice.tts.request": ("tts", _voice_sem_tts, "voice.tts"), f"node.{nid}.voice.llm.request": ("llm", _voice_sem_llm, "voice.llm"), f"node.{nid}.voice.stt.request": ("stt", _voice_sem_stt, "voice.stt"), } for subj, (required_cap, sem, cap_key) in voice_subjects_to_caps.items(): if required_cap == "tts" and config.TTS_PROVIDER == "none": logger.info(f"⏭ Skipping {subj}: TTS_PROVIDER=none") continue if required_cap == "stt" and config.STT_PROVIDER == "none": logger.info(f"⏭ Skipping {subj}: STT_PROVIDER=none") continue # LLM always available on this node _VOICE_SUBJECTS.add(subj) async def _make_voice_handler(s=sem, k=cap_key): async def _voice_handler(msg): await _handle_voice_request(msg, voice_sem=s, cap_key=k) return _voice_handler await nats_client.subscribe(subj, cb=await _make_voice_handler()) logger.info(f"✅ Voice subscribed: {subj}") async def _handle_request(msg): t0 = time.time() try: raw = msg.data if len(raw) > config.MAX_PAYLOAD_BYTES: await _reply(msg, JobResponse( status="error", node_id=config.NODE_ID, error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"), )) return data = json.loads(raw) job = JobRequest(**data) job.trace_id = job.trace_id or job.job_id idem_key = job.effective_idem_key() cached = _idem.get(idem_key) if cached: logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}") await _reply(msg, cached) return remaining = job.remaining_ms() if remaining <= 0: resp = JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"), ) _idem.put(idem_key, resp) await _reply(msg, resp) return inflight = await _idem.acquire_inflight(idem_key) if inflight is not None: try: resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0) resp_copy = resp.model_copy() resp_copy.cached = True await _reply(msg, resp_copy) except asyncio.TimeoutError: await _reply(msg, JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"), )) return if _semaphore.locked() and _semaphore._value == 0: resp = JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="busy", error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"), ) _idem.complete_inflight(idem_key, resp) await _reply(msg, resp) return global _inflight_count _inflight_count += 1 fm.set_inflight(_inflight_count) try: async with _semaphore: resp = await _execute(job, remaining) finally: _inflight_count -= 1 fm.set_inflight(_inflight_count) _idem.put(idem_key, resp) _idem.complete_inflight(idem_key, resp) resp.latency_ms = int((time.time() - t0) * 1000) fm.inc_job(job.required_type, resp.status) if resp.status == "ok" and resp.latency_ms > 0: fm.observe_latency(job.required_type, resp.model or "?", resp.latency_ms) buf = _latencies_llm if job.required_type in ("llm", "code") else _latencies_vision buf.append(resp.latency_ms) if len(buf) > _LATENCY_BUFFER: del buf[:len(buf) - _LATENCY_BUFFER] _report_latency_async(job.required_type, resp.provider or "ollama", resp.latency_ms) await _reply(msg, resp) except Exception as e: logger.exception(f"Worker handler error: {e}") try: await _reply(msg, JobResponse( node_id=config.NODE_ID, status="error", error=JobError(code="INTERNAL", message=str(e)[:200]), )) except Exception: pass async def _handle_voice_request(msg, voice_sem: asyncio.Semaphore, cap_key: str): """Voice-dedicated handler: separate semaphore, metrics, retry hints. Maps voice.{tts|llm|stt} to the same _execute() but with: - Own concurrency limit (VOICE_MAX_CONCURRENT_{TTS|LLM|STT}) - TOO_BUSY includes retry_after_ms hint (client can retry immediately elsewhere) - Voice-specific Prometheus labels (type=voice.tts, etc.) - WARNING log on fallback (contract: no silent fallback) """ t0 = time.time() # Extract the base type for _execute (voice.tts → tts) base_type = cap_key.split(".")[-1] # "tts", "llm", "stt" try: raw = msg.data if len(raw) > config.MAX_PAYLOAD_BYTES: await _reply(msg, JobResponse( node_id=config.NODE_ID, status="error", error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"), )) return data = json.loads(raw) job = JobRequest(**data) job.trace_id = job.trace_id or job.job_id remaining = job.remaining_ms() if remaining <= 0: await _reply(msg, JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="DEADLINE_EXCEEDED"), )) return # Voice concurrency check — TOO_BUSY includes retry hint if voice_sem._value == 0: logger.warning( "[voice.busy] cap=%s node=%s — all %d slots occupied. " "WARNING: request turned away, Router should failover.", cap_key, config.NODE_ID, { "voice.tts": config.VOICE_MAX_CONCURRENT_TTS, "voice.llm": config.VOICE_MAX_CONCURRENT_LLM, "voice.stt": config.VOICE_MAX_CONCURRENT_STT, }.get(cap_key, "?"), ) fm.inc_voice_job(cap_key, "busy") await _reply(msg, JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="busy", error=JobError( code="TOO_BUSY", message=f"voice {cap_key} at capacity", details={"retry_after_ms": 500, "cap": cap_key}, ), )) return global _voice_inflight _voice_inflight[cap_key] = _voice_inflight.get(cap_key, 0) + 1 fm.set_voice_inflight(cap_key, _voice_inflight[cap_key]) try: async with voice_sem: # Route to _execute with the base type job.required_type = base_type resp = await _execute(job, remaining) finally: _voice_inflight[cap_key] = max(0, _voice_inflight.get(cap_key, 1) - 1) fm.set_voice_inflight(cap_key, _voice_inflight[cap_key]) resp.latency_ms = int((time.time() - t0) * 1000) fm.inc_voice_job(cap_key, resp.status) if resp.status == "ok" and resp.latency_ms > 0: fm.observe_voice_latency(cap_key, resp.latency_ms) # Contract: log WARNING on any non-ok voice result if resp.status != "ok": logger.warning( "[voice.fallback] cap=%s node=%s status=%s error=%s trace=%s", cap_key, config.NODE_ID, resp.status, resp.error.code if resp.error else "?", job.trace_id, ) await _reply(msg, resp) except Exception as e: logger.exception(f"Voice handler error cap={cap_key}: {e}") fm.inc_voice_job(cap_key, "error") try: await _reply(msg, JobResponse( node_id=config.NODE_ID, status="error", error=JobError(code="INTERNAL", message=str(e)[:200]), )) except Exception: pass async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse: payload = job.payload hints = job.hints timeout_s = min(remaining_ms / 1000.0, 120.0) model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "") msg_count = len(payload.get("messages", [])) prompt_chars = len(payload.get("prompt", "")) logger.info( f"[job.start] job={job.job_id} trace={job.trace_id} " f"type={job.required_type} model={model or '?'} " f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms" ) try: if job.required_type == "llm": result = await asyncio.wait_for( ollama.infer( messages=payload.get("messages"), prompt=payload.get("prompt", ""), model=model, system=payload.get("system", ""), max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)), temperature=hints.get("temperature", payload.get("temperature", 0.2)), timeout_s=timeout_s, ), timeout=timeout_s, ) elif job.required_type == "vision": result = await asyncio.wait_for( ollama_vision.infer( images=payload.get("images"), prompt=payload.get("prompt", ""), model=model, system=payload.get("system", ""), max_tokens=hints.get("max_tokens", 1024), temperature=hints.get("temperature", 0.2), timeout_s=timeout_s, ), timeout=timeout_s, ) elif job.required_type == "stt": if config.STT_PROVIDER == "none": return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="NOT_AVAILABLE", message="STT not configured on this node"), ) if config.STT_PROVIDER == "memory_service": result = await asyncio.wait_for( stt_memory_service.transcribe(payload), timeout=timeout_s, ) else: result = await asyncio.wait_for( stt_mlx_whisper.transcribe(payload), timeout=timeout_s, ) elif job.required_type == "tts": if config.TTS_PROVIDER == "none": return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="NOT_AVAILABLE", message="TTS not configured on this node"), ) if config.TTS_PROVIDER == "memory_service": result = await asyncio.wait_for( tts_memory_service.synthesize(payload), timeout=timeout_s, ) else: result = await asyncio.wait_for( tts_mlx_kokoro.synthesize(payload), timeout=timeout_s, ) elif job.required_type == "ocr": if config.OCR_PROVIDER == "none": return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="NOT_AVAILABLE", message="OCR not configured on this node"), ) ocr_prompt = payload.get("prompt", "Extract all text from this image. Return JSON: {\"text\": \"...\", \"language\": \"...\"}") result = await asyncio.wait_for( ollama_vision.infer( images=payload.get("images"), prompt=ocr_prompt, model=model or config.DEFAULT_VISION, system="You are an OCR engine. Extract text precisely. Return valid JSON only.", max_tokens=hints.get("max_tokens", 4096), temperature=0.05, timeout_s=timeout_s, ), timeout=timeout_s, ) result["provider"] = "vision_prompted_ocr" elif job.required_type == "image": return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="NOT_YET_IMPLEMENTED", message="Image adapter pending P3.7"), ) else: return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not supported"), ) logger.info( f"[job.done] job={job.job_id} status=ok " f"provider={result.get('provider')} model={result.get('model')}" ) return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="ok", provider=result.get("provider", ""), model=result.get("model", ""), result=result, ) except asyncio.TimeoutError: logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s") return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="timeout", error=JobError(code="PROVIDER_TIMEOUT"), ) except Exception as e: logger.warning(f"[job.error] job={job.job_id} {e}") return JobResponse( job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]), ) def get_metrics() -> Dict[str, Any]: return { "inflight_jobs": _inflight_count, "concurrency_limit": config.MAX_CONCURRENCY, "queue_depth": 0, "last_latencies_llm": list(_latencies_llm[-_LATENCY_BUFFER:]), "last_latencies_vision": list(_latencies_vision[-_LATENCY_BUFFER:]), } def _report_latency_async(req_type: str, runtime: str, latency_ms: int): """Fire-and-forget latency report to local NCS.""" import httpx as _httpx ncs_url = os.getenv("NCS_REPORT_URL", "http://node-capabilities:8099") async def _do(): try: async with _httpx.AsyncClient(timeout=1) as c: await c.post(f"{ncs_url}/capabilities/report_latency", json={ "runtime": runtime, "type": req_type, "latency_ms": latency_ms, }) except Exception: pass try: asyncio.get_event_loop().create_task(_do()) except RuntimeError: pass async def _reply(msg, resp: JobResponse): if msg.reply: await _nats_client.publish(msg.reply, resp.model_dump_json().encode())