P3.5-P3.7: 2-layer inventory, capability routing, STT/TTS adapters, Dev Contract

NCS:
- _collect_worker_caps() fetches capability flags from node-worker /caps
- _derive_capabilities() merges served model types + worker provider flags
- installed_artifacts replaces inventory_only (disk scan with DISK_SCAN_PATHS env)
- New endpoints: /capabilities/caps, /capabilities/installed

Node Worker:
- STT_PROVIDER, TTS_PROVIDER, OCR_PROVIDER, IMAGE_PROVIDER env flags
- /caps endpoint returns capabilities + providers for NCS aggregation
- STT adapter (providers/stt_mlx_whisper.py) — remote + local mode
- TTS adapter (providers/tts_mlx_kokoro.py) — remote + local mode
- OCR handler via vision_prompted (ollama_vision with OCR prompt)
- NATS subjects: node.{id}.stt/tts/ocr/image.request

Router:
- POST /v1/capability/{stt,tts,ocr,image} — capability-based offload routing
- GET /v1/capabilities — global view with capabilities_by_node
- require_fresh_caps(ttl) preflight guard
- find_nodes_with_capability(cap) + load-based node selection

Ops:
- ops/fabric_snapshot.py — full runtime snapshot collector
- ops/fabric_preflight.sh — quick check + snapshot save + diff
- docs/fabric_contract.md — Dev Contract v0.1 (preflight-first)
- tests/test_fabric_contract.py — CI enforcement (6 tests)

Made-with: Cursor
This commit is contained in:
Apple
2026-02-27 05:24:09 -08:00
parent 194c87f53c
commit 9a36020316
17 changed files with 1352 additions and 21 deletions

View File

@@ -20,8 +20,10 @@ app = FastAPI(title="Node Capabilities Service", version="1.0.0")
NODE_ID = os.getenv("NODE_ID", "noda2")
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
SWAPPER_URL = os.getenv("SWAPPER_URL", "") # empty = skip Swapper probing
SWAPPER_URL = os.getenv("SWAPPER_URL", "")
LLAMA_SERVER_URL = os.getenv("LLAMA_SERVER_URL", "")
NODE_WORKER_URL = os.getenv("NODE_WORKER_URL", "http://node-worker:8109")
DISK_SCAN_PATHS = os.getenv("DISK_SCAN_PATHS", "") # comma-sep extra dirs
_cache: Dict[str, Any] = {}
_cache_ts: float = 0
@@ -129,30 +131,48 @@ async def _collect_llama_server() -> Optional[Dict[str, Any]]:
return runtime
async def _collect_worker_caps() -> Dict[str, Any]:
"""Fetch capability flags from local Node Worker."""
default = {"capabilities": {}, "providers": {}, "defaults": {}}
try:
async with httpx.AsyncClient(timeout=3) as c:
r = await c.get(f"{NODE_WORKER_URL}/caps")
if r.status_code == 200:
return r.json()
except Exception as e:
logger.debug(f"Worker caps unavailable: {e}")
return default
def _collect_disk_inventory() -> List[Dict[str, Any]]:
"""Scan known model directories — NOT for routing, only inventory."""
import pathlib
inventory: List[Dict[str, Any]] = []
scan_dirs = [
("cursor_worktrees", pathlib.Path.home() / ".cursor" / "worktrees"),
("jan_ai", pathlib.Path.home() / "Library" / "Application Support" / "Jan"),
("ollama", pathlib.Path.home() / ".ollama" / "models"),
("hf_cache", pathlib.Path.home() / ".cache" / "huggingface" / "hub"),
("comfyui_main", pathlib.Path.home() / "ComfyUI" / "models"),
("comfyui_docs", pathlib.Path.home() / "Documents" / "ComfyUI" / "models"),
("llama_cpp", pathlib.Path.home() / "Library" / "Application Support" / "llama.cpp" / "models"),
("hf_models", pathlib.Path.home() / "hf_models"),
("jan_ai", pathlib.Path.home() / "Library" / "Application Support" / "Jan"),
]
if DISK_SCAN_PATHS:
for p in DISK_SCAN_PATHS.split(","):
p = p.strip()
if p:
scan_dirs.append(("custom", pathlib.Path(p)))
for source, base in scan_dirs:
if not base.exists():
continue
try:
for f in base.rglob("*"):
if f.suffix in (".gguf", ".safetensors", ".bin", ".pt") and f.stat().st_size > 100_000_000:
if f.suffix in (".gguf", ".safetensors", ".bin", ".pt", ".mlx") and f.stat().st_size > 50_000_000:
inventory.append({
"name": f.stem,
"path": str(f.relative_to(pathlib.Path.home())),
"path": str(f),
"source": source,
"size_gb": round(f.stat().st_size / 1e9, 1),
"type": _classify_model(f.stem),
@@ -191,6 +211,24 @@ def _build_served_models(ollama: Dict, swapper: Dict, llama: Optional[Dict]) ->
return served
def _derive_capabilities(served: List[Dict], worker_caps: Dict) -> Dict[str, Any]:
"""Merge served model types + worker provider flags into capability map."""
served_types = {m.get("type", "llm") for m in served}
wc = worker_caps.get("capabilities", {})
wp = worker_caps.get("providers", {})
has_vision = "vision" in served_types or wc.get("vision", False)
return {
"llm": "llm" in served_types or "code" in served_types,
"vision": has_vision,
"stt": wc.get("stt", False),
"tts": wc.get("tts", False),
"ocr": has_vision and wp.get("ocr", "none") != "none",
"image": wc.get("image", False),
"providers": wp,
}
async def _build_capabilities() -> Dict[str, Any]:
global _cache, _cache_ts
@@ -200,6 +238,7 @@ async def _build_capabilities() -> Dict[str, Any]:
ollama = await _collect_ollama()
swapper = await _collect_swapper()
llama = await _collect_llama_server()
worker_caps = await _collect_worker_caps()
disk = _collect_disk_inventory()
served = _build_served_models(ollama, swapper, llama)
@@ -209,6 +248,7 @@ async def _build_capabilities() -> Dict[str, Any]:
node_load = await build_node_load()
runtime_load = await build_runtime_load(runtimes)
capabilities = _derive_capabilities(served, worker_caps)
result = {
"node_id": NODE_ID,
@@ -216,10 +256,12 @@ async def _build_capabilities() -> Dict[str, Any]:
"runtimes": runtimes,
"served_models": served,
"served_count": len(served),
"capabilities": capabilities,
"node_load": node_load,
"runtime_load": runtime_load,
"inventory_only": disk,
"inventory_count": len(disk),
"installed_artifacts": disk,
"installed_count": len(disk),
"worker": worker_caps,
}
_cache = result
@@ -245,6 +287,26 @@ async def capabilities_models():
return JSONResponse(content={"node_id": data["node_id"], "served_models": data["served_models"]})
@app.get("/capabilities/caps")
async def capabilities_caps():
data = await _build_capabilities()
return JSONResponse(content={
"node_id": data["node_id"],
"capabilities": data.get("capabilities", {}),
"worker": data.get("worker", {}),
})
@app.get("/capabilities/installed")
async def capabilities_installed():
data = await _build_capabilities()
return JSONResponse(content={
"node_id": data["node_id"],
"installed_artifacts": data.get("installed_artifacts", []),
"installed_count": data.get("installed_count", 0),
})
@app.post("/capabilities/refresh")
async def capabilities_refresh():
global _cache_ts

View File

@@ -9,3 +9,8 @@ DEFAULT_VISION = os.getenv("NODE_DEFAULT_VISION", "llava:13b")
MAX_CONCURRENCY = int(os.getenv("NODE_WORKER_MAX_CONCURRENCY", "2"))
MAX_PAYLOAD_BYTES = int(os.getenv("NODE_WORKER_MAX_PAYLOAD_BYTES", str(1024 * 1024)))
PORT = int(os.getenv("PORT", "8109"))
STT_PROVIDER = os.getenv("STT_PROVIDER", "none")
TTS_PROVIDER = os.getenv("TTS_PROVIDER", "none")
OCR_PROVIDER = os.getenv("OCR_PROVIDER", "vision_prompted")
IMAGE_PROVIDER = os.getenv("IMAGE_PROVIDER", "none")

View File

@@ -41,6 +41,33 @@ async def prom_metrics():
return {"error": "prometheus_client not installed"}
@app.get("/caps")
async def caps():
"""Capability flags for NCS to aggregate."""
return {
"node_id": config.NODE_ID,
"capabilities": {
"llm": True,
"vision": True,
"stt": config.STT_PROVIDER != "none",
"tts": config.TTS_PROVIDER != "none",
"ocr": config.OCR_PROVIDER != "none",
"image": config.IMAGE_PROVIDER != "none",
},
"providers": {
"stt": config.STT_PROVIDER,
"tts": config.TTS_PROVIDER,
"ocr": config.OCR_PROVIDER,
"image": config.IMAGE_PROVIDER,
},
"defaults": {
"llm": config.DEFAULT_LLM,
"vision": config.DEFAULT_VISION,
},
"concurrency": config.MAX_CONCURRENCY,
}
@app.on_event("startup")
async def startup():
global _nats_client

View File

@@ -14,7 +14,7 @@ class JobRequest(BaseModel):
trace_id: str = ""
actor_agent_id: str = ""
target_agent_id: str = ""
required_type: Literal["llm", "vision", "stt", "tts"] = "llm"
required_type: Literal["llm", "vision", "stt", "tts", "image", "ocr"] = "llm"
deadline_ts: int = 0
idempotency_key: str = ""
payload: Dict[str, Any] = Field(default_factory=dict)

View File

@@ -0,0 +1,135 @@
"""MLX Whisper STT provider — transcribes audio via mlx-whisper on host.
Runs inside Docker; delegates to MLX Whisper HTTP API on the host.
If MLX_WHISPER_URL is not set, falls back to running mlx_whisper directly
(only works when node-worker runs natively, not in Docker).
"""
import base64
import logging
import os
import tempfile
from typing import Any, Dict, Optional
import httpx
logger = logging.getLogger("provider.stt_mlx_whisper")
MLX_WHISPER_URL = os.getenv("MLX_WHISPER_URL", "")
MLX_WHISPER_MODEL = os.getenv("MLX_WHISPER_MODEL", "mlx-community/whisper-large-v3-turbo")
MAX_AUDIO_BYTES = int(os.getenv("STT_MAX_AUDIO_BYTES", str(25 * 1024 * 1024))) # 25MB
_local_model = None
_local_lock = None
def _lazy_init_local():
"""Lazy-load mlx_whisper for native (non-Docker) execution."""
global _local_model, _local_lock
if _local_lock is not None:
return
import asyncio
_local_lock = asyncio.Lock()
async def _transcribe_local(audio_path: str, language: Optional[str]) -> Dict[str, Any]:
"""Transcribe using local mlx_whisper (Apple Silicon only)."""
_lazy_init_local()
async with _local_lock:
import mlx_whisper
kwargs: Dict[str, Any] = {"path_or_hf_repo": MLX_WHISPER_MODEL}
if language:
kwargs["language"] = language
result = mlx_whisper.transcribe(audio_path, **kwargs)
segments = []
for seg in result.get("segments", []):
segments.append({
"start": seg.get("start", 0),
"end": seg.get("end", 0),
"text": seg.get("text", ""),
})
return {
"text": result.get("text", ""),
"segments": segments,
"language": result.get("language", ""),
}
async def _transcribe_remote(audio_b64: str, language: Optional[str]) -> Dict[str, Any]:
"""Transcribe via MLX Whisper HTTP service on host."""
payload: Dict[str, Any] = {"audio_b64": audio_b64}
if language:
payload["language"] = language
async with httpx.AsyncClient(timeout=120) as c:
resp = await c.post(f"{MLX_WHISPER_URL}/transcribe", json=payload)
resp.raise_for_status()
return resp.json()
async def _resolve_audio(payload: Dict[str, Any]) -> tuple:
"""Return (audio_bytes, audio_b64) from payload."""
audio_b64 = payload.get("audio_b64", "")
audio_url = payload.get("audio_url", "")
if audio_b64:
raw = base64.b64decode(audio_b64)
if len(raw) > MAX_AUDIO_BYTES:
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
return raw, audio_b64
if audio_url:
if audio_url.startswith(("file://", "/")):
path = audio_url.replace("file://", "")
with open(path, "rb") as f:
raw = f.read()
if len(raw) > MAX_AUDIO_BYTES:
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
return raw, base64.b64encode(raw).decode()
async with httpx.AsyncClient(timeout=30) as c:
resp = await c.get(audio_url)
resp.raise_for_status()
raw = resp.content
if len(raw) > MAX_AUDIO_BYTES:
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
return raw, base64.b64encode(raw).decode()
raise ValueError("Either audio_b64 or audio_url is required")
async def transcribe(payload: Dict[str, Any]) -> Dict[str, Any]:
"""Canonical STT entry point.
Payload:
audio_url: str (http/file) — OR —
audio_b64: str (base64 encoded)
language: str (optional, e.g. "uk", "en")
format: "text" | "segments" | "json"
"""
language = payload.get("language")
fmt = payload.get("format", "json")
audio_bytes, audio_b64 = await _resolve_audio(payload)
if MLX_WHISPER_URL:
result = await _transcribe_remote(audio_b64, language)
else:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
result = await _transcribe_local(tmp_path, language)
finally:
os.unlink(tmp_path)
meta = {
"model": MLX_WHISPER_MODEL,
"provider": "mlx_whisper",
"device": "apple_silicon",
}
if fmt == "text":
return {"text": result.get("text", ""), "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}
if fmt == "segments":
return {"text": result.get("text", ""), "segments": result.get("segments", []), "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}
return {**result, "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}

View File

@@ -0,0 +1,123 @@
"""MLX Kokoro TTS provider — generates speech via kokoro on host.
Runs inside Docker; delegates to Kokoro HTTP API on the host.
Falls back to local kokoro-onnx if running natively on Apple Silicon.
"""
import base64
import logging
import os
import tempfile
from typing import Any, Dict
import httpx
logger = logging.getLogger("provider.tts_mlx_kokoro")
MLX_KOKORO_URL = os.getenv("MLX_KOKORO_URL", "")
MLX_KOKORO_MODEL = os.getenv("MLX_KOKORO_MODEL", "kokoro-v1.0")
DEFAULT_VOICE = os.getenv("TTS_DEFAULT_VOICE", "af_heart")
MAX_TEXT_CHARS = int(os.getenv("TTS_MAX_TEXT_CHARS", "5000"))
DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", "24000"))
_local_pipeline = None
_local_lock = None
def _lazy_init_local():
global _local_lock
if _local_lock is not None:
return
import asyncio
_local_lock = asyncio.Lock()
async def _synthesize_local(text: str, voice: str, sample_rate: int) -> bytes:
"""Synthesize via local kokoro (Apple Silicon)."""
_lazy_init_local()
global _local_pipeline
async with _local_lock:
if _local_pipeline is None:
from kokoro import KPipeline
_local_pipeline = KPipeline(lang_code="a")
logger.info(f"Kokoro pipeline initialized: voice={voice}")
import soundfile as sf
import io
generator = _local_pipeline(text, voice=voice)
all_audio = []
for _, _, audio in generator:
all_audio.append(audio)
if not all_audio:
raise RuntimeError("Kokoro produced no audio")
import numpy as np
combined = np.concatenate(all_audio)
buf = io.BytesIO()
sf.write(buf, combined, sample_rate, format="WAV")
return buf.getvalue()
async def _synthesize_remote(text: str, voice: str, fmt: str, sample_rate: int) -> Dict[str, Any]:
"""Synthesize via Kokoro HTTP service on host."""
payload = {
"text": text,
"voice": voice,
"format": fmt,
"sample_rate": sample_rate,
}
async with httpx.AsyncClient(timeout=120) as c:
resp = await c.post(f"{MLX_KOKORO_URL}/synthesize", json=payload)
resp.raise_for_status()
return resp.json()
async def synthesize(payload: Dict[str, Any]) -> Dict[str, Any]:
"""Canonical TTS entry point.
Payload:
text: str (required)
voice: str (optional, default "af_heart")
format: "wav" | "mp3" (default "wav")
sample_rate: int (default 24000)
"""
text = payload.get("text", "")
if not text:
raise ValueError("text is required")
if len(text) > MAX_TEXT_CHARS:
raise ValueError(f"Text exceeds {MAX_TEXT_CHARS} chars")
voice = payload.get("voice", DEFAULT_VOICE)
fmt = payload.get("format", "wav")
sample_rate = payload.get("sample_rate", DEFAULT_SAMPLE_RATE)
meta = {
"model": MLX_KOKORO_MODEL,
"provider": "mlx_kokoro",
"voice": voice,
"device": "apple_silicon",
}
if MLX_KOKORO_URL:
result = await _synthesize_remote(text, voice, fmt, sample_rate)
return {
"audio_b64": result.get("audio_b64", ""),
"audio_url": result.get("audio_url", ""),
"format": fmt,
"meta": meta,
"provider": "mlx_kokoro",
"model": MLX_KOKORO_MODEL,
}
wav_bytes = await _synthesize_local(text, voice, sample_rate)
audio_b64 = base64.b64encode(wav_bytes).decode()
return {
"audio_b64": audio_b64,
"format": "wav",
"meta": meta,
"provider": "mlx_kokoro",
"model": MLX_KOKORO_MODEL,
}

View File

@@ -10,6 +10,7 @@ import config
from models import JobRequest, JobResponse, JobError
from idempotency import IdempotencyStore
from providers import ollama, ollama_vision
from providers import stt_mlx_whisper, tts_mlx_kokoro
import fabric_metrics as fm
logger = logging.getLogger("node-worker")
@@ -34,6 +35,7 @@ async def start(nats_client):
f"node.{nid}.stt.request",
f"node.{nid}.tts.request",
f"node.{nid}.image.request",
f"node.{nid}.ocr.request",
]
for subj in subjects:
await nats_client.subscribe(subj, cb=_handle_request)
@@ -175,14 +177,52 @@ async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
),
timeout=timeout_s,
)
elif job.required_type in ("stt", "tts", "image"):
elif job.required_type == "stt":
if config.STT_PROVIDER == "none":
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="NOT_AVAILABLE", message="STT not configured on this node"),
)
result = await asyncio.wait_for(
stt_mlx_whisper.transcribe(payload), timeout=timeout_s,
)
elif job.required_type == "tts":
if config.TTS_PROVIDER == "none":
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="NOT_AVAILABLE", message="TTS not configured on this node"),
)
result = await asyncio.wait_for(
tts_mlx_kokoro.synthesize(payload), timeout=timeout_s,
)
elif job.required_type == "ocr":
if config.OCR_PROVIDER == "none":
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(code="NOT_AVAILABLE", message="OCR not configured on this node"),
)
ocr_prompt = payload.get("prompt", "Extract all text from this image. Return JSON: {\"text\": \"...\", \"language\": \"...\"}")
result = await asyncio.wait_for(
ollama_vision.infer(
images=payload.get("images"),
prompt=ocr_prompt,
model=model or config.DEFAULT_VISION,
system="You are an OCR engine. Extract text precisely. Return valid JSON only.",
max_tokens=hints.get("max_tokens", 4096),
temperature=0.05,
timeout_s=timeout_s,
),
timeout=timeout_s,
)
result["provider"] = "vision_prompted_ocr"
elif job.required_type == "image":
return JobResponse(
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
status="error",
error=JobError(
code="NOT_YET_IMPLEMENTED",
message=f"{job.required_type} adapter coming soon; use direct runtime API for now",
),
error=JobError(code="NOT_YET_IMPLEMENTED", message="Image adapter pending P3.7"),
)
else:
return JobResponse(

View File

@@ -100,8 +100,8 @@ async def _discover_remote_nodes() -> List[Dict[str, Any]]:
sub = await _nats_client.subscribe(inbox)
try:
await _nats_client.publish_request(
"node.*.capabilities.get", inbox, b""
await _nats_client.publish(
CAPS_DISCOVERY_SUBJECT, b"", reply=inbox,
)
await _nats_client.flush()
@@ -183,6 +183,7 @@ async def get_global_capabilities(force: bool = False) -> Dict[str, Any]:
def _build_global_view() -> Dict[str, Any]:
"""Build a unified view from all cached node capabilities."""
all_served: List[Dict[str, Any]] = []
global_caps: Dict[str, Dict[str, Any]] = {}
for node_id, caps in _node_cache.items():
is_local = (node_id.lower() == LOCAL_NODE_ID.lower())
@@ -194,16 +195,27 @@ def _build_global_view() -> Dict[str, Any]:
"local": is_local,
"node_age_s": round(age, 1),
})
node_caps = caps.get("capabilities", {})
if node_caps:
global_caps[node_id] = {
k: v for k, v in node_caps.items() if k != "providers"
}
all_served.sort(key=lambda m: (0 if m.get("local") else 1, m.get("name", "")))
return {
"local_node": LOCAL_NODE_ID,
"nodes": {nid: {"node_id": nid, "served_count": len(c.get("served_models", [])),
"age_s": round(time.time() - _node_timestamps.get(nid, 0), 1)}
for nid, c in _node_cache.items()},
"nodes": {nid: {
"node_id": nid,
"served_count": len(c.get("served_models", [])),
"installed_count": c.get("installed_count", 0),
"capabilities": c.get("capabilities", {}),
"node_load": c.get("node_load", {}),
"age_s": round(time.time() - _node_timestamps.get(nid, 0), 1),
} for nid, c in _node_cache.items()},
"served_models": all_served,
"served_count": len(all_served),
"capabilities_by_node": global_caps,
"node_count": len(_node_cache),
"updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
@@ -214,6 +226,44 @@ def get_cached_global() -> Dict[str, Any]:
return _build_global_view()
async def require_fresh_caps(ttl: int = 30) -> Optional[Dict[str, Any]]:
"""Preflight: return global caps only if fresh enough.
Returns None if NCS data is stale beyond ttl — caller should use
safe fallback instead of making routing decisions on outdated info.
"""
if not _node_timestamps:
gcaps = await get_global_capabilities(force=True)
if not _node_timestamps:
return None
return gcaps
oldest = min(_node_timestamps.values())
if (time.time() - oldest) > ttl:
gcaps = await get_global_capabilities(force=True)
oldest = min(_node_timestamps.values()) if _node_timestamps else 0
if (time.time() - oldest) > ttl:
logger.warning("[preflight] caps stale after refresh, age=%ds", int(time.time() - oldest))
return None
return gcaps
return _build_global_view()
def find_nodes_with_capability(cap: str) -> List[str]:
"""Return node IDs that have a given capability enabled."""
result = []
for nid, caps in _node_cache.items():
node_caps = caps.get("capabilities", {})
if node_caps.get(cap, False):
result.append(nid)
return result
def get_node_load(node_id: str) -> Dict[str, Any]:
"""Get cached node_load for a specific node."""
caps = _node_cache.get(node_id, {})
return caps.get("node_load", {})
async def send_offload_request(
node_id: str,
request_type: str,

View File

@@ -1,5 +1,5 @@
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import Response
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel, ConfigDict
from typing import Literal, Optional, Dict, Any, List
import asyncio
@@ -3542,6 +3542,7 @@ async def documents_versions(doc_id: str, agent_id: str, limit: int = 20):
async def list_available_models():
"""List all available models from NCS (global capabilities pool)."""
models = []
caps_by_node = {}
try:
from global_capabilities_client import get_global_capabilities
@@ -3555,6 +3556,7 @@ async def list_available_models():
"size_gb": m.get("size_gb"),
"status": "served",
})
caps_by_node = pool.get("capabilities_by_node", {})
except Exception as e:
logger.warning(f"Cannot get NCS global models: {e}")
@@ -3572,7 +3574,110 @@ async def list_available_models():
except Exception as e:
logger.warning(f"Cannot get Ollama models: {e}")
return {"models": models, "total": len(models)}
return {
"models": models,
"total": len(models),
"capabilities_by_node": caps_by_node,
}
# ── Capability-based offload routing ────────────────────────────────────────
@app.post("/v1/capability/{cap_type}")
async def capability_offload(cap_type: str, request: Request):
"""Route a capability request (stt/tts/ocr/image) to the best node.
Router selects the node based on capabilities_by_node, circuit breaker,
and node_load — no static assumptions about which node has what.
"""
valid_types = {"stt", "tts", "ocr", "image"}
if cap_type not in valid_types:
return JSONResponse(status_code=400, content={
"error": f"Invalid capability type: {cap_type}. Valid: {sorted(valid_types)}",
})
if not NCS_AVAILABLE or not global_capabilities_client:
return JSONResponse(status_code=503, content={
"error": "NCS not available — cannot route capability requests",
})
gcaps = await global_capabilities_client.require_fresh_caps(ttl=30)
if gcaps is None:
return JSONResponse(status_code=503, content={
"error": "NCS caps stale — preflight failed, refusing to route",
})
eligible_nodes = global_capabilities_client.find_nodes_with_capability(cap_type)
if not eligible_nodes:
return JSONResponse(status_code=404, content={
"error": f"No node with capability '{cap_type}' available",
"capabilities_by_node": gcaps.get("capabilities_by_node", {}),
})
unavailable = offload_client.get_unavailable_nodes(cap_type) if offload_client else set()
available = [n for n in eligible_nodes if n.lower() not in {u.lower() for u in unavailable}]
if not available:
return JSONResponse(status_code=503, content={
"error": f"All nodes with '{cap_type}' are circuit-broken",
"eligible": eligible_nodes,
"unavailable": list(unavailable),
})
best_node = available[0]
if len(available) > 1:
loads = []
for nid in available:
nl = global_capabilities_client.get_node_load(nid)
score = nl.get("inflight", 0) * 10
if nl.get("mem_pressure") == "high":
score += 100
loads.append((score, nid))
loads.sort()
best_node = loads[0][1]
payload = await request.json()
logger.info(f"[cap.offload] type={cap_type} → node={best_node} (of {available})")
nats_ok = nc is not None and nats_available
if nats_ok and offload_client:
import uuid as _uuid
job = {
"job_id": str(_uuid.uuid4()),
"required_type": cap_type,
"payload": payload,
"deadline_ts": int(time.time() * 1000) + 60000,
"hints": payload.pop("hints", {}),
}
result = await offload_client.offload_infer(
nats_client=nc, node_id=best_node, required_type=cap_type,
job_payload=job, timeout_ms=60000,
)
if result and result.get("status") == "ok":
return JSONResponse(content=result.get("result", result))
error = result.get("error", {}) if result else {}
return JSONResponse(status_code=502, content={
"error": error.get("message", f"Offload to {best_node} failed"),
"code": error.get("code", "OFFLOAD_FAILED"),
"node": best_node,
})
return JSONResponse(status_code=503, content={
"error": "NATS not connected — cannot offload",
})
@app.get("/v1/capabilities")
async def list_global_capabilities():
"""Return full capabilities view across all nodes."""
if not NCS_AVAILABLE or not global_capabilities_client:
return JSONResponse(status_code=503, content={"error": "NCS not available"})
gcaps = await global_capabilities_client.get_global_capabilities()
return JSONResponse(content={
"node_count": gcaps.get("node_count", 0),
"nodes": gcaps.get("nodes", {}),
"capabilities_by_node": gcaps.get("capabilities_by_node", {}),
"served_count": gcaps.get("served_count", 0),
})
@app.get("/v1/agromatrix/shared-memory/pending")

View File

@@ -81,7 +81,7 @@ def get_unavailable_nodes(req_type: str) -> Set[str]:
async def offload_infer(
nats_client,
node_id: str,
required_type: Literal["llm", "vision", "stt", "tts"],
required_type: Literal["llm", "vision", "stt", "tts", "ocr", "image"],
job_payload: Dict[str, Any],
timeout_ms: int = 25000,
) -> Optional[Dict[str, Any]]: