"""Node Capabilities Service — exposes live model inventory for router decisions.""" import os import time import logging from typing import Any, Dict, List, Optional from fastapi import FastAPI, Request from fastapi.responses import JSONResponse import httpx from metrics import ( build_node_load, build_runtime_load, record_latency, ) import prom_metrics logging.basicConfig(level=logging.INFO) logger = logging.getLogger("node-capabilities") app = FastAPI(title="Node Capabilities Service", version="1.0.0") NODE_ID = os.getenv("NODE_ID", "noda2") OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434") SWAPPER_URL = os.getenv("SWAPPER_URL", "") LLAMA_SERVER_URL = os.getenv("LLAMA_SERVER_URL", "") NODE_WORKER_URL = os.getenv("NODE_WORKER_URL", "http://node-worker:8109") DISK_SCAN_PATHS = os.getenv("DISK_SCAN_PATHS", "") # comma-sep extra dirs _cache: Dict[str, Any] = {} _cache_ts: float = 0 CACHE_TTL = int(os.getenv("CACHE_TTL_SEC", "15")) def _classify_model(name: str) -> str: nl = name.lower() if any(k in nl for k in ("vl", "vision", "llava", "minicpm-v", "clip")): return "vision" if any(k in nl for k in ("coder", "starcoder", "codellama", "code")): return "code" if any(k in nl for k in ("embed", "bge", "minilm", "e5-")): return "embedding" if any(k in nl for k in ("whisper", "stt")): return "stt" if any(k in nl for k in ("kokoro", "tts", "bark", "coqui", "xtts")): return "tts" if any(k in nl for k in ("flux", "sdxl", "stable-diffusion", "ltx")): return "image_gen" return "llm" async def _collect_ollama() -> Dict[str, Any]: runtime: Dict[str, Any] = {"base_url": OLLAMA_BASE_URL, "status": "unknown", "models": []} try: async with httpx.AsyncClient(timeout=5) as c: r = await c.get(f"{OLLAMA_BASE_URL}/api/tags") if r.status_code == 200: data = r.json() runtime["status"] = "ok" for m in data.get("models", []): runtime["models"].append({ "name": m.get("name", ""), "size_bytes": m.get("size", 0), "size_gb": round(m.get("size", 0) / 1e9, 1), "type": _classify_model(m.get("name", "")), "modified": m.get("modified_at", "")[:10], }) ps = await c.get(f"{OLLAMA_BASE_URL}/api/ps") if ps.status_code == 200: running = ps.json().get("models", []) running_names = {m.get("name", "") for m in running} for model in runtime["models"]: model["running"] = model["name"] in running_names except Exception as e: runtime["status"] = f"error: {e}" logger.warning(f"Ollama collector failed: {e}") return runtime async def _collect_swapper() -> Dict[str, Any]: runtime: Dict[str, Any] = {"base_url": SWAPPER_URL or "n/a", "status": "unknown", "models": [], "vision_models": [], "active_model": None} if not SWAPPER_URL: runtime["status"] = "disabled" return runtime try: async with httpx.AsyncClient(timeout=5) as c: h = await c.get(f"{SWAPPER_URL}/health") if h.status_code == 200: hd = h.json() runtime["status"] = hd.get("status", "ok") runtime["active_model"] = hd.get("active_model") mr = await c.get(f"{SWAPPER_URL}/models") if mr.status_code == 200: for m in mr.json().get("models", []): runtime["models"].append({ "name": m.get("name", ""), "type": m.get("type", "llm"), "size_gb": m.get("size_gb", 0), "status": m.get("status", "unknown"), }) vr = await c.get(f"{SWAPPER_URL}/vision/models") if vr.status_code == 200: for m in vr.json().get("models", []): runtime["vision_models"].append({ "name": m.get("name", ""), "type": "vision", "size_gb": m.get("size_gb", 0), "status": m.get("status", "unknown"), }) except Exception as e: runtime["status"] = f"error: {e}" logger.warning(f"Swapper collector failed: {e}") return runtime async def _collect_llama_server() -> Optional[Dict[str, Any]]: if not LLAMA_SERVER_URL: return None runtime: Dict[str, Any] = {"base_url": LLAMA_SERVER_URL, "status": "unknown", "models": []} try: async with httpx.AsyncClient(timeout=5) as c: r = await c.get(f"{LLAMA_SERVER_URL}/v1/models") if r.status_code == 200: data = r.json() runtime["status"] = "ok" for m in data.get("data", data.get("models", [])): name = m.get("id", m.get("name", "unknown")) runtime["models"].append({"name": name, "type": "llm"}) except Exception as e: runtime["status"] = f"error: {e}" return runtime async def _collect_worker_caps() -> Dict[str, Any]: """Fetch capability flags from local Node Worker.""" default = {"capabilities": {}, "providers": {}, "defaults": {}} try: async with httpx.AsyncClient(timeout=3) as c: r = await c.get(f"{NODE_WORKER_URL}/caps") if r.status_code == 200: return r.json() except Exception as e: logger.debug(f"Worker caps unavailable: {e}") return default def _collect_disk_inventory() -> List[Dict[str, Any]]: """Scan known model directories — NOT for routing, only inventory.""" import pathlib inventory: List[Dict[str, Any]] = [] scan_dirs = [ ("ollama", pathlib.Path.home() / ".ollama" / "models"), ("hf_cache", pathlib.Path.home() / ".cache" / "huggingface" / "hub"), ("comfyui_main", pathlib.Path.home() / "ComfyUI" / "models"), ("comfyui_docs", pathlib.Path.home() / "Documents" / "ComfyUI" / "models"), ("llama_cpp", pathlib.Path.home() / "Library" / "Application Support" / "llama.cpp" / "models"), ("hf_models", pathlib.Path.home() / "hf_models"), ("jan_ai", pathlib.Path.home() / "Library" / "Application Support" / "Jan"), ] if DISK_SCAN_PATHS: for p in DISK_SCAN_PATHS.split(","): p = p.strip() if p: scan_dirs.append(("custom", pathlib.Path(p))) for source, base in scan_dirs: if not base.exists(): continue try: for f in base.rglob("*"): if f.suffix in (".gguf", ".safetensors", ".bin", ".pt", ".mlx") and f.stat().st_size > 50_000_000: inventory.append({ "name": f.stem, "path": str(f), "source": source, "size_gb": round(f.stat().st_size / 1e9, 1), "type": _classify_model(f.stem), "served": False, }) except Exception: pass return inventory def _build_served_models(ollama: Dict, swapper: Dict, llama: Optional[Dict]) -> List[Dict[str, Any]]: """Merge all served models into a flat canonical list.""" served: List[Dict[str, Any]] = [] seen = set() for m in ollama.get("models", []): key = m["name"] if key not in seen: seen.add(key) served.append({**m, "runtime": "ollama", "base_url": ollama["base_url"]}) for m in swapper.get("vision_models", []): key = f"swapper:{m['name']}" if key not in seen: seen.add(key) served.append({**m, "runtime": "swapper", "base_url": swapper["base_url"]}) if llama: for m in llama.get("models", []): key = f"llama:{m['name']}" if key not in seen: seen.add(key) served.append({**m, "runtime": "llama_server", "base_url": llama["base_url"]}) return served def _derive_capabilities(served: List[Dict], worker_caps: Dict) -> Dict[str, Any]: """Merge served model types + worker provider flags into capability map. Voice HA caps (voice_tts, voice_llm, voice_stt) pass through directly from node-worker /caps where they are validated against active NATS subscriptions. """ served_types = {m.get("type", "llm") for m in served} wc = worker_caps.get("capabilities", {}) wp = worker_caps.get("providers", {}) has_vision = "vision" in served_types or wc.get("vision", False) return { "llm": "llm" in served_types or "code" in served_types, "vision": has_vision, "stt": wc.get("stt", False), "tts": wc.get("tts", False), "ocr": has_vision and wp.get("ocr", "none") != "none", "image": wc.get("image", False), # Voice HA — pass through from node-worker /caps (reflects active subjects) "voice_tts": wc.get("voice_tts", False), "voice_llm": wc.get("voice_llm", False), "voice_stt": wc.get("voice_stt", False), "providers": wp, "voice_concurrency": worker_caps.get("voice_concurrency", {}), } async def _build_capabilities() -> Dict[str, Any]: global _cache, _cache_ts if _cache and (time.time() - _cache_ts) < CACHE_TTL: return _cache ollama = await _collect_ollama() swapper = await _collect_swapper() llama = await _collect_llama_server() worker_caps = await _collect_worker_caps() disk = _collect_disk_inventory() served = _build_served_models(ollama, swapper, llama) runtimes = {"ollama": ollama, "swapper": swapper} if llama: runtimes["llama_server"] = llama node_load = await build_node_load() runtime_load = await build_runtime_load(runtimes) capabilities = _derive_capabilities(served, worker_caps) result = { "node_id": NODE_ID, "updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "runtimes": runtimes, "served_models": served, "served_count": len(served), "capabilities": capabilities, "node_load": node_load, "runtime_load": runtime_load, "installed_artifacts": disk, "installed_count": len(disk), "worker": worker_caps, } _cache = result _cache_ts = time.time() return result @app.get("/healthz") async def healthz(): return {"status": "ok", "node_id": NODE_ID} @app.get("/capabilities") async def capabilities(): data = await _build_capabilities() prom_metrics.update_from_caps(data) return JSONResponse(content=data) @app.get("/capabilities/models") async def capabilities_models(): data = await _build_capabilities() return JSONResponse(content={"node_id": data["node_id"], "served_models": data["served_models"]}) @app.get("/capabilities/caps") async def capabilities_caps(): data = await _build_capabilities() return JSONResponse(content={ "node_id": data["node_id"], "capabilities": data.get("capabilities", {}), "worker": data.get("worker", {}), }) @app.get("/capabilities/installed") async def capabilities_installed(): data = await _build_capabilities() return JSONResponse(content={ "node_id": data["node_id"], "installed_artifacts": data.get("installed_artifacts", []), "installed_count": data.get("installed_count", 0), }) @app.post("/capabilities/refresh") async def capabilities_refresh(): global _cache_ts _cache_ts = 0 data = await _build_capabilities() return JSONResponse(content={"refreshed": True, "served_count": data["served_count"]}) @app.get("/prom_metrics") async def prom_metrics_endpoint(): data = prom_metrics.get_metrics_text() if data: from fastapi.responses import Response return Response(content=data, media_type="text/plain; charset=utf-8") return {"error": "prometheus_client not installed"} @app.post("/capabilities/report_latency") async def report_latency_endpoint(request: Request): data = await request.json() runtime = data.get("runtime", "ollama") req_type = data.get("type", "llm") latency_ms = data.get("latency_ms", 0) if latency_ms > 0: record_latency(runtime, req_type, latency_ms) return {"ok": True} # ── NATS request/reply (optional) ───────────────────────────────────────────── ENABLE_NATS = os.getenv("ENABLE_NATS_CAPS", "false").lower() in ("true", "1", "yes") NATS_URL = os.getenv("NATS_URL", "nats://dagi-nats:4222") NATS_SUBJECT = f"node.{NODE_ID.lower()}.capabilities.get" NATS_BROADCAST_SUBJECT = "fabric.capabilities.discover" _nats_client = None async def _nats_capabilities_handler(msg): """Handle NATS request/reply for capabilities.""" import json as _json try: data = await _build_capabilities() payload = _json.dumps(data).encode() if msg.reply: await _nats_client.publish(msg.reply, payload) logger.debug(f"NATS reply sent to {msg.reply} ({len(payload)} bytes)") except Exception as e: logger.warning(f"NATS handler error: {e}") if msg.reply and _nats_client: await _nats_client.publish(msg.reply, b'{"error":"internal"}') @app.on_event("startup") async def startup_nats(): global _nats_client if not ENABLE_NATS: logger.info(f"NATS capabilities disabled (ENABLE_NATS_CAPS={ENABLE_NATS})") return try: import nats as nats_lib _nats_client = await nats_lib.connect(NATS_URL) await _nats_client.subscribe(NATS_SUBJECT, cb=_nats_capabilities_handler) await _nats_client.subscribe(NATS_BROADCAST_SUBJECT, cb=_nats_capabilities_handler) logger.info(f"✅ NATS subscribed: {NATS_SUBJECT} + {NATS_BROADCAST_SUBJECT} on {NATS_URL}") except Exception as e: logger.warning(f"⚠️ NATS init failed (non-fatal): {e}") _nats_client = None @app.on_event("shutdown") async def shutdown_nats(): if _nats_client: try: await _nats_client.close() except Exception: pass if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8099")))