NCS:
- _collect_worker_caps() fetches capability flags from node-worker /caps
- _derive_capabilities() merges served model types + worker provider flags
- installed_artifacts replaces inventory_only (disk scan with DISK_SCAN_PATHS env)
- New endpoints: /capabilities/caps, /capabilities/installed
Node Worker:
- STT_PROVIDER, TTS_PROVIDER, OCR_PROVIDER, IMAGE_PROVIDER env flags
- /caps endpoint returns capabilities + providers for NCS aggregation
- STT adapter (providers/stt_mlx_whisper.py) — remote + local mode
- TTS adapter (providers/tts_mlx_kokoro.py) — remote + local mode
- OCR handler via vision_prompted (ollama_vision with OCR prompt)
- NATS subjects: node.{id}.stt/tts/ocr/image.request
Router:
- POST /v1/capability/{stt,tts,ocr,image} — capability-based offload routing
- GET /v1/capabilities — global view with capabilities_by_node
- require_fresh_caps(ttl) preflight guard
- find_nodes_with_capability(cap) + load-based node selection
Ops:
- ops/fabric_snapshot.py — full runtime snapshot collector
- ops/fabric_preflight.sh — quick check + snapshot save + diff
- docs/fabric_contract.md — Dev Contract v0.1 (preflight-first)
- tests/test_fabric_contract.py — CI enforcement (6 tests)
Made-with: Cursor
136 lines
4.7 KiB
Python
136 lines
4.7 KiB
Python
"""MLX Whisper STT provider — transcribes audio via mlx-whisper on host.
|
|
|
|
Runs inside Docker; delegates to MLX Whisper HTTP API on the host.
|
|
If MLX_WHISPER_URL is not set, falls back to running mlx_whisper directly
|
|
(only works when node-worker runs natively, not in Docker).
|
|
"""
|
|
import base64
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
from typing import Any, Dict, Optional
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger("provider.stt_mlx_whisper")
|
|
|
|
MLX_WHISPER_URL = os.getenv("MLX_WHISPER_URL", "")
|
|
MLX_WHISPER_MODEL = os.getenv("MLX_WHISPER_MODEL", "mlx-community/whisper-large-v3-turbo")
|
|
MAX_AUDIO_BYTES = int(os.getenv("STT_MAX_AUDIO_BYTES", str(25 * 1024 * 1024))) # 25MB
|
|
|
|
_local_model = None
|
|
_local_lock = None
|
|
|
|
|
|
def _lazy_init_local():
|
|
"""Lazy-load mlx_whisper for native (non-Docker) execution."""
|
|
global _local_model, _local_lock
|
|
if _local_lock is not None:
|
|
return
|
|
import asyncio
|
|
_local_lock = asyncio.Lock()
|
|
|
|
|
|
async def _transcribe_local(audio_path: str, language: Optional[str]) -> Dict[str, Any]:
|
|
"""Transcribe using local mlx_whisper (Apple Silicon only)."""
|
|
_lazy_init_local()
|
|
async with _local_lock:
|
|
import mlx_whisper
|
|
kwargs: Dict[str, Any] = {"path_or_hf_repo": MLX_WHISPER_MODEL}
|
|
if language:
|
|
kwargs["language"] = language
|
|
result = mlx_whisper.transcribe(audio_path, **kwargs)
|
|
segments = []
|
|
for seg in result.get("segments", []):
|
|
segments.append({
|
|
"start": seg.get("start", 0),
|
|
"end": seg.get("end", 0),
|
|
"text": seg.get("text", ""),
|
|
})
|
|
return {
|
|
"text": result.get("text", ""),
|
|
"segments": segments,
|
|
"language": result.get("language", ""),
|
|
}
|
|
|
|
|
|
async def _transcribe_remote(audio_b64: str, language: Optional[str]) -> Dict[str, Any]:
|
|
"""Transcribe via MLX Whisper HTTP service on host."""
|
|
payload: Dict[str, Any] = {"audio_b64": audio_b64}
|
|
if language:
|
|
payload["language"] = language
|
|
|
|
async with httpx.AsyncClient(timeout=120) as c:
|
|
resp = await c.post(f"{MLX_WHISPER_URL}/transcribe", json=payload)
|
|
resp.raise_for_status()
|
|
return resp.json()
|
|
|
|
|
|
async def _resolve_audio(payload: Dict[str, Any]) -> tuple:
|
|
"""Return (audio_bytes, audio_b64) from payload."""
|
|
audio_b64 = payload.get("audio_b64", "")
|
|
audio_url = payload.get("audio_url", "")
|
|
|
|
if audio_b64:
|
|
raw = base64.b64decode(audio_b64)
|
|
if len(raw) > MAX_AUDIO_BYTES:
|
|
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
|
|
return raw, audio_b64
|
|
|
|
if audio_url:
|
|
if audio_url.startswith(("file://", "/")):
|
|
path = audio_url.replace("file://", "")
|
|
with open(path, "rb") as f:
|
|
raw = f.read()
|
|
if len(raw) > MAX_AUDIO_BYTES:
|
|
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
|
|
return raw, base64.b64encode(raw).decode()
|
|
|
|
async with httpx.AsyncClient(timeout=30) as c:
|
|
resp = await c.get(audio_url)
|
|
resp.raise_for_status()
|
|
raw = resp.content
|
|
if len(raw) > MAX_AUDIO_BYTES:
|
|
raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes")
|
|
return raw, base64.b64encode(raw).decode()
|
|
|
|
raise ValueError("Either audio_b64 or audio_url is required")
|
|
|
|
|
|
async def transcribe(payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Canonical STT entry point.
|
|
|
|
Payload:
|
|
audio_url: str (http/file) — OR —
|
|
audio_b64: str (base64 encoded)
|
|
language: str (optional, e.g. "uk", "en")
|
|
format: "text" | "segments" | "json"
|
|
"""
|
|
language = payload.get("language")
|
|
fmt = payload.get("format", "json")
|
|
|
|
audio_bytes, audio_b64 = await _resolve_audio(payload)
|
|
|
|
if MLX_WHISPER_URL:
|
|
result = await _transcribe_remote(audio_b64, language)
|
|
else:
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
|
tmp.write(audio_bytes)
|
|
tmp_path = tmp.name
|
|
try:
|
|
result = await _transcribe_local(tmp_path, language)
|
|
finally:
|
|
os.unlink(tmp_path)
|
|
|
|
meta = {
|
|
"model": MLX_WHISPER_MODEL,
|
|
"provider": "mlx_whisper",
|
|
"device": "apple_silicon",
|
|
}
|
|
|
|
if fmt == "text":
|
|
return {"text": result.get("text", ""), "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}
|
|
if fmt == "segments":
|
|
return {"text": result.get("text", ""), "segments": result.get("segments", []), "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}
|
|
return {**result, "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}
|