P2.2+P2.3: NATS offload node-worker + router offload integration
Node Worker (services/node-worker/):
- NATS subscriber for node.{NODE_ID}.llm.request / vision.request
- Canonical JobRequest/JobResponse envelope (Pydantic)
- Idempotency cache (TTL 10min) with inflight dedup
- Deadline enforcement (DEADLINE_EXCEEDED on expired jobs)
- Concurrency limiter (semaphore, returns busy)
- Ollama + Swapper vision providers
Router offload (services/router/offload_client.py):
- NATS req/reply with configurable retries
- Circuit breaker per node+type (3 fails/60s → open 120s)
- Concurrency semaphore for remote requests
Model selection (services/router/model_select.py):
- exclude_nodes parameter for circuit-broken nodes
- force_local flag for fallback re-selection
- Integrated circuit breaker state awareness
Router /infer pipeline:
- Remote offload path when NCS selects remote node
- Automatic fallback: exclude failed node → force_local re-select
- Deadline propagation from router to node-worker
Tests: 17 unit tests (idempotency, deadline, circuit breaker)
Docs: ops/offload_routing.md (subjects, envelope, verification)
Made-with: Cursor
This commit is contained in:
184
services/node-worker/worker.py
Normal file
184
services/node-worker/worker.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects."""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import config
|
||||
from models import JobRequest, JobResponse, JobError
|
||||
from idempotency import IdempotencyStore
|
||||
from providers import ollama, swapper_vision
|
||||
|
||||
logger = logging.getLogger("node-worker")
|
||||
|
||||
_idem = IdempotencyStore()
|
||||
_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY)
|
||||
_nats_client = None
|
||||
|
||||
|
||||
async def start(nats_client):
|
||||
global _nats_client
|
||||
_nats_client = nats_client
|
||||
|
||||
subjects = [
|
||||
f"node.{config.NODE_ID.lower()}.llm.request",
|
||||
f"node.{config.NODE_ID.lower()}.vision.request",
|
||||
]
|
||||
for subj in subjects:
|
||||
await nats_client.subscribe(subj, cb=_handle_request)
|
||||
logger.info(f"✅ Subscribed: {subj}")
|
||||
|
||||
|
||||
async def _handle_request(msg):
|
||||
t0 = time.time()
|
||||
try:
|
||||
raw = msg.data
|
||||
if len(raw) > config.MAX_PAYLOAD_BYTES:
|
||||
await _reply(msg, JobResponse(
|
||||
status="error",
|
||||
node_id=config.NODE_ID,
|
||||
error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"),
|
||||
))
|
||||
return
|
||||
|
||||
data = json.loads(raw)
|
||||
job = JobRequest(**data)
|
||||
job.trace_id = job.trace_id or job.job_id
|
||||
|
||||
idem_key = job.effective_idem_key()
|
||||
cached = _idem.get(idem_key)
|
||||
if cached:
|
||||
logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}")
|
||||
await _reply(msg, cached)
|
||||
return
|
||||
|
||||
remaining = job.remaining_ms()
|
||||
if remaining <= 0:
|
||||
resp = JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="timeout",
|
||||
error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"),
|
||||
)
|
||||
_idem.put(idem_key, resp)
|
||||
await _reply(msg, resp)
|
||||
return
|
||||
|
||||
inflight = await _idem.acquire_inflight(idem_key)
|
||||
if inflight is not None:
|
||||
try:
|
||||
resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0)
|
||||
resp_copy = resp.model_copy()
|
||||
resp_copy.cached = True
|
||||
await _reply(msg, resp_copy)
|
||||
except asyncio.TimeoutError:
|
||||
await _reply(msg, JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"),
|
||||
))
|
||||
return
|
||||
|
||||
if _semaphore.locked() and _semaphore._value == 0:
|
||||
resp = JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="busy",
|
||||
error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"),
|
||||
)
|
||||
_idem.complete_inflight(idem_key, resp)
|
||||
await _reply(msg, resp)
|
||||
return
|
||||
|
||||
async with _semaphore:
|
||||
resp = await _execute(job, remaining)
|
||||
|
||||
_idem.put(idem_key, resp)
|
||||
_idem.complete_inflight(idem_key, resp)
|
||||
resp.latency_ms = int((time.time() - t0) * 1000)
|
||||
await _reply(msg, resp)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Worker handler error: {e}")
|
||||
try:
|
||||
await _reply(msg, JobResponse(
|
||||
node_id=config.NODE_ID, status="error",
|
||||
error=JobError(code="INTERNAL", message=str(e)[:200]),
|
||||
))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
|
||||
payload = job.payload
|
||||
hints = job.hints
|
||||
timeout_s = min(remaining_ms / 1000.0, 120.0)
|
||||
model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "")
|
||||
|
||||
msg_count = len(payload.get("messages", []))
|
||||
prompt_chars = len(payload.get("prompt", ""))
|
||||
logger.info(
|
||||
f"[job.start] job={job.job_id} trace={job.trace_id} "
|
||||
f"type={job.required_type} model={model or '?'} "
|
||||
f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms"
|
||||
)
|
||||
|
||||
try:
|
||||
if job.required_type == "llm":
|
||||
result = await asyncio.wait_for(
|
||||
ollama.infer(
|
||||
messages=payload.get("messages"),
|
||||
prompt=payload.get("prompt", ""),
|
||||
model=model,
|
||||
system=payload.get("system", ""),
|
||||
max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)),
|
||||
temperature=hints.get("temperature", payload.get("temperature", 0.2)),
|
||||
timeout_s=timeout_s,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
elif job.required_type == "vision":
|
||||
result = await asyncio.wait_for(
|
||||
swapper_vision.infer(
|
||||
images=payload.get("images"),
|
||||
prompt=payload.get("prompt", ""),
|
||||
model=model,
|
||||
system=payload.get("system", ""),
|
||||
max_tokens=hints.get("max_tokens", 1024),
|
||||
temperature=hints.get("temperature", 0.2),
|
||||
timeout_s=timeout_s,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
else:
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error",
|
||||
error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not implemented"),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[job.done] job={job.job_id} status=ok "
|
||||
f"provider={result.get('provider')} model={result.get('model')}"
|
||||
)
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="ok", provider=result.get("provider", ""), model=result.get("model", ""),
|
||||
result=result,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s")
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="timeout", error=JobError(code="PROVIDER_TIMEOUT"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[job.error] job={job.job_id} {e}")
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]),
|
||||
)
|
||||
|
||||
|
||||
async def _reply(msg, resp: JobResponse):
|
||||
if msg.reply:
|
||||
await _nats_client.publish(msg.reply, resp.model_dump_json().encode())
|
||||
Reference in New Issue
Block a user