P3.5-P3.7: 2-layer inventory, capability routing, STT/TTS adapters, Dev Contract
NCS:
- _collect_worker_caps() fetches capability flags from node-worker /caps
- _derive_capabilities() merges served model types + worker provider flags
- installed_artifacts replaces inventory_only (disk scan with DISK_SCAN_PATHS env)
- New endpoints: /capabilities/caps, /capabilities/installed
Node Worker:
- STT_PROVIDER, TTS_PROVIDER, OCR_PROVIDER, IMAGE_PROVIDER env flags
- /caps endpoint returns capabilities + providers for NCS aggregation
- STT adapter (providers/stt_mlx_whisper.py) — remote + local mode
- TTS adapter (providers/tts_mlx_kokoro.py) — remote + local mode
- OCR handler via vision_prompted (ollama_vision with OCR prompt)
- NATS subjects: node.{id}.stt/tts/ocr/image.request
Router:
- POST /v1/capability/{stt,tts,ocr,image} — capability-based offload routing
- GET /v1/capabilities — global view with capabilities_by_node
- require_fresh_caps(ttl) preflight guard
- find_nodes_with_capability(cap) + load-based node selection
Ops:
- ops/fabric_snapshot.py — full runtime snapshot collector
- ops/fabric_preflight.sh — quick check + snapshot save + diff
- docs/fabric_contract.md — Dev Contract v0.1 (preflight-first)
- tests/test_fabric_contract.py — CI enforcement (6 tests)
Made-with: Cursor
This commit is contained in:
@@ -20,8 +20,10 @@ app = FastAPI(title="Node Capabilities Service", version="1.0.0")
|
||||
|
||||
NODE_ID = os.getenv("NODE_ID", "noda2")
|
||||
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://host.docker.internal:11434")
|
||||
SWAPPER_URL = os.getenv("SWAPPER_URL", "") # empty = skip Swapper probing
|
||||
SWAPPER_URL = os.getenv("SWAPPER_URL", "")
|
||||
LLAMA_SERVER_URL = os.getenv("LLAMA_SERVER_URL", "")
|
||||
NODE_WORKER_URL = os.getenv("NODE_WORKER_URL", "http://node-worker:8109")
|
||||
DISK_SCAN_PATHS = os.getenv("DISK_SCAN_PATHS", "") # comma-sep extra dirs
|
||||
|
||||
_cache: Dict[str, Any] = {}
|
||||
_cache_ts: float = 0
|
||||
@@ -129,30 +131,48 @@ async def _collect_llama_server() -> Optional[Dict[str, Any]]:
|
||||
return runtime
|
||||
|
||||
|
||||
async def _collect_worker_caps() -> Dict[str, Any]:
|
||||
"""Fetch capability flags from local Node Worker."""
|
||||
default = {"capabilities": {}, "providers": {}, "defaults": {}}
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3) as c:
|
||||
r = await c.get(f"{NODE_WORKER_URL}/caps")
|
||||
if r.status_code == 200:
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
logger.debug(f"Worker caps unavailable: {e}")
|
||||
return default
|
||||
|
||||
|
||||
def _collect_disk_inventory() -> List[Dict[str, Any]]:
|
||||
"""Scan known model directories — NOT for routing, only inventory."""
|
||||
import pathlib
|
||||
inventory: List[Dict[str, Any]] = []
|
||||
|
||||
scan_dirs = [
|
||||
("cursor_worktrees", pathlib.Path.home() / ".cursor" / "worktrees"),
|
||||
("jan_ai", pathlib.Path.home() / "Library" / "Application Support" / "Jan"),
|
||||
("ollama", pathlib.Path.home() / ".ollama" / "models"),
|
||||
("hf_cache", pathlib.Path.home() / ".cache" / "huggingface" / "hub"),
|
||||
("comfyui_main", pathlib.Path.home() / "ComfyUI" / "models"),
|
||||
("comfyui_docs", pathlib.Path.home() / "Documents" / "ComfyUI" / "models"),
|
||||
("llama_cpp", pathlib.Path.home() / "Library" / "Application Support" / "llama.cpp" / "models"),
|
||||
("hf_models", pathlib.Path.home() / "hf_models"),
|
||||
("jan_ai", pathlib.Path.home() / "Library" / "Application Support" / "Jan"),
|
||||
]
|
||||
if DISK_SCAN_PATHS:
|
||||
for p in DISK_SCAN_PATHS.split(","):
|
||||
p = p.strip()
|
||||
if p:
|
||||
scan_dirs.append(("custom", pathlib.Path(p)))
|
||||
|
||||
for source, base in scan_dirs:
|
||||
if not base.exists():
|
||||
continue
|
||||
try:
|
||||
for f in base.rglob("*"):
|
||||
if f.suffix in (".gguf", ".safetensors", ".bin", ".pt") and f.stat().st_size > 100_000_000:
|
||||
if f.suffix in (".gguf", ".safetensors", ".bin", ".pt", ".mlx") and f.stat().st_size > 50_000_000:
|
||||
inventory.append({
|
||||
"name": f.stem,
|
||||
"path": str(f.relative_to(pathlib.Path.home())),
|
||||
"path": str(f),
|
||||
"source": source,
|
||||
"size_gb": round(f.stat().st_size / 1e9, 1),
|
||||
"type": _classify_model(f.stem),
|
||||
@@ -191,6 +211,24 @@ def _build_served_models(ollama: Dict, swapper: Dict, llama: Optional[Dict]) ->
|
||||
return served
|
||||
|
||||
|
||||
def _derive_capabilities(served: List[Dict], worker_caps: Dict) -> Dict[str, Any]:
|
||||
"""Merge served model types + worker provider flags into capability map."""
|
||||
served_types = {m.get("type", "llm") for m in served}
|
||||
wc = worker_caps.get("capabilities", {})
|
||||
wp = worker_caps.get("providers", {})
|
||||
|
||||
has_vision = "vision" in served_types or wc.get("vision", False)
|
||||
return {
|
||||
"llm": "llm" in served_types or "code" in served_types,
|
||||
"vision": has_vision,
|
||||
"stt": wc.get("stt", False),
|
||||
"tts": wc.get("tts", False),
|
||||
"ocr": has_vision and wp.get("ocr", "none") != "none",
|
||||
"image": wc.get("image", False),
|
||||
"providers": wp,
|
||||
}
|
||||
|
||||
|
||||
async def _build_capabilities() -> Dict[str, Any]:
|
||||
global _cache, _cache_ts
|
||||
|
||||
@@ -200,6 +238,7 @@ async def _build_capabilities() -> Dict[str, Any]:
|
||||
ollama = await _collect_ollama()
|
||||
swapper = await _collect_swapper()
|
||||
llama = await _collect_llama_server()
|
||||
worker_caps = await _collect_worker_caps()
|
||||
disk = _collect_disk_inventory()
|
||||
served = _build_served_models(ollama, swapper, llama)
|
||||
|
||||
@@ -209,6 +248,7 @@ async def _build_capabilities() -> Dict[str, Any]:
|
||||
|
||||
node_load = await build_node_load()
|
||||
runtime_load = await build_runtime_load(runtimes)
|
||||
capabilities = _derive_capabilities(served, worker_caps)
|
||||
|
||||
result = {
|
||||
"node_id": NODE_ID,
|
||||
@@ -216,10 +256,12 @@ async def _build_capabilities() -> Dict[str, Any]:
|
||||
"runtimes": runtimes,
|
||||
"served_models": served,
|
||||
"served_count": len(served),
|
||||
"capabilities": capabilities,
|
||||
"node_load": node_load,
|
||||
"runtime_load": runtime_load,
|
||||
"inventory_only": disk,
|
||||
"inventory_count": len(disk),
|
||||
"installed_artifacts": disk,
|
||||
"installed_count": len(disk),
|
||||
"worker": worker_caps,
|
||||
}
|
||||
|
||||
_cache = result
|
||||
@@ -245,6 +287,26 @@ async def capabilities_models():
|
||||
return JSONResponse(content={"node_id": data["node_id"], "served_models": data["served_models"]})
|
||||
|
||||
|
||||
@app.get("/capabilities/caps")
|
||||
async def capabilities_caps():
|
||||
data = await _build_capabilities()
|
||||
return JSONResponse(content={
|
||||
"node_id": data["node_id"],
|
||||
"capabilities": data.get("capabilities", {}),
|
||||
"worker": data.get("worker", {}),
|
||||
})
|
||||
|
||||
|
||||
@app.get("/capabilities/installed")
|
||||
async def capabilities_installed():
|
||||
data = await _build_capabilities()
|
||||
return JSONResponse(content={
|
||||
"node_id": data["node_id"],
|
||||
"installed_artifacts": data.get("installed_artifacts", []),
|
||||
"installed_count": data.get("installed_count", 0),
|
||||
})
|
||||
|
||||
|
||||
@app.post("/capabilities/refresh")
|
||||
async def capabilities_refresh():
|
||||
global _cache_ts
|
||||
|
||||
@@ -9,3 +9,8 @@ DEFAULT_VISION = os.getenv("NODE_DEFAULT_VISION", "llava:13b")
|
||||
MAX_CONCURRENCY = int(os.getenv("NODE_WORKER_MAX_CONCURRENCY", "2"))
|
||||
MAX_PAYLOAD_BYTES = int(os.getenv("NODE_WORKER_MAX_PAYLOAD_BYTES", str(1024 * 1024)))
|
||||
PORT = int(os.getenv("PORT", "8109"))
|
||||
|
||||
STT_PROVIDER = os.getenv("STT_PROVIDER", "none")
|
||||
TTS_PROVIDER = os.getenv("TTS_PROVIDER", "none")
|
||||
OCR_PROVIDER = os.getenv("OCR_PROVIDER", "vision_prompted")
|
||||
IMAGE_PROVIDER = os.getenv("IMAGE_PROVIDER", "none")
|
||||
|
||||
@@ -41,6 +41,33 @@ async def prom_metrics():
|
||||
return {"error": "prometheus_client not installed"}
|
||||
|
||||
|
||||
@app.get("/caps")
|
||||
async def caps():
|
||||
"""Capability flags for NCS to aggregate."""
|
||||
return {
|
||||
"node_id": config.NODE_ID,
|
||||
"capabilities": {
|
||||
"llm": True,
|
||||
"vision": True,
|
||||
"stt": config.STT_PROVIDER != "none",
|
||||
"tts": config.TTS_PROVIDER != "none",
|
||||
"ocr": config.OCR_PROVIDER != "none",
|
||||
"image": config.IMAGE_PROVIDER != "none",
|
||||
},
|
||||
"providers": {
|
||||
"stt": config.STT_PROVIDER,
|
||||
"tts": config.TTS_PROVIDER,
|
||||
"ocr": config.OCR_PROVIDER,
|
||||
"image": config.IMAGE_PROVIDER,
|
||||
},
|
||||
"defaults": {
|
||||
"llm": config.DEFAULT_LLM,
|
||||
"vision": config.DEFAULT_VISION,
|
||||
},
|
||||
"concurrency": config.MAX_CONCURRENCY,
|
||||
}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
global _nats_client
|
||||
|
||||
@@ -14,7 +14,7 @@ class JobRequest(BaseModel):
|
||||
trace_id: str = ""
|
||||
actor_agent_id: str = ""
|
||||
target_agent_id: str = ""
|
||||
required_type: Literal["llm", "vision", "stt", "tts"] = "llm"
|
||||
required_type: Literal["llm", "vision", "stt", "tts", "image", "ocr"] = "llm"
|
||||
deadline_ts: int = 0
|
||||
idempotency_key: str = ""
|
||||
payload: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
135
services/node-worker/providers/stt_mlx_whisper.py
Normal file
135
services/node-worker/providers/stt_mlx_whisper.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""MLX Whisper STT provider — transcribes audio via mlx-whisper on host.
|
||||
|
||||
Runs inside Docker; delegates to MLX Whisper HTTP API on the host.
|
||||
If MLX_WHISPER_URL is not set, falls back to running mlx_whisper directly
|
||||
(only works when node-worker runs natively, not in Docker).
|
||||
"""
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger("provider.stt_mlx_whisper")
|
||||
|
||||
MLX_WHISPER_URL = os.getenv("MLX_WHISPER_URL", "")
|
||||
MLX_WHISPER_MODEL = os.getenv("MLX_WHISPER_MODEL", "mlx-community/whisper-large-v3-turbo")
|
||||
MAX_AUDIO_BYTES = int(os.getenv("STT_MAX_AUDIO_BYTES", str(25 * 1024 * 1024))) # 25MB
|
||||
|
||||
_local_model = None
|
||||
_local_lock = None
|
||||
|
||||
|
||||
def _lazy_init_local():
|
||||
"""Lazy-load mlx_whisper for native (non-Docker) execution."""
|
||||
global _local_model, _local_lock
|
||||
if _local_lock is not None:
|
||||
return
|
||||
import asyncio
|
||||
_local_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def _transcribe_local(audio_path: str, language: Optional[str]) -> Dict[str, Any]:
|
||||
"""Transcribe using local mlx_whisper (Apple Silicon only)."""
|
||||
_lazy_init_local()
|
||||
async with _local_lock:
|
||||
import mlx_whisper
|
||||
kwargs: Dict[str, Any] = {"path_or_hf_repo": MLX_WHISPER_MODEL}
|
||||
if language:
|
||||
kwargs["language"] = language
|
||||
result = mlx_whisper.transcribe(audio_path, **kwargs)
|
||||
segments = []
|
||||
for seg in result.get("segments", []):
|
||||
segments.append({
|
||||
"start": seg.get("start", 0),
|
||||
"end": seg.get("end", 0),
|
||||
"text": seg.get("text", ""),
|
||||
})
|
||||
return {
|
||||
"text": result.get("text", ""),
|
||||
"segments": segments,
|
||||
"language": result.get("language", ""),
|
||||
}
|
||||
|
||||
|
||||
async def _transcribe_remote(audio_b64: str, language: Optional[str]) -> Dict[str, Any]:
|
||||
"""Transcribe via MLX Whisper HTTP service on host."""
|
||||
payload: Dict[str, Any] = {"audio_b64": audio_b64}
|
||||
if language:
|
||||
payload["language"] = language
|
||||
|
||||
async with httpx.AsyncClient(timeout=120) as c:
|
||||
resp = await c.post(f"{MLX_WHISPER_URL}/transcribe", json=payload)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _resolve_audio(payload: Dict[str, Any]) -> tuple:
|
||||
"""Return (audio_bytes, audio_b64) from payload."""
|
||||
audio_b64 = payload.get("audio_b64", "")
|
||||
audio_url = payload.get("audio_url", "")
|
||||
|
||||
if audio_b64:
|
||||
raw = base64.b64decode(audio_b64)
|
||||
if len(raw) > MAX_AUDIO_BYTES:
|
||||
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
|
||||
return raw, audio_b64
|
||||
|
||||
if audio_url:
|
||||
if audio_url.startswith(("file://", "/")):
|
||||
path = audio_url.replace("file://", "")
|
||||
with open(path, "rb") as f:
|
||||
raw = f.read()
|
||||
if len(raw) > MAX_AUDIO_BYTES:
|
||||
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
|
||||
return raw, base64.b64encode(raw).decode()
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as c:
|
||||
resp = await c.get(audio_url)
|
||||
resp.raise_for_status()
|
||||
raw = resp.content
|
||||
if len(raw) > MAX_AUDIO_BYTES:
|
||||
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
|
||||
return raw, base64.b64encode(raw).decode()
|
||||
|
||||
raise ValueError("Either audio_b64 or audio_url is required")
|
||||
|
||||
|
||||
async def transcribe(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Canonical STT entry point.
|
||||
|
||||
Payload:
|
||||
audio_url: str (http/file) — OR —
|
||||
audio_b64: str (base64 encoded)
|
||||
language: str (optional, e.g. "uk", "en")
|
||||
format: "text" | "segments" | "json"
|
||||
"""
|
||||
language = payload.get("language")
|
||||
fmt = payload.get("format", "json")
|
||||
|
||||
audio_bytes, audio_b64 = await _resolve_audio(payload)
|
||||
|
||||
if MLX_WHISPER_URL:
|
||||
result = await _transcribe_remote(audio_b64, language)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(audio_bytes)
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
result = await _transcribe_local(tmp_path, language)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
meta = {
|
||||
"model": MLX_WHISPER_MODEL,
|
||||
"provider": "mlx_whisper",
|
||||
"device": "apple_silicon",
|
||||
}
|
||||
|
||||
if fmt == "text":
|
||||
return {"text": result.get("text", ""), "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}
|
||||
if fmt == "segments":
|
||||
return {"text": result.get("text", ""), "segments": result.get("segments", []), "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}
|
||||
return {**result, "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}
|
||||
123
services/node-worker/providers/tts_mlx_kokoro.py
Normal file
123
services/node-worker/providers/tts_mlx_kokoro.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""MLX Kokoro TTS provider — generates speech via kokoro on host.
|
||||
|
||||
Runs inside Docker; delegates to Kokoro HTTP API on the host.
|
||||
Falls back to local kokoro-onnx if running natively on Apple Silicon.
|
||||
"""
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger("provider.tts_mlx_kokoro")
|
||||
|
||||
MLX_KOKORO_URL = os.getenv("MLX_KOKORO_URL", "")
|
||||
MLX_KOKORO_MODEL = os.getenv("MLX_KOKORO_MODEL", "kokoro-v1.0")
|
||||
DEFAULT_VOICE = os.getenv("TTS_DEFAULT_VOICE", "af_heart")
|
||||
MAX_TEXT_CHARS = int(os.getenv("TTS_MAX_TEXT_CHARS", "5000"))
|
||||
DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", "24000"))
|
||||
|
||||
_local_pipeline = None
|
||||
_local_lock = None
|
||||
|
||||
|
||||
def _lazy_init_local():
|
||||
global _local_lock
|
||||
if _local_lock is not None:
|
||||
return
|
||||
import asyncio
|
||||
_local_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def _synthesize_local(text: str, voice: str, sample_rate: int) -> bytes:
|
||||
"""Synthesize via local kokoro (Apple Silicon)."""
|
||||
_lazy_init_local()
|
||||
global _local_pipeline
|
||||
async with _local_lock:
|
||||
if _local_pipeline is None:
|
||||
from kokoro import KPipeline
|
||||
_local_pipeline = KPipeline(lang_code="a")
|
||||
logger.info(f"Kokoro pipeline initialized: voice={voice}")
|
||||
|
||||
import soundfile as sf
|
||||
import io
|
||||
|
||||
generator = _local_pipeline(text, voice=voice)
|
||||
all_audio = []
|
||||
for _, _, audio in generator:
|
||||
all_audio.append(audio)
|
||||
|
||||
if not all_audio:
|
||||
raise RuntimeError("Kokoro produced no audio")
|
||||
|
||||
import numpy as np
|
||||
combined = np.concatenate(all_audio)
|
||||
|
||||
buf = io.BytesIO()
|
||||
sf.write(buf, combined, sample_rate, format="WAV")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
async def _synthesize_remote(text: str, voice: str, fmt: str, sample_rate: int) -> Dict[str, Any]:
|
||||
"""Synthesize via Kokoro HTTP service on host."""
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice": voice,
|
||||
"format": fmt,
|
||||
"sample_rate": sample_rate,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=120) as c:
|
||||
resp = await c.post(f"{MLX_KOKORO_URL}/synthesize", json=payload)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def synthesize(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Canonical TTS entry point.
|
||||
|
||||
Payload:
|
||||
text: str (required)
|
||||
voice: str (optional, default "af_heart")
|
||||
format: "wav" | "mp3" (default "wav")
|
||||
sample_rate: int (default 24000)
|
||||
"""
|
||||
text = payload.get("text", "")
|
||||
if not text:
|
||||
raise ValueError("text is required")
|
||||
if len(text) > MAX_TEXT_CHARS:
|
||||
raise ValueError(f"Text exceeds {MAX_TEXT_CHARS} chars")
|
||||
|
||||
voice = payload.get("voice", DEFAULT_VOICE)
|
||||
fmt = payload.get("format", "wav")
|
||||
sample_rate = payload.get("sample_rate", DEFAULT_SAMPLE_RATE)
|
||||
|
||||
meta = {
|
||||
"model": MLX_KOKORO_MODEL,
|
||||
"provider": "mlx_kokoro",
|
||||
"voice": voice,
|
||||
"device": "apple_silicon",
|
||||
}
|
||||
|
||||
if MLX_KOKORO_URL:
|
||||
result = await _synthesize_remote(text, voice, fmt, sample_rate)
|
||||
return {
|
||||
"audio_b64": result.get("audio_b64", ""),
|
||||
"audio_url": result.get("audio_url", ""),
|
||||
"format": fmt,
|
||||
"meta": meta,
|
||||
"provider": "mlx_kokoro",
|
||||
"model": MLX_KOKORO_MODEL,
|
||||
}
|
||||
|
||||
wav_bytes = await _synthesize_local(text, voice, sample_rate)
|
||||
audio_b64 = base64.b64encode(wav_bytes).decode()
|
||||
|
||||
return {
|
||||
"audio_b64": audio_b64,
|
||||
"format": "wav",
|
||||
"meta": meta,
|
||||
"provider": "mlx_kokoro",
|
||||
"model": MLX_KOKORO_MODEL,
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import config
|
||||
from models import JobRequest, JobResponse, JobError
|
||||
from idempotency import IdempotencyStore
|
||||
from providers import ollama, ollama_vision
|
||||
from providers import stt_mlx_whisper, tts_mlx_kokoro
|
||||
import fabric_metrics as fm
|
||||
|
||||
logger = logging.getLogger("node-worker")
|
||||
@@ -34,6 +35,7 @@ async def start(nats_client):
|
||||
f"node.{nid}.stt.request",
|
||||
f"node.{nid}.tts.request",
|
||||
f"node.{nid}.image.request",
|
||||
f"node.{nid}.ocr.request",
|
||||
]
|
||||
for subj in subjects:
|
||||
await nats_client.subscribe(subj, cb=_handle_request)
|
||||
@@ -175,14 +177,52 @@ async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
elif job.required_type in ("stt", "tts", "image"):
|
||||
elif job.required_type == "stt":
|
||||
if config.STT_PROVIDER == "none":
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error",
|
||||
error=JobError(code="NOT_AVAILABLE", message="STT not configured on this node"),
|
||||
)
|
||||
result = await asyncio.wait_for(
|
||||
stt_mlx_whisper.transcribe(payload), timeout=timeout_s,
|
||||
)
|
||||
elif job.required_type == "tts":
|
||||
if config.TTS_PROVIDER == "none":
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error",
|
||||
error=JobError(code="NOT_AVAILABLE", message="TTS not configured on this node"),
|
||||
)
|
||||
result = await asyncio.wait_for(
|
||||
tts_mlx_kokoro.synthesize(payload), timeout=timeout_s,
|
||||
)
|
||||
elif job.required_type == "ocr":
|
||||
if config.OCR_PROVIDER == "none":
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error",
|
||||
error=JobError(code="NOT_AVAILABLE", message="OCR not configured on this node"),
|
||||
)
|
||||
ocr_prompt = payload.get("prompt", "Extract all text from this image. Return JSON: {\"text\": \"...\", \"language\": \"...\"}")
|
||||
result = await asyncio.wait_for(
|
||||
ollama_vision.infer(
|
||||
images=payload.get("images"),
|
||||
prompt=ocr_prompt,
|
||||
model=model or config.DEFAULT_VISION,
|
||||
system="You are an OCR engine. Extract text precisely. Return valid JSON only.",
|
||||
max_tokens=hints.get("max_tokens", 4096),
|
||||
temperature=0.05,
|
||||
timeout_s=timeout_s,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
result["provider"] = "vision_prompted_ocr"
|
||||
elif job.required_type == "image":
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error",
|
||||
error=JobError(
|
||||
code="NOT_YET_IMPLEMENTED",
|
||||
message=f"{job.required_type} adapter coming soon; use direct runtime API for now",
|
||||
),
|
||||
error=JobError(code="NOT_YET_IMPLEMENTED", message="Image adapter pending P3.7"),
|
||||
)
|
||||
else:
|
||||
return JobResponse(
|
||||
|
||||
@@ -100,8 +100,8 @@ async def _discover_remote_nodes() -> List[Dict[str, Any]]:
|
||||
sub = await _nats_client.subscribe(inbox)
|
||||
|
||||
try:
|
||||
await _nats_client.publish_request(
|
||||
"node.*.capabilities.get", inbox, b""
|
||||
await _nats_client.publish(
|
||||
CAPS_DISCOVERY_SUBJECT, b"", reply=inbox,
|
||||
)
|
||||
await _nats_client.flush()
|
||||
|
||||
@@ -183,6 +183,7 @@ async def get_global_capabilities(force: bool = False) -> Dict[str, Any]:
|
||||
def _build_global_view() -> Dict[str, Any]:
|
||||
"""Build a unified view from all cached node capabilities."""
|
||||
all_served: List[Dict[str, Any]] = []
|
||||
global_caps: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
for node_id, caps in _node_cache.items():
|
||||
is_local = (node_id.lower() == LOCAL_NODE_ID.lower())
|
||||
@@ -194,16 +195,27 @@ def _build_global_view() -> Dict[str, Any]:
|
||||
"local": is_local,
|
||||
"node_age_s": round(age, 1),
|
||||
})
|
||||
node_caps = caps.get("capabilities", {})
|
||||
if node_caps:
|
||||
global_caps[node_id] = {
|
||||
k: v for k, v in node_caps.items() if k != "providers"
|
||||
}
|
||||
|
||||
all_served.sort(key=lambda m: (0 if m.get("local") else 1, m.get("name", "")))
|
||||
|
||||
return {
|
||||
"local_node": LOCAL_NODE_ID,
|
||||
"nodes": {nid: {"node_id": nid, "served_count": len(c.get("served_models", [])),
|
||||
"age_s": round(time.time() - _node_timestamps.get(nid, 0), 1)}
|
||||
for nid, c in _node_cache.items()},
|
||||
"nodes": {nid: {
|
||||
"node_id": nid,
|
||||
"served_count": len(c.get("served_models", [])),
|
||||
"installed_count": c.get("installed_count", 0),
|
||||
"capabilities": c.get("capabilities", {}),
|
||||
"node_load": c.get("node_load", {}),
|
||||
"age_s": round(time.time() - _node_timestamps.get(nid, 0), 1),
|
||||
} for nid, c in _node_cache.items()},
|
||||
"served_models": all_served,
|
||||
"served_count": len(all_served),
|
||||
"capabilities_by_node": global_caps,
|
||||
"node_count": len(_node_cache),
|
||||
"updated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
}
|
||||
@@ -214,6 +226,44 @@ def get_cached_global() -> Dict[str, Any]:
|
||||
return _build_global_view()
|
||||
|
||||
|
||||
async def require_fresh_caps(ttl: int = 30) -> Optional[Dict[str, Any]]:
|
||||
"""Preflight: return global caps only if fresh enough.
|
||||
|
||||
Returns None if NCS data is stale beyond ttl — caller should use
|
||||
safe fallback instead of making routing decisions on outdated info.
|
||||
"""
|
||||
if not _node_timestamps:
|
||||
gcaps = await get_global_capabilities(force=True)
|
||||
if not _node_timestamps:
|
||||
return None
|
||||
return gcaps
|
||||
oldest = min(_node_timestamps.values())
|
||||
if (time.time() - oldest) > ttl:
|
||||
gcaps = await get_global_capabilities(force=True)
|
||||
oldest = min(_node_timestamps.values()) if _node_timestamps else 0
|
||||
if (time.time() - oldest) > ttl:
|
||||
logger.warning("[preflight] caps stale after refresh, age=%ds", int(time.time() - oldest))
|
||||
return None
|
||||
return gcaps
|
||||
return _build_global_view()
|
||||
|
||||
|
||||
def find_nodes_with_capability(cap: str) -> List[str]:
|
||||
"""Return node IDs that have a given capability enabled."""
|
||||
result = []
|
||||
for nid, caps in _node_cache.items():
|
||||
node_caps = caps.get("capabilities", {})
|
||||
if node_caps.get(cap, False):
|
||||
result.append(nid)
|
||||
return result
|
||||
|
||||
|
||||
def get_node_load(node_id: str) -> Dict[str, Any]:
|
||||
"""Get cached node_load for a specific node."""
|
||||
caps = _node_cache.get(node_id, {})
|
||||
return caps.get("node_load", {})
|
||||
|
||||
|
||||
async def send_offload_request(
|
||||
node_id: str,
|
||||
request_type: str,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import Response
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Literal, Optional, Dict, Any, List
|
||||
import asyncio
|
||||
@@ -3542,6 +3542,7 @@ async def documents_versions(doc_id: str, agent_id: str, limit: int = 20):
|
||||
async def list_available_models():
|
||||
"""List all available models from NCS (global capabilities pool)."""
|
||||
models = []
|
||||
caps_by_node = {}
|
||||
|
||||
try:
|
||||
from global_capabilities_client import get_global_capabilities
|
||||
@@ -3555,6 +3556,7 @@ async def list_available_models():
|
||||
"size_gb": m.get("size_gb"),
|
||||
"status": "served",
|
||||
})
|
||||
caps_by_node = pool.get("capabilities_by_node", {})
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot get NCS global models: {e}")
|
||||
|
||||
@@ -3572,7 +3574,110 @@ async def list_available_models():
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot get Ollama models: {e}")
|
||||
|
||||
return {"models": models, "total": len(models)}
|
||||
return {
|
||||
"models": models,
|
||||
"total": len(models),
|
||||
"capabilities_by_node": caps_by_node,
|
||||
}
|
||||
|
||||
|
||||
# ── Capability-based offload routing ────────────────────────────────────────
|
||||
|
||||
@app.post("/v1/capability/{cap_type}")
|
||||
async def capability_offload(cap_type: str, request: Request):
|
||||
"""Route a capability request (stt/tts/ocr/image) to the best node.
|
||||
|
||||
Router selects the node based on capabilities_by_node, circuit breaker,
|
||||
and node_load — no static assumptions about which node has what.
|
||||
"""
|
||||
valid_types = {"stt", "tts", "ocr", "image"}
|
||||
if cap_type not in valid_types:
|
||||
return JSONResponse(status_code=400, content={
|
||||
"error": f"Invalid capability type: {cap_type}. Valid: {sorted(valid_types)}",
|
||||
})
|
||||
|
||||
if not NCS_AVAILABLE or not global_capabilities_client:
|
||||
return JSONResponse(status_code=503, content={
|
||||
"error": "NCS not available — cannot route capability requests",
|
||||
})
|
||||
|
||||
gcaps = await global_capabilities_client.require_fresh_caps(ttl=30)
|
||||
if gcaps is None:
|
||||
return JSONResponse(status_code=503, content={
|
||||
"error": "NCS caps stale — preflight failed, refusing to route",
|
||||
})
|
||||
|
||||
eligible_nodes = global_capabilities_client.find_nodes_with_capability(cap_type)
|
||||
if not eligible_nodes:
|
||||
return JSONResponse(status_code=404, content={
|
||||
"error": f"No node with capability '{cap_type}' available",
|
||||
"capabilities_by_node": gcaps.get("capabilities_by_node", {}),
|
||||
})
|
||||
|
||||
unavailable = offload_client.get_unavailable_nodes(cap_type) if offload_client else set()
|
||||
available = [n for n in eligible_nodes if n.lower() not in {u.lower() for u in unavailable}]
|
||||
if not available:
|
||||
return JSONResponse(status_code=503, content={
|
||||
"error": f"All nodes with '{cap_type}' are circuit-broken",
|
||||
"eligible": eligible_nodes,
|
||||
"unavailable": list(unavailable),
|
||||
})
|
||||
|
||||
best_node = available[0]
|
||||
if len(available) > 1:
|
||||
loads = []
|
||||
for nid in available:
|
||||
nl = global_capabilities_client.get_node_load(nid)
|
||||
score = nl.get("inflight", 0) * 10
|
||||
if nl.get("mem_pressure") == "high":
|
||||
score += 100
|
||||
loads.append((score, nid))
|
||||
loads.sort()
|
||||
best_node = loads[0][1]
|
||||
|
||||
payload = await request.json()
|
||||
logger.info(f"[cap.offload] type={cap_type} → node={best_node} (of {available})")
|
||||
|
||||
nats_ok = nc is not None and nats_available
|
||||
if nats_ok and offload_client:
|
||||
import uuid as _uuid
|
||||
job = {
|
||||
"job_id": str(_uuid.uuid4()),
|
||||
"required_type": cap_type,
|
||||
"payload": payload,
|
||||
"deadline_ts": int(time.time() * 1000) + 60000,
|
||||
"hints": payload.pop("hints", {}),
|
||||
}
|
||||
result = await offload_client.offload_infer(
|
||||
nats_client=nc, node_id=best_node, required_type=cap_type,
|
||||
job_payload=job, timeout_ms=60000,
|
||||
)
|
||||
if result and result.get("status") == "ok":
|
||||
return JSONResponse(content=result.get("result", result))
|
||||
error = result.get("error", {}) if result else {}
|
||||
return JSONResponse(status_code=502, content={
|
||||
"error": error.get("message", f"Offload to {best_node} failed"),
|
||||
"code": error.get("code", "OFFLOAD_FAILED"),
|
||||
"node": best_node,
|
||||
})
|
||||
|
||||
return JSONResponse(status_code=503, content={
|
||||
"error": "NATS not connected — cannot offload",
|
||||
})
|
||||
|
||||
|
||||
@app.get("/v1/capabilities")
|
||||
async def list_global_capabilities():
|
||||
"""Return full capabilities view across all nodes."""
|
||||
if not NCS_AVAILABLE or not global_capabilities_client:
|
||||
return JSONResponse(status_code=503, content={"error": "NCS not available"})
|
||||
gcaps = await global_capabilities_client.get_global_capabilities()
|
||||
return JSONResponse(content={
|
||||
"node_count": gcaps.get("node_count", 0),
|
||||
"nodes": gcaps.get("nodes", {}),
|
||||
"capabilities_by_node": gcaps.get("capabilities_by_node", {}),
|
||||
"served_count": gcaps.get("served_count", 0),
|
||||
})
|
||||
|
||||
|
||||
@app.get("/v1/agromatrix/shared-memory/pending")
|
||||
|
||||
@@ -81,7 +81,7 @@ def get_unavailable_nodes(req_type: str) -> Set[str]:
|
||||
async def offload_infer(
|
||||
nats_client,
|
||||
node_id: str,
|
||||
required_type: Literal["llm", "vision", "stt", "tts"],
|
||||
required_type: Literal["llm", "vision", "stt", "tts", "ocr", "image"],
|
||||
job_payload: Dict[str, Any],
|
||||
timeout_ms: int = 25000,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
|
||||
Reference in New Issue
Block a user