diff --git a/docker-compose.node2-sofiia.yml b/docker-compose.node2-sofiia.yml index 3bfd2ce4..61e72e1e 100644 --- a/docker-compose.node2-sofiia.yml +++ b/docker-compose.node2-sofiia.yml @@ -133,6 +133,30 @@ services: - dagi-network restart: unless-stopped + node-worker: + build: + context: ./services/node-worker + dockerfile: Dockerfile + container_name: node-worker-node2 + ports: + - "127.0.0.1:8109:8109" + extra_hosts: + - "host.docker.internal:host-gateway" + environment: + - NODE_ID=noda2 + - NATS_URL=nats://dagi-nats:4222 + - OLLAMA_BASE_URL=http://host.docker.internal:11434 + - SWAPPER_URL=http://swapper-service:8890 + - NODE_DEFAULT_LLM=qwen3:14b + - NODE_DEFAULT_VISION=llava:13b + - NODE_WORKER_MAX_CONCURRENCY=2 + depends_on: + - dagi-nats + - swapper-service + networks: + - dagi-network + restart: unless-stopped + sofiia-console: build: context: ./services/sofiia-console diff --git a/ops/offload_routing.md b/ops/offload_routing.md new file mode 100644 index 00000000..f29dae19 --- /dev/null +++ b/ops/offload_routing.md @@ -0,0 +1,144 @@ +# NATS Offload Routing — Operations Guide + +## Architecture + +``` + Router (NODA1/NODA2) + │ + ├── model_select.py → selects best model from global capabilities pool + │ (local first, remote if needed, circuit breaker aware) + │ + ├── offload_client.py → NATS req/reply to remote node-worker + │ (retries, deadlines, circuit breaker) + │ + └── global_capabilities_client.py → scatter-gather discovery + node.*.capabilities.get +``` + +## NATS Subjects + +| Subject | Direction | Description | +|---------|-----------|-------------| +| `node.{node_id}.capabilities.get` | req/reply | NCS capabilities query | +| `node.{node_id}.llm.request` | req/reply | LLM inference offload | +| `node.{node_id}.vision.request` | req/reply | Vision inference offload | +| `node.{node_id}.stt.request` | req/reply | STT (scaffold) | +| `node.{node_id}.tts.request` | req/reply | TTS (scaffold) | + +## Job Request Envelope + +```json +{ + "job_id": "uuid", + "trace_id": "uuid", + "actor_agent_id": "sofiia", + "target_agent_id": "helion", + "required_type": "llm", + "deadline_ts": 1740000000000, + "idempotency_key": "uuid", + "payload": { + "messages": [{"role": "user", "content": "..."}], + "prompt": "...", + "model": "qwen3:14b", + "max_tokens": 2048, + "temperature": 0.2 + }, + "hints": { + "prefer_models": ["qwen3:14b"] + } +} +``` + +## Job Response + +```json +{ + "job_id": "uuid", + "trace_id": "uuid", + "node_id": "noda2", + "status": "ok|busy|timeout|error", + "provider": "ollama", + "model": "qwen3:14b", + "latency_ms": 5500, + "result": {"text": "..."}, + "error": null, + "cached": false +} +``` + +## Circuit Breaker + +Per `node_id:required_type`. Opens after 3 failures in 60s, stays open 120s. + +| Env | Default | Description | +|-----|---------|-------------| +| `ROUTER_OFFLOAD_CB_FAILS` | 3 | Failures to trip | +| `ROUTER_OFFLOAD_CB_WINDOW_S` | 60 | Failure window | +| `ROUTER_OFFLOAD_CB_OPEN_S` | 120 | Open duration | + +## Idempotency + +Node-worker caches `idempotency_key → response` for 10 minutes. +Duplicate requests return cached response immediately (< 10ms). +Inflight dedup prevents parallel execution of same job. + +## Deadline Enforcement + +`deadline_ts` is absolute Unix milliseconds. If already expired on arrival, +node-worker returns `status=timeout` + `error.code=DEADLINE_EXCEEDED`. +During inference, `asyncio.wait_for` enforces remaining time. + +## Fallback Chain + +1. NCS selects remote model (e.g. noda2) +2. Router sends NATS offload request +3. If remote fails (timeout/busy/error): + - Record circuit breaker failure + - Re-select with `exclude_nodes={failed}` + `force_local=True` + - Execute locally + +## Environment Variables + +### Router +| Env | Default | Description | +|-----|---------|-------------| +| `ROUTER_INFER_TIMEOUT_MS` | 25000 | Total inference deadline | +| `ROUTER_OFFLOAD_RETRIES` | 1 | NATS retry on transient | +| `ROUTER_OFFLOAD_MAX_CONCURRENCY_REMOTE` | 8 | Max parallel offloads | + +### Node Worker +| Env | Default | Description | +|-----|---------|-------------| +| `NODE_ID` | noda2 | Canonical node identifier | +| `NATS_URL` | nats://dagi-nats:4222 | NATS server | +| `OLLAMA_BASE_URL` | http://host.docker.internal:11434 | Ollama API | +| `SWAPPER_URL` | http://swapper-service:8890 | Swapper API | +| `NODE_DEFAULT_LLM` | qwen3:14b | Default LLM model | +| `NODE_DEFAULT_VISION` | llava:13b | Default vision model | +| `NODE_WORKER_MAX_CONCURRENCY` | 2 | Max parallel inferences | + +## Verification Commands + +```bash +# Node Worker health +curl -s http://localhost:8109/healthz | jq . + +# Direct NATS LLM offload +nats request node.noda2.llm.request '{"job_id":"test","required_type":"llm","payload":{"prompt":"hi","max_tokens":10}}' + +# Idempotency test (same job_id) +nats request node.noda2.llm.request '{"job_id":"test","idempotency_key":"test","required_type":"llm","payload":{"prompt":"hi"}}' # → cached=true + +# Deadline expired test +nats request node.noda2.llm.request '{"job_id":"expired","required_type":"llm","deadline_ts":1000,"payload":{"prompt":"hi"}}' # → DEADLINE_EXCEEDED + +# Router logs (selection + offload) +docker logs dagi-router-node2 2>&1 | grep -E '\[select\]|\[offload\]|\[fallback\]' +``` + +## Adding New Nodes + +1. Deploy `node-capabilities` + `node-worker` with unique `NODE_ID` +2. Connect NATS (leafnode or cluster member) +3. Router auto-discovers via `node.*.capabilities.get` scatter-gather +4. No config changes on Router needed diff --git a/services/node-worker/Dockerfile b/services/node-worker/Dockerfile new file mode 100644 index 00000000..bfb7fe70 --- /dev/null +++ b/services/node-worker/Dockerfile @@ -0,0 +1,7 @@ +FROM python:3.11-slim +WORKDIR /app +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt +COPY . . +EXPOSE 8109 +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8109"] diff --git a/services/node-worker/__init__.py b/services/node-worker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/node-worker/config.py b/services/node-worker/config.py new file mode 100644 index 00000000..959a774b --- /dev/null +++ b/services/node-worker/config.py @@ -0,0 +1,12 @@ +"""Node-worker configuration from environment.""" +import os + +NODE_ID = os.getenv("NODE_ID", "noda2") +NATS_URL = os.getenv("NATS_URL", "nats://dagi-nats:4222") +OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") +SWAPPER_URL = os.getenv("SWAPPER_URL", "http://swapper-service:8890") +DEFAULT_LLM = os.getenv("NODE_DEFAULT_LLM", "qwen3:14b") +DEFAULT_VISION = os.getenv("NODE_DEFAULT_VISION", "llava:13b") +MAX_CONCURRENCY = int(os.getenv("NODE_WORKER_MAX_CONCURRENCY", "2")) +MAX_PAYLOAD_BYTES = int(os.getenv("NODE_WORKER_MAX_PAYLOAD_BYTES", str(1024 * 1024))) +PORT = int(os.getenv("PORT", "8109")) diff --git a/services/node-worker/idempotency.py b/services/node-worker/idempotency.py new file mode 100644 index 00000000..9cad285c --- /dev/null +++ b/services/node-worker/idempotency.py @@ -0,0 +1,62 @@ +"""Idempotency cache + inflight dedup for job execution.""" +import asyncio +import logging +import time +from typing import Dict, Optional, Tuple + +from models import JobResponse + +logger = logging.getLogger("idempotency") + +CACHE_TTL = 600 # 10 min for successful results +TIMEOUT_TTL = 30 # 30s for timeout results + + +class IdempotencyStore: + def __init__(self, max_size: int = 10_000): + self._cache: Dict[str, Tuple[JobResponse, float]] = {} + self._inflight: Dict[str, asyncio.Future] = {} + self._max_size = max_size + + def get(self, key: str) -> Optional[JobResponse]: + entry = self._cache.get(key) + if not entry: + return None + resp, expires = entry + if time.time() > expires: + self._cache.pop(key, None) + return None + cached = resp.model_copy() + cached.cached = True + return cached + + def put(self, key: str, resp: JobResponse): + ttl = TIMEOUT_TTL if resp.status == "timeout" else CACHE_TTL + self._cache[key] = (resp, time.time() + ttl) + self._evict_if_needed() + + def _evict_if_needed(self): + if len(self._cache) <= self._max_size: + return + now = time.time() + expired = [k for k, (_, exp) in self._cache.items() if now > exp] + for k in expired: + self._cache.pop(k, None) + if len(self._cache) > self._max_size: + oldest = sorted(self._cache, key=lambda k: self._cache[k][1]) + for k in oldest[:len(self._cache) - self._max_size]: + self._cache.pop(k, None) + + async def acquire_inflight(self, key: str) -> Optional[asyncio.Future]: + """If another coroutine is already processing this key, return its future. + Otherwise register this coroutine as the processor and return None.""" + if key in self._inflight: + return self._inflight[key] + fut: asyncio.Future = asyncio.get_event_loop().create_future() + self._inflight[key] = fut + return None + + def complete_inflight(self, key: str, resp: JobResponse): + fut = self._inflight.pop(key, None) + if fut and not fut.done(): + fut.set_result(resp) diff --git a/services/node-worker/main.py b/services/node-worker/main.py new file mode 100644 index 00000000..f85b3211 --- /dev/null +++ b/services/node-worker/main.py @@ -0,0 +1,53 @@ +"""Node Worker — NATS offload executor for cross-node inference.""" +import logging +import os + +from fastapi import FastAPI + +import config +import worker + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger("node-worker") + +app = FastAPI(title="Node Worker", version="1.0.0") + +_nats_client = None + + +@app.get("/healthz") +async def healthz(): + connected = _nats_client is not None and _nats_client.is_connected if _nats_client else False + return { + "status": "ok" if connected else "degraded", + "node_id": config.NODE_ID, + "nats_connected": connected, + "max_concurrency": config.MAX_CONCURRENCY, + } + + +@app.on_event("startup") +async def startup(): + global _nats_client + try: + import nats as nats_lib + _nats_client = await nats_lib.connect(config.NATS_URL) + logger.info(f"✅ NATS connected: {config.NATS_URL}") + await worker.start(_nats_client) + logger.info(f"✅ Node Worker ready: node={config.NODE_ID} concurrency={config.MAX_CONCURRENCY}") + except Exception as e: + logger.error(f"❌ Startup failed: {e}") + + +@app.on_event("shutdown") +async def shutdown(): + if _nats_client: + try: + await _nats_client.close() + except Exception: + pass + + +if __name__ == "__main__": + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=config.PORT) diff --git a/services/node-worker/models.py b/services/node-worker/models.py new file mode 100644 index 00000000..b24ce36f --- /dev/null +++ b/services/node-worker/models.py @@ -0,0 +1,47 @@ +"""Canonical job envelope for cross-node inference offload.""" +from typing import Any, Dict, List, Literal, Optional +from pydantic import BaseModel, Field +import time +import uuid + + +def _ulid() -> str: + return str(uuid.uuid4()) + + +class JobRequest(BaseModel): + job_id: str = Field(default_factory=_ulid) + trace_id: str = "" + actor_agent_id: str = "" + target_agent_id: str = "" + required_type: Literal["llm", "vision", "stt", "tts"] = "llm" + deadline_ts: int = 0 + idempotency_key: str = "" + payload: Dict[str, Any] = Field(default_factory=dict) + hints: Dict[str, Any] = Field(default_factory=dict) + + def remaining_ms(self) -> int: + if self.deadline_ts <= 0: + return 30_000 + return max(0, self.deadline_ts - int(time.time() * 1000)) + + def effective_idem_key(self) -> str: + return self.idempotency_key or self.job_id + + +class JobError(BaseModel): + code: str = "UNKNOWN" + message: str = "" + + +class JobResponse(BaseModel): + job_id: str = "" + trace_id: str = "" + node_id: str = "" + status: Literal["ok", "busy", "timeout", "error"] = "ok" + provider: str = "" + model: str = "" + latency_ms: int = 0 + result: Optional[Dict[str, Any]] = None + error: Optional[JobError] = None + cached: bool = False diff --git a/services/node-worker/providers/__init__.py b/services/node-worker/providers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/node-worker/providers/ollama.py b/services/node-worker/providers/ollama.py new file mode 100644 index 00000000..af9646fc --- /dev/null +++ b/services/node-worker/providers/ollama.py @@ -0,0 +1,81 @@ +"""Ollama LLM provider for node-worker.""" +import logging +from typing import Any, Dict, List, Optional + +import httpx + +from config import OLLAMA_BASE_URL, DEFAULT_LLM + +logger = logging.getLogger("provider.ollama") + + +async def infer( + messages: Optional[List[Dict[str, str]]] = None, + prompt: str = "", + model: str = "", + system: str = "", + max_tokens: int = 2048, + temperature: float = 0.2, + timeout_s: float = 25.0, +) -> Dict[str, Any]: + model = model or DEFAULT_LLM + + if messages: + return await _chat(messages, model, max_tokens, temperature, timeout_s) + return await _generate(prompt, system, model, max_tokens, temperature, timeout_s) + + +async def _chat( + messages: List[Dict[str, str]], + model: str, + max_tokens: int, + temperature: float, + timeout_s: float, +) -> Dict[str, Any]: + async with httpx.AsyncClient(timeout=timeout_s) as c: + resp = await c.post( + f"{OLLAMA_BASE_URL}/api/chat", + json={ + "model": model, + "messages": messages, + "stream": False, + "options": {"num_predict": max_tokens, "temperature": temperature}, + }, + ) + resp.raise_for_status() + data = resp.json() + return { + "text": data.get("message", {}).get("content", ""), + "model": model, + "provider": "ollama", + "eval_count": data.get("eval_count", 0), + } + + +async def _generate( + prompt: str, + system: str, + model: str, + max_tokens: int, + temperature: float, + timeout_s: float, +) -> Dict[str, Any]: + async with httpx.AsyncClient(timeout=timeout_s) as c: + resp = await c.post( + f"{OLLAMA_BASE_URL}/api/generate", + json={ + "model": model, + "prompt": prompt, + "system": system, + "stream": False, + "options": {"num_predict": max_tokens, "temperature": temperature}, + }, + ) + resp.raise_for_status() + data = resp.json() + return { + "text": data.get("response", ""), + "model": model, + "provider": "ollama", + "eval_count": data.get("eval_count", 0), + } diff --git a/services/node-worker/providers/swapper_vision.py b/services/node-worker/providers/swapper_vision.py new file mode 100644 index 00000000..c6720ed0 --- /dev/null +++ b/services/node-worker/providers/swapper_vision.py @@ -0,0 +1,42 @@ +"""Swapper vision provider for node-worker.""" +import logging +from typing import Any, Dict, List, Optional + +import httpx + +from config import SWAPPER_URL, DEFAULT_VISION + +logger = logging.getLogger("provider.swapper_vision") + + +async def infer( + images: Optional[List[str]] = None, + prompt: str = "", + model: str = "", + system: str = "", + max_tokens: int = 1024, + temperature: float = 0.2, + timeout_s: float = 60.0, +) -> Dict[str, Any]: + model = model or DEFAULT_VISION + + payload: Dict[str, Any] = { + "model": model, + "prompt": prompt or "Describe this image.", + "max_tokens": max_tokens, + "temperature": temperature, + } + if images: + payload["images"] = images + if system: + payload["system"] = system + + async with httpx.AsyncClient(timeout=timeout_s) as c: + resp = await c.post(f"{SWAPPER_URL}/vision", json=payload) + resp.raise_for_status() + data = resp.json() + return { + "text": data.get("text", data.get("response", "")), + "model": model, + "provider": "swapper_vision", + } diff --git a/services/node-worker/requirements.txt b/services/node-worker/requirements.txt new file mode 100644 index 00000000..4a7c9644 --- /dev/null +++ b/services/node-worker/requirements.txt @@ -0,0 +1,5 @@ +fastapi>=0.110.0 +uvicorn>=0.29.0 +httpx>=0.27.0 +nats-py>=2.7.0 +pydantic>=2.5.0 diff --git a/services/node-worker/worker.py b/services/node-worker/worker.py new file mode 100644 index 00000000..a0fd2a0a --- /dev/null +++ b/services/node-worker/worker.py @@ -0,0 +1,184 @@ +"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects.""" +import asyncio +import json +import logging +import time +from typing import Any, Dict + +import config +from models import JobRequest, JobResponse, JobError +from idempotency import IdempotencyStore +from providers import ollama, swapper_vision + +logger = logging.getLogger("node-worker") + +_idem = IdempotencyStore() +_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY) +_nats_client = None + + +async def start(nats_client): + global _nats_client + _nats_client = nats_client + + subjects = [ + f"node.{config.NODE_ID.lower()}.llm.request", + f"node.{config.NODE_ID.lower()}.vision.request", + ] + for subj in subjects: + await nats_client.subscribe(subj, cb=_handle_request) + logger.info(f"✅ Subscribed: {subj}") + + +async def _handle_request(msg): + t0 = time.time() + try: + raw = msg.data + if len(raw) > config.MAX_PAYLOAD_BYTES: + await _reply(msg, JobResponse( + status="error", + node_id=config.NODE_ID, + error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"), + )) + return + + data = json.loads(raw) + job = JobRequest(**data) + job.trace_id = job.trace_id or job.job_id + + idem_key = job.effective_idem_key() + cached = _idem.get(idem_key) + if cached: + logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}") + await _reply(msg, cached) + return + + remaining = job.remaining_ms() + if remaining <= 0: + resp = JobResponse( + job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, + status="timeout", + error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"), + ) + _idem.put(idem_key, resp) + await _reply(msg, resp) + return + + inflight = await _idem.acquire_inflight(idem_key) + if inflight is not None: + try: + resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0) + resp_copy = resp.model_copy() + resp_copy.cached = True + await _reply(msg, resp_copy) + except asyncio.TimeoutError: + await _reply(msg, JobResponse( + job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, + status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"), + )) + return + + if _semaphore.locked() and _semaphore._value == 0: + resp = JobResponse( + job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, + status="busy", + error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"), + ) + _idem.complete_inflight(idem_key, resp) + await _reply(msg, resp) + return + + async with _semaphore: + resp = await _execute(job, remaining) + + _idem.put(idem_key, resp) + _idem.complete_inflight(idem_key, resp) + resp.latency_ms = int((time.time() - t0) * 1000) + await _reply(msg, resp) + + except Exception as e: + logger.exception(f"Worker handler error: {e}") + try: + await _reply(msg, JobResponse( + node_id=config.NODE_ID, status="error", + error=JobError(code="INTERNAL", message=str(e)[:200]), + )) + except Exception: + pass + + +async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse: + payload = job.payload + hints = job.hints + timeout_s = min(remaining_ms / 1000.0, 120.0) + model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "") + + msg_count = len(payload.get("messages", [])) + prompt_chars = len(payload.get("prompt", "")) + logger.info( + f"[job.start] job={job.job_id} trace={job.trace_id} " + f"type={job.required_type} model={model or '?'} " + f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms" + ) + + try: + if job.required_type == "llm": + result = await asyncio.wait_for( + ollama.infer( + messages=payload.get("messages"), + prompt=payload.get("prompt", ""), + model=model, + system=payload.get("system", ""), + max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)), + temperature=hints.get("temperature", payload.get("temperature", 0.2)), + timeout_s=timeout_s, + ), + timeout=timeout_s, + ) + elif job.required_type == "vision": + result = await asyncio.wait_for( + swapper_vision.infer( + images=payload.get("images"), + prompt=payload.get("prompt", ""), + model=model, + system=payload.get("system", ""), + max_tokens=hints.get("max_tokens", 1024), + temperature=hints.get("temperature", 0.2), + timeout_s=timeout_s, + ), + timeout=timeout_s, + ) + else: + return JobResponse( + job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, + status="error", + error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not implemented"), + ) + + logger.info( + f"[job.done] job={job.job_id} status=ok " + f"provider={result.get('provider')} model={result.get('model')}" + ) + return JobResponse( + job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, + status="ok", provider=result.get("provider", ""), model=result.get("model", ""), + result=result, + ) + + except asyncio.TimeoutError: + logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s") + return JobResponse( + job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, + status="timeout", error=JobError(code="PROVIDER_TIMEOUT"), + ) + except Exception as e: + logger.warning(f"[job.error] job={job.job_id} {e}") + return JobResponse( + job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID, + status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]), + ) + + +async def _reply(msg, resp: JobResponse): + if msg.reply: + await _nats_client.publish(msg.reply, resp.model_dump_json().encode()) diff --git a/services/router/main.py b/services/router/main.py index 8ea94a40..c96893fc 100644 --- a/services/router/main.py +++ b/services/router/main.py @@ -51,11 +51,13 @@ try: import capabilities_client import global_capabilities_client from model_select import select_model_for_agent, ModelSelection, CLOUD_PROVIDERS as NCS_CLOUD_PROVIDERS + import offload_client NCS_AVAILABLE = True except ImportError: NCS_AVAILABLE = False capabilities_client = None # type: ignore[assignment] global_capabilities_client = None # type: ignore[assignment] + offload_client = None # type: ignore[assignment] logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -1707,6 +1709,76 @@ async def agent_infer(agent_id: str, request: InferRequest): f"provider={provider} model={model}" ) + # ========================================================================= + # REMOTE OFFLOAD (if model selected on remote node) + # ========================================================================= + nats_client_available = nc is not None and nats_available + if (ncs_selection and ncs_selection.via_nats and not ncs_selection.local + and nats_client_available and offload_client and nc): + infer_timeout = int(os.getenv("ROUTER_INFER_TIMEOUT_MS", "25000")) + import uuid as _uuid + job_payload = { + "job_id": str(_uuid.uuid4()), + "trace_id": str(_uuid.uuid4()), + "actor_agent_id": request_agent_id or agent_id, + "target_agent_id": agent_id, + "required_type": ncs_selection.model_type if ncs_selection.model_type != "code" else "llm", + "deadline_ts": int(time.time() * 1000) + infer_timeout, + "idempotency_key": str(_uuid.uuid4()), + "payload": { + "prompt": request.prompt, + "messages": [{"role": "system", "content": system_prompt}] if system_prompt else [], + "model": ncs_selection.name, + "max_tokens": request.max_tokens or 2048, + "temperature": request.temperature or 0.2, + }, + "hints": {"prefer_models": [ncs_selection.name]}, + } + if request.images: + job_payload["payload"]["images"] = request.images + job_payload["required_type"] = "vision" + job_payload["payload"]["messages"].append({"role": "user", "content": request.prompt}) + + offload_resp = await offload_client.offload_infer( + nats_client=nc, + node_id=ncs_selection.node, + required_type=job_payload["required_type"], + job_payload=job_payload, + timeout_ms=infer_timeout, + ) + if offload_resp and offload_resp.get("status") == "ok": + result_text = offload_resp.get("result", {}).get("text", "") + return InferResponse( + response=result_text, + model=f"{offload_resp.get('model', ncs_selection.name)}@{ncs_selection.node}", + backend=f"nats-offload:{ncs_selection.node}", + tokens_used=offload_resp.get("result", {}).get("eval_count", 0), + ) + else: + offload_status = offload_resp.get("status", "none") if offload_resp else "no_reply" + logger.warning( + f"[fallback] offload to {ncs_selection.node} failed ({offload_status}) " + f"→ re-selecting with exclude={ncs_selection.node}, force_local" + ) + try: + gcaps = await global_capabilities_client.get_global_capabilities() + ncs_selection = await select_model_for_agent( + agent_id, agent_config, router_config, gcaps, request.model, + exclude_nodes={ncs_selection.node}, force_local=True, + ) + if ncs_selection and ncs_selection.name: + provider = ncs_selection.provider + model = ncs_selection.name + llm_profile = router_config.get("llm_profiles", {}).get(default_llm, {}) + if ncs_selection.base_url and provider == "ollama": + llm_profile = {**llm_profile, "base_url": ncs_selection.base_url} + logger.info( + f"[fallback.reselect] → local node={ncs_selection.node} " + f"model={model} provider={provider}" + ) + except Exception as e: + logger.warning(f"[fallback.reselect] error: {e}; proceeding with static") + # ========================================================================= # VISION PROCESSING (if images present) # ========================================================================= diff --git a/services/router/model_select.py b/services/router/model_select.py index b784916b..7d9a173f 100644 --- a/services/router/model_select.py +++ b/services/router/model_select.py @@ -9,7 +9,7 @@ Scaling: works with 1 node or 150+. No static node lists. import logging import time from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Set logger = logging.getLogger("model_select") @@ -110,6 +110,7 @@ def profile_requirements( def select_best_model( reqs: ProfileRequirements, capabilities: Dict[str, Any], + exclude_nodes: Optional[Set[str]] = None, ) -> Optional[ModelSelection]: """Choose the best served model from global (multi-node) capabilities. @@ -117,18 +118,25 @@ def select_best_model( 1. Prefer list matches (local first, then remote) 2. Best candidate by size (local first, then remote) 3. None → caller should try static fallback + + exclude_nodes: set of node_ids to skip (e.g. circuit-broken nodes). """ served = capabilities.get("served_models", []) if not served: return None + exclude = exclude_nodes or set() + search_types = [reqs.required_type] if reqs.required_type == "code": search_types.append("llm") if reqs.required_type == "llm": search_types.append("code") - candidates = [m for m in served if m.get("type") in search_types] + candidates = [ + m for m in served + if m.get("type") in search_types and m.get("node", "") not in exclude + ] if not candidates: return None @@ -218,15 +226,21 @@ async def select_model_for_agent( router_cfg: Dict[str, Any], capabilities: Optional[Dict[str, Any]], request_model: Optional[str] = None, + exclude_nodes: Optional[Set[str]] = None, + force_local: bool = False, ) -> ModelSelection: - """Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default.""" + """Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default. + + exclude_nodes: skip these nodes (circuit-broken). Used on fallback re-selection. + force_local: prefer local-only models (fallback after remote failure). + """ profile = resolve_effective_profile( agent_id, agent_cfg, router_cfg, request_model, ) reqs = profile_requirements(profile, agent_cfg, router_cfg) - if reqs.required_type == "cloud_llm": + if reqs.required_type == "cloud_llm" and not force_local: static = static_fallback(profile, router_cfg) if static: static.fallback_reason = "" @@ -236,14 +250,31 @@ async def select_model_for_agent( ) return static + excl = set(exclude_nodes) if exclude_nodes else set() + try: + from offload_client import get_unavailable_nodes + cb_nodes = get_unavailable_nodes(reqs.required_type) + excl |= cb_nodes + if cb_nodes: + logger.info(f"[select] circuit-broken nodes for {reqs.required_type}: {cb_nodes}") + except ImportError: + pass + if capabilities and capabilities.get("served_models"): - sel = select_best_model(reqs, capabilities) + sel = select_best_model(reqs, capabilities, exclude_nodes=excl) + if force_local and sel and not sel.local: + sel = select_best_model( + reqs, capabilities, + exclude_nodes=excl | {n.get("node", "") for n in capabilities.get("served_models", []) if not n.get("local")}, + ) if sel: logger.info( f"[select] agent={agent_id} profile={profile} → " - f"{'NCS' if sel.local else 'REMOTE'} " + f"{'LOCAL' if sel.local else 'REMOTE'} " f"node={sel.node} runtime={sel.runtime} " f"model={sel.name} caps_age={sel.caps_age_s}s" + f"{' (force_local)' if force_local else ''}" + f"{' (excluded: ' + ','.join(excl) + ')' if excl else ''}" ) return sel logger.warning( diff --git a/services/router/offload_client.py b/services/router/offload_client.py new file mode 100644 index 00000000..906ae2e2 --- /dev/null +++ b/services/router/offload_client.py @@ -0,0 +1,153 @@ +"""NATS offload client — sends inference requests to remote nodes with +circuit breaker, retries, and deadline enforcement.""" +import asyncio +import json +import logging +import os +import time +from typing import Any, Dict, Literal, Optional, Set + +logger = logging.getLogger("offload_client") + +CB_FAILS = int(os.getenv("ROUTER_OFFLOAD_CB_FAILS", "3")) +CB_WINDOW_S = int(os.getenv("ROUTER_OFFLOAD_CB_WINDOW_S", "60")) +CB_OPEN_S = int(os.getenv("ROUTER_OFFLOAD_CB_OPEN_S", "120")) +MAX_RETRIES = int(os.getenv("ROUTER_OFFLOAD_RETRIES", "1")) +MAX_CONCURRENCY = int(os.getenv("ROUTER_OFFLOAD_MAX_CONCURRENCY_REMOTE", "8")) + +_semaphore: Optional[asyncio.Semaphore] = None +_breakers: Dict[str, Dict[str, Any]] = {} + + +def _get_semaphore() -> asyncio.Semaphore: + global _semaphore + if _semaphore is None: + _semaphore = asyncio.Semaphore(MAX_CONCURRENCY) + return _semaphore + + +def _breaker_key(node_id: str, req_type: str) -> str: + return f"{node_id}:{req_type}" + + +def is_node_available(node_id: str, req_type: str) -> bool: + key = _breaker_key(node_id, req_type) + state = _breakers.get(key) + if not state: + return True + open_until = state.get("open_until", 0) + if open_until > time.time(): + return False + if open_until > 0 and open_until <= time.time(): + return True + now = time.time() + window_start = now - CB_WINDOW_S + recent = [t for t in state.get("failures", []) if t > window_start] + state["failures"] = recent + return len(recent) < CB_FAILS + + +def record_failure(node_id: str, req_type: str): + key = _breaker_key(node_id, req_type) + state = _breakers.setdefault(key, {"failures": [], "open_until": 0}) + state["failures"].append(time.time()) + window_start = time.time() - CB_WINDOW_S + recent = [t for t in state["failures"] if t > window_start] + state["failures"] = recent + if len(recent) >= CB_FAILS: + state["open_until"] = time.time() + CB_OPEN_S + logger.warning(f"Circuit OPEN: {key} ({len(recent)} failures in {CB_WINDOW_S}s, open for {CB_OPEN_S}s)") + + +def record_success(node_id: str, req_type: str): + key = _breaker_key(node_id, req_type) + state = _breakers.get(key) + if state: + state["failures"] = [] + state["open_until"] = 0 + + +def get_unavailable_nodes(req_type: str) -> Set[str]: + result = set() + for key, state in _breakers.items(): + if not key.endswith(f":{req_type}"): + continue + nid = key.rsplit(":", 1)[0] + if not is_node_available(nid, req_type): + result.add(nid) + return result + + +async def offload_infer( + nats_client, + node_id: str, + required_type: Literal["llm", "vision", "stt", "tts"], + job_payload: Dict[str, Any], + timeout_ms: int = 25000, +) -> Optional[Dict[str, Any]]: + """Send inference job to remote node via NATS request/reply. + + Returns parsed JobResponse dict or None on total failure. + Retries on transient errors (timeout, busy). Does NOT retry on provider errors. + """ + subject = f"node.{node_id.lower()}.{required_type}.request" + payload_bytes = json.dumps(job_payload).encode() + sem = _get_semaphore() + + for attempt in range(1 + MAX_RETRIES): + timeout_s = timeout_ms / 1000.0 + if timeout_s <= 0: + logger.warning(f"[offload] deadline exhausted before attempt {attempt}") + return None + + t0 = time.time() + try: + async with sem: + logger.info( + f"[offload.start] node={node_id} subj={subject} " + f"timeout={timeout_ms}ms attempt={attempt}" + ) + msg = await nats_client.request(subject, payload_bytes, timeout=timeout_s) + resp = json.loads(msg.data) + + latency = int((time.time() - t0) * 1000) + status = resp.get("status", "error") + + if status == "ok": + record_success(node_id, required_type) + logger.info( + f"[offload.done] node={node_id} status=ok latency={latency}ms " + f"provider={resp.get('provider')} model={resp.get('model')} " + f"cached={resp.get('cached', False)}" + ) + return resp + + if status in ("timeout", "busy") and attempt < MAX_RETRIES: + elapsed = int((time.time() - t0) * 1000) + timeout_ms -= elapsed + logger.warning(f"[offload.retry] node={node_id} status={status} → retry {attempt+1}") + continue + + record_failure(node_id, required_type) + logger.warning( + f"[offload.fail] node={node_id} status={status} " + f"error={resp.get('error', {}).get('code', '?')}" + ) + return resp + + except asyncio.TimeoutError: + record_failure(node_id, required_type) + elapsed = int((time.time() - t0) * 1000) + timeout_ms -= elapsed + if attempt < MAX_RETRIES: + logger.warning(f"[offload.timeout] node={node_id} {elapsed}ms → retry {attempt+1}") + continue + logger.warning(f"[offload.timeout] node={node_id} all retries exhausted") + return None + + except Exception as e: + record_failure(node_id, required_type) + logger.warning(f"[offload.error] node={node_id} {e}") + return None + + return None diff --git a/tests/test_node_worker_deadline.py b/tests/test_node_worker_deadline.py new file mode 100644 index 00000000..6708ae35 --- /dev/null +++ b/tests/test_node_worker_deadline.py @@ -0,0 +1,34 @@ +"""Tests for node-worker deadline handling.""" +import sys +import os +import time + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services", "node-worker")) + +from models import JobRequest + + +def test_remaining_ms_with_future_deadline(): + job = JobRequest(deadline_ts=int(time.time() * 1000) + 10000) + rem = job.remaining_ms() + assert 9000 <= rem <= 10500 + + +def test_remaining_ms_with_past_deadline(): + job = JobRequest(deadline_ts=int(time.time() * 1000) - 5000) + assert job.remaining_ms() == 0 + + +def test_remaining_ms_no_deadline(): + job = JobRequest(deadline_ts=0) + assert job.remaining_ms() == 30_000 + + +def test_effective_idem_key_uses_idempotency_key(): + job = JobRequest(job_id="j1", idempotency_key="custom-key") + assert job.effective_idem_key() == "custom-key" + + +def test_effective_idem_key_falls_back_to_job_id(): + job = JobRequest(job_id="j2", idempotency_key="") + assert job.effective_idem_key() == "j2" diff --git a/tests/test_node_worker_idempotency.py b/tests/test_node_worker_idempotency.py new file mode 100644 index 00000000..024d1122 --- /dev/null +++ b/tests/test_node_worker_idempotency.py @@ -0,0 +1,66 @@ +"""Tests for node-worker idempotency store.""" +import asyncio +import sys +import os +import time + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services", "node-worker")) + +from models import JobResponse, JobError +from idempotency import IdempotencyStore + + +def test_put_and_get(): + store = IdempotencyStore() + resp = JobResponse(job_id="j1", status="ok", node_id="n1", provider="ollama", model="qwen3:14b") + store.put("key1", resp) + cached = store.get("key1") + assert cached is not None + assert cached.status == "ok" + assert cached.cached is True + assert cached.job_id == "j1" + + +def test_miss(): + store = IdempotencyStore() + assert store.get("nonexistent") is None + + +def test_ttl_expiry(): + store = IdempotencyStore() + resp = JobResponse(job_id="j2", status="ok", node_id="n1") + store.put("key2", resp) + store._cache["key2"] = (resp, time.time() - 1) + assert store.get("key2") is None + + +def test_timeout_shorter_ttl(): + store = IdempotencyStore() + resp = JobResponse(job_id="j3", status="timeout", error=JobError(code="TO")) + store.put("key3", resp) + _, expires = store._cache["key3"] + assert expires - time.time() < 35 # timeout TTL ≈ 30s + + +def test_inflight_dedup(): + store = IdempotencyStore() + + async def run(): + fut1 = await store.acquire_inflight("key4") + assert fut1 is None # first caller gets None (becomes processor) + fut2 = await store.acquire_inflight("key4") + assert fut2 is not None # second caller gets future to wait on + + resp = JobResponse(job_id="j4", status="ok") + store.complete_inflight("key4", resp) + result = await asyncio.wait_for(fut2, timeout=1.0) + assert result.status == "ok" + + asyncio.run(run()) + + +def test_evict_on_max_size(): + store = IdempotencyStore(max_size=3) + for i in range(5): + store.put(f"k{i}", JobResponse(job_id=f"j{i}", status="ok")) + assert len(store._cache) <= 3 diff --git a/tests/test_offload_circuit_breaker.py b/tests/test_offload_circuit_breaker.py new file mode 100644 index 00000000..409a191f --- /dev/null +++ b/tests/test_offload_circuit_breaker.py @@ -0,0 +1,52 @@ +"""Tests for offload_client circuit breaker.""" +import sys +import os +import time + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services", "router")) + +import offload_client + + +def setup_function(): + offload_client._breakers.clear() + + +def test_node_available_by_default(): + assert offload_client.is_node_available("noda99", "llm") is True + + +def test_node_unavailable_after_threshold(): + for _ in range(offload_client.CB_FAILS): + offload_client.record_failure("noda99", "llm") + assert offload_client.is_node_available("noda99", "llm") is False + + +def test_node_available_after_open_expires(): + for _ in range(offload_client.CB_FAILS): + offload_client.record_failure("noda99", "llm") + offload_client._breakers["noda99:llm"]["open_until"] = time.time() - 1 + assert offload_client.is_node_available("noda99", "llm") is True + + +def test_success_resets_breaker(): + for _ in range(offload_client.CB_FAILS): + offload_client.record_failure("noda99", "llm") + offload_client._breakers["noda99:llm"]["open_until"] = time.time() - 1 + offload_client.record_success("noda99", "llm") + assert offload_client.is_node_available("noda99", "llm") is True + + +def test_get_unavailable_nodes(): + for _ in range(offload_client.CB_FAILS): + offload_client.record_failure("bad_node", "llm") + unavail = offload_client.get_unavailable_nodes("llm") + assert "bad_node" in unavail + assert offload_client.get_unavailable_nodes("vision") == set() + + +def test_different_types_independent(): + for _ in range(offload_client.CB_FAILS): + offload_client.record_failure("noda99", "llm") + assert offload_client.is_node_available("noda99", "llm") is False + assert offload_client.is_node_available("noda99", "vision") is True