"""NCS-first model selection for DAGI Router — multi-node aware. Resolves an agent's LLM profile into a concrete model+provider using live capabilities from Node Capabilities Services across all nodes. Falls back to static router-config.yml when NCS is unavailable. Scaling: works with 1 node or 150+. No static node lists. """ import logging import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set logger = logging.getLogger("model_select") CLOUD_PROVIDERS = {"deepseek", "mistral", "grok", "openai", "anthropic"} @dataclass class ProfileRequirements: profile_name: str required_type: str # llm | vision | code | stt | tts | cloud_llm prefer: List[str] = field(default_factory=list) provider: Optional[str] = None fallback_profile: Optional[str] = None constraints: Dict[str, Any] = field(default_factory=dict) LOCAL_THRESHOLD_MS = 250 @dataclass class ModelSelection: runtime: str # ollama | swapper | llama_server | cloud name: str # model name as runtime knows it model_type: str # llm | vision | code | … base_url: str = "" provider: str = "" # cloud provider name if applicable node: str = "" # which node owns this model local: bool = True # is it on the current node? via_ncs: bool = False via_nats: bool = False fallback_reason: str = "" caps_age_s: float = 0.0 score: int = 0 # lower = faster # ── Profile resolution ──────────────────────────────────────────────────────── def resolve_effective_profile( agent_id: str, agent_cfg: Dict[str, Any], router_cfg: Dict[str, Any], request_model: Optional[str] = None, ) -> str: if request_model: llm_profiles = router_cfg.get("llm_profiles", {}) for pname, pcfg in llm_profiles.items(): if pcfg.get("model") == request_model: return pname return agent_cfg.get("default_llm", "local_default_coder") def profile_requirements( profile_name: str, agent_cfg: Dict[str, Any], router_cfg: Dict[str, Any], ) -> ProfileRequirements: llm_profiles = router_cfg.get("llm_profiles", {}) selection_policies = router_cfg.get("selection_policies", {}) profile_cfg = llm_profiles.get(profile_name, {}) policy = selection_policies.get(profile_name, {}) if policy: return ProfileRequirements( profile_name=profile_name, required_type=policy.get("required_type", "llm"), prefer=policy.get("prefer", []), provider=policy.get("provider"), fallback_profile=policy.get("fallback_profile") or agent_cfg.get("fallback_llm"), constraints=policy.get("constraints", {}), ) provider = profile_cfg.get("provider", "ollama") model = profile_cfg.get("model", "") if provider in CLOUD_PROVIDERS: return ProfileRequirements( profile_name=profile_name, required_type="cloud_llm", prefer=[], provider=provider, fallback_profile=agent_cfg.get("fallback_llm", "local_default_coder"), ) req_type = "llm" if "vision" in profile_name or "vl" in model.lower(): req_type = "vision" elif "coder" in profile_name or "code" in model.lower(): req_type = "code" return ProfileRequirements( profile_name=profile_name, required_type=req_type, prefer=[model] if model else [], provider=provider, fallback_profile=agent_cfg.get("fallback_llm"), ) # ── Scoring ─────────────────────────────────────────────────────────────────── def score_candidate( model: Dict[str, Any], capabilities: Dict[str, Any], prefer: List[str], rtt_hint_ms: int = 60, ) -> int: """Lower score = better candidate. Formula: wait + model_latency + cross_node_penalty + prefer_bonus """ is_local = model.get("local", False) node_id = model.get("node", "") node_load = capabilities.get("node_load", {}) if not is_local: for ndata in capabilities.get("nodes", {}).values(): if ndata.get("node_id") == node_id: node_load = ndata.get("node_load", {}) break wait = node_load.get("estimated_wait_ms", 0) model_lat = model.get("model_p50_ms") or 0 if not model_lat: runtime_loads = capabilities.get("runtime_load", []) rt = model.get("runtime", "ollama") for rl in runtime_loads: if rl.get("runtime") == rt: model_lat = rl.get("p50_ms") or 0 break if not model_lat: model_lat = 1500 rtt = 0 if is_local else (node_load.get("rtt_ms_to_hub") or rtt_hint_ms or 60) cross_penalty = 0 if is_local else (rtt * 2) prefer_bonus = 0 name = model.get("name", "") for i, pref in enumerate(prefer): if pref == "*": break if pref == name or pref in name: prefer_bonus = -(1000 - i * 100) break return wait + model_lat + cross_penalty + prefer_bonus # ── Multi-node model selection ──────────────────────────────────────────────── def select_best_model( reqs: ProfileRequirements, capabilities: Dict[str, Any], exclude_nodes: Optional[Set[str]] = None, ) -> Optional[ModelSelection]: """Choose the best served model from global (multi-node) capabilities. Uses scoring: wait + model_latency + cross_node_rtt + prefer_bonus. If best local score <= best remote score + LOCAL_THRESHOLD_MS, prefer local. exclude_nodes: set of node_ids to skip (e.g. circuit-broken nodes). """ served = capabilities.get("served_models", []) if not served: return None exclude = exclude_nodes or set() search_types = [reqs.required_type] if reqs.required_type == "code": search_types.append("llm") if reqs.required_type == "llm": search_types.append("code") candidates = [ m for m in served if m.get("type") in search_types and m.get("node", "") not in exclude ] if not candidates: return None prefer = reqs.prefer if reqs.prefer else [] scored = [(score_candidate(m, capabilities, prefer), m) for m in candidates] scored.sort(key=lambda x: x[0]) local_scored = [(s, m) for s, m in scored if m.get("local", False)] remote_scored = [(s, m) for s, m in scored if not m.get("local", False)] best_local = local_scored[0] if local_scored else None best_remote = remote_scored[0] if remote_scored else None if best_local and best_remote: if best_local[0] <= best_remote[0] + LOCAL_THRESHOLD_MS: sel = _make_selection(best_local[1], capabilities) sel.score = best_local[0] return sel sel = _make_selection(best_remote[1], capabilities) sel.score = best_remote[0] return sel winner = (best_local or best_remote) if winner: sel = _make_selection(winner[1], capabilities) sel.score = winner[0] return sel return None def _make_selection( model: Dict[str, Any], capabilities: Dict[str, Any], ) -> ModelSelection: runtime = model.get("runtime", "ollama") is_local = model.get("local", False) node = model.get("node", capabilities.get("local_node", "")) base_url = model.get("base_url", "") return ModelSelection( runtime=runtime, name=model.get("name", ""), model_type=model.get("type", "llm"), base_url=base_url, provider="ollama" if runtime in ("ollama", "llama_server") else runtime, node=node, local=is_local, via_ncs=True, via_nats=not is_local, caps_age_s=model.get("node_age_s", 0.0), ) # ── Static fallback ────────────────────────────────────────────────────────── def static_fallback( profile_name: str, router_cfg: Dict[str, Any], ) -> Optional[ModelSelection]: llm_profiles = router_cfg.get("llm_profiles", {}) cfg = llm_profiles.get(profile_name, {}) if not cfg: return None provider = cfg.get("provider", "ollama") return ModelSelection( runtime="cloud" if provider in CLOUD_PROVIDERS else "ollama", name=cfg.get("model", ""), model_type="cloud_llm" if provider in CLOUD_PROVIDERS else "llm", base_url=cfg.get("base_url", ""), provider=provider, node="local", local=True, via_ncs=False, fallback_reason="NCS unavailable or no match; using static config", ) # ── Top-level orchestrator ──────────────────────────────────────────────────── async def select_model_for_agent( agent_id: str, agent_cfg: Dict[str, Any], router_cfg: Dict[str, Any], capabilities: Optional[Dict[str, Any]], request_model: Optional[str] = None, exclude_nodes: Optional[Set[str]] = None, force_local: bool = False, ) -> ModelSelection: """Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default. exclude_nodes: skip these nodes (circuit-broken). Used on fallback re-selection. force_local: prefer local-only models (fallback after remote failure). """ profile = resolve_effective_profile( agent_id, agent_cfg, router_cfg, request_model, ) reqs = profile_requirements(profile, agent_cfg, router_cfg) if reqs.required_type == "cloud_llm" and not force_local: static = static_fallback(profile, router_cfg) if static: static.fallback_reason = "" logger.info( f"[select] agent={agent_id} profile={profile} → cloud " f"provider={static.provider} model={static.name}" ) return static excl = set(exclude_nodes) if exclude_nodes else set() try: from offload_client import get_unavailable_nodes cb_nodes = get_unavailable_nodes(reqs.required_type) excl |= cb_nodes if cb_nodes: logger.info(f"[select] circuit-broken nodes for {reqs.required_type}: {cb_nodes}") except ImportError: pass if capabilities and capabilities.get("served_models"): sel = select_best_model(reqs, capabilities, exclude_nodes=excl) if force_local and sel and not sel.local: sel = select_best_model( reqs, capabilities, exclude_nodes=excl | {n.get("node", "") for n in capabilities.get("served_models", []) if not n.get("local")}, ) if sel: logger.info( f"[score] agent={agent_id} type={reqs.required_type} " f"chosen={'LOCAL' if sel.local else 'REMOTE'}:{sel.node}/{sel.name} " f"score={sel.score} caps_age={sel.caps_age_s}s" f"{' (force_local)' if force_local else ''}" f"{' (excluded: ' + ','.join(excl) + ')' if excl else ''}" ) try: from fabric_metrics import inc_model_select, observe_score inc_model_select(sel.node, sel.runtime, reqs.required_type) observe_score(sel.score) except ImportError: pass return sel logger.warning( f"[select] agent={agent_id} profile={profile} → no match " f"for type={reqs.required_type} across {capabilities.get('node_count', 0)} node(s)" ) static = static_fallback(profile, router_cfg) if static: logger.info( f"[select] agent={agent_id} profile={profile} → static " f"provider={static.provider} model={static.name}" ) return static if reqs.fallback_profile and reqs.fallback_profile != profile: logger.warning( f"[select] agent={agent_id} profile={profile} not found → " f"fallback_profile={reqs.fallback_profile}" ) return await select_model_for_agent( agent_id, agent_cfg, router_cfg, capabilities, ) logger.error( f"[select] agent={agent_id} ALL methods failed → hard default" ) return ModelSelection( runtime="ollama", name="qwen3:14b", model_type="llm", base_url="http://host.docker.internal:11434", provider="ollama", node="local", local=True, via_ncs=False, fallback_reason="all methods failed; hard default", )