P3.2 — Multi-node deployment: - Added node-worker service to docker-compose.node1.yml (NODE_ID=noda1) - NCS NODA1 now has NODE_WORKER_URL for metrics collection - Fixed NODE_ID consistency: router NODA1 uses 'noda1' - NODA2 node-worker/NCS gets NCS_REPORT_URL for latency reporting P3.3 — NATS accounts/auth (opt-in config): - config/nats-server.conf with 3 accounts: SYS, FABRIC, APP - Per-user topic permissions (router, ncs, node_worker) - Leafnode listener :7422 with auth - Not yet activated (requires credential provisioning) P3.4 — Prometheus counters: - Router /fabric_metrics: caps_refresh, caps_stale, model_select, offload_total, breaker_state, score_ms histogram - Node Worker /prom_metrics: jobs_total, inflight gauge, latency_ms histogram - NCS /prom_metrics: runtime_health, runtime_p50/p95, node_wait_ms - All bound to 127.0.0.1 (not externally exposed) Made-with: Cursor
239 lines
8.5 KiB
Python
239 lines
8.5 KiB
Python
"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects."""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from typing import Any, Dict
|
|
|
|
import config
|
|
from models import JobRequest, JobResponse, JobError
|
|
from idempotency import IdempotencyStore
|
|
from providers import ollama, swapper_vision
|
|
import fabric_metrics as fm
|
|
|
|
logger = logging.getLogger("node-worker")
|
|
|
|
_idem = IdempotencyStore()
|
|
_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY)
|
|
_nats_client = None
|
|
_inflight_count: int = 0
|
|
_latencies_llm: list = []
|
|
_latencies_vision: list = []
|
|
_LATENCY_BUFFER = 50
|
|
|
|
|
|
async def start(nats_client):
|
|
global _nats_client
|
|
_nats_client = nats_client
|
|
|
|
subjects = [
|
|
f"node.{config.NODE_ID.lower()}.llm.request",
|
|
f"node.{config.NODE_ID.lower()}.vision.request",
|
|
]
|
|
for subj in subjects:
|
|
await nats_client.subscribe(subj, cb=_handle_request)
|
|
logger.info(f"✅ Subscribed: {subj}")
|
|
|
|
|
|
async def _handle_request(msg):
|
|
t0 = time.time()
|
|
try:
|
|
raw = msg.data
|
|
if len(raw) > config.MAX_PAYLOAD_BYTES:
|
|
await _reply(msg, JobResponse(
|
|
status="error",
|
|
node_id=config.NODE_ID,
|
|
error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"),
|
|
))
|
|
return
|
|
|
|
data = json.loads(raw)
|
|
job = JobRequest(**data)
|
|
job.trace_id = job.trace_id or job.job_id
|
|
|
|
idem_key = job.effective_idem_key()
|
|
cached = _idem.get(idem_key)
|
|
if cached:
|
|
logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}")
|
|
await _reply(msg, cached)
|
|
return
|
|
|
|
remaining = job.remaining_ms()
|
|
if remaining <= 0:
|
|
resp = JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="timeout",
|
|
error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"),
|
|
)
|
|
_idem.put(idem_key, resp)
|
|
await _reply(msg, resp)
|
|
return
|
|
|
|
inflight = await _idem.acquire_inflight(idem_key)
|
|
if inflight is not None:
|
|
try:
|
|
resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0)
|
|
resp_copy = resp.model_copy()
|
|
resp_copy.cached = True
|
|
await _reply(msg, resp_copy)
|
|
except asyncio.TimeoutError:
|
|
await _reply(msg, JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"),
|
|
))
|
|
return
|
|
|
|
if _semaphore.locked() and _semaphore._value == 0:
|
|
resp = JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="busy",
|
|
error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"),
|
|
)
|
|
_idem.complete_inflight(idem_key, resp)
|
|
await _reply(msg, resp)
|
|
return
|
|
|
|
global _inflight_count
|
|
_inflight_count += 1
|
|
fm.set_inflight(_inflight_count)
|
|
try:
|
|
async with _semaphore:
|
|
resp = await _execute(job, remaining)
|
|
finally:
|
|
_inflight_count -= 1
|
|
fm.set_inflight(_inflight_count)
|
|
|
|
_idem.put(idem_key, resp)
|
|
_idem.complete_inflight(idem_key, resp)
|
|
resp.latency_ms = int((time.time() - t0) * 1000)
|
|
|
|
fm.inc_job(job.required_type, resp.status)
|
|
if resp.status == "ok" and resp.latency_ms > 0:
|
|
fm.observe_latency(job.required_type, resp.model or "?", resp.latency_ms)
|
|
buf = _latencies_llm if job.required_type in ("llm", "code") else _latencies_vision
|
|
buf.append(resp.latency_ms)
|
|
if len(buf) > _LATENCY_BUFFER:
|
|
del buf[:len(buf) - _LATENCY_BUFFER]
|
|
_report_latency_async(job.required_type, resp.provider or "ollama", resp.latency_ms)
|
|
|
|
await _reply(msg, resp)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Worker handler error: {e}")
|
|
try:
|
|
await _reply(msg, JobResponse(
|
|
node_id=config.NODE_ID, status="error",
|
|
error=JobError(code="INTERNAL", message=str(e)[:200]),
|
|
))
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
|
|
payload = job.payload
|
|
hints = job.hints
|
|
timeout_s = min(remaining_ms / 1000.0, 120.0)
|
|
model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "")
|
|
|
|
msg_count = len(payload.get("messages", []))
|
|
prompt_chars = len(payload.get("prompt", ""))
|
|
logger.info(
|
|
f"[job.start] job={job.job_id} trace={job.trace_id} "
|
|
f"type={job.required_type} model={model or '?'} "
|
|
f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms"
|
|
)
|
|
|
|
try:
|
|
if job.required_type == "llm":
|
|
result = await asyncio.wait_for(
|
|
ollama.infer(
|
|
messages=payload.get("messages"),
|
|
prompt=payload.get("prompt", ""),
|
|
model=model,
|
|
system=payload.get("system", ""),
|
|
max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)),
|
|
temperature=hints.get("temperature", payload.get("temperature", 0.2)),
|
|
timeout_s=timeout_s,
|
|
),
|
|
timeout=timeout_s,
|
|
)
|
|
elif job.required_type == "vision":
|
|
result = await asyncio.wait_for(
|
|
swapper_vision.infer(
|
|
images=payload.get("images"),
|
|
prompt=payload.get("prompt", ""),
|
|
model=model,
|
|
system=payload.get("system", ""),
|
|
max_tokens=hints.get("max_tokens", 1024),
|
|
temperature=hints.get("temperature", 0.2),
|
|
timeout_s=timeout_s,
|
|
),
|
|
timeout=timeout_s,
|
|
)
|
|
else:
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="error",
|
|
error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not implemented"),
|
|
)
|
|
|
|
logger.info(
|
|
f"[job.done] job={job.job_id} status=ok "
|
|
f"provider={result.get('provider')} model={result.get('model')}"
|
|
)
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="ok", provider=result.get("provider", ""), model=result.get("model", ""),
|
|
result=result,
|
|
)
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s")
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="timeout", error=JobError(code="PROVIDER_TIMEOUT"),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"[job.error] job={job.job_id} {e}")
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]),
|
|
)
|
|
|
|
|
|
def get_metrics() -> Dict[str, Any]:
|
|
return {
|
|
"inflight_jobs": _inflight_count,
|
|
"concurrency_limit": config.MAX_CONCURRENCY,
|
|
"queue_depth": 0,
|
|
"last_latencies_llm": list(_latencies_llm[-_LATENCY_BUFFER:]),
|
|
"last_latencies_vision": list(_latencies_vision[-_LATENCY_BUFFER:]),
|
|
}
|
|
|
|
|
|
def _report_latency_async(req_type: str, runtime: str, latency_ms: int):
|
|
"""Fire-and-forget latency report to local NCS."""
|
|
import httpx as _httpx
|
|
|
|
ncs_url = os.getenv("NCS_REPORT_URL", "http://node-capabilities:8099")
|
|
|
|
async def _do():
|
|
try:
|
|
async with _httpx.AsyncClient(timeout=1) as c:
|
|
await c.post(f"{ncs_url}/capabilities/report_latency", json={
|
|
"runtime": runtime, "type": req_type, "latency_ms": latency_ms,
|
|
})
|
|
except Exception:
|
|
pass
|
|
|
|
try:
|
|
asyncio.get_event_loop().create_task(_do())
|
|
except RuntimeError:
|
|
pass
|
|
|
|
|
|
async def _reply(msg, resp: JobResponse):
|
|
if msg.reply:
|
|
await _nats_client.publish(msg.reply, resp.model_dump_json().encode())
|