Files
microdao-daarion/services/router/model_select.py
Apple c4b94a327d P2.2+P2.3: NATS offload node-worker + router offload integration
Node Worker (services/node-worker/):
- NATS subscriber for node.{NODE_ID}.llm.request / vision.request
- Canonical JobRequest/JobResponse envelope (Pydantic)
- Idempotency cache (TTL 10min) with inflight dedup
- Deadline enforcement (DEADLINE_EXCEEDED on expired jobs)
- Concurrency limiter (semaphore, returns busy)
- Ollama + Swapper vision providers

Router offload (services/router/offload_client.py):
- NATS req/reply with configurable retries
- Circuit breaker per node+type (3 fails/60s → open 120s)
- Concurrency semaphore for remote requests

Model selection (services/router/model_select.py):
- exclude_nodes parameter for circuit-broken nodes
- force_local flag for fallback re-selection
- Integrated circuit breaker state awareness

Router /infer pipeline:
- Remote offload path when NCS selects remote node
- Automatic fallback: exclude failed node → force_local re-select
- Deadline propagation from router to node-worker

Tests: 17 unit tests (idempotency, deadline, circuit breaker)
Docs: ops/offload_routing.md (subjects, envelope, verification)
Made-with: Cursor
2026-02-27 02:44:05 -08:00

316 lines
11 KiB
Python

"""NCS-first model selection for DAGI Router — multi-node aware.
Resolves an agent's LLM profile into a concrete model+provider using live
capabilities from Node Capabilities Services across all nodes.
Falls back to static router-config.yml when NCS is unavailable.
Scaling: works with 1 node or 150+. No static node lists.
"""
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
logger = logging.getLogger("model_select")
CLOUD_PROVIDERS = {"deepseek", "mistral", "grok", "openai", "anthropic"}
@dataclass
class ProfileRequirements:
profile_name: str
required_type: str # llm | vision | code | stt | tts | cloud_llm
prefer: List[str] = field(default_factory=list)
provider: Optional[str] = None
fallback_profile: Optional[str] = None
constraints: Dict[str, Any] = field(default_factory=dict)
@dataclass
class ModelSelection:
runtime: str # ollama | swapper | llama_server | cloud
name: str # model name as runtime knows it
model_type: str # llm | vision | code | …
base_url: str = ""
provider: str = "" # cloud provider name if applicable
node: str = "" # which node owns this model
local: bool = True # is it on the current node?
via_ncs: bool = False
via_nats: bool = False
fallback_reason: str = ""
caps_age_s: float = 0.0
# ── Profile resolution ────────────────────────────────────────────────────────
def resolve_effective_profile(
agent_id: str,
agent_cfg: Dict[str, Any],
router_cfg: Dict[str, Any],
request_model: Optional[str] = None,
) -> str:
if request_model:
llm_profiles = router_cfg.get("llm_profiles", {})
for pname, pcfg in llm_profiles.items():
if pcfg.get("model") == request_model:
return pname
return agent_cfg.get("default_llm", "local_default_coder")
def profile_requirements(
profile_name: str,
agent_cfg: Dict[str, Any],
router_cfg: Dict[str, Any],
) -> ProfileRequirements:
llm_profiles = router_cfg.get("llm_profiles", {})
selection_policies = router_cfg.get("selection_policies", {})
profile_cfg = llm_profiles.get(profile_name, {})
policy = selection_policies.get(profile_name, {})
if policy:
return ProfileRequirements(
profile_name=profile_name,
required_type=policy.get("required_type", "llm"),
prefer=policy.get("prefer", []),
provider=policy.get("provider"),
fallback_profile=policy.get("fallback_profile")
or agent_cfg.get("fallback_llm"),
constraints=policy.get("constraints", {}),
)
provider = profile_cfg.get("provider", "ollama")
model = profile_cfg.get("model", "")
if provider in CLOUD_PROVIDERS:
return ProfileRequirements(
profile_name=profile_name,
required_type="cloud_llm",
prefer=[],
provider=provider,
fallback_profile=agent_cfg.get("fallback_llm", "local_default_coder"),
)
req_type = "llm"
if "vision" in profile_name or "vl" in model.lower():
req_type = "vision"
elif "coder" in profile_name or "code" in model.lower():
req_type = "code"
return ProfileRequirements(
profile_name=profile_name,
required_type=req_type,
prefer=[model] if model else [],
provider=provider,
fallback_profile=agent_cfg.get("fallback_llm"),
)
# ── Multi-node model selection ────────────────────────────────────────────────
def select_best_model(
reqs: ProfileRequirements,
capabilities: Dict[str, Any],
exclude_nodes: Optional[Set[str]] = None,
) -> Optional[ModelSelection]:
"""Choose the best served model from global (multi-node) capabilities.
Selection order:
1. Prefer list matches (local first, then remote)
2. Best candidate by size (local first, then remote)
3. None → caller should try static fallback
exclude_nodes: set of node_ids to skip (e.g. circuit-broken nodes).
"""
served = capabilities.get("served_models", [])
if not served:
return None
exclude = exclude_nodes or set()
search_types = [reqs.required_type]
if reqs.required_type == "code":
search_types.append("llm")
if reqs.required_type == "llm":
search_types.append("code")
candidates = [
m for m in served
if m.get("type") in search_types and m.get("node", "") not in exclude
]
if not candidates:
return None
local_candidates = [m for m in candidates if m.get("local", False)]
remote_candidates = [m for m in candidates if not m.get("local", False)]
prefer = reqs.prefer if reqs.prefer else []
for pref in prefer:
if pref == "*":
break
for m in local_candidates:
if pref == m.get("name") or pref in m.get("name", ""):
return _make_selection(m, capabilities)
for m in remote_candidates:
if pref == m.get("name") or pref in m.get("name", ""):
return _make_selection(m, capabilities)
if local_candidates:
return _make_selection(_pick_best(local_candidates), capabilities)
if remote_candidates:
return _make_selection(_pick_best(remote_candidates), capabilities)
return None
def _pick_best(candidates: List[Dict[str, Any]]) -> Dict[str, Any]:
running = [m for m in candidates if m.get("running")]
pool = running if running else candidates
return max(pool, key=lambda m: m.get("size_gb", 0))
def _make_selection(
model: Dict[str, Any],
capabilities: Dict[str, Any],
) -> ModelSelection:
runtime = model.get("runtime", "ollama")
is_local = model.get("local", False)
node = model.get("node", capabilities.get("local_node", ""))
base_url = model.get("base_url", "")
return ModelSelection(
runtime=runtime,
name=model.get("name", ""),
model_type=model.get("type", "llm"),
base_url=base_url,
provider="ollama" if runtime in ("ollama", "llama_server") else runtime,
node=node,
local=is_local,
via_ncs=True,
via_nats=not is_local,
caps_age_s=model.get("node_age_s", 0.0),
)
# ── Static fallback ──────────────────────────────────────────────────────────
def static_fallback(
profile_name: str,
router_cfg: Dict[str, Any],
) -> Optional[ModelSelection]:
llm_profiles = router_cfg.get("llm_profiles", {})
cfg = llm_profiles.get(profile_name, {})
if not cfg:
return None
provider = cfg.get("provider", "ollama")
return ModelSelection(
runtime="cloud" if provider in CLOUD_PROVIDERS else "ollama",
name=cfg.get("model", ""),
model_type="cloud_llm" if provider in CLOUD_PROVIDERS else "llm",
base_url=cfg.get("base_url", ""),
provider=provider,
node="local",
local=True,
via_ncs=False,
fallback_reason="NCS unavailable or no match; using static config",
)
# ── Top-level orchestrator ────────────────────────────────────────────────────
async def select_model_for_agent(
agent_id: str,
agent_cfg: Dict[str, Any],
router_cfg: Dict[str, Any],
capabilities: Optional[Dict[str, Any]],
request_model: Optional[str] = None,
exclude_nodes: Optional[Set[str]] = None,
force_local: bool = False,
) -> ModelSelection:
"""Full selection pipeline: resolve profile → NCS (multi-node) → static → hard default.
exclude_nodes: skip these nodes (circuit-broken). Used on fallback re-selection.
force_local: prefer local-only models (fallback after remote failure).
"""
profile = resolve_effective_profile(
agent_id, agent_cfg, router_cfg, request_model,
)
reqs = profile_requirements(profile, agent_cfg, router_cfg)
if reqs.required_type == "cloud_llm" and not force_local:
static = static_fallback(profile, router_cfg)
if static:
static.fallback_reason = ""
logger.info(
f"[select] agent={agent_id} profile={profile} → cloud "
f"provider={static.provider} model={static.name}"
)
return static
excl = set(exclude_nodes) if exclude_nodes else set()
try:
from offload_client import get_unavailable_nodes
cb_nodes = get_unavailable_nodes(reqs.required_type)
excl |= cb_nodes
if cb_nodes:
logger.info(f"[select] circuit-broken nodes for {reqs.required_type}: {cb_nodes}")
except ImportError:
pass
if capabilities and capabilities.get("served_models"):
sel = select_best_model(reqs, capabilities, exclude_nodes=excl)
if force_local and sel and not sel.local:
sel = select_best_model(
reqs, capabilities,
exclude_nodes=excl | {n.get("node", "") for n in capabilities.get("served_models", []) if not n.get("local")},
)
if sel:
logger.info(
f"[select] agent={agent_id} profile={profile}"
f"{'LOCAL' if sel.local else 'REMOTE'} "
f"node={sel.node} runtime={sel.runtime} "
f"model={sel.name} caps_age={sel.caps_age_s}s"
f"{' (force_local)' if force_local else ''}"
f"{' (excluded: ' + ','.join(excl) + ')' if excl else ''}"
)
return sel
logger.warning(
f"[select] agent={agent_id} profile={profile} → no match "
f"for type={reqs.required_type} across {capabilities.get('node_count', 0)} node(s)"
)
static = static_fallback(profile, router_cfg)
if static:
logger.info(
f"[select] agent={agent_id} profile={profile} → static "
f"provider={static.provider} model={static.name}"
)
return static
if reqs.fallback_profile and reqs.fallback_profile != profile:
logger.warning(
f"[select] agent={agent_id} profile={profile} not found → "
f"fallback_profile={reqs.fallback_profile}"
)
return await select_model_for_agent(
agent_id, agent_cfg, router_cfg, capabilities,
)
logger.error(
f"[select] agent={agent_id} ALL methods failed → hard default"
)
return ModelSelection(
runtime="ollama",
name="qwen3:14b",
model_type="llm",
base_url="http://host.docker.internal:11434",
provider="ollama",
node="local",
local=True,
via_ncs=False,
fallback_reason="all methods failed; hard default",
)