P3.5-P3.7: 2-layer inventory, capability routing, STT/TTS adapters, Dev Contract
NCS:
- _collect_worker_caps() fetches capability flags from node-worker /caps
- _derive_capabilities() merges served model types + worker provider flags
- installed_artifacts replaces inventory_only (disk scan with DISK_SCAN_PATHS env)
- New endpoints: /capabilities/caps, /capabilities/installed
Node Worker:
- STT_PROVIDER, TTS_PROVIDER, OCR_PROVIDER, IMAGE_PROVIDER env flags
- /caps endpoint returns capabilities + providers for NCS aggregation
- STT adapter (providers/stt_mlx_whisper.py) — remote + local mode
- TTS adapter (providers/tts_mlx_kokoro.py) — remote + local mode
- OCR handler via vision_prompted (ollama_vision with OCR prompt)
- NATS subjects: node.{id}.stt/tts/ocr/image.request
Router:
- POST /v1/capability/{stt,tts,ocr,image} — capability-based offload routing
- GET /v1/capabilities — global view with capabilities_by_node
- require_fresh_caps(ttl) preflight guard
- find_nodes_with_capability(cap) + load-based node selection
Ops:
- ops/fabric_snapshot.py — full runtime snapshot collector
- ops/fabric_preflight.sh — quick check + snapshot save + diff
- docs/fabric_contract.md — Dev Contract v0.1 (preflight-first)
- tests/test_fabric_contract.py — CI enforcement (6 tests)
Made-with: Cursor
This commit is contained in:
@@ -9,3 +9,8 @@ DEFAULT_VISION = os.getenv("NODE_DEFAULT_VISION", "llava:13b")
|
||||
MAX_CONCURRENCY = int(os.getenv("NODE_WORKER_MAX_CONCURRENCY", "2"))
|
||||
MAX_PAYLOAD_BYTES = int(os.getenv("NODE_WORKER_MAX_PAYLOAD_BYTES", str(1024 * 1024)))
|
||||
PORT = int(os.getenv("PORT", "8109"))
|
||||
|
||||
STT_PROVIDER = os.getenv("STT_PROVIDER", "none")
|
||||
TTS_PROVIDER = os.getenv("TTS_PROVIDER", "none")
|
||||
OCR_PROVIDER = os.getenv("OCR_PROVIDER", "vision_prompted")
|
||||
IMAGE_PROVIDER = os.getenv("IMAGE_PROVIDER", "none")
|
||||
|
||||
@@ -41,6 +41,33 @@ async def prom_metrics():
|
||||
return {"error": "prometheus_client not installed"}
|
||||
|
||||
|
||||
@app.get("/caps")
|
||||
async def caps():
|
||||
"""Capability flags for NCS to aggregate."""
|
||||
return {
|
||||
"node_id": config.NODE_ID,
|
||||
"capabilities": {
|
||||
"llm": True,
|
||||
"vision": True,
|
||||
"stt": config.STT_PROVIDER != "none",
|
||||
"tts": config.TTS_PROVIDER != "none",
|
||||
"ocr": config.OCR_PROVIDER != "none",
|
||||
"image": config.IMAGE_PROVIDER != "none",
|
||||
},
|
||||
"providers": {
|
||||
"stt": config.STT_PROVIDER,
|
||||
"tts": config.TTS_PROVIDER,
|
||||
"ocr": config.OCR_PROVIDER,
|
||||
"image": config.IMAGE_PROVIDER,
|
||||
},
|
||||
"defaults": {
|
||||
"llm": config.DEFAULT_LLM,
|
||||
"vision": config.DEFAULT_VISION,
|
||||
},
|
||||
"concurrency": config.MAX_CONCURRENCY,
|
||||
}
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
global _nats_client
|
||||
|
||||
@@ -14,7 +14,7 @@ class JobRequest(BaseModel):
|
||||
trace_id: str = ""
|
||||
actor_agent_id: str = ""
|
||||
target_agent_id: str = ""
|
||||
required_type: Literal["llm", "vision", "stt", "tts"] = "llm"
|
||||
required_type: Literal["llm", "vision", "stt", "tts", "image", "ocr"] = "llm"
|
||||
deadline_ts: int = 0
|
||||
idempotency_key: str = ""
|
||||
payload: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
135
services/node-worker/providers/stt_mlx_whisper.py
Normal file
135
services/node-worker/providers/stt_mlx_whisper.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""MLX Whisper STT provider — transcribes audio via mlx-whisper on host.
|
||||
|
||||
Runs inside Docker; delegates to MLX Whisper HTTP API on the host.
|
||||
If MLX_WHISPER_URL is not set, falls back to running mlx_whisper directly
|
||||
(only works when node-worker runs natively, not in Docker).
|
||||
"""
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger("provider.stt_mlx_whisper")
|
||||
|
||||
MLX_WHISPER_URL = os.getenv("MLX_WHISPER_URL", "")
|
||||
MLX_WHISPER_MODEL = os.getenv("MLX_WHISPER_MODEL", "mlx-community/whisper-large-v3-turbo")
|
||||
MAX_AUDIO_BYTES = int(os.getenv("STT_MAX_AUDIO_BYTES", str(25 * 1024 * 1024))) # 25MB
|
||||
|
||||
_local_model = None
|
||||
_local_lock = None
|
||||
|
||||
|
||||
def _lazy_init_local():
|
||||
"""Lazy-load mlx_whisper for native (non-Docker) execution."""
|
||||
global _local_model, _local_lock
|
||||
if _local_lock is not None:
|
||||
return
|
||||
import asyncio
|
||||
_local_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def _transcribe_local(audio_path: str, language: Optional[str]) -> Dict[str, Any]:
|
||||
"""Transcribe using local mlx_whisper (Apple Silicon only)."""
|
||||
_lazy_init_local()
|
||||
async with _local_lock:
|
||||
import mlx_whisper
|
||||
kwargs: Dict[str, Any] = {"path_or_hf_repo": MLX_WHISPER_MODEL}
|
||||
if language:
|
||||
kwargs["language"] = language
|
||||
result = mlx_whisper.transcribe(audio_path, **kwargs)
|
||||
segments = []
|
||||
for seg in result.get("segments", []):
|
||||
segments.append({
|
||||
"start": seg.get("start", 0),
|
||||
"end": seg.get("end", 0),
|
||||
"text": seg.get("text", ""),
|
||||
})
|
||||
return {
|
||||
"text": result.get("text", ""),
|
||||
"segments": segments,
|
||||
"language": result.get("language", ""),
|
||||
}
|
||||
|
||||
|
||||
async def _transcribe_remote(audio_b64: str, language: Optional[str]) -> Dict[str, Any]:
|
||||
"""Transcribe via MLX Whisper HTTP service on host."""
|
||||
payload: Dict[str, Any] = {"audio_b64": audio_b64}
|
||||
if language:
|
||||
payload["language"] = language
|
||||
|
||||
async with httpx.AsyncClient(timeout=120) as c:
|
||||
resp = await c.post(f"{MLX_WHISPER_URL}/transcribe", json=payload)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def _resolve_audio(payload: Dict[str, Any]) -> tuple:
|
||||
"""Return (audio_bytes, audio_b64) from payload."""
|
||||
audio_b64 = payload.get("audio_b64", "")
|
||||
audio_url = payload.get("audio_url", "")
|
||||
|
||||
if audio_b64:
|
||||
raw = base64.b64decode(audio_b64)
|
||||
if len(raw) > MAX_AUDIO_BYTES:
|
||||
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
|
||||
return raw, audio_b64
|
||||
|
||||
if audio_url:
|
||||
if audio_url.startswith(("file://", "/")):
|
||||
path = audio_url.replace("file://", "")
|
||||
with open(path, "rb") as f:
|
||||
raw = f.read()
|
||||
if len(raw) > MAX_AUDIO_BYTES:
|
||||
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
|
||||
return raw, base64.b64encode(raw).decode()
|
||||
|
||||
async with httpx.AsyncClient(timeout=30) as c:
|
||||
resp = await c.get(audio_url)
|
||||
resp.raise_for_status()
|
||||
raw = resp.content
|
||||
if len(raw) > MAX_AUDIO_BYTES:
|
||||
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
|
||||
return raw, base64.b64encode(raw).decode()
|
||||
|
||||
raise ValueError("Either audio_b64 or audio_url is required")
|
||||
|
||||
|
||||
async def transcribe(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Canonical STT entry point.
|
||||
|
||||
Payload:
|
||||
audio_url: str (http/file) — OR —
|
||||
audio_b64: str (base64 encoded)
|
||||
language: str (optional, e.g. "uk", "en")
|
||||
format: "text" | "segments" | "json"
|
||||
"""
|
||||
language = payload.get("language")
|
||||
fmt = payload.get("format", "json")
|
||||
|
||||
audio_bytes, audio_b64 = await _resolve_audio(payload)
|
||||
|
||||
if MLX_WHISPER_URL:
|
||||
result = await _transcribe_remote(audio_b64, language)
|
||||
else:
|
||||
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
||||
tmp.write(audio_bytes)
|
||||
tmp_path = tmp.name
|
||||
try:
|
||||
result = await _transcribe_local(tmp_path, language)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
|
||||
meta = {
|
||||
"model": MLX_WHISPER_MODEL,
|
||||
"provider": "mlx_whisper",
|
||||
"device": "apple_silicon",
|
||||
}
|
||||
|
||||
if fmt == "text":
|
||||
return {"text": result.get("text", ""), "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}
|
||||
if fmt == "segments":
|
||||
return {"text": result.get("text", ""), "segments": result.get("segments", []), "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}
|
||||
return {**result, "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}
|
||||
123
services/node-worker/providers/tts_mlx_kokoro.py
Normal file
123
services/node-worker/providers/tts_mlx_kokoro.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""MLX Kokoro TTS provider — generates speech via kokoro on host.
|
||||
|
||||
Runs inside Docker; delegates to Kokoro HTTP API on the host.
|
||||
Falls back to local kokoro-onnx if running natively on Apple Silicon.
|
||||
"""
|
||||
import base64
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger("provider.tts_mlx_kokoro")
|
||||
|
||||
MLX_KOKORO_URL = os.getenv("MLX_KOKORO_URL", "")
|
||||
MLX_KOKORO_MODEL = os.getenv("MLX_KOKORO_MODEL", "kokoro-v1.0")
|
||||
DEFAULT_VOICE = os.getenv("TTS_DEFAULT_VOICE", "af_heart")
|
||||
MAX_TEXT_CHARS = int(os.getenv("TTS_MAX_TEXT_CHARS", "5000"))
|
||||
DEFAULT_SAMPLE_RATE = int(os.getenv("TTS_SAMPLE_RATE", "24000"))
|
||||
|
||||
_local_pipeline = None
|
||||
_local_lock = None
|
||||
|
||||
|
||||
def _lazy_init_local():
|
||||
global _local_lock
|
||||
if _local_lock is not None:
|
||||
return
|
||||
import asyncio
|
||||
_local_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def _synthesize_local(text: str, voice: str, sample_rate: int) -> bytes:
|
||||
"""Synthesize via local kokoro (Apple Silicon)."""
|
||||
_lazy_init_local()
|
||||
global _local_pipeline
|
||||
async with _local_lock:
|
||||
if _local_pipeline is None:
|
||||
from kokoro import KPipeline
|
||||
_local_pipeline = KPipeline(lang_code="a")
|
||||
logger.info(f"Kokoro pipeline initialized: voice={voice}")
|
||||
|
||||
import soundfile as sf
|
||||
import io
|
||||
|
||||
generator = _local_pipeline(text, voice=voice)
|
||||
all_audio = []
|
||||
for _, _, audio in generator:
|
||||
all_audio.append(audio)
|
||||
|
||||
if not all_audio:
|
||||
raise RuntimeError("Kokoro produced no audio")
|
||||
|
||||
import numpy as np
|
||||
combined = np.concatenate(all_audio)
|
||||
|
||||
buf = io.BytesIO()
|
||||
sf.write(buf, combined, sample_rate, format="WAV")
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
async def _synthesize_remote(text: str, voice: str, fmt: str, sample_rate: int) -> Dict[str, Any]:
|
||||
"""Synthesize via Kokoro HTTP service on host."""
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice": voice,
|
||||
"format": fmt,
|
||||
"sample_rate": sample_rate,
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=120) as c:
|
||||
resp = await c.post(f"{MLX_KOKORO_URL}/synthesize", json=payload)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
|
||||
async def synthesize(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Canonical TTS entry point.
|
||||
|
||||
Payload:
|
||||
text: str (required)
|
||||
voice: str (optional, default "af_heart")
|
||||
format: "wav" | "mp3" (default "wav")
|
||||
sample_rate: int (default 24000)
|
||||
"""
|
||||
text = payload.get("text", "")
|
||||
if not text:
|
||||
raise ValueError("text is required")
|
||||
if len(text) > MAX_TEXT_CHARS:
|
||||
raise ValueError(f"Text exceeds {MAX_TEXT_CHARS} chars")
|
||||
|
||||
voice = payload.get("voice", DEFAULT_VOICE)
|
||||
fmt = payload.get("format", "wav")
|
||||
sample_rate = payload.get("sample_rate", DEFAULT_SAMPLE_RATE)
|
||||
|
||||
meta = {
|
||||
"model": MLX_KOKORO_MODEL,
|
||||
"provider": "mlx_kokoro",
|
||||
"voice": voice,
|
||||
"device": "apple_silicon",
|
||||
}
|
||||
|
||||
if MLX_KOKORO_URL:
|
||||
result = await _synthesize_remote(text, voice, fmt, sample_rate)
|
||||
return {
|
||||
"audio_b64": result.get("audio_b64", ""),
|
||||
"audio_url": result.get("audio_url", ""),
|
||||
"format": fmt,
|
||||
"meta": meta,
|
||||
"provider": "mlx_kokoro",
|
||||
"model": MLX_KOKORO_MODEL,
|
||||
}
|
||||
|
||||
wav_bytes = await _synthesize_local(text, voice, sample_rate)
|
||||
audio_b64 = base64.b64encode(wav_bytes).decode()
|
||||
|
||||
return {
|
||||
"audio_b64": audio_b64,
|
||||
"format": "wav",
|
||||
"meta": meta,
|
||||
"provider": "mlx_kokoro",
|
||||
"model": MLX_KOKORO_MODEL,
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import config
|
||||
from models import JobRequest, JobResponse, JobError
|
||||
from idempotency import IdempotencyStore
|
||||
from providers import ollama, ollama_vision
|
||||
from providers import stt_mlx_whisper, tts_mlx_kokoro
|
||||
import fabric_metrics as fm
|
||||
|
||||
logger = logging.getLogger("node-worker")
|
||||
@@ -34,6 +35,7 @@ async def start(nats_client):
|
||||
f"node.{nid}.stt.request",
|
||||
f"node.{nid}.tts.request",
|
||||
f"node.{nid}.image.request",
|
||||
f"node.{nid}.ocr.request",
|
||||
]
|
||||
for subj in subjects:
|
||||
await nats_client.subscribe(subj, cb=_handle_request)
|
||||
@@ -175,14 +177,52 @@ async def _execute(job: JobRequest, remaining_ms: int) -> JobResponse:
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
elif job.required_type in ("stt", "tts", "image"):
|
||||
elif job.required_type == "stt":
|
||||
if config.STT_PROVIDER == "none":
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error",
|
||||
error=JobError(code="NOT_AVAILABLE", message="STT not configured on this node"),
|
||||
)
|
||||
result = await asyncio.wait_for(
|
||||
stt_mlx_whisper.transcribe(payload), timeout=timeout_s,
|
||||
)
|
||||
elif job.required_type == "tts":
|
||||
if config.TTS_PROVIDER == "none":
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error",
|
||||
error=JobError(code="NOT_AVAILABLE", message="TTS not configured on this node"),
|
||||
)
|
||||
result = await asyncio.wait_for(
|
||||
tts_mlx_kokoro.synthesize(payload), timeout=timeout_s,
|
||||
)
|
||||
elif job.required_type == "ocr":
|
||||
if config.OCR_PROVIDER == "none":
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error",
|
||||
error=JobError(code="NOT_AVAILABLE", message="OCR not configured on this node"),
|
||||
)
|
||||
ocr_prompt = payload.get("prompt", "Extract all text from this image. Return JSON: {\"text\": \"...\", \"language\": \"...\"}")
|
||||
result = await asyncio.wait_for(
|
||||
ollama_vision.infer(
|
||||
images=payload.get("images"),
|
||||
prompt=ocr_prompt,
|
||||
model=model or config.DEFAULT_VISION,
|
||||
system="You are an OCR engine. Extract text precisely. Return valid JSON only.",
|
||||
max_tokens=hints.get("max_tokens", 4096),
|
||||
temperature=0.05,
|
||||
timeout_s=timeout_s,
|
||||
),
|
||||
timeout=timeout_s,
|
||||
)
|
||||
result["provider"] = "vision_prompted_ocr"
|
||||
elif job.required_type == "image":
|
||||
return JobResponse(
|
||||
job_id=job.job_id, trace_id=job.trace_id, node_id=config.NODE_ID,
|
||||
status="error",
|
||||
error=JobError(
|
||||
code="NOT_YET_IMPLEMENTED",
|
||||
message=f"{job.required_type} adapter coming soon; use direct runtime API for now",
|
||||
),
|
||||
error=JobError(code="NOT_YET_IMPLEMENTED", message="Image adapter pending P3.7"),
|
||||
)
|
||||
else:
|
||||
return JobResponse(
|
||||
|
||||
Reference in New Issue
Block a user