"""Phase 1 tests: STT/TTS memory_service providers. Tests: 1. stt_memory_service.transcribe() mocks Memory Service → fabric contract 2. tts_memory_service.synthesize() mocks Memory Service → fabric contract 3. /caps endpoint reflects STT/TTS providers correctly 4. No hardcoded model names in providers 5. Provider switch: memory_service vs none vs mlx_whisper/mlx_kokoro """ import base64 import importlib import json import os import sys import types import unittest from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch # Ensure providers are importable without full app startup WORKER_DIR = Path(__file__).parent.parent sys.path.insert(0, str(WORKER_DIR)) # ── Helpers ──────────────────────────────────────────────────────────────────── def _make_httpx_response(json_body: dict | None = None, content: bytes = b"", headers: dict | None = None, status_code: int = 200): """Create a minimal mock httpx.Response.""" resp = MagicMock() resp.status_code = status_code resp.content = content resp.headers = headers or {} if json_body is not None: resp.json = MagicMock(return_value=json_body) resp.raise_for_status = MagicMock() return resp # ── STT Memory Service Provider Tests ───────────────────────────────────────── class TestSTTMemoryServiceProvider(unittest.IsolatedAsyncioTestCase): async def test_transcribe_audio_b64_returns_fabric_contract(self): """Provider translates Memory Service response to fabric contract.""" raw = b"fake-wav-bytes" audio_b64 = base64.b64encode(raw).decode() mock_resp = _make_httpx_response( json_body={"text": "Привіт", "model": "faster-whisper", "language": "uk"}, ) with patch("httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) mock_client.post = AsyncMock(return_value=mock_resp) mock_client_cls.return_value = mock_client from providers import stt_memory_service result = await stt_memory_service.transcribe({ "audio_b64": audio_b64, "filename": "test.wav", }) self.assertEqual(result["text"], "Привіт") self.assertEqual(result["language"], "uk") self.assertIn("segments", result) self.assertIsInstance(result["segments"], list) self.assertEqual(result["provider"], "memory_service") self.assertIn("meta", result) self.assertEqual(result["meta"]["provider"], "memory_service") async def test_transcribe_requires_audio_input(self): """Should raise ValueError if no audio_b64 or audio_url provided.""" from providers import stt_memory_service with self.assertRaises(ValueError, msg="audio_b64 or audio_url is required"): await stt_memory_service.transcribe({}) async def test_transcribe_passes_language_param(self): """language param is forwarded to Memory Service.""" raw = b"fake-wav" audio_b64 = base64.b64encode(raw).decode() mock_resp = _make_httpx_response( json_body={"text": "Hello", "model": "faster-whisper", "language": "en"}, ) with patch("httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) captured_params = {} async def capture_post(url, *, files, params=None): captured_params.update(params or {}) return mock_resp mock_client.post = capture_post mock_client_cls.return_value = mock_client from providers import stt_memory_service importlib.reload(stt_memory_service) result = await stt_memory_service.transcribe({ "audio_b64": audio_b64, "language": "en", }) self.assertEqual(captured_params.get("language"), "en") def test_no_hardcoded_model_in_stt_provider(self): """STT provider must not call any local model directly (all via Memory Service HTTP).""" src = (WORKER_DIR / "providers" / "stt_memory_service.py").read_text() # These should NOT appear as actual Python imports — provider must not load local models banned_imports = ["from faster_whisper", "import faster_whisper", "from mlx_audio", "import mlx_audio", "WhisperModel"] for name in banned_imports: self.assertNotIn(name, src, f"Local model import '{name}' found in stt_memory_service.py") # ── TTS Memory Service Provider Tests ───────────────────────────────────────── class TestTTSMemoryServiceProvider(unittest.IsolatedAsyncioTestCase): async def test_synthesize_returns_fabric_contract(self): """Provider wraps MP3 bytes into fabric contract with audio_b64.""" mp3_bytes = b"\xff\xfbfake-mp3-data" mock_resp = _make_httpx_response( content=mp3_bytes, headers={ "content-type": "audio/mpeg", "X-TTS-Engine": "edge-tts", "X-TTS-Voice": "uk-UA-PolinaNeural", }, ) with patch("httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) mock_client.post = AsyncMock(return_value=mock_resp) mock_client_cls.return_value = mock_client from providers import tts_memory_service result = await tts_memory_service.synthesize({"text": "Привіт"}) self.assertIn("audio_b64", result) self.assertEqual(base64.b64decode(result["audio_b64"]), mp3_bytes) self.assertEqual(result["format"], "mp3") self.assertEqual(result["provider"], "memory_service") self.assertIn("meta", result) self.assertEqual(result["meta"]["engine"], "edge-tts") self.assertEqual(result["meta"]["voice"], "uk-UA-PolinaNeural") async def test_synthesize_requires_text(self): """Should raise ValueError if text is empty.""" from providers import tts_memory_service with self.assertRaises(ValueError): await tts_memory_service.synthesize({"text": ""}) async def test_synthesize_truncates_long_text(self): """Text exceeding MAX_TEXT_CHARS is truncated (no crash).""" long_text = "А" * 1000 mp3_bytes = b"\xff\xfb" mock_resp = _make_httpx_response( content=mp3_bytes, headers={"content-type": "audio/mpeg", "X-TTS-Engine": "edge-tts", "X-TTS-Voice": "Polina"}, ) with patch("httpx.AsyncClient") as mock_client_cls: mock_client = AsyncMock() mock_client.__aenter__ = AsyncMock(return_value=mock_client) mock_client.__aexit__ = AsyncMock(return_value=False) captured_json = {} async def capture_post(url, *, json=None): captured_json.update(json or {}) return mock_resp mock_client.post = capture_post mock_client_cls.return_value = mock_client from providers import tts_memory_service importlib.reload(tts_memory_service) result = await tts_memory_service.synthesize({"text": long_text}) self.assertLessEqual(len(captured_json.get("text", "")), tts_memory_service.MAX_TEXT_CHARS) def test_no_hardcoded_model_in_tts_provider(self): """TTS provider must not hardcode any model name.""" src = (WORKER_DIR / "providers" / "tts_memory_service.py").read_text() banned = ["kokoro", "mlx", "espeak", "piper"] for name in banned: self.assertNotIn(name, src, f"Hardcoded engine '{name}' found in tts_memory_service.py") # ── /caps endpoint tests ─────────────────────────────────────────────────────── class TestCapsEndpoint(unittest.IsolatedAsyncioTestCase): def _get_caps_result(self, stt: str, tts: str) -> dict: """Simulate /caps logic from main.py.""" return { "capabilities": { "stt": stt != "none", "tts": tts != "none", }, "providers": { "stt": stt, "tts": tts, }, } def test_caps_memory_service_stt_tts_true(self): r = self._get_caps_result("memory_service", "memory_service") self.assertTrue(r["capabilities"]["stt"]) self.assertTrue(r["capabilities"]["tts"]) self.assertEqual(r["providers"]["stt"], "memory_service") self.assertEqual(r["providers"]["tts"], "memory_service") def test_caps_none_stt_tts_false(self): r = self._get_caps_result("none", "none") self.assertFalse(r["capabilities"]["stt"]) self.assertFalse(r["capabilities"]["tts"]) def test_caps_mlx_providers_true(self): r = self._get_caps_result("mlx_whisper", "mlx_kokoro") self.assertTrue(r["capabilities"]["stt"]) self.assertTrue(r["capabilities"]["tts"]) def test_caps_mixed_memory_none(self): r = self._get_caps_result("memory_service", "none") self.assertTrue(r["capabilities"]["stt"]) self.assertFalse(r["capabilities"]["tts"]) # ── Provider switch in config ───────────────────────────────────────────────── class TestProviderConfig(unittest.TestCase): def _reload_config(self, env_overrides: dict) -> types.ModuleType: """Reload config module with given env overrides.""" import config as cfg_module with patch.dict(os.environ, env_overrides, clear=False): return importlib.reload(cfg_module) def test_default_providers_are_none(self): """Default config has no STT/TTS (safe for NODA1).""" env = {} for k in ("STT_PROVIDER", "TTS_PROVIDER"): if k in os.environ: env[k] = "" with patch.dict(os.environ, {"STT_PROVIDER": "", "TTS_PROVIDER": ""}): import config as cfg_module with patch.object(cfg_module, "STT_PROVIDER", "none"), \ patch.object(cfg_module, "TTS_PROVIDER", "none"): self.assertEqual(cfg_module.STT_PROVIDER, "none") self.assertEqual(cfg_module.TTS_PROVIDER, "none") def test_memory_service_provider_from_env(self): with patch.dict(os.environ, {"STT_PROVIDER": "memory_service", "TTS_PROVIDER": "memory_service"}): import config as cfg_module cfg = importlib.reload(cfg_module) self.assertEqual(cfg.STT_PROVIDER, "memory_service") self.assertEqual(cfg.TTS_PROVIDER, "memory_service") def test_memory_service_url_default(self): """Default MEMORY_SERVICE_URL falls back to http://memory-service:8000.""" import config as cfg_module # Verify the default value in source regardless of env src = (WORKER_DIR / "config.py").read_text() self.assertIn("http://memory-service:8000", src) self.assertIn("MEMORY_SERVICE_URL", src) if __name__ == "__main__": unittest.main()