diff --git a/tests/test_sofiia_idempotency_redis.py b/tests/test_sofiia_idempotency_redis.py new file mode 100644 index 00000000..b903ccbb --- /dev/null +++ b/tests/test_sofiia_idempotency_redis.py @@ -0,0 +1,201 @@ +from __future__ import annotations + +import asyncio +import importlib +import json + +import httpx + +import app.idempotency as idem_mod # type: ignore + + +class _FakeRedis: + def __init__(self) -> None: + self._values = {} + self._now = 0.0 + self.last_set_ex = {} + + def _expired(self, key: str) -> bool: + hit = self._values.get(key) + if not hit: + return True + exp = hit.get("expires_at") + if exp is None: + return False + return float(exp) <= self._now + + def set(self, key: str, value, ex=None): + expires_at = None if ex is None else (self._now + float(ex)) + encoded = value if isinstance(value, bytes) else str(value).encode("utf-8") + self._values[key] = {"value": encoded, "expires_at": expires_at} + self.last_set_ex[key] = ex + return True + + def get(self, key: str): + if self._expired(key): + self._values.pop(key, None) + return None + return self._values[key]["value"] + + def delete(self, *keys: str): + deleted = 0 + for key in keys: + if key in self._values: + deleted += 1 + self._values.pop(key, None) + return deleted + + def keys(self, pattern: str): + if pattern.endswith("*"): + pref = pattern[:-1] + return [k for k in self._values.keys() if k.startswith(pref) and not self._expired(k)] + return [k for k in self._values.keys() if k == pattern and not self._expired(k)] + + def advance(self, seconds: float) -> None: + self._now += float(seconds) + + +def _build_entry() -> idem_mod.ReplayEntry: + return idem_mod.ReplayEntry( + message_id="m1", + response_body={"ok": True, "text": "hello"}, + created_at=123.0, + node_id="NODA1", + ) + + +def _create_chat(client, agent_id: str, node_id: str, ref: str) -> str: + r = client.post( + "/api/chats", + json={ + "agent_id": agent_id, + "node_id": node_id, + "source": "web", + "external_chat_ref": ref, + }, + ) + assert r.status_code == 200, r.text + return r.json()["chat"]["chat_id"] + + +def _load_sofiia_module_with_fake_redis(tmp_path, monkeypatch, fake_redis: _FakeRedis): + monkeypatch.setenv("SOFIIA_DATA_DIR", str(tmp_path / "sofiia-data")) + monkeypatch.setenv("ENV", "dev") + monkeypatch.setenv("SOFIIA_IDEMPOTENCY_BACKEND", "redis") + monkeypatch.setenv("SOFIIA_REDIS_URL", "redis://fake:6379/0") + monkeypatch.setenv("SOFIIA_REDIS_PREFIX", "sofiia:idem:test:") + monkeypatch.delenv("SOFIIA_CONSOLE_API_KEY", raising=False) + monkeypatch.setenv("ROUTER_URL", "http://router.local:8000") + monkeypatch.delenv("NODE_ID", raising=False) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + import app.db as db_mod # type: ignore + import app.main as main_mod # type: ignore + + importlib.reload(db_mod) + importlib.reload(idem_mod) + monkeypatch.setattr(idem_mod, "_make_redis_client", lambda _: fake_redis) + idem_mod._STORE = None + importlib.reload(main_mod) + main_mod._rate_buckets.clear() + store = idem_mod.get_idempotency_store() + if hasattr(store, "reset"): + store.reset() + return main_mod + + +class _LocalClient: + def __init__(self, app): + self.app = app + + def request(self, method: str, path: str, **kwargs): + async def _do(): + transport = httpx.ASGITransport(app=self.app) + async with httpx.AsyncClient( + transport=transport, + base_url="http://testserver", + follow_redirects=True, + ) as client: + return await client.request(method, path, **kwargs) + + return asyncio.run(_do()) + + def get(self, path: str, **kwargs): + return self.request("GET", path, **kwargs) + + def post(self, path: str, **kwargs): + return self.request("POST", path, **kwargs) + + +def test_redis_store_set_get_roundtrip(): + fake = _FakeRedis() + store = idem_mod.RedisIdempotencyStore(fake, ttl_seconds=120, prefix="sofiia:idem:test:") + entry = _build_entry() + + store.set("abc", entry) + got = store.get("abc") + assert got is not None + assert got.message_id == entry.message_id + assert got.response_body == entry.response_body + assert got.node_id == entry.node_id + + +def test_redis_store_ttl_expiry(): + fake = _FakeRedis() + store = idem_mod.RedisIdempotencyStore(fake, ttl_seconds=60, prefix="sofiia:idem:test:") + + store.set("abc", _build_entry()) + assert store.get("abc") is not None + fake.advance(61) + assert store.get("abc") is None + + +def test_get_idempotency_store_selects_redis_backend(monkeypatch): + fake = _FakeRedis() + monkeypatch.setenv("SOFIIA_IDEMPOTENCY_BACKEND", "redis") + monkeypatch.setenv("SOFIIA_REDIS_URL", "redis://fake:6379/0") + monkeypatch.setenv("SOFIIA_REDIS_PREFIX", "sofiia:idem:test:") + monkeypatch.setattr(idem_mod, "_make_redis_client", lambda _: fake) + idem_mod._STORE = None + + store = idem_mod.get_idempotency_store() + assert isinstance(store, idem_mod.RedisIdempotencyStore) + + +def test_get_idempotency_store_falls_back_when_redis_unavailable(monkeypatch): + monkeypatch.setenv("SOFIIA_IDEMPOTENCY_BACKEND", "redis") + monkeypatch.setenv("SOFIIA_REDIS_URL", "redis://fake:6379/0") + monkeypatch.setattr(idem_mod, "_make_redis_client", lambda _: (_ for _ in ()).throw(RuntimeError("redis down"))) + idem_mod._STORE = None + + store = idem_mod.get_idempotency_store() + assert isinstance(store, idem_mod.InMemoryIdempotencyStore) + + +def test_send_idempotency_replay_with_redis_backend(tmp_path, monkeypatch): + fake = _FakeRedis() + sofiia_module = _load_sofiia_module_with_fake_redis(tmp_path, monkeypatch, fake) + client = _LocalClient(sofiia_module.app) + + async def _fake_infer(base_url, agent_id, text, **kwargs): + return {"response": f"ok:{agent_id}:{text}", "backend": "fake", "model": "fake-model"} + + monkeypatch.setattr(sofiia_module, "infer", _fake_infer) + chat_id = _create_chat(client, "sofiia", "NODA2", "redis-idem") + headers = {"Idempotency-Key": "redis-idem-1"} + + r1 = client.post(f"/api/chats/{chat_id}/send", json={"text": "ping"}, headers=headers) + r2 = client.post(f"/api/chats/{chat_id}/send", json={"text": "ping"}, headers=headers) + + assert r1.status_code == 200 and r2.status_code == 200 + j1, j2 = r1.json(), r2.json() + assert j1["message"]["message_id"] == j2["message"]["message_id"] + assert j1["idempotency"]["replayed"] is False + assert j2["idempotency"]["replayed"] is True + + raw = fake.get(f"sofiia:idem:test:{chat_id}::redis-idem-1") + assert raw is not None + payload = json.loads(raw.decode("utf-8")) + assert payload["node_id"] == "NODA2"