Files
microdao-daarion/services/router/global_capabilities_client.py
Apple 9a36020316 P3.5-P3.7: 2-layer inventory, capability routing, STT/TTS adapters, Dev Contract
NCS:
- _collect_worker_caps() fetches capability flags from node-worker /caps
- _derive_capabilities() merges served model types + worker provider flags
- installed_artifacts replaces inventory_only (disk scan with DISK_SCAN_PATHS env)
- New endpoints: /capabilities/caps, /capabilities/installed

Node Worker:
- STT_PROVIDER, TTS_PROVIDER, OCR_PROVIDER, IMAGE_PROVIDER env flags
- /caps endpoint returns capabilities + providers for NCS aggregation
- STT adapter (providers/stt_mlx_whisper.py) — remote + local mode
- TTS adapter (providers/tts_mlx_kokoro.py) — remote + local mode
- OCR handler via vision_prompted (ollama_vision with OCR prompt)
- NATS subjects: node.{id}.stt/tts/ocr/image.request

Router:
- POST /v1/capability/{stt,tts,ocr,image} — capability-based offload routing
- GET /v1/capabilities — global view with capabilities_by_node
- require_fresh_caps(ttl) preflight guard
- find_nodes_with_capability(cap) + load-based node selection

Ops:
- ops/fabric_snapshot.py — full runtime snapshot collector
- ops/fabric_preflight.sh — quick check + snapshot save + diff
- docs/fabric_contract.md — Dev Contract v0.1 (preflight-first)
- tests/test_fabric_contract.py — CI enforcement (6 tests)

Made-with: Cursor
2026-02-27 05:24:09 -08:00

296 lines
9.6 KiB
Python

"""Global Capabilities Client — aggregates model capabilities across all nodes.
Design for 150+ nodes:
- Local NCS: HTTP (fast, always available)
- Remote nodes: NATS request/reply with wildcard discovery
- node.*.capabilities.get → each NCS replies with its capabilities
- No static node list needed — new nodes auto-register by subscribing
- scatter-gather pattern: send one request, collect N replies
- TTL cache per node, stale nodes expire automatically
"""
import asyncio
import json
import logging
import os
import time
from typing import Any, Dict, List, Optional
import httpx
logger = logging.getLogger("global_caps")
LOCAL_NCS_URL = os.getenv("NODE_CAPABILITIES_URL", "")
LOCAL_NODE_ID = os.getenv("NODE_ID", "unknown")
NATS_URL = os.getenv("NATS_URL", "nats://nats:4222")
CACHE_TTL = int(os.getenv("GLOBAL_CAPS_TTL", "30"))
NATS_DISCOVERY_TIMEOUT_MS = int(os.getenv("NATS_DISCOVERY_TIMEOUT_MS", "500"))
NATS_ENABLED = os.getenv("ENABLE_GLOBAL_CAPS_NATS", "true").lower() in ("true", "1")
CAPS_DISCOVERY_SUBJECT = "fabric.capabilities.discover"
CAPS_INBOX_PREFIX = "_CAPS_REPLY"
_node_cache: Dict[str, Dict[str, Any]] = {}
_node_timestamps: Dict[str, float] = {}
_nats_client = None
_initialized = False
async def initialize():
"""Connect to NATS for discovery. Called once at router startup."""
global _nats_client, _initialized
if not NATS_ENABLED:
logger.info("Global caps NATS discovery disabled")
_initialized = True
return
try:
import nats as nats_lib
_nats_client = await nats_lib.connect(NATS_URL)
_initialized = True
logger.info(f"✅ Global caps NATS connected: {NATS_URL}")
except Exception as e:
logger.warning(f"⚠️ Global caps NATS init failed (non-fatal): {e}")
_nats_client = None
_initialized = True
async def shutdown():
global _nats_client
if _nats_client:
try:
await _nats_client.close()
except Exception:
pass
_nats_client = None
async def _fetch_local() -> Optional[Dict[str, Any]]:
"""Fetch capabilities from local NCS via HTTP."""
if not LOCAL_NCS_URL:
return None
try:
async with httpx.AsyncClient(timeout=3) as c:
resp = await c.get(LOCAL_NCS_URL)
if resp.status_code == 200:
data = resp.json()
node_id = data.get("node_id", LOCAL_NODE_ID)
_node_cache[node_id] = data
_node_timestamps[node_id] = time.time()
return data
except Exception as e:
logger.warning(f"Local NCS fetch failed: {e}")
return _node_cache.get(LOCAL_NODE_ID)
async def _discover_remote_nodes() -> List[Dict[str, Any]]:
"""Scatter-gather discovery: send to node.*.capabilities.get, collect replies.
Each NCS on every node subscribes to node.{node_id}.capabilities.get.
NATS wildcard routing delivers our request to ALL of them.
We collect replies within NATS_DISCOVERY_TIMEOUT_MS.
This scales to 150+ nodes with zero static configuration:
- New node deploys NCS → subscribes to its subject → automatically discovered.
- Dead node stops responding → its cache entry expires after TTL.
"""
if not _nats_client:
return []
collected: List[Dict[str, Any]] = []
inbox = _nats_client.new_inbox()
sub = await _nats_client.subscribe(inbox)
try:
await _nats_client.publish(
CAPS_DISCOVERY_SUBJECT, b"", reply=inbox,
)
await _nats_client.flush()
deadline = time.time() + (NATS_DISCOVERY_TIMEOUT_MS / 1000.0)
while time.time() < deadline:
remaining = deadline - time.time()
if remaining <= 0:
break
try:
msg = await asyncio.wait_for(
sub.next_msg(), timeout=remaining,
)
data = json.loads(msg.data)
node_id = data.get("node_id", "?")
if node_id != LOCAL_NODE_ID:
_node_cache[node_id] = data
_node_timestamps[node_id] = time.time()
collected.append(data)
except asyncio.TimeoutError:
break
except Exception as e:
logger.debug(f"Discovery parse error: {e}")
break
finally:
await sub.unsubscribe()
if collected:
logger.info(
f"Discovered {len(collected)} remote node(s): "
f"{[c.get('node_id', '?') for c in collected]}"
)
return collected
def _evict_stale():
"""Remove nodes that haven't refreshed within 3x TTL."""
cutoff = time.time() - (CACHE_TTL * 3)
stale = [nid for nid, ts in _node_timestamps.items() if ts < cutoff]
for nid in stale:
_node_cache.pop(nid, None)
_node_timestamps.pop(nid, None)
logger.info(f"Evicted stale node: {nid}")
def _needs_refresh() -> bool:
"""Check if any node cache is older than TTL."""
if not _node_timestamps:
return True
oldest = min(_node_timestamps.values())
return (time.time() - oldest) > CACHE_TTL
async def get_global_capabilities(force: bool = False) -> Dict[str, Any]:
"""Return merged capabilities from all known nodes.
Returns:
{
"local_node": "noda1",
"nodes": {"noda1": {...}, "noda2": {...}, ...},
"served_models": [...], # all models with "node" field
"node_count": 2,
"updated_at": "...",
}
"""
if not force and not _needs_refresh():
return _build_global_view()
_evict_stale()
tasks = [_fetch_local()]
if _nats_client:
tasks.append(_discover_remote_nodes())
await asyncio.gather(*tasks, return_exceptions=True)
return _build_global_view()
def _build_global_view() -> Dict[str, Any]:
"""Build a unified view from all cached node capabilities."""
all_served: List[Dict[str, Any]] = []
global_caps: Dict[str, Dict[str, Any]] = {}
for node_id, caps in _node_cache.items():
is_local = (node_id.lower() == LOCAL_NODE_ID.lower())
age = time.time() - _node_timestamps.get(node_id, 0)
for m in caps.get("served_models", []):
all_served.append({
**m,
"node": node_id,
"local": is_local,
"node_age_s": round(age, 1),
})
node_caps = caps.get("capabilities", {})
if node_caps:
global_caps[node_id] = {
k: v for k, v in node_caps.items() if k != "providers"
}
all_served.sort(key=lambda m: (0 if m.get("local") else 1, m.get("name", "")))
return {
"local_node": LOCAL_NODE_ID,
"nodes": {nid: {
"node_id": nid,
"served_count": len(c.get("served_models", [])),
"installed_count": c.get("installed_count", 0),
"capabilities": c.get("capabilities", {}),
"node_load": c.get("node_load", {}),
"age_s": round(time.time() - _node_timestamps.get(nid, 0), 1),
} for nid, c in _node_cache.items()},
"served_models": all_served,
"served_count": len(all_served),
"capabilities_by_node": global_caps,
"node_count": len(_node_cache),
"updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
def get_cached_global() -> Dict[str, Any]:
"""Return cached global view without fetching."""
return _build_global_view()
async def require_fresh_caps(ttl: int = 30) -> Optional[Dict[str, Any]]:
"""Preflight: return global caps only if fresh enough.
Returns None if NCS data is stale beyond ttl — caller should use
safe fallback instead of making routing decisions on outdated info.
"""
if not _node_timestamps:
gcaps = await get_global_capabilities(force=True)
if not _node_timestamps:
return None
return gcaps
oldest = min(_node_timestamps.values())
if (time.time() - oldest) > ttl:
gcaps = await get_global_capabilities(force=True)
oldest = min(_node_timestamps.values()) if _node_timestamps else 0
if (time.time() - oldest) > ttl:
logger.warning("[preflight] caps stale after refresh, age=%ds", int(time.time() - oldest))
return None
return gcaps
return _build_global_view()
def find_nodes_with_capability(cap: str) -> List[str]:
"""Return node IDs that have a given capability enabled."""
result = []
for nid, caps in _node_cache.items():
node_caps = caps.get("capabilities", {})
if node_caps.get(cap, False):
result.append(nid)
return result
def get_node_load(node_id: str) -> Dict[str, Any]:
"""Get cached node_load for a specific node."""
caps = _node_cache.get(node_id, {})
return caps.get("node_load", {})
async def send_offload_request(
node_id: str,
request_type: str,
payload: Dict[str, Any],
timeout_s: float = 30.0,
) -> Optional[Dict[str, Any]]:
"""Send an inference request to a remote node via NATS.
Subject pattern: node.{node_id}.{type}.request
Reply: inline NATS request/reply
"""
if not _nats_client:
logger.warning("Cannot offload: NATS not connected")
return None
subject = f"node.{node_id.lower()}.{request_type}.request"
try:
msg = await _nats_client.request(
subject,
json.dumps(payload).encode(),
timeout=timeout_s,
)
return json.loads(msg.data)
except asyncio.TimeoutError:
logger.warning(f"Offload timeout: {subject} ({timeout_s}s)")
return None
except Exception as e:
logger.warning(f"Offload error: {subject}: {e}")
return None