P2.2+P2.3: NATS offload node-worker + router offload integration

Node Worker (services/node-worker/):
- NATS subscriber for node.{NODE_ID}.llm.request / vision.request
- Canonical JobRequest/JobResponse envelope (Pydantic)
- Idempotency cache (TTL 10min) with inflight dedup
- Deadline enforcement (DEADLINE_EXCEEDED on expired jobs)
- Concurrency limiter (semaphore, returns busy)
- Ollama + Swapper vision providers

Router offload (services/router/offload_client.py):
- NATS req/reply with configurable retries
- Circuit breaker per node+type (3 fails/60s → open 120s)
- Concurrency semaphore for remote requests

Model selection (services/router/model_select.py):
- exclude_nodes parameter for circuit-broken nodes
- force_local flag for fallback re-selection
- Integrated circuit breaker state awareness

Router /infer pipeline:
- Remote offload path when NCS selects remote node
- Automatic fallback: exclude failed node → force_local re-select
- Deadline propagation from router to node-worker

Tests: 17 unit tests (idempotency, deadline, circuit breaker)
Docs: ops/offload_routing.md (subjects, envelope, verification)
Made-with: Cursor
This commit is contained in:
Apple
2026-02-27 02:44:05 -08:00
parent a92c424845
commit c4b94a327d
19 changed files with 1075 additions and 6 deletions

View File

@@ -0,0 +1,7 @@
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8109
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8109"]

View File

View File

@@ -0,0 +1,12 @@
"""Node-worker configuration from environment."""
import os
NODE_ID = os.getenv("NODE_ID", "noda2")
NATS_URL = os.getenv("NATS_URL", "nats://dagi-nats:4222")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
SWAPPER_URL = os.getenv("SWAPPER_URL", "http://swapper-service:8890")
DEFAULT_LLM = os.getenv("NODE_DEFAULT_LLM", "qwen3:14b")
DEFAULT_VISION = os.getenv("NODE_DEFAULT_VISION", "llava:13b")
MAX_CONCURRENCY = int(os.getenv("NODE_WORKER_MAX_CONCURRENCY", "2"))
MAX_PAYLOAD_BYTES = int(os.getenv("NODE_WORKER_MAX_PAYLOAD_BYTES", str(1024 * 1024)))
PORT = int(os.getenv("PORT", "8109"))

View File

@@ -0,0 +1,62 @@
"""Idempotency cache + inflight dedup for job execution."""
import asyncio
import logging
import time
from typing import Dict, Optional, Tuple
from models import JobResponse
logger = logging.getLogger("idempotency")
CACHE_TTL = 600 # 10 min for successful results
TIMEOUT_TTL = 30 # 30s for timeout results
class IdempotencyStore:
def __init__(self, max_size: int = 10_000):
self._cache: Dict[str, Tuple[JobResponse, float]] = {}
self._inflight: Dict[str, asyncio.Future] = {}
self._max_size = max_size
def get(self, key: str) -> Optional[JobResponse]:
entry = self._cache.get(key)
if not entry:
return None
resp, expires = entry
if time.time() > expires:
self._cache.pop(key, None)
return None
cached = resp.model_copy()
cached.cached = True
return cached
def put(self, key: str, resp: JobResponse):
ttl = TIMEOUT_TTL if resp.status == "timeout" else CACHE_TTL
self._cache[key] = (resp, time.time() + ttl)
self._evict_if_needed()
def _evict_if_needed(self):
if len(self._cache) <= self._max_size:
return
now = time.time()
expired = [k for k, (_, exp) in self._cache.items() if now > exp]
for k in expired:
self._cache.pop(k, None)
if len(self._cache) > self._max_size:
oldest = sorted(self._cache, key=lambda k: self._cache[k][1])
for k in oldest[:len(self._cache) - self._max_size]:
self._cache.pop(k, None)
async def acquire_inflight(self, key: str) -> Optional[asyncio.Future]:
"""If another coroutine is already processing this key, return its future.
Otherwise register this coroutine as the processor and return None."""
if key in self._inflight:
return self._inflight[key]
fut: asyncio.Future = asyncio.get_event_loop().create_future()
self._inflight[key] = fut
return None
def complete_inflight(self, key: str, resp: JobResponse):
fut = self._inflight.pop(key, None)
if fut and not fut.done():
fut.set_result(resp)

View File

@@ -0,0 +1,53 @@
"""Node Worker — NATS offload executor for cross-node inference."""
import logging
import os
from fastapi import FastAPI
import config
import worker
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("node-worker")
app = FastAPI(title="Node Worker", version="1.0.0")
_nats_client = None
@app.get("/healthz")
async def healthz():
connected = _nats_client is not None and _nats_client.is_connected if _nats_client else False
return {
"status": "ok" if connected else "degraded",
"node_id": config.NODE_ID,
"nats_connected": connected,
"max_concurrency": config.MAX_CONCURRENCY,
}
@app.on_event("startup")
async def startup():
global _nats_client
try:
import nats as nats_lib
_nats_client = await nats_lib.connect(config.NATS_URL)
logger.info(f"✅ NATS connected: {config.NATS_URL}")
await worker.start(_nats_client)
logger.info(f"✅ Node Worker ready: node={config.NODE_ID} concurrency={config.MAX_CONCURRENCY}")
except Exception as e:
logger.error(f"❌ Startup failed: {e}")
@app.on_event("shutdown")
async def shutdown():
if _nats_client:
try:
await _nats_client.close()
except Exception:
pass
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=config.PORT)

View File

@@ -0,0 +1,47 @@
"""Canonical job envelope for cross-node inference offload."""
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
import time
import uuid
def _ulid() -> str:
return str(uuid.uuid4())
class JobRequest(BaseModel):
job_id: str = Field(default_factory=_ulid)
trace_id: str = ""
actor_agent_id: str = ""
target_agent_id: str = ""
required_type: Literal["llm", "vision", "stt", "tts"] = "llm"
deadline_ts: int = 0
idempotency_key: str = ""
payload: Dict[str, Any] = Field(default_factory=dict)
hints: Dict[str, Any] = Field(default_factory=dict)
def remaining_ms(self) -> int:
if self.deadline_ts <= 0:
return 30_000
return max(0, self.deadline_ts - int(time.time() * 1000))
def effective_idem_key(self) -> str:
return self.idempotency_key or self.job_id
class JobError(BaseModel):
code: str = "UNKNOWN"
message: str = ""
class JobResponse(BaseModel):
job_id: str = ""
trace_id: str = ""
node_id: str = ""
status: Literal["ok", "busy", "timeout", "error"] = "ok"
provider: str = ""
model: str = ""
latency_ms: int = 0
result: Optional[Dict[str, Any]] = None
error: Optional[JobError] = None
cached: bool = False

View File

@@ -0,0 +1,81 @@
"""Ollama LLM provider for node-worker."""
import logging
from typing import Any, Dict, List, Optional
import httpx
from config import OLLAMA_BASE_URL, DEFAULT_LLM
logger = logging.getLogger("provider.ollama")
async def infer(
messages: Optional[List[Dict[str, str]]] = None,
prompt: str = "",
model: str = "",
system: str = "",
max_tokens: int = 2048,
temperature: float = 0.2,
timeout_s: float = 25.0,
) -> Dict[str, Any]:
model = model or DEFAULT_LLM
if messages:
return await _chat(messages, model, max_tokens, temperature, timeout_s)
return await _generate(prompt, system, model, max_tokens, temperature, timeout_s)
async def _chat(
messages: List[Dict[str, str]],
model: str,
max_tokens: int,
temperature: float,
timeout_s: float,
) -> Dict[str, Any]:
async with httpx.AsyncClient(timeout=timeout_s) as c:
resp = await c.post(
f"{OLLAMA_BASE_URL}/api/chat",
json={
"model": model,
"messages": messages,
"stream": False,
"options": {"num_predict": max_tokens, "temperature": temperature},
},
)
resp.raise_for_status()
data = resp.json()
return {
"text": data.get("message", {}).get("content", ""),
"model": model,
"provider": "ollama",
"eval_count": data.get("eval_count", 0),
}
async def _generate(
prompt: str,
system: str,
model: str,
max_tokens: int,
temperature: float,
timeout_s: float,
) -> Dict[str, Any]:
async with httpx.AsyncClient(timeout=timeout_s) as c:
resp = await c.post(
f"{OLLAMA_BASE_URL}/api/generate",
json={
"model": model,
"prompt": prompt,
"system": system,
"stream": False,
"options": {"num_predict": max_tokens, "temperature": temperature},
},
)
resp.raise_for_status()
data = resp.json()
return {
"text": data.get("response", ""),
"model": model,
"provider": "ollama",
"eval_count": data.get("eval_count", 0),
}

View File

@@ -0,0 +1,42 @@
"""Swapper vision provider for node-worker."""
import logging
from typing import Any, Dict, List, Optional
import httpx
from config import SWAPPER_URL, DEFAULT_VISION
logger = logging.getLogger("provider.swapper_vision")
async def infer(
images: Optional[List[str]] = None,
prompt: str = "",
model: str = "",
system: str = "",
max_tokens: int = 1024,
temperature: float = 0.2,
timeout_s: float = 60.0,
) -> Dict[str, Any]:
model = model or DEFAULT_VISION
payload: Dict[str, Any] = {
"model": model,
"prompt": prompt or "Describe this image.",
"max_tokens": max_tokens,
"temperature": temperature,
}
if images:
payload["images"] = images
if system:
payload["system"] = system
async with httpx.AsyncClient(timeout=timeout_s) as c:
resp = await c.post(f"{SWAPPER_URL}/vision", json=payload)
resp.raise_for_status()
data = resp.json()
return {
"text": data.get("text", data.get("response", "")),
"model": model,
"provider": "swapper_vision",
}

View File

@@ -0,0 +1,5 @@
fastapi>=0.110.0
uvicorn>=0.29.0
httpx>=0.27.0
nats-py>=2.7.0
pydantic>=2.5.0

View File

@@ -0,0 +1,184 @@
"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects."""
import asyncio
import json
import logging
import time
from typing import Any, Dict
import config
from models import JobRequest, JobResponse, JobError
from idempotency import IdempotencyStore
from providers import ollama, swapper_vision
logger = logging.getLogger("node-worker")
_idem = IdempotencyStore()
_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY)
_nats_client = None
async def start(nats_client):
global _nats_client
_nats_client = nats_client
subjects = [
f"node.{config.NODE_ID.lower()}.llm.request",
f"node.{config.NODE_ID.lower()}.vision.request",
]
for subj in subjects:
await nats_client.subscribe(subj, cb=_handle_request)
logger.info(f"✅ Subscribed: {subj}")
async def _handle_request(msg):
t0 = time.time()
try:
raw = msg.data
if len(raw) > config.MAX_PAYLOAD_BYTES:
await _reply(msg, JobResponse(
status="error",
node_id=config.NODE_ID,
error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"),
))
return
data = json.loads(raw)
job = JobRequest(**data)
job.trace_id = job.trace_id or job.job_id
idem_key = job.effective_idem_key()
cached = _idem.get(idem_key)
if cached:
logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}")
await _reply(msg, cached)
return
remaining = job.remaining_ms()
if remaining <= 0:
resp = JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout",
error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"),
)
_idem.put(idem_key, resp)
await _reply(msg, resp)
return
inflight = await _idem.acquire_inflight(idem_key)
if inflight is not None:
try:
resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0)
resp_copy = resp.model_copy()
resp_copy.cached = True
await _reply(msg, resp_copy)
except asyncio.TimeoutError:
await _reply(msg, JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"),
))
return
if _semaphore.locked() and _semaphore._value == 0:
resp = JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="busy",
error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"),
)
_idem.complete_inflight(idem_key, resp)
await _reply(msg, resp)
return
async with _semaphore:
resp = await _execute(job, remaining)
_idem.put(idem_key, resp)
_idem.complete_inflight(idem_key, resp)
resp.latency_ms = int((time.time() - t0) * 1000)
await _reply(msg, resp)
except Exception as e:
logger.exception(f"Worker handler error: {e}")
try:
await _reply(msg, JobResponse(
node_id=config.NODE_ID, status="error",
error=JobError(code="INTERNAL", message=str(e)[:200]),
))
except Exception:
pass
async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
payload = job.payload
hints = job.hints
timeout_s = min(remaining_ms / 1000.0, 120.0)
model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "")
msg_count = len(payload.get("messages", []))
prompt_chars = len(payload.get("prompt", ""))
logger.info(
f"[job.start] job={job.job_id} trace={job.trace_id} "
f"type={job.required_type} model={model or '?'} "
f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms"
)
try:
if job.required_type == "llm":
result = await asyncio.wait_for(
ollama.infer(
messages=payload.get("messages"),
prompt=payload.get("prompt", ""),
model=model,
system=payload.get("system", ""),
max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)),
temperature=hints.get("temperature", payload.get("temperature", 0.2)),
timeout_s=timeout_s,
),
timeout=timeout_s,
)
elif job.required_type == "vision":
result = await asyncio.wait_for(
swapper_vision.infer(
images=payload.get("images"),
prompt=payload.get("prompt", ""),
model=model,
system=payload.get("system", ""),
max_tokens=hints.get("max_tokens", 1024),
temperature=hints.get("temperature", 0.2),
timeout_s=timeout_s,
),
timeout=timeout_s,
)
else:
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not implemented"),
)
logger.info(
f"[job.done] job={job.job_id} status=ok "
f"provider={result.get('provider')} model={result.get('model')}"
)
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="ok", provider=result.get("provider", ""), model=result.get("model", ""),
result=result,
)
except asyncio.TimeoutError:
logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s")
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout", error=JobError(code="PROVIDER_TIMEOUT"),
)
except Exception as e:
logger.warning(f"[job.error] job={job.job_id} {e}")
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]),
)
async def _reply(msg, resp: JobResponse):
if msg.reply:
await _nats_client.publish(msg.reply, resp.model_dump_json().encode())