Router model selection: - New model_select.py: resolve_effective_profile → profile_requirements → select_best_model pipeline. NCS-first with graceful static fallback. - selection_policies in router-config.node2.yml define prefer order per profile without hardcoding models (e.g. local_default_coder prefers qwen3:14b then qwen3.5:35b-a3b). - Cloud profiles (cloud_grok, cloud_deepseek) skip NCS; on cloud failure use fallback_profile via NCS for local selection. - Structured logs: selected_profile, required_type, runtime, model, caps_age_s, fallback_reason on every infer request. Grok model fix: - grok-2-1212 no longer exists on xAI API → updated to grok-4-1-fast-reasoning across all 3 hardcoded locations in main.py and router-config.node2.yml. NCS NATS request/reply: - node-capabilities subscribes to node.noda2.capabilities.get (NATS request/reply). Enabled via ENABLE_NATS_CAPS=true in compose. - NODA1 router can query NODA2 capabilities over NATS leafnode without HTTP connectivity. Verified: - NCS: 14 served models from Ollama+Swapper+llama-server - NATS: request/reply returns full capabilities JSON - Sofiia: cloud_grok → grok-4-1-fast-reasoning (tested, 200 OK) - Helion: NCS → qwen3:14b via Ollama (caps_age=23.7s cache hit) - Router health: ok Made-with: Cursor
295 lines
10 KiB
Python
295 lines
10 KiB
Python
"""Node Capabilities Service — exposes live model inventory for router decisions."""
|
|
import os
|
|
import time
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import FastAPI
|
|
from fastapi.responses import JSONResponse
|
|
import httpx
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("node-capabilities")
|
|
|
|
app = FastAPI(title="Node Capabilities Service", version="1.0.0")
|
|
|
|
NODE_ID = os.getenv("NODE_ID", "noda2")
|
|
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
|
SWAPPER_URL = os.getenv("SWAPPER_URL", "http://swapper-service:8890")
|
|
LLAMA_SERVER_URL = os.getenv("LLAMA_SERVER_URL", "")
|
|
|
|
_cache: Dict[str, Any] = {}
|
|
_cache_ts: float = 0
|
|
CACHE_TTL = int(os.getenv("CACHE_TTL_SEC", "15"))
|
|
|
|
|
|
def _classify_model(name: str) -> str:
|
|
nl = name.lower()
|
|
if any(k in nl for k in ("vl", "vision", "llava", "minicpm-v", "clip")):
|
|
return "vision"
|
|
if any(k in nl for k in ("coder", "starcoder", "codellama", "code")):
|
|
return "code"
|
|
if any(k in nl for k in ("embed", "bge", "minilm", "e5-")):
|
|
return "embedding"
|
|
if any(k in nl for k in ("whisper", "stt")):
|
|
return "stt"
|
|
if any(k in nl for k in ("kokoro", "tts", "bark", "coqui", "xtts")):
|
|
return "tts"
|
|
if any(k in nl for k in ("flux", "sdxl", "stable-diffusion", "ltx")):
|
|
return "image_gen"
|
|
return "llm"
|
|
|
|
|
|
async def _collect_ollama() -> Dict[str, Any]:
|
|
runtime: Dict[str, Any] = {"base_url": OLLAMA_BASE_URL, "status": "unknown", "models": []}
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5) as c:
|
|
r = await c.get(f"{OLLAMA_BASE_URL}/api/tags")
|
|
if r.status_code == 200:
|
|
data = r.json()
|
|
runtime["status"] = "ok"
|
|
for m in data.get("models", []):
|
|
runtime["models"].append({
|
|
"name": m.get("name", ""),
|
|
"size_bytes": m.get("size", 0),
|
|
"size_gb": round(m.get("size", 0) / 1e9, 1),
|
|
"type": _classify_model(m.get("name", "")),
|
|
"modified": m.get("modified_at", "")[:10],
|
|
})
|
|
ps = await c.get(f"{OLLAMA_BASE_URL}/api/ps")
|
|
if ps.status_code == 200:
|
|
running = ps.json().get("models", [])
|
|
running_names = {m.get("name", "") for m in running}
|
|
for model in runtime["models"]:
|
|
model["running"] = model["name"] in running_names
|
|
except Exception as e:
|
|
runtime["status"] = f"error: {e}"
|
|
logger.warning(f"Ollama collector failed: {e}")
|
|
return runtime
|
|
|
|
|
|
async def _collect_swapper() -> Dict[str, Any]:
|
|
runtime: Dict[str, Any] = {"base_url": SWAPPER_URL, "status": "unknown", "models": [], "vision_models": [], "active_model": None}
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5) as c:
|
|
h = await c.get(f"{SWAPPER_URL}/health")
|
|
if h.status_code == 200:
|
|
hd = h.json()
|
|
runtime["status"] = hd.get("status", "ok")
|
|
runtime["active_model"] = hd.get("active_model")
|
|
|
|
mr = await c.get(f"{SWAPPER_URL}/models")
|
|
if mr.status_code == 200:
|
|
for m in mr.json().get("models", []):
|
|
runtime["models"].append({
|
|
"name": m.get("name", ""),
|
|
"type": m.get("type", "llm"),
|
|
"size_gb": m.get("size_gb", 0),
|
|
"status": m.get("status", "unknown"),
|
|
})
|
|
|
|
vr = await c.get(f"{SWAPPER_URL}/vision/models")
|
|
if vr.status_code == 200:
|
|
for m in vr.json().get("models", []):
|
|
runtime["vision_models"].append({
|
|
"name": m.get("name", ""),
|
|
"type": "vision",
|
|
"size_gb": m.get("size_gb", 0),
|
|
"status": m.get("status", "unknown"),
|
|
})
|
|
except Exception as e:
|
|
runtime["status"] = f"error: {e}"
|
|
logger.warning(f"Swapper collector failed: {e}")
|
|
return runtime
|
|
|
|
|
|
async def _collect_llama_server() -> Optional[Dict[str, Any]]:
|
|
if not LLAMA_SERVER_URL:
|
|
return None
|
|
runtime: Dict[str, Any] = {"base_url": LLAMA_SERVER_URL, "status": "unknown", "models": []}
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5) as c:
|
|
r = await c.get(f"{LLAMA_SERVER_URL}/v1/models")
|
|
if r.status_code == 200:
|
|
data = r.json()
|
|
runtime["status"] = "ok"
|
|
for m in data.get("data", data.get("models", [])):
|
|
name = m.get("id", m.get("name", "unknown"))
|
|
runtime["models"].append({"name": name, "type": "llm"})
|
|
except Exception as e:
|
|
runtime["status"] = f"error: {e}"
|
|
return runtime
|
|
|
|
|
|
def _collect_disk_inventory() -> List[Dict[str, Any]]:
|
|
"""Scan known model directories — NOT for routing, only inventory."""
|
|
import pathlib
|
|
inventory: List[Dict[str, Any]] = []
|
|
|
|
scan_dirs = [
|
|
("cursor_worktrees", pathlib.Path.home() / ".cursor" / "worktrees"),
|
|
("jan_ai", pathlib.Path.home() / "Library" / "Application Support" / "Jan"),
|
|
("hf_cache", pathlib.Path.home() / ".cache" / "huggingface" / "hub"),
|
|
("comfyui_main", pathlib.Path.home() / "ComfyUI" / "models"),
|
|
("comfyui_docs", pathlib.Path.home() / "Documents" / "ComfyUI" / "models"),
|
|
("llama_cpp", pathlib.Path.home() / "Library" / "Application Support" / "llama.cpp" / "models"),
|
|
("hf_models", pathlib.Path.home() / "hf_models"),
|
|
]
|
|
|
|
for source, base in scan_dirs:
|
|
if not base.exists():
|
|
continue
|
|
try:
|
|
for f in base.rglob("*"):
|
|
if f.suffix in (".gguf", ".safetensors", ".bin", ".pt") and f.stat().st_size > 100_000_000:
|
|
inventory.append({
|
|
"name": f.stem,
|
|
"path": str(f.relative_to(pathlib.Path.home())),
|
|
"source": source,
|
|
"size_gb": round(f.stat().st_size / 1e9, 1),
|
|
"type": _classify_model(f.stem),
|
|
"served": False,
|
|
})
|
|
except Exception:
|
|
pass
|
|
|
|
return inventory
|
|
|
|
|
|
def _build_served_models(ollama: Dict, swapper: Dict, llama: Optional[Dict]) -> List[Dict[str, Any]]:
|
|
"""Merge all served models into a flat canonical list."""
|
|
served: List[Dict[str, Any]] = []
|
|
seen = set()
|
|
|
|
for m in ollama.get("models", []):
|
|
key = m["name"]
|
|
if key not in seen:
|
|
seen.add(key)
|
|
served.append({**m, "runtime": "ollama", "base_url": ollama["base_url"]})
|
|
|
|
for m in swapper.get("vision_models", []):
|
|
key = f"swapper:{m['name']}"
|
|
if key not in seen:
|
|
seen.add(key)
|
|
served.append({**m, "runtime": "swapper", "base_url": swapper["base_url"]})
|
|
|
|
if llama:
|
|
for m in llama.get("models", []):
|
|
key = f"llama:{m['name']}"
|
|
if key not in seen:
|
|
seen.add(key)
|
|
served.append({**m, "runtime": "llama_server", "base_url": llama["base_url"]})
|
|
|
|
return served
|
|
|
|
|
|
async def _build_capabilities() -> Dict[str, Any]:
|
|
global _cache, _cache_ts
|
|
|
|
if _cache and (time.time() - _cache_ts) < CACHE_TTL:
|
|
return _cache
|
|
|
|
ollama = await _collect_ollama()
|
|
swapper = await _collect_swapper()
|
|
llama = await _collect_llama_server()
|
|
disk = _collect_disk_inventory()
|
|
served = _build_served_models(ollama, swapper, llama)
|
|
|
|
result = {
|
|
"node_id": NODE_ID,
|
|
"updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
"runtimes": {
|
|
"ollama": ollama,
|
|
"swapper": swapper,
|
|
},
|
|
"served_models": served,
|
|
"served_count": len(served),
|
|
"inventory_only": disk,
|
|
"inventory_count": len(disk),
|
|
}
|
|
if llama:
|
|
result["runtimes"]["llama_server"] = llama
|
|
|
|
_cache = result
|
|
_cache_ts = time.time()
|
|
return result
|
|
|
|
|
|
@app.get("/healthz")
|
|
async def healthz():
|
|
return {"status": "ok", "node_id": NODE_ID}
|
|
|
|
|
|
@app.get("/capabilities")
|
|
async def capabilities():
|
|
data = await _build_capabilities()
|
|
return JSONResponse(content=data)
|
|
|
|
|
|
@app.get("/capabilities/models")
|
|
async def capabilities_models():
|
|
data = await _build_capabilities()
|
|
return JSONResponse(content={"node_id": data["node_id"], "served_models": data["served_models"]})
|
|
|
|
|
|
@app.post("/capabilities/refresh")
|
|
async def capabilities_refresh():
|
|
global _cache_ts
|
|
_cache_ts = 0
|
|
data = await _build_capabilities()
|
|
return JSONResponse(content={"refreshed": True, "served_count": data["served_count"]})
|
|
|
|
|
|
# ── NATS request/reply (optional) ─────────────────────────────────────────────
|
|
|
|
ENABLE_NATS = os.getenv("ENABLE_NATS_CAPS", "false").lower() in ("true", "1", "yes")
|
|
NATS_URL = os.getenv("NATS_URL", "nats://dagi-nats:4222")
|
|
NATS_SUBJECT = f"node.{NODE_ID.lower()}.capabilities.get"
|
|
|
|
_nats_client = None
|
|
|
|
|
|
async def _nats_capabilities_handler(msg):
|
|
"""Handle NATS request/reply for capabilities."""
|
|
import json as _json
|
|
try:
|
|
data = await _build_capabilities()
|
|
payload = _json.dumps(data).encode()
|
|
if msg.reply:
|
|
await _nats_client.publish(msg.reply, payload)
|
|
logger.debug(f"NATS reply sent to {msg.reply} ({len(payload)} bytes)")
|
|
except Exception as e:
|
|
logger.warning(f"NATS handler error: {e}")
|
|
if msg.reply and _nats_client:
|
|
await _nats_client.publish(msg.reply, b'{"error":"internal"}')
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_nats():
|
|
global _nats_client
|
|
if not ENABLE_NATS:
|
|
logger.info(f"NATS capabilities disabled (ENABLE_NATS_CAPS={ENABLE_NATS})")
|
|
return
|
|
try:
|
|
import nats as nats_lib
|
|
_nats_client = await nats_lib.connect(NATS_URL)
|
|
await _nats_client.subscribe(NATS_SUBJECT, cb=_nats_capabilities_handler)
|
|
logger.info(f"✅ NATS subscribed: {NATS_SUBJECT} on {NATS_URL}")
|
|
except Exception as e:
|
|
logger.warning(f"⚠️ NATS init failed (non-fatal): {e}")
|
|
_nats_client = None
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_nats():
|
|
if _nats_client:
|
|
try:
|
|
await _nats_client.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8099")))
|