Files
microdao-daarion/services/router/offload_client.py
Apple 9a36020316 P3.5-P3.7: 2-layer inventory, capability routing, STT/TTS adapters, Dev Contract
NCS:
- _collect_worker_caps() fetches capability flags from node-worker /caps
- _derive_capabilities() merges served model types + worker provider flags
- installed_artifacts replaces inventory_only (disk scan with DISK_SCAN_PATHS env)
- New endpoints: /capabilities/caps, /capabilities/installed

Node Worker:
- STT_PROVIDER, TTS_PROVIDER, OCR_PROVIDER, IMAGE_PROVIDER env flags
- /caps endpoint returns capabilities + providers for NCS aggregation
- STT adapter (providers/stt_mlx_whisper.py) — remote + local mode
- TTS adapter (providers/tts_mlx_kokoro.py) — remote + local mode
- OCR handler via vision_prompted (ollama_vision with OCR prompt)
- NATS subjects: node.{id}.stt/tts/ocr/image.request

Router:
- POST /v1/capability/{stt,tts,ocr,image} — capability-based offload routing
- GET /v1/capabilities — global view with capabilities_by_node
- require_fresh_caps(ttl) preflight guard
- find_nodes_with_capability(cap) + load-based node selection

Ops:
- ops/fabric_snapshot.py — full runtime snapshot collector
- ops/fabric_preflight.sh — quick check + snapshot save + diff
- docs/fabric_contract.md — Dev Contract v0.1 (preflight-first)
- tests/test_fabric_contract.py — CI enforcement (6 tests)

Made-with: Cursor
2026-02-27 05:24:09 -08:00

154 lines
5.3 KiB
Python

"""NATS offload client — sends inference requests to remote nodes with
circuit breaker, retries, and deadline enforcement."""
import asyncio
import json
import logging
import os
import time
from typing import Any, Dict, Literal, Optional, Set
logger = logging.getLogger("offload_client")
CB_FAILS = int(os.getenv("ROUTER_OFFLOAD_CB_FAILS", "3"))
CB_WINDOW_S = int(os.getenv("ROUTER_OFFLOAD_CB_WINDOW_S", "60"))
CB_OPEN_S = int(os.getenv("ROUTER_OFFLOAD_CB_OPEN_S", "120"))
MAX_RETRIES = int(os.getenv("ROUTER_OFFLOAD_RETRIES", "1"))
MAX_CONCURRENCY = int(os.getenv("ROUTER_OFFLOAD_MAX_CONCURRENCY_REMOTE", "8"))
_semaphore: Optional[asyncio.Semaphore] = None
_breakers: Dict[str, Dict[str, Any]] = {}
def _get_semaphore() -> asyncio.Semaphore:
global _semaphore
if _semaphore is None:
_semaphore = asyncio.Semaphore(MAX_CONCURRENCY)
return _semaphore
def _breaker_key(node_id: str, req_type: str) -> str:
return f"{node_id}:{req_type}"
def is_node_available(node_id: str, req_type: str) -> bool:
key = _breaker_key(node_id, req_type)
state = _breakers.get(key)
if not state:
return True
open_until = state.get("open_until", 0)
if open_until > time.time():
return False
if open_until > 0 and open_until <= time.time():
return True
now = time.time()
window_start = now - CB_WINDOW_S
recent = [t for t in state.get("failures", []) if t > window_start]
state["failures"] = recent
return len(recent) < CB_FAILS
def record_failure(node_id: str, req_type: str):
key = _breaker_key(node_id, req_type)
state = _breakers.setdefault(key, {"failures": [], "open_until": 0})
state["failures"].append(time.time())
window_start = time.time() - CB_WINDOW_S
recent = [t for t in state["failures"] if t > window_start]
state["failures"] = recent
if len(recent) >= CB_FAILS:
state["open_until"] = time.time() + CB_OPEN_S
logger.warning(f"Circuit OPEN: {key} ({len(recent)} failures in {CB_WINDOW_S}s, open for {CB_OPEN_S}s)")
def record_success(node_id: str, req_type: str):
key = _breaker_key(node_id, req_type)
state = _breakers.get(key)
if state:
state["failures"] = []
state["open_until"] = 0
def get_unavailable_nodes(req_type: str) -> Set[str]:
result = set()
for key, state in _breakers.items():
if not key.endswith(f":{req_type}"):
continue
nid = key.rsplit(":", 1)[0]
if not is_node_available(nid, req_type):
result.add(nid)
return result
async def offload_infer(
nats_client,
node_id: str,
required_type: Literal["llm", "vision", "stt", "tts", "ocr", "image"],
job_payload: Dict[str, Any],
timeout_ms: int = 25000,
) -> Optional[Dict[str, Any]]:
"""Send inference job to remote node via NATS request/reply.
Returns parsed JobResponse dict or None on total failure.
Retries on transient errors (timeout, busy). Does NOT retry on provider errors.
"""
subject = f"node.{node_id.lower()}.{required_type}.request"
payload_bytes = json.dumps(job_payload).encode()
sem = _get_semaphore()
for attempt in range(1 + MAX_RETRIES):
timeout_s = timeout_ms / 1000.0
if timeout_s <= 0:
logger.warning(f"[offload] deadline exhausted before attempt {attempt}")
return None
t0 = time.time()
try:
async with sem:
logger.info(
f"[offload.start] node={node_id} subj={subject} "
f"timeout={timeout_ms}ms attempt={attempt}"
)
msg = await nats_client.request(subject, payload_bytes, timeout=timeout_s)
resp = json.loads(msg.data)
latency = int((time.time() - t0) * 1000)
status = resp.get("status", "error")
if status == "ok":
record_success(node_id, required_type)
logger.info(
f"[offload.done] node={node_id} status=ok latency={latency}ms "
f"provider={resp.get('provider')} model={resp.get('model')} "
f"cached={resp.get('cached', False)}"
)
return resp
if status in ("timeout", "busy") and attempt < MAX_RETRIES:
elapsed = int((time.time() - t0) * 1000)
timeout_ms -= elapsed
logger.warning(f"[offload.retry] node={node_id} status={status} → retry {attempt+1}")
continue
record_failure(node_id, required_type)
logger.warning(
f"[offload.fail] node={node_id} status={status} "
f"error={resp.get('error', {}).get('code', '?')}"
)
return resp
except asyncio.TimeoutError:
record_failure(node_id, required_type)
elapsed = int((time.time() - t0) * 1000)
timeout_ms -= elapsed
if attempt < MAX_RETRIES:
logger.warning(f"[offload.timeout] node={node_id} {elapsed}ms → retry {attempt+1}")
continue
logger.warning(f"[offload.timeout] node={node_id} all retries exhausted")
return None
except Exception as e:
record_failure(node_id, required_type)
logger.warning(f"[offload.error] node={node_id} {e}")
return None
return None