P2.2+P2.3: NATS offload node-worker + router offload integration

Node Worker (services/node-worker/):
- NATS subscriber for node.{NODE_ID}.llm.request / vision.request
- Canonical JobRequest/JobResponse envelope (Pydantic)
- Idempotency cache (TTL 10min) with inflight dedup
- Deadline enforcement (DEADLINE_EXCEEDED on expired jobs)
- Concurrency limiter (semaphore, returns busy)
- Ollama + Swapper vision providers

Router offload (services/router/offload_client.py):
- NATS req/reply with configurable retries
- Circuit breaker per node+type (3 fails/60s → open 120s)
- Concurrency semaphore for remote requests

Model selection (services/router/model_select.py):
- exclude_nodes parameter for circuit-broken nodes
- force_local flag for fallback re-selection
- Integrated circuit breaker state awareness

Router /infer pipeline:
- Remote offload path when NCS selects remote node
- Automatic fallback: exclude failed node → force_local re-select
- Deadline propagation from router to node-worker

Tests: 17 unit tests (idempotency, deadline, circuit breaker)
Docs: ops/offload_routing.md (subjects, envelope, verification)
Made-with: Cursor
This commit is contained in:
Apple
2026-02-27 02:44:05 -08:00
parent a92c424845
commit c4b94a327d
19 changed files with 1075 additions and 6 deletions

View File

@@ -133,6 +133,30 @@ services:
- dagi-network
restart: unless-stopped
node-worker:
build:
context: ./services/node-worker
dockerfile: Dockerfile
container_name: node-worker-node2
ports:
- "127.0.0.1:8109:8109"
extra_hosts:
- "host.docker.internal:host-gateway"
environment:
- NODE_ID=noda2
- NATS_URL=nats://dagi-nats:4222
- OLLAMA_BASE_URL=http://host.docker.internal:11434
- SWAPPER_URL=http://swapper-service:8890
- NODE_DEFAULT_LLM=qwen3:14b
- NODE_DEFAULT_VISION=llava:13b
- NODE_WORKER_MAX_CONCURRENCY=2
depends_on:
- dagi-nats
- swapper-service
networks:
- dagi-network
restart: unless-stopped
sofiia-console:
build:
context: ./services/sofiia-console

144
ops/offload_routing.md Normal file
View File

@@ -0,0 +1,144 @@
# NATS Offload Routing — Operations Guide
## Architecture
```
Router (NODA1/NODA2)
├── model_select.py → selects best model from global capabilities pool
│ (local first, remote if needed, circuit breaker aware)
├── offload_client.py → NATS req/reply to remote node-worker
│ (retries, deadlines, circuit breaker)
└── global_capabilities_client.py → scatter-gather discovery
node.*.capabilities.get
```
## NATS Subjects
| Subject | Direction | Description |
|---------|-----------|-------------|
| `node.{node_id}.capabilities.get` | req/reply | NCS capabilities query |
| `node.{node_id}.llm.request` | req/reply | LLM inference offload |
| `node.{node_id}.vision.request` | req/reply | Vision inference offload |
| `node.{node_id}.stt.request` | req/reply | STT (scaffold) |
| `node.{node_id}.tts.request` | req/reply | TTS (scaffold) |
## Job Request Envelope
```json
{
"job_id": "uuid",
"trace_id": "uuid",
"actor_agent_id": "sofiia",
"target_agent_id": "helion",
"required_type": "llm",
"deadline_ts": 1740000000000,
"idempotency_key": "uuid",
"payload": {
"messages": [{"role": "user", "content": "..."}],
"prompt": "...",
"model": "qwen3:14b",
"max_tokens": 2048,
"temperature": 0.2
},
"hints": {
"prefer_models": ["qwen3:14b"]
}
}
```
## Job Response
```json
{
"job_id": "uuid",
"trace_id": "uuid",
"node_id": "noda2",
"status": "ok|busy|timeout|error",
"provider": "ollama",
"model": "qwen3:14b",
"latency_ms": 5500,
"result": {"text": "..."},
"error": null,
"cached": false
}
```
## Circuit Breaker
Per `node_id:required_type`. Opens after 3 failures in 60s, stays open 120s.
| Env | Default | Description |
|-----|---------|-------------|
| `ROUTER_OFFLOAD_CB_FAILS` | 3 | Failures to trip |
| `ROUTER_OFFLOAD_CB_WINDOW_S` | 60 | Failure window |
| `ROUTER_OFFLOAD_CB_OPEN_S` | 120 | Open duration |
## Idempotency
Node-worker caches `idempotency_key → response` for 10 minutes.
Duplicate requests return cached response immediately (< 10ms).
Inflight dedup prevents parallel execution of same job.
## Deadline Enforcement
`deadline_ts` is absolute Unix milliseconds. If already expired on arrival,
node-worker returns `status=timeout` + `error.code=DEADLINE_EXCEEDED`.
During inference, `asyncio.wait_for` enforces remaining time.
## Fallback Chain
1. NCS selects remote model (e.g. noda2)
2. Router sends NATS offload request
3. If remote fails (timeout/busy/error):
- Record circuit breaker failure
- Re-select with `exclude_nodes={failed}` + `force_local=True`
- Execute locally
## Environment Variables
### Router
| Env | Default | Description |
|-----|---------|-------------|
| `ROUTER_INFER_TIMEOUT_MS` | 25000 | Total inference deadline |
| `ROUTER_OFFLOAD_RETRIES` | 1 | NATS retry on transient |
| `ROUTER_OFFLOAD_MAX_CONCURRENCY_REMOTE` | 8 | Max parallel offloads |
### Node Worker
| Env | Default | Description |
|-----|---------|-------------|
| `NODE_ID` | noda2 | Canonical node identifier |
| `NATS_URL` | nats://dagi-nats:4222 | NATS server |
| `OLLAMA_BASE_URL` | http://host.docker.internal:11434 | Ollama API |
| `SWAPPER_URL` | http://swapper-service:8890 | Swapper API |
| `NODE_DEFAULT_LLM` | qwen3:14b | Default LLM model |
| `NODE_DEFAULT_VISION` | llava:13b | Default vision model |
| `NODE_WORKER_MAX_CONCURRENCY` | 2 | Max parallel inferences |
## Verification Commands
```bash
# Node Worker health
curl -s http://localhost:8109/healthz | jq .
# Direct NATS LLM offload
nats request node.noda2.llm.request '{"job_id":"test","required_type":"llm","payload":{"prompt":"hi","max_tokens":10}}'
# Idempotency test (same job_id)
nats request node.noda2.llm.request '{"job_id":"test","idempotency_key":"test","required_type":"llm","payload":{"prompt":"hi"}}' # → cached=true
# Deadline expired test
nats request node.noda2.llm.request '{"job_id":"expired","required_type":"llm","deadline_ts":1000,"payload":{"prompt":"hi"}}' # → DEADLINE_EXCEEDED
# Router logs (selection + offload)
docker logs dagi-router-node2 2>&1 | grep -E '\[select\]|\[offload\]|\[fallback\]'
```
## Adding New Nodes
1. Deploy `node-capabilities` + `node-worker` with unique `NODE_ID`
2. Connect NATS (leafnode or cluster member)
3. Router auto-discovers via `node.*.capabilities.get` scatter-gather
4. No config changes on Router needed

View File

@@ -0,0 +1,7 @@
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8109
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8109"]

View File

View File

@@ -0,0 +1,12 @@
"""Node-worker configuration from environment."""
import os
NODE_ID = os.getenv("NODE_ID", "noda2")
NATS_URL = os.getenv("NATS_URL", "nats://dagi-nats:4222")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
SWAPPER_URL = os.getenv("SWAPPER_URL", "http://swapper-service:8890")
DEFAULT_LLM = os.getenv("NODE_DEFAULT_LLM", "qwen3:14b")
DEFAULT_VISION = os.getenv("NODE_DEFAULT_VISION", "llava:13b")
MAX_CONCURRENCY = int(os.getenv("NODE_WORKER_MAX_CONCURRENCY", "2"))
MAX_PAYLOAD_BYTES = int(os.getenv("NODE_WORKER_MAX_PAYLOAD_BYTES", str(1024 * 1024)))
PORT = int(os.getenv("PORT", "8109"))

View File

@@ -0,0 +1,62 @@
"""Idempotency cache + inflight dedup for job execution."""
import asyncio
import logging
import time
from typing import Dict, Optional, Tuple
from models import JobResponse
logger = logging.getLogger("idempotency")
CACHE_TTL = 600 # 10 min for successful results
TIMEOUT_TTL = 30 # 30s for timeout results
class IdempotencyStore:
def __init__(self, max_size: int = 10_000):
self._cache: Dict[str, Tuple[JobResponse, float]] = {}
self._inflight: Dict[str, asyncio.Future] = {}
self._max_size = max_size
def get(self, key: str) -> Optional[JobResponse]:
entry = self._cache.get(key)
if not entry:
return None
resp, expires = entry
if time.time() > expires:
self._cache.pop(key, None)
return None
cached = resp.model_copy()
cached.cached = True
return cached
def put(self, key: str, resp: JobResponse):
ttl = TIMEOUT_TTL if resp.status == "timeout" else CACHE_TTL
self._cache[key] = (resp, time.time() + ttl)
self._evict_if_needed()
def _evict_if_needed(self):
if len(self._cache) <= self._max_size:
return
now = time.time()
expired = [k for k, (_, exp) in self._cache.items() if now > exp]
for k in expired:
self._cache.pop(k, None)
if len(self._cache) > self._max_size:
oldest = sorted(self._cache, key=lambda k: self._cache[k][1])
for k in oldest[:len(self._cache) - self._max_size]:
self._cache.pop(k, None)
async def acquire_inflight(self, key: str) -> Optional[asyncio.Future]:
"""If another coroutine is already processing this key, return its future.
Otherwise register this coroutine as the processor and return None."""
if key in self._inflight:
return self._inflight[key]
fut: asyncio.Future = asyncio.get_event_loop().create_future()
self._inflight[key] = fut
return None
def complete_inflight(self, key: str, resp: JobResponse):
fut = self._inflight.pop(key, None)
if fut and not fut.done():
fut.set_result(resp)

View File

@@ -0,0 +1,53 @@
"""Node Worker — NATS offload executor for cross-node inference."""
import logging
import os
from fastapi import FastAPI
import config
import worker
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("node-worker")
app = FastAPI(title="Node Worker", version="1.0.0")
_nats_client = None
@app.get("/healthz")
async def healthz():
connected = _nats_client is not None and _nats_client.is_connected if _nats_client else False
return {
"status": "ok" if connected else "degraded",
"node_id": config.NODE_ID,
"nats_connected": connected,
"max_concurrency": config.MAX_CONCURRENCY,
}
@app.on_event("startup")
async def startup():
global _nats_client
try:
import nats as nats_lib
_nats_client = await nats_lib.connect(config.NATS_URL)
logger.info(f"✅ NATS connected: {config.NATS_URL}")
await worker.start(_nats_client)
logger.info(f"✅ Node Worker ready: node={config.NODE_ID} concurrency={config.MAX_CONCURRENCY}")
except Exception as e:
logger.error(f"❌ Startup failed: {e}")
@app.on_event("shutdown")
async def shutdown():
if _nats_client:
try:
await _nats_client.close()
except Exception:
pass
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=config.PORT)

View File

@@ -0,0 +1,47 @@
"""Canonical job envelope for cross-node inference offload."""
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field
import time
import uuid
def _ulid() -> str:
return str(uuid.uuid4())
class JobRequest(BaseModel):
job_id: str = Field(default_factory=_ulid)
trace_id: str = ""
actor_agent_id: str = ""
target_agent_id: str = ""
required_type: Literal["llm", "vision", "stt", "tts"] = "llm"
deadline_ts: int = 0
idempotency_key: str = ""
payload: Dict[str, Any] = Field(default_factory=dict)
hints: Dict[str, Any] = Field(default_factory=dict)
def remaining_ms(self) -> int:
if self.deadline_ts <= 0:
return 30_000
return max(0, self.deadline_ts - int(time.time() * 1000))
def effective_idem_key(self) -> str:
return self.idempotency_key or self.job_id
class JobError(BaseModel):
code: str = "UNKNOWN"
message: str = ""
class JobResponse(BaseModel):
job_id: str = ""
trace_id: str = ""
node_id: str = ""
status: Literal["ok", "busy", "timeout", "error"] = "ok"
provider: str = ""
model: str = ""
latency_ms: int = 0
result: Optional[Dict[str, Any]] = None
error: Optional[JobError] = None
cached: bool = False

View File

@@ -0,0 +1,81 @@
"""Ollama LLM provider for node-worker."""
import logging
from typing import Any, Dict, List, Optional
import httpx
from config import OLLAMA_BASE_URL, DEFAULT_LLM
logger = logging.getLogger("provider.ollama")
async def infer(
messages: Optional[List[Dict[str, str]]] = None,
prompt: str = "",
model: str = "",
system: str = "",
max_tokens: int = 2048,
temperature: float = 0.2,
timeout_s: float = 25.0,
) -> Dict[str, Any]:
model = model or DEFAULT_LLM
if messages:
return await _chat(messages, model, max_tokens, temperature, timeout_s)
return await _generate(prompt, system, model, max_tokens, temperature, timeout_s)
async def _chat(
messages: List[Dict[str, str]],
model: str,
max_tokens: int,
temperature: float,
timeout_s: float,
) -> Dict[str, Any]:
async with httpx.AsyncClient(timeout=timeout_s) as c:
resp = await c.post(
f"{OLLAMA_BASE_URL}/api/chat",
json={
"model": model,
"messages": messages,
"stream": False,
"options": {"num_predict": max_tokens, "temperature": temperature},
},
)
resp.raise_for_status()
data = resp.json()
return {
"text": data.get("message", {}).get("content", ""),
"model": model,
"provider": "ollama",
"eval_count": data.get("eval_count", 0),
}
async def _generate(
prompt: str,
system: str,
model: str,
max_tokens: int,
temperature: float,
timeout_s: float,
) -> Dict[str, Any]:
async with httpx.AsyncClient(timeout=timeout_s) as c:
resp = await c.post(
f"{OLLAMA_BASE_URL}/api/generate",
json={
"model": model,
"prompt": prompt,
"system": system,
"stream": False,
"options": {"num_predict": max_tokens, "temperature": temperature},
},
)
resp.raise_for_status()
data = resp.json()
return {
"text": data.get("response", ""),
"model": model,
"provider": "ollama",
"eval_count": data.get("eval_count", 0),
}

View File

@@ -0,0 +1,42 @@
"""Swapper vision provider for node-worker."""
import logging
from typing import Any, Dict, List, Optional
import httpx
from config import SWAPPER_URL, DEFAULT_VISION
logger = logging.getLogger("provider.swapper_vision")
async def infer(
images: Optional[List[str]] = None,
prompt: str = "",
model: str = "",
system: str = "",
max_tokens: int = 1024,
temperature: float = 0.2,
timeout_s: float = 60.0,
) -> Dict[str, Any]:
model = model or DEFAULT_VISION
payload: Dict[str, Any] = {
"model": model,
"prompt": prompt or "Describe this image.",
"max_tokens": max_tokens,
"temperature": temperature,
}
if images:
payload["images"] = images
if system:
payload["system"] = system
async with httpx.AsyncClient(timeout=timeout_s) as c:
resp = await c.post(f"{SWAPPER_URL}/vision", json=payload)
resp.raise_for_status()
data = resp.json()
return {
"text": data.get("text", data.get("response", "")),
"model": model,
"provider": "swapper_vision",
}

View File

@@ -0,0 +1,5 @@
fastapi>=0.110.0
uvicorn>=0.29.0
httpx>=0.27.0
nats-py>=2.7.0
pydantic>=2.5.0

View File

@@ -0,0 +1,184 @@
"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects."""
import asyncio
import json
import logging
import time
from typing import Any, Dict
import config
from models import JobRequest, JobResponse, JobError
from idempotency import IdempotencyStore
from providers import ollama, swapper_vision
logger = logging.getLogger("node-worker")
_idem = IdempotencyStore()
_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY)
_nats_client = None
async def start(nats_client):
global _nats_client
_nats_client = nats_client
subjects = [
f"node.{config.NODE_ID.lower()}.llm.request",
f"node.{config.NODE_ID.lower()}.vision.request",
]
for subj in subjects:
await nats_client.subscribe(subj, cb=_handle_request)
logger.info(f"✅ Subscribed: {subj}")
async def _handle_request(msg):
t0 = time.time()
try:
raw = msg.data
if len(raw) > config.MAX_PAYLOAD_BYTES:
await _reply(msg, JobResponse(
status="error",
node_id=config.NODE_ID,
error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"),
))
return
data = json.loads(raw)
job = JobRequest(**data)
job.trace_id = job.trace_id or job.job_id
idem_key = job.effective_idem_key()
cached = _idem.get(idem_key)
if cached:
logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}")
await _reply(msg, cached)
return
remaining = job.remaining_ms()
if remaining <= 0:
resp = JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout",
error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"),
)
_idem.put(idem_key, resp)
await _reply(msg, resp)
return
inflight = await _idem.acquire_inflight(idem_key)
if inflight is not None:
try:
resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0)
resp_copy = resp.model_copy()
resp_copy.cached = True
await _reply(msg, resp_copy)
except asyncio.TimeoutError:
await _reply(msg, JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"),
))
return
if _semaphore.locked() and _semaphore._value == 0:
resp = JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="busy",
error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"),
)
_idem.complete_inflight(idem_key, resp)
await _reply(msg, resp)
return
async with _semaphore:
resp = await _execute(job, remaining)
_idem.put(idem_key, resp)
_idem.complete_inflight(idem_key, resp)
resp.latency_ms = int((time.time() - t0) * 1000)
await _reply(msg, resp)
except Exception as e:
logger.exception(f"Worker handler error: {e}")
try:
await _reply(msg, JobResponse(
node_id=config.NODE_ID, status="error",
error=JobError(code="INTERNAL", message=str(e)[:200]),
))
except Exception:
pass
async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
payload = job.payload
hints = job.hints
timeout_s = min(remaining_ms / 1000.0, 120.0)
model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "")
msg_count = len(payload.get("messages", []))
prompt_chars = len(payload.get("prompt", ""))
logger.info(
f"[job.start] job={job.job_id} trace={job.trace_id} "
f"type={job.required_type} model={model or '?'} "
f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms"
)
try:
if job.required_type == "llm":
result = await asyncio.wait_for(
ollama.infer(
messages=payload.get("messages"),
prompt=payload.get("prompt", ""),
model=model,
system=payload.get("system", ""),
max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)),
temperature=hints.get("temperature", payload.get("temperature", 0.2)),
timeout_s=timeout_s,
),
timeout=timeout_s,
)
elif job.required_type == "vision":
result = await asyncio.wait_for(
swapper_vision.infer(
images=payload.get("images"),
prompt=payload.get("prompt", ""),
model=model,
system=payload.get("system", ""),
max_tokens=hints.get("max_tokens", 1024),
temperature=hints.get("temperature", 0.2),
timeout_s=timeout_s,
),
timeout=timeout_s,
)
else:
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not implemented"),
)
logger.info(
f"[job.done] job={job.job_id} status=ok "
f"provider={result.get('provider')} model={result.get('model')}"
)
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="ok", provider=result.get("provider", ""), model=result.get("model", ""),
result=result,
)
except asyncio.TimeoutError:
logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s")
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="timeout", error=JobError(code="PROVIDER_TIMEOUT"),
)
except Exception as e:
logger.warning(f"[job.error] job={job.job_id} {e}")
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]),
)
async def _reply(msg, resp: JobResponse):
if msg.reply:
await _nats_client.publish(msg.reply, resp.model_dump_json().encode())

View File

@@ -51,11 +51,13 @@ try:
import capabilities_client
import global_capabilities_client
from model_select import select_model_for_agent, ModelSelection, CLOUD_PROVIDERS as NCS_CLOUD_PROVIDERS
import offload_client
NCS_AVAILABLE = True
except ImportError:
NCS_AVAILABLE = False
capabilities_client = None # type: ignore[assignment]
global_capabilities_client = None # type: ignore[assignment]
offload_client = None # type: ignore[assignment]
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -1707,6 +1709,76 @@ async def agent_infer(agent_id: str, request: InferRequest):
f"provider={provider} model={model}"
)
# =========================================================================
# REMOTE OFFLOAD (if model selected on remote node)
# =========================================================================
nats_client_available = nc is not None and nats_available
if (ncs_selection and ncs_selection.via_nats and not ncs_selection.local
and nats_client_available and offload_client and nc):
infer_timeout = int(os.getenv("ROUTER_INFER_TIMEOUT_MS", "25000"))
import uuid as _uuid
job_payload = {
"job_id": str(_uuid.uuid4()),
"trace_id": str(_uuid.uuid4()),
"actor_agent_id": request_agent_id or agent_id,
"target_agent_id": agent_id,
"required_type": ncs_selection.model_type if ncs_selection.model_type != "code" else "llm",
"deadline_ts": int(time.time() * 1000) + infer_timeout,
"idempotency_key": str(_uuid.uuid4()),
"payload": {
"prompt": request.prompt,
"messages": [{"role": "system", "content": system_prompt}] if system_prompt else [],
"model": ncs_selection.name,
"max_tokens": request.max_tokens or 2048,
"temperature": request.temperature or 0.2,
},
"hints": {"prefer_models": [ncs_selection.name]},
}
if request.images:
job_payload["payload"]["images"] = request.images
job_payload["required_type"] = "vision"
job_payload["payload"]["messages"].append({"role": "user", "content": request.prompt})
offload_resp = await offload_client.offload_infer(
nats_client=nc,
node_id=ncs_selection.node,
required_type=job_payload["required_type"],
job_payload=job_payload,
timeout_ms=infer_timeout,
)
if offload_resp and offload_resp.get("status") == "ok":
result_text = offload_resp.get("result", {}).get("text", "")
return InferResponse(
response=result_text,
model=f"{offload_resp.get('model', ncs_selection.name)}@{ncs_selection.node}",
backend=f"nats-offload:{ncs_selection.node}",
tokens_used=offload_resp.get("result", {}).get("eval_count", 0),
)
else:
offload_status = offload_resp.get("status", "none") if offload_resp else "no_reply"
logger.warning(
f"[fallback] offload to {ncs_selection.node} failed ({offload_status}) "
f"→ re-selecting with exclude={ncs_selection.node}, force_local"
)
try:
gcaps = await global_capabilities_client.get_global_capabilities()
ncs_selection = await select_model_for_agent(
agent_id, agent_config, router_config, gcaps, request.model,
exclude_nodes={ncs_selection.node}, force_local=True,
)
if ncs_selection and ncs_selection.name:
provider = ncs_selection.provider
model = ncs_selection.name
llm_profile = router_config.get("llm_profiles", {}).get(default_llm, {})
if ncs_selection.base_url and provider == "ollama":
llm_profile = {**llm_profile, "base_url": ncs_selection.base_url}
logger.info(
f"[fallback.reselect] → local node={ncs_selection.node} "
f"model={model} provider={provider}"
)
except Exception as e:
logger.warning(f"[fallback.reselect] error: {e}; proceeding with static")
# =========================================================================
# VISION PROCESSING (if images present)
# =========================================================================

View File

@@ -9,7 +9,7 @@ Scaling: works with 1 node or 150+. No static node lists.
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Set
logger = logging.getLogger("model_select")
@@ -110,6 +110,7 @@ def profile_requirements(
def select_best_model(
reqs: ProfileRequirements,
capabilities: Dict[str, Any],
exclude_nodes: Optional[Set[str]] = None,
) -> Optional[ModelSelection]:
"""Choose the best served model from global (multi-node) capabilities.
@@ -117,18 +118,25 @@ def select_best_model(
1. Prefer list matches (local first, then remote)
2. Best candidate by size (local first, then remote)
3. None → caller should try static fallback
exclude_nodes: set of node_ids to skip (e.g. circuit-broken nodes).
"""
served = capabilities.get("served_models", [])
if not served:
return None
exclude = exclude_nodes or set()
search_types = [reqs.required_type]
if reqs.required_type == "code":
search_types.append("llm")
if reqs.required_type == "llm":
search_types.append("code")
candidates = [m for m in served if m.get("type") in search_types]
candidates = [
m for m in served
if m.get("type") in search_types and m.get("node", "") not in exclude
]
if not candidates:
return None
@@ -218,15 +226,21 @@ async def select_model_for_agent(
router_cfg: Dict[str, Any],
capabilities: Optional[Dict[str, Any]],
request_model: Optional[str] = None,
exclude_nodes: Optional[Set[str]] = None,
force_local: bool = False,
) -> ModelSelection:
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default."""
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default.
exclude_nodes: skip these nodes (circuit-broken). Used on fallback re-selection.
force_local: prefer local-only models (fallback after remote failure).
"""
profile = resolve_effective_profile(
agent_id, agent_cfg, router_cfg, request_model,
)
reqs = profile_requirements(profile, agent_cfg, router_cfg)
if reqs.required_type == "cloud_llm":
if reqs.required_type == "cloud_llm" and not force_local:
static = static_fallback(profile, router_cfg)
if static:
static.fallback_reason = ""
@@ -236,14 +250,31 @@ async def select_model_for_agent(
)
return static
excl = set(exclude_nodes) if exclude_nodes else set()
try:
from offload_client import get_unavailable_nodes
cb_nodes = get_unavailable_nodes(reqs.required_type)
excl |= cb_nodes
if cb_nodes:
logger.info(f"[select] circuit-broken nodes for {reqs.required_type}: {cb_nodes}")
except ImportError:
pass
if capabilities and capabilities.get("served_models"):
sel = select_best_model(reqs, capabilities)
sel = select_best_model(reqs, capabilities, exclude_nodes=excl)
if force_local and sel and not sel.local:
sel = select_best_model(
reqs, capabilities,
exclude_nodes=excl | {n.get("node", "") for n in capabilities.get("served_models", []) if not n.get("local")},
)
if sel:
logger.info(
f"[select] agent={agent_id} profile={profile}"
f"{'NCS' if sel.local else 'REMOTE'} "
f"{'LOCAL' if sel.local else 'REMOTE'} "
f"node={sel.node} runtime={sel.runtime} "
f"model={sel.name} caps_age={sel.caps_age_s}s"
f"{' (force_local)' if force_local else ''}"
f"{' (excluded: ' + ','.join(excl) + ')' if excl else ''}"
)
return sel
logger.warning(

View File

@@ -0,0 +1,153 @@
"""NATS offload client — sends inference requests to remote nodes with
circuit breaker, retries, and deadline enforcement."""
import asyncio
import json
import logging
import os
import time
from typing import Any, Dict, Literal, Optional, Set
logger = logging.getLogger("offload_client")
CB_FAILS = int(os.getenv("ROUTER_OFFLOAD_CB_FAILS", "3"))
CB_WINDOW_S = int(os.getenv("ROUTER_OFFLOAD_CB_WINDOW_S", "60"))
CB_OPEN_S = int(os.getenv("ROUTER_OFFLOAD_CB_OPEN_S", "120"))
MAX_RETRIES = int(os.getenv("ROUTER_OFFLOAD_RETRIES", "1"))
MAX_CONCURRENCY = int(os.getenv("ROUTER_OFFLOAD_MAX_CONCURRENCY_REMOTE", "8"))
_semaphore: Optional[asyncio.Semaphore] = None
_breakers: Dict[str, Dict[str, Any]] = {}
def _get_semaphore() -> asyncio.Semaphore:
global _semaphore
if _semaphore is None:
_semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
return _semaphore
def _breaker_key(node_id: str, req_type: str) -> str:
return f"{node_id}:{req_type}"
def is_node_available(node_id: str, req_type: str) -> bool:
key = _breaker_key(node_id, req_type)
state = _breakers.get(key)
if not state:
return True
open_until = state.get("open_until", 0)
if open_until > time.time():
return False
if open_until > 0 and open_until <= time.time():
return True
now = time.time()
window_start = now - CB_WINDOW_S
recent = [t for t in state.get("failures", []) if t > window_start]
state["failures"] = recent
return len(recent) < CB_FAILS
def record_failure(node_id: str, req_type: str):
key = _breaker_key(node_id, req_type)
state = _breakers.setdefault(key, {"failures": [], "open_until": 0})
state["failures"].append(time.time())
window_start = time.time() - CB_WINDOW_S
recent = [t for t in state["failures"] if t > window_start]
state["failures"] = recent
if len(recent) >= CB_FAILS:
state["open_until"] = time.time() + CB_OPEN_S
logger.warning(f"Circuit OPEN: {key} ({len(recent)} failures in {CB_WINDOW_S}s, open for {CB_OPEN_S}s)")
def record_success(node_id: str, req_type: str):
key = _breaker_key(node_id, req_type)
state = _breakers.get(key)
if state:
state["failures"] = []
state["open_until"] = 0
def get_unavailable_nodes(req_type: str) -> Set[str]:
result = set()
for key, state in _breakers.items():
if not key.endswith(f":{req_type}"):
continue
nid = key.rsplit(":", 1)[0]
if not is_node_available(nid, req_type):
result.add(nid)
return result
async def offload_infer(
nats_client,
node_id: str,
required_type: Literal["llm", "vision", "stt", "tts"],
job_payload: Dict[str, Any],
timeout_ms: int = 25000,
) -> Optional[Dict[str, Any]]:
"""Send inference job to remote node via NATS request/reply.
Returns parsed JobResponse dict or None on total failure.
Retries on transient errors (timeout, busy). Does NOT retry on provider errors.
"""
subject = f"node.{node_id.lower()}.{required_type}.request"
payload_bytes = json.dumps(job_payload).encode()
sem = _get_semaphore()
for attempt in range(1 + MAX_RETRIES):
timeout_s = timeout_ms / 1000.0
if timeout_s <= 0:
logger.warning(f"[offload] deadline exhausted before attempt {attempt}")
return None
t0 = time.time()
try:
async with sem:
logger.info(
f"[offload.start] node={node_id} subj={subject} "
f"timeout={timeout_ms}ms attempt={attempt}"
)
msg = await nats_client.request(subject, payload_bytes, timeout=timeout_s)
resp = json.loads(msg.data)
latency = int((time.time() - t0) * 1000)
status = resp.get("status", "error")
if status == "ok":
record_success(node_id, required_type)
logger.info(
f"[offload.done] node={node_id} status=ok latency={latency}ms "
f"provider={resp.get('provider')} model={resp.get('model')} "
f"cached={resp.get('cached', False)}"
)
return resp
if status in ("timeout", "busy") and attempt < MAX_RETRIES:
elapsed = int((time.time() - t0) * 1000)
timeout_ms -= elapsed
logger.warning(f"[offload.retry] node={node_id} status={status} → retry {attempt+1}")
continue
record_failure(node_id, required_type)
logger.warning(
f"[offload.fail] node={node_id} status={status} "
f"error={resp.get('error', {}).get('code', '?')}"
)
return resp
except asyncio.TimeoutError:
record_failure(node_id, required_type)
elapsed = int((time.time() - t0) * 1000)
timeout_ms -= elapsed
if attempt < MAX_RETRIES:
logger.warning(f"[offload.timeout] node={node_id} {elapsed}ms → retry {attempt+1}")
continue
logger.warning(f"[offload.timeout] node={node_id} all retries exhausted")
return None
except Exception as e:
record_failure(node_id, required_type)
logger.warning(f"[offload.error] node={node_id} {e}")
return None
return None

View File

@@ -0,0 +1,34 @@
"""Tests for node-worker deadline handling."""
import sys
import os
import time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services", "node-worker"))
from models import JobRequest
def test_remaining_ms_with_future_deadline():
job = JobRequest(deadline_ts=int(time.time() * 1000) + 10000)
rem = job.remaining_ms()
assert 9000 <= rem <= 10500
def test_remaining_ms_with_past_deadline():
job = JobRequest(deadline_ts=int(time.time() * 1000) - 5000)
assert job.remaining_ms() == 0
def test_remaining_ms_no_deadline():
job = JobRequest(deadline_ts=0)
assert job.remaining_ms() == 30_000
def test_effective_idem_key_uses_idempotency_key():
job = JobRequest(job_id="j1", idempotency_key="custom-key")
assert job.effective_idem_key() == "custom-key"
def test_effective_idem_key_falls_back_to_job_id():
job = JobRequest(job_id="j2", idempotency_key="")
assert job.effective_idem_key() == "j2"

View File

@@ -0,0 +1,66 @@
"""Tests for node-worker idempotency store."""
import asyncio
import sys
import os
import time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services", "node-worker"))
from models import JobResponse, JobError
from idempotency import IdempotencyStore
def test_put_and_get():
store = IdempotencyStore()
resp = JobResponse(job_id="j1", status="ok", node_id="n1", provider="ollama", model="qwen3:14b")
store.put("key1", resp)
cached = store.get("key1")
assert cached is not None
assert cached.status == "ok"
assert cached.cached is True
assert cached.job_id == "j1"
def test_miss():
store = IdempotencyStore()
assert store.get("nonexistent") is None
def test_ttl_expiry():
store = IdempotencyStore()
resp = JobResponse(job_id="j2", status="ok", node_id="n1")
store.put("key2", resp)
store._cache["key2"] = (resp, time.time() - 1)
assert store.get("key2") is None
def test_timeout_shorter_ttl():
store = IdempotencyStore()
resp = JobResponse(job_id="j3", status="timeout", error=JobError(code="TO"))
store.put("key3", resp)
_, expires = store._cache["key3"]
assert expires - time.time() < 35 # timeout TTL ≈ 30s
def test_inflight_dedup():
store = IdempotencyStore()
async def run():
fut1 = await store.acquire_inflight("key4")
assert fut1 is None # first caller gets None (becomes processor)
fut2 = await store.acquire_inflight("key4")
assert fut2 is not None # second caller gets future to wait on
resp = JobResponse(job_id="j4", status="ok")
store.complete_inflight("key4", resp)
result = await asyncio.wait_for(fut2, timeout=1.0)
assert result.status == "ok"
asyncio.run(run())
def test_evict_on_max_size():
store = IdempotencyStore(max_size=3)
for i in range(5):
store.put(f"k{i}", JobResponse(job_id=f"j{i}", status="ok"))
assert len(store._cache) <= 3

View File

@@ -0,0 +1,52 @@
"""Tests for offload_client circuit breaker."""
import sys
import os
import time
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services", "router"))
import offload_client
def setup_function():
offload_client._breakers.clear()
def test_node_available_by_default():
assert offload_client.is_node_available("noda99", "llm") is True
def test_node_unavailable_after_threshold():
for _ in range(offload_client.CB_FAILS):
offload_client.record_failure("noda99", "llm")
assert offload_client.is_node_available("noda99", "llm") is False
def test_node_available_after_open_expires():
for _ in range(offload_client.CB_FAILS):
offload_client.record_failure("noda99", "llm")
offload_client._breakers["noda99:llm"]["open_until"] = time.time() - 1
assert offload_client.is_node_available("noda99", "llm") is True
def test_success_resets_breaker():
for _ in range(offload_client.CB_FAILS):
offload_client.record_failure("noda99", "llm")
offload_client._breakers["noda99:llm"]["open_until"] = time.time() - 1
offload_client.record_success("noda99", "llm")
assert offload_client.is_node_available("noda99", "llm") is True
def test_get_unavailable_nodes():
for _ in range(offload_client.CB_FAILS):
offload_client.record_failure("bad_node", "llm")
unavail = offload_client.get_unavailable_nodes("llm")
assert "bad_node" in unavail
assert offload_client.get_unavailable_nodes("vision") == set()
def test_different_types_independent():
for _ in range(offload_client.CB_FAILS):
offload_client.record_failure("noda99", "llm")
assert offload_client.is_node_available("noda99", "llm") is False
assert offload_client.is_node_available("noda99", "vision") is True