P2.2+P2.3: NATS offload node-worker + router offload integration

Node Worker (services/node-worker/):
- NATS subscriber for node.{NODE_ID}.llm.request / vision.request
- Canonical JobRequest/JobResponse envelope (Pydantic)
- Idempotency cache (TTL 10min) with inflight dedup
- Deadline enforcement (DEADLINE_EXCEEDED on expired jobs)
- Concurrency limiter (semaphore, returns busy)
- Ollama + Swapper vision providers

Router offload (services/router/offload_client.py):
- NATS req/reply with configurable retries
- Circuit breaker per node+type (3 fails/60s → open 120s)
- Concurrency semaphore for remote requests

Model selection (services/router/model_select.py):
- exclude_nodes parameter for circuit-broken nodes
- force_local flag for fallback re-selection
- Integrated circuit breaker state awareness

Router /infer pipeline:
- Remote offload path when NCS selects remote node
- Automatic fallback: exclude failed node → force_local re-select
- Deadline propagation from router to node-worker

Tests: 17 unit tests (idempotency, deadline, circuit breaker)
Docs: ops/offload_routing.md (subjects, envelope, verification)
Made-with: Cursor
This commit is contained in:
Apple
2026-02-27 02:44:05 -08:00
parent a92c424845
commit c4b94a327d
19 changed files with 1075 additions and 6 deletions

View File

@@ -0,0 +1,7 @@
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8109
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8109"]

View File

View File

@@ -0,0 +1,12 @@
"""Node-worker configuration from environment."""
import os
NODE_ID = os.getenv("NODE_ID", "noda2")
NATS_URL = os.getenv("NATS_URL", "nats://dagi-nats:4222")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
SWAPPER_URL = os.getenv("SWAPPER_URL", "http://swapper-service:8890")
DEFAULT_LLM = os.getenv("NODE_DEFAULT_LLM", "qwen3:14b")
DEFAULT_VISION = os.getenv("NODE_DEFAULT_VISION", "llava:13b")
MAX_CONCURRENCY = int(os.getenv("NODE_WORKER_MAX_CONCURRENCY", "2"))
MAX_PAYLOAD_BYTES = int(os.getenv("NODE_WORKER_MAX_PAYLOAD_BYTES", str(1024 * 1024)))
PORT = int(os.getenv("PORT", "8109"))

View File

@@ -0,0 +1,62 @@
"""Idempotency cache + inflight dedup for job execution."""
import asyncio
import logging
import time
from typing import Dict, Optional, Tuple
from models import JobResponse
logger = logging.getLogger("idempotency")
CACHE_TTL = 600 # 10 min for successful results
TIMEOUT_TTL = 30 # 30s for timeout results
class IdempotencyStore:
def __init__(self, max_size: int = 10_000):
self._cache: Dict[str, Tuple[JobResponse, float]] = {}
self._inflight: Dict[str, asyncio.Future] = {}
self._max_size = max_size
def get(self, key: str) -> Optional[JobResponse]:
entry = self._cache.get(key)
if not entry:
return None
resp, expires = entry
if time.time() > expires:
self._cache.pop(key, None)
return None
cached = resp.model_copy()
cached.cached = True
return cached
def put(self, key: str, resp: JobResponse):
ttl = TIMEOUT_TTL if resp.status == "timeout" else CACHE_TTL
self._cache[key] = (resp, time.time() + ttl)
self._evict_if_needed()
def _evict_if_needed(self):
if len(self._cache) <= self._max_size:
return
now = time.time()
expired = [k for k, (_, exp) in self._cache.items() if now > exp]
for k in expired:
self._cache.pop(k, None)
if len(self._cache) > self._max_size:
oldest = sorted(self._cache, key=lambda k: self._cache[k][1])
for k in oldest[:len(self._cache) - self._max_size]:
self._cache.pop(k, None)
async def acquire_inflight(self, key: str) -> Optional[asyncio.Future]:
"""If another coroutine is already processing this key, return its future.
Otherwise register this coroutine as the processor and return None."""
if key in self._inflight:
return self._inflight[key]
fut: asyncio.Future = asyncio.get_event_loop().create_future()
self._inflight[key] = fut
return None
def complete_inflight(self, key: str, resp: JobResponse):
fut = self._inflight.pop(key, None)
if fut and not fut.done():
fut.set_result(resp)

View File

@@ -0,0 +1,53 @@
"""Node Worker — NATS offload executor for cross-node inference."""
import logging
import os
from fastapi import FastAPI
import config
import worker
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("node-worker")
app = FastAPI(title="Node Worker", version="1.0.0")
_nats_client = None
@app.get("/healthz")
async def healthz():
connected = _nats_client is not None and _nats_client.is_connected if _nats_client else False
return {
"status": "ok" if connected else "degraded",
"node_id": config.NODE_ID,
"nats_connected": connected,
"max_concurrency": config.MAX_CONCURRENCY,
}
@app.on_event("startup")
async def startup():
global _nats_client
try:
import nats as nats_lib
_nats_client = await nats_lib.connect(config.NATS_URL)
logger.info(f"✅ NATS connected: {config.NATS_URL}")
await worker.start(_nats_client)
logger.info(f"✅ Node Worker ready: node={config.NODE_ID} concurrency={config.MAX_CONCURRENCY}")
except Exception as e:
logger.error(f"❌ Startup failed: {e}")
@app.on_event("shutdown")
async def shutdown():
if _nats_client:
try:
await _nats_client.close()
except Exception:
pass
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=config.PORT)

View File

@@ -0,0 +1,47 @@
"""Canonical job envelope for cross-node inference offload."""
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
import time
import uuid
def _ulid() -> str:
return str(uuid.uuid4())
class JobRequest(BaseModel):
job_id: str = Field(default_factory=_ulid)
trace_id: str = ""
actor_agent_id: str = ""
target_agent_id: str = ""
required_type: Literal["llm", "vision", "stt", "tts"] = "llm"
deadline_ts: int = 0
idempotency_key: str = ""
payload: Dict[str, Any] = Field(default_factory=dict)
hints: Dict[str, Any] = Field(default_factory=dict)
def remaining_ms(self) -> int:
if self.deadline_ts <= 0:
return 30_000
return max(0, self.deadline_ts - int(time.time() * 1000))
def effective_idem_key(self) -> str:
return self.idempotency_key or self.job_id
class JobError(BaseModel):
code: str = "UNKNOWN"
message: str = ""
class JobResponse(BaseModel):
job_id: str = ""
trace_id: str = ""
node_id: str = ""
status: Literal["ok", "busy", "timeout", "error"] = "ok"
provider: str = ""
model: str = ""
latency_ms: int = 0
result: Optional[Dict[str, Any]] = None
error: Optional[JobError] = None
cached: bool = False

View File

@@ -0,0 +1,81 @@
"""Ollama LLM provider for node-worker."""
import logging
from typing import Any, Dict, List, Optional
import httpx
from config import OLLAMA_BASE_URL, DEFAULT_LLM
logger = logging.getLogger("provider.ollama")
async def infer(
messages: Optional[List[Dict[str, str]]] = None,
prompt: str = "",
model: str = "",
system: str = "",
max_tokens: int = 2048,
temperature: float = 0.2,
timeout_s: float = 25.0,
) -> Dict[str, Any]:
model = model or DEFAULT_LLM
if messages:
return await _chat(messages, model, max_tokens, temperature, timeout_s)
return await _generate(prompt, system, model, max_tokens, temperature, timeout_s)
async def _chat(
messages: List[Dict[str, str]],
model: str,
max_tokens: int,
temperature: float,
timeout_s: float,
) -> Dict[str, Any]:
async with httpx.AsyncClient(timeout=timeout_s) as c:
resp = await c.post(
f"{OLLAMA_BASE_URL}/api/chat",
json={
"model": model,
"messages": messages,
"stream": False,
"options": {"num_predict": max_tokens, "temperature": temperature},
},
)
resp.raise_for_status()
data = resp.json()
return {
"text": data.get("message", {}).get("content", ""),
"model": model,
"provider": "ollama",
"eval_count": data.get("eval_count", 0),
}
async def _generate(
prompt: str,
system: str,
model: str,
max_tokens: int,
temperature: float,
timeout_s: float,
) -> Dict[str, Any]:
async with httpx.AsyncClient(timeout=timeout_s) as c:
resp = await c.post(
f"{OLLAMA_BASE_URL}/api/generate",
json={
"model": model,
"prompt": prompt,
"system": system,
"stream": False,
"options": {"num_predict": max_tokens, "temperature": temperature},
},
)
resp.raise_for_status()
data = resp.json()
return {
"text": data.get("response", ""),
"model": model,
"provider": "ollama",
"eval_count": data.get("eval_count", 0),
}

View File

@@ -0,0 +1,42 @@
"""Swapper vision provider for node-worker."""
import logging
from typing import Any, Dict, List, Optional
import httpx
from config import SWAPPER_URL, DEFAULT_VISION
logger = logging.getLogger("provider.swapper_vision")
async def infer(
images: Optional[List[str]] = None,
prompt: str = "",
model: str = "",
system: str = "",
max_tokens: int = 1024,
temperature: float = 0.2,
timeout_s: float = 60.0,
) -> Dict[str, Any]:
model = model or DEFAULT_VISION
payload: Dict[str, Any] = {
"model": model,
"prompt": prompt or "Describe this image.",
"max_tokens": max_tokens,
"temperature": temperature,
}
if images:
payload["images"] = images
if system:
payload["system"] = system
async with httpx.AsyncClient(timeout=timeout_s) as c:
resp = await c.post(f"{SWAPPER_URL}/vision", json=payload)
resp.raise_for_status()
data = resp.json()
return {
"text": data.get("text", data.get("response", "")),
"model": model,
"provider": "swapper_vision",
}

View File

@@ -0,0 +1,5 @@
fastapi>=0.110.0
uvicorn>=0.29.0
httpx>=0.27.0
nats-py>=2.7.0
pydantic>=2.5.0

View File

@@ -0,0 +1,184 @@
"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects."""
import asyncio
import json
import logging
import time
from typing import Any, Dict
import config
from models import JobRequest, JobResponse, JobError
from idempotency import IdempotencyStore
from providers import ollama, swapper_vision
logger = logging.getLogger("node-worker")
_idem = IdempotencyStore()
_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY)
_nats_client = None
async def start(nats_client):
global _nats_client
_nats_client = nats_client
subjects = [
f"node.{config.NODE_ID.lower()}.llm.request",
f"node.{config.NODE_ID.lower()}.vision.request",
]
for subj in subjects:
await nats_client.subscribe(subj, cb=_handle_request)
logger.info(f"✅ Subscribed: {subj}")
async def _handle_request(msg):
t0 = time.time()
try:
raw = msg.data
if len(raw) > config.MAX_PAYLOAD_BYTES:
await _reply(msg, JobResponse(
status="error",
node_id=config.NODE_ID,
error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"),
))
return
data = json.loads(raw)
job = JobRequest(**data)
job.trace_id = job.trace_id or job.job_id
idem_key = job.effective_idem_key()
cached = _idem.get(idem_key)
if cached:
logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}")
await _reply(msg, cached)
return
remaining = job.remaining_ms()
if remaining <= 0:
resp = JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout",
error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"),
)
_idem.put(idem_key, resp)
await _reply(msg, resp)
return
inflight = await _idem.acquire_inflight(idem_key)
if inflight is not None:
try:
resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0)
resp_copy = resp.model_copy()
resp_copy.cached = True
await _reply(msg, resp_copy)
except asyncio.TimeoutError:
await _reply(msg, JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"),
))
return
if _semaphore.locked() and _semaphore._value == 0:
resp = JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="busy",
error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"),
)
_idem.complete_inflight(idem_key, resp)
await _reply(msg, resp)
return
async with _semaphore:
resp = await _execute(job, remaining)
_idem.put(idem_key, resp)
_idem.complete_inflight(idem_key, resp)
resp.latency_ms = int((time.time() - t0) * 1000)
await _reply(msg, resp)
except Exception as e:
logger.exception(f"Worker handler error: {e}")
try:
await _reply(msg, JobResponse(
node_id=config.NODE_ID, status="error",
error=JobError(code="INTERNAL", message=str(e)[:200]),
))
except Exception:
pass
async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
payload = job.payload
hints = job.hints
timeout_s = min(remaining_ms / 1000.0, 120.0)
model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "")
msg_count = len(payload.get("messages", []))
prompt_chars = len(payload.get("prompt", ""))
logger.info(
f"[job.start] job={job.job_id} trace={job.trace_id} "
f"type={job.required_type} model={model or '?'} "
f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms"
)
try:
if job.required_type == "llm":
result = await asyncio.wait_for(
ollama.infer(
messages=payload.get("messages"),
prompt=payload.get("prompt", ""),
model=model,
system=payload.get("system", ""),
max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)),
temperature=hints.get("temperature", payload.get("temperature", 0.2)),
timeout_s=timeout_s,
),
timeout=timeout_s,
)
elif job.required_type == "vision":
result = await asyncio.wait_for(
swapper_vision.infer(
images=payload.get("images"),
prompt=payload.get("prompt", ""),
model=model,
system=payload.get("system", ""),
max_tokens=hints.get("max_tokens", 1024),
temperature=hints.get("temperature", 0.2),
timeout_s=timeout_s,
),
timeout=timeout_s,
)
else:
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not implemented"),
)
logger.info(
f"[job.done] job={job.job_id} status=ok "
f"provider={result.get('provider')} model={result.get('model')}"
)
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="ok", provider=result.get("provider", ""), model=result.get("model", ""),
result=result,
)
except asyncio.TimeoutError:
logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s")
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout", error=JobError(code="PROVIDER_TIMEOUT"),
)
except Exception as e:
logger.warning(f"[job.error] job={job.job_id} {e}")
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]),
)
async def _reply(msg, resp: JobResponse):
if msg.reply:
await _nats_client.publish(msg.reply, resp.model_dump_json().encode())

View File

@@ -51,11 +51,13 @@ try:
import capabilities_client
import global_capabilities_client
from model_select import select_model_for_agent, ModelSelection, CLOUD_PROVIDERS as NCS_CLOUD_PROVIDERS
import offload_client
NCS_AVAILABLE = True
except ImportError:
NCS_AVAILABLE = False
capabilities_client = None # type: ignore[assignment]
global_capabilities_client = None # type: ignore[assignment]
offload_client = None # type: ignore[assignment]
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -1707,6 +1709,76 @@ async def agent_infer(agent_id: str, request: InferRequest):
f"provider={provider} model={model}"
)
# =========================================================================
# REMOTE OFFLOAD (if model selected on remote node)
# =========================================================================
nats_client_available = nc is not None and nats_available
if (ncs_selection and ncs_selection.via_nats and not ncs_selection.local
and nats_client_available and offload_client and nc):
infer_timeout = int(os.getenv("ROUTER_INFER_TIMEOUT_MS", "25000"))
import uuid as _uuid
job_payload = {
"job_id": str(_uuid.uuid4()),
"trace_id": str(_uuid.uuid4()),
"actor_agent_id": request_agent_id or agent_id,
"target_agent_id": agent_id,
"required_type": ncs_selection.model_type if ncs_selection.model_type != "code" else "llm",
"deadline_ts": int(time.time() * 1000) + infer_timeout,
"idempotency_key": str(_uuid.uuid4()),
"payload": {
"prompt": request.prompt,
"messages": [{"role": "system", "content": system_prompt}] if system_prompt else [],
"model": ncs_selection.name,
"max_tokens": request.max_tokens or 2048,
"temperature": request.temperature or 0.2,
},
"hints": {"prefer_models": [ncs_selection.name]},
}
if request.images:
job_payload["payload"]["images"] = request.images
job_payload["required_type"] = "vision"
job_payload["payload"]["messages"].append({"role": "user", "content": request.prompt})
offload_resp = await offload_client.offload_infer(
nats_client=nc,
node_id=ncs_selection.node,
required_type=job_payload["required_type"],
job_payload=job_payload,
timeout_ms=infer_timeout,
)
if offload_resp and offload_resp.get("status") == "ok":
result_text = offload_resp.get("result", {}).get("text", "")
return InferResponse(
response=result_text,
model=f"{offload_resp.get('model', ncs_selection.name)}@{ncs_selection.node}",
backend=f"nats-offload:{ncs_selection.node}",
tokens_used=offload_resp.get("result", {}).get("eval_count", 0),
)
else:
offload_status = offload_resp.get("status", "none") if offload_resp else "no_reply"
logger.warning(
f"[fallback] offload to {ncs_selection.node} failed ({offload_status}) "
f"→ re-selecting with exclude={ncs_selection.node}, force_local"
)
try:
gcaps = await global_capabilities_client.get_global_capabilities()
ncs_selection = await select_model_for_agent(
agent_id, agent_config, router_config, gcaps, request.model,
exclude_nodes={ncs_selection.node}, force_local=True,
)
if ncs_selection and ncs_selection.name:
provider = ncs_selection.provider
model = ncs_selection.name
llm_profile = router_config.get("llm_profiles", {}).get(default_llm, {})
if ncs_selection.base_url and provider == "ollama":
llm_profile = {**llm_profile, "base_url": ncs_selection.base_url}
logger.info(
f"[fallback.reselect] → local node={ncs_selection.node} "
f"model={model} provider={provider}"
)
except Exception as e:
logger.warning(f"[fallback.reselect] error: {e}; proceeding with static")
# =========================================================================
# VISION PROCESSING (if images present)
# =========================================================================

View File

@@ -9,7 +9,7 @@ Scaling: works with 1 node or 150+. No static node lists.
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set
logger = logging.getLogger("model_select")
@@ -110,6 +110,7 @@ def profile_requirements(
def select_best_model(
reqs: ProfileRequirements,
capabilities: Dict[str, Any],
exclude_nodes: Optional[Set[str]] = None,
) -> Optional[ModelSelection]:
"""Choose the best served model from global (multi-node) capabilities.
@@ -117,18 +118,25 @@ def select_best_model(
1. Prefer list matches (local first, then remote)
2. Best candidate by size (local first, then remote)
3. None → caller should try static fallback
exclude_nodes: set of node_ids to skip (e.g. circuit-broken nodes).
"""
served = capabilities.get("served_models", [])
if not served:
return None
exclude = exclude_nodes or set()
search_types = [reqs.required_type]
if reqs.required_type == "code":
search_types.append("llm")
if reqs.required_type == "llm":
search_types.append("code")
candidates = [m for m in served if m.get("type") in search_types]
candidates = [
m for m in served
if m.get("type") in search_types and m.get("node", "") not in exclude
]
if not candidates:
return None
@@ -218,15 +226,21 @@ async def select_model_for_agent(
router_cfg: Dict[str, Any],
capabilities: Optional[Dict[str, Any]],
request_model: Optional[str] = None,
exclude_nodes: Optional[Set[str]] = None,
force_local: bool = False,
) -> ModelSelection:
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default."""
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default.
exclude_nodes: skip these nodes (circuit-broken). Used on fallback re-selection.
force_local: prefer local-only models (fallback after remote failure).
"""
profile = resolve_effective_profile(
agent_id, agent_cfg, router_cfg, request_model,
)
reqs = profile_requirements(profile, agent_cfg, router_cfg)
if reqs.required_type == "cloud_llm":
if reqs.required_type == "cloud_llm" and not force_local:
static = static_fallback(profile, router_cfg)
if static:
static.fallback_reason = ""
@@ -236,14 +250,31 @@ async def select_model_for_agent(
)
return static
excl = set(exclude_nodes) if exclude_nodes else set()
try:
from offload_client import get_unavailable_nodes
cb_nodes = get_unavailable_nodes(reqs.required_type)
excl |= cb_nodes
if cb_nodes:
logger.info(f"[select] circuit-broken nodes for {reqs.required_type}: {cb_nodes}")
except ImportError:
pass
if capabilities and capabilities.get("served_models"):
sel = select_best_model(reqs, capabilities)
sel = select_best_model(reqs, capabilities, exclude_nodes=excl)
if force_local and sel and not sel.local:
sel = select_best_model(
reqs, capabilities,
exclude_nodes=excl | {n.get("node", "") for n in capabilities.get("served_models", []) if not n.get("local")},
)
if sel:
logger.info(
f"[select] agent={agent_id} profile={profile}"
f"{'NCS' if sel.local else 'REMOTE'} "
f"{'LOCAL' if sel.local else 'REMOTE'} "
f"node={sel.node} runtime={sel.runtime} "
f"model={sel.name} caps_age={sel.caps_age_s}s"
f"{' (force_local)' if force_local else ''}"
f"{' (excluded: ' + ','.join(excl) + ')' if excl else ''}"
)
return sel
logger.warning(

View File

@@ -0,0 +1,153 @@
"""NATS offload client — sends inference requests to remote nodes with
circuit breaker, retries, and deadline enforcement."""
import asyncio
import json
import logging
import os
import time
from typing import Any, Dict, Literal, Optional, Set
logger = logging.getLogger("offload_client")
CB_FAILS = int(os.getenv("ROUTER_OFFLOAD_CB_FAILS", "3"))
CB_WINDOW_S = int(os.getenv("ROUTER_OFFLOAD_CB_WINDOW_S", "60"))
CB_OPEN_S = int(os.getenv("ROUTER_OFFLOAD_CB_OPEN_S", "120"))
MAX_RETRIES = int(os.getenv("ROUTER_OFFLOAD_RETRIES", "1"))
MAX_CONCURRENCY = int(os.getenv("ROUTER_OFFLOAD_MAX_CONCURRENCY_REMOTE", "8"))
_semaphore: Optional[asyncio.Semaphore] = None
_breakers: Dict[str, Dict[str, Any]] = {}
def _get_semaphore() -> asyncio.Semaphore:
global _semaphore
if _semaphore is None:
_semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
return _semaphore
def _breaker_key(node_id: str, req_type: str) -> str:
return f"{node_id}:{req_type}"
def is_node_available(node_id: str, req_type: str) -> bool:
key = _breaker_key(node_id, req_type)
state = _breakers.get(key)
if not state:
return True
open_until = state.get("open_until", 0)
if open_until > time.time():
return False
if open_until > 0 and open_until <= time.time():
return True
now = time.time()
window_start = now - CB_WINDOW_S
recent = [t for t in state.get("failures", []) if t > window_start]
state["failures"] = recent
return len(recent) < CB_FAILS
def record_failure(node_id: str, req_type: str):
key = _breaker_key(node_id, req_type)
state = _breakers.setdefault(key, {"failures": [], "open_until": 0})
state["failures"].append(time.time())
window_start = time.time() - CB_WINDOW_S
recent = [t for t in state["failures"] if t > window_start]
state["failures"] = recent
if len(recent) >= CB_FAILS:
state["open_until"] = time.time() + CB_OPEN_S
logger.warning(f"Circuit OPEN: {key} ({len(recent)} failures in {CB_WINDOW_S}s, open for {CB_OPEN_S}s)")
def record_success(node_id: str, req_type: str):
key = _breaker_key(node_id, req_type)
state = _breakers.get(key)
if state:
state["failures"] = []
state["open_until"] = 0
def get_unavailable_nodes(req_type: str) -> Set[str]:
result = set()
for key, state in _breakers.items():
if not key.endswith(f":{req_type}"):
continue
nid = key.rsplit(":", 1)[0]
if not is_node_available(nid, req_type):
result.add(nid)
return result
async def offload_infer(
nats_client,
node_id: str,
required_type: Literal["llm", "vision", "stt", "tts"],
job_payload: Dict[str, Any],
timeout_ms: int = 25000,
) -> Optional[Dict[str, Any]]:
"""Send inference job to remote node via NATS request/reply.
Returns parsed JobResponse dict or None on total failure.
Retries on transient errors (timeout, busy). Does NOT retry on provider errors.
"""
subject = f"node.{node_id.lower()}.{required_type}.request"
payload_bytes = json.dumps(job_payload).encode()
sem = _get_semaphore()
for attempt in range(1 + MAX_RETRIES):
timeout_s = timeout_ms / 1000.0
if timeout_s <= 0:
logger.warning(f"[offload] deadline exhausted before attempt {attempt}")
return None
t0 = time.time()
try:
async with sem:
logger.info(
f"[offload.start] node={node_id} subj={subject} "
f"timeout={timeout_ms}ms attempt={attempt}"
)
msg = await nats_client.request(subject, payload_bytes, timeout=timeout_s)
resp = json.loads(msg.data)
latency = int((time.time() - t0) * 1000)
status = resp.get("status", "error")
if status == "ok":
record_success(node_id, required_type)
logger.info(
f"[offload.done] node={node_id} status=ok latency={latency}ms "
f"provider={resp.get('provider')} model={resp.get('model')} "
f"cached={resp.get('cached', False)}"
)
return resp
if status in ("timeout", "busy") and attempt < MAX_RETRIES:
elapsed = int((time.time() - t0) * 1000)
timeout_ms -= elapsed
logger.warning(f"[offload.retry] node={node_id} status={status} → retry {attempt+1}")
continue
record_failure(node_id, required_type)
logger.warning(
f"[offload.fail] node={node_id} status={status} "
f"error={resp.get('error', {}).get('code', '?')}"
)
return resp
except asyncio.TimeoutError:
record_failure(node_id, required_type)
elapsed = int((time.time() - t0) * 1000)
timeout_ms -= elapsed
if attempt < MAX_RETRIES:
logger.warning(f"[offload.timeout] node={node_id} {elapsed}ms → retry {attempt+1}")
continue
logger.warning(f"[offload.timeout] node={node_id} all retries exhausted")
return None
except Exception as e:
record_failure(node_id, required_type)
logger.warning(f"[offload.error] node={node_id} {e}")
return None
return None