"""Global Capabilities Client — aggregates model capabilities across all nodes. Design for 150+ nodes: - Local NCS: HTTP (fast, always available) - Remote nodes: NATS request/reply with wildcard discovery - node.*.capabilities.get → each NCS replies with its capabilities - No static node list needed — new nodes auto-register by subscribing - scatter-gather pattern: send one request, collect N replies - TTL cache per node, stale nodes expire automatically """ import asyncio import json import logging import os import time from typing import Any, Dict, List, Optional import httpx logger = logging.getLogger("global_caps") LOCAL_NCS_URL = os.getenv("NODE_CAPABILITIES_URL", "") LOCAL_NODE_ID = os.getenv("NODE_ID", "unknown") NATS_URL = os.getenv("NATS_URL", "nats://nats:4222") CACHE_TTL = int(os.getenv("GLOBAL_CAPS_TTL", "30")) NATS_DISCOVERY_TIMEOUT_MS = int(os.getenv("NATS_DISCOVERY_TIMEOUT_MS", "500")) NATS_ENABLED = os.getenv("ENABLE_GLOBAL_CAPS_NATS", "true").lower() in ("true", "1") CAPS_DISCOVERY_SUBJECT = "fabric.capabilities.discover" CAPS_INBOX_PREFIX = "_CAPS_REPLY" _node_cache: Dict[str, Dict[str, Any]] = {} _node_timestamps: Dict[str, float] = {} _nats_client = None _initialized = False async def initialize(): """Connect to NATS for discovery. Called once at router startup.""" global _nats_client, _initialized if not NATS_ENABLED: logger.info("Global caps NATS discovery disabled") _initialized = True return try: import nats as nats_lib _nats_client = await nats_lib.connect(NATS_URL) _initialized = True logger.info(f"✅ Global caps NATS connected: {NATS_URL}") except Exception as e: logger.warning(f"⚠️ Global caps NATS init failed (non-fatal): {e}") _nats_client = None _initialized = True async def shutdown(): global _nats_client if _nats_client: try: await _nats_client.close() except Exception: pass _nats_client = None async def _fetch_local() -> Optional[Dict[str, Any]]: """Fetch capabilities from local NCS via HTTP.""" if not LOCAL_NCS_URL: return None try: async with httpx.AsyncClient(timeout=3) as c: resp = await c.get(LOCAL_NCS_URL) if resp.status_code == 200: data = resp.json() node_id = data.get("node_id", LOCAL_NODE_ID) _node_cache[node_id] = data _node_timestamps[node_id] = time.time() return data except Exception as e: logger.warning(f"Local NCS fetch failed: {e}") return _node_cache.get(LOCAL_NODE_ID) async def _discover_remote_nodes() -> List[Dict[str, Any]]: """Scatter-gather discovery: send to node.*.capabilities.get, collect replies. Each NCS on every node subscribes to node.{node_id}.capabilities.get. NATS wildcard routing delivers our request to ALL of them. We collect replies within NATS_DISCOVERY_TIMEOUT_MS. This scales to 150+ nodes with zero static configuration: - New node deploys NCS → subscribes to its subject → automatically discovered. - Dead node stops responding → its cache entry expires after TTL. """ if not _nats_client: return [] collected: List[Dict[str, Any]] = [] inbox = _nats_client.new_inbox() sub = await _nats_client.subscribe(inbox) try: await _nats_client.publish( CAPS_DISCOVERY_SUBJECT, b"", reply=inbox, ) await _nats_client.flush() deadline = time.time() + (NATS_DISCOVERY_TIMEOUT_MS / 1000.0) while time.time() < deadline: remaining = deadline - time.time() if remaining <= 0: break try: msg = await asyncio.wait_for( sub.next_msg(), timeout=remaining, ) data = json.loads(msg.data) node_id = data.get("node_id", "?") if node_id != LOCAL_NODE_ID: _node_cache[node_id] = data _node_timestamps[node_id] = time.time() collected.append(data) except asyncio.TimeoutError: break except Exception as e: logger.debug(f"Discovery parse error: {e}") break finally: await sub.unsubscribe() if collected: logger.info( f"Discovered {len(collected)} remote node(s): " f"{[c.get('node_id', '?') for c in collected]}" ) return collected def _evict_stale(): """Remove nodes that haven't refreshed within 3x TTL.""" cutoff = time.time() - (CACHE_TTL * 3) stale = [nid for nid, ts in _node_timestamps.items() if ts < cutoff] for nid in stale: _node_cache.pop(nid, None) _node_timestamps.pop(nid, None) logger.info(f"Evicted stale node: {nid}") def _needs_refresh() -> bool: """Check if any node cache is older than TTL.""" if not _node_timestamps: return True oldest = min(_node_timestamps.values()) return (time.time() - oldest) > CACHE_TTL async def get_global_capabilities(force: bool = False) -> Dict[str, Any]: """Return merged capabilities from all known nodes. Returns: { "local_node": "noda1", "nodes": {"noda1": {...}, "noda2": {...}, ...}, "served_models": [...], # all models with "node" field "node_count": 2, "updated_at": "...", } """ if not force and not _needs_refresh(): return _build_global_view() _evict_stale() tasks = [_fetch_local()] if _nats_client: tasks.append(_discover_remote_nodes()) await asyncio.gather(*tasks, return_exceptions=True) return _build_global_view() def _build_global_view() -> Dict[str, Any]: """Build a unified view from all cached node capabilities.""" all_served: List[Dict[str, Any]] = [] global_caps: Dict[str, Dict[str, Any]] = {} for node_id, caps in _node_cache.items(): is_local = (node_id.lower() == LOCAL_NODE_ID.lower()) age = time.time() - _node_timestamps.get(node_id, 0) for m in caps.get("served_models", []): all_served.append({ **m, "node": node_id, "local": is_local, "node_age_s": round(age, 1), }) node_caps = caps.get("capabilities", {}) if node_caps: global_caps[node_id] = { k: v for k, v in node_caps.items() if k != "providers" } all_served.sort(key=lambda m: (0 if m.get("local") else 1, m.get("name", ""))) return { "local_node": LOCAL_NODE_ID, "nodes": {nid: { "node_id": nid, "served_count": len(c.get("served_models", [])), "installed_count": c.get("installed_count", 0), "capabilities": c.get("capabilities", {}), "node_load": c.get("node_load", {}), "age_s": round(time.time() - _node_timestamps.get(nid, 0), 1), } for nid, c in _node_cache.items()}, "served_models": all_served, "served_count": len(all_served), "capabilities_by_node": global_caps, "node_count": len(_node_cache), "updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), } def get_cached_global() -> Dict[str, Any]: """Return cached global view without fetching.""" return _build_global_view() async def require_fresh_caps(ttl: int = 30) -> Optional[Dict[str, Any]]: """Preflight: return global caps only if fresh enough. Returns None if NCS data is stale beyond ttl — caller should use safe fallback instead of making routing decisions on outdated info. """ if not _node_timestamps: gcaps = await get_global_capabilities(force=True) if not _node_timestamps: return None return gcaps oldest = min(_node_timestamps.values()) if (time.time() - oldest) > ttl: gcaps = await get_global_capabilities(force=True) oldest = min(_node_timestamps.values()) if _node_timestamps else 0 if (time.time() - oldest) > ttl: logger.warning("[preflight] caps stale after refresh, age=%ds", int(time.time() - oldest)) return None return gcaps return _build_global_view() def find_nodes_with_capability(cap: str) -> List[str]: """Return node IDs that have a given capability enabled.""" result = [] for nid, caps in _node_cache.items(): node_caps = caps.get("capabilities", {}) if node_caps.get(cap, False): result.append(nid) return result def get_node_load(node_id: str) -> Dict[str, Any]: """Get cached node_load for a specific node.""" caps = _node_cache.get(node_id, {}) return caps.get("node_load", {}) async def send_offload_request( node_id: str, request_type: str, payload: Dict[str, Any], timeout_s: float = 30.0, ) -> Optional[Dict[str, Any]]: """Send an inference request to a remote node via NATS. Subject pattern: node.{node_id}.{type}.request Reply: inline NATS request/reply """ if not _nats_client: logger.warning("Cannot offload: NATS not connected") return None subject = f"node.{node_id.lower()}.{request_type}.request" try: msg = await _nats_client.request( subject, json.dumps(payload).encode(), timeout=timeout_s, ) return json.loads(msg.data) except asyncio.TimeoutError: logger.warning(f"Offload timeout: {subject} ({timeout_s}s)") return None except Exception as e: logger.warning(f"Offload error: {subject}: {e}") return None