NCS:
- _collect_worker_caps() fetches capability flags from node-worker /caps
- _derive_capabilities() merges served model types + worker provider flags
- installed_artifacts replaces inventory_only (disk scan with DISK_SCAN_PATHS env)
- New endpoints: /capabilities/caps, /capabilities/installed
Node Worker:
- STT_PROVIDER, TTS_PROVIDER, OCR_PROVIDER, IMAGE_PROVIDER env flags
- /caps endpoint returns capabilities + providers for NCS aggregation
- STT adapter (providers/stt_mlx_whisper.py) — remote + local mode
- TTS adapter (providers/tts_mlx_kokoro.py) — remote + local mode
- OCR handler via vision_prompted (ollama_vision with OCR prompt)
- NATS subjects: node.{id}.stt/tts/ocr/image.request
Router:
- POST /v1/capability/{stt,tts,ocr,image} — capability-based offload routing
- GET /v1/capabilities — global view with capabilities_by_node
- require_fresh_caps(ttl) preflight guard
- find_nodes_with_capability(cap) + load-based node selection
Ops:
- ops/fabric_snapshot.py — full runtime snapshot collector
- ops/fabric_preflight.sh — quick check + snapshot save + diff
- docs/fabric_contract.md — Dev Contract v0.1 (preflight-first)
- tests/test_fabric_contract.py — CI enforcement (6 tests)
Made-with: Cursor
292 lines
11 KiB
Python
292 lines
11 KiB
Python
"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects."""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from typing import Any, Dict
|
|
|
|
import config
|
|
from models import JobRequest, JobResponse, JobError
|
|
from idempotency import IdempotencyStore
|
|
from providers import ollama, ollama_vision
|
|
from providers import stt_mlx_whisper, tts_mlx_kokoro
|
|
import fabric_metrics as fm
|
|
|
|
logger = logging.getLogger("node-worker")
|
|
|
|
_idem = IdempotencyStore()
|
|
_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY)
|
|
_nats_client = None
|
|
_inflight_count: int = 0
|
|
_latencies_llm: list = []
|
|
_latencies_vision: list = []
|
|
_LATENCY_BUFFER = 50
|
|
|
|
|
|
async def start(nats_client):
|
|
global _nats_client
|
|
_nats_client = nats_client
|
|
|
|
nid = config.NODE_ID.lower()
|
|
subjects = [
|
|
f"node.{nid}.llm.request",
|
|
f"node.{nid}.vision.request",
|
|
f"node.{nid}.stt.request",
|
|
f"node.{nid}.tts.request",
|
|
f"node.{nid}.image.request",
|
|
f"node.{nid}.ocr.request",
|
|
]
|
|
for subj in subjects:
|
|
await nats_client.subscribe(subj, cb=_handle_request)
|
|
logger.info(f"✅ Subscribed: {subj}")
|
|
|
|
|
|
async def _handle_request(msg):
|
|
t0 = time.time()
|
|
try:
|
|
raw = msg.data
|
|
if len(raw) > config.MAX_PAYLOAD_BYTES:
|
|
await _reply(msg, JobResponse(
|
|
status="error",
|
|
node_id=config.NODE_ID,
|
|
error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"),
|
|
))
|
|
return
|
|
|
|
data = json.loads(raw)
|
|
job = JobRequest(**data)
|
|
job.trace_id = job.trace_id or job.job_id
|
|
|
|
idem_key = job.effective_idem_key()
|
|
cached = _idem.get(idem_key)
|
|
if cached:
|
|
logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}")
|
|
await _reply(msg, cached)
|
|
return
|
|
|
|
remaining = job.remaining_ms()
|
|
if remaining <= 0:
|
|
resp = JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="timeout",
|
|
error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"),
|
|
)
|
|
_idem.put(idem_key, resp)
|
|
await _reply(msg, resp)
|
|
return
|
|
|
|
inflight = await _idem.acquire_inflight(idem_key)
|
|
if inflight is not None:
|
|
try:
|
|
resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0)
|
|
resp_copy = resp.model_copy()
|
|
resp_copy.cached = True
|
|
await _reply(msg, resp_copy)
|
|
except asyncio.TimeoutError:
|
|
await _reply(msg, JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"),
|
|
))
|
|
return
|
|
|
|
if _semaphore.locked() and _semaphore._value == 0:
|
|
resp = JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="busy",
|
|
error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"),
|
|
)
|
|
_idem.complete_inflight(idem_key, resp)
|
|
await _reply(msg, resp)
|
|
return
|
|
|
|
global _inflight_count
|
|
_inflight_count += 1
|
|
fm.set_inflight(_inflight_count)
|
|
try:
|
|
async with _semaphore:
|
|
resp = await _execute(job, remaining)
|
|
finally:
|
|
_inflight_count -= 1
|
|
fm.set_inflight(_inflight_count)
|
|
|
|
_idem.put(idem_key, resp)
|
|
_idem.complete_inflight(idem_key, resp)
|
|
resp.latency_ms = int((time.time() - t0) * 1000)
|
|
|
|
fm.inc_job(job.required_type, resp.status)
|
|
if resp.status == "ok" and resp.latency_ms > 0:
|
|
fm.observe_latency(job.required_type, resp.model or "?", resp.latency_ms)
|
|
buf = _latencies_llm if job.required_type in ("llm", "code") else _latencies_vision
|
|
buf.append(resp.latency_ms)
|
|
if len(buf) > _LATENCY_BUFFER:
|
|
del buf[:len(buf) - _LATENCY_BUFFER]
|
|
_report_latency_async(job.required_type, resp.provider or "ollama", resp.latency_ms)
|
|
|
|
await _reply(msg, resp)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Worker handler error: {e}")
|
|
try:
|
|
await _reply(msg, JobResponse(
|
|
node_id=config.NODE_ID, status="error",
|
|
error=JobError(code="INTERNAL", message=str(e)[:200]),
|
|
))
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
|
|
payload = job.payload
|
|
hints = job.hints
|
|
timeout_s = min(remaining_ms / 1000.0, 120.0)
|
|
model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "")
|
|
|
|
msg_count = len(payload.get("messages", []))
|
|
prompt_chars = len(payload.get("prompt", ""))
|
|
logger.info(
|
|
f"[job.start] job={job.job_id} trace={job.trace_id} "
|
|
f"type={job.required_type} model={model or '?'} "
|
|
f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms"
|
|
)
|
|
|
|
try:
|
|
if job.required_type == "llm":
|
|
result = await asyncio.wait_for(
|
|
ollama.infer(
|
|
messages=payload.get("messages"),
|
|
prompt=payload.get("prompt", ""),
|
|
model=model,
|
|
system=payload.get("system", ""),
|
|
max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)),
|
|
temperature=hints.get("temperature", payload.get("temperature", 0.2)),
|
|
timeout_s=timeout_s,
|
|
),
|
|
timeout=timeout_s,
|
|
)
|
|
elif job.required_type == "vision":
|
|
result = await asyncio.wait_for(
|
|
ollama_vision.infer(
|
|
images=payload.get("images"),
|
|
prompt=payload.get("prompt", ""),
|
|
model=model,
|
|
system=payload.get("system", ""),
|
|
max_tokens=hints.get("max_tokens", 1024),
|
|
temperature=hints.get("temperature", 0.2),
|
|
timeout_s=timeout_s,
|
|
),
|
|
timeout=timeout_s,
|
|
)
|
|
elif job.required_type == "stt":
|
|
if config.STT_PROVIDER == "none":
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="error",
|
|
error=JobError(code="NOT_AVAILABLE", message="STT not configured on this node"),
|
|
)
|
|
result = await asyncio.wait_for(
|
|
stt_mlx_whisper.transcribe(payload), timeout=timeout_s,
|
|
)
|
|
elif job.required_type == "tts":
|
|
if config.TTS_PROVIDER == "none":
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="error",
|
|
error=JobError(code="NOT_AVAILABLE", message="TTS not configured on this node"),
|
|
)
|
|
result = await asyncio.wait_for(
|
|
tts_mlx_kokoro.synthesize(payload), timeout=timeout_s,
|
|
)
|
|
elif job.required_type == "ocr":
|
|
if config.OCR_PROVIDER == "none":
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="error",
|
|
error=JobError(code="NOT_AVAILABLE", message="OCR not configured on this node"),
|
|
)
|
|
ocr_prompt = payload.get("prompt", "Extract all text from this image. Return JSON: {\"text\": \"...\", \"language\": \"...\"}")
|
|
result = await asyncio.wait_for(
|
|
ollama_vision.infer(
|
|
images=payload.get("images"),
|
|
prompt=ocr_prompt,
|
|
model=model or config.DEFAULT_VISION,
|
|
system="You are an OCR engine. Extract text precisely. Return valid JSON only.",
|
|
max_tokens=hints.get("max_tokens", 4096),
|
|
temperature=0.05,
|
|
timeout_s=timeout_s,
|
|
),
|
|
timeout=timeout_s,
|
|
)
|
|
result["provider"] = "vision_prompted_ocr"
|
|
elif job.required_type == "image":
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="error",
|
|
error=JobError(code="NOT_YET_IMPLEMENTED", message="Image adapter pending P3.7"),
|
|
)
|
|
else:
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="error",
|
|
error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not supported"),
|
|
)
|
|
|
|
logger.info(
|
|
f"[job.done] job={job.job_id} status=ok "
|
|
f"provider={result.get('provider')} model={result.get('model')}"
|
|
)
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="ok", provider=result.get("provider", ""), model=result.get("model", ""),
|
|
result=result,
|
|
)
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s")
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="timeout", error=JobError(code="PROVIDER_TIMEOUT"),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"[job.error] job={job.job_id} {e}")
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]),
|
|
)
|
|
|
|
|
|
def get_metrics() -> Dict[str, Any]:
|
|
return {
|
|
"inflight_jobs": _inflight_count,
|
|
"concurrency_limit": config.MAX_CONCURRENCY,
|
|
"queue_depth": 0,
|
|
"last_latencies_llm": list(_latencies_llm[-_LATENCY_BUFFER:]),
|
|
"last_latencies_vision": list(_latencies_vision[-_LATENCY_BUFFER:]),
|
|
}
|
|
|
|
|
|
def _report_latency_async(req_type: str, runtime: str, latency_ms: int):
|
|
"""Fire-and-forget latency report to local NCS."""
|
|
import httpx as _httpx
|
|
|
|
ncs_url = os.getenv("NCS_REPORT_URL", "http://node-capabilities:8099")
|
|
|
|
async def _do():
|
|
try:
|
|
async with _httpx.AsyncClient(timeout=1) as c:
|
|
await c.post(f"{ncs_url}/capabilities/report_latency", json={
|
|
"runtime": runtime, "type": req_type, "latency_ms": latency_ms,
|
|
})
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
asyncio.get_event_loop().create_task(_do())
|
|
except RuntimeError:
|
|
pass
|
|
|
|
|
|
async def _reply(msg, resp: JobResponse):
|
|
if msg.reply:
|
|
await _nats_client.publish(msg.reply, resp.model_dump_json().encode())
|