"""MLX Whisper STT provider — transcribes audio via mlx-whisper on host. Runs inside Docker; delegates to MLX Whisper HTTP API on the host. If MLX_WHISPER_URL is not set, falls back to running mlx_whisper directly (only works when node-worker runs natively, not in Docker). """ import base64 import logging import os import tempfile from typing import Any, Dict, Optional import httpx logger = logging.getLogger("provider.stt_mlx_whisper") MLX_WHISPER_URL = os.getenv("MLX_WHISPER_URL", "") MLX_WHISPER_MODEL = os.getenv("MLX_WHISPER_MODEL", "mlx-community/whisper-large-v3-turbo") MAX_AUDIO_BYTES = int(os.getenv("STT_MAX_AUDIO_BYTES", str(25 * 1024 * 1024))) # 25MB _local_model = None _local_lock = None def _lazy_init_local(): """Lazy-load mlx_whisper for native (non-Docker) execution.""" global _local_model, _local_lock if _local_lock is not None: return import asyncio _local_lock = asyncio.Lock() async def _transcribe_local(audio_path: str, language: Optional[str]) -> Dict[str, Any]: """Transcribe using local mlx_whisper (Apple Silicon only).""" _lazy_init_local() async with _local_lock: import mlx_whisper kwargs: Dict[str, Any] = {"path_or_hf_repo": MLX_WHISPER_MODEL} if language: kwargs["language"] = language result = mlx_whisper.transcribe(audio_path, **kwargs) segments = [] for seg in result.get("segments", []): segments.append({ "start": seg.get("start", 0), "end": seg.get("end", 0), "text": seg.get("text", ""), }) return { "text": result.get("text", ""), "segments": segments, "language": result.get("language", ""), } async def _transcribe_remote(audio_b64: str, language: Optional[str]) -> Dict[str, Any]: """Transcribe via MLX Whisper HTTP service on host.""" payload: Dict[str, Any] = {"audio_b64": audio_b64} if language: payload["language"] = language async with httpx.AsyncClient(timeout=120) as c: resp = await c.post(f"{MLX_WHISPER_URL}/transcribe", json=payload) resp.raise_for_status() return resp.json() async def _resolve_audio(payload: Dict[str, Any]) -> tuple: """Return (audio_bytes, audio_b64) from payload.""" audio_b64 = payload.get("audio_b64", "") audio_url = payload.get("audio_url", "") if audio_b64: raw = base64.b64decode(audio_b64) if len(raw) > MAX_AUDIO_BYTES: raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes") return raw, audio_b64 if audio_url: if audio_url.startswith(("file://", "/")): path = audio_url.replace("file://", "") with open(path, "rb") as f: raw = f.read() if len(raw) > MAX_AUDIO_BYTES: raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes") return raw, base64.b64encode(raw).decode() async with httpx.AsyncClient(timeout=30) as c: resp = await c.get(audio_url) resp.raise_for_status() raw = resp.content if len(raw) > MAX_AUDIO_BYTES: raise ValueError(f"Audio exceeds {MAX_AUDIO_BYTES} bytes") return raw, base64.b64encode(raw).decode() raise ValueError("Either audio_b64 or audio_url is required") async def transcribe(payload: Dict[str, Any]) -> Dict[str, Any]: """Canonical STT entry point. Payload: audio_url: str (http/file) — OR — audio_b64: str (base64 encoded) language: str (optional, e.g. "uk", "en") format: "text" | "segments" | "json" """ language = payload.get("language") fmt = payload.get("format", "json") audio_bytes, audio_b64 = await _resolve_audio(payload) if MLX_WHISPER_URL: result = await _transcribe_remote(audio_b64, language) else: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name try: result = await _transcribe_local(tmp_path, language) finally: os.unlink(tmp_path) meta = { "model": MLX_WHISPER_MODEL, "provider": "mlx_whisper", "device": "apple_silicon", } if fmt == "text": return {"text": result.get("text", ""), "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL} if fmt == "segments": return {"text": result.get("text", ""), "segments": result.get("segments", []), "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL} return {**result, "meta": meta, "provider": "mlx_whisper", "model": MLX_WHISPER_MODEL}