P2.2+P2.3: NATS offload node-worker + router offload integration

Node Worker (services/node-worker/):
- NATS subscriber for node.{NODE_ID}.llm.request / vision.request
- Canonical JobRequest/JobResponse envelope (Pydantic)
- Idempotency cache (TTL 10min) with inflight dedup
- Deadline enforcement (DEADLINE_EXCEEDED on expired jobs)
- Concurrency limiter (semaphore, returns busy)
- Ollama + Swapper vision providers

Router offload (services/router/offload_client.py):
- NATS req/reply with configurable retries
- Circuit breaker per node+type (3 fails/60s → open 120s)
- Concurrency semaphore for remote requests

Model selection (services/router/model_select.py):
- exclude_nodes parameter for circuit-broken nodes
- force_local flag for fallback re-selection
- Integrated circuit breaker state awareness

Router /infer pipeline:
- Remote offload path when NCS selects remote node
- Automatic fallback: exclude failed node → force_local re-select
- Deadline propagation from router to node-worker

Tests: 17 unit tests (idempotency, deadline, circuit breaker)
Docs: ops/offload_routing.md (subjects, envelope, verification)
Made-with: Cursor
This commit is contained in:
Apple
2026-02-27 02:44:05 -08:00
parent a92c424845
commit c4b94a327d
19 changed files with 1075 additions and 6 deletions

View File

@@ -51,11 +51,13 @@ try:
import capabilities_client
import global_capabilities_client
from model_select import select_model_for_agent, ModelSelection, CLOUD_PROVIDERS as NCS_CLOUD_PROVIDERS
import offload_client
NCS_AVAILABLE = True
except ImportError:
NCS_AVAILABLE = False
capabilities_client = None # type: ignore[assignment]
global_capabilities_client = None # type: ignore[assignment]
offload_client = None # type: ignore[assignment]
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -1707,6 +1709,76 @@ async def agent_infer(agent_id: str, request: InferRequest):
f"provider={provider} model={model}"
)
# =========================================================================
# REMOTE OFFLOAD (if model selected on remote node)
# =========================================================================
nats_client_available = nc is not None and nats_available
if (ncs_selection and ncs_selection.via_nats and not ncs_selection.local
and nats_client_available and offload_client and nc):
infer_timeout = int(os.getenv("ROUTER_INFER_TIMEOUT_MS", "25000"))
import uuid as _uuid
job_payload = {
"job_id": str(_uuid.uuid4()),
"trace_id": str(_uuid.uuid4()),
"actor_agent_id": request_agent_id or agent_id,
"target_agent_id": agent_id,
"required_type": ncs_selection.model_type if ncs_selection.model_type != "code" else "llm",
"deadline_ts": int(time.time() * 1000) + infer_timeout,
"idempotency_key": str(_uuid.uuid4()),
"payload": {
"prompt": request.prompt,
"messages": [{"role": "system", "content": system_prompt}] if system_prompt else [],
"model": ncs_selection.name,
"max_tokens": request.max_tokens or 2048,
"temperature": request.temperature or 0.2,
},
"hints": {"prefer_models": [ncs_selection.name]},
}
if request.images:
job_payload["payload"]["images"] = request.images
job_payload["required_type"] = "vision"
job_payload["payload"]["messages"].append({"role": "user", "content": request.prompt})
offload_resp = await offload_client.offload_infer(
nats_client=nc,
node_id=ncs_selection.node,
required_type=job_payload["required_type"],
job_payload=job_payload,
timeout_ms=infer_timeout,
)
if offload_resp and offload_resp.get("status") == "ok":
result_text = offload_resp.get("result", {}).get("text", "")
return InferResponse(
response=result_text,
model=f"{offload_resp.get('model', ncs_selection.name)}@{ncs_selection.node}",
backend=f"nats-offload:{ncs_selection.node}",
tokens_used=offload_resp.get("result", {}).get("eval_count", 0),
)
else:
offload_status = offload_resp.get("status", "none") if offload_resp else "no_reply"
logger.warning(
f"[fallback] offload to {ncs_selection.node} failed ({offload_status}) "
f"→ re-selecting with exclude={ncs_selection.node}, force_local"
)
try:
gcaps = await global_capabilities_client.get_global_capabilities()
ncs_selection = await select_model_for_agent(
agent_id, agent_config, router_config, gcaps, request.model,
exclude_nodes={ncs_selection.node}, force_local=True,
)
if ncs_selection and ncs_selection.name:
provider = ncs_selection.provider
model = ncs_selection.name
llm_profile = router_config.get("llm_profiles", {}).get(default_llm, {})
if ncs_selection.base_url and provider == "ollama":
llm_profile = {**llm_profile, "base_url": ncs_selection.base_url}
logger.info(
f"[fallback.reselect] → local node={ncs_selection.node} "
f"model={model} provider={provider}"
)
except Exception as e:
logger.warning(f"[fallback.reselect] error: {e}; proceeding with static")
# =========================================================================
# VISION PROCESSING (if images present)
# =========================================================================

View File

@@ -9,7 +9,7 @@ Scaling: works with 1 node or 150+. No static node lists.
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set
logger = logging.getLogger("model_select")
@@ -110,6 +110,7 @@ def profile_requirements(
def select_best_model(
reqs: ProfileRequirements,
capabilities: Dict[str, Any],
exclude_nodes: Optional[Set[str]] = None,
) -> Optional[ModelSelection]:
"""Choose the best served model from global (multi-node) capabilities.
@@ -117,18 +118,25 @@ def select_best_model(
1. Prefer list matches (local first, then remote)
2. Best candidate by size (local first, then remote)
3. None → caller should try static fallback
exclude_nodes: set of node_ids to skip (e.g. circuit-broken nodes).
"""
served = capabilities.get("served_models", [])
if not served:
return None
exclude = exclude_nodes or set()
search_types = [reqs.required_type]
if reqs.required_type == "code":
search_types.append("llm")
if reqs.required_type == "llm":
search_types.append("code")
candidates = [m for m in served if m.get("type") in search_types]
candidates = [
m for m in served
if m.get("type") in search_types and m.get("node", "") not in exclude
]
if not candidates:
return None
@@ -218,15 +226,21 @@ async def select_model_for_agent(
router_cfg: Dict[str, Any],
capabilities: Optional[Dict[str, Any]],
request_model: Optional[str] = None,
exclude_nodes: Optional[Set[str]] = None,
force_local: bool = False,
) -> ModelSelection:
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default."""
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default.
exclude_nodes: skip these nodes (circuit-broken). Used on fallback re-selection.
force_local: prefer local-only models (fallback after remote failure).
"""
profile = resolve_effective_profile(
agent_id, agent_cfg, router_cfg, request_model,
)
reqs = profile_requirements(profile, agent_cfg, router_cfg)
if reqs.required_type == "cloud_llm":
if reqs.required_type == "cloud_llm" and not force_local:
static = static_fallback(profile, router_cfg)
if static:
static.fallback_reason = ""
@@ -236,14 +250,31 @@ async def select_model_for_agent(
)
return static
excl = set(exclude_nodes) if exclude_nodes else set()
try:
from offload_client import get_unavailable_nodes
cb_nodes = get_unavailable_nodes(reqs.required_type)
excl |= cb_nodes
if cb_nodes:
logger.info(f"[select] circuit-broken nodes for {reqs.required_type}: {cb_nodes}")
except ImportError:
pass
if capabilities and capabilities.get("served_models"):
sel = select_best_model(reqs, capabilities)
sel = select_best_model(reqs, capabilities, exclude_nodes=excl)
if force_local and sel and not sel.local:
sel = select_best_model(
reqs, capabilities,
exclude_nodes=excl | {n.get("node", "") for n in capabilities.get("served_models", []) if not n.get("local")},
)
if sel:
logger.info(
f"[select] agent={agent_id} profile={profile}"
f"{'NCS' if sel.local else 'REMOTE'} "
f"{'LOCAL' if sel.local else 'REMOTE'} "
f"node={sel.node} runtime={sel.runtime} "
f"model={sel.name} caps_age={sel.caps_age_s}s"
f"{' (force_local)' if force_local else ''}"
f"{' (excluded: ' + ','.join(excl) + ')' if excl else ''}"
)
return sel
logger.warning(

View File

@@ -0,0 +1,153 @@
"""NATS offload client — sends inference requests to remote nodes with
circuit breaker, retries, and deadline enforcement."""
import asyncio
import json
import logging
import os
import time
from typing import Any, Dict, Literal, Optional, Set
logger = logging.getLogger("offload_client")
CB_FAILS = int(os.getenv("ROUTER_OFFLOAD_CB_FAILS", "3"))
CB_WINDOW_S = int(os.getenv("ROUTER_OFFLOAD_CB_WINDOW_S", "60"))
CB_OPEN_S = int(os.getenv("ROUTER_OFFLOAD_CB_OPEN_S", "120"))
MAX_RETRIES = int(os.getenv("ROUTER_OFFLOAD_RETRIES", "1"))
MAX_CONCURRENCY = int(os.getenv("ROUTER_OFFLOAD_MAX_CONCURRENCY_REMOTE", "8"))
_semaphore: Optional[asyncio.Semaphore] = None
_breakers: Dict[str, Dict[str, Any]] = {}
def _get_semaphore() -> asyncio.Semaphore:
global _semaphore
if _semaphore is None:
_semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
return _semaphore
def _breaker_key(node_id: str, req_type: str) -> str:
return f"{node_id}:{req_type}"
def is_node_available(node_id: str, req_type: str) -> bool:
key = _breaker_key(node_id, req_type)
state = _breakers.get(key)
if not state:
return True
open_until = state.get("open_until", 0)
if open_until > time.time():
return False
if open_until > 0 and open_until <= time.time():
return True
now = time.time()
window_start = now - CB_WINDOW_S
recent = [t for t in state.get("failures", []) if t > window_start]
state["failures"] = recent
return len(recent) < CB_FAILS
def record_failure(node_id: str, req_type: str):
key = _breaker_key(node_id, req_type)
state = _breakers.setdefault(key, {"failures": [], "open_until": 0})
state["failures"].append(time.time())
window_start = time.time() - CB_WINDOW_S
recent = [t for t in state["failures"] if t > window_start]
state["failures"] = recent
if len(recent) >= CB_FAILS:
state["open_until"] = time.time() + CB_OPEN_S
logger.warning(f"Circuit OPEN: {key} ({len(recent)} failures in {CB_WINDOW_S}s, open for {CB_OPEN_S}s)")
def record_success(node_id: str, req_type: str):
key = _breaker_key(node_id, req_type)
state = _breakers.get(key)
if state:
state["failures"] = []
state["open_until"] = 0
def get_unavailable_nodes(req_type: str) -> Set[str]:
result = set()
for key, state in _breakers.items():
if not key.endswith(f":{req_type}"):
continue
nid = key.rsplit(":", 1)[0]
if not is_node_available(nid, req_type):
result.add(nid)
return result
async def offload_infer(
nats_client,
node_id: str,
required_type: Literal["llm", "vision", "stt", "tts"],
job_payload: Dict[str, Any],
timeout_ms: int = 25000,
) -> Optional[Dict[str, Any]]:
"""Send inference job to remote node via NATS request/reply.
Returns parsed JobResponse dict or None on total failure.
Retries on transient errors (timeout, busy). Does NOT retry on provider errors.
"""
subject = f"node.{node_id.lower()}.{required_type}.request"
payload_bytes = json.dumps(job_payload).encode()
sem = _get_semaphore()
for attempt in range(1 + MAX_RETRIES):
timeout_s = timeout_ms / 1000.0
if timeout_s <= 0:
logger.warning(f"[offload] deadline exhausted before attempt {attempt}")
return None
t0 = time.time()
try:
async with sem:
logger.info(
f"[offload.start] node={node_id} subj={subject} "
f"timeout={timeout_ms}ms attempt={attempt}"
)
msg = await nats_client.request(subject, payload_bytes, timeout=timeout_s)
resp = json.loads(msg.data)
latency = int((time.time() - t0) * 1000)
status = resp.get("status", "error")
if status == "ok":
record_success(node_id, required_type)
logger.info(
f"[offload.done] node={node_id} status=ok latency={latency}ms "
f"provider={resp.get('provider')} model={resp.get('model')} "
f"cached={resp.get('cached', False)}"
)
return resp
if status in ("timeout", "busy") and attempt < MAX_RETRIES:
elapsed = int((time.time() - t0) * 1000)
timeout_ms -= elapsed
logger.warning(f"[offload.retry] node={node_id} status={status} → retry {attempt+1}")
continue
record_failure(node_id, required_type)
logger.warning(
f"[offload.fail] node={node_id} status={status} "
f"error={resp.get('error', {}).get('code', '?')}"
)
return resp
except asyncio.TimeoutError:
record_failure(node_id, required_type)
elapsed = int((time.time() - t0) * 1000)
timeout_ms -= elapsed
if attempt < MAX_RETRIES:
logger.warning(f"[offload.timeout] node={node_id} {elapsed}ms → retry {attempt+1}")
continue
logger.warning(f"[offload.timeout] node={node_id} all retries exhausted")
return None
except Exception as e:
record_failure(node_id, required_type)
logger.warning(f"[offload.error] node={node_id} {e}")
return None
return None