P2.2+P2.3: NATS offload node-worker + router offload integration
Node Worker (services/node-worker/):
- NATS subscriber for node.{NODE_ID}.llm.request / vision.request
- Canonical JobRequest/JobResponse envelope (Pydantic)
- Idempotency cache (TTL 10min) with inflight dedup
- Deadline enforcement (DEADLINE_EXCEEDED on expired jobs)
- Concurrency limiter (semaphore, returns busy)
- Ollama + Swapper vision providers
Router offload (services/router/offload_client.py):
- NATS req/reply with configurable retries
- Circuit breaker per node+type (3 fails/60s → open 120s)
- Concurrency semaphore for remote requests
Model selection (services/router/model_select.py):
- exclude_nodes parameter for circuit-broken nodes
- force_local flag for fallback re-selection
- Integrated circuit breaker state awareness
Router /infer pipeline:
- Remote offload path when NCS selects remote node
- Automatic fallback: exclude failed node → force_local re-select
- Deadline propagation from router to node-worker
Tests: 17 unit tests (idempotency, deadline, circuit breaker)
Docs: ops/offload_routing.md (subjects, envelope, verification)
Made-with: Cursor
This commit is contained in:
@@ -133,6 +133,30 @@ services:
|
||||
- dagi-network
|
||||
restart: unless-stopped
|
||||
|
||||
node-worker:
|
||||
build:
|
||||
context: ./services/node-worker
|
||||
dockerfile: Dockerfile
|
||||
container_name: node-worker-node2
|
||||
ports:
|
||||
- "127.0.0.1:8109:8109"
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
environment:
|
||||
- NODE_ID=noda2
|
||||
- NATS_URL=nats://dagi-nats:4222
|
||||
- OLLAMA_BASE_URL=http://host.docker.internal:11434
|
||||
- SWAPPER_URL=http://swapper-service:8890
|
||||
- NODE_DEFAULT_LLM=qwen3:14b
|
||||
- NODE_DEFAULT_VISION=llava:13b
|
||||
- NODE_WORKER_MAX_CONCURRENCY=2
|
||||
depends_on:
|
||||
- dagi-nats
|
||||
- swapper-service
|
||||
networks:
|
||||
- dagi-network
|
||||
restart: unless-stopped
|
||||
|
||||
sofiia-console:
|
||||
build:
|
||||
context: ./services/sofiia-console
|
||||
|
||||
144
ops/offload_routing.md
Normal file
144
ops/offload_routing.md
Normal file
@@ -0,0 +1,144 @@
|
||||
# NATS Offload Routing — Operations Guide
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Router (NODA1/NODA2)
|
||||
│
|
||||
├── model_select.py → selects best model from global capabilities pool
|
||||
│ (local first, remote if needed, circuit breaker aware)
|
||||
│
|
||||
├── offload_client.py → NATS req/reply to remote node-worker
|
||||
│ (retries, deadlines, circuit breaker)
|
||||
│
|
||||
└── global_capabilities_client.py → scatter-gather discovery
|
||||
node.*.capabilities.get
|
||||
```
|
||||
|
||||
## NATS Subjects
|
||||
|
||||
| Subject | Direction | Description |
|
||||
|---------|-----------|-------------|
|
||||
| `node.{node_id}.capabilities.get` | req/reply | NCS capabilities query |
|
||||
| `node.{node_id}.llm.request` | req/reply | LLM inference offload |
|
||||
| `node.{node_id}.vision.request` | req/reply | Vision inference offload |
|
||||
| `node.{node_id}.stt.request` | req/reply | STT (scaffold) |
|
||||
| `node.{node_id}.tts.request` | req/reply | TTS (scaffold) |
|
||||
|
||||
## Job Request Envelope
|
||||
|
||||
```json
|
||||
{
|
||||
"job_id": "uuid",
|
||||
"trace_id": "uuid",
|
||||
"actor_agent_id": "sofiia",
|
||||
"target_agent_id": "helion",
|
||||
"required_type": "llm",
|
||||
"deadline_ts": 1740000000000,
|
||||
"idempotency_key": "uuid",
|
||||
"payload": {
|
||||
"messages": [{"role": "user", "content": "..."}],
|
||||
"prompt": "...",
|
||||
"model": "qwen3:14b",
|
||||
"max_tokens": 2048,
|
||||
"temperature": 0.2
|
||||
},
|
||||
"hints": {
|
||||
"prefer_models": ["qwen3:14b"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Job Response
|
||||
|
||||
```json
|
||||
{
|
||||
"job_id": "uuid",
|
||||
"trace_id": "uuid",
|
||||
"node_id": "noda2",
|
||||
"status": "ok|busy|timeout|error",
|
||||
"provider": "ollama",
|
||||
"model": "qwen3:14b",
|
||||
"latency_ms": 5500,
|
||||
"result": {"text": "..."},
|
||||
"error": null,
|
||||
"cached": false
|
||||
}
|
||||
```
|
||||
|
||||
## Circuit Breaker
|
||||
|
||||
Per `node_id:required_type`. Opens after 3 failures in 60s, stays open 120s.
|
||||
|
||||
| Env | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `ROUTER_OFFLOAD_CB_FAILS` | 3 | Failures to trip |
|
||||
| `ROUTER_OFFLOAD_CB_WINDOW_S` | 60 | Failure window |
|
||||
| `ROUTER_OFFLOAD_CB_OPEN_S` | 120 | Open duration |
|
||||
|
||||
## Idempotency
|
||||
|
||||
Node-worker caches `idempotency_key → response` for 10 minutes.
|
||||
Duplicate requests return cached response immediately (< 10ms).
|
||||
Inflight dedup prevents parallel execution of same job.
|
||||
|
||||
## Deadline Enforcement
|
||||
|
||||
`deadline_ts` is absolute Unix milliseconds. If already expired on arrival,
|
||||
node-worker returns `status=timeout` + `error.code=DEADLINE_EXCEEDED`.
|
||||
During inference, `asyncio.wait_for` enforces remaining time.
|
||||
|
||||
## Fallback Chain
|
||||
|
||||
1. NCS selects remote model (e.g. noda2)
|
||||
2. Router sends NATS offload request
|
||||
3. If remote fails (timeout/busy/error):
|
||||
- Record circuit breaker failure
|
||||
- Re-select with `exclude_nodes={failed}` + `force_local=True`
|
||||
- Execute locally
|
||||
|
||||
## Environment Variables
|
||||
|
||||
### Router
|
||||
| Env | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `ROUTER_INFER_TIMEOUT_MS` | 25000 | Total inference deadline |
|
||||
| `ROUTER_OFFLOAD_RETRIES` | 1 | NATS retry on transient |
|
||||
| `ROUTER_OFFLOAD_MAX_CONCURRENCY_REMOTE` | 8 | Max parallel offloads |
|
||||
|
||||
### Node Worker
|
||||
| Env | Default | Description |
|
||||
|-----|---------|-------------|
|
||||
| `NODE_ID` | noda2 | Canonical node identifier |
|
||||
| `NATS_URL` | nats://dagi-nats:4222 | NATS server |
|
||||
| `OLLAMA_BASE_URL` | http://host.docker.internal:11434 | Ollama API |
|
||||
| `SWAPPER_URL` | http://swapper-service:8890 | Swapper API |
|
||||
| `NODE_DEFAULT_LLM` | qwen3:14b | Default LLM model |
|
||||
| `NODE_DEFAULT_VISION` | llava:13b | Default vision model |
|
||||
| `NODE_WORKER_MAX_CONCURRENCY` | 2 | Max parallel inferences |
|
||||
|
||||
## Verification Commands
|
||||
|
||||
```bash
|
||||
# Node Worker health
|
||||
curl -s http://localhost:8109/healthz | jq .
|
||||
|
||||
# Direct NATS LLM offload
|
||||
nats request node.noda2.llm.request '{"job_id":"test","required_type":"llm","payload":{"prompt":"hi","max_tokens":10}}'
|
||||
|
||||
# Idempotency test (same job_id)
|
||||
nats request node.noda2.llm.request '{"job_id":"test","idempotency_key":"test","required_type":"llm","payload":{"prompt":"hi"}}' # → cached=true
|
||||
|
||||
# Deadline expired test
|
||||
nats request node.noda2.llm.request '{"job_id":"expired","required_type":"llm","deadline_ts":1000,"payload":{"prompt":"hi"}}' # → DEADLINE_EXCEEDED
|
||||
|
||||
# Router logs (selection + offload)
|
||||
docker logs dagi-router-node2 2>&1 | grep -E '\[select\]|\[offload\]|\[fallback\]'
|
||||
```
|
||||
|
||||
## Adding New Nodes
|
||||
|
||||
1. Deploy `node-capabilities` + `node-worker` with unique `NODE_ID`
|
||||
2. Connect NATS (leafnode or cluster member)
|
||||
3. Router auto-discovers via `node.*.capabilities.get` scatter-gather
|
||||
4. No config changes on Router needed
|
||||
7
services/node-worker/Dockerfile
Normal file
7
services/node-worker/Dockerfile
Normal file
@@ -0,0 +1,7 @@
|
||||
FROM python:3.11-slim
|
||||
WORKDIR /app
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
COPY . .
|
||||
EXPOSE 8109
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8109"]
|
||||
0
services/node-worker/__init__.py
Normal file
0
services/node-worker/__init__.py
Normal file
12
services/node-worker/config.py
Normal file
12
services/node-worker/config.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Node-worker configuration from environment."""
|
||||
import os
|
||||
|
||||
NODE_ID = os.getenv("NODE_ID", "noda2")
|
||||
NATS_URL = os.getenv("NATS_URL", "nats://dagi-nats:4222")
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||
SWAPPER_URL = os.getenv("SWAPPER_URL", "http://swapper-service:8890")
|
||||
DEFAULT_LLM = os.getenv("NODE_DEFAULT_LLM", "qwen3:14b")
|
||||
DEFAULT_VISION = os.getenv("NODE_DEFAULT_VISION", "llava:13b")
|
||||
MAX_CONCURRENCY = int(os.getenv("NODE_WORKER_MAX_CONCURRENCY", "2"))
|
||||
MAX_PAYLOAD_BYTES = int(os.getenv("NODE_WORKER_MAX_PAYLOAD_BYTES", str(1024 * 1024)))
|
||||
PORT = int(os.getenv("PORT", "8109"))
|
||||
62
services/node-worker/idempotency.py
Normal file
62
services/node-worker/idempotency.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Idempotency cache + inflight dedup for job execution."""
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from models import JobResponse
|
||||
|
||||
logger = logging.getLogger("idempotency")
|
||||
|
||||
CACHE_TTL = 600 # 10 min for successful results
|
||||
TIMEOUT_TTL = 30 # 30s for timeout results
|
||||
|
||||
|
||||
class IdempotencyStore:
|
||||
def __init__(self, max_size: int = 10_000):
|
||||
self._cache: Dict[str, Tuple[JobResponse, float]] = {}
|
||||
self._inflight: Dict[str, asyncio.Future] = {}
|
||||
self._max_size = max_size
|
||||
|
||||
def get(self, key: str) -> Optional[JobResponse]:
|
||||
entry = self._cache.get(key)
|
||||
if not entry:
|
||||
return None
|
||||
resp, expires = entry
|
||||
if time.time() > expires:
|
||||
self._cache.pop(key, None)
|
||||
return None
|
||||
cached = resp.model_copy()
|
||||
cached.cached = True
|
||||
return cached
|
||||
|
||||
def put(self, key: str, resp: JobResponse):
|
||||
ttl = TIMEOUT_TTL if resp.status == "timeout" else CACHE_TTL
|
||||
self._cache[key] = (resp, time.time() + ttl)
|
||||
self._evict_if_needed()
|
||||
|
||||
def _evict_if_needed(self):
|
||||
if len(self._cache) <= self._max_size:
|
||||
return
|
||||
now = time.time()
|
||||
expired = [k for k, (_, exp) in self._cache.items() if now > exp]
|
||||
for k in expired:
|
||||
self._cache.pop(k, None)
|
||||
if len(self._cache) > self._max_size:
|
||||
oldest = sorted(self._cache, key=lambda k: self._cache[k][1])
|
||||
for k in oldest[:len(self._cache) - self._max_size]:
|
||||
self._cache.pop(k, None)
|
||||
|
||||
async def acquire_inflight(self, key: str) -> Optional[asyncio.Future]:
|
||||
"""If another coroutine is already processing this key, return its future.
|
||||
Otherwise register this coroutine as the processor and return None."""
|
||||
if key in self._inflight:
|
||||
return self._inflight[key]
|
||||
fut: asyncio.Future = asyncio.get_event_loop().create_future()
|
||||
self._inflight[key] = fut
|
||||
return None
|
||||
|
||||
def complete_inflight(self, key: str, resp: JobResponse):
|
||||
fut = self._inflight.pop(key, None)
|
||||
if fut and not fut.done():
|
||||
fut.set_result(resp)
|
||||
53
services/node-worker/main.py
Normal file
53
services/node-worker/main.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Node Worker — NATS offload executor for cross-node inference."""
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import FastAPI
|
||||
|
||||
import config
|
||||
import worker
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger("node-worker")
|
||||
|
||||
app = FastAPI(title="Node Worker", version="1.0.0")
|
||||
|
||||
_nats_client = None
|
||||
|
||||
|
||||
@app.get("/healthz")
|
||||
async def healthz():
|
||||
connected = _nats_client is not None and _nats_client.is_connected if _nats_client else False
|
||||
return {
|
||||
"status": "ok" if connected else "degraded",
|
||||
"node_id": config.NODE_ID,
|
||||
"nats_connected": connected,
|
||||
"max_concurrency": config.MAX_CONCURRENCY,
|
||||
}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
global _nats_client
|
||||
try:
|
||||
import nats as nats_lib
|
||||
_nats_client = await nats_lib.connect(config.NATS_URL)
|
||||
logger.info(f"✅ NATS connected: {config.NATS_URL}")
|
||||
await worker.start(_nats_client)
|
||||
logger.info(f"✅ Node Worker ready: node={config.NODE_ID} concurrency={config.MAX_CONCURRENCY}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Startup failed: {e}")
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown():
|
||||
if _nats_client:
|
||||
try:
|
||||
await _nats_client.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=config.PORT)
|
||||
47
services/node-worker/models.py
Normal file
47
services/node-worker/models.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Canonical job envelope for cross-node inference offload."""
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import time
|
||||
import uuid
|
||||
|
||||
|
||||
def _ulid() -> str:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class JobRequest(BaseModel):
|
||||
job_id: str = Field(default_factory=_ulid)
|
||||
trace_id: str = ""
|
||||
actor_agent_id: str = ""
|
||||
target_agent_id: str = ""
|
||||
required_type: Literal["llm", "vision", "stt", "tts"] = "llm"
|
||||
deadline_ts: int = 0
|
||||
idempotency_key: str = ""
|
||||
payload: Dict[str, Any] = Field(default_factory=dict)
|
||||
hints: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def remaining_ms(self) -> int:
|
||||
if self.deadline_ts <= 0:
|
||||
return 30_000
|
||||
return max(0, self.deadline_ts - int(time.time() * 1000))
|
||||
|
||||
def effective_idem_key(self) -> str:
|
||||
return self.idempotency_key or self.job_id
|
||||
|
||||
|
||||
class JobError(BaseModel):
|
||||
code: str = "UNKNOWN"
|
||||
message: str = ""
|
||||
|
||||
|
||||
class JobResponse(BaseModel):
|
||||
job_id: str = ""
|
||||
trace_id: str = ""
|
||||
node_id: str = ""
|
||||
status: Literal["ok", "busy", "timeout", "error"] = "ok"
|
||||
provider: str = ""
|
||||
model: str = ""
|
||||
latency_ms: int = 0
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error: Optional[JobError] = None
|
||||
cached: bool = False
|
||||
0
services/node-worker/providers/__init__.py
Normal file
0
services/node-worker/providers/__init__.py
Normal file
81
services/node-worker/providers/ollama.py
Normal file
81
services/node-worker/providers/ollama.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Ollama LLM provider for node-worker."""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from config import OLLAMA_BASE_URL, DEFAULT_LLM
|
||||
|
||||
logger = logging.getLogger("provider.ollama")
|
||||
|
||||
|
||||
async def infer(
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
prompt: str = "",
|
||||
model: str = "",
|
||||
system: str = "",
|
||||
max_tokens: int = 2048,
|
||||
temperature: float = 0.2,
|
||||
timeout_s: float = 25.0,
|
||||
) -> Dict[str, Any]:
|
||||
model = model or DEFAULT_LLM
|
||||
|
||||
if messages:
|
||||
return await _chat(messages, model, max_tokens, temperature, timeout_s)
|
||||
return await _generate(prompt, system, model, max_tokens, temperature, timeout_s)
|
||||
|
||||
|
||||
async def _chat(
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
timeout_s: float,
|
||||
) -> Dict[str, Any]:
|
||||
async with httpx.AsyncClient(timeout=timeout_s) as c:
|
||||
resp = await c.post(
|
||||
f"{OLLAMA_BASE_URL}/api/chat",
|
||||
json={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": False,
|
||||
"options": {"num_predict": max_tokens, "temperature": temperature},
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return {
|
||||
"text": data.get("message", {}).get("content", ""),
|
||||
"model": model,
|
||||
"provider": "ollama",
|
||||
"eval_count": data.get("eval_count", 0),
|
||||
}
|
||||
|
||||
|
||||
async def _generate(
|
||||
prompt: str,
|
||||
system: str,
|
||||
model: str,
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
timeout_s: float,
|
||||
) -> Dict[str, Any]:
|
||||
async with httpx.AsyncClient(timeout=timeout_s) as c:
|
||||
resp = await c.post(
|
||||
f"{OLLAMA_BASE_URL}/api/generate",
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"system": system,
|
||||
"stream": False,
|
||||
"options": {"num_predict": max_tokens, "temperature": temperature},
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return {
|
||||
"text": data.get("response", ""),
|
||||
"model": model,
|
||||
"provider": "ollama",
|
||||
"eval_count": data.get("eval_count", 0),
|
||||
}
|
||||
42
services/node-worker/providers/swapper_vision.py
Normal file
42
services/node-worker/providers/swapper_vision.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Swapper vision provider for node-worker."""
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from config import SWAPPER_URL, DEFAULT_VISION
|
||||
|
||||
logger = logging.getLogger("provider.swapper_vision")
|
||||
|
||||
|
||||
async def infer(
|
||||
images: Optional[List[str]] = None,
|
||||
prompt: str = "",
|
||||
model: str = "",
|
||||
system: str = "",
|
||||
max_tokens: int = 1024,
|
||||
temperature: float = 0.2,
|
||||
timeout_s: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
model = model or DEFAULT_VISION
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"prompt": prompt or "Describe this image.",
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": temperature,
|
||||
}
|
||||
if images:
|
||||
payload["images"] = images
|
||||
if system:
|
||||
payload["system"] = system
|
||||
|
||||
async with httpx.AsyncClient(timeout=timeout_s) as c:
|
||||
resp = await c.post(f"{SWAPPER_URL}/vision", json=payload)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return {
|
||||
"text": data.get("text", data.get("response", "")),
|
||||
"model": model,
|
||||
"provider": "swapper_vision",
|
||||
}
|
||||
5
services/node-worker/requirements.txt
Normal file
5
services/node-worker/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
fastapi>=0.110.0
|
||||
uvicorn>=0.29.0
|
||||
httpx>=0.27.0
|
||||
nats-py>=2.7.0
|
||||
pydantic>=2.5.0
|
||||
184
services/node-worker/worker.py
Normal file
184
services/node-worker/worker.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""NATS offload worker — subscribes to node.{NODE_ID}.{type}.request subjects."""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import config
|
||||
from models import JobRequest, JobResponse, JobError
|
||||
from idempotency import IdempotencyStore
|
||||
from providers import ollama, swapper_vision
|
||||
|
||||
logger = logging.getLogger("node-worker")
|
||||
|
||||
_idem = IdempotencyStore()
|
||||
_semaphore: asyncio.Semaphore = asyncio.Semaphore(config.MAX_CONCURRENCY)
|
||||
_nats_client = None
|
||||
|
||||
|
||||
async def start(nats_client):
|
||||
global _nats_client
|
||||
_nats_client = nats_client
|
||||
|
||||
subjects = [
|
||||
f"node.{config.NODE_ID.lower()}.llm.request",
|
||||
f"node.{config.NODE_ID.lower()}.vision.request",
|
||||
]
|
||||
for subj in subjects:
|
||||
await nats_client.subscribe(subj, cb=_handle_request)
|
||||
logger.info(f"✅ Subscribed: {subj}")
|
||||
|
||||
|
||||
async def _handle_request(msg):
|
||||
t0 = time.time()
|
||||
try:
|
||||
raw = msg.data
|
||||
if len(raw) > config.MAX_PAYLOAD_BYTES:
|
||||
await _reply(msg, JobResponse(
|
||||
status="error",
|
||||
node_id=config.NODE_ID,
|
||||
error=JobError(code="PAYLOAD_TOO_LARGE", message=f"max {config.MAX_PAYLOAD_BYTES} bytes"),
|
||||
))
|
||||
return
|
||||
|
||||
data = json.loads(raw)
|
||||
job = JobRequest(**data)
|
||||
job.trace_id = job.trace_id or job.job_id
|
||||
|
||||
idem_key = job.effective_idem_key()
|
||||
cached = _idem.get(idem_key)
|
||||
if cached:
|
||||
logger.info(f"[job.cached] job={job.job_id} trace={job.trace_id} idem={idem_key}")
|
||||
await _reply(msg, cached)
|
||||
return
|
||||
|
||||
remaining = job.remaining_ms()
|
||||
if remaining <= 0:
|
||||
resp = JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="timeout",
|
||||
error=JobError(code="DEADLINE_EXCEEDED", message="deadline already passed"),
|
||||
)
|
||||
_idem.put(idem_key, resp)
|
||||
await _reply(msg, resp)
|
||||
return
|
||||
|
||||
inflight = await _idem.acquire_inflight(idem_key)
|
||||
if inflight is not None:
|
||||
try:
|
||||
resp = await asyncio.wait_for(inflight, timeout=remaining / 1000.0)
|
||||
resp_copy = resp.model_copy()
|
||||
resp_copy.cached = True
|
||||
await _reply(msg, resp_copy)
|
||||
except asyncio.TimeoutError:
|
||||
await _reply(msg, JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="timeout", error=JobError(code="INFLIGHT_TIMEOUT"),
|
||||
))
|
||||
return
|
||||
|
||||
if _semaphore.locked() and _semaphore._value == 0:
|
||||
resp = JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="busy",
|
||||
error=JobError(code="CONCURRENCY_LIMIT", message=f"max {config.MAX_CONCURRENCY}"),
|
||||
)
|
||||
_idem.complete_inflight(idem_key, resp)
|
||||
await _reply(msg, resp)
|
||||
return
|
||||
|
||||
async with _semaphore:
|
||||
resp = await _execute(job, remaining)
|
||||
|
||||
_idem.put(idem_key, resp)
|
||||
_idem.complete_inflight(idem_key, resp)
|
||||
resp.latency_ms = int((time.time() - t0) * 1000)
|
||||
await _reply(msg, resp)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Worker handler error: {e}")
|
||||
try:
|
||||
await _reply(msg, JobResponse(
|
||||
node_id=config.NODE_ID, status="error",
|
||||
error=JobError(code="INTERNAL", message=str(e)[:200]),
|
||||
))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
|
||||
payload = job.payload
|
||||
hints = job.hints
|
||||
timeout_s = min(remaining_ms / 1000.0, 120.0)
|
||||
model = hints.get("prefer_models", [None])[0] if hints.get("prefer_models") else payload.get("model", "")
|
||||
|
||||
msg_count = len(payload.get("messages", []))
|
||||
prompt_chars = len(payload.get("prompt", ""))
|
||||
logger.info(
|
||||
f"[job.start] job={job.job_id} trace={job.trace_id} "
|
||||
f"type={job.required_type} model={model or '?'} "
|
||||
f"msgs={msg_count} chars={prompt_chars} deadline_rem={remaining_ms}ms"
|
||||
)
|
||||
|
||||
try:
|
||||
if job.required_type == "llm":
|
||||
result = await asyncio.wait_for(
|
||||
ollama.infer(
|
||||
messages=payload.get("messages"),
|
||||
prompt=payload.get("prompt", ""),
|
||||
model=model,
|
||||
system=payload.get("system", ""),
|
||||
max_tokens=hints.get("max_tokens", payload.get("max_tokens", 2048)),
|
||||
temperature=hints.get("temperature", payload.get("temperature", 0.2)),
|
||||
timeout_s=timeout_s,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
elif job.required_type == "vision":
|
||||
result = await asyncio.wait_for(
|
||||
swapper_vision.infer(
|
||||
images=payload.get("images"),
|
||||
prompt=payload.get("prompt", ""),
|
||||
model=model,
|
||||
system=payload.get("system", ""),
|
||||
max_tokens=hints.get("max_tokens", 1024),
|
||||
temperature=hints.get("temperature", 0.2),
|
||||
timeout_s=timeout_s,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
else:
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error",
|
||||
error=JobError(code="UNSUPPORTED_TYPE", message=f"{job.required_type} not implemented"),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[job.done] job={job.job_id} status=ok "
|
||||
f"provider={result.get('provider')} model={result.get('model')}"
|
||||
)
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="ok", provider=result.get("provider", ""), model=result.get("model", ""),
|
||||
result=result,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"[job.timeout] job={job.job_id} after {timeout_s}s")
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="timeout", error=JobError(code="PROVIDER_TIMEOUT"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[job.error] job={job.job_id} {e}")
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error", error=JobError(code="PROVIDER_ERROR", message=str(e)[:300]),
|
||||
)
|
||||
|
||||
|
||||
async def _reply(msg, resp: JobResponse):
|
||||
if msg.reply:
|
||||
await _nats_client.publish(msg.reply, resp.model_dump_json().encode())
|
||||
@@ -51,11 +51,13 @@ try:
|
||||
import capabilities_client
|
||||
import global_capabilities_client
|
||||
from model_select import select_model_for_agent, ModelSelection, CLOUD_PROVIDERS as NCS_CLOUD_PROVIDERS
|
||||
import offload_client
|
||||
NCS_AVAILABLE = True
|
||||
except ImportError:
|
||||
NCS_AVAILABLE = False
|
||||
capabilities_client = None # type: ignore[assignment]
|
||||
global_capabilities_client = None # type: ignore[assignment]
|
||||
offload_client = None # type: ignore[assignment]
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1707,6 +1709,76 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
f"provider={provider} model={model}"
|
||||
)
|
||||
|
||||
# =========================================================================
|
||||
# REMOTE OFFLOAD (if model selected on remote node)
|
||||
# =========================================================================
|
||||
nats_client_available = nc is not None and nats_available
|
||||
if (ncs_selection and ncs_selection.via_nats and not ncs_selection.local
|
||||
and nats_client_available and offload_client and nc):
|
||||
infer_timeout = int(os.getenv("ROUTER_INFER_TIMEOUT_MS", "25000"))
|
||||
import uuid as _uuid
|
||||
job_payload = {
|
||||
"job_id": str(_uuid.uuid4()),
|
||||
"trace_id": str(_uuid.uuid4()),
|
||||
"actor_agent_id": request_agent_id or agent_id,
|
||||
"target_agent_id": agent_id,
|
||||
"required_type": ncs_selection.model_type if ncs_selection.model_type != "code" else "llm",
|
||||
"deadline_ts": int(time.time() * 1000) + infer_timeout,
|
||||
"idempotency_key": str(_uuid.uuid4()),
|
||||
"payload": {
|
||||
"prompt": request.prompt,
|
||||
"messages": [{"role": "system", "content": system_prompt}] if system_prompt else [],
|
||||
"model": ncs_selection.name,
|
||||
"max_tokens": request.max_tokens or 2048,
|
||||
"temperature": request.temperature or 0.2,
|
||||
},
|
||||
"hints": {"prefer_models": [ncs_selection.name]},
|
||||
}
|
||||
if request.images:
|
||||
job_payload["payload"]["images"] = request.images
|
||||
job_payload["required_type"] = "vision"
|
||||
job_payload["payload"]["messages"].append({"role": "user", "content": request.prompt})
|
||||
|
||||
offload_resp = await offload_client.offload_infer(
|
||||
nats_client=nc,
|
||||
node_id=ncs_selection.node,
|
||||
required_type=job_payload["required_type"],
|
||||
job_payload=job_payload,
|
||||
timeout_ms=infer_timeout,
|
||||
)
|
||||
if offload_resp and offload_resp.get("status") == "ok":
|
||||
result_text = offload_resp.get("result", {}).get("text", "")
|
||||
return InferResponse(
|
||||
response=result_text,
|
||||
model=f"{offload_resp.get('model', ncs_selection.name)}@{ncs_selection.node}",
|
||||
backend=f"nats-offload:{ncs_selection.node}",
|
||||
tokens_used=offload_resp.get("result", {}).get("eval_count", 0),
|
||||
)
|
||||
else:
|
||||
offload_status = offload_resp.get("status", "none") if offload_resp else "no_reply"
|
||||
logger.warning(
|
||||
f"[fallback] offload to {ncs_selection.node} failed ({offload_status}) "
|
||||
f"→ re-selecting with exclude={ncs_selection.node}, force_local"
|
||||
)
|
||||
try:
|
||||
gcaps = await global_capabilities_client.get_global_capabilities()
|
||||
ncs_selection = await select_model_for_agent(
|
||||
agent_id, agent_config, router_config, gcaps, request.model,
|
||||
exclude_nodes={ncs_selection.node}, force_local=True,
|
||||
)
|
||||
if ncs_selection and ncs_selection.name:
|
||||
provider = ncs_selection.provider
|
||||
model = ncs_selection.name
|
||||
llm_profile = router_config.get("llm_profiles", {}).get(default_llm, {})
|
||||
if ncs_selection.base_url and provider == "ollama":
|
||||
llm_profile = {**llm_profile, "base_url": ncs_selection.base_url}
|
||||
logger.info(
|
||||
f"[fallback.reselect] → local node={ncs_selection.node} "
|
||||
f"model={model} provider={provider}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[fallback.reselect] error: {e}; proceeding with static")
|
||||
|
||||
# =========================================================================
|
||||
# VISION PROCESSING (if images present)
|
||||
# =========================================================================
|
||||
|
||||
@@ -9,7 +9,7 @@ Scaling: works with 1 node or 150+. No static node lists.
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger("model_select")
|
||||
|
||||
@@ -110,6 +110,7 @@ def profile_requirements(
|
||||
def select_best_model(
|
||||
reqs: ProfileRequirements,
|
||||
capabilities: Dict[str, Any],
|
||||
exclude_nodes: Optional[Set[str]] = None,
|
||||
) -> Optional[ModelSelection]:
|
||||
"""Choose the best served model from global (multi-node) capabilities.
|
||||
|
||||
@@ -117,18 +118,25 @@ def select_best_model(
|
||||
1. Prefer list matches (local first, then remote)
|
||||
2. Best candidate by size (local first, then remote)
|
||||
3. None → caller should try static fallback
|
||||
|
||||
exclude_nodes: set of node_ids to skip (e.g. circuit-broken nodes).
|
||||
"""
|
||||
served = capabilities.get("served_models", [])
|
||||
if not served:
|
||||
return None
|
||||
|
||||
exclude = exclude_nodes or set()
|
||||
|
||||
search_types = [reqs.required_type]
|
||||
if reqs.required_type == "code":
|
||||
search_types.append("llm")
|
||||
if reqs.required_type == "llm":
|
||||
search_types.append("code")
|
||||
|
||||
candidates = [m for m in served if m.get("type") in search_types]
|
||||
candidates = [
|
||||
m for m in served
|
||||
if m.get("type") in search_types and m.get("node", "") not in exclude
|
||||
]
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
@@ -218,15 +226,21 @@ async def select_model_for_agent(
|
||||
router_cfg: Dict[str, Any],
|
||||
capabilities: Optional[Dict[str, Any]],
|
||||
request_model: Optional[str] = None,
|
||||
exclude_nodes: Optional[Set[str]] = None,
|
||||
force_local: bool = False,
|
||||
) -> ModelSelection:
|
||||
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default."""
|
||||
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default.
|
||||
|
||||
exclude_nodes: skip these nodes (circuit-broken). Used on fallback re-selection.
|
||||
force_local: prefer local-only models (fallback after remote failure).
|
||||
"""
|
||||
profile = resolve_effective_profile(
|
||||
agent_id, agent_cfg, router_cfg, request_model,
|
||||
)
|
||||
|
||||
reqs = profile_requirements(profile, agent_cfg, router_cfg)
|
||||
|
||||
if reqs.required_type == "cloud_llm":
|
||||
if reqs.required_type == "cloud_llm" and not force_local:
|
||||
static = static_fallback(profile, router_cfg)
|
||||
if static:
|
||||
static.fallback_reason = ""
|
||||
@@ -236,14 +250,31 @@ async def select_model_for_agent(
|
||||
)
|
||||
return static
|
||||
|
||||
excl = set(exclude_nodes) if exclude_nodes else set()
|
||||
try:
|
||||
from offload_client import get_unavailable_nodes
|
||||
cb_nodes = get_unavailable_nodes(reqs.required_type)
|
||||
excl |= cb_nodes
|
||||
if cb_nodes:
|
||||
logger.info(f"[select] circuit-broken nodes for {reqs.required_type}: {cb_nodes}")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if capabilities and capabilities.get("served_models"):
|
||||
sel = select_best_model(reqs, capabilities)
|
||||
sel = select_best_model(reqs, capabilities, exclude_nodes=excl)
|
||||
if force_local and sel and not sel.local:
|
||||
sel = select_best_model(
|
||||
reqs, capabilities,
|
||||
exclude_nodes=excl | {n.get("node", "") for n in capabilities.get("served_models", []) if not n.get("local")},
|
||||
)
|
||||
if sel:
|
||||
logger.info(
|
||||
f"[select] agent={agent_id} profile={profile} → "
|
||||
f"{'NCS' if sel.local else 'REMOTE'} "
|
||||
f"{'LOCAL' if sel.local else 'REMOTE'} "
|
||||
f"node={sel.node} runtime={sel.runtime} "
|
||||
f"model={sel.name} caps_age={sel.caps_age_s}s"
|
||||
f"{' (force_local)' if force_local else ''}"
|
||||
f"{' (excluded: ' + ','.join(excl) + ')' if excl else ''}"
|
||||
)
|
||||
return sel
|
||||
logger.warning(
|
||||
|
||||
153
services/router/offload_client.py
Normal file
153
services/router/offload_client.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""NATS offload client — sends inference requests to remote nodes with
|
||||
circuit breaker, retries, and deadline enforcement."""
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, Literal, Optional, Set
|
||||
|
||||
logger = logging.getLogger("offload_client")
|
||||
|
||||
CB_FAILS = int(os.getenv("ROUTER_OFFLOAD_CB_FAILS", "3"))
|
||||
CB_WINDOW_S = int(os.getenv("ROUTER_OFFLOAD_CB_WINDOW_S", "60"))
|
||||
CB_OPEN_S = int(os.getenv("ROUTER_OFFLOAD_CB_OPEN_S", "120"))
|
||||
MAX_RETRIES = int(os.getenv("ROUTER_OFFLOAD_RETRIES", "1"))
|
||||
MAX_CONCURRENCY = int(os.getenv("ROUTER_OFFLOAD_MAX_CONCURRENCY_REMOTE", "8"))
|
||||
|
||||
_semaphore: Optional[asyncio.Semaphore] = None
|
||||
_breakers: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _get_semaphore() -> asyncio.Semaphore:
|
||||
global _semaphore
|
||||
if _semaphore is None:
|
||||
_semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
|
||||
return _semaphore
|
||||
|
||||
|
||||
def _breaker_key(node_id: str, req_type: str) -> str:
|
||||
return f"{node_id}:{req_type}"
|
||||
|
||||
|
||||
def is_node_available(node_id: str, req_type: str) -> bool:
|
||||
key = _breaker_key(node_id, req_type)
|
||||
state = _breakers.get(key)
|
||||
if not state:
|
||||
return True
|
||||
open_until = state.get("open_until", 0)
|
||||
if open_until > time.time():
|
||||
return False
|
||||
if open_until > 0 and open_until <= time.time():
|
||||
return True
|
||||
now = time.time()
|
||||
window_start = now - CB_WINDOW_S
|
||||
recent = [t for t in state.get("failures", []) if t > window_start]
|
||||
state["failures"] = recent
|
||||
return len(recent) < CB_FAILS
|
||||
|
||||
|
||||
def record_failure(node_id: str, req_type: str):
|
||||
key = _breaker_key(node_id, req_type)
|
||||
state = _breakers.setdefault(key, {"failures": [], "open_until": 0})
|
||||
state["failures"].append(time.time())
|
||||
window_start = time.time() - CB_WINDOW_S
|
||||
recent = [t for t in state["failures"] if t > window_start]
|
||||
state["failures"] = recent
|
||||
if len(recent) >= CB_FAILS:
|
||||
state["open_until"] = time.time() + CB_OPEN_S
|
||||
logger.warning(f"Circuit OPEN: {key} ({len(recent)} failures in {CB_WINDOW_S}s, open for {CB_OPEN_S}s)")
|
||||
|
||||
|
||||
def record_success(node_id: str, req_type: str):
|
||||
key = _breaker_key(node_id, req_type)
|
||||
state = _breakers.get(key)
|
||||
if state:
|
||||
state["failures"] = []
|
||||
state["open_until"] = 0
|
||||
|
||||
|
||||
def get_unavailable_nodes(req_type: str) -> Set[str]:
|
||||
result = set()
|
||||
for key, state in _breakers.items():
|
||||
if not key.endswith(f":{req_type}"):
|
||||
continue
|
||||
nid = key.rsplit(":", 1)[0]
|
||||
if not is_node_available(nid, req_type):
|
||||
result.add(nid)
|
||||
return result
|
||||
|
||||
|
||||
async def offload_infer(
|
||||
nats_client,
|
||||
node_id: str,
|
||||
required_type: Literal["llm", "vision", "stt", "tts"],
|
||||
job_payload: Dict[str, Any],
|
||||
timeout_ms: int = 25000,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Send inference job to remote node via NATS request/reply.
|
||||
|
||||
Returns parsed JobResponse dict or None on total failure.
|
||||
Retries on transient errors (timeout, busy). Does NOT retry on provider errors.
|
||||
"""
|
||||
subject = f"node.{node_id.lower()}.{required_type}.request"
|
||||
payload_bytes = json.dumps(job_payload).encode()
|
||||
sem = _get_semaphore()
|
||||
|
||||
for attempt in range(1 + MAX_RETRIES):
|
||||
timeout_s = timeout_ms / 1000.0
|
||||
if timeout_s <= 0:
|
||||
logger.warning(f"[offload] deadline exhausted before attempt {attempt}")
|
||||
return None
|
||||
|
||||
t0 = time.time()
|
||||
try:
|
||||
async with sem:
|
||||
logger.info(
|
||||
f"[offload.start] node={node_id} subj={subject} "
|
||||
f"timeout={timeout_ms}ms attempt={attempt}"
|
||||
)
|
||||
msg = await nats_client.request(subject, payload_bytes, timeout=timeout_s)
|
||||
resp = json.loads(msg.data)
|
||||
|
||||
latency = int((time.time() - t0) * 1000)
|
||||
status = resp.get("status", "error")
|
||||
|
||||
if status == "ok":
|
||||
record_success(node_id, required_type)
|
||||
logger.info(
|
||||
f"[offload.done] node={node_id} status=ok latency={latency}ms "
|
||||
f"provider={resp.get('provider')} model={resp.get('model')} "
|
||||
f"cached={resp.get('cached', False)}"
|
||||
)
|
||||
return resp
|
||||
|
||||
if status in ("timeout", "busy") and attempt < MAX_RETRIES:
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
timeout_ms -= elapsed
|
||||
logger.warning(f"[offload.retry] node={node_id} status={status} → retry {attempt+1}")
|
||||
continue
|
||||
|
||||
record_failure(node_id, required_type)
|
||||
logger.warning(
|
||||
f"[offload.fail] node={node_id} status={status} "
|
||||
f"error={resp.get('error', {}).get('code', '?')}"
|
||||
)
|
||||
return resp
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
record_failure(node_id, required_type)
|
||||
elapsed = int((time.time() - t0) * 1000)
|
||||
timeout_ms -= elapsed
|
||||
if attempt < MAX_RETRIES:
|
||||
logger.warning(f"[offload.timeout] node={node_id} {elapsed}ms → retry {attempt+1}")
|
||||
continue
|
||||
logger.warning(f"[offload.timeout] node={node_id} all retries exhausted")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
record_failure(node_id, required_type)
|
||||
logger.warning(f"[offload.error] node={node_id} {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
34
tests/test_node_worker_deadline.py
Normal file
34
tests/test_node_worker_deadline.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Tests for node-worker deadline handling."""
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services", "node-worker"))
|
||||
|
||||
from models import JobRequest
|
||||
|
||||
|
||||
def test_remaining_ms_with_future_deadline():
|
||||
job = JobRequest(deadline_ts=int(time.time() * 1000) + 10000)
|
||||
rem = job.remaining_ms()
|
||||
assert 9000 <= rem <= 10500
|
||||
|
||||
|
||||
def test_remaining_ms_with_past_deadline():
|
||||
job = JobRequest(deadline_ts=int(time.time() * 1000) - 5000)
|
||||
assert job.remaining_ms() == 0
|
||||
|
||||
|
||||
def test_remaining_ms_no_deadline():
|
||||
job = JobRequest(deadline_ts=0)
|
||||
assert job.remaining_ms() == 30_000
|
||||
|
||||
|
||||
def test_effective_idem_key_uses_idempotency_key():
|
||||
job = JobRequest(job_id="j1", idempotency_key="custom-key")
|
||||
assert job.effective_idem_key() == "custom-key"
|
||||
|
||||
|
||||
def test_effective_idem_key_falls_back_to_job_id():
|
||||
job = JobRequest(job_id="j2", idempotency_key="")
|
||||
assert job.effective_idem_key() == "j2"
|
||||
66
tests/test_node_worker_idempotency.py
Normal file
66
tests/test_node_worker_idempotency.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Tests for node-worker idempotency store."""
|
||||
import asyncio
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services", "node-worker"))
|
||||
|
||||
from models import JobResponse, JobError
|
||||
from idempotency import IdempotencyStore
|
||||
|
||||
|
||||
def test_put_and_get():
|
||||
store = IdempotencyStore()
|
||||
resp = JobResponse(job_id="j1", status="ok", node_id="n1", provider="ollama", model="qwen3:14b")
|
||||
store.put("key1", resp)
|
||||
cached = store.get("key1")
|
||||
assert cached is not None
|
||||
assert cached.status == "ok"
|
||||
assert cached.cached is True
|
||||
assert cached.job_id == "j1"
|
||||
|
||||
|
||||
def test_miss():
|
||||
store = IdempotencyStore()
|
||||
assert store.get("nonexistent") is None
|
||||
|
||||
|
||||
def test_ttl_expiry():
|
||||
store = IdempotencyStore()
|
||||
resp = JobResponse(job_id="j2", status="ok", node_id="n1")
|
||||
store.put("key2", resp)
|
||||
store._cache["key2"] = (resp, time.time() - 1)
|
||||
assert store.get("key2") is None
|
||||
|
||||
|
||||
def test_timeout_shorter_ttl():
|
||||
store = IdempotencyStore()
|
||||
resp = JobResponse(job_id="j3", status="timeout", error=JobError(code="TO"))
|
||||
store.put("key3", resp)
|
||||
_, expires = store._cache["key3"]
|
||||
assert expires - time.time() < 35 # timeout TTL ≈ 30s
|
||||
|
||||
|
||||
def test_inflight_dedup():
|
||||
store = IdempotencyStore()
|
||||
|
||||
async def run():
|
||||
fut1 = await store.acquire_inflight("key4")
|
||||
assert fut1 is None # first caller gets None (becomes processor)
|
||||
fut2 = await store.acquire_inflight("key4")
|
||||
assert fut2 is not None # second caller gets future to wait on
|
||||
|
||||
resp = JobResponse(job_id="j4", status="ok")
|
||||
store.complete_inflight("key4", resp)
|
||||
result = await asyncio.wait_for(fut2, timeout=1.0)
|
||||
assert result.status == "ok"
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_evict_on_max_size():
|
||||
store = IdempotencyStore(max_size=3)
|
||||
for i in range(5):
|
||||
store.put(f"k{i}", JobResponse(job_id=f"j{i}", status="ok"))
|
||||
assert len(store._cache) <= 3
|
||||
52
tests/test_offload_circuit_breaker.py
Normal file
52
tests/test_offload_circuit_breaker.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Tests for offload_client circuit breaker."""
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services", "router"))
|
||||
|
||||
import offload_client
|
||||
|
||||
|
||||
def setup_function():
|
||||
offload_client._breakers.clear()
|
||||
|
||||
|
||||
def test_node_available_by_default():
|
||||
assert offload_client.is_node_available("noda99", "llm") is True
|
||||
|
||||
|
||||
def test_node_unavailable_after_threshold():
|
||||
for _ in range(offload_client.CB_FAILS):
|
||||
offload_client.record_failure("noda99", "llm")
|
||||
assert offload_client.is_node_available("noda99", "llm") is False
|
||||
|
||||
|
||||
def test_node_available_after_open_expires():
|
||||
for _ in range(offload_client.CB_FAILS):
|
||||
offload_client.record_failure("noda99", "llm")
|
||||
offload_client._breakers["noda99:llm"]["open_until"] = time.time() - 1
|
||||
assert offload_client.is_node_available("noda99", "llm") is True
|
||||
|
||||
|
||||
def test_success_resets_breaker():
|
||||
for _ in range(offload_client.CB_FAILS):
|
||||
offload_client.record_failure("noda99", "llm")
|
||||
offload_client._breakers["noda99:llm"]["open_until"] = time.time() - 1
|
||||
offload_client.record_success("noda99", "llm")
|
||||
assert offload_client.is_node_available("noda99", "llm") is True
|
||||
|
||||
|
||||
def test_get_unavailable_nodes():
|
||||
for _ in range(offload_client.CB_FAILS):
|
||||
offload_client.record_failure("bad_node", "llm")
|
||||
unavail = offload_client.get_unavailable_nodes("llm")
|
||||
assert "bad_node" in unavail
|
||||
assert offload_client.get_unavailable_nodes("vision") == set()
|
||||
|
||||
|
||||
def test_different_types_independent():
|
||||
for _ in range(offload_client.CB_FAILS):
|
||||
offload_client.record_failure("noda99", "llm")
|
||||
assert offload_client.is_node_available("noda99", "llm") is False
|
||||
assert offload_client.is_node_available("noda99", "vision") is True
|
||||
Reference in New Issue
Block a user