202 lines
6.7 KiB
Python
202 lines
6.7 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import importlib
|
|
import json
|
|
|
|
import httpx
|
|
|
|
import app.idempotency as idem_mod # type: ignore
|
|
|
|
|
|
class _FakeRedis:
|
|
def __init__(self) -> None:
|
|
self._values = {}
|
|
self._now = 0.0
|
|
self.last_set_ex = {}
|
|
|
|
def _expired(self, key: str) -> bool:
|
|
hit = self._values.get(key)
|
|
if not hit:
|
|
return True
|
|
exp = hit.get("expires_at")
|
|
if exp is None:
|
|
return False
|
|
return float(exp) <= self._now
|
|
|
|
def set(self, key: str, value, ex=None):
|
|
expires_at = None if ex is None else (self._now + float(ex))
|
|
encoded = value if isinstance(value, bytes) else str(value).encode("utf-8")
|
|
self._values[key] = {"value": encoded, "expires_at": expires_at}
|
|
self.last_set_ex[key] = ex
|
|
return True
|
|
|
|
def get(self, key: str):
|
|
if self._expired(key):
|
|
self._values.pop(key, None)
|
|
return None
|
|
return self._values[key]["value"]
|
|
|
|
def delete(self, *keys: str):
|
|
deleted = 0
|
|
for key in keys:
|
|
if key in self._values:
|
|
deleted += 1
|
|
self._values.pop(key, None)
|
|
return deleted
|
|
|
|
def keys(self, pattern: str):
|
|
if pattern.endswith("*"):
|
|
pref = pattern[:-1]
|
|
return [k for k in self._values.keys() if k.startswith(pref) and not self._expired(k)]
|
|
return [k for k in self._values.keys() if k == pattern and not self._expired(k)]
|
|
|
|
def advance(self, seconds: float) -> None:
|
|
self._now += float(seconds)
|
|
|
|
|
|
def _build_entry() -> idem_mod.ReplayEntry:
|
|
return idem_mod.ReplayEntry(
|
|
message_id="m1",
|
|
response_body={"ok": True, "text": "hello"},
|
|
created_at=123.0,
|
|
node_id="NODA1",
|
|
)
|
|
|
|
|
|
def _create_chat(client, agent_id: str, node_id: str, ref: str) -> str:
|
|
r = client.post(
|
|
"/api/chats",
|
|
json={
|
|
"agent_id": agent_id,
|
|
"node_id": node_id,
|
|
"source": "web",
|
|
"external_chat_ref": ref,
|
|
},
|
|
)
|
|
assert r.status_code == 200, r.text
|
|
return r.json()["chat"]["chat_id"]
|
|
|
|
|
|
def _load_sofiia_module_with_fake_redis(tmp_path, monkeypatch, fake_redis: _FakeRedis):
|
|
monkeypatch.setenv("SOFIIA_DATA_DIR", str(tmp_path / "sofiia-data"))
|
|
monkeypatch.setenv("ENV", "dev")
|
|
monkeypatch.setenv("SOFIIA_IDEMPOTENCY_BACKEND", "redis")
|
|
monkeypatch.setenv("SOFIIA_REDIS_URL", "redis://fake:6379/0")
|
|
monkeypatch.setenv("SOFIIA_REDIS_PREFIX", "sofiia:idem:test:")
|
|
monkeypatch.delenv("SOFIIA_CONSOLE_API_KEY", raising=False)
|
|
monkeypatch.setenv("ROUTER_URL", "http://router.local:8000")
|
|
monkeypatch.delenv("NODE_ID", raising=False)
|
|
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
import app.db as db_mod # type: ignore
|
|
import app.main as main_mod # type: ignore
|
|
|
|
importlib.reload(db_mod)
|
|
importlib.reload(idem_mod)
|
|
monkeypatch.setattr(idem_mod, "_make_redis_client", lambda _: fake_redis)
|
|
idem_mod._STORE = None
|
|
importlib.reload(main_mod)
|
|
main_mod._rate_buckets.clear()
|
|
store = idem_mod.get_idempotency_store()
|
|
if hasattr(store, "reset"):
|
|
store.reset()
|
|
return main_mod
|
|
|
|
|
|
class _LocalClient:
|
|
def __init__(self, app):
|
|
self.app = app
|
|
|
|
def request(self, method: str, path: str, **kwargs):
|
|
async def _do():
|
|
transport = httpx.ASGITransport(app=self.app)
|
|
async with httpx.AsyncClient(
|
|
transport=transport,
|
|
base_url="http://testserver",
|
|
follow_redirects=True,
|
|
) as client:
|
|
return await client.request(method, path, **kwargs)
|
|
|
|
return asyncio.run(_do())
|
|
|
|
def get(self, path: str, **kwargs):
|
|
return self.request("GET", path, **kwargs)
|
|
|
|
def post(self, path: str, **kwargs):
|
|
return self.request("POST", path, **kwargs)
|
|
|
|
|
|
def test_redis_store_set_get_roundtrip():
|
|
fake = _FakeRedis()
|
|
store = idem_mod.RedisIdempotencyStore(fake, ttl_seconds=120, prefix="sofiia:idem:test:")
|
|
entry = _build_entry()
|
|
|
|
store.set("abc", entry)
|
|
got = store.get("abc")
|
|
assert got is not None
|
|
assert got.message_id == entry.message_id
|
|
assert got.response_body == entry.response_body
|
|
assert got.node_id == entry.node_id
|
|
|
|
|
|
def test_redis_store_ttl_expiry():
|
|
fake = _FakeRedis()
|
|
store = idem_mod.RedisIdempotencyStore(fake, ttl_seconds=60, prefix="sofiia:idem:test:")
|
|
|
|
store.set("abc", _build_entry())
|
|
assert store.get("abc") is not None
|
|
fake.advance(61)
|
|
assert store.get("abc") is None
|
|
|
|
|
|
def test_get_idempotency_store_selects_redis_backend(monkeypatch):
|
|
fake = _FakeRedis()
|
|
monkeypatch.setenv("SOFIIA_IDEMPOTENCY_BACKEND", "redis")
|
|
monkeypatch.setenv("SOFIIA_REDIS_URL", "redis://fake:6379/0")
|
|
monkeypatch.setenv("SOFIIA_REDIS_PREFIX", "sofiia:idem:test:")
|
|
monkeypatch.setattr(idem_mod, "_make_redis_client", lambda _: fake)
|
|
idem_mod._STORE = None
|
|
|
|
store = idem_mod.get_idempotency_store()
|
|
assert isinstance(store, idem_mod.RedisIdempotencyStore)
|
|
|
|
|
|
def test_get_idempotency_store_falls_back_when_redis_unavailable(monkeypatch):
|
|
monkeypatch.setenv("SOFIIA_IDEMPOTENCY_BACKEND", "redis")
|
|
monkeypatch.setenv("SOFIIA_REDIS_URL", "redis://fake:6379/0")
|
|
monkeypatch.setattr(idem_mod, "_make_redis_client", lambda _: (_ for _ in ()).throw(RuntimeError("redis down")))
|
|
idem_mod._STORE = None
|
|
|
|
store = idem_mod.get_idempotency_store()
|
|
assert isinstance(store, idem_mod.InMemoryIdempotencyStore)
|
|
|
|
|
|
def test_send_idempotency_replay_with_redis_backend(tmp_path, monkeypatch):
|
|
fake = _FakeRedis()
|
|
sofiia_module = _load_sofiia_module_with_fake_redis(tmp_path, monkeypatch, fake)
|
|
client = _LocalClient(sofiia_module.app)
|
|
|
|
async def _fake_infer(base_url, agent_id, text, **kwargs):
|
|
return {"response": f"ok:{agent_id}:{text}", "backend": "fake", "model": "fake-model"}
|
|
|
|
monkeypatch.setattr(sofiia_module, "infer", _fake_infer)
|
|
chat_id = _create_chat(client, "sofiia", "NODA2", "redis-idem")
|
|
headers = {"Idempotency-Key": "redis-idem-1"}
|
|
|
|
r1 = client.post(f"/api/chats/{chat_id}/send", json={"text": "ping"}, headers=headers)
|
|
r2 = client.post(f"/api/chats/{chat_id}/send", json={"text": "ping"}, headers=headers)
|
|
|
|
assert r1.status_code == 200 and r2.status_code == 200
|
|
j1, j2 = r1.json(), r2.json()
|
|
assert j1["message"]["message_id"] == j2["message"]["message_id"]
|
|
assert j1["idempotency"]["replayed"] is False
|
|
assert j2["idempotency"]["replayed"] is True
|
|
|
|
raw = fake.get(f"sofiia:idem:test:{chat_id}::redis-idem-1")
|
|
assert raw is not None
|
|
payload = json.loads(raw.decode("utf-8"))
|
|
assert payload["node_id"] == "NODA2"
|