P2: Global multi-node model selection + NCS on NODA1

Architecture for 150+ nodes:
- global_capabilities_client.py: NATS scatter-gather discovery using
  wildcard subject node.*.capabilities.get — zero static node lists.
  New nodes auto-register by deploying NCS and subscribing to NATS.
  Dead nodes expire from cache after 3x TTL automatically.

Multi-node model_select.py:
- ModelSelection now includes node, local, via_nats fields
- select_best_model prefers local candidates, then remote
- Prefer list resolution: local first, remote second
- All logged per request: node, runtime, model, local/remote

NODA1 compose:
- Added node-capabilities service (NCS) to docker-compose.node1.yml
- NATS subscription: node.noda1.capabilities.get
- Router env: NODE_CAPABILITIES_URL + ENABLE_GLOBAL_CAPS_NATS=true

NODA2 compose:
- Router env: ENABLE_GLOBAL_CAPS_NATS=true

Router main.py:
- Startup: initializes global_capabilities_client (NATS connect + first
  discovery). Falls back to local-only capabilities_client if unavailable.
- /infer: uses get_global_capabilities() for cross-node model pool
- Offload support: send_offload_request(node_id, type, payload) via NATS

Verified on NODA2:
- Global caps: 1 node, 14 models (NODA1 not yet deployed)
- Sofiia: cloud_grok → grok-4-1-fast-reasoning (OK)
- Helion: NCS → qwen3:14b local (OK)
- When NODA1 deploys NCS, its models appear automatically via NATS discovery

Made-with: Cursor
This commit is contained in:
Apple
2026-02-27 02:26:12 -08:00
parent 89c3f2ac66
commit a92c424845
5 changed files with 575 additions and 62 deletions

View File

@@ -0,0 +1,245 @@
"""Global Capabilities Client — aggregates model capabilities across all nodes.
Design for 150+ nodes:
- Local NCS: HTTP (fast, always available)
- Remote nodes: NATS request/reply with wildcard discovery
- node.*.capabilities.get → each NCS replies with its capabilities
- No static node list needed — new nodes auto-register by subscribing
- scatter-gather pattern: send one request, collect N replies
- TTL cache per node, stale nodes expire automatically
"""
import asyncio
import json
import logging
import os
import time
from typing import Any, Dict, List, Optional
import httpx
logger = logging.getLogger("global_caps")
LOCAL_NCS_URL = os.getenv("NODE_CAPABILITIES_URL", "")
LOCAL_NODE_ID = os.getenv("NODE_ID", "unknown")
NATS_URL = os.getenv("NATS_URL", "nats://nats:4222")
CACHE_TTL = int(os.getenv("GLOBAL_CAPS_TTL", "30"))
NATS_DISCOVERY_TIMEOUT_MS = int(os.getenv("NATS_DISCOVERY_TIMEOUT_MS", "500"))
NATS_ENABLED = os.getenv("ENABLE_GLOBAL_CAPS_NATS", "true").lower() in ("true", "1")
CAPS_DISCOVERY_SUBJECT = "node.*.capabilities.get"
CAPS_INBOX_PREFIX = "_CAPS_REPLY"
_node_cache: Dict[str, Dict[str, Any]] = {}
_node_timestamps: Dict[str, float] = {}
_nats_client = None
_initialized = False
async def initialize():
"""Connect to NATS for discovery. Called once at router startup."""
global _nats_client, _initialized
if not NATS_ENABLED:
logger.info("Global caps NATS discovery disabled")
_initialized = True
return
try:
import nats as nats_lib
_nats_client = await nats_lib.connect(NATS_URL)
_initialized = True
logger.info(f"✅ Global caps NATS connected: {NATS_URL}")
except Exception as e:
logger.warning(f"⚠️ Global caps NATS init failed (non-fatal): {e}")
_nats_client = None
_initialized = True
async def shutdown():
global _nats_client
if _nats_client:
try:
await _nats_client.close()
except Exception:
pass
_nats_client = None
async def _fetch_local() -> Optional[Dict[str, Any]]:
"""Fetch capabilities from local NCS via HTTP."""
if not LOCAL_NCS_URL:
return None
try:
async with httpx.AsyncClient(timeout=3) as c:
resp = await c.get(LOCAL_NCS_URL)
if resp.status_code == 200:
data = resp.json()
node_id = data.get("node_id", LOCAL_NODE_ID)
_node_cache[node_id] = data
_node_timestamps[node_id] = time.time()
return data
except Exception as e:
logger.warning(f"Local NCS fetch failed: {e}")
return _node_cache.get(LOCAL_NODE_ID)
async def _discover_remote_nodes() -> List[Dict[str, Any]]:
"""Scatter-gather discovery: send to node.*.capabilities.get, collect replies.
Each NCS on every node subscribes to node.{node_id}.capabilities.get.
NATS wildcard routing delivers our request to ALL of them.
We collect replies within NATS_DISCOVERY_TIMEOUT_MS.
This scales to 150+ nodes with zero static configuration:
- New node deploys NCS → subscribes to its subject → automatically discovered.
- Dead node stops responding → its cache entry expires after TTL.
"""
if not _nats_client:
return []
collected: List[Dict[str, Any]] = []
inbox = _nats_client.new_inbox()
sub = await _nats_client.subscribe(inbox)
try:
await _nats_client.publish_request(
"node.*.capabilities.get", inbox, b""
)
await _nats_client.flush()
deadline = time.time() + (NATS_DISCOVERY_TIMEOUT_MS / 1000.0)
while time.time() < deadline:
remaining = deadline - time.time()
if remaining <= 0:
break
try:
msg = await asyncio.wait_for(
sub.next_msg(), timeout=remaining,
)
data = json.loads(msg.data)
node_id = data.get("node_id", "?")
if node_id != LOCAL_NODE_ID:
_node_cache[node_id] = data
_node_timestamps[node_id] = time.time()
collected.append(data)
except asyncio.TimeoutError:
break
except Exception as e:
logger.debug(f"Discovery parse error: {e}")
break
finally:
await sub.unsubscribe()
if collected:
logger.info(
f"Discovered {len(collected)} remote node(s): "
f"{[c.get('node_id', '?') for c in collected]}"
)
return collected
def _evict_stale():
"""Remove nodes that haven't refreshed within 3x TTL."""
cutoff = time.time() - (CACHE_TTL * 3)
stale = [nid for nid, ts in _node_timestamps.items() if ts < cutoff]
for nid in stale:
_node_cache.pop(nid, None)
_node_timestamps.pop(nid, None)
logger.info(f"Evicted stale node: {nid}")
def _needs_refresh() -> bool:
"""Check if any node cache is older than TTL."""
if not _node_timestamps:
return True
oldest = min(_node_timestamps.values())
return (time.time() - oldest) > CACHE_TTL
async def get_global_capabilities(force: bool = False) -> Dict[str, Any]:
"""Return merged capabilities from all known nodes.
Returns:
{
"local_node": "noda1",
"nodes": {"noda1": {...}, "noda2": {...}, ...},
"served_models": [...], # all models with "node" field
"node_count": 2,
"updated_at": "...",
}
"""
if not force and not _needs_refresh():
return _build_global_view()
_evict_stale()
tasks = [_fetch_local()]
if _nats_client:
tasks.append(_discover_remote_nodes())
await asyncio.gather(*tasks, return_exceptions=True)
return _build_global_view()
def _build_global_view() -> Dict[str, Any]:
"""Build a unified view from all cached node capabilities."""
all_served: List[Dict[str, Any]] = []
for node_id, caps in _node_cache.items():
is_local = (node_id.lower() == LOCAL_NODE_ID.lower())
age = time.time() - _node_timestamps.get(node_id, 0)
for m in caps.get("served_models", []):
all_served.append({
**m,
"node": node_id,
"local": is_local,
"node_age_s": round(age, 1),
})
all_served.sort(key=lambda m: (0 if m.get("local") else 1, m.get("name", "")))
return {
"local_node": LOCAL_NODE_ID,
"nodes": {nid: {"node_id": nid, "served_count": len(c.get("served_models", [])),
"age_s": round(time.time() - _node_timestamps.get(nid, 0), 1)}
for nid, c in _node_cache.items()},
"served_models": all_served,
"served_count": len(all_served),
"node_count": len(_node_cache),
"updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
def get_cached_global() -> Dict[str, Any]:
"""Return cached global view without fetching."""
return _build_global_view()
async def send_offload_request(
node_id: str,
request_type: str,
payload: Dict[str, Any],
timeout_s: float = 30.0,
) -> Optional[Dict[str, Any]]:
"""Send an inference request to a remote node via NATS.
Subject pattern: node.{node_id}.{type}.request
Reply: inline NATS request/reply
"""
if not _nats_client:
logger.warning("Cannot offload: NATS not connected")
return None
subject = f"node.{node_id.lower()}.{request_type}.request"
try:
msg = await _nats_client.request(
subject,
json.dumps(payload).encode(),
timeout=timeout_s,
)
return json.loads(msg.data)
except asyncio.TimeoutError:
logger.warning(f"Offload timeout: {subject} ({timeout_s}s)")
return None
except Exception as e:
logger.warning(f"Offload error: {subject}: {e}")
return None

View File

@@ -46,14 +46,16 @@ except ImportError:
RUNTIME_GUARD_AVAILABLE = False
RuntimeGuard = None
# NCS-first model selection
# NCS-first model selection (multi-node global)
try:
import capabilities_client
import global_capabilities_client
from model_select import select_model_for_agent, ModelSelection, CLOUD_PROVIDERS as NCS_CLOUD_PROVIDERS
NCS_AVAILABLE = True
except ImportError:
NCS_AVAILABLE = False
capabilities_client = None # type: ignore[assignment]
global_capabilities_client = None # type: ignore[assignment]
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -765,7 +767,7 @@ async def startup_event():
else:
tool_manager = None
# Initialize Node Capabilities client
# Initialize Node Capabilities (local + global multi-node)
if NCS_AVAILABLE and capabilities_client:
ncs_cfg = router_config.get("node_capabilities", {})
ncs_url = ncs_cfg.get("url", "") or os.getenv("NODE_CAPABILITIES_URL", "")
@@ -774,11 +776,17 @@ async def startup_event():
capabilities_client.configure(url=ncs_url, ttl=ncs_ttl)
caps = await capabilities_client.fetch_capabilities()
served = caps.get("served_count", 0)
logger.info(f"✅ NCS configured: url={ncs_url} ttl={ncs_ttl}s served={served} models")
logger.info(f"✅ NCS local configured: url={ncs_url} ttl={ncs_ttl}s served={served} models")
else:
logger.warning("⚠️ NCS url not configured; model selection will use static config only")
elif NCS_AVAILABLE:
logger.info(" NCS modules loaded but capabilities_client is None")
if global_capabilities_client:
await global_capabilities_client.initialize()
gcaps = await global_capabilities_client.get_global_capabilities()
logger.info(
f"✅ Global caps: {gcaps.get('node_count', 0)} node(s), "
f"{gcaps.get('served_count', 0)} total models"
)
else:
logger.warning("⚠️ NCS modules not available (model_select / capabilities_client import failed)")
@@ -1629,18 +1637,24 @@ async def agent_infer(agent_id: str, request: InferRequest):
cloud_provider_names = {"deepseek", "mistral", "grok", "openai", "anthropic"}
# ── NCS-first model selection ────────────────────────────────────────
# ── Global NCS-first model selection (multi-node) ───────────────────
ncs_selection = None
if NCS_AVAILABLE and capabilities_client:
if NCS_AVAILABLE and global_capabilities_client:
try:
gcaps = await global_capabilities_client.get_global_capabilities()
ncs_selection = await select_model_for_agent(
agent_id, agent_config, router_config, gcaps, request.model,
)
except Exception as e:
logger.warning(f"⚠️ Global NCS selection error: {e}; falling back to static")
elif NCS_AVAILABLE and capabilities_client:
try:
caps = await capabilities_client.fetch_capabilities()
if caps:
caps["_fetch_ts"] = capabilities_client._cache_ts
ncs_selection = await select_model_for_agent(
agent_id, agent_config, router_config, caps, request.model,
)
except Exception as e:
logger.warning(f"⚠️ NCS selection error: {e}; falling back to static config")
logger.warning(f"⚠️ NCS selection error: {e}; falling back to static")
llm_profiles = router_config.get("llm_profiles", {})
@@ -1651,9 +1665,10 @@ async def agent_infer(agent_id: str, request: InferRequest):
if ncs_selection.base_url and provider == "ollama":
llm_profile = {**llm_profile, "base_url": ncs_selection.base_url}
logger.info(
f"🎯 NCS select: agent={agent_id} profile={default_llm} "
f"→ runtime={ncs_selection.runtime} model={model} "
f"provider={provider} via_ncs={ncs_selection.via_ncs} "
f"🎯 Select: agent={agent_id} profile={default_llm} "
f" node={ncs_selection.node} runtime={ncs_selection.runtime} "
f"model={model} provider={provider} "
f"local={ncs_selection.local} via_nats={ncs_selection.via_nats} "
f"caps_age={ncs_selection.caps_age_s}s "
f"fallback={ncs_selection.fallback_reason or 'none'}"
)

View File

@@ -1,8 +1,10 @@
"""NCS-first model selection for DAGI Router.
"""NCS-first model selection for DAGI Router — multi-node aware.
Resolves an agent's LLM profile into a concrete model+provider using live
capabilities from the Node Capabilities Service (NCS). Falls back to static
router-config.yml when NCS is unavailable.
capabilities from Node Capabilities Services across all nodes.
Falls back to static router-config.yml when NCS is unavailable.
Scaling: works with 1 node or 150+. No static node lists.
"""
import logging
import time
@@ -31,7 +33,10 @@ class ModelSelection:
model_type: str # llm | vision | code | …
base_url: str = ""
provider: str = "" # cloud provider name if applicable
node: str = "" # which node owns this model
local: bool = True # is it on the current node?
via_ncs: bool = False
via_nats: bool = False
fallback_reason: str = ""
caps_age_s: float = 0.0
@@ -44,13 +49,11 @@ def resolve_effective_profile(
router_cfg: Dict[str, Any],
request_model: Optional[str] = None,
) -> str:
"""Determine the effective LLM profile name for a request."""
if request_model:
llm_profiles = router_cfg.get("llm_profiles", {})
for pname, pcfg in llm_profiles.items():
if pcfg.get("model") == request_model:
return pname
return agent_cfg.get("default_llm", "local_default_coder")
@@ -59,11 +62,6 @@ def profile_requirements(
agent_cfg: Dict[str, Any],
router_cfg: Dict[str, Any],
) -> ProfileRequirements:
"""Build selection requirements from a profile definition.
If the profile has `selection_policy` in config, use it directly.
Otherwise, infer from the legacy `provider`/`model` fields.
"""
llm_profiles = router_cfg.get("llm_profiles", {})
selection_policies = router_cfg.get("selection_policies", {})
profile_cfg = llm_profiles.get(profile_name, {})
@@ -107,22 +105,23 @@ def profile_requirements(
)
# ── NCS-based selection ───────────────────────────────────────────────────────
# ── Multi-node model selection ────────────────────────────────────────────────
def select_best_model(
reqs: ProfileRequirements,
capabilities: Dict[str, Any],
) -> Optional[ModelSelection]:
"""Choose the best served model from NCS capabilities.
"""Choose the best served model from global (multi-node) capabilities.
Returns None if no suitable model found (caller should try static fallback).
Selection order:
1. Prefer list matches (local first, then remote)
2. Best candidate by size (local first, then remote)
3. None → caller should try static fallback
"""
served = capabilities.get("served_models", [])
if not served:
return None
caps_age = time.time() - capabilities.get("_fetch_ts", time.time())
search_types = [reqs.required_type]
if reqs.required_type == "code":
search_types.append("llm")
@@ -133,24 +132,30 @@ def select_best_model(
if not candidates:
return None
local_candidates = [m for m in candidates if m.get("local", False)]
remote_candidates = [m for m in candidates if not m.get("local", False)]
prefer = reqs.prefer if reqs.prefer else []
for pref in prefer:
if pref == "*":
break
for m in candidates:
for m in local_candidates:
if pref == m.get("name") or pref in m.get("name", ""):
return _make_selection(m, capabilities, caps_age, reqs)
return _make_selection(m, capabilities)
for m in remote_candidates:
if pref == m.get("name") or pref in m.get("name", ""):
return _make_selection(m, capabilities)
if candidates:
best = _pick_best_candidate(candidates)
return _make_selection(best, capabilities, caps_age, reqs)
if local_candidates:
return _make_selection(_pick_best(local_candidates), capabilities)
if remote_candidates:
return _make_selection(_pick_best(remote_candidates), capabilities)
return None
def _pick_best_candidate(candidates: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Prefer running models, then largest by size_gb."""
def _pick_best(candidates: List[Dict[str, Any]]) -> Dict[str, Any]:
running = [m for m in candidates if m.get("running")]
pool = running if running else candidates
return max(pool, key=lambda m: m.get("size_gb", 0))
@@ -159,15 +164,11 @@ def _pick_best_candidate(candidates: List[Dict[str, Any]]) -> Dict[str, Any]:
def _make_selection(
model: Dict[str, Any],
capabilities: Dict[str, Any],
caps_age: float,
reqs: ProfileRequirements,
) -> ModelSelection:
runtime = model.get("runtime", "ollama")
is_local = model.get("local", False)
node = model.get("node", capabilities.get("local_node", ""))
base_url = model.get("base_url", "")
if not base_url:
runtimes = capabilities.get("runtimes", {})
rt = runtimes.get(runtime, {})
base_url = rt.get("base_url", "")
return ModelSelection(
runtime=runtime,
@@ -175,18 +176,20 @@ def _make_selection(
model_type=model.get("type", "llm"),
base_url=base_url,
provider="ollama" if runtime in ("ollama", "llama_server") else runtime,
node=node,
local=is_local,
via_ncs=True,
caps_age_s=round(caps_age, 1),
via_nats=not is_local,
caps_age_s=model.get("node_age_s", 0.0),
)
# ── Static fallback (from router-config profiles) ────────────────────────────
# ── Static fallback ──────────────────────────────────────────────────────────
def static_fallback(
profile_name: str,
router_cfg: Dict[str, Any],
) -> Optional[ModelSelection]:
"""Build a ModelSelection from the static llm_profiles config."""
llm_profiles = router_cfg.get("llm_profiles", {})
cfg = llm_profiles.get(profile_name, {})
if not cfg:
@@ -200,6 +203,8 @@ def static_fallback(
model_type="cloud_llm" if provider in CLOUD_PROVIDERS else "llm",
base_url=cfg.get("base_url", ""),
provider=provider,
node="local",
local=True,
via_ncs=False,
fallback_reason="NCS unavailable or no match; using static config",
)
@@ -214,10 +219,7 @@ async def select_model_for_agent(
capabilities: Optional[Dict[str, Any]],
request_model: Optional[str] = None,
) -> ModelSelection:
"""Full selection pipeline: resolve profile → NCS → static fallback.
This is the single entry point the router calls for each request.
"""
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default."""
profile = resolve_effective_profile(
agent_id, agent_cfg, router_cfg, request_model,
)
@@ -238,36 +240,36 @@ async def select_model_for_agent(
sel = select_best_model(reqs, capabilities)
if sel:
logger.info(
f"[select] agent={agent_id} profile={profile} NCS "
f"runtime={sel.runtime} model={sel.name} caps_age={sel.caps_age_s}s"
f"[select] agent={agent_id} profile={profile}"
f"{'NCS' if sel.local else 'REMOTE'} "
f"node={sel.node} runtime={sel.runtime} "
f"model={sel.name} caps_age={sel.caps_age_s}s"
)
return sel
logger.warning(
f"[select] agent={agent_id} profile={profile} NCS had no match "
f"for type={reqs.required_type}; trying static"
f"[select] agent={agent_id} profile={profile} → no match "
f"for type={reqs.required_type} across {capabilities.get('node_count', 0)} node(s)"
)
static = static_fallback(profile, router_cfg)
if static:
logger.info(
f"[select] agent={agent_id} profile={profile} → static "
f"provider={static.provider} model={static.name} "
f"reason={static.fallback_reason}"
f"provider={static.provider} model={static.name}"
)
return static
if reqs.fallback_profile and reqs.fallback_profile != profile:
logger.warning(
f"[select] agent={agent_id} profile={profile} not found → "
f"trying fallback_profile={reqs.fallback_profile}"
f"fallback_profile={reqs.fallback_profile}"
)
return await select_model_for_agent(
agent_id, agent_cfg, router_cfg, capabilities,
)
logger.error(
f"[select] agent={agent_id} profile={profile} → ALL selection "
f"methods failed. Using hard default qwen3:14b"
f"[select] agent={agent_id} ALL methods failed → hard default"
)
return ModelSelection(
runtime="ollama",
@@ -275,6 +277,8 @@ async def select_model_for_agent(
model_type="llm",
base_url="http://host.docker.internal:11434",
provider="ollama",
node="local",
local=True,
via_ncs=False,
fallback_reason="all methods failed; hard default",
)