P2.2+P2.3: NATS offload node-worker + router offload integration
Node Worker (services/node-worker/):
- NATS subscriber for node.{NODE_ID}.llm.request / vision.request
- Canonical JobRequest/JobResponse envelope (Pydantic)
- Idempotency cache (TTL 10min) with inflight dedup
- Deadline enforcement (DEADLINE_EXCEEDED on expired jobs)
- Concurrency limiter (semaphore, returns busy)
- Ollama + Swapper vision providers
Router offload (services/router/offload_client.py):
- NATS req/reply with configurable retries
- Circuit breaker per node+type (3 fails/60s → open 120s)
- Concurrency semaphore for remote requests
Model selection (services/router/model_select.py):
- exclude_nodes parameter for circuit-broken nodes
- force_local flag for fallback re-selection
- Integrated circuit breaker state awareness
Router /infer pipeline:
- Remote offload path when NCS selects remote node
- Automatic fallback: exclude failed node → force_local re-select
- Deadline propagation from router to node-worker
Tests: 17 unit tests (idempotency, deadline, circuit breaker)
Docs: ops/offload_routing.md (subjects, envelope, verification)
Made-with: Cursor
This commit is contained in:
@@ -51,11 +51,13 @@ try:
|
||||
import capabilities_client
|
||||
import global_capabilities_client
|
||||
from model_select import select_model_for_agent, ModelSelection, CLOUD_PROVIDERS as NCS_CLOUD_PROVIDERS
|
||||
import offload_client
|
||||
NCS_AVAILABLE = True
|
||||
except ImportError:
|
||||
NCS_AVAILABLE = False
|
||||
capabilities_client = None # type: ignore[assignment]
|
||||
global_capabilities_client = None # type: ignore[assignment]
|
||||
offload_client = None # type: ignore[assignment]
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1707,6 +1709,76 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
f"provider={provider} model={model}"
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# REMOTE OFFLOAD (if model selected on remote node)
|
||||
# =========================================================================
|
||||
nats_client_available = nc is not None and nats_available
|
||||
if (ncs_selection and ncs_selection.via_nats and not ncs_selection.local
|
||||
and nats_client_available and offload_client and nc):
|
||||
infer_timeout = int(os.getenv("ROUTER_INFER_TIMEOUT_MS", "25000"))
|
||||
import uuid as _uuid
|
||||
job_payload = {
|
||||
"job_id": str(_uuid.uuid4()),
|
||||
"trace_id": str(_uuid.uuid4()),
|
||||
"actor_agent_id": request_agent_id or agent_id,
|
||||
"target_agent_id": agent_id,
|
||||
"required_type": ncs_selection.model_type if ncs_selection.model_type != "code" else "llm",
|
||||
"deadline_ts": int(time.time() * 1000) + infer_timeout,
|
||||
"idempotency_key": str(_uuid.uuid4()),
|
||||
"payload": {
|
||||
"prompt": request.prompt,
|
||||
"messages": [{"role": "system", "content": system_prompt}] if system_prompt else [],
|
||||
"model": ncs_selection.name,
|
||||
"max_tokens": request.max_tokens or 2048,
|
||||
"temperature": request.temperature or 0.2,
|
||||
},
|
||||
"hints": {"prefer_models": [ncs_selection.name]},
|
||||
}
|
||||
if request.images:
|
||||
job_payload["payload"]["images"] = request.images
|
||||
job_payload["required_type"] = "vision"
|
||||
job_payload["payload"]["messages"].append({"role": "user", "content": request.prompt})
|
||||
|
||||
offload_resp = await offload_client.offload_infer(
|
||||
nats_client=nc,
|
||||
node_id=ncs_selection.node,
|
||||
required_type=job_payload["required_type"],
|
||||
job_payload=job_payload,
|
||||
timeout_ms=infer_timeout,
|
||||
)
|
||||
if offload_resp and offload_resp.get("status") == "ok":
|
||||
result_text = offload_resp.get("result", {}).get("text", "")
|
||||
return InferResponse(
|
||||
response=result_text,
|
||||
model=f"{offload_resp.get('model', ncs_selection.name)}@{ncs_selection.node}",
|
||||
backend=f"nats-offload:{ncs_selection.node}",
|
||||
tokens_used=offload_resp.get("result", {}).get("eval_count", 0),
|
||||
)
|
||||
else:
|
||||
offload_status = offload_resp.get("status", "none") if offload_resp else "no_reply"
|
||||
logger.warning(
|
||||
f"[fallback] offload to {ncs_selection.node} failed ({offload_status}) "
|
||||
f"→ re-selecting with exclude={ncs_selection.node}, force_local"
|
||||
)
|
||||
try:
|
||||
gcaps = await global_capabilities_client.get_global_capabilities()
|
||||
ncs_selection = await select_model_for_agent(
|
||||
agent_id, agent_config, router_config, gcaps, request.model,
|
||||
exclude_nodes={ncs_selection.node}, force_local=True,
|
||||
)
|
||||
if ncs_selection and ncs_selection.name:
|
||||
provider = ncs_selection.provider
|
||||
model = ncs_selection.name
|
||||
llm_profile = router_config.get("llm_profiles", {}).get(default_llm, {})
|
||||
if ncs_selection.base_url and provider == "ollama":
|
||||
llm_profile = {**llm_profile, "base_url": ncs_selection.base_url}
|
||||
logger.info(
|
||||
f"[fallback.reselect] → local node={ncs_selection.node} "
|
||||
f"model={model} provider={provider}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[fallback.reselect] error: {e}; proceeding with static")
|
||||
|
||||
# =========================================================================
|
||||
# VISION PROCESSING (if images present)
|
||||
# =========================================================================
|
||||
|
||||
@@ -9,7 +9,7 @@ Scaling: works with 1 node or 150+. No static node lists.
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger("model_select")
|
||||
|
||||
@@ -110,6 +110,7 @@ def profile_requirements(
|
||||
def select_best_model(
|
||||
reqs: ProfileRequirements,
|
||||
capabilities: Dict[str, Any],
|
||||
exclude_nodes: Optional[Set[str]] = None,
|
||||
) -> Optional[ModelSelection]:
|
||||
"""Choose the best served model from global (multi-node) capabilities.
|
||||
|
||||
@@ -117,18 +118,25 @@ def select_best_model(
|
||||
1. Prefer list matches (local first, then remote)
|
||||
2. Best candidate by size (local first, then remote)
|
||||
3. None → caller should try static fallback
|
||||
|
||||
exclude_nodes: set of node_ids to skip (e.g. circuit-broken nodes).
|
||||
"""
|
||||
served = capabilities.get("served_models", [])
|
||||
if not served:
|
||||
return None
|
||||
|
||||
exclude = exclude_nodes or set()
|
||||
|
||||
search_types = [reqs.required_type]
|
||||
if reqs.required_type == "code":
|
||||
search_types.append("llm")
|
||||
if reqs.required_type == "llm":
|
||||
search_types.append("code")
|
||||
|
||||
candidates = [m for m in served if m.get("type") in search_types]
|
||||
candidates = [
|
||||
m for m in served
|
||||
if m.get("type") in search_types and m.get("node", "") not in exclude
|
||||
]
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
@@ -218,15 +226,21 @@ async def select_model_for_agent(
|
||||
router_cfg: Dict[str, Any],
|
||||
capabilities: Optional[Dict[str, Any]],
|
||||
request_model: Optional[str] = None,
|
||||
exclude_nodes: Optional[Set[str]] = None,
|
||||
force_local: bool = False,
|
||||
) -> ModelSelection:
|
||||
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default."""
|
||||
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default.
|
||||
|
||||
exclude_nodes: skip these nodes (circuit-broken). Used on fallback re-selection.
|
||||
force_local: prefer local-only models (fallback after remote failure).
|
||||
"""
|
||||
profile = resolve_effective_profile(
|
||||
agent_id, agent_cfg, router_cfg, request_model,
|
||||
)
|
||||
|
||||
reqs = profile_requirements(profile, agent_cfg, router_cfg)
|
||||
|
||||
if reqs.required_type == "cloud_llm":
|
||||
if reqs.required_type == "cloud_llm" and not force_local:
|
||||
static = static_fallback(profile, router_cfg)
|
||||
if static:
|
||||
static.fallback_reason = ""
|
||||
@@ -236,14 +250,31 @@ async def select_model_for_agent(
|
||||
)
|
||||
return static
|
||||
|
||||
excl = set(exclude_nodes) if exclude_nodes else set()
|
||||
try:
|
||||
from offload_client import get_unavailable_nodes
|
||||
cb_nodes = get_unavailable_nodes(reqs.required_type)
|
||||
excl |= cb_nodes
|
||||
if cb_nodes:
|
||||
logger.info(f"[select] circuit-broken nodes for {reqs.required_type}: {cb_nodes}")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if capabilities and capabilities.get("served_models"):
|
||||
sel = select_best_model(reqs, capabilities)
|
||||
sel = select_best_model(reqs, capabilities, exclude_nodes=excl)
|
||||
if force_local and sel and not sel.local:
|
||||
sel = select_best_model(
|
||||
reqs, capabilities,
|
||||
exclude_nodes=excl | {n.get("node", "") for n in capabilities.get("served_models", []) if not n.get("local")},
|
||||
)
|
||||
if sel:
|
||||
logger.info(
|
||||
f"[select] agent={agent_id} profile={profile} → "
|
||||
f"{'NCS' if sel.local else 'REMOTE'} "
|
||||
f"{'LOCAL' if sel.local else 'REMOTE'} "
|
||||
f"node={sel.node} runtime={sel.runtime} "
|
||||
f"model={sel.name} caps_age={sel.caps_age_s}s"
|
||||
f"{' (force_local)' if force_local else ''}"
|
||||
f"{' (excluded: ' + ','.join(excl) + ')' if excl else ''}"
|
||||
)
|
||||
return sel
|
||||
logger.warning(
|
||||
|
||||
153
services/router/offload_client.py
Normal file
153
services/router/offload_client.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""NATS offload client — sends inference requests to remote nodes with
|
||||
circuit breaker, retries, and deadline enforcement."""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Literal, Optional, Set
|
||||
|
||||
logger = logging.getLogger("offload_client")
|
||||
|
||||
CB_FAILS = int(os.getenv("ROUTER_OFFLOAD_CB_FAILS", "3"))
|
||||
CB_WINDOW_S = int(os.getenv("ROUTER_OFFLOAD_CB_WINDOW_S", "60"))
|
||||
CB_OPEN_S = int(os.getenv("ROUTER_OFFLOAD_CB_OPEN_S", "120"))
|
||||
MAX_RETRIES = int(os.getenv("ROUTER_OFFLOAD_RETRIES", "1"))
|
||||
MAX_CONCURRENCY = int(os.getenv("ROUTER_OFFLOAD_MAX_CONCURRENCY_REMOTE", "8"))
|
||||
|
||||
_semaphore: Optional[asyncio.Semaphore] = None
|
||||
_breakers: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _get_semaphore() -> asyncio.Semaphore:
|
||||
global _semaphore
|
||||
if _semaphore is None:
|
||||
_semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
|
||||
return _semaphore
|
||||
|
||||
|
||||
def _breaker_key(node_id: str, req_type: str) -> str:
|
||||
return f"{node_id}:{req_type}"
|
||||
|
||||
|
||||
def is_node_available(node_id: str, req_type: str) -> bool:
|
||||
key = _breaker_key(node_id, req_type)
|
||||
state = _breakers.get(key)
|
||||
if not state:
|
||||
return True
|
||||
open_until = state.get("open_until", 0)
|
||||
if open_until > time.time():
|
||||
return False
|
||||
if open_until > 0 and open_until <= time.time():
|
||||
return True
|
||||
now = time.time()
|
||||
window_start = now - CB_WINDOW_S
|
||||
recent = [t for t in state.get("failures", []) if t > window_start]
|
||||
state["failures"] = recent
|
||||
return len(recent) < CB_FAILS
|
||||
|
||||
|
||||
def record_failure(node_id: str, req_type: str):
|
||||
key = _breaker_key(node_id, req_type)
|
||||
state = _breakers.setdefault(key, {"failures": [], "open_until": 0})
|
||||
state["failures"].append(time.time())
|
||||
window_start = time.time() - CB_WINDOW_S
|
||||
recent = [t for t in state["failures"] if t > window_start]
|
||||
state["failures"] = recent
|
||||
if len(recent) >= CB_FAILS:
|
||||
state["open_until"] = time.time() + CB_OPEN_S
|
||||
logger.warning(f"Circuit OPEN: {key} ({len(recent)} failures in {CB_WINDOW_S}s, open for {CB_OPEN_S}s)")
|
||||
|
||||
|
||||
def record_success(node_id: str, req_type: str):
|
||||
key = _breaker_key(node_id, req_type)
|
||||
state = _breakers.get(key)
|
||||
if state:
|
||||
state["failures"] = []
|
||||
state["open_until"] = 0
|
||||
|
||||
|
||||
def get_unavailable_nodes(req_type: str) -> Set[str]:
|
||||
result = set()
|
||||
for key, state in _breakers.items():
|
||||
if not key.endswith(f":{req_type}"):
|
||||
continue
|
||||
nid = key.rsplit(":", 1)[0]
|
||||
if not is_node_available(nid, req_type):
|
||||
result.add(nid)
|
||||
return result
|
||||
|
||||
|
||||
async def offload_infer(
|
||||
nats_client,
|
||||
node_id: str,
|
||||
required_type: Literal["llm", "vision", "stt", "tts"],
|
||||
job_payload: Dict[str, Any],
|
||||
timeout_ms: int = 25000,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Send inference job to remote node via NATS request/reply.
|
||||
|
||||
Returns parsed JobResponse dict or None on total failure.
|
||||
Retries on transient errors (timeout, busy). Does NOT retry on provider errors.
|
||||
"""
|
||||
subject = f"node.{node_id.lower()}.{required_type}.request"
|
||||
payload_bytes = json.dumps(job_payload).encode()
|
||||
sem = _get_semaphore()
|
||||
|
||||
for attempt in range(1 + MAX_RETRIES):
|
||||
timeout_s = timeout_ms / 1000.0
|
||||
if timeout_s <= 0:
|
||||
logger.warning(f"[offload] deadline exhausted before attempt {attempt}")
|
||||
return None
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
async with sem:
|
||||
logger.info(
|
||||
f"[offload.start] node={node_id} subj={subject} "
|
||||
f"timeout={timeout_ms}ms attempt={attempt}"
|
||||
)
|
||||
msg = await nats_client.request(subject, payload_bytes, timeout=timeout_s)
|
||||
resp = json.loads(msg.data)
|
||||
|
||||
latency = int((time.time() - t0) * 1000)
|
||||
status = resp.get("status", "error")
|
||||
|
||||
if status == "ok":
|
||||
record_success(node_id, required_type)
|
||||
logger.info(
|
||||
f"[offload.done] node={node_id} status=ok latency={latency}ms "
|
||||
f"provider={resp.get('provider')} model={resp.get('model')} "
|
||||
f"cached={resp.get('cached', False)}"
|
||||
)
|
||||
return resp
|
||||
|
||||
if status in ("timeout", "busy") and attempt < MAX_RETRIES:
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
timeout_ms -= elapsed
|
||||
logger.warning(f"[offload.retry] node={node_id} status={status} → retry {attempt+1}")
|
||||
continue
|
||||
|
||||
record_failure(node_id, required_type)
|
||||
logger.warning(
|
||||
f"[offload.fail] node={node_id} status={status} "
|
||||
f"error={resp.get('error', {}).get('code', '?')}"
|
||||
)
|
||||
return resp
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
record_failure(node_id, required_type)
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
timeout_ms -= elapsed
|
||||
if attempt < MAX_RETRIES:
|
||||
logger.warning(f"[offload.timeout] node={node_id} {elapsed}ms → retry {attempt+1}")
|
||||
continue
|
||||
logger.warning(f"[offload.timeout] node={node_id} all retries exhausted")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
record_failure(node_id, required_type)
|
||||
logger.warning(f"[offload.error] node={node_id} {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user