Architecture for 150+ nodes: - global_capabilities_client.py: NATS scatter-gather discovery using wildcard subject node.*.capabilities.get — zero static node lists. New nodes auto-register by deploying NCS and subscribing to NATS. Dead nodes expire from cache after 3x TTL automatically. Multi-node model_select.py: - ModelSelection now includes node, local, via_nats fields - select_best_model prefers local candidates, then remote - Prefer list resolution: local first, remote second - All logged per request: node, runtime, model, local/remote NODA1 compose: - Added node-capabilities service (NCS) to docker-compose.node1.yml - NATS subscription: node.noda1.capabilities.get - Router env: NODE_CAPABILITIES_URL + ENABLE_GLOBAL_CAPS_NATS=true NODA2 compose: - Router env: ENABLE_GLOBAL_CAPS_NATS=true Router main.py: - Startup: initializes global_capabilities_client (NATS connect + first discovery). Falls back to local-only capabilities_client if unavailable. - /infer: uses get_global_capabilities() for cross-node model pool - Offload support: send_offload_request(node_id, type, payload) via NATS Verified on NODA2: - Global caps: 1 node, 14 models (NODA1 not yet deployed) - Sofiia: cloud_grok → grok-4-1-fast-reasoning (OK) - Helion: NCS → qwen3:14b local (OK) - When NODA1 deploys NCS, its models appear automatically via NATS discovery Made-with: Cursor
246 lines
7.9 KiB
Python
246 lines
7.9 KiB
Python
"""Global Capabilities Client — aggregates model capabilities across all nodes.
|
|
|
|
Design for 150+ nodes:
|
|
- Local NCS: HTTP (fast, always available)
|
|
- Remote nodes: NATS request/reply with wildcard discovery
|
|
- node.*.capabilities.get → each NCS replies with its capabilities
|
|
- No static node list needed — new nodes auto-register by subscribing
|
|
- scatter-gather pattern: send one request, collect N replies
|
|
- TTL cache per node, stale nodes expire automatically
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger("global_caps")
|
|
|
|
LOCAL_NCS_URL = os.getenv("NODE_CAPABILITIES_URL", "")
|
|
LOCAL_NODE_ID = os.getenv("NODE_ID", "unknown")
|
|
NATS_URL = os.getenv("NATS_URL", "nats://nats:4222")
|
|
CACHE_TTL = int(os.getenv("GLOBAL_CAPS_TTL", "30"))
|
|
NATS_DISCOVERY_TIMEOUT_MS = int(os.getenv("NATS_DISCOVERY_TIMEOUT_MS", "500"))
|
|
NATS_ENABLED = os.getenv("ENABLE_GLOBAL_CAPS_NATS", "true").lower() in ("true", "1")
|
|
|
|
CAPS_DISCOVERY_SUBJECT = "node.*.capabilities.get"
|
|
CAPS_INBOX_PREFIX = "_CAPS_REPLY"
|
|
|
|
_node_cache: Dict[str, Dict[str, Any]] = {}
|
|
_node_timestamps: Dict[str, float] = {}
|
|
_nats_client = None
|
|
_initialized = False
|
|
|
|
|
|
async def initialize():
|
|
"""Connect to NATS for discovery. Called once at router startup."""
|
|
global _nats_client, _initialized
|
|
if not NATS_ENABLED:
|
|
logger.info("Global caps NATS discovery disabled")
|
|
_initialized = True
|
|
return
|
|
try:
|
|
import nats as nats_lib
|
|
_nats_client = await nats_lib.connect(NATS_URL)
|
|
_initialized = True
|
|
logger.info(f"✅ Global caps NATS connected: {NATS_URL}")
|
|
except Exception as e:
|
|
logger.warning(f"⚠️ Global caps NATS init failed (non-fatal): {e}")
|
|
_nats_client = None
|
|
_initialized = True
|
|
|
|
|
|
async def shutdown():
|
|
global _nats_client
|
|
if _nats_client:
|
|
try:
|
|
await _nats_client.close()
|
|
except Exception:
|
|
pass
|
|
_nats_client = None
|
|
|
|
|
|
async def _fetch_local() -> Optional[Dict[str, Any]]:
|
|
"""Fetch capabilities from local NCS via HTTP."""
|
|
if not LOCAL_NCS_URL:
|
|
return None
|
|
try:
|
|
async with httpx.AsyncClient(timeout=3) as c:
|
|
resp = await c.get(LOCAL_NCS_URL)
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
node_id = data.get("node_id", LOCAL_NODE_ID)
|
|
_node_cache[node_id] = data
|
|
_node_timestamps[node_id] = time.time()
|
|
return data
|
|
except Exception as e:
|
|
logger.warning(f"Local NCS fetch failed: {e}")
|
|
return _node_cache.get(LOCAL_NODE_ID)
|
|
|
|
|
|
async def _discover_remote_nodes() -> List[Dict[str, Any]]:
|
|
"""Scatter-gather discovery: send to node.*.capabilities.get, collect replies.
|
|
|
|
Each NCS on every node subscribes to node.{node_id}.capabilities.get.
|
|
NATS wildcard routing delivers our request to ALL of them.
|
|
We collect replies within NATS_DISCOVERY_TIMEOUT_MS.
|
|
|
|
This scales to 150+ nodes with zero static configuration:
|
|
- New node deploys NCS → subscribes to its subject → automatically discovered.
|
|
- Dead node stops responding → its cache entry expires after TTL.
|
|
"""
|
|
if not _nats_client:
|
|
return []
|
|
|
|
collected: List[Dict[str, Any]] = []
|
|
inbox = _nats_client.new_inbox()
|
|
sub = await _nats_client.subscribe(inbox)
|
|
|
|
try:
|
|
await _nats_client.publish_request(
|
|
"node.*.capabilities.get", inbox, b""
|
|
)
|
|
await _nats_client.flush()
|
|
|
|
deadline = time.time() + (NATS_DISCOVERY_TIMEOUT_MS / 1000.0)
|
|
while time.time() < deadline:
|
|
remaining = deadline - time.time()
|
|
if remaining <= 0:
|
|
break
|
|
try:
|
|
msg = await asyncio.wait_for(
|
|
sub.next_msg(), timeout=remaining,
|
|
)
|
|
data = json.loads(msg.data)
|
|
node_id = data.get("node_id", "?")
|
|
if node_id != LOCAL_NODE_ID:
|
|
_node_cache[node_id] = data
|
|
_node_timestamps[node_id] = time.time()
|
|
collected.append(data)
|
|
except asyncio.TimeoutError:
|
|
break
|
|
except Exception as e:
|
|
logger.debug(f"Discovery parse error: {e}")
|
|
break
|
|
finally:
|
|
await sub.unsubscribe()
|
|
|
|
if collected:
|
|
logger.info(
|
|
f"Discovered {len(collected)} remote node(s): "
|
|
f"{[c.get('node_id', '?') for c in collected]}"
|
|
)
|
|
return collected
|
|
|
|
|
|
def _evict_stale():
|
|
"""Remove nodes that haven't refreshed within 3x TTL."""
|
|
cutoff = time.time() - (CACHE_TTL * 3)
|
|
stale = [nid for nid, ts in _node_timestamps.items() if ts < cutoff]
|
|
for nid in stale:
|
|
_node_cache.pop(nid, None)
|
|
_node_timestamps.pop(nid, None)
|
|
logger.info(f"Evicted stale node: {nid}")
|
|
|
|
|
|
def _needs_refresh() -> bool:
|
|
"""Check if any node cache is older than TTL."""
|
|
if not _node_timestamps:
|
|
return True
|
|
oldest = min(_node_timestamps.values())
|
|
return (time.time() - oldest) > CACHE_TTL
|
|
|
|
|
|
async def get_global_capabilities(force: bool = False) -> Dict[str, Any]:
|
|
"""Return merged capabilities from all known nodes.
|
|
|
|
Returns:
|
|
{
|
|
"local_node": "noda1",
|
|
"nodes": {"noda1": {...}, "noda2": {...}, ...},
|
|
"served_models": [...], # all models with "node" field
|
|
"node_count": 2,
|
|
"updated_at": "...",
|
|
}
|
|
"""
|
|
if not force and not _needs_refresh():
|
|
return _build_global_view()
|
|
|
|
_evict_stale()
|
|
|
|
tasks = [_fetch_local()]
|
|
if _nats_client:
|
|
tasks.append(_discover_remote_nodes())
|
|
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
return _build_global_view()
|
|
|
|
|
|
def _build_global_view() -> Dict[str, Any]:
|
|
"""Build a unified view from all cached node capabilities."""
|
|
all_served: List[Dict[str, Any]] = []
|
|
|
|
for node_id, caps in _node_cache.items():
|
|
is_local = (node_id.lower() == LOCAL_NODE_ID.lower())
|
|
age = time.time() - _node_timestamps.get(node_id, 0)
|
|
for m in caps.get("served_models", []):
|
|
all_served.append({
|
|
**m,
|
|
"node": node_id,
|
|
"local": is_local,
|
|
"node_age_s": round(age, 1),
|
|
})
|
|
|
|
all_served.sort(key=lambda m: (0 if m.get("local") else 1, m.get("name", "")))
|
|
|
|
return {
|
|
"local_node": LOCAL_NODE_ID,
|
|
"nodes": {nid: {"node_id": nid, "served_count": len(c.get("served_models", [])),
|
|
"age_s": round(time.time() - _node_timestamps.get(nid, 0), 1)}
|
|
for nid, c in _node_cache.items()},
|
|
"served_models": all_served,
|
|
"served_count": len(all_served),
|
|
"node_count": len(_node_cache),
|
|
"updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
}
|
|
|
|
|
|
def get_cached_global() -> Dict[str, Any]:
|
|
"""Return cached global view without fetching."""
|
|
return _build_global_view()
|
|
|
|
|
|
async def send_offload_request(
|
|
node_id: str,
|
|
request_type: str,
|
|
payload: Dict[str, Any],
|
|
timeout_s: float = 30.0,
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Send an inference request to a remote node via NATS.
|
|
|
|
Subject pattern: node.{node_id}.{type}.request
|
|
Reply: inline NATS request/reply
|
|
"""
|
|
if not _nats_client:
|
|
logger.warning("Cannot offload: NATS not connected")
|
|
return None
|
|
|
|
subject = f"node.{node_id.lower()}.{request_type}.request"
|
|
try:
|
|
msg = await _nats_client.request(
|
|
subject,
|
|
json.dumps(payload).encode(),
|
|
timeout=timeout_s,
|
|
)
|
|
return json.loads(msg.data)
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Offload timeout: {subject} ({timeout_s}s)")
|
|
return None
|
|
except Exception as e:
|
|
logger.warning(f"Offload error: {subject}: {e}")
|
|
return None
|