"""NCS-first model selection for DAGI Router — multi-node aware. Resolves an agent's LLM profile into a concrete model+provider using live capabilities from Node Capabilities Services across all nodes. Falls back to static router-config.yml when NCS is unavailable. Scaling: works with 1 node or 150+. No static node lists. """ import logging import time from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Set logger = logging.getLogger("model_select") CLOUD_PROVIDERS = {"deepseek", "mistral", "grok", "openai", "anthropic"} @dataclass class ProfileRequirements: profile_name: str required_type: str # llm | vision | code | stt | tts | cloud_llm prefer: List[str] = field(default_factory=list) provider: Optional[str] = None fallback_profile: Optional[str] = None constraints: Dict[str, Any] = field(default_factory=dict) @dataclass class ModelSelection: runtime: str # ollama | swapper | llama_server | cloud name: str # model name as runtime knows it model_type: str # llm | vision | code | … base_url: str = "" provider: str = "" # cloud provider name if applicable node: str = "" # which node owns this model local: bool = True # is it on the current node? via_ncs: bool = False via_nats: bool = False fallback_reason: str = "" caps_age_s: float = 0.0 # ── Profile resolution ──────────────────────────────────────────────────────── def resolve_effective_profile( agent_id: str, agent_cfg: Dict[str, Any], router_cfg: Dict[str, Any], request_model: Optional[str] = None, ) -> str: if request_model: llm_profiles = router_cfg.get("llm_profiles", {}) for pname, pcfg in llm_profiles.items(): if pcfg.get("model") == request_model: return pname return agent_cfg.get("default_llm", "local_default_coder") def profile_requirements( profile_name: str, agent_cfg: Dict[str, Any], router_cfg: Dict[str, Any], ) -> ProfileRequirements: llm_profiles = router_cfg.get("llm_profiles", {}) selection_policies = router_cfg.get("selection_policies", {}) profile_cfg = llm_profiles.get(profile_name, {}) policy = selection_policies.get(profile_name, {}) if policy: return ProfileRequirements( profile_name=profile_name, required_type=policy.get("required_type", "llm"), prefer=policy.get("prefer", []), provider=policy.get("provider"), fallback_profile=policy.get("fallback_profile") or agent_cfg.get("fallback_llm"), constraints=policy.get("constraints", {}), ) provider = profile_cfg.get("provider", "ollama") model = profile_cfg.get("model", "") if provider in CLOUD_PROVIDERS: return ProfileRequirements( profile_name=profile_name, required_type="cloud_llm", prefer=[], provider=provider, fallback_profile=agent_cfg.get("fallback_llm", "local_default_coder"), ) req_type = "llm" if "vision" in profile_name or "vl" in model.lower(): req_type = "vision" elif "coder" in profile_name or "code" in model.lower(): req_type = "code" return ProfileRequirements( profile_name=profile_name, required_type=req_type, prefer=[model] if model else [], provider=provider, fallback_profile=agent_cfg.get("fallback_llm"), ) # ── Multi-node model selection ──────────────────────────────────────────────── def select_best_model( reqs: ProfileRequirements, capabilities: Dict[str, Any], exclude_nodes: Optional[Set[str]] = None, ) -> Optional[ModelSelection]: """Choose the best served model from global (multi-node) capabilities. Selection order: 1. Prefer list matches (local first, then remote) 2. Best candidate by size (local first, then remote) 3. None → caller should try static fallback exclude_nodes: set of node_ids to skip (e.g. circuit-broken nodes). """ served = capabilities.get("served_models", []) if not served: return None exclude = exclude_nodes or set() search_types = [reqs.required_type] if reqs.required_type == "code": search_types.append("llm") if reqs.required_type == "llm": search_types.append("code") candidates = [ m for m in served if m.get("type") in search_types and m.get("node", "") not in exclude ] if not candidates: return None local_candidates = [m for m in candidates if m.get("local", False)] remote_candidates = [m for m in candidates if not m.get("local", False)] prefer = reqs.prefer if reqs.prefer else [] for pref in prefer: if pref == "*": break for m in local_candidates: if pref == m.get("name") or pref in m.get("name", ""): return _make_selection(m, capabilities) for m in remote_candidates: if pref == m.get("name") or pref in m.get("name", ""): return _make_selection(m, capabilities) if local_candidates: return _make_selection(_pick_best(local_candidates), capabilities) if remote_candidates: return _make_selection(_pick_best(remote_candidates), capabilities) return None def _pick_best(candidates: List[Dict[str, Any]]) -> Dict[str, Any]: running = [m for m in candidates if m.get("running")] pool = running if running else candidates return max(pool, key=lambda m: m.get("size_gb", 0)) def _make_selection( model: Dict[str, Any], capabilities: Dict[str, Any], ) -> ModelSelection: runtime = model.get("runtime", "ollama") is_local = model.get("local", False) node = model.get("node", capabilities.get("local_node", "")) base_url = model.get("base_url", "") return ModelSelection( runtime=runtime, name=model.get("name", ""), model_type=model.get("type", "llm"), base_url=base_url, provider="ollama" if runtime in ("ollama", "llama_server") else runtime, node=node, local=is_local, via_ncs=True, via_nats=not is_local, caps_age_s=model.get("node_age_s", 0.0), ) # ── Static fallback ────────────────────────────────────────────────────────── def static_fallback( profile_name: str, router_cfg: Dict[str, Any], ) -> Optional[ModelSelection]: llm_profiles = router_cfg.get("llm_profiles", {}) cfg = llm_profiles.get(profile_name, {}) if not cfg: return None provider = cfg.get("provider", "ollama") return ModelSelection( runtime="cloud" if provider in CLOUD_PROVIDERS else "ollama", name=cfg.get("model", ""), model_type="cloud_llm" if provider in CLOUD_PROVIDERS else "llm", base_url=cfg.get("base_url", ""), provider=provider, node="local", local=True, via_ncs=False, fallback_reason="NCS unavailable or no match; using static config", ) # ── Top-level orchestrator ──────────────────────────────────────────────────── async def select_model_for_agent( agent_id: str, agent_cfg: Dict[str, Any], router_cfg: Dict[str, Any], capabilities: Optional[Dict[str, Any]], request_model: Optional[str] = None, exclude_nodes: Optional[Set[str]] = None, force_local: bool = False, ) -> ModelSelection: """Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default. exclude_nodes: skip these nodes (circuit-broken). Used on fallback re-selection. force_local: prefer local-only models (fallback after remote failure). """ profile = resolve_effective_profile( agent_id, agent_cfg, router_cfg, request_model, ) reqs = profile_requirements(profile, agent_cfg, router_cfg) if reqs.required_type == "cloud_llm" and not force_local: static = static_fallback(profile, router_cfg) if static: static.fallback_reason = "" logger.info( f"[select] agent={agent_id} profile={profile} → cloud " f"provider={static.provider} model={static.name}" ) return static excl = set(exclude_nodes) if exclude_nodes else set() try: from offload_client import get_unavailable_nodes cb_nodes = get_unavailable_nodes(reqs.required_type) excl |= cb_nodes if cb_nodes: logger.info(f"[select] circuit-broken nodes for {reqs.required_type}: {cb_nodes}") except ImportError: pass if capabilities and capabilities.get("served_models"): sel = select_best_model(reqs, capabilities, exclude_nodes=excl) if force_local and sel and not sel.local: sel = select_best_model( reqs, capabilities, exclude_nodes=excl | {n.get("node", "") for n in capabilities.get("served_models", []) if not n.get("local")}, ) if sel: logger.info( f"[select] agent={agent_id} profile={profile} → " f"{'LOCAL' if sel.local else 'REMOTE'} " f"node={sel.node} runtime={sel.runtime} " f"model={sel.name} caps_age={sel.caps_age_s}s" f"{' (force_local)' if force_local else ''}" f"{' (excluded: ' + ','.join(excl) + ')' if excl else ''}" ) return sel logger.warning( f"[select] agent={agent_id} profile={profile} → no match " f"for type={reqs.required_type} across {capabilities.get('node_count', 0)} node(s)" ) static = static_fallback(profile, router_cfg) if static: logger.info( f"[select] agent={agent_id} profile={profile} → static " f"provider={static.provider} model={static.name}" ) return static if reqs.fallback_profile and reqs.fallback_profile != profile: logger.warning( f"[select] agent={agent_id} profile={profile} not found → " f"fallback_profile={reqs.fallback_profile}" ) return await select_model_for_agent( agent_id, agent_cfg, router_cfg, capabilities, ) logger.error( f"[select] agent={agent_id} ALL methods failed → hard default" ) return ModelSelection( runtime="ollama", name="qwen3:14b", model_type="llm", base_url="http://host.docker.internal:11434", provider="ollama", node="local", local=True, via_ncs=False, fallback_reason="all methods failed; hard default", )