Files
microdao-daarion/services/node-worker/worker.py
Apple 9a36020316 P3.5-P3.7: 2-layer inventory, capability routing, STT/TTS adapters, Dev Contract
NCS:
- _collect_worker_caps() fetches capability flags from node-worker /caps
- _derive_capabilities() merges served model types + worker provider flags
- installed_artifacts replaces inventory_only (disk scan with DISK_SCAN_PATHS env)
- New endpoints: /capabilities/caps, /capabilities/installed

Node Worker:
- STT_PROVIDER, TTS_PROVIDER, OCR_PROVIDER, IMAGE_PROVIDER env flags
- /caps endpoint returns capabilities + providers for NCS aggregation
- STT adapter (providers/stt_mlx_whisper.py) — remote + local mode
- TTS adapter (providers/tts_mlx_kokoro.py) — remote + local mode
- OCR handler via vision_prompted (ollama_vision with OCR prompt)
- NATS subjects: node.{id}.stt/tts/ocr/image.request

Router:
- POST /v1/capability/{stt,tts,ocr,image} — capability-based offload routing
- GET /v1/capabilities — global view with capabilities_by_node
- require_fresh_caps(ttl) preflight guard
- find_nodes_with_capability(cap) + load-based node selection

Ops:
- ops/fabric_snapshot.py — full runtime snapshot collector
- ops/fabric_preflight.sh — quick check + snapshot save + diff
- docs/fabric_contract.md — Dev Contract v0.1 (preflight-first)
- tests/test_fabric_contract.py — CI enforcement (6 tests)

Made-with: Cursor
2026-02-27 05:24:09 -08:00

292 lines
11 KiB
Python

"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects."""
import asyncio
import json
import logging
import os
import time
from typing import Any, Dict
import config
from models import JobRequest, JobResponse, JobError
from idempotency import IdempotencyStore
from providers import ollama, ollama_vision
from providers import stt_mlx_whisper, tts_mlx_kokoro
import fabric_metrics as fm
logger = logging.getLogger("node-worker")
_idem = IdempotencyStore()
_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY)
_nats_client = None
_inflight_count: int = 0
_latencies_llm: list = []
_latencies_vision: list = []
_LATENCY_BUFFER = 50
async def start(nats_client):
global _nats_client
_nats_client = nats_client
nid = config.NODE_ID.lower()
subjects = [
f"node.{nid}.llm.request",
f"node.{nid}.vision.request",
f"node.{nid}.stt.request",
f"node.{nid}.tts.request",
f"node.{nid}.image.request",
f"node.{nid}.ocr.request",
]
for subj in subjects:
await nats_client.subscribe(subj, cb=_handle_request)
logger.info(f"✅ Subscribed: {subj}")
async def _handle_request(msg):
t0 = time.time()
try:
raw = msg.data
if len(raw) > config.MAX_PAYLOAD_BYTES:
await _reply(msg, JobResponse(
status="error",
node_id=config.NODE_ID,
error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"),
))
return
data = json.loads(raw)
job = JobRequest(**data)
job.trace_id = job.trace_id or job.job_id
idem_key = job.effective_idem_key()
cached = _idem.get(idem_key)
if cached:
logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}")
await _reply(msg, cached)
return
remaining = job.remaining_ms()
if remaining <= 0:
resp = JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout",
error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"),
)
_idem.put(idem_key, resp)
await _reply(msg, resp)
return
inflight = await _idem.acquire_inflight(idem_key)
if inflight is not None:
try:
resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0)
resp_copy = resp.model_copy()
resp_copy.cached = True
await _reply(msg, resp_copy)
except asyncio.TimeoutError:
await _reply(msg, JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"),
))
return
if _semaphore.locked() and _semaphore._value == 0:
resp = JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="busy",
error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"),
)
_idem.complete_inflight(idem_key, resp)
await _reply(msg, resp)
return
global _inflight_count
_inflight_count += 1
fm.set_inflight(_inflight_count)
try:
async with _semaphore:
resp = await _execute(job, remaining)
finally:
_inflight_count -= 1
fm.set_inflight(_inflight_count)
_idem.put(idem_key, resp)
_idem.complete_inflight(idem_key, resp)
resp.latency_ms = int((time.time() - t0) * 1000)
fm.inc_job(job.required_type, resp.status)
if resp.status == "ok" and resp.latency_ms > 0:
fm.observe_latency(job.required_type, resp.model or "?", resp.latency_ms)
buf = _latencies_llm if job.required_type in ("llm", "code") else _latencies_vision
buf.append(resp.latency_ms)
if len(buf) > _LATENCY_BUFFER:
del buf[:len(buf) - _LATENCY_BUFFER]
_report_latency_async(job.required_type, resp.provider or "ollama", resp.latency_ms)
await _reply(msg, resp)
except Exception as e:
logger.exception(f"Worker handler error: {e}")
try:
await _reply(msg, JobResponse(
node_id=config.NODE_ID, status="error",
error=JobError(code="INTERNAL", message=str(e)[:200]),
))
except Exception:
pass
async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
payload = job.payload
hints = job.hints
timeout_s = min(remaining_ms / 1000.0, 120.0)
model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "")
msg_count = len(payload.get("messages", []))
prompt_chars = len(payload.get("prompt", ""))
logger.info(
f"[job.start] job={job.job_id} trace={job.trace_id} "
f"type={job.required_type} model={model or '?'} "
f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms"
)
try:
if job.required_type == "llm":
result = await asyncio.wait_for(
ollama.infer(
messages=payload.get("messages"),
prompt=payload.get("prompt", ""),
model=model,
system=payload.get("system", ""),
max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)),
temperature=hints.get("temperature", payload.get("temperature", 0.2)),
timeout_s=timeout_s,
),
timeout=timeout_s,
)
elif job.required_type == "vision":
result = await asyncio.wait_for(
ollama_vision.infer(
images=payload.get("images"),
prompt=payload.get("prompt", ""),
model=model,
system=payload.get("system", ""),
max_tokens=hints.get("max_tokens", 1024),
temperature=hints.get("temperature", 0.2),
timeout_s=timeout_s,
),
timeout=timeout_s,
)
elif job.required_type == "stt":
if config.STT_PROVIDER == "none":
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="NOT_AVAILABLE", message="STT not configured on this node"),
)
result = await asyncio.wait_for(
stt_mlx_whisper.transcribe(payload), timeout=timeout_s,
)
elif job.required_type == "tts":
if config.TTS_PROVIDER == "none":
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="NOT_AVAILABLE", message="TTS not configured on this node"),
)
result = await asyncio.wait_for(
tts_mlx_kokoro.synthesize(payload), timeout=timeout_s,
)
elif job.required_type == "ocr":
if config.OCR_PROVIDER == "none":
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="NOT_AVAILABLE", message="OCR not configured on this node"),
)
ocr_prompt = payload.get("prompt", "Extract all text from this image. Return JSON: {\"text\": \"...\", \"language\": \"...\"}")
result = await asyncio.wait_for(
ollama_vision.infer(
images=payload.get("images"),
prompt=ocr_prompt,
model=model or config.DEFAULT_VISION,
system="You are an OCR engine. Extract text precisely. Return valid JSON only.",
max_tokens=hints.get("max_tokens", 4096),
temperature=0.05,
timeout_s=timeout_s,
),
timeout=timeout_s,
)
result["provider"] = "vision_prompted_ocr"
elif job.required_type == "image":
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="NOT_YET_IMPLEMENTED", message="Image adapter pending P3.7"),
)
else:
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not supported"),
)
logger.info(
f"[job.done] job={job.job_id} status=ok "
f"provider={result.get('provider')} model={result.get('model')}"
)
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="ok", provider=result.get("provider", ""), model=result.get("model", ""),
result=result,
)
except asyncio.TimeoutError:
logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s")
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout", error=JobError(code="PROVIDER_TIMEOUT"),
)
except Exception as e:
logger.warning(f"[job.error] job={job.job_id} {e}")
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]),
)
def get_metrics() -> Dict[str, Any]:
return {
"inflight_jobs": _inflight_count,
"concurrency_limit": config.MAX_CONCURRENCY,
"queue_depth": 0,
"last_latencies_llm": list(_latencies_llm[-_LATENCY_BUFFER:]),
"last_latencies_vision": list(_latencies_vision[-_LATENCY_BUFFER:]),
}
def _report_latency_async(req_type: str, runtime: str, latency_ms: int):
"""Fire-and-forget latency report to local NCS."""
import httpx as _httpx
ncs_url = os.getenv("NCS_REPORT_URL", "http://node-capabilities:8099")
async def _do():
try:
async with _httpx.AsyncClient(timeout=1) as c:
await c.post(f"{ncs_url}/capabilities/report_latency", json={
"runtime": runtime, "type": req_type, "latency_ms": latency_ms,
})
except Exception:
pass
try:
asyncio.get_event_loop().create_task(_do())
except RuntimeError:
pass
async def _reply(msg, resp: JobResponse):
if msg.reply:
await _nats_client.publish(msg.reply, resp.model_dump_json().encode())