Files
microdao-daarion/services/node-worker/worker.py
Apple c4b94a327d P2.2+P2.3: NATS offload node-worker + router offload integration
Node Worker (services/node-worker/):
- NATS subscriber for node.{NODE_ID}.llm.request / vision.request
- Canonical JobRequest/JobResponse envelope (Pydantic)
- Idempotency cache (TTL 10min) with inflight dedup
- Deadline enforcement (DEADLINE_EXCEEDED on expired jobs)
- Concurrency limiter (semaphore, returns busy)
- Ollama + Swapper vision providers

Router offload (services/router/offload_client.py):
- NATS req/reply with configurable retries
- Circuit breaker per node+type (3 fails/60s → open 120s)
- Concurrency semaphore for remote requests

Model selection (services/router/model_select.py):
- exclude_nodes parameter for circuit-broken nodes
- force_local flag for fallback re-selection
- Integrated circuit breaker state awareness

Router /infer pipeline:
- Remote offload path when NCS selects remote node
- Automatic fallback: exclude failed node → force_local re-select
- Deadline propagation from router to node-worker

Tests: 17 unit tests (idempotency, deadline, circuit breaker)
Docs: ops/offload_routing.md (subjects, envelope, verification)
Made-with: Cursor
2026-02-27 02:44:05 -08:00

185 lines
6.6 KiB
Python

"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects."""
import asyncio
import json
import logging
import time
from typing import Any, Dict
import config
from models import JobRequest, JobResponse, JobError
from idempotency import IdempotencyStore
from providers import ollama, swapper_vision
logger = logging.getLogger("node-worker")
_idem = IdempotencyStore()
_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY)
_nats_client = None
async def start(nats_client):
global _nats_client
_nats_client = nats_client
subjects = [
f"node.{config.NODE_ID.lower()}.llm.request",
f"node.{config.NODE_ID.lower()}.vision.request",
]
for subj in subjects:
await nats_client.subscribe(subj, cb=_handle_request)
logger.info(f"✅ Subscribed: {subj}")
async def _handle_request(msg):
t0 = time.time()
try:
raw = msg.data
if len(raw) > config.MAX_PAYLOAD_BYTES:
await _reply(msg, JobResponse(
status="error",
node_id=config.NODE_ID,
error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"),
))
return
data = json.loads(raw)
job = JobRequest(**data)
job.trace_id = job.trace_id or job.job_id
idem_key = job.effective_idem_key()
cached = _idem.get(idem_key)
if cached:
logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}")
await _reply(msg, cached)
return
remaining = job.remaining_ms()
if remaining <= 0:
resp = JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout",
error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"),
)
_idem.put(idem_key, resp)
await _reply(msg, resp)
return
inflight = await _idem.acquire_inflight(idem_key)
if inflight is not None:
try:
resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0)
resp_copy = resp.model_copy()
resp_copy.cached = True
await _reply(msg, resp_copy)
except asyncio.TimeoutError:
await _reply(msg, JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"),
))
return
if _semaphore.locked() and _semaphore._value == 0:
resp = JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="busy",
error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"),
)
_idem.complete_inflight(idem_key, resp)
await _reply(msg, resp)
return
async with _semaphore:
resp = await _execute(job, remaining)
_idem.put(idem_key, resp)
_idem.complete_inflight(idem_key, resp)
resp.latency_ms = int((time.time() - t0) * 1000)
await _reply(msg, resp)
except Exception as e:
logger.exception(f"Worker handler error: {e}")
try:
await _reply(msg, JobResponse(
node_id=config.NODE_ID, status="error",
error=JobError(code="INTERNAL", message=str(e)[:200]),
))
except Exception:
pass
async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
payload = job.payload
hints = job.hints
timeout_s = min(remaining_ms / 1000.0, 120.0)
model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "")
msg_count = len(payload.get("messages", []))
prompt_chars = len(payload.get("prompt", ""))
logger.info(
f"[job.start] job={job.job_id} trace={job.trace_id} "
f"type={job.required_type} model={model or '?'} "
f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms"
)
try:
if job.required_type == "llm":
result = await asyncio.wait_for(
ollama.infer(
messages=payload.get("messages"),
prompt=payload.get("prompt", ""),
model=model,
system=payload.get("system", ""),
max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)),
temperature=hints.get("temperature", payload.get("temperature", 0.2)),
timeout_s=timeout_s,
),
timeout=timeout_s,
)
elif job.required_type == "vision":
result = await asyncio.wait_for(
swapper_vision.infer(
images=payload.get("images"),
prompt=payload.get("prompt", ""),
model=model,
system=payload.get("system", ""),
max_tokens=hints.get("max_tokens", 1024),
temperature=hints.get("temperature", 0.2),
timeout_s=timeout_s,
),
timeout=timeout_s,
)
else:
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not implemented"),
)
logger.info(
f"[job.done] job={job.job_id} status=ok "
f"provider={result.get('provider')} model={result.get('model')}"
)
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="ok", provider=result.get("provider", ""), model=result.get("model", ""),
result=result,
)
except asyncio.TimeoutError:
logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s")
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout", error=JobError(code="PROVIDER_TIMEOUT"),
)
except Exception as e:
logger.warning(f"[job.error] job={job.job_id} {e}")
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]),
)
async def _reply(msg, resp: JobResponse):
if msg.reply:
await _nats_client.publish(msg.reply, resp.model_dump_json().encode())