"""Tests for P3.1 scoring-based model selection.""" import sys import os sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "services", "router")) from model_select import ( score_candidate, select_best_model, ProfileRequirements, ModelSelection, LOCAL_THRESHOLD_MS, ) def _caps(served, node_load=None, runtime_load=None, nodes=None): return { "served_models": served, "node_load": node_load or {}, "runtime_load": runtime_load or [], "nodes": nodes or {}, } def _model(name, typ="llm", local=True, node="n1", runtime="ollama", **kw): return {"name": name, "type": typ, "local": local, "node": node, "runtime": runtime, "base_url": "http://x", **kw} def _reqs(typ="llm", prefer=None): return ProfileRequirements("test", typ, prefer or []) # ── 1) local wins when scores close ──────────────────────────────────────── def test_local_wins_when_scores_close(): caps = _caps( served=[ _model("qwen3:14b", local=True, node="n1"), _model("qwen3:14b", local=False, node="n2"), ], node_load={"estimated_wait_ms": 0, "rtt_ms_to_hub": None}, ) sel = select_best_model(_reqs(), caps) assert sel is not None assert sel.local is True assert sel.node == "n1" # ── 2) remote wins when local wait is high ───────────────────────────────── def test_remote_wins_when_local_wait_high(): caps = _caps( served=[ _model("qwen3:14b", local=True, node="n1"), _model("qwen3:14b", local=False, node="n2"), ], node_load={"estimated_wait_ms": 5000, "rtt_ms_to_hub": None}, nodes={"n2": {"node_id": "n2", "node_load": {"estimated_wait_ms": 0, "rtt_ms_to_hub": 50}}}, ) sel = select_best_model(_reqs(), caps) assert sel is not None assert sel.local is False assert sel.node == "n2" # ── 3) exclude_nodes works ───────────────────────────────────────────────── def test_exclude_nodes_works(): caps = _caps(served=[ _model("qwen3:14b", local=False, node="n2"), _model("qwen3:14b", local=False, node="n3"), ]) sel = select_best_model(_reqs(), caps, exclude_nodes={"n2"}) assert sel is not None assert sel.node == "n3" # ── 4) breaker open → node excluded (via exclude_nodes) ─────────────────── def test_breaker_excludes_node(): caps = _caps(served=[ _model("qwen3:14b", local=False, node="broken"), _model("qwen3:14b", local=True, node="n1"), ]) sel = select_best_model(_reqs(), caps, exclude_nodes={"broken"}) assert sel is not None assert sel.node == "n1" # ── 5) required_type filter ──────────────────────────────────────────────── def test_required_type_filter(): caps = _caps(served=[ _model("qwen3:14b", typ="llm"), _model("llava:13b", typ="vision"), ]) sel = select_best_model(_reqs(typ="vision"), caps) assert sel is not None assert sel.name == "llava:13b" # ── 6) prefer list filter ───────────────────────────────────────────────── def test_prefer_list_selects_preferred(): caps = _caps(served=[ _model("qwen3:14b"), _model("qwen3.5:35b"), ]) sel = select_best_model(_reqs(prefer=["qwen3.5:35b"]), caps) assert sel is not None assert sel.name == "qwen3.5:35b" # ── 7) score formula — prefer bonus lowers score ────────────────────────── def test_prefer_bonus_lowers_score(): m1 = _model("qwen3:14b") m2 = _model("qwen3.5:35b") caps = _caps(served=[m1, m2]) s1 = score_candidate(m1, caps, prefer=["qwen3:14b"]) s2 = score_candidate(m2, caps, prefer=["qwen3:14b"]) assert s1 < s2 # ── 8) score formula — cross_penalty for remote ────────────────────────── def test_cross_penalty_for_remote(): local = _model("m", local=True) remote = _model("m", local=False, node="r1") caps = _caps(served=[local, remote]) sl = score_candidate(local, caps, prefer=[]) sr = score_candidate(remote, caps, prefer=[], rtt_hint_ms=50) assert sr > sl # ── 9) score formula — wait increases score ────────────────────────────── def test_wait_increases_score(): m = _model("m", local=True) caps_idle = _caps(served=[m], node_load={"estimated_wait_ms": 0}) caps_busy = _caps(served=[m], node_load={"estimated_wait_ms": 3000}) s_idle = score_candidate(m, caps_idle, prefer=[]) s_busy = score_candidate(m, caps_busy, prefer=[]) assert s_busy > s_idle # ── 10) no candidates → None ───────────────────────────────────────────── def test_no_candidates_returns_none(): caps = _caps(served=[_model("m", typ="stt")]) sel = select_best_model(_reqs(typ="llm"), caps) assert sel is None # ── 11) local threshold: local wins within threshold even if remote lower ─ def test_local_threshold(): caps = _caps( served=[ _model("qwen3:14b", local=True, node="n1"), _model("qwen3:14b", local=False, node="n2"), ], node_load={"estimated_wait_ms": 100}, nodes={"n2": {"node_id": "n2", "node_load": {"estimated_wait_ms": 0, "rtt_ms_to_hub": 10}}}, ) sel = select_best_model(_reqs(), caps) assert sel.local is True # ── 12) code type cross-filters with llm ───────────────────────────────── def test_code_type_finds_llm_models(): caps = _caps(served=[_model("qwen3:14b", typ="llm")]) sel = select_best_model(_reqs(typ="code"), caps) assert sel is not None assert sel.name == "qwen3:14b"