"""NATS offload client — sends inference requests to remote nodes with circuit breaker, retries, and deadline enforcement.""" import asyncio import json import logging import os import time from typing import Any, Dict, Literal, Optional, Set logger = logging.getLogger("offload_client") CB_FAILS = int(os.getenv("ROUTER_OFFLOAD_CB_FAILS", "3")) CB_WINDOW_S = int(os.getenv("ROUTER_OFFLOAD_CB_WINDOW_S", "60")) CB_OPEN_S = int(os.getenv("ROUTER_OFFLOAD_CB_OPEN_S", "120")) MAX_RETRIES = int(os.getenv("ROUTER_OFFLOAD_RETRIES", "1")) MAX_CONCURRENCY = int(os.getenv("ROUTER_OFFLOAD_MAX_CONCURRENCY_REMOTE", "8")) _semaphore: Optional[asyncio.Semaphore] = None _breakers: Dict[str, Dict[str, Any]] = {} def _get_semaphore() -> asyncio.Semaphore: global _semaphore if _semaphore is None: _semaphore = asyncio.Semaphore(MAX_CONCURRENCY) return _semaphore def _breaker_key(node_id: str, req_type: str) -> str: return f"{node_id}:{req_type}" def is_node_available(node_id: str, req_type: str) -> bool: key = _breaker_key(node_id, req_type) state = _breakers.get(key) if not state: return True open_until = state.get("open_until", 0) if open_until > time.time(): return False if open_until > 0 and open_until <= time.time(): return True now = time.time() window_start = now - CB_WINDOW_S recent = [t for t in state.get("failures", []) if t > window_start] state["failures"] = recent return len(recent) < CB_FAILS def record_failure(node_id: str, req_type: str): key = _breaker_key(node_id, req_type) state = _breakers.setdefault(key, {"failures": [], "open_until": 0}) state["failures"].append(time.time()) window_start = time.time() - CB_WINDOW_S recent = [t for t in state["failures"] if t > window_start] state["failures"] = recent if len(recent) >= CB_FAILS: state["open_until"] = time.time() + CB_OPEN_S logger.warning(f"Circuit OPEN: {key} ({len(recent)} failures in {CB_WINDOW_S}s, open for {CB_OPEN_S}s)") def record_success(node_id: str, req_type: str): key = _breaker_key(node_id, req_type) state = _breakers.get(key) if state: state["failures"] = [] state["open_until"] = 0 def get_unavailable_nodes(req_type: str) -> Set[str]: result = set() for key, state in _breakers.items(): if not key.endswith(f":{req_type}"): continue nid = key.rsplit(":", 1)[0] if not is_node_available(nid, req_type): result.add(nid) return result async def offload_infer( nats_client, node_id: str, required_type: Literal["llm", "vision", "stt", "tts", "ocr", "image"], job_payload: Dict[str, Any], timeout_ms: int = 25000, ) -> Optional[Dict[str, Any]]: """Send inference job to remote node via NATS request/reply. Returns parsed JobResponse dict or None on total failure. Retries on transient errors (timeout, busy). Does NOT retry on provider errors. """ subject = f"node.{node_id.lower()}.{required_type}.request" payload_bytes = json.dumps(job_payload).encode() sem = _get_semaphore() for attempt in range(1 + MAX_RETRIES): timeout_s = timeout_ms / 1000.0 if timeout_s <= 0: logger.warning(f"[offload] deadline exhausted before attempt {attempt}") return None t0 = time.time() try: async with sem: logger.info( f"[offload.start] node={node_id} subj={subject} " f"timeout={timeout_ms}ms attempt={attempt}" ) msg = await nats_client.request(subject, payload_bytes, timeout=timeout_s) resp = json.loads(msg.data) latency = int((time.time() - t0) * 1000) status = resp.get("status", "error") if status == "ok": record_success(node_id, required_type) logger.info( f"[offload.done] node={node_id} status=ok latency={latency}ms " f"provider={resp.get('provider')} model={resp.get('model')} " f"cached={resp.get('cached', False)}" ) return resp if status in ("timeout", "busy") and attempt < MAX_RETRIES: elapsed = int((time.time() - t0) * 1000) timeout_ms -= elapsed logger.warning(f"[offload.retry] node={node_id} status={status} → retry {attempt+1}") continue record_failure(node_id, required_type) logger.warning( f"[offload.fail] node={node_id} status={status} " f"error={resp.get('error', {}).get('code', '?')}" ) return resp except asyncio.TimeoutError: record_failure(node_id, required_type) elapsed = int((time.time() - t0) * 1000) timeout_ms -= elapsed if attempt < MAX_RETRIES: logger.warning(f"[offload.timeout] node={node_id} {elapsed}ms → retry {attempt+1}") continue logger.warning(f"[offload.timeout] node={node_id} all retries exhausted") return None except Exception as e: record_failure(node_id, required_type) logger.warning(f"[offload.error] node={node_id} {e}") return None return None