"""MLX Whisper STT Service — lightweight HTTP wrapper for mlx-whisper on Apple Silicon. Runs natively on host (not in Docker) to access Metal/MPS acceleration. Port: 8200 """ import asyncio import base64 import logging import os import tempfile import time from typing import Optional from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field import uvicorn logging.basicConfig(level=logging.INFO) logger = logging.getLogger("mlx-stt") app = FastAPI(title="MLX Whisper STT", version="1.0.0") MODEL = os.getenv("MLX_WHISPER_MODEL", "mlx-community/whisper-large-v3-turbo") MAX_AUDIO_BYTES = int(os.getenv("STT_MAX_AUDIO_BYTES", str(25 * 1024 * 1024))) _whisper = None _lock = asyncio.Lock() def _load_model(): global _whisper if _whisper is not None: return logger.info(f"Loading MLX Whisper model: {MODEL}") t0 = time.time() import mlx_whisper _whisper = mlx_whisper _whisper.transcribe("", path_or_hf_repo=MODEL) # warm up / download logger.info(f"MLX Whisper ready in {time.time()-t0:.1f}s") class TranscribeRequest(BaseModel): audio_b64: str = "" audio_url: str = "" language: Optional[str] = None format: str = Field(default="json", description="text|segments|json") class TranscribeResponse(BaseModel): text: str = "" segments: list = Field(default_factory=list) language: str = "" meta: dict = Field(default_factory=dict) @app.on_event("startup") async def startup(): _load_model() @app.get("/health") async def health(): return {"status": "ok", "model": MODEL, "ready": _whisper is not None} @app.post("/transcribe", response_model=TranscribeResponse) async def transcribe(req: TranscribeRequest): if not req.audio_b64 and not req.audio_url: raise HTTPException(400, "audio_b64 or audio_url required") if req.audio_b64: raw = base64.b64decode(req.audio_b64) elif req.audio_url.startswith(("file://", "/")): path = req.audio_url.replace("file://", "") with open(path, "rb") as f: raw = f.read() else: import httpx async with httpx.AsyncClient(timeout=30) as c: resp = await c.get(req.audio_url) resp.raise_for_status() raw = resp.content if len(raw) > MAX_AUDIO_BYTES: raise HTTPException(413, f"Audio exceeds {MAX_AUDIO_BYTES} bytes") with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp: tmp.write(raw) tmp_path = tmp.name try: async with _lock: t0 = time.time() kwargs = {"path_or_hf_repo": MODEL} if req.language: kwargs["language"] = req.language result = _whisper.transcribe(tmp_path, **kwargs) duration_ms = int((time.time() - t0) * 1000) finally: os.unlink(tmp_path) segments = [ {"start": s.get("start", 0), "end": s.get("end", 0), "text": s.get("text", "")} for s in result.get("segments", []) ] return TranscribeResponse( text=result.get("text", ""), segments=segments, language=result.get("language", ""), meta={"model": MODEL, "duration_ms": duration_ms, "device": "apple_silicon"}, ) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", "8200")))