NCS:
- _collect_worker_caps() fetches capability flags from node-worker /caps
- _derive_capabilities() merges served model types + worker provider flags
- installed_artifacts replaces inventory_only (disk scan with DISK_SCAN_PATHS env)
- New endpoints: /capabilities/caps, /capabilities/installed
Node Worker:
- STT_PROVIDER, TTS_PROVIDER, OCR_PROVIDER, IMAGE_PROVIDER env flags
- /caps endpoint returns capabilities + providers for NCS aggregation
- STT adapter (providers/stt_mlx_whisper.py) — remote + local mode
- TTS adapter (providers/tts_mlx_kokoro.py) — remote + local mode
- OCR handler via vision_prompted (ollama_vision with OCR prompt)
- NATS subjects: node.{id}.stt/tts/ocr/image.request
Router:
- POST /v1/capability/{stt,tts,ocr,image} — capability-based offload routing
- GET /v1/capabilities — global view with capabilities_by_node
- require_fresh_caps(ttl) preflight guard
- find_nodes_with_capability(cap) + load-based node selection
Ops:
- ops/fabric_snapshot.py — full runtime snapshot collector
- ops/fabric_preflight.sh — quick check + snapshot save + diff
- docs/fabric_contract.md — Dev Contract v0.1 (preflight-first)
- tests/test_fabric_contract.py — CI enforcement (6 tests)
Made-with: Cursor
296 lines
9.6 KiB
Python
296 lines
9.6 KiB
Python
"""Global Capabilities Client — aggregates model capabilities across all nodes.
|
|
|
|
Design for 150+ nodes:
|
|
- Local NCS: HTTP (fast, always available)
|
|
- Remote nodes: NATS request/reply with wildcard discovery
|
|
- node.*.capabilities.get → each NCS replies with its capabilities
|
|
- No static node list needed — new nodes auto-register by subscribing
|
|
- scatter-gather pattern: send one request, collect N replies
|
|
- TTL cache per node, stale nodes expire automatically
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger("global_caps")
|
|
|
|
LOCAL_NCS_URL = os.getenv("NODE_CAPABILITIES_URL", "")
|
|
LOCAL_NODE_ID = os.getenv("NODE_ID", "unknown")
|
|
NATS_URL = os.getenv("NATS_URL", "nats://nats:4222")
|
|
CACHE_TTL = int(os.getenv("GLOBAL_CAPS_TTL", "30"))
|
|
NATS_DISCOVERY_TIMEOUT_MS = int(os.getenv("NATS_DISCOVERY_TIMEOUT_MS", "500"))
|
|
NATS_ENABLED = os.getenv("ENABLE_GLOBAL_CAPS_NATS", "true").lower() in ("true", "1")
|
|
|
|
CAPS_DISCOVERY_SUBJECT = "fabric.capabilities.discover"
|
|
CAPS_INBOX_PREFIX = "_CAPS_REPLY"
|
|
|
|
_node_cache: Dict[str, Dict[str, Any]] = {}
|
|
_node_timestamps: Dict[str, float] = {}
|
|
_nats_client = None
|
|
_initialized = False
|
|
|
|
|
|
async def initialize():
|
|
"""Connect to NATS for discovery. Called once at router startup."""
|
|
global _nats_client, _initialized
|
|
if not NATS_ENABLED:
|
|
logger.info("Global caps NATS discovery disabled")
|
|
_initialized = True
|
|
return
|
|
try:
|
|
import nats as nats_lib
|
|
_nats_client = await nats_lib.connect(NATS_URL)
|
|
_initialized = True
|
|
logger.info(f"✅ Global caps NATS connected: {NATS_URL}")
|
|
except Exception as e:
|
|
logger.warning(f"⚠️ Global caps NATS init failed (non-fatal): {e}")
|
|
_nats_client = None
|
|
_initialized = True
|
|
|
|
|
|
async def shutdown():
|
|
global _nats_client
|
|
if _nats_client:
|
|
try:
|
|
await _nats_client.close()
|
|
except Exception:
|
|
pass
|
|
_nats_client = None
|
|
|
|
|
|
async def _fetch_local() -> Optional[Dict[str, Any]]:
|
|
"""Fetch capabilities from local NCS via HTTP."""
|
|
if not LOCAL_NCS_URL:
|
|
return None
|
|
try:
|
|
async with httpx.AsyncClient(timeout=3) as c:
|
|
resp = await c.get(LOCAL_NCS_URL)
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
node_id = data.get("node_id", LOCAL_NODE_ID)
|
|
_node_cache[node_id] = data
|
|
_node_timestamps[node_id] = time.time()
|
|
return data
|
|
except Exception as e:
|
|
logger.warning(f"Local NCS fetch failed: {e}")
|
|
return _node_cache.get(LOCAL_NODE_ID)
|
|
|
|
|
|
async def _discover_remote_nodes() -> List[Dict[str, Any]]:
|
|
"""Scatter-gather discovery: send to node.*.capabilities.get, collect replies.
|
|
|
|
Each NCS on every node subscribes to node.{node_id}.capabilities.get.
|
|
NATS wildcard routing delivers our request to ALL of them.
|
|
We collect replies within NATS_DISCOVERY_TIMEOUT_MS.
|
|
|
|
This scales to 150+ nodes with zero static configuration:
|
|
- New node deploys NCS → subscribes to its subject → automatically discovered.
|
|
- Dead node stops responding → its cache entry expires after TTL.
|
|
"""
|
|
if not _nats_client:
|
|
return []
|
|
|
|
collected: List[Dict[str, Any]] = []
|
|
inbox = _nats_client.new_inbox()
|
|
sub = await _nats_client.subscribe(inbox)
|
|
|
|
try:
|
|
await _nats_client.publish(
|
|
CAPS_DISCOVERY_SUBJECT, b"", reply=inbox,
|
|
)
|
|
await _nats_client.flush()
|
|
|
|
deadline = time.time() + (NATS_DISCOVERY_TIMEOUT_MS / 1000.0)
|
|
while time.time() < deadline:
|
|
remaining = deadline - time.time()
|
|
if remaining <= 0:
|
|
break
|
|
try:
|
|
msg = await asyncio.wait_for(
|
|
sub.next_msg(), timeout=remaining,
|
|
)
|
|
data = json.loads(msg.data)
|
|
node_id = data.get("node_id", "?")
|
|
if node_id != LOCAL_NODE_ID:
|
|
_node_cache[node_id] = data
|
|
_node_timestamps[node_id] = time.time()
|
|
collected.append(data)
|
|
except asyncio.TimeoutError:
|
|
break
|
|
except Exception as e:
|
|
logger.debug(f"Discovery parse error: {e}")
|
|
break
|
|
finally:
|
|
await sub.unsubscribe()
|
|
|
|
if collected:
|
|
logger.info(
|
|
f"Discovered {len(collected)} remote node(s): "
|
|
f"{[c.get('node_id', '?') for c in collected]}"
|
|
)
|
|
return collected
|
|
|
|
|
|
def _evict_stale():
|
|
"""Remove nodes that haven't refreshed within 3x TTL."""
|
|
cutoff = time.time() - (CACHE_TTL * 3)
|
|
stale = [nid for nid, ts in _node_timestamps.items() if ts < cutoff]
|
|
for nid in stale:
|
|
_node_cache.pop(nid, None)
|
|
_node_timestamps.pop(nid, None)
|
|
logger.info(f"Evicted stale node: {nid}")
|
|
|
|
|
|
def _needs_refresh() -> bool:
|
|
"""Check if any node cache is older than TTL."""
|
|
if not _node_timestamps:
|
|
return True
|
|
oldest = min(_node_timestamps.values())
|
|
return (time.time() - oldest) > CACHE_TTL
|
|
|
|
|
|
async def get_global_capabilities(force: bool = False) -> Dict[str, Any]:
|
|
"""Return merged capabilities from all known nodes.
|
|
|
|
Returns:
|
|
{
|
|
"local_node": "noda1",
|
|
"nodes": {"noda1": {...}, "noda2": {...}, ...},
|
|
"served_models": [...], # all models with "node" field
|
|
"node_count": 2,
|
|
"updated_at": "...",
|
|
}
|
|
"""
|
|
if not force and not _needs_refresh():
|
|
return _build_global_view()
|
|
|
|
_evict_stale()
|
|
|
|
tasks = [_fetch_local()]
|
|
if _nats_client:
|
|
tasks.append(_discover_remote_nodes())
|
|
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
return _build_global_view()
|
|
|
|
|
|
def _build_global_view() -> Dict[str, Any]:
|
|
"""Build a unified view from all cached node capabilities."""
|
|
all_served: List[Dict[str, Any]] = []
|
|
global_caps: Dict[str, Dict[str, Any]] = {}
|
|
|
|
for node_id, caps in _node_cache.items():
|
|
is_local = (node_id.lower() == LOCAL_NODE_ID.lower())
|
|
age = time.time() - _node_timestamps.get(node_id, 0)
|
|
for m in caps.get("served_models", []):
|
|
all_served.append({
|
|
**m,
|
|
"node": node_id,
|
|
"local": is_local,
|
|
"node_age_s": round(age, 1),
|
|
})
|
|
node_caps = caps.get("capabilities", {})
|
|
if node_caps:
|
|
global_caps[node_id] = {
|
|
k: v for k, v in node_caps.items() if k != "providers"
|
|
}
|
|
|
|
all_served.sort(key=lambda m: (0 if m.get("local") else 1, m.get("name", "")))
|
|
|
|
return {
|
|
"local_node": LOCAL_NODE_ID,
|
|
"nodes": {nid: {
|
|
"node_id": nid,
|
|
"served_count": len(c.get("served_models", [])),
|
|
"installed_count": c.get("installed_count", 0),
|
|
"capabilities": c.get("capabilities", {}),
|
|
"node_load": c.get("node_load", {}),
|
|
"age_s": round(time.time() - _node_timestamps.get(nid, 0), 1),
|
|
} for nid, c in _node_cache.items()},
|
|
"served_models": all_served,
|
|
"served_count": len(all_served),
|
|
"capabilities_by_node": global_caps,
|
|
"node_count": len(_node_cache),
|
|
"updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
}
|
|
|
|
|
|
def get_cached_global() -> Dict[str, Any]:
|
|
"""Return cached global view without fetching."""
|
|
return _build_global_view()
|
|
|
|
|
|
async def require_fresh_caps(ttl: int = 30) -> Optional[Dict[str, Any]]:
|
|
"""Preflight: return global caps only if fresh enough.
|
|
|
|
Returns None if NCS data is stale beyond ttl — caller should use
|
|
safe fallback instead of making routing decisions on outdated info.
|
|
"""
|
|
if not _node_timestamps:
|
|
gcaps = await get_global_capabilities(force=True)
|
|
if not _node_timestamps:
|
|
return None
|
|
return gcaps
|
|
oldest = min(_node_timestamps.values())
|
|
if (time.time() - oldest) > ttl:
|
|
gcaps = await get_global_capabilities(force=True)
|
|
oldest = min(_node_timestamps.values()) if _node_timestamps else 0
|
|
if (time.time() - oldest) > ttl:
|
|
logger.warning("[preflight] caps stale after refresh, age=%ds", int(time.time() - oldest))
|
|
return None
|
|
return gcaps
|
|
return _build_global_view()
|
|
|
|
|
|
def find_nodes_with_capability(cap: str) -> List[str]:
|
|
"""Return node IDs that have a given capability enabled."""
|
|
result = []
|
|
for nid, caps in _node_cache.items():
|
|
node_caps = caps.get("capabilities", {})
|
|
if node_caps.get(cap, False):
|
|
result.append(nid)
|
|
return result
|
|
|
|
|
|
def get_node_load(node_id: str) -> Dict[str, Any]:
|
|
"""Get cached node_load for a specific node."""
|
|
caps = _node_cache.get(node_id, {})
|
|
return caps.get("node_load", {})
|
|
|
|
|
|
async def send_offload_request(
|
|
node_id: str,
|
|
request_type: str,
|
|
payload: Dict[str, Any],
|
|
timeout_s: float = 30.0,
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Send an inference request to a remote node via NATS.
|
|
|
|
Subject pattern: node.{node_id}.{type}.request
|
|
Reply: inline NATS request/reply
|
|
"""
|
|
if not _nats_client:
|
|
logger.warning("Cannot offload: NATS not connected")
|
|
return None
|
|
|
|
subject = f"node.{node_id.lower()}.{request_type}.request"
|
|
try:
|
|
msg = await _nats_client.request(
|
|
subject,
|
|
json.dumps(payload).encode(),
|
|
timeout=timeout_s,
|
|
)
|
|
return json.loads(msg.data)
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Offload timeout: {subject} ({timeout_s}s)")
|
|
return None
|
|
except Exception as e:
|
|
logger.warning(f"Offload error: {subject}: {e}")
|
|
return None
|