NCS (services/node-capabilities/metrics.py): - NodeLoad: inflight_jobs, queue_depth, concurrency_limit, estimated_wait_ms, cpu_load_1m, mem_pressure (macOS + Linux), rtt_ms_to_hub - RuntimeLoad: per-runtime healthy, p50_ms, p95_ms from rolling 50-sample window - POST /capabilities/report_latency for node-worker → NCS reporting - NCS fetches worker metrics via NODE_WORKER_URL Node Worker: - GET /metrics endpoint (inflight, concurrency, latency buffers) - Latency tracking per job type (llm/vision) with rolling buffer - Fire-and-forget latency reporting to NCS after each successful job Router (model_select v3): - score_candidate(): wait + model_latency + cross_node_penalty + prefer_bonus - LOCAL_THRESHOLD_MS=250: prefer local if within threshold of remote - ModelSelection.score field for observability - Structured [score] logs with chosen node, model, and score breakdown Tests: 19 new (12 scoring + 7 NCS metrics), 36 total pass Docs: ops/runbook_p3_1.md, ops/CHANGELOG_FABRIC.md No breaking changes to JobRequest/JobResponse or capabilities schema. Made-with: Cursor
314 lines
11 KiB
Python
314 lines
11 KiB
Python
"""Node Capabilities Service — exposes live model inventory for router decisions."""
|
|
import os
|
|
import time
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from fastapi import FastAPI, Request
|
|
from fastapi.responses import JSONResponse
|
|
import httpx
|
|
|
|
from metrics import (
|
|
build_node_load, build_runtime_load, record_latency,
|
|
)
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger("node-capabilities")
|
|
|
|
app = FastAPI(title="Node Capabilities Service", version="1.0.0")
|
|
|
|
NODE_ID = os.getenv("NODE_ID", "noda2")
|
|
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
|
SWAPPER_URL = os.getenv("SWAPPER_URL", "http://swapper-service:8890")
|
|
LLAMA_SERVER_URL = os.getenv("LLAMA_SERVER_URL", "")
|
|
|
|
_cache: Dict[str, Any] = {}
|
|
_cache_ts: float = 0
|
|
CACHE_TTL = int(os.getenv("CACHE_TTL_SEC", "15"))
|
|
|
|
|
|
def _classify_model(name: str) -> str:
|
|
nl = name.lower()
|
|
if any(k in nl for k in ("vl", "vision", "llava", "minicpm-v", "clip")):
|
|
return "vision"
|
|
if any(k in nl for k in ("coder", "starcoder", "codellama", "code")):
|
|
return "code"
|
|
if any(k in nl for k in ("embed", "bge", "minilm", "e5-")):
|
|
return "embedding"
|
|
if any(k in nl for k in ("whisper", "stt")):
|
|
return "stt"
|
|
if any(k in nl for k in ("kokoro", "tts", "bark", "coqui", "xtts")):
|
|
return "tts"
|
|
if any(k in nl for k in ("flux", "sdxl", "stable-diffusion", "ltx")):
|
|
return "image_gen"
|
|
return "llm"
|
|
|
|
|
|
async def _collect_ollama() -> Dict[str, Any]:
|
|
runtime: Dict[str, Any] = {"base_url": OLLAMA_BASE_URL, "status": "unknown", "models": []}
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5) as c:
|
|
r = await c.get(f"{OLLAMA_BASE_URL}/api/tags")
|
|
if r.status_code == 200:
|
|
data = r.json()
|
|
runtime["status"] = "ok"
|
|
for m in data.get("models", []):
|
|
runtime["models"].append({
|
|
"name": m.get("name", ""),
|
|
"size_bytes": m.get("size", 0),
|
|
"size_gb": round(m.get("size", 0) / 1e9, 1),
|
|
"type": _classify_model(m.get("name", "")),
|
|
"modified": m.get("modified_at", "")[:10],
|
|
})
|
|
ps = await c.get(f"{OLLAMA_BASE_URL}/api/ps")
|
|
if ps.status_code == 200:
|
|
running = ps.json().get("models", [])
|
|
running_names = {m.get("name", "") for m in running}
|
|
for model in runtime["models"]:
|
|
model["running"] = model["name"] in running_names
|
|
except Exception as e:
|
|
runtime["status"] = f"error: {e}"
|
|
logger.warning(f"Ollama collector failed: {e}")
|
|
return runtime
|
|
|
|
|
|
async def _collect_swapper() -> Dict[str, Any]:
|
|
runtime: Dict[str, Any] = {"base_url": SWAPPER_URL, "status": "unknown", "models": [], "vision_models": [], "active_model": None}
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5) as c:
|
|
h = await c.get(f"{SWAPPER_URL}/health")
|
|
if h.status_code == 200:
|
|
hd = h.json()
|
|
runtime["status"] = hd.get("status", "ok")
|
|
runtime["active_model"] = hd.get("active_model")
|
|
|
|
mr = await c.get(f"{SWAPPER_URL}/models")
|
|
if mr.status_code == 200:
|
|
for m in mr.json().get("models", []):
|
|
runtime["models"].append({
|
|
"name": m.get("name", ""),
|
|
"type": m.get("type", "llm"),
|
|
"size_gb": m.get("size_gb", 0),
|
|
"status": m.get("status", "unknown"),
|
|
})
|
|
|
|
vr = await c.get(f"{SWAPPER_URL}/vision/models")
|
|
if vr.status_code == 200:
|
|
for m in vr.json().get("models", []):
|
|
runtime["vision_models"].append({
|
|
"name": m.get("name", ""),
|
|
"type": "vision",
|
|
"size_gb": m.get("size_gb", 0),
|
|
"status": m.get("status", "unknown"),
|
|
})
|
|
except Exception as e:
|
|
runtime["status"] = f"error: {e}"
|
|
logger.warning(f"Swapper collector failed: {e}")
|
|
return runtime
|
|
|
|
|
|
async def _collect_llama_server() -> Optional[Dict[str, Any]]:
|
|
if not LLAMA_SERVER_URL:
|
|
return None
|
|
runtime: Dict[str, Any] = {"base_url": LLAMA_SERVER_URL, "status": "unknown", "models": []}
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5) as c:
|
|
r = await c.get(f"{LLAMA_SERVER_URL}/v1/models")
|
|
if r.status_code == 200:
|
|
data = r.json()
|
|
runtime["status"] = "ok"
|
|
for m in data.get("data", data.get("models", [])):
|
|
name = m.get("id", m.get("name", "unknown"))
|
|
runtime["models"].append({"name": name, "type": "llm"})
|
|
except Exception as e:
|
|
runtime["status"] = f"error: {e}"
|
|
return runtime
|
|
|
|
|
|
def _collect_disk_inventory() -> List[Dict[str, Any]]:
|
|
"""Scan known model directories — NOT for routing, only inventory."""
|
|
import pathlib
|
|
inventory: List[Dict[str, Any]] = []
|
|
|
|
scan_dirs = [
|
|
("cursor_worktrees", pathlib.Path.home() / ".cursor" / "worktrees"),
|
|
("jan_ai", pathlib.Path.home() / "Library" / "Application Support" / "Jan"),
|
|
("hf_cache", pathlib.Path.home() / ".cache" / "huggingface" / "hub"),
|
|
("comfyui_main", pathlib.Path.home() / "ComfyUI" / "models"),
|
|
("comfyui_docs", pathlib.Path.home() / "Documents" / "ComfyUI" / "models"),
|
|
("llama_cpp", pathlib.Path.home() / "Library" / "Application Support" / "llama.cpp" / "models"),
|
|
("hf_models", pathlib.Path.home() / "hf_models"),
|
|
]
|
|
|
|
for source, base in scan_dirs:
|
|
if not base.exists():
|
|
continue
|
|
try:
|
|
for f in base.rglob("*"):
|
|
if f.suffix in (".gguf", ".safetensors", ".bin", ".pt") and f.stat().st_size > 100_000_000:
|
|
inventory.append({
|
|
"name": f.stem,
|
|
"path": str(f.relative_to(pathlib.Path.home())),
|
|
"source": source,
|
|
"size_gb": round(f.stat().st_size / 1e9, 1),
|
|
"type": _classify_model(f.stem),
|
|
"served": False,
|
|
})
|
|
except Exception:
|
|
pass
|
|
|
|
return inventory
|
|
|
|
|
|
def _build_served_models(ollama: Dict, swapper: Dict, llama: Optional[Dict]) -> List[Dict[str, Any]]:
|
|
"""Merge all served models into a flat canonical list."""
|
|
served: List[Dict[str, Any]] = []
|
|
seen = set()
|
|
|
|
for m in ollama.get("models", []):
|
|
key = m["name"]
|
|
if key not in seen:
|
|
seen.add(key)
|
|
served.append({**m, "runtime": "ollama", "base_url": ollama["base_url"]})
|
|
|
|
for m in swapper.get("vision_models", []):
|
|
key = f"swapper:{m['name']}"
|
|
if key not in seen:
|
|
seen.add(key)
|
|
served.append({**m, "runtime": "swapper", "base_url": swapper["base_url"]})
|
|
|
|
if llama:
|
|
for m in llama.get("models", []):
|
|
key = f"llama:{m['name']}"
|
|
if key not in seen:
|
|
seen.add(key)
|
|
served.append({**m, "runtime": "llama_server", "base_url": llama["base_url"]})
|
|
|
|
return served
|
|
|
|
|
|
async def _build_capabilities() -> Dict[str, Any]:
|
|
global _cache, _cache_ts
|
|
|
|
if _cache and (time.time() - _cache_ts) < CACHE_TTL:
|
|
return _cache
|
|
|
|
ollama = await _collect_ollama()
|
|
swapper = await _collect_swapper()
|
|
llama = await _collect_llama_server()
|
|
disk = _collect_disk_inventory()
|
|
served = _build_served_models(ollama, swapper, llama)
|
|
|
|
runtimes = {"ollama": ollama, "swapper": swapper}
|
|
if llama:
|
|
runtimes["llama_server"] = llama
|
|
|
|
node_load = await build_node_load()
|
|
runtime_load = await build_runtime_load(runtimes)
|
|
|
|
result = {
|
|
"node_id": NODE_ID,
|
|
"updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
|
"runtimes": runtimes,
|
|
"served_models": served,
|
|
"served_count": len(served),
|
|
"node_load": node_load,
|
|
"runtime_load": runtime_load,
|
|
"inventory_only": disk,
|
|
"inventory_count": len(disk),
|
|
}
|
|
|
|
_cache = result
|
|
_cache_ts = time.time()
|
|
return result
|
|
|
|
|
|
@app.get("/healthz")
|
|
async def healthz():
|
|
return {"status": "ok", "node_id": NODE_ID}
|
|
|
|
|
|
@app.get("/capabilities")
|
|
async def capabilities():
|
|
data = await _build_capabilities()
|
|
return JSONResponse(content=data)
|
|
|
|
|
|
@app.get("/capabilities/models")
|
|
async def capabilities_models():
|
|
data = await _build_capabilities()
|
|
return JSONResponse(content={"node_id": data["node_id"], "served_models": data["served_models"]})
|
|
|
|
|
|
@app.post("/capabilities/refresh")
|
|
async def capabilities_refresh():
|
|
global _cache_ts
|
|
_cache_ts = 0
|
|
data = await _build_capabilities()
|
|
return JSONResponse(content={"refreshed": True, "served_count": data["served_count"]})
|
|
|
|
|
|
@app.post("/capabilities/report_latency")
|
|
async def report_latency_endpoint(request: Request):
|
|
data = await request.json()
|
|
runtime = data.get("runtime", "ollama")
|
|
req_type = data.get("type", "llm")
|
|
latency_ms = data.get("latency_ms", 0)
|
|
if latency_ms > 0:
|
|
record_latency(runtime, req_type, latency_ms)
|
|
return {"ok": True}
|
|
|
|
|
|
# ── NATS request/reply (optional) ─────────────────────────────────────────────
|
|
|
|
ENABLE_NATS = os.getenv("ENABLE_NATS_CAPS", "false").lower() in ("true", "1", "yes")
|
|
NATS_URL = os.getenv("NATS_URL", "nats://dagi-nats:4222")
|
|
NATS_SUBJECT = f"node.{NODE_ID.lower()}.capabilities.get"
|
|
|
|
_nats_client = None
|
|
|
|
|
|
async def _nats_capabilities_handler(msg):
|
|
"""Handle NATS request/reply for capabilities."""
|
|
import json as _json
|
|
try:
|
|
data = await _build_capabilities()
|
|
payload = _json.dumps(data).encode()
|
|
if msg.reply:
|
|
await _nats_client.publish(msg.reply, payload)
|
|
logger.debug(f"NATS reply sent to {msg.reply} ({len(payload)} bytes)")
|
|
except Exception as e:
|
|
logger.warning(f"NATS handler error: {e}")
|
|
if msg.reply and _nats_client:
|
|
await _nats_client.publish(msg.reply, b'{"error":"internal"}')
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup_nats():
|
|
global _nats_client
|
|
if not ENABLE_NATS:
|
|
logger.info(f"NATS capabilities disabled (ENABLE_NATS_CAPS={ENABLE_NATS})")
|
|
return
|
|
try:
|
|
import nats as nats_lib
|
|
_nats_client = await nats_lib.connect(NATS_URL)
|
|
await _nats_client.subscribe(NATS_SUBJECT, cb=_nats_capabilities_handler)
|
|
logger.info(f"✅ NATS subscribed: {NATS_SUBJECT} on {NATS_URL}")
|
|
except Exception as e:
|
|
logger.warning(f"⚠️ NATS init failed (non-fatal): {e}")
|
|
_nats_client = None
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown_nats():
|
|
if _nats_client:
|
|
try:
|
|
await _nats_client.close()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8099")))
|