P1: NCS-first model selection + NATS capabilities + Grok 4.1
Router model selection: - New model_select.py: resolve_effective_profile → profile_requirements → select_best_model pipeline. NCS-first with graceful static fallback. - selection_policies in router-config.node2.yml define prefer order per profile without hardcoding models (e.g. local_default_coder prefers qwen3:14b then qwen3.5:35b-a3b). - Cloud profiles (cloud_grok, cloud_deepseek) skip NCS; on cloud failure use fallback_profile via NCS for local selection. - Structured logs: selected_profile, required_type, runtime, model, caps_age_s, fallback_reason on every infer request. Grok model fix: - grok-2-1212 no longer exists on xAI API → updated to grok-4-1-fast-reasoning across all 3 hardcoded locations in main.py and router-config.node2.yml. NCS NATS request/reply: - node-capabilities subscribes to node.noda2.capabilities.get (NATS request/reply). Enabled via ENABLE_NATS_CAPS=true in compose. - NODA1 router can query NODA2 capabilities over NATS leafnode without HTTP connectivity. Verified: - NCS: 14 served models from Ollama+Swapper+llama-server - NATS: request/reply returns full capabilities JSON - Sofiia: cloud_grok → grok-4-1-fast-reasoning (tested, 200 OK) - Helion: NCS → qwen3:14b via Ollama (caps_age=23.7s cache hit) - Router health: ok Made-with: Cursor
This commit is contained in:
280
services/router/model_select.py
Normal file
280
services/router/model_select.py
Normal file
@@ -0,0 +1,280 @@
|
||||
"""NCS-first model selection for DAGI Router.
|
||||
|
||||
Resolves an agent's LLM profile into a concrete model+provider using live
|
||||
capabilities from the Node Capabilities Service (NCS). Falls back to static
|
||||
router-config.yml when NCS is unavailable.
|
||||
"""
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger("model_select")
|
||||
|
||||
CLOUD_PROVIDERS = {"deepseek", "mistral", "grok", "openai", "anthropic"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProfileRequirements:
|
||||
profile_name: str
|
||||
required_type: str # llm | vision | code | stt | tts | cloud_llm
|
||||
prefer: List[str] = field(default_factory=list)
|
||||
provider: Optional[str] = None
|
||||
fallback_profile: Optional[str] = None
|
||||
constraints: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelSelection:
|
||||
runtime: str # ollama | swapper | llama_server | cloud
|
||||
name: str # model name as runtime knows it
|
||||
model_type: str # llm | vision | code | …
|
||||
base_url: str = ""
|
||||
provider: str = "" # cloud provider name if applicable
|
||||
via_ncs: bool = False
|
||||
fallback_reason: str = ""
|
||||
caps_age_s: float = 0.0
|
||||
|
||||
|
||||
# ── Profile resolution ────────────────────────────────────────────────────────
|
||||
|
||||
def resolve_effective_profile(
|
||||
agent_id: str,
|
||||
agent_cfg: Dict[str, Any],
|
||||
router_cfg: Dict[str, Any],
|
||||
request_model: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Determine the effective LLM profile name for a request."""
|
||||
if request_model:
|
||||
llm_profiles = router_cfg.get("llm_profiles", {})
|
||||
for pname, pcfg in llm_profiles.items():
|
||||
if pcfg.get("model") == request_model:
|
||||
return pname
|
||||
|
||||
return agent_cfg.get("default_llm", "local_default_coder")
|
||||
|
||||
|
||||
def profile_requirements(
|
||||
profile_name: str,
|
||||
agent_cfg: Dict[str, Any],
|
||||
router_cfg: Dict[str, Any],
|
||||
) -> ProfileRequirements:
|
||||
"""Build selection requirements from a profile definition.
|
||||
|
||||
If the profile has `selection_policy` in config, use it directly.
|
||||
Otherwise, infer from the legacy `provider`/`model` fields.
|
||||
"""
|
||||
llm_profiles = router_cfg.get("llm_profiles", {})
|
||||
selection_policies = router_cfg.get("selection_policies", {})
|
||||
profile_cfg = llm_profiles.get(profile_name, {})
|
||||
|
||||
policy = selection_policies.get(profile_name, {})
|
||||
if policy:
|
||||
return ProfileRequirements(
|
||||
profile_name=profile_name,
|
||||
required_type=policy.get("required_type", "llm"),
|
||||
prefer=policy.get("prefer", []),
|
||||
provider=policy.get("provider"),
|
||||
fallback_profile=policy.get("fallback_profile")
|
||||
or agent_cfg.get("fallback_llm"),
|
||||
constraints=policy.get("constraints", {}),
|
||||
)
|
||||
|
||||
provider = profile_cfg.get("provider", "ollama")
|
||||
model = profile_cfg.get("model", "")
|
||||
|
||||
if provider in CLOUD_PROVIDERS:
|
||||
return ProfileRequirements(
|
||||
profile_name=profile_name,
|
||||
required_type="cloud_llm",
|
||||
prefer=[],
|
||||
provider=provider,
|
||||
fallback_profile=agent_cfg.get("fallback_llm", "local_default_coder"),
|
||||
)
|
||||
|
||||
req_type = "llm"
|
||||
if "vision" in profile_name or "vl" in model.lower():
|
||||
req_type = "vision"
|
||||
elif "coder" in profile_name or "code" in model.lower():
|
||||
req_type = "code"
|
||||
|
||||
return ProfileRequirements(
|
||||
profile_name=profile_name,
|
||||
required_type=req_type,
|
||||
prefer=[model] if model else [],
|
||||
provider=provider,
|
||||
fallback_profile=agent_cfg.get("fallback_llm"),
|
||||
)
|
||||
|
||||
|
||||
# ── NCS-based selection ───────────────────────────────────────────────────────
|
||||
|
||||
def select_best_model(
|
||||
reqs: ProfileRequirements,
|
||||
capabilities: Dict[str, Any],
|
||||
) -> Optional[ModelSelection]:
|
||||
"""Choose the best served model from NCS capabilities.
|
||||
|
||||
Returns None if no suitable model found (caller should try static fallback).
|
||||
"""
|
||||
served = capabilities.get("served_models", [])
|
||||
if not served:
|
||||
return None
|
||||
|
||||
caps_age = time.time() - capabilities.get("_fetch_ts", time.time())
|
||||
|
||||
search_types = [reqs.required_type]
|
||||
if reqs.required_type == "code":
|
||||
search_types.append("llm")
|
||||
if reqs.required_type == "llm":
|
||||
search_types.append("code")
|
||||
|
||||
candidates = [m for m in served if m.get("type") in search_types]
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
prefer = reqs.prefer if reqs.prefer else []
|
||||
|
||||
for pref in prefer:
|
||||
if pref == "*":
|
||||
break
|
||||
for m in candidates:
|
||||
if pref == m.get("name") or pref in m.get("name", ""):
|
||||
return _make_selection(m, capabilities, caps_age, reqs)
|
||||
|
||||
if candidates:
|
||||
best = _pick_best_candidate(candidates)
|
||||
return _make_selection(best, capabilities, caps_age, reqs)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _pick_best_candidate(candidates: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Prefer running models, then largest by size_gb."""
|
||||
running = [m for m in candidates if m.get("running")]
|
||||
pool = running if running else candidates
|
||||
return max(pool, key=lambda m: m.get("size_gb", 0))
|
||||
|
||||
|
||||
def _make_selection(
|
||||
model: Dict[str, Any],
|
||||
capabilities: Dict[str, Any],
|
||||
caps_age: float,
|
||||
reqs: ProfileRequirements,
|
||||
) -> ModelSelection:
|
||||
runtime = model.get("runtime", "ollama")
|
||||
base_url = model.get("base_url", "")
|
||||
if not base_url:
|
||||
runtimes = capabilities.get("runtimes", {})
|
||||
rt = runtimes.get(runtime, {})
|
||||
base_url = rt.get("base_url", "")
|
||||
|
||||
return ModelSelection(
|
||||
runtime=runtime,
|
||||
name=model.get("name", ""),
|
||||
model_type=model.get("type", "llm"),
|
||||
base_url=base_url,
|
||||
provider="ollama" if runtime in ("ollama", "llama_server") else runtime,
|
||||
via_ncs=True,
|
||||
caps_age_s=round(caps_age, 1),
|
||||
)
|
||||
|
||||
|
||||
# ── Static fallback (from router-config profiles) ────────────────────────────
|
||||
|
||||
def static_fallback(
|
||||
profile_name: str,
|
||||
router_cfg: Dict[str, Any],
|
||||
) -> Optional[ModelSelection]:
|
||||
"""Build a ModelSelection from the static llm_profiles config."""
|
||||
llm_profiles = router_cfg.get("llm_profiles", {})
|
||||
cfg = llm_profiles.get(profile_name, {})
|
||||
if not cfg:
|
||||
return None
|
||||
|
||||
provider = cfg.get("provider", "ollama")
|
||||
|
||||
return ModelSelection(
|
||||
runtime="cloud" if provider in CLOUD_PROVIDERS else "ollama",
|
||||
name=cfg.get("model", ""),
|
||||
model_type="cloud_llm" if provider in CLOUD_PROVIDERS else "llm",
|
||||
base_url=cfg.get("base_url", ""),
|
||||
provider=provider,
|
||||
via_ncs=False,
|
||||
fallback_reason="NCS unavailable or no match; using static config",
|
||||
)
|
||||
|
||||
|
||||
# ── Top-level orchestrator ────────────────────────────────────────────────────
|
||||
|
||||
async def select_model_for_agent(
|
||||
agent_id: str,
|
||||
agent_cfg: Dict[str, Any],
|
||||
router_cfg: Dict[str, Any],
|
||||
capabilities: Optional[Dict[str, Any]],
|
||||
request_model: Optional[str] = None,
|
||||
) -> ModelSelection:
|
||||
"""Full selection pipeline: resolve profile → NCS → static fallback.
|
||||
|
||||
This is the single entry point the router calls for each request.
|
||||
"""
|
||||
profile = resolve_effective_profile(
|
||||
agent_id, agent_cfg, router_cfg, request_model,
|
||||
)
|
||||
|
||||
reqs = profile_requirements(profile, agent_cfg, router_cfg)
|
||||
|
||||
if reqs.required_type == "cloud_llm":
|
||||
static = static_fallback(profile, router_cfg)
|
||||
if static:
|
||||
static.fallback_reason = ""
|
||||
logger.info(
|
||||
f"[select] agent={agent_id} profile={profile} → cloud "
|
||||
f"provider={static.provider} model={static.name}"
|
||||
)
|
||||
return static
|
||||
|
||||
if capabilities and capabilities.get("served_models"):
|
||||
sel = select_best_model(reqs, capabilities)
|
||||
if sel:
|
||||
logger.info(
|
||||
f"[select] agent={agent_id} profile={profile} → NCS "
|
||||
f"runtime={sel.runtime} model={sel.name} caps_age={sel.caps_age_s}s"
|
||||
)
|
||||
return sel
|
||||
logger.warning(
|
||||
f"[select] agent={agent_id} profile={profile} → NCS had no match "
|
||||
f"for type={reqs.required_type}; trying static"
|
||||
)
|
||||
|
||||
static = static_fallback(profile, router_cfg)
|
||||
if static:
|
||||
logger.info(
|
||||
f"[select] agent={agent_id} profile={profile} → static "
|
||||
f"provider={static.provider} model={static.name} "
|
||||
f"reason={static.fallback_reason}"
|
||||
)
|
||||
return static
|
||||
|
||||
if reqs.fallback_profile and reqs.fallback_profile != profile:
|
||||
logger.warning(
|
||||
f"[select] agent={agent_id} profile={profile} not found → "
|
||||
f"trying fallback_profile={reqs.fallback_profile}"
|
||||
)
|
||||
return await select_model_for_agent(
|
||||
agent_id, agent_cfg, router_cfg, capabilities,
|
||||
)
|
||||
|
||||
logger.error(
|
||||
f"[select] agent={agent_id} profile={profile} → ALL selection "
|
||||
f"methods failed. Using hard default qwen3:14b"
|
||||
)
|
||||
return ModelSelection(
|
||||
runtime="ollama",
|
||||
name="qwen3:14b",
|
||||
model_type="llm",
|
||||
base_url="http://host.docker.internal:11434",
|
||||
provider="ollama",
|
||||
via_ncs=False,
|
||||
fallback_reason="all methods failed; hard default",
|
||||
)
|
||||
Reference in New Issue
Block a user