Node Worker (services/node-worker/):
- NATS subscriber for node.{NODE_ID}.llm.request / vision.request
- Canonical JobRequest/JobResponse envelope (Pydantic)
- Idempotency cache (TTL 10min) with inflight dedup
- Deadline enforcement (DEADLINE_EXCEEDED on expired jobs)
- Concurrency limiter (semaphore, returns busy)
- Ollama + Swapper vision providers
Router offload (services/router/offload_client.py):
- NATS req/reply with configurable retries
- Circuit breaker per node+type (3 fails/60s → open 120s)
- Concurrency semaphore for remote requests
Model selection (services/router/model_select.py):
- exclude_nodes parameter for circuit-broken nodes
- force_local flag for fallback re-selection
- Integrated circuit breaker state awareness
Router /infer pipeline:
- Remote offload path when NCS selects remote node
- Automatic fallback: exclude failed node → force_local re-select
- Deadline propagation from router to node-worker
Tests: 17 unit tests (idempotency, deadline, circuit breaker)
Docs: ops/offload_routing.md (subjects, envelope, verification)
Made-with: Cursor
185 lines
6.6 KiB
Python
185 lines
6.6 KiB
Python
"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects."""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import time
|
|
from typing import Any, Dict
|
|
|
|
import config
|
|
from models import JobRequest, JobResponse, JobError
|
|
from idempotency import IdempotencyStore
|
|
from providers import ollama, swapper_vision
|
|
|
|
logger = logging.getLogger("node-worker")
|
|
|
|
_idem = IdempotencyStore()
|
|
_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY)
|
|
_nats_client = None
|
|
|
|
|
|
async def start(nats_client):
|
|
global _nats_client
|
|
_nats_client = nats_client
|
|
|
|
subjects = [
|
|
f"node.{config.NODE_ID.lower()}.llm.request",
|
|
f"node.{config.NODE_ID.lower()}.vision.request",
|
|
]
|
|
for subj in subjects:
|
|
await nats_client.subscribe(subj, cb=_handle_request)
|
|
logger.info(f"✅ Subscribed: {subj}")
|
|
|
|
|
|
async def _handle_request(msg):
|
|
t0 = time.time()
|
|
try:
|
|
raw = msg.data
|
|
if len(raw) > config.MAX_PAYLOAD_BYTES:
|
|
await _reply(msg, JobResponse(
|
|
status="error",
|
|
node_id=config.NODE_ID,
|
|
error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"),
|
|
))
|
|
return
|
|
|
|
data = json.loads(raw)
|
|
job = JobRequest(**data)
|
|
job.trace_id = job.trace_id or job.job_id
|
|
|
|
idem_key = job.effective_idem_key()
|
|
cached = _idem.get(idem_key)
|
|
if cached:
|
|
logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}")
|
|
await _reply(msg, cached)
|
|
return
|
|
|
|
remaining = job.remaining_ms()
|
|
if remaining <= 0:
|
|
resp = JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="timeout",
|
|
error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"),
|
|
)
|
|
_idem.put(idem_key, resp)
|
|
await _reply(msg, resp)
|
|
return
|
|
|
|
inflight = await _idem.acquire_inflight(idem_key)
|
|
if inflight is not None:
|
|
try:
|
|
resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0)
|
|
resp_copy = resp.model_copy()
|
|
resp_copy.cached = True
|
|
await _reply(msg, resp_copy)
|
|
except asyncio.TimeoutError:
|
|
await _reply(msg, JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"),
|
|
))
|
|
return
|
|
|
|
if _semaphore.locked() and _semaphore._value == 0:
|
|
resp = JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="busy",
|
|
error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"),
|
|
)
|
|
_idem.complete_inflight(idem_key, resp)
|
|
await _reply(msg, resp)
|
|
return
|
|
|
|
async with _semaphore:
|
|
resp = await _execute(job, remaining)
|
|
|
|
_idem.put(idem_key, resp)
|
|
_idem.complete_inflight(idem_key, resp)
|
|
resp.latency_ms = int((time.time() - t0) * 1000)
|
|
await _reply(msg, resp)
|
|
|
|
except Exception as e:
|
|
logger.exception(f"Worker handler error: {e}")
|
|
try:
|
|
await _reply(msg, JobResponse(
|
|
node_id=config.NODE_ID, status="error",
|
|
error=JobError(code="INTERNAL", message=str(e)[:200]),
|
|
))
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
|
|
payload = job.payload
|
|
hints = job.hints
|
|
timeout_s = min(remaining_ms / 1000.0, 120.0)
|
|
model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "")
|
|
|
|
msg_count = len(payload.get("messages", []))
|
|
prompt_chars = len(payload.get("prompt", ""))
|
|
logger.info(
|
|
f"[job.start] job={job.job_id} trace={job.trace_id} "
|
|
f"type={job.required_type} model={model or '?'} "
|
|
f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms"
|
|
)
|
|
|
|
try:
|
|
if job.required_type == "llm":
|
|
result = await asyncio.wait_for(
|
|
ollama.infer(
|
|
messages=payload.get("messages"),
|
|
prompt=payload.get("prompt", ""),
|
|
model=model,
|
|
system=payload.get("system", ""),
|
|
max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)),
|
|
temperature=hints.get("temperature", payload.get("temperature", 0.2)),
|
|
timeout_s=timeout_s,
|
|
),
|
|
timeout=timeout_s,
|
|
)
|
|
elif job.required_type == "vision":
|
|
result = await asyncio.wait_for(
|
|
swapper_vision.infer(
|
|
images=payload.get("images"),
|
|
prompt=payload.get("prompt", ""),
|
|
model=model,
|
|
system=payload.get("system", ""),
|
|
max_tokens=hints.get("max_tokens", 1024),
|
|
temperature=hints.get("temperature", 0.2),
|
|
timeout_s=timeout_s,
|
|
),
|
|
timeout=timeout_s,
|
|
)
|
|
else:
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="error",
|
|
error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not implemented"),
|
|
)
|
|
|
|
logger.info(
|
|
f"[job.done] job={job.job_id} status=ok "
|
|
f"provider={result.get('provider')} model={result.get('model')}"
|
|
)
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="ok", provider=result.get("provider", ""), model=result.get("model", ""),
|
|
result=result,
|
|
)
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s")
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="timeout", error=JobError(code="PROVIDER_TIMEOUT"),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"[job.error] job={job.job_id} {e}")
|
|
return JobResponse(
|
|
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
|
status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]),
|
|
)
|
|
|
|
|
|
async def _reply(msg, resp: JobResponse):
|
|
if msg.reply:
|
|
await _nats_client.publish(msg.reply, resp.model_dump_json().encode())
|