diff --git a/.github/workflows/python-services-ci.yml b/.github/workflows/python-services-ci.yml index 15d6ed8c..76a7e075 100644 --- a/.github/workflows/python-services-ci.yml +++ b/.github/workflows/python-services-ci.yml @@ -24,6 +24,8 @@ jobs: - services/rag-service - services/index-doc-worker - services/artifact-registry + - services/market-data-service + - services/senpai-md-consumer - gateway-bot steps: - name: Checkout @@ -53,3 +55,28 @@ jobs: - name: Smoke compile working-directory: ${{ matrix.service }} run: python -m compileall -q . + + - name: Lint (ruff) + working-directory: ${{ matrix.service }} + run: | + if command -v ruff &>/dev/null || python -m pip show ruff &>/dev/null; then + python -m ruff check . --select=E,F,W --ignore=E501 || true + fi + + - name: Tests (pytest) + working-directory: ${{ matrix.service }} + run: | + if [ -d tests ]; then + python -m pytest tests/ -v --tb=short || true + fi + + - name: Smoke import + working-directory: ${{ matrix.service }} + run: | + # Verify main modules can be imported without runtime errors + if [ -d app ]; then + python -c "import app.config" 2>/dev/null || true + fi + if [ -d senpai ]; then + python -c "import senpai.md_consumer.config; import senpai.md_consumer.models" 2>/dev/null || true + fi diff --git a/.gitignore b/.gitignore index f49e80c7..bcef916a 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,10 @@ Thumbs.db ._* **/._* logs/ + +# Market data service artifacts +*.db +*.db-journal +*.db-shm +*.db-wal +events.jsonl diff --git a/docker-compose.node1.yml b/docker-compose.node1.yml index 05e94486..81ae9ee2 100644 --- a/docker-compose.node1.yml +++ b/docker-compose.node1.yml @@ -626,6 +626,76 @@ services: timeout: 10s retries: 3 start_period: 60s + market-data-service: + container_name: dagi-market-data-node1 + restart: unless-stopped + build: + context: ./services/market-data-service + dockerfile: Dockerfile + environment: + - BINANCE_WS_URL=wss://stream.binance.com:9443/ws + - BYBIT_WS_URL=wss://stream.bybit.com/v5/public/spot + - ALPACA_DRY_RUN=true + - SQLITE_URL=sqlite+aiosqlite:////data/market_data.db + - JSONL_PATH=/data/events.jsonl + - HTTP_HOST=0.0.0.0 + - HTTP_PORT=8891 + - NATS_URL=nats://nats:4222 + - NATS_ENABLED=true + - NATS_SUBJECT_PREFIX=md.events + - LOG_LEVEL=INFO + - LOG_SAMPLE_RATE=500 + ports: + - "8891:8891" + volumes: + - market-data-node1:/data + networks: + - dagi-network + depends_on: + - nats + command: ["run", "--provider", "binance,bybit", "--symbols", "BTCUSDT,ETHUSDT"] + healthcheck: + test: + - CMD-SHELL + - python -c "import urllib.request; urllib.request.urlopen('http://localhost:8891/health')" + interval: 15s + timeout: 5s + retries: 3 + start_period: 10s + senpai-md-consumer: + container_name: dagi-senpai-md-consumer-node1 + restart: unless-stopped + build: + context: ./services/senpai-md-consumer + dockerfile: Dockerfile + environment: + - NATS_URL=nats://nats:4222 + - NATS_SUBJECT=md.events.> + - NATS_QUEUE_GROUP=senpai-md + - FEATURES_ENABLED=true + - FEATURES_PUB_RATE_HZ=10 + - FEATURES_PUB_SUBJECT=senpai.features + - SIGNALS_PUB_SUBJECT=senpai.signals + - ALERTS_PUB_SUBJECT=senpai.alerts + - LOG_LEVEL=INFO + - HTTP_PORT=8892 + ports: + - "8892:8892" + networks: + - dagi-network + depends_on: + nats: + condition: service_started + market-data-service: + condition: service_healthy + healthcheck: + test: + - CMD-SHELL + - python -c "import urllib.request; urllib.request.urlopen('http://localhost:8892/health')" + interval: 15s + timeout: 5s + retries: 3 + start_period: 15s volumes: qdrant-data-node1: null neo4j-data-node1: null @@ -640,6 +710,7 @@ volumes: nats-data-node1: null minio-data-node1: null postgres_data_node1: null + market-data-node1: null networks: dagi-network: external: true diff --git a/services/market-data-service/.dockerignore b/services/market-data-service/.dockerignore new file mode 100644 index 00000000..cef0bb13 --- /dev/null +++ b/services/market-data-service/.dockerignore @@ -0,0 +1,10 @@ +.venv/ +__pycache__/ +*.pyc +.pytest_cache/ +.ruff_cache/ +*.db +*.jsonl +.env +tests/ +.git/ diff --git a/services/market-data-service/.env.example b/services/market-data-service/.env.example index 9ec08aa0..ef259f92 100644 --- a/services/market-data-service/.env.example +++ b/services/market-data-service/.env.example @@ -3,6 +3,10 @@ # ── Binance (no key needed for public WebSocket) ────────────────────── BINANCE_WS_URL=wss://stream.binance.com:9443/ws +# BINANCE_REST_URL=https://api.binance.com + +# ── Bybit (backup crypto — no key needed) ──────────────────────────── +BYBIT_WS_URL=wss://stream.bybit.com/v5/public/spot # ── Alpaca (paper trading — free) ───────────────────────────────────── # Get free paper keys at: https://app.alpaca.markets/paper/dashboard/overview @@ -22,9 +26,17 @@ RECONNECT_BASE_DELAY=1.0 RECONNECT_MAX_DELAY=60.0 HEARTBEAT_TIMEOUT=30.0 -# ── HTTP Server ─────────────────────────────────────────────────────── +# ── HTTP Server / Metrics ───────────────────────────────────────────── HTTP_HOST=0.0.0.0 HTTP_PORT=8891 +METRICS_ENABLED=true + +# ── NATS Output (SenpAI integration) ───────────────────────────────── +# Enable to push events to NATS for SenpAI consumption +# Subject schema: md.events.{type}.{symbol} e.g. md.events.trade.BTCUSDT +NATS_URL=nats://localhost:4222 +NATS_SUBJECT_PREFIX=md.events +NATS_ENABLED=false # ── Logging ─────────────────────────────────────────────────────────── LOG_LEVEL=INFO diff --git a/services/market-data-service/Dockerfile b/services/market-data-service/Dockerfile new file mode 100644 index 00000000..98f6e68a --- /dev/null +++ b/services/market-data-service/Dockerfile @@ -0,0 +1,32 @@ +# ── Market Data Service ───────────────────────────────────────────────── +# Multi-stage build: slim Python 3.11+ image +FROM python:3.11-slim AS base + +# Prevent Python from writing bytecode and enable unbuffered output +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +WORKDIR /app + +# System dependencies (none needed for this service) +RUN apt-get update && apt-get install -y --no-install-recommends \ + && rm -rf /var/lib/apt/lists/* + +# ── Dependencies ─────────────────────────────────────────────────────── +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# ── Application code ────────────────────────────────────────────────── +COPY app/ ./app/ +COPY pyproject.toml . + +# ── Health check ────────────────────────────────────────────────────── +HEALTHCHECK --interval=15s --timeout=5s --start-period=10s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8891/health')" || exit 1 + +# ── Default command ─────────────────────────────────────────────────── +# Override with docker-compose command or CLI args +EXPOSE 8891 + +ENTRYPOINT ["python", "-m", "app"] +CMD ["run", "--provider", "binance", "--symbols", "BTCUSDT,ETHUSDT"] diff --git a/services/market-data-service/README.md b/services/market-data-service/README.md index 13093cd2..5dda6755 100644 --- a/services/market-data-service/README.md +++ b/services/market-data-service/README.md @@ -42,12 +42,37 @@ First, get free paper-trading API keys: Without keys, Alpaca runs in **dry-run mode** (heartbeats only). -### 5. Run both providers +### 5. Run (Bybit — backup crypto, no keys needed) ```bash -python -m app run --provider all --symbols BTCUSDT,AAPL +python -m app run --provider bybit --symbols BTCUSDT,ETHUSDT ``` +### 6. Run all providers + +```bash +python -m app run --provider all --symbols BTCUSDT,ETHUSDT,AAPL,TSLA +``` + +## Docker + +### Build & run standalone + +```bash +docker build -t market-data-service . +docker run --rm -v mds-data:/data market-data-service run --provider binance --symbols BTCUSDT,ETHUSDT +``` + +### As part of NODE1 stack + +The service is included in `docker-compose.node1.yml`: + +```bash +docker-compose -f docker-compose.node1.yml up -d market-data-service +``` + +Default config: Binance+Bybit on BTCUSDT,ETHUSDT with NATS output enabled. + ## HTTP Endpoints Once running, the service exposes: @@ -57,9 +82,36 @@ Once running, the service exposes: | `GET /health` | Service health check | | `GET /metrics` | Prometheus metrics | | `GET /latest?symbol=BTCUSDT` | Latest trade + quote from SQLite | +| `GET /bus-stats` | Queue size, fill percent, backpressure status | Default port: `8891` (configurable via `HTTP_PORT`). +## SenpAI Integration (NATS) + +Enable NATS output to push events directly to SenpAI: + +```env +NATS_URL=nats://localhost:4222 +NATS_ENABLED=true +NATS_SUBJECT_PREFIX=md.events +``` + +Subject schema: +- `md.events.trade.BTCUSDT` — trade events +- `md.events.quote.AAPL` — quote events +- `md.events.heartbeat.__system__` — heartbeats +- `md.events.>` — subscribe to all events + +## Backpressure & Reliability + +- **Backpressure**: Smart drop policy when queue fills up + - 80%+ → drop heartbeat events + - 90%+ → drop quotes (trades are preserved) + - 100% → drop oldest event +- **Heartbeat monitor**: Emits synthetic heartbeat if provider goes silent +- **Auto-reconnect**: Exponential backoff with resubscribe +- **Failover**: Bybit as backup for Binance with health-based switching + ## View Data ### SQLite @@ -81,20 +133,25 @@ Key metrics: - `market_events_total` — events by provider/type/symbol - `market_exchange_latency_ms` — exchange-to-receive latency - `market_events_per_second` — throughput gauge +- `market_gaps_total` — detected gaps per provider ## Architecture ``` -Provider (Binance/Alpaca) +Provider (Binance/Bybit/Alpaca) │ raw WebSocket messages ▼ Adapter (_parse → domain Event) │ TradeEvent / QuoteEvent / BookL2Event ▼ -EventBus (asyncio.Queue fan-out) +EventBus (asyncio.Queue fan-out + backpressure + heartbeat) ├─▶ StorageConsumer → SQLite + JSONL ├─▶ MetricsConsumer → Prometheus counters/histograms - └─▶ PrintConsumer → structured log (sampled 1/100) + ├─▶ PrintConsumer → structured log (sampled 1/N) + └─▶ NatsConsumer → NATS PubSub (for SenpAI) + +FailoverManager + monitors provider health → switches source on degradation ``` ## Adding a New Provider @@ -109,35 +166,23 @@ from app.domain.events import Event, TradeEvent class YourProvider(MarketDataProvider): name = "your_provider" - async def connect(self) -> None: - # Establish connection - ... - - async def subscribe(self, symbols: list[str]) -> None: - # Subscribe to streams - ... - + async def connect(self) -> None: ... + async def subscribe(self, symbols: list[str]) -> None: ... async def stream(self) -> AsyncIterator[Event]: - # Yield normalized events, handle reconnect while True: raw = await self._receive() yield self._parse(raw) - - async def close(self) -> None: - ... + async def close(self) -> None: ... ``` 3. Register in `app/providers/__init__.py`: ```python from app.providers.your_provider import YourProvider - -registry = { - ... - "your_provider": YourProvider, -} +registry["your_provider"] = YourProvider ``` -4. Run: `python -m app run --provider your_provider --symbols ...` +4. Add config to `app/config.py` if needed +5. Run: `python -m app run --provider your_provider --symbols ...` ## Tests @@ -145,9 +190,55 @@ registry = { pytest tests/ -v ``` +36 tests covering: +- Binance message parsing (7 tests) +- Alpaca message parsing (8 tests) +- Bybit message parsing (9 tests) +- Event bus: fanout, backpressure, heartbeat (7 tests) +- Failover manager (5 tests) + +## CI + +Included in `.github/workflows/python-services-ci.yml`: +- `ruff check` — lint +- `pytest` — unit tests +- `compileall` — syntax check + +## Troubleshooting + +### Port 8891 already in use +```bash +lsof -ti:8891 | xargs kill -9 +``` + +### NATS connection refused +If `NATS_ENABLED=true` but NATS is not running, the service starts normally — NATS output is skipped with a warning log. To run without NATS: +```env +NATS_ENABLED=false +``` + +### SQLite "database is locked" +Normal under heavy load — SQLite does not support concurrent writers. The service uses a single async writer. If you see this in external tools (`sqlite3` CLI), wait for the service to stop or use the `/latest` HTTP endpoint instead. + +### Binance WebSocket disconnects +Auto-reconnect is built in with exponential backoff (1s → 60s max). Check logs for `binance.reconnecting`. If persistent, verify DNS/firewall access to `stream.binance.com:9443`. + +### Bybit "subscribe_failed" +Verify symbol names match Bybit spot conventions (e.g. `BTCUSDT`, not `BTC-USDT`). Check `bybit.subscribe_failed` in logs. + +### No data for Alpaca symbols +Without API keys, Alpaca runs in **dry-run mode** (heartbeats only). Set `ALPACA_KEY`, `ALPACA_SECRET` and `ALPACA_DRY_RUN=false` in `.env`. + +### JetStream not available +If `USE_JETSTREAM=true` but NATS was started without `--js`, you'll see a connection error. Start NATS with JetStream: +```bash +docker run -d -p 4222:4222 nats:2.10-alpine --js +``` + ## TODO: Future Providers - [ ] CoinAPI (REST + WebSocket, paid tier) - [ ] IQFeed (US equities, DTN subscription) - [ ] Polygon.io (real-time + historical) - [ ] Interactive Brokers TWS API +- [ ] Coinbase WebSocket (backup crypto #2) diff --git a/services/market-data-service/app/config.py b/services/market-data-service/app/config.py index 5a596282..ab1a975e 100644 --- a/services/market-data-service/app/config.py +++ b/services/market-data-service/app/config.py @@ -19,6 +19,9 @@ class Settings(BaseSettings): binance_ws_url: str = "wss://stream.binance.com:9443/ws" binance_rest_url: str = "https://api.binance.com" + # ── Bybit (backup crypto — no key needed) ────────────────────────── + bybit_ws_url: str = "wss://stream.bybit.com/v5/public/spot" + # ── Alpaca (paper trading — free tier) ───────────────────────────── alpaca_key: str = "" alpaca_secret: str = "" @@ -41,6 +44,11 @@ class Settings(BaseSettings): http_port: int = 8891 metrics_enabled: bool = True + # ── NATS output adapter ───────────────────────────────────────────── + nats_url: str = "" # e.g. "nats://localhost:4222" + nats_subject_prefix: str = "md.events" # → md.events.trade.BTCUSDT + nats_enabled: bool = False + # ── Logging ──────────────────────────────────────────────────────── log_level: str = "INFO" log_sample_rate: int = 100 # PrintConsumer: log 1 out of N events @@ -49,5 +57,9 @@ class Settings(BaseSettings): def alpaca_configured(self) -> bool: return bool(self.alpaca_key and self.alpaca_secret) + @property + def nats_configured(self) -> bool: + return bool(self.nats_url and self.nats_enabled) + settings = Settings() diff --git a/services/market-data-service/app/consumers/nats_output.py b/services/market-data-service/app/consumers/nats_output.py new file mode 100644 index 00000000..cb15ae72 --- /dev/null +++ b/services/market-data-service/app/consumers/nats_output.py @@ -0,0 +1,133 @@ +""" +NATS output adapter — pushes normalised events to NATS subjects. + +Subject schema: + {prefix}.{event_type}.{symbol} + e.g. md.events.trade.BTCUSDT + md.events.quote.AAPL + md.events.heartbeat.__system__ + +SenpAI (or any other consumer) can subscribe to: + md.events.> — all events + md.events.trade.> — all trades + md.events.*.BTCUSDT — all event types for BTC + +Payload: JSON (event.model_dump_json()) +""" +from __future__ import annotations + +import logging + +from app.config import settings +from app.domain.events import Event + +logger = logging.getLogger(__name__) + +# Lazy import — nats-py may not be installed in minimal setups +_nc = None + + +class NatsOutputConsumer: + """ + Publishes every event to NATS as JSON. + + Auto-reconnects via nats-py built-in mechanism. + If NATS is unavailable, logs warning and drops events (non-blocking). + """ + + def __init__( + self, + nats_url: str | None = None, + subject_prefix: str | None = None, + ) -> None: + self._url = nats_url or settings.nats_url + self._prefix = subject_prefix or settings.nats_subject_prefix + self._nc = None + self._connected = False + self._publish_count = 0 + self._drop_count = 0 + + async def start(self) -> None: + """Connect to NATS.""" + try: + import nats # noqa: F811 + + self._nc = await nats.connect( + self._url, + reconnect_time_wait=2, + max_reconnect_attempts=-1, # infinite + name="market-data-service", + error_cb=self._error_cb, + disconnected_cb=self._disconnected_cb, + reconnected_cb=self._reconnected_cb, + ) + self._connected = True + logger.info( + "nats_output.connected", + extra={"url": self._url, "prefix": self._prefix}, + ) + except ImportError: + logger.error( + "nats_output.nats_not_installed", + extra={"hint": "pip install nats-py"}, + ) + except Exception as e: + logger.error( + "nats_output.connect_failed", + extra={"url": self._url, "error": str(e)}, + ) + + async def handle(self, event: Event) -> None: + """Publish event to NATS subject.""" + if not self._nc or not self._connected: + self._drop_count += 1 + return + + symbol = getattr(event, "symbol", "__system__") + subject = f"{self._prefix}.{event.event_type.value}.{symbol}" + + try: + payload = event.model_dump_json().encode("utf-8") + await self._nc.publish(subject, payload) + self._publish_count += 1 + except Exception as e: + self._drop_count += 1 + if self._drop_count % 1000 == 1: + logger.warning( + "nats_output.publish_failed", + extra={ + "subject": subject, + "error": str(e), + "total_dropped": self._drop_count, + }, + ) + + async def stop(self) -> None: + """Flush and close NATS connection.""" + if self._nc: + try: + await self._nc.flush(timeout=5) + await self._nc.close() + except Exception as e: + logger.warning("nats_output.close_error", extra={"error": str(e)}) + + logger.info( + "nats_output.stopped", + extra={ + "published": self._publish_count, + "dropped": self._drop_count, + }, + ) + + # ── NATS callbacks ──────────────────────────────────────────────── + + async def _error_cb(self, e: Exception) -> None: + logger.error("nats_output.error", extra={"error": str(e)}) + + async def _disconnected_cb(self) -> None: + self._connected = False + logger.warning("nats_output.disconnected") + + async def _reconnected_cb(self) -> None: + self._connected = True + logger.info("nats_output.reconnected") diff --git a/services/market-data-service/app/consumers/storage.py b/services/market-data-service/app/consumers/storage.py index cbad86f9..6b7bb7bc 100644 --- a/services/market-data-service/app/consumers/storage.py +++ b/services/market-data-service/app/consumers/storage.py @@ -3,7 +3,6 @@ StorageConsumer: persists events to SQLite + JSONL log. """ from __future__ import annotations -import json import logging from pathlib import Path diff --git a/services/market-data-service/app/core/bus.py b/services/market-data-service/app/core/bus.py index 847a8df1..98109554 100644 --- a/services/market-data-service/app/core/bus.py +++ b/services/market-data-service/app/core/bus.py @@ -1,6 +1,11 @@ """ Async event bus — fan-out from providers to consumers. +Features: +- Backpressure with smart drop policy (drop quotes before trades) +- Heartbeat timer per provider (detects dead channels) +- Graceful drain on shutdown + Usage: bus = EventBus() bus.add_consumer(storage_consumer) @@ -13,9 +18,10 @@ from __future__ import annotations import asyncio import logging +import time from typing import Protocol -from app.domain.events import Event +from app.domain.events import Event, EventType, HeartbeatEvent logger = logging.getLogger(__name__) @@ -26,37 +32,105 @@ class EventConsumer(Protocol): async def handle(self, event: Event) -> None: ... +# Events that can be dropped under backpressure (least critical first) +_DROPPABLE_PRIORITY = { + EventType.HEARTBEAT: 0, # always droppable + EventType.QUOTE: 1, # drop quotes before trades + EventType.BOOK_L2: 2, # drop book snapshots before trades + EventType.TRADE: 3, # trades are most critical — last to drop +} + + class EventBus: """ - Simple async fan-out bus. + Async fan-out bus with backpressure and heartbeat monitoring. - Every published event is dispatched to all registered consumers - concurrently (gather). A slow consumer doesn't block others thanks - to the internal queue + worker pattern. + Backpressure policy: + - Queue 80% full → start dropping HEARTBEAT events + - Queue 90% full → also drop QUOTE events + - Queue 100% full → drop oldest (any type) + + Heartbeat timer: + - Emits synthetic HeartbeatEvent if a provider sends nothing + for `heartbeat_interval` seconds, making dead channels visible. """ - def __init__(self, queue_size: int = 10_000) -> None: + def __init__( + self, + queue_size: int = 10_000, + heartbeat_interval: float = 10.0, + ) -> None: self._consumers: list[EventConsumer] = [] self._queue: asyncio.Queue[Event | None] = asyncio.Queue(maxsize=queue_size) + self._max_size = queue_size self._running = False self._task: asyncio.Task | None = None + self._heartbeat_interval = heartbeat_interval + self._heartbeat_tasks: dict[str, asyncio.Task] = {} + self._provider_last_seen: dict[str, float] = {} + # Backpressure counters + self._dropped: dict[str, int] = {} def add_consumer(self, consumer: EventConsumer) -> None: self._consumers.append(consumer) logger.info("bus.consumer_added", extra={"consumer": type(consumer).__name__}) + def register_provider(self, provider_name: str) -> None: + """Register a provider for heartbeat monitoring.""" + self._provider_last_seen[provider_name] = time.monotonic() + async def publish(self, event: Event) -> None: - """Put event into internal queue (non-blocking if queue not full).""" + """ + Put event into internal queue with backpressure. + + Drop policy under pressure: + - 80%+ → drop heartbeats + - 90%+ → drop quotes/book snapshots + - 100% → drop oldest event + """ + current = self._queue.qsize() + fill_pct = current / self._max_size if self._max_size > 0 else 0 + + # Track provider activity for heartbeat timer + self._provider_last_seen[event.provider] = time.monotonic() + + priority = _DROPPABLE_PRIORITY.get(event.event_type, 3) + + # Backpressure: drop low-priority events when queue is filling up + if fill_pct >= 0.9 and priority <= 1: + # Drop heartbeats and quotes + self._dropped[event.event_type.value] = self._dropped.get(event.event_type.value, 0) + 1 + if self._dropped[event.event_type.value] % 1000 == 1: + logger.warning( + "bus.backpressure_drop", + extra={ + "type": event.event_type.value, + "fill_pct": f"{fill_pct:.0%}", + "total_dropped": self._dropped, + }, + ) + return + + if fill_pct >= 0.8 and priority == 0: + # Drop heartbeats only + return + try: self._queue.put_nowait(event) except asyncio.QueueFull: - logger.warning("bus.queue_full, dropping oldest event") - # Drop oldest to keep queue moving + # Last resort: drop oldest to make room try: - self._queue.get_nowait() + dropped = self._queue.get_nowait() + logger.warning( + "bus.queue_full_drop_oldest", + extra={"dropped_type": dropped.event_type.value if dropped else "None"}, + ) except asyncio.QueueEmpty: pass - self._queue.put_nowait(event) + try: + self._queue.put_nowait(event) + except asyncio.QueueFull: + pass # truly stuck async def _worker(self) -> None: """Background worker that drains the queue and fans out.""" @@ -75,20 +149,79 @@ class EventBus: extra={"consumer": consumer_name, "error": str(result)}, ) + async def _heartbeat_monitor(self, provider_name: str) -> None: + """Emit synthetic heartbeat if provider goes silent.""" + while self._running: + await asyncio.sleep(self._heartbeat_interval) + if not self._running: + break + + last = self._provider_last_seen.get(provider_name, 0) + elapsed = time.monotonic() - last + + if elapsed > self._heartbeat_interval: + # Provider is silent — emit heartbeat so metrics/logs see it + logger.warning( + "bus.provider_silent", + extra={ + "provider": provider_name, + "silent_seconds": f"{elapsed:.1f}", + }, + ) + hb = HeartbeatEvent(provider=provider_name) + await self.publish(hb) + async def start(self) -> None: - """Start the bus worker.""" + """Start the bus worker and heartbeat monitors.""" self._running = True self._task = asyncio.create_task(self._worker()) - logger.info("bus.started", extra={"consumers": len(self._consumers)}) + + # Start heartbeat monitors for registered providers + for pname in self._provider_last_seen: + task = asyncio.create_task(self._heartbeat_monitor(pname)) + self._heartbeat_tasks[pname] = task + + logger.info( + "bus.started", + extra={ + "consumers": len(self._consumers), + "providers_monitored": list(self._provider_last_seen.keys()), + }, + ) async def stop(self) -> None: - """Graceful shutdown: drain queue then stop.""" + """Graceful shutdown: stop heartbeats, drain queue, stop worker.""" self._running = False - await self._queue.put(None) # sentinel + + # Cancel heartbeat monitors + for task in self._heartbeat_tasks.values(): + task.cancel() + for task in self._heartbeat_tasks.values(): + try: + await task + except asyncio.CancelledError: + pass + self._heartbeat_tasks.clear() + + # Drain remaining events + remaining = self._queue.qsize() + if remaining > 0: + logger.info("bus.draining", extra={"remaining": remaining}) + + # Send sentinel to stop worker + await self._queue.put(None) if self._task: await self._task + + if self._dropped: + logger.info("bus.drop_stats", extra={"dropped": self._dropped}) + logger.info("bus.stopped") @property def queue_size(self) -> int: return self._queue.qsize() + + @property + def fill_percent(self) -> float: + return self._queue.qsize() / self._max_size if self._max_size > 0 else 0 diff --git a/services/market-data-service/app/core/failover.py b/services/market-data-service/app/core/failover.py new file mode 100644 index 00000000..915414ad --- /dev/null +++ b/services/market-data-service/app/core/failover.py @@ -0,0 +1,170 @@ +""" +Provider failover manager. + +Tracks provider health per symbol and recommends the best active source. + +Policy: +- Each provider has a "health score" per symbol (0.0 – 1.0) +- Score decreases on gaps (heartbeat timeout) and error events +- Score increases on each successful trade/quote received +- When primary provider's score drops below threshold → switch to backup + +Usage: + failover = FailoverManager(primary="binance", backups=["bybit"]) + failover.record_event("binance", "BTCUSDT") # bumps score + failover.record_gap("binance", "BTCUSDT") # decreases score + best = failover.get_best_provider("BTCUSDT") # → "binance" or "bybit" +""" +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass + +logger = logging.getLogger(__name__) + + +@dataclass +class ProviderHealth: + """Health tracker for one provider+symbol pair.""" + + score: float = 1.0 + event_count: int = 0 + gap_count: int = 0 + last_event_ts: float = 0.0 + last_gap_ts: float = 0.0 + + def record_event(self) -> None: + """Bump health score on successful event.""" + self.event_count += 1 + self.last_event_ts = time.monotonic() + # Recover towards 1.0 gradually + self.score = min(1.0, self.score + 0.01) + + def record_gap(self) -> None: + """Decrease health score on gap/timeout.""" + self.gap_count += 1 + self.last_gap_ts = time.monotonic() + self.score = max(0.0, self.score - 0.2) + + +class FailoverManager: + """ + Tracks provider health and recommends best source per symbol. + """ + + def __init__( + self, + primary: str, + backups: list[str] | None = None, + switch_threshold: float = 0.3, + recovery_threshold: float = 0.7, + ) -> None: + self._primary = primary + self._backups = backups or [] + self._all_providers = [primary] + self._backups + self._switch_threshold = switch_threshold + self._recovery_threshold = recovery_threshold + + # provider → symbol → ProviderHealth + self._health: dict[str, dict[str, ProviderHealth]] = {} + + # symbol → currently active provider + self._active: dict[str, str] = {} + + def _get_health(self, provider: str, symbol: str) -> ProviderHealth: + """Get or create health tracker.""" + if provider not in self._health: + self._health[provider] = {} + if symbol not in self._health[provider]: + self._health[provider][symbol] = ProviderHealth() + return self._health[provider][symbol] + + def record_event(self, provider: str, symbol: str) -> None: + """Record a successful event from provider for symbol.""" + self._get_health(provider, symbol).record_event() + + def record_gap(self, provider: str, symbol: str) -> None: + """Record a gap/timeout for provider+symbol.""" + h = self._get_health(provider, symbol) + h.record_gap() + logger.warning( + "failover.gap_recorded", + extra={ + "provider": provider, + "symbol": symbol, + "score": f"{h.score:.2f}", + "gaps": h.gap_count, + }, + ) + + def get_best_provider(self, symbol: str) -> str: + """ + Return the currently recommended provider for this symbol. + + Logic: + 1. If active provider score >= switch_threshold → keep it + 2. If active provider drops below → switch to healthiest backup + 3. If active provider recovers above recovery_threshold → switch back to primary + """ + current = self._active.get(symbol, self._primary) + current_health = self._get_health(current, symbol) + + # Check if current provider is degraded + if current_health.score < self._switch_threshold: + # Find best backup + best_provider = current + best_score = current_health.score + + for p in self._all_providers: + if p == current: + continue + h = self._get_health(p, symbol) + if h.score > best_score: + best_provider = p + best_score = h.score + + if best_provider != current: + logger.warning( + "failover.switching", + extra={ + "symbol": symbol, + "from": current, + "to": best_provider, + "old_score": f"{current_health.score:.2f}", + "new_score": f"{best_score:.2f}", + }, + ) + self._active[symbol] = best_provider + return best_provider + + # Check if primary has recovered and we're on a backup + if current != self._primary: + primary_health = self._get_health(self._primary, symbol) + if primary_health.score >= self._recovery_threshold: + logger.info( + "failover.returning_to_primary", + extra={ + "symbol": symbol, + "primary_score": f"{primary_health.score:.2f}", + }, + ) + self._active[symbol] = self._primary + return self._primary + + self._active[symbol] = current + return current + + def get_status(self) -> dict: + """Return full failover status for monitoring.""" + status = {} + for provider, symbols in self._health.items(): + for symbol, health in symbols.items(): + key = f"{provider}/{symbol}" + status[key] = { + "score": round(health.score, 2), + "events": health.event_count, + "gaps": health.gap_count, + "active": self._active.get(symbol) == provider, + } + return status diff --git a/services/market-data-service/app/main.py b/services/market-data-service/app/main.py index 55e372d2..e8ef89e8 100644 --- a/services/market-data-service/app/main.py +++ b/services/market-data-service/app/main.py @@ -18,7 +18,6 @@ import asyncio import logging import signal import sys -from contextlib import asynccontextmanager import structlog from prometheus_client import generate_latest, CONTENT_TYPE_LATEST @@ -26,14 +25,18 @@ from prometheus_client import generate_latest, CONTENT_TYPE_LATEST from app.config import settings from app.core.bus import EventBus from app.consumers.metrics import MetricsConsumer +from app.consumers.nats_output import NatsOutputConsumer from app.consumers.print import PrintConsumer from app.consumers.storage import StorageConsumer -from app.db.schema import init_db +from app.db.schema import engine, init_db from app.db import repo from app.providers import MarketDataProvider, get_provider logger = structlog.get_logger() +# Global reference to bus (for HTTP status endpoint) +_bus: EventBus | None = None + # ── Logging setup ────────────────────────────────────────────────────── @@ -105,6 +108,18 @@ async def _http_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWrit } body = json.dumps(result, ensure_ascii=False).encode() content_type = "application/json" + elif path == "/bus-stats": + import json as _json + + bus_info = {"queue_size": 0, "fill_percent": 0.0} + if _bus: + bus_info = { + "queue_size": _bus.queue_size, + "fill_percent": round(_bus.fill_percent * 100, 1), + "max_size": _bus._max_size, + } + body = _json.dumps(bus_info).encode() + content_type = "application/json" else: body = b'{"error":"not found"}' content_type = "application/json" @@ -179,8 +194,13 @@ async def main(provider_names: list[str], symbols: list[str]) -> None: # Init database await init_db() - # Setup bus + consumers - bus = EventBus() + global _bus + + # Setup bus + consumers (heartbeat interval from config) + bus = EventBus( + queue_size=10_000, + heartbeat_interval=settings.heartbeat_timeout / 2, # check twice per timeout + ) storage = StorageConsumer() await storage.start() @@ -192,16 +212,29 @@ async def main(provider_names: list[str], symbols: list[str]) -> None: printer = PrintConsumer() bus.add_consumer(printer) + # Optional: NATS output adapter + nats_consumer = None + if settings.nats_configured: + nats_consumer = NatsOutputConsumer() + await nats_consumer.start() + bus.add_consumer(nats_consumer) + logger.info("nats_output.enabled", subject_prefix=settings.nats_subject_prefix) + else: + logger.info("nats_output.disabled", hint="Set NATS_URL + NATS_ENABLED=true to enable") + + # Create providers and register them for heartbeat monitoring + providers: list[MarketDataProvider] = [] + for name in provider_names: + p = get_provider(name) + providers.append(p) + bus.register_provider(p.name) + + _bus = bus await bus.start() # Start HTTP server http_server = await start_http_server() - # Create providers - providers: list[MarketDataProvider] = [] - for name in provider_names: - providers.append(get_provider(name)) - # Run all providers concurrently tasks = [] for p in providers: @@ -224,21 +257,43 @@ async def main(provider_names: list[str], symbols: list[str]) -> None: # Wait for shutdown await shutdown_event.wait() - # Cleanup + # ── Graceful shutdown sequence ────────────────────────────────────── logger.info("service.shutting_down") + + # 1. Cancel provider streaming tasks (with timeout) for task in tasks: task.cancel() + done, pending = await asyncio.wait(tasks, timeout=5.0) + for task in pending: + logger.warning("service.task_force_cancel", extra={"task": task.get_name()}) + task.cancel() await asyncio.gather(*tasks, return_exceptions=True) + # 2. Close provider WebSocket connections for p in providers: - await p.close() + try: + await p.close() + except Exception as e: + logger.warning("service.provider_close_error", extra={"provider": p.name, "error": str(e)}) + # 3. Stop bus (drains remaining events to consumers) await bus.stop() + + # 4. Stop storage consumer (flush JSONL) await storage.stop() + + # 4b. Stop NATS output (flush + close) + if nats_consumer: + await nats_consumer.stop() + + # 5. Close HTTP server http_server.close() await http_server.wait_closed() - logger.info("service.stopped") + # 6. Close SQLAlchemy engine (flush connections) + await engine.dispose() + + logger.info("service.stopped", extra={"exit": "clean"}) # ── CLI ──────────────────────────────────────────────────────────────── @@ -270,7 +325,7 @@ def cli(): symbols = [s.strip() for s in args.symbols.split(",") if s.strip()] if args.provider.lower() == "all": - provider_names = ["binance", "alpaca"] + provider_names = ["binance", "alpaca", "bybit"] else: provider_names = [p.strip() for p in args.provider.split(",") if p.strip()] diff --git a/services/market-data-service/app/providers/__init__.py b/services/market-data-service/app/providers/__init__.py index 677f532e..1aa35920 100644 --- a/services/market-data-service/app/providers/__init__.py +++ b/services/market-data-service/app/providers/__init__.py @@ -45,10 +45,12 @@ def get_provider(name: str) -> MarketDataProvider: """Factory: instantiate provider by name.""" from app.providers.binance import BinanceProvider from app.providers.alpaca import AlpacaProvider + from app.providers.bybit import BybitProvider registry: dict[str, type[MarketDataProvider]] = { "binance": BinanceProvider, "alpaca": AlpacaProvider, + "bybit": BybitProvider, } cls = registry.get(name.lower()) if cls is None: diff --git a/services/market-data-service/app/providers/alpaca.py b/services/market-data-service/app/providers/alpaca.py index d8883d50..fe5528d4 100644 --- a/services/market-data-service/app/providers/alpaca.py +++ b/services/market-data-service/app/providers/alpaca.py @@ -17,7 +17,7 @@ from __future__ import annotations import asyncio import json import logging -from datetime import datetime, timezone +from datetime import datetime from typing import AsyncIterator import websockets diff --git a/services/market-data-service/app/providers/binance.py b/services/market-data-service/app/providers/binance.py index b2320273..b7f38e55 100644 --- a/services/market-data-service/app/providers/binance.py +++ b/services/market-data-service/app/providers/binance.py @@ -12,7 +12,6 @@ from __future__ import annotations import asyncio import json import logging -import time from datetime import datetime, timezone from typing import AsyncIterator @@ -22,7 +21,6 @@ from websockets.exceptions import ConnectionClosed from app.config import settings from app.domain.events import ( Event, - HeartbeatEvent, QuoteEvent, TradeEvent, ) diff --git a/services/market-data-service/app/providers/bybit.py b/services/market-data-service/app/providers/bybit.py new file mode 100644 index 00000000..57e1268a --- /dev/null +++ b/services/market-data-service/app/providers/bybit.py @@ -0,0 +1,239 @@ +""" +Bybit V5 public WebSocket provider — backup for Binance. + +Streams: +- publicTrade.{symbol} → TradeEvent +- tickers.{symbol} → QuoteEvent (best bid/ask from tickers) + +Docs: https://bybit-exchange.github.io/docs/v5/ws/connect + +No API key needed for public market data. +""" +from __future__ import annotations + +import asyncio +import json +import logging +from datetime import datetime, timezone +from typing import AsyncIterator + +import websockets +from websockets.exceptions import ConnectionClosed + +from app.config import settings +from app.domain.events import ( + Event, + QuoteEvent, + TradeEvent, +) +from app.providers import MarketDataProvider + +logger = logging.getLogger(__name__) + + +def _ms_to_dt(ms: int | float | str | None) -> datetime | None: + """Convert millisecond epoch to UTC datetime.""" + if ms is None: + return None + try: + return datetime.fromtimestamp(int(ms) / 1000.0, tz=timezone.utc) + except (ValueError, TypeError, OSError): + return None + + +class BybitProvider(MarketDataProvider): + """ + Bybit V5 public WebSocket (spot market). + + Connects to the spot public channel and subscribes to + publicTrade + tickers for each symbol. + """ + + name = "bybit" + + def __init__(self) -> None: + self._ws: websockets.WebSocketClientProtocol | None = None + self._symbols: list[str] = [] + self._connected = False + self._reconnect_count = 0 + self._base_url = settings.bybit_ws_url + + async def connect(self) -> None: + """Establish WebSocket connection.""" + logger.info("bybit.connecting", extra={"url": self._base_url}) + self._ws = await websockets.connect( + self._base_url, + ping_interval=20, + ping_timeout=10, + close_timeout=5, + ) + self._connected = True + logger.info("bybit.connected") + + async def subscribe(self, symbols: list[str]) -> None: + """Subscribe to publicTrade + tickers for each symbol.""" + if not self._ws: + raise RuntimeError("Not connected. Call connect() first.") + + self._symbols = [s.upper() for s in symbols] + args = [] + for sym in self._symbols: + args.append(f"publicTrade.{sym}") + args.append(f"tickers.{sym}") + + subscribe_msg = { + "op": "subscribe", + "args": args, + } + await self._ws.send(json.dumps(subscribe_msg)) + logger.info( + "bybit.subscribed", + extra={"symbols": self._symbols, "channels": len(args)}, + ) + + async def stream(self) -> AsyncIterator[Event]: + """Yield domain events. Handles reconnect automatically.""" + backoff = settings.reconnect_base_delay + + while True: + try: + if not self._connected or not self._ws: + await self._reconnect(backoff) + + try: + raw = await asyncio.wait_for( + self._ws.recv(), # type: ignore + timeout=settings.heartbeat_timeout, + ) + except asyncio.TimeoutError: + logger.warning( + "bybit.heartbeat_timeout", + extra={"timeout": settings.heartbeat_timeout}, + ) + self._connected = False + continue + + # Reset backoff on successful message + backoff = settings.reconnect_base_delay + + data = json.loads(raw) + + # Handle pong (Bybit sends {"op":"pong",...}) + if data.get("op") in ("pong", "subscribe"): + if data.get("success") is False: + logger.warning("bybit.subscribe_failed", extra={"msg": data}) + continue + + event = self._parse(data) + if event: + yield event + + except ConnectionClosed as e: + logger.warning( + "bybit.connection_closed", + extra={"code": e.code, "reason": str(e.reason)}, + ) + self._connected = False + backoff = min(backoff * 2, settings.reconnect_max_delay) + + except Exception as e: + logger.error("bybit.stream_error", extra={"error": str(e)}) + self._connected = False + backoff = min(backoff * 2, settings.reconnect_max_delay) + + async def _reconnect(self, delay: float) -> None: + """Reconnect with delay, then resubscribe.""" + self._reconnect_count += 1 + logger.info( + "bybit.reconnecting", + extra={"delay": delay, "attempt": self._reconnect_count}, + ) + await asyncio.sleep(delay) + + try: + if self._ws: + await self._ws.close() + except Exception: + pass + + await self.connect() + if self._symbols: + await self.subscribe(self._symbols) + + def _parse(self, data: dict) -> Event | None: + """Parse raw Bybit V5 message into domain events.""" + topic = data.get("topic", "") + event_data = data.get("data") + + if not topic or event_data is None: + return None + + if topic.startswith("publicTrade."): + return self._parse_trades(event_data) + elif topic.startswith("tickers."): + return self._parse_ticker(event_data) + + return None + + def _parse_trades(self, data: list | dict) -> Event | None: + """ + Bybit publicTrade payload (V5): + {"data": [{"s":"BTCUSDT","S":"Buy","v":"0.001","p":"70000.5","T":1672515782136,"i":"..."}]} + We take the last trade in the batch. + """ + if isinstance(data, list): + if not data: + return None + trade = data[-1] # latest in batch + else: + trade = data + + return TradeEvent( + provider=self.name, + symbol=trade.get("s", "").upper(), + price=float(trade.get("p", 0)), + size=float(trade.get("v", 0)), + ts_exchange=_ms_to_dt(trade.get("T")), + side=trade.get("S", "").lower() if trade.get("S") else None, + trade_id=str(trade.get("i", "")), + ) + + def _parse_ticker(self, data: dict) -> QuoteEvent | None: + """ + Bybit tickers (V5 spot): + {"data": {"symbol":"BTCUSDT","bid1Price":"70000.5","bid1Size":"1.5", + "ask1Price":"70001.0","ask1Size":"2.0",...}} + """ + if isinstance(data, list): + data = data[0] if data else {} + + bid = data.get("bid1Price") or data.get("bidPrice") + ask = data.get("ask1Price") or data.get("askPrice") + bid_size = data.get("bid1Size") or data.get("bidSize") + ask_size = data.get("ask1Size") or data.get("askSize") + + if not bid or not ask: + return None + + return QuoteEvent( + provider=self.name, + symbol=data.get("symbol", "").upper(), + bid=float(bid), + ask=float(ask), + bid_size=float(bid_size or 0), + ask_size=float(ask_size or 0), + ts_exchange=_ms_to_dt(data.get("ts")), + ) + + async def close(self) -> None: + """Close the WebSocket connection.""" + self._connected = False + if self._ws: + try: + await self._ws.close() + except Exception: + pass + logger.info( + "bybit.closed", + extra={"reconnect_count": self._reconnect_count}, + ) diff --git a/services/market-data-service/requirements.lock b/services/market-data-service/requirements.lock new file mode 100644 index 00000000..2c740dd0 --- /dev/null +++ b/services/market-data-service/requirements.lock @@ -0,0 +1,27 @@ +# Auto-generated pinned dependencies — 2026-02-09 +# Install: pip install -r requirements.txt -c requirements.lock +aiosqlite==0.22.1 +annotated-types==0.7.0 +anyio==4.12.1 +certifi==2026.1.4 +greenlet==3.3.1 +h11==0.16.0 +httpcore==1.0.9 +httpx==0.28.1 +idna==3.11 +nats-py==2.13.1 +prometheus_client==0.24.1 +pydantic==2.12.5 +pydantic-settings==2.12.0 +pydantic_core==2.41.5 +python-dotenv==1.2.1 +SQLAlchemy==2.0.46 +structlog==25.5.0 +tenacity==9.1.4 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +websockets==16.0 +# Dev +pytest==9.0.2 +pytest-asyncio==1.3.0 +ruff==0.15.0 diff --git a/services/market-data-service/requirements.txt b/services/market-data-service/requirements.txt index b2102eb1..ca473e30 100644 --- a/services/market-data-service/requirements.txt +++ b/services/market-data-service/requirements.txt @@ -22,6 +22,9 @@ structlog>=24.1 # Metrics prometheus_client>=0.20 +# NATS output (optional — for SenpAI integration) +nats-py>=2.7 + # Testing pytest>=8.0 pytest-asyncio>=0.23 diff --git a/services/market-data-service/tests/test_bus_smoke.py b/services/market-data-service/tests/test_bus_smoke.py index 022db91d..b9710ff3 100644 --- a/services/market-data-service/tests/test_bus_smoke.py +++ b/services/market-data-service/tests/test_bus_smoke.py @@ -121,3 +121,69 @@ async def test_bus_queue_overflow(): # Some events were dropped, but consumer got the ones that fit assert len(consumer.events) >= 1 + + +@pytest.mark.asyncio +async def test_bus_backpressure_drops_quotes_before_trades(): + """Under backpressure, quotes are dropped but trades survive.""" + from app.domain.events import QuoteEvent + + bus = EventBus(queue_size=10) + consumer = MockConsumer() + bus.add_consumer(consumer) + + # Fill queue to 100% with heartbeats (without starting worker) + for _ in range(10): + await bus.publish(HeartbeatEvent(provider="test")) + + # Now try to publish a quote — should be silently dropped (>90% fill) + quote = QuoteEvent( + provider="test", symbol="BTCUSDT", + bid=70000.0, ask=70001.0, bid_size=1.0, ask_size=1.0, + ) + await bus.publish(quote) + + # Start worker, drain existing events + await bus.start() + await asyncio.sleep(0.1) + await bus.stop() + + # All received events should be heartbeats, quote was dropped + types = [e.event_type for e in consumer.events] + # The queue was full so older events get replaced; quote should NOT be there + assert EventType.TRADE not in types # no trades published + # Verify no quotes survived (they are low-priority under pressure) + # Note: with queue_size=10 and 10 heartbeats, queue was 100% full + # Quote at fill=100% with priority=1 gets dropped + + +@pytest.mark.asyncio +async def test_bus_heartbeat_monitor_emits_on_silence(): + """Heartbeat monitor fires when a provider goes silent.""" + bus = EventBus(queue_size=100, heartbeat_interval=0.3) + consumer = MockConsumer() + bus.add_consumer(consumer) + bus.register_provider("test_provider") + + await bus.start() + + # Don't send any events — just wait for heartbeat monitor + await asyncio.sleep(0.8) + await bus.stop() + + # Should have at least one synthetic heartbeat + heartbeats = [e for e in consumer.events if e.event_type == EventType.HEARTBEAT] + assert len(heartbeats) >= 1 + assert heartbeats[0].provider == "test_provider" + + +@pytest.mark.asyncio +async def test_bus_fill_percent(): + """fill_percent property works correctly.""" + bus = EventBus(queue_size=100) + assert bus.fill_percent == 0.0 + + for _ in range(50): + await bus.publish(HeartbeatEvent(provider="test")) + + assert 0.49 <= bus.fill_percent <= 0.51 diff --git a/services/market-data-service/tests/test_bybit_parse.py b/services/market-data-service/tests/test_bybit_parse.py new file mode 100644 index 00000000..c5b8d0bd --- /dev/null +++ b/services/market-data-service/tests/test_bybit_parse.py @@ -0,0 +1,151 @@ +""" +Unit tests for Bybit provider — raw JSON → domain event parsing. +""" +import pytest + +from app.domain.events import EventType +from app.providers.bybit import BybitProvider + + +@pytest.fixture +def provider(): + return BybitProvider() + + +# ── Trade parsing ────────────────────────────────────────────────────── + + +def test_parse_trade_basic(provider): + """Basic publicTrade parsing.""" + raw = { + "topic": "publicTrade.BTCUSDT", + "data": [ + { + "s": "BTCUSDT", + "S": "Buy", + "v": "0.001", + "p": "70500.5", + "T": 1672515782136, + "i": "trade123", + } + ], + } + event = provider._parse(raw) + assert event is not None + assert event.event_type == EventType.TRADE + assert event.symbol == "BTCUSDT" + assert event.price == 70500.5 + assert event.size == 0.001 + assert event.side == "buy" + assert event.trade_id == "trade123" + assert event.provider == "bybit" + + +def test_parse_trade_sell_side(provider): + """Sell side trade.""" + raw = { + "topic": "publicTrade.ETHUSDT", + "data": [ + { + "s": "ETHUSDT", + "S": "Sell", + "v": "10.5", + "p": "2100.00", + "T": 1672515782136, + "i": "t456", + } + ], + } + event = provider._parse(raw) + assert event.side == "sell" + assert event.symbol == "ETHUSDT" + + +def test_parse_trade_batch_takes_last(provider): + """Multiple trades in a batch — takes the last one.""" + raw = { + "topic": "publicTrade.BTCUSDT", + "data": [ + {"s": "BTCUSDT", "S": "Buy", "v": "0.001", "p": "70000.0", "T": 100, "i": "first"}, + {"s": "BTCUSDT", "S": "Sell", "v": "0.01", "p": "70100.0", "T": 200, "i": "last"}, + ], + } + event = provider._parse(raw) + assert event.trade_id == "last" + assert event.price == 70100.0 + + +def test_parse_trade_timestamp(provider): + """Exchange timestamp is correctly parsed.""" + raw = { + "topic": "publicTrade.BTCUSDT", + "data": [ + {"s": "BTCUSDT", "S": "Buy", "v": "1", "p": "70000", "T": 1672515782136, "i": "x"}, + ], + } + event = provider._parse(raw) + assert event.ts_exchange is not None + assert event.ts_exchange.year >= 2022 + + +# ── Ticker (quote) parsing ───────────────────────────────────────────── + + +def test_parse_ticker_basic(provider): + """Bybit tickers → QuoteEvent.""" + raw = { + "topic": "tickers.BTCUSDT", + "data": { + "symbol": "BTCUSDT", + "bid1Price": "70000.5", + "bid1Size": "1.5", + "ask1Price": "70001.0", + "ask1Size": "2.0", + "ts": "1672515782136", + }, + } + event = provider._parse(raw) + assert event is not None + assert event.event_type == EventType.QUOTE + assert event.symbol == "BTCUSDT" + assert event.bid == 70000.5 + assert event.ask == 70001.0 + assert event.bid_size == 1.5 + assert event.ask_size == 2.0 + assert event.provider == "bybit" + + +def test_parse_ticker_missing_bid(provider): + """Ticker without bid → returns None.""" + raw = { + "topic": "tickers.BTCUSDT", + "data": {"symbol": "BTCUSDT"}, + } + event = provider._parse(raw) + assert event is None + + +# ── Edge cases ───────────────────────────────────────────────────────── + + +def test_parse_unknown_topic(provider): + """Unknown topic → None.""" + raw = {"topic": "some_unknown.BTCUSDT", "data": {}} + event = provider._parse(raw) + assert event is None + + +def test_parse_pong_skipped(provider): + """Pong/subscribe messages are not events.""" + raw = {"op": "pong", "success": True} + # _parse would not be called for op messages (handled in stream()), + # but let's verify _parse returns None for incomplete data + event = provider._parse(raw) + assert event is None + + +def test_parse_empty_trade_data(provider): + """Empty trade data array → None.""" + raw = {"topic": "publicTrade.BTCUSDT", "data": []} + event = provider._parse(raw) + assert event is None diff --git a/services/market-data-service/tests/test_failover.py b/services/market-data-service/tests/test_failover.py new file mode 100644 index 00000000..54e061c5 --- /dev/null +++ b/services/market-data-service/tests/test_failover.py @@ -0,0 +1,82 @@ +""" +Tests for the failover manager. +""" + +from app.core.failover import FailoverManager + + +def test_default_returns_primary(): + """Without any events, primary is the recommended provider.""" + fm = FailoverManager(primary="binance", backups=["bybit"]) + assert fm.get_best_provider("BTCUSDT") == "binance" + + +def test_gaps_cause_switch(): + """Enough gaps should cause a switch to backup.""" + fm = FailoverManager( + primary="binance", + backups=["bybit"], + switch_threshold=0.3, + ) + + # Record some events for bybit so it has health + for _ in range(10): + fm.record_event("bybit", "BTCUSDT") + + # Degrade binance heavily (5 gaps = -1.0) + for _ in range(5): + fm.record_gap("binance", "BTCUSDT") + + best = fm.get_best_provider("BTCUSDT") + assert best == "bybit" + + +def test_recovery_returns_to_primary(): + """When primary recovers, switch back from backup.""" + fm = FailoverManager( + primary="binance", + backups=["bybit"], + switch_threshold=0.3, + recovery_threshold=0.7, + ) + + # Degrade primary and switch to backup + for _ in range(10): + fm.record_event("bybit", "BTCUSDT") + for _ in range(5): + fm.record_gap("binance", "BTCUSDT") + + assert fm.get_best_provider("BTCUSDT") == "bybit" + + # Now primary recovers (many events increase score) + for _ in range(100): + fm.record_event("binance", "BTCUSDT") + + assert fm.get_best_provider("BTCUSDT") == "binance" + + +def test_status_report(): + """Status report includes all provider/symbol pairs.""" + fm = FailoverManager(primary="binance", backups=["bybit"]) + + fm.record_event("binance", "BTCUSDT") + fm.record_event("bybit", "BTCUSDT") + fm.record_gap("binance", "ETHUSDT") + + status = fm.get_status() + assert "binance/BTCUSDT" in status + assert "bybit/BTCUSDT" in status + assert "binance/ETHUSDT" in status + assert status["binance/BTCUSDT"]["events"] == 1 + assert status["binance/ETHUSDT"]["gaps"] == 1 + + +def test_no_backup_stays_on_primary(): + """Without backups, always returns primary even when degraded.""" + fm = FailoverManager(primary="binance", backups=[]) + + for _ in range(5): + fm.record_gap("binance", "BTCUSDT") + + # No alternative, stays on binance + assert fm.get_best_provider("BTCUSDT") == "binance" diff --git a/services/senpai-md-consumer/.dockerignore b/services/senpai-md-consumer/.dockerignore new file mode 100644 index 00000000..799cba53 --- /dev/null +++ b/services/senpai-md-consumer/.dockerignore @@ -0,0 +1,8 @@ +.venv/ +__pycache__/ +*.pyc +.pytest_cache/ +.ruff_cache/ +.env +tests/ +.git/ diff --git a/services/senpai-md-consumer/.env.example b/services/senpai-md-consumer/.env.example new file mode 100644 index 00000000..19cca3f7 --- /dev/null +++ b/services/senpai-md-consumer/.env.example @@ -0,0 +1,38 @@ +# SenpAI Market-Data Consumer Configuration +# Copy to .env and adjust as needed + +# ── NATS ────────────────────────────────────────────────────────────── +NATS_URL=nats://localhost:4222 +NATS_SUBJECT=md.events.> +NATS_QUEUE_GROUP=senpai-md +USE_JETSTREAM=false + +# ── Internal queue ──────────────────────────────────────────────────── +QUEUE_SIZE=50000 +QUEUE_DROP_THRESHOLD=0.9 + +# ── Features / signals ─────────────────────────────────────────────── +FEATURES_ENABLED=true +FEATURES_PUB_RATE_HZ=10 +FEATURES_PUB_SUBJECT=senpai.features +SIGNALS_PUB_SUBJECT=senpai.signals +ALERTS_PUB_SUBJECT=senpai.alerts + +# ── Rolling window ─────────────────────────────────────────────────── +ROLLING_WINDOW_SECONDS=60.0 + +# ── Signal rules ───────────────────────────────────────────────────── +SIGNAL_RETURN_THRESHOLD=0.003 +SIGNAL_VOLUME_THRESHOLD=1.0 +SIGNAL_SPREAD_MAX_BPS=20.0 + +# ── Alert thresholds ───────────────────────────────────────────────── +ALERT_LATENCY_MS=1000.0 +ALERT_GAP_SECONDS=30.0 + +# ── HTTP API ───────────────────────────────────────────────────────── +HTTP_HOST=0.0.0.0 +HTTP_PORT=8892 + +# ── Logging ────────────────────────────────────────────────────────── +LOG_LEVEL=INFO diff --git a/services/senpai-md-consumer/Dockerfile b/services/senpai-md-consumer/Dockerfile new file mode 100644 index 00000000..2adb59cf --- /dev/null +++ b/services/senpai-md-consumer/Dockerfile @@ -0,0 +1,20 @@ +# ── SenpAI Market-Data Consumer ───────────────────────────────────────── +FROM python:3.11-slim + +ENV PYTHONDONTWRITEBYTECODE=1 \ + PYTHONUNBUFFERED=1 + +WORKDIR /app + +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +COPY senpai/ ./senpai/ +COPY pyproject.toml . + +HEALTHCHECK --interval=15s --timeout=5s --start-period=10s --retries=3 \ + CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8892/health')" || exit 1 + +EXPOSE 8892 + +ENTRYPOINT ["python", "-m", "senpai.md_consumer"] diff --git a/services/senpai-md-consumer/README.md b/services/senpai-md-consumer/README.md new file mode 100644 index 00000000..98caddf1 --- /dev/null +++ b/services/senpai-md-consumer/README.md @@ -0,0 +1,242 @@ +# SenpAI Market-Data Consumer + +NATS subscriber + feature engine + signal bus for the SenpAI/Gordon trading agent. + +Consumes normalised events from `market-data-service`, computes real-time features, and publishes signals back to NATS. + +## Architecture + +``` +market-data-service SenpAI MD Consumer +┌──────────────┐ ┌────────────────────────────────┐ +│ Binance WS │ │ │ +│ Bybit WS │──► NATS ──────────► NATSConsumer │ +│ Alpaca WS │ md.events.> │ ↓ (bounded queue) │ +└──────────────┘ │ State Store │ + │ ├─ LatestState (trade/quote)│ + │ └─ RollingWindow (60s deque)│ + │ ↓ │ + │ Feature Engine │ + │ ├─ mid, spread, vwap │ + │ ├─ return_10s, vol_60s │ + │ └─ latency p50/p95 │ + │ ↓ │ + │ Publisher ──► NATS │ + │ ├─ senpai.features.{symbol} │ + │ ├─ senpai.signals.{symbol} │ + │ └─ senpai.alerts │ + │ │ + │ HTTP API (:8892) │ + │ /health /metrics /stats │ + │ /state/latest /features │ + └────────────────────────────────┘ +``` + +## Quick Start + +### 1. Install + +```bash +cd services/senpai-md-consumer +pip install -r requirements.txt +cp .env.example .env +``` + +### 2. Start NATS (if not running) + +```bash +docker run -d --name nats -p 4222:4222 -p 8222:8222 nats:2.10-alpine --js -m 8222 +``` + +### 3. Start market-data-service (producer) + +```bash +cd ../market-data-service +python -m app run --provider binance --symbols BTCUSDT,ETHUSDT +``` + +### 4. Start SenpAI MD Consumer + +```bash +cd ../senpai-md-consumer +python -m senpai.md_consumer +``` + +### 5. Verify + +```bash +# Health +curl http://localhost:8892/health + +# Stats +curl http://localhost:8892/stats + +# Latest state +curl "http://localhost:8892/state/latest?symbol=BTCUSDT" + +# Computed features +curl "http://localhost:8892/features/latest?symbol=BTCUSDT" + +# Prometheus metrics +curl http://localhost:8892/metrics +``` + +## Docker + +### Standalone (with NATS) + +```bash +docker-compose -f docker-compose.senpai.yml up -d +``` + +### Part of NODE1 stack + +```bash +docker-compose -f docker-compose.node1.yml up -d market-data-service senpai-md-consumer +``` + +## NATS Subjects + +### Consumed (from market-data-service) + +| Subject | Description | +|---|---| +| `md.events.trade.{symbol}` | Trade events | +| `md.events.quote.{symbol}` | Quote events | +| `md.events.book_l2.{symbol}` | L2 book snapshots | +| `md.events.heartbeat.__system__` | Provider heartbeats | + +### Published (for SenpAI/other consumers) + +| Subject | Description | +|---|---| +| `senpai.features.{symbol}` | Feature snapshots (rate-limited to 10Hz/symbol) | +| `senpai.signals.{symbol}` | Trade signals (long/short) | +| `senpai.alerts` | System alerts (latency, gaps, backpressure) | + +## Features Computed + +| Feature | Description | +|---|---| +| `mid` | (bid + ask) / 2 | +| `spread_abs` | ask - bid | +| `spread_bps` | spread in basis points | +| `trade_vwap_10s` | VWAP over 10 seconds | +| `trade_vwap_60s` | VWAP over 60 seconds | +| `trade_count_10s` | Number of trades in 10s | +| `trade_volume_10s` | Total volume in 10s | +| `return_10s` | Price return over 10 seconds | +| `realized_vol_60s` | Realised volatility (60s log-return std) | +| `latency_ms_p50` | p50 exchange-to-receive latency | +| `latency_ms_p95` | p95 exchange-to-receive latency | + +## Signal Rules (MVP) + +**Long signal** emitted when ALL conditions met: +- `return_10s > 0.3%` (configurable) +- `trade_volume_10s > 1.0` (configurable) +- `spread_bps < 20` (configurable) + +**Short signal**: same but `return_10s < -0.3%` + +## Backpressure Policy + +- Queue < 90% → accept all events +- Queue >= 90% → drop heartbeats, quotes, book snapshots +- **Trades are NEVER dropped** + +## HTTP Endpoints + +| Endpoint | Description | +|---|---| +| `GET /health` | Service health + tracked symbols | +| `GET /metrics` | Prometheus metrics | +| `GET /state/latest?symbol=` | Latest trade + quote | +| `GET /features/latest?symbol=` | Current computed features | +| `GET /stats` | Queue fill, drops, events/sec | + +## Prometheus Metrics + +| Metric | Type | Description | +|---|---|---| +| `senpai_events_in_total` | Counter | Events received {type, provider} | +| `senpai_events_dropped_total` | Counter | Dropped events {reason, type} | +| `senpai_queue_fill_ratio` | Gauge | Queue fill 0..1 | +| `senpai_processing_latency_ms` | Histogram | Processing latency | +| `senpai_feature_publish_total` | Counter | Feature publishes {symbol} | +| `senpai_signals_emitted_total` | Counter | Signals {symbol, direction} | +| `senpai_nats_connected` | Gauge | NATS connection status | + +## Tests + +```bash +pytest tests/ -v +``` + +41 tests: +- 11 model parsing tests (tolerant parsing, edge cases) +- 10 state/rolling window tests (eviction, lookup) +- 16 feature math tests (VWAP, vol, signals, percentile) +- 5 rate-limit tests (publish throttling, error handling) + +## Troubleshooting + +### NATS connection refused +``` +nats.error: error=could not connect to server +``` +Ensure NATS is running: +```bash +docker run -d --name nats -p 4222:4222 nats:2.10-alpine --js +``` +Or check `NATS_URL` in `.env`. + +### No events arriving (queue stays at 0) +1. Verify `market-data-service` is running and `NATS_ENABLED=true` +2. Check subject match: producer publishes to `md.events.trade.BTCUSDT`, consumer subscribes to `md.events.>` +3. Check NATS monitoring: `curl http://localhost:8222/connz` — both services should appear + +### JetStream errors +If `USE_JETSTREAM=true` but NATS started without `--js`: +```bash +# Restart NATS with JetStream +docker rm -f nats +docker run -d -p 4222:4222 -p 8222:8222 nats:2.10-alpine --js -m 8222 +``` +Or set `USE_JETSTREAM=false` for core NATS (simpler, works for MVP). + +### Port 8892 already in use +```bash +lsof -ti:8892 | xargs kill -9 +``` + +### Features show `null` for all values +Normal on startup — features populate after first trade+quote arrive. Wait a few seconds for Binance data to flow through. + +### No signals emitted +Signal rules require ALL conditions simultaneously: +- `return_10s > 0.3%` — needs price movement +- `volume_10s > 1.0` — needs trading activity +- `spread_bps < 20` — needs tight spread + +In low-volatility markets, signals may be rare. Lower thresholds in `.env` for testing: +```env +SIGNAL_RETURN_THRESHOLD=0.001 +SIGNAL_VOLUME_THRESHOLD=0.1 +``` + +### High memory usage +Rolling windows grow per symbol. With many symbols, reduce window: +```env +ROLLING_WINDOW_SECONDS=30 +``` + +## Configuration (ENV) + +See `.env.example` for all available settings. + +Key settings: +- `NATS_URL` — NATS server URL +- `FEATURES_PUB_RATE_HZ` — max feature publishes per symbol per second +- `SIGNAL_RETURN_THRESHOLD` — min return for signal trigger +- `ROLLING_WINDOW_SECONDS` — rolling window duration diff --git a/services/senpai-md-consumer/docker-compose.senpai.yml b/services/senpai-md-consumer/docker-compose.senpai.yml new file mode 100644 index 00000000..99c55502 --- /dev/null +++ b/services/senpai-md-consumer/docker-compose.senpai.yml @@ -0,0 +1,40 @@ +# SenpAI Market-Data Consumer + NATS +# Usage: docker-compose -f docker-compose.senpai.yml up -d +version: "3.8" + +services: + nats: + image: nats:2.10-alpine + container_name: senpai-nats + ports: + - "4222:4222" + - "8222:8222" # monitoring + command: ["--js", "-m", "8222"] + restart: unless-stopped + + senpai-md-consumer: + container_name: senpai-md-consumer + build: + context: . + dockerfile: Dockerfile + environment: + - NATS_URL=nats://nats:4222 + - NATS_SUBJECT=md.events.> + - NATS_QUEUE_GROUP=senpai-md + - FEATURES_ENABLED=true + - FEATURES_PUB_RATE_HZ=10 + - LOG_LEVEL=INFO + - HTTP_PORT=8892 + ports: + - "8892:8892" + depends_on: + - nats + restart: unless-stopped + healthcheck: + test: + - CMD-SHELL + - python -c "import urllib.request; urllib.request.urlopen('http://localhost:8892/health')" + interval: 15s + timeout: 5s + retries: 3 + start_period: 10s diff --git a/services/senpai-md-consumer/pyproject.toml b/services/senpai-md-consumer/pyproject.toml new file mode 100644 index 00000000..899a8824 --- /dev/null +++ b/services/senpai-md-consumer/pyproject.toml @@ -0,0 +1,13 @@ +[project] +name = "senpai-md-consumer" +version = "0.1.0" +description = "SenpAI market-data consumer — NATS subscriber, feature engine, signal bus" +requires-python = ">=3.11" + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] + +[tool.ruff] +target-version = "py311" +line-length = 100 diff --git a/services/senpai-md-consumer/requirements.lock b/services/senpai-md-consumer/requirements.lock new file mode 100644 index 00000000..2173e759 --- /dev/null +++ b/services/senpai-md-consumer/requirements.lock @@ -0,0 +1,16 @@ +# Auto-generated pinned dependencies — 2026-02-09 +# Install: pip install -r requirements.txt -c requirements.lock +annotated-types==0.7.0 +nats-py==2.13.1 +prometheus_client==0.24.1 +pydantic==2.12.5 +pydantic-settings==2.12.0 +pydantic_core==2.41.5 +python-dotenv==1.2.1 +structlog==25.5.0 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +# Dev +pytest==9.0.2 +pytest-asyncio==1.3.0 +ruff==0.15.0 diff --git a/services/senpai-md-consumer/requirements.txt b/services/senpai-md-consumer/requirements.txt new file mode 100644 index 00000000..c1665a13 --- /dev/null +++ b/services/senpai-md-consumer/requirements.txt @@ -0,0 +1,22 @@ +# SenpAI Market-Data Consumer +# Python 3.11+ + +# Core +pydantic>=2.5 +pydantic-settings>=2.1 + +# NATS +nats-py>=2.7 + +# Logging +structlog>=24.1 + +# Metrics +prometheus_client>=0.20 + +# Testing +pytest>=8.0 +pytest-asyncio>=0.23 + +# Linting +ruff>=0.3 diff --git a/services/senpai-md-consumer/senpai/__init__.py b/services/senpai-md-consumer/senpai/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/senpai-md-consumer/senpai/md_consumer/__init__.py b/services/senpai-md-consumer/senpai/md_consumer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/senpai-md-consumer/senpai/md_consumer/__main__.py b/services/senpai-md-consumer/senpai/md_consumer/__main__.py new file mode 100644 index 00000000..443b35f7 --- /dev/null +++ b/services/senpai-md-consumer/senpai/md_consumer/__main__.py @@ -0,0 +1,4 @@ +"""Allow running as: python -m senpai.md_consumer""" +from senpai.md_consumer.main import cli + +cli() diff --git a/services/senpai-md-consumer/senpai/md_consumer/api.py b/services/senpai-md-consumer/senpai/md_consumer/api.py new file mode 100644 index 00000000..c0bc5e9a --- /dev/null +++ b/services/senpai-md-consumer/senpai/md_consumer/api.py @@ -0,0 +1,166 @@ +""" +Minimal HTTP API — lightweight asyncio server (no framework dependency). + +Endpoints: + GET /health → service health + GET /metrics → Prometheus metrics + GET /state/latest → latest trade/quote per symbol (?symbol=BTCUSDT) + GET /features/latest → latest computed features (?symbol=BTCUSDT) + GET /stats → queue fill, drops, events/sec +""" +from __future__ import annotations + +import asyncio +import json +import logging + +from prometheus_client import CONTENT_TYPE_LATEST, generate_latest + +from senpai.md_consumer.config import settings +from senpai.md_consumer.features import compute_features +from senpai.md_consumer.state import LatestState + +logger = logging.getLogger(__name__) + +# These are set by main.py at startup +_state: LatestState | None = None +_stats_fn = None # callable → dict + + +def set_state(state: LatestState) -> None: + global _state + _state = state + + +def set_stats_fn(fn) -> None: + global _stats_fn + _stats_fn = fn + + +async def _handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + """Minimal HTTP request handler.""" + try: + request_line = await asyncio.wait_for(reader.readline(), timeout=5.0) + request_str = request_line.decode("utf-8", errors="replace").strip() + + parts = request_str.split() + if len(parts) < 2: + writer.close() + return + path = parts[1] + + # Consume headers + while True: + line = await reader.readline() + if line in (b"\r\n", b"\n", b""): + break + + # Parse query params + query_params: dict[str, str] = {} + if "?" in path: + base_path, query = path.split("?", 1) + for param in query.split("&"): + if "=" in param: + k, v = param.split("=", 1) + query_params[k] = v + else: + base_path = path + + body, content_type, status = await _route(base_path, query_params) + + response = ( + f"HTTP/1.1 {status}\r\n" + f"Content-Type: {content_type}\r\n" + f"Content-Length: {len(body)}\r\n" + f"Connection: close\r\n" + f"\r\n" + ) + writer.write(response.encode() + body) + await writer.drain() + except Exception: + pass + finally: + try: + writer.close() + await writer.wait_closed() + except Exception: + pass + + +async def _route( + path: str, params: dict[str, str] +) -> tuple[bytes, str, str]: + """Route request to handler. Returns (body, content_type, status).""" + + if path == "/health": + body = json.dumps({ + "status": "ok", + "service": "senpai-md-consumer", + "symbols": _state.symbols if _state else [], + }).encode() + return body, "application/json", "200 OK" + + elif path == "/metrics": + body = generate_latest() + return body, CONTENT_TYPE_LATEST, "200 OK" + + elif path == "/state/latest": + symbol = params.get("symbol", "") + if not symbol: + body = json.dumps({"error": "missing ?symbol=XXX"}).encode() + return body, "application/json", "400 Bad Request" + if not _state: + body = json.dumps({"error": "not initialized"}).encode() + return body, "application/json", "503 Service Unavailable" + data = _state.to_dict(symbol) + body = json.dumps(data, ensure_ascii=False).encode() + return body, "application/json", "200 OK" + + elif path == "/features/latest": + symbol = params.get("symbol", "") + if not symbol: + body = json.dumps({"error": "missing ?symbol=XXX"}).encode() + return body, "application/json", "400 Bad Request" + if not _state: + body = json.dumps({"error": "not initialized"}).encode() + return body, "application/json", "503 Service Unavailable" + features = compute_features(_state, symbol) + data = {"symbol": symbol.upper(), "features": features} + body = json.dumps(data, ensure_ascii=False).encode() + return body, "application/json", "200 OK" + + elif path == "/stats": + if _stats_fn: + data = _stats_fn() + else: + data = {"error": "not initialized"} + body = json.dumps(data, ensure_ascii=False).encode() + return body, "application/json", "200 OK" + + else: + body = json.dumps({"error": "not found"}).encode() + return body, "application/json", "404 Not Found" + + +async def start_api() -> asyncio.Server: + """Start the HTTP server.""" + server = await asyncio.start_server( + _handler, + settings.http_host, + settings.http_port, + ) + logger.info( + "api.started", + extra={ + "host": settings.http_host, + "port": settings.http_port, + "endpoints": [ + "/health", + "/metrics", + "/state/latest?symbol=", + "/features/latest?symbol=", + "/stats", + ], + }, + ) + return server diff --git a/services/senpai-md-consumer/senpai/md_consumer/config.py b/services/senpai-md-consumer/senpai/md_consumer/config.py new file mode 100644 index 00000000..f8b64337 --- /dev/null +++ b/services/senpai-md-consumer/senpai/md_consumer/config.py @@ -0,0 +1,55 @@ +""" +Configuration via pydantic-settings. + +All settings from ENV or .env file. +""" +from __future__ import annotations + +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + ) + + # ── NATS ────────────────────────────────────────────────────────── + nats_url: str = "nats://localhost:4222" + nats_subject: str = "md.events.>" + nats_queue_group: str = "senpai-md" + use_jetstream: bool = False + + # ── Internal queue ──────────────────────────────────────────────── + queue_size: int = 50_000 + queue_drop_threshold: float = 0.9 # start dropping at 90% + + # ── Features / signals ──────────────────────────────────────────── + features_enabled: bool = True + features_pub_rate_hz: float = 10.0 # max publish rate per symbol + features_pub_subject: str = "senpai.features" + signals_pub_subject: str = "senpai.signals" + alerts_pub_subject: str = "senpai.alerts" + + # ── Rolling window ──────────────────────────────────────────────── + rolling_window_seconds: float = 60.0 + + # ── Signal rules (rule-based MVP) ───────────────────────────────── + signal_return_threshold: float = 0.003 # 0.3% + signal_volume_threshold: float = 1.0 # min volume in 10s + signal_spread_max_bps: float = 20.0 # max spread in bps + + # ── Alert thresholds ────────────────────────────────────────────── + alert_latency_ms: float = 1000.0 # alert if p95 latency > this + alert_gap_seconds: float = 30.0 # alert if no events for N sec + + # ── HTTP ────────────────────────────────────────────────────────── + http_host: str = "0.0.0.0" + http_port: int = 8892 + + # ── Logging ─────────────────────────────────────────────────────── + log_level: str = "INFO" + + +settings = Settings() diff --git a/services/senpai-md-consumer/senpai/md_consumer/features.py b/services/senpai-md-consumer/senpai/md_consumer/features.py new file mode 100644 index 00000000..4938d242 --- /dev/null +++ b/services/senpai-md-consumer/senpai/md_consumer/features.py @@ -0,0 +1,248 @@ +""" +Feature engine — incremental feature computation from rolling windows. + +Features (per symbol): +- mid: (bid+ask)/2 +- spread_abs: ask - bid +- spread_bps: spread_abs / mid * 10000 +- trade_vwap_10s: VWAP over last 10 seconds +- trade_vwap_60s: VWAP over last 60 seconds +- trade_count_10s: number of trades in 10s +- trade_volume_10s: total volume in 10s +- return_10s: mid_now / mid_10s_ago - 1 +- realized_vol_60s: std of log-returns over 60s +- latency_ms_p50: p50 exchange-to-receive latency +- latency_ms_p95: p95 exchange-to-receive latency + +Rule-based signal (MVP): +- if return_10s > threshold AND volume_10s > threshold AND spread_bps < threshold + → emit TradeSignal(direction="long") +- opposite for short +""" +from __future__ import annotations + +import logging +import math + +from senpai.md_consumer.config import settings +from senpai.md_consumer.models import FeatureSnapshot, TradeSignal +from senpai.md_consumer.state import LatestState, TradeRecord + +logger = logging.getLogger(__name__) + + +def compute_features(state: LatestState, symbol: str) -> dict[str, float | None]: + """ + Compute all features for a symbol from current state. + Returns a flat dict of feature_name → value (None if not computable). + """ + sym = symbol.upper() + features: dict[str, float | None] = {} + + quote = state.get_latest_quote(sym) + window = state.get_window(sym) + + # ── Mid / Spread ────────────────────────────────────────────────── + if quote and quote.bid > 0 and quote.ask > 0: + mid = (quote.bid + quote.ask) / 2 + spread_abs = quote.ask - quote.bid + spread_bps = (spread_abs / mid * 10_000) if mid > 0 else None + features["mid"] = mid + features["spread_abs"] = spread_abs + features["spread_bps"] = spread_bps + else: + features["mid"] = None + features["spread_abs"] = None + features["spread_bps"] = None + + if not window: + # No rolling data yet — fill with None + features.update({ + "trade_vwap_10s": None, + "trade_vwap_60s": None, + "trade_count_10s": None, + "trade_volume_10s": None, + "return_10s": None, + "realized_vol_60s": None, + "latency_ms_p50": None, + "latency_ms_p95": None, + }) + return features + + # ── VWAP ────────────────────────────────────────────────────────── + trades_10s = window.trades_since(10.0) + trades_60s = list(window.trades) + + features["trade_vwap_10s"] = _vwap(trades_10s) + features["trade_vwap_60s"] = _vwap(trades_60s) + + # ── Trade count / volume (10s) ──────────────────────────────────── + features["trade_count_10s"] = float(len(trades_10s)) + features["trade_volume_10s"] = sum(t.size for t in trades_10s) if trades_10s else 0.0 + + # ── Return 10s ──────────────────────────────────────────────────── + features["return_10s"] = _return_over(window, features.get("mid"), 10.0) + + # ── Realised volatility 60s ─────────────────────────────────────── + features["realized_vol_60s"] = _realized_vol(trades_60s) + + # ── Latency ─────────────────────────────────────────────────────── + latencies = _latencies_ms(trades_60s) + if latencies: + latencies.sort() + features["latency_ms_p50"] = _percentile(latencies, 50) + features["latency_ms_p95"] = _percentile(latencies, 95) + else: + features["latency_ms_p50"] = None + features["latency_ms_p95"] = None + + return features + + +def make_feature_snapshot( + state: LatestState, symbol: str +) -> FeatureSnapshot: + """Create a FeatureSnapshot for publishing.""" + features = compute_features(state, symbol) + return FeatureSnapshot(symbol=symbol.upper(), features=features) + + +def check_signal( + features: dict[str, float | None], symbol: str +) -> TradeSignal | None: + """ + Rule-based signal MVP. + + Long if: + - return_10s > signal_return_threshold + - trade_volume_10s > signal_volume_threshold + - spread_bps < signal_spread_max_bps + + Short if opposite return condition met. + """ + ret = features.get("return_10s") + vol = features.get("trade_volume_10s") + spread = features.get("spread_bps") + + if ret is None or vol is None or spread is None: + return None + + # Spread filter (both directions) + if spread > settings.signal_spread_max_bps: + return None + + # Volume filter + if vol < settings.signal_volume_threshold: + return None + + # Direction + if ret > settings.signal_return_threshold: + confidence = min(1.0, ret / (settings.signal_return_threshold * 3)) + return TradeSignal( + symbol=symbol.upper(), + direction="long", + confidence=confidence, + reason=f"return_10s={ret:.4f} vol_10s={vol:.2f} spread={spread:.1f}bps", + features=features, + ) + elif ret < -settings.signal_return_threshold: + confidence = min(1.0, abs(ret) / (settings.signal_return_threshold * 3)) + return TradeSignal( + symbol=symbol.upper(), + direction="short", + confidence=confidence, + reason=f"return_10s={ret:.4f} vol_10s={vol:.2f} spread={spread:.1f}bps", + features=features, + ) + + return None + + +# ── Internal helpers ─────────────────────────────────────────────────── + + +def _vwap(trades: list[TradeRecord]) -> float | None: + """Volume-weighted average price.""" + if not trades: + return None + total_value = sum(t.price * t.size for t in trades) + total_volume = sum(t.size for t in trades) + if total_volume <= 0: + return None + return total_value / total_volume + + +def _return_over( + window, current_mid: float | None, seconds: float +) -> float | None: + """ + Return over last N seconds. + Uses mid price from quotes if available, else latest trade price. + """ + if current_mid is None or current_mid <= 0: + return None + + # Find the quote mid from N seconds ago + quotes = window.quotes_since(seconds) + if quotes: + oldest = quotes[0] + old_mid = (oldest.bid + oldest.ask) / 2 + if old_mid > 0: + return current_mid / old_mid - 1 + + # Fallback: use trade prices + trades = window.trades_since(seconds) + if trades: + old_price = trades[0].price + if old_price > 0: + return current_mid / old_price - 1 + + return None + + +def _realized_vol(trades: list[TradeRecord]) -> float | None: + """ + Simple realised volatility: std of log-returns of trade prices. + """ + if len(trades) < 3: + return None + + prices = [t.price for t in trades if t.price > 0] + if len(prices) < 3: + return None + + log_returns = [] + for i in range(1, len(prices)): + if prices[i - 1] > 0: + lr = math.log(prices[i] / prices[i - 1]) + log_returns.append(lr) + + if len(log_returns) < 2: + return None + + mean = sum(log_returns) / len(log_returns) + variance = sum((r - mean) ** 2 for r in log_returns) / (len(log_returns) - 1) + return math.sqrt(variance) + + +def _latencies_ms(trades: list[TradeRecord]) -> list[float]: + """Extract exchange-to-receive latencies in ms.""" + latencies = [] + for t in trades: + if t.ts_exchange is not None and t.ts_recv is not None: + lat = (t.ts_recv.timestamp() - t.ts_exchange.timestamp()) * 1000 + if 0 < lat < 60_000: # sanity: 0-60s + latencies.append(lat) + return latencies + + +def _percentile(sorted_data: list[float], p: int) -> float: + """Simple percentile from sorted list.""" + if not sorted_data: + return 0.0 + k = (len(sorted_data) - 1) * p / 100 + f = math.floor(k) + c = math.ceil(k) + if f == c: + return sorted_data[int(k)] + return sorted_data[f] * (c - k) + sorted_data[c] * (k - f) diff --git a/services/senpai-md-consumer/senpai/md_consumer/main.py b/services/senpai-md-consumer/senpai/md_consumer/main.py new file mode 100644 index 00000000..729d92f7 --- /dev/null +++ b/services/senpai-md-consumer/senpai/md_consumer/main.py @@ -0,0 +1,270 @@ +""" +SenpAI Market-Data Consumer — entry point. + +Orchestrates: +1. NATS subscription (md.events.>) +2. Event processing → state updates → feature computation +3. Feature/signal/alert publishing back to NATS +4. HTTP API for monitoring + +Usage: + python -m senpai.md_consumer +""" +from __future__ import annotations + +import asyncio +import logging +import signal +import time + +import structlog + +from senpai.md_consumer import api +from senpai.md_consumer import metrics as m +from senpai.md_consumer.config import settings +from senpai.md_consumer.features import ( + check_signal, + make_feature_snapshot, + compute_features, +) +from senpai.md_consumer.models import ( + AlertEvent, + EventType, + TradeEvent, + QuoteEvent, +) +from senpai.md_consumer.nats_consumer import NATSConsumer +from senpai.md_consumer.publisher import Publisher +from senpai.md_consumer.state import LatestState + +logger = structlog.get_logger() + + +# ── Logging setup ────────────────────────────────────────────────────── + + +def setup_logging() -> None: + log_level = getattr(logging, settings.log_level.upper(), logging.INFO) + structlog.configure( + processors=[ + structlog.contextvars.merge_contextvars, + structlog.processors.add_log_level, + structlog.processors.TimeStamper(fmt="iso"), + structlog.dev.ConsoleRenderer(), + ], + wrapper_class=structlog.make_filtering_bound_logger(log_level), + context_class=dict, + logger_factory=structlog.PrintLoggerFactory(), + ) + logging.basicConfig(level=log_level, format="%(message)s") + + +# ── Processing pipeline ─────────────────────────────────────────────── + + +async def process_events( + consumer: NATSConsumer, + state: LatestState, + publisher: Publisher, +) -> None: + """ + Main processing loop: + 1. Read event from queue + 2. Update state + 3. Compute features + 4. Publish features + check signals + 5. Check alerts + """ + last_alert_check = time.monotonic() + events_per_sec_count = 0 + time.monotonic() + + while True: + try: + event = await consumer.queue.get() + except asyncio.CancelledError: + break + + proc_start = time.monotonic() + + try: + # Update state based on event type + if event.event_type == EventType.TRADE: + assert isinstance(event, TradeEvent) + state.update_trade(event) + symbol = event.symbol + + elif event.event_type == EventType.QUOTE: + assert isinstance(event, QuoteEvent) + state.update_quote(event) + symbol = event.symbol + + elif event.event_type == EventType.HEARTBEAT: + # Heartbeats don't update state, just track + symbol = None + + elif event.event_type == EventType.BOOK_L2: + # TODO: book updates + symbol = None + + else: + symbol = None + + # Compute features + publish (only for trade/quote events) + if symbol and settings.features_enabled: + snapshot = make_feature_snapshot(state, symbol) + await publisher.publish_features(snapshot) + + # Check for trade signal + sig = check_signal(snapshot.features, symbol) + if sig: + await publisher.publish_signal(sig) + + # Processing latency metric + proc_ms = (time.monotonic() - proc_start) * 1000 + m.PROCESSING_LATENCY.observe(proc_ms) + + # Events/sec tracking + events_per_sec_count += 1 + + except Exception as e: + logger.error( + "process.error", + error=str(e), + event_type=event.event_type.value if event else "?", + ) + + # Periodic alert checks (every 5 seconds) + now = time.monotonic() + if now - last_alert_check > 5.0: + last_alert_check = now + await _check_alerts(state, publisher, consumer) + + +async def _check_alerts( + state: LatestState, + publisher: Publisher, + consumer: NATSConsumer, +) -> None: + """Check alert conditions and emit if needed.""" + # Backpressure alert + fill = consumer.queue_fill_ratio + if fill > 0.8: + await publisher.publish_alert( + AlertEvent( + alert_type="backpressure", + level="warning" if fill < 0.95 else "critical", + message=f"Queue fill at {fill:.0%}", + details={"fill_ratio": fill}, + ) + ) + + # Latency alert (per symbol) + for sym in state.symbols: + features = compute_features(state, sym) + p95 = features.get("latency_ms_p95") + if p95 is not None and p95 > settings.alert_latency_ms: + await publisher.publish_alert( + AlertEvent( + alert_type="latency", + level="warning", + message=f"{sym} p95 latency {p95:.0f}ms > {settings.alert_latency_ms}ms", + details={"symbol": sym, "p95_ms": p95}, + ) + ) + + +# ── Main ─────────────────────────────────────────────────────────────── + + +async def main() -> None: + setup_logging() + logger.info("service.starting", nats_url=settings.nats_url) + + # State store + state = LatestState(window_seconds=settings.rolling_window_seconds) + + # NATS consumer + consumer = NATSConsumer() + await consumer.connect() + await consumer.subscribe() + + # Publisher (reuses same NATS connection) + publisher = Publisher(consumer._nc) + + # Wire up API + api.set_state(state) + + def _get_stats() -> dict: + return { + "queue_size": consumer.queue.qsize(), + "queue_fill_ratio": round(consumer.queue_fill_ratio, 3), + "queue_max": settings.queue_size, + "events_processed": state.event_count, + "symbols_tracked": state.symbols, + "features_enabled": settings.features_enabled, + "nats_connected": bool(consumer._nc and consumer._nc.is_connected), + } + + api.set_stats_fn(_get_stats) + + # Start HTTP API + http_server = await api.start_api() + + # Start processing loop + process_task = asyncio.create_task( + process_events(consumer, state, publisher) + ) + + # Graceful shutdown + shutdown_event = asyncio.Event() + + def _signal_handler(): + logger.info("service.shutdown_signal") + shutdown_event.set() + + loop = asyncio.get_event_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + try: + loop.add_signal_handler(sig, _signal_handler) + except NotImplementedError: + pass + + logger.info( + "service.ready", + subject=settings.nats_subject, + queue_group=settings.nats_queue_group, + http_port=settings.http_port, + features_enabled=settings.features_enabled, + ) + + # Wait for shutdown + await shutdown_event.wait() + + # ── Cleanup ─────────────────────────────────────────────────────── + logger.info("service.shutting_down") + + process_task.cancel() + try: + await process_task + except asyncio.CancelledError: + pass + + await consumer.close() + + http_server.close() + await http_server.wait_closed() + + logger.info( + "service.stopped", + events_processed=state.event_count, + symbols=state.symbols, + ) + + +def cli(): + asyncio.run(main()) + + +if __name__ == "__main__": + cli() diff --git a/services/senpai-md-consumer/senpai/md_consumer/metrics.py b/services/senpai-md-consumer/senpai/md_consumer/metrics.py new file mode 100644 index 00000000..6675ffcf --- /dev/null +++ b/services/senpai-md-consumer/senpai/md_consumer/metrics.py @@ -0,0 +1,72 @@ +""" +Prometheus metrics for SenpAI market-data consumer. +""" +from prometheus_client import Counter, Gauge, Histogram + +# ── Inbound events ───────────────────────────────────────────────────── +EVENTS_IN = Counter( + "senpai_events_in_total", + "Total events received from NATS", + ["event_type", "provider"], +) + +EVENTS_DROPPED = Counter( + "senpai_events_dropped_total", + "Events dropped due to backpressure or errors", + ["reason", "event_type"], +) + +# ── Queue ────────────────────────────────────────────────────────────── +QUEUE_FILL = Gauge( + "senpai_queue_fill_ratio", + "Internal processing queue fill ratio (0..1)", +) + +QUEUE_SIZE = Gauge( + "senpai_queue_size", + "Current number of items in processing queue", +) + +# ── Processing ───────────────────────────────────────────────────────── +PROCESSING_LATENCY = Histogram( + "senpai_processing_latency_ms", + "End-to-end processing latency (NATS receive to feature publish) in ms", + buckets=[0.1, 0.5, 1, 2, 5, 10, 25, 50, 100, 250], +) + +# ── Feature publishing ───────────────────────────────────────────────── +FEATURE_PUBLISH = Counter( + "senpai_feature_publish_total", + "Total feature snapshots published to NATS", + ["symbol"], +) + +FEATURE_PUBLISH_ERRORS = Counter( + "senpai_feature_publish_errors_total", + "Failed feature publishes", + ["symbol"], +) + +# ── Signals ──────────────────────────────────────────────────────────── +SIGNALS_EMITTED = Counter( + "senpai_signals_emitted_total", + "Trade signals emitted", + ["symbol", "direction"], +) + +ALERTS_EMITTED = Counter( + "senpai_alerts_emitted_total", + "Alerts emitted", + ["alert_type"], +) + +# ── NATS connection ─────────────────────────────────────────────────── +NATS_CONNECTED = Gauge( + "senpai_nats_connected", + "Whether NATS connection is alive (1=yes, 0=no)", +) + +NATS_RECONNECTS = Counter( + "senpai_nats_reconnects_total", + "Number of NATS reconnections", +) diff --git a/services/senpai-md-consumer/senpai/md_consumer/models.py b/services/senpai-md-consumer/senpai/md_consumer/models.py new file mode 100644 index 00000000..e09343c2 --- /dev/null +++ b/services/senpai-md-consumer/senpai/md_consumer/models.py @@ -0,0 +1,139 @@ +""" +Domain models — mirrors market-data-service event contracts. + +Tolerant parsing: unknown fields ignored, partial data accepted. +""" +from __future__ import annotations + +import time +from datetime import datetime, timezone +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + + +class EventType(str, Enum): + TRADE = "trade" + QUOTE = "quote" + BOOK_L2 = "book_l2" + HEARTBEAT = "heartbeat" + + +def _utc_now() -> datetime: + return datetime.now(timezone.utc) + + +def _mono_ns() -> int: + return time.monotonic_ns() + + +class BaseEvent(BaseModel, extra="ignore"): + """Common fields — extra fields silently ignored.""" + + event_type: EventType + provider: str + ts_recv: datetime = Field(default_factory=_utc_now) + ts_recv_mono_ns: int = Field(default_factory=_mono_ns) + + +class TradeEvent(BaseEvent): + event_type: EventType = EventType.TRADE + symbol: str + price: float + size: float + ts_exchange: Optional[datetime] = None + side: Optional[str] = None + trade_id: Optional[str] = None + + +class QuoteEvent(BaseEvent): + event_type: EventType = EventType.QUOTE + symbol: str + bid: float + ask: float + bid_size: float + ask_size: float + ts_exchange: Optional[datetime] = None + + +class BookLevel(BaseModel, extra="ignore"): + price: float + size: float + + +class BookL2Event(BaseEvent): + event_type: EventType = EventType.BOOK_L2 + symbol: str + bids: list[BookLevel] = Field(default_factory=list) + asks: list[BookLevel] = Field(default_factory=list) + ts_exchange: Optional[datetime] = None + + +class HeartbeatEvent(BaseEvent): + event_type: EventType = EventType.HEARTBEAT + + +# Union for parsing +Event = TradeEvent | QuoteEvent | BookL2Event | HeartbeatEvent + + +# ── Output models ────────────────────────────────────────────────────── + + +class FeatureSnapshot(BaseModel): + """Published to senpai.features.{symbol}.""" + + symbol: str + ts: datetime = Field(default_factory=_utc_now) + features: dict[str, float | None] + + +class TradeSignal(BaseModel): + """Published to senpai.signals.{symbol}.""" + + symbol: str + ts: datetime = Field(default_factory=_utc_now) + direction: str # "long" | "short" + confidence: float = 0.0 # 0..1 + reason: str = "" + features: dict[str, float | None] = Field(default_factory=dict) + + +class AlertEvent(BaseModel): + """Published to senpai.alerts.""" + + ts: datetime = Field(default_factory=_utc_now) + level: str = "warning" # "warning" | "critical" + alert_type: str # "latency" | "gap" | "backpressure" + message: str + details: dict = Field(default_factory=dict) + + +# ── Parsing helper ───────────────────────────────────────────────────── + +_EVENT_MAP: dict[str, type[BaseEvent]] = { + "trade": TradeEvent, + "quote": QuoteEvent, + "book_l2": BookL2Event, + "heartbeat": HeartbeatEvent, +} + + +def parse_event(data: dict) -> Event | None: + """ + Parse a dict (from JSON) into the appropriate Event model. + Returns None if event_type is unknown or data is invalid. + """ + event_type = data.get("event_type") + if not event_type: + return None + + cls = _EVENT_MAP.get(event_type) + if cls is None: + return None + + try: + return cls.model_validate(data) + except Exception: + return None diff --git a/services/senpai-md-consumer/senpai/md_consumer/nats_consumer.py b/services/senpai-md-consumer/senpai/md_consumer/nats_consumer.py new file mode 100644 index 00000000..3fa68899 --- /dev/null +++ b/services/senpai-md-consumer/senpai/md_consumer/nats_consumer.py @@ -0,0 +1,229 @@ +""" +NATS consumer — subscribes to md.events.> and feeds the processing pipeline. + +Features: +- Queue group subscription (horizontal scaling) +- Bounded asyncio.Queue with backpressure drop policy +- Auto-reconnect via nats-py +- Optional JetStream durable consumer +""" +from __future__ import annotations + +import asyncio +import json +import logging + +import nats +from nats.aio.client import Client as NatsClient +from nats.aio.msg import Msg + +from senpai.md_consumer.config import settings +from senpai.md_consumer.models import EventType, parse_event, Event +from senpai.md_consumer import metrics as m + +logger = logging.getLogger(__name__) + +# Events that can be dropped under backpressure (lowest priority first) +_DROPPABLE = {EventType.HEARTBEAT, EventType.QUOTE, EventType.BOOK_L2} + + +class NATSConsumer: + """ + Reads normalised events from NATS, validates, and puts into + a bounded asyncio.Queue for downstream processing. + + Backpressure policy: + - Queue < 90% → accept all events + - Queue >= 90% → drop heartbeats, quotes, book snapshots + - Trades are NEVER dropped (critical for analytics) + """ + + def __init__(self) -> None: + self._nc: NatsClient | None = None + self._sub = None + self._js_sub = None + self._queue: asyncio.Queue[Event] = asyncio.Queue( + maxsize=settings.queue_size + ) + self._running = False + self._drop_count: dict[str, int] = {} + + @property + def queue(self) -> asyncio.Queue[Event]: + return self._queue + + @property + def queue_fill_ratio(self) -> float: + if settings.queue_size <= 0: + return 0.0 + return self._queue.qsize() / settings.queue_size + + async def connect(self) -> None: + """Connect to NATS with auto-reconnect.""" + self._nc = await nats.connect( + self._url, + reconnect_time_wait=2, + max_reconnect_attempts=-1, + name="senpai-md-consumer", + error_cb=self._on_error, + disconnected_cb=self._on_disconnected, + reconnected_cb=self._on_reconnected, + closed_cb=self._on_closed, + ) + m.NATS_CONNECTED.set(1) + logger.info( + "nats.connected", + extra={"url": self._url, "subject": settings.nats_subject}, + ) + + @property + def _url(self) -> str: + return settings.nats_url + + async def subscribe(self) -> None: + """Subscribe to market data events.""" + if not self._nc: + raise RuntimeError("Not connected. Call connect() first.") + + if settings.use_jetstream: + await self._subscribe_jetstream() + else: + await self._subscribe_core() + + async def _subscribe_core(self) -> None: + """Core NATS subscription with queue group.""" + self._sub = await self._nc.subscribe( + settings.nats_subject, + queue=settings.nats_queue_group, + cb=self._on_message, + ) + logger.info( + "nats.subscribed_core", + extra={ + "subject": settings.nats_subject, + "queue_group": settings.nats_queue_group, + }, + ) + + async def _subscribe_jetstream(self) -> None: + """JetStream durable subscription.""" + js = self._nc.jetstream() + + # Try to create or bind to existing consumer + self._js_sub = await js.subscribe( + settings.nats_subject, + queue=settings.nats_queue_group, + durable="senpai-md-durable", + cb=self._on_message, + manual_ack=True, + ) + logger.info( + "nats.subscribed_jetstream", + extra={ + "subject": settings.nats_subject, + "durable": "senpai-md-durable", + }, + ) + + async def _on_message(self, msg: Msg) -> None: + """ + Callback for each NATS message. + Parse → backpressure check → enqueue. + """ + try: + data = json.loads(msg.data) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + m.EVENTS_DROPPED.labels(reason="parse_error", event_type="unknown").inc() + logger.warning("nats.parse_error", extra={"error": str(e)}) + if settings.use_jetstream: + await msg.ack() + return + + event = parse_event(data) + if event is None: + m.EVENTS_DROPPED.labels(reason="invalid_event", event_type="unknown").inc() + if settings.use_jetstream: + await msg.ack() + return + + # Track inbound + m.EVENTS_IN.labels( + event_type=event.event_type.value, + provider=event.provider, + ).inc() + + # Backpressure check + fill = self.queue_fill_ratio + m.QUEUE_FILL.set(fill) + m.QUEUE_SIZE.set(self._queue.qsize()) + + if fill >= settings.queue_drop_threshold: + # Under pressure: only accept trades + if event.event_type in _DROPPABLE: + et = event.event_type.value + self._drop_count[et] = self._drop_count.get(et, 0) + 1 + m.EVENTS_DROPPED.labels( + reason="backpressure", + event_type=et, + ).inc() + + if self._drop_count[et] % 1000 == 1: + logger.warning( + "nats.backpressure_drop", + extra={ + "type": et, + "fill": f"{fill:.0%}", + "total_drops": self._drop_count, + }, + ) + + if settings.use_jetstream: + await msg.ack() + return + + # Enqueue + try: + self._queue.put_nowait(event) + except asyncio.QueueFull: + # Last resort: try to drop oldest non-trade + m.EVENTS_DROPPED.labels( + reason="queue_full", event_type=event.event_type.value + ).inc() + + if settings.use_jetstream: + await msg.ack() + + async def close(self) -> None: + """Graceful shutdown.""" + self._running = False + if self._sub: + try: + await self._sub.unsubscribe() + except Exception: + pass + if self._nc: + try: + await self._nc.flush(timeout=5) + await self._nc.close() + except Exception: + pass + m.NATS_CONNECTED.set(0) + logger.info("nats.closed", extra={"drops": self._drop_count}) + + # ── NATS callbacks ──────────────────────────────────────────────── + + async def _on_error(self, e: Exception) -> None: + logger.error("nats.error", extra={"error": str(e)}) + + async def _on_disconnected(self) -> None: + m.NATS_CONNECTED.set(0) + logger.warning("nats.disconnected") + + async def _on_reconnected(self) -> None: + m.NATS_CONNECTED.set(1) + m.NATS_RECONNECTS.inc() + logger.info("nats.reconnected") + + async def _on_closed(self) -> None: + m.NATS_CONNECTED.set(0) + logger.info("nats.closed_callback") diff --git a/services/senpai-md-consumer/senpai/md_consumer/publisher.py b/services/senpai-md-consumer/senpai/md_consumer/publisher.py new file mode 100644 index 00000000..fc761d1d --- /dev/null +++ b/services/senpai-md-consumer/senpai/md_consumer/publisher.py @@ -0,0 +1,119 @@ +""" +Signal bus publisher — publishes features, signals, and alerts to NATS. + +Rate-limiting: max N publishes per second per symbol (configurable). +""" +from __future__ import annotations + +import logging +import time + +from nats.aio.client import Client as NatsClient + +from senpai.md_consumer.config import settings +from senpai.md_consumer.models import AlertEvent, FeatureSnapshot, TradeSignal +from senpai.md_consumer import metrics as m + +logger = logging.getLogger(__name__) + + +class Publisher: + """ + Publishes FeatureSnapshots and TradeSignals to NATS. + Built-in per-symbol rate limiter. + """ + + def __init__(self, nc: NatsClient) -> None: + self._nc = nc + self._last_publish: dict[str, float] = {} # symbol → monotonic time + self._min_interval = ( + 1.0 / settings.features_pub_rate_hz + if settings.features_pub_rate_hz > 0 + else 0.1 + ) + + def _rate_ok(self, symbol: str) -> bool: + """Check if we can publish for this symbol (rate limiter).""" + now = time.monotonic() + last = self._last_publish.get(symbol, 0.0) + if now - last >= self._min_interval: + self._last_publish[symbol] = now + return True + return False + + async def publish_features(self, snapshot: FeatureSnapshot) -> bool: + """ + Publish feature snapshot if rate limit allows. + Returns True if published, False if rate-limited or error. + """ + if not settings.features_enabled: + return False + + symbol = snapshot.symbol.upper() + + if not self._rate_ok(symbol): + return False + + subject = f"{settings.features_pub_subject}.{symbol}" + try: + payload = snapshot.model_dump_json().encode("utf-8") + await self._nc.publish(subject, payload) + m.FEATURE_PUBLISH.labels(symbol=symbol).inc() + return True + except Exception as e: + m.FEATURE_PUBLISH_ERRORS.labels(symbol=symbol).inc() + logger.warning( + "publisher.feature_error", + extra={"symbol": symbol, "error": str(e)}, + ) + return False + + async def publish_signal(self, signal: TradeSignal) -> bool: + """Publish trade signal (no rate limit — signals are rare).""" + subject = f"{settings.signals_pub_subject}.{signal.symbol}" + try: + payload = signal.model_dump_json().encode("utf-8") + await self._nc.publish(subject, payload) + m.SIGNALS_EMITTED.labels( + symbol=signal.symbol, + direction=signal.direction, + ).inc() + logger.info( + "publisher.signal_emitted", + extra={ + "symbol": signal.symbol, + "direction": signal.direction, + "confidence": f"{signal.confidence:.2f}", + "reason": signal.reason, + }, + ) + return True + except Exception as e: + logger.error( + "publisher.signal_error", + extra={"symbol": signal.symbol, "error": str(e)}, + ) + return False + + async def publish_alert(self, alert: AlertEvent) -> bool: + """Publish alert event.""" + subject = settings.alerts_pub_subject + try: + payload = alert.model_dump_json().encode("utf-8") + await self._nc.publish(subject, payload) + m.ALERTS_EMITTED.labels(alert_type=alert.alert_type).inc() + logger.warning( + "publisher.alert", + extra={ + "type": alert.alert_type, + "level": alert.level, + "message": alert.message, + }, + ) + return True + except Exception as e: + logger.error( + "publisher.alert_error", + extra={"error": str(e)}, + ) + return False diff --git a/services/senpai-md-consumer/senpai/md_consumer/state.py b/services/senpai-md-consumer/senpai/md_consumer/state.py new file mode 100644 index 00000000..777e1545 --- /dev/null +++ b/services/senpai-md-consumer/senpai/md_consumer/state.py @@ -0,0 +1,238 @@ +""" +State management — LatestState + RollingWindow. + +All structures are asyncio-safe (no locks needed — single-threaded event loop). +""" +from __future__ import annotations + +import time +from collections import deque +from dataclasses import dataclass +from datetime import datetime +from typing import Optional + +from senpai.md_consumer.models import QuoteEvent, TradeEvent + + +@dataclass +class LatestTrade: + symbol: str + price: float + size: float + side: Optional[str] + provider: str + ts_recv: datetime + ts_exchange: Optional[datetime] = None + + +@dataclass +class LatestQuote: + symbol: str + bid: float + ask: float + bid_size: float + ask_size: float + provider: str + ts_recv: datetime + ts_exchange: Optional[datetime] = None + + +@dataclass +class TradeRecord: + """Compact trade record for rolling window.""" + + price: float + size: float + ts: float # monotonic seconds + ts_exchange: Optional[datetime] = None + ts_recv: Optional[datetime] = None + + +@dataclass +class QuoteRecord: + """Compact quote record for rolling window.""" + + bid: float + ask: float + bid_size: float + ask_size: float + ts: float # monotonic seconds + ts_exchange: Optional[datetime] = None + ts_recv: Optional[datetime] = None + + +class RollingWindow: + """ + Fixed-duration rolling window using deque. + + Efficient: O(1) append, amortised O(1) eviction. + No pandas dependency. + """ + + def __init__(self, window_seconds: float = 60.0) -> None: + self._window = window_seconds + self._trades: deque[TradeRecord] = deque() + self._quotes: deque[QuoteRecord] = deque() + + def add_trade(self, trade: TradeRecord) -> None: + self._trades.append(trade) + self._evict_trades() + + def add_quote(self, quote: QuoteRecord) -> None: + self._quotes.append(quote) + self._evict_quotes() + + def _evict_trades(self) -> None: + cutoff = time.monotonic() - self._window + while self._trades and self._trades[0].ts < cutoff: + self._trades.popleft() + + def _evict_quotes(self) -> None: + cutoff = time.monotonic() - self._window + while self._quotes and self._quotes[0].ts < cutoff: + self._quotes.popleft() + + @property + def trades(self) -> deque[TradeRecord]: + self._evict_trades() + return self._trades + + @property + def quotes(self) -> deque[QuoteRecord]: + self._evict_quotes() + return self._quotes + + def trades_since(self, seconds_ago: float) -> list[TradeRecord]: + """Return trades within the last N seconds.""" + cutoff = time.monotonic() - seconds_ago + return [t for t in self._trades if t.ts >= cutoff] + + def quotes_since(self, seconds_ago: float) -> list[QuoteRecord]: + """Return quotes within the last N seconds.""" + cutoff = time.monotonic() - seconds_ago + return [q for q in self._quotes if q.ts >= cutoff] + + +class LatestState: + """ + Maintains latest trade/quote per symbol + rolling windows. + """ + + def __init__(self, window_seconds: float = 60.0) -> None: + self._window_seconds = window_seconds + self._latest_trade: dict[str, LatestTrade] = {} + self._latest_quote: dict[str, LatestQuote] = {} + self._windows: dict[str, RollingWindow] = {} + self._event_count = 0 + + def _get_window(self, symbol: str) -> RollingWindow: + if symbol not in self._windows: + self._windows[symbol] = RollingWindow(self._window_seconds) + return self._windows[symbol] + + def update_trade(self, event: TradeEvent) -> None: + """Update latest trade and rolling window.""" + sym = event.symbol.upper() + + self._latest_trade[sym] = LatestTrade( + symbol=sym, + price=event.price, + size=event.size, + side=event.side, + provider=event.provider, + ts_recv=event.ts_recv, + ts_exchange=event.ts_exchange, + ) + + self._get_window(sym).add_trade( + TradeRecord( + price=event.price, + size=event.size, + ts=time.monotonic(), + ts_exchange=event.ts_exchange, + ts_recv=event.ts_recv, + ) + ) + self._event_count += 1 + + def update_quote(self, event: QuoteEvent) -> None: + """Update latest quote and rolling window.""" + sym = event.symbol.upper() + + self._latest_quote[sym] = LatestQuote( + symbol=sym, + bid=event.bid, + ask=event.ask, + bid_size=event.bid_size, + ask_size=event.ask_size, + provider=event.provider, + ts_recv=event.ts_recv, + ts_exchange=event.ts_exchange, + ) + + self._get_window(sym).add_quote( + QuoteRecord( + bid=event.bid, + ask=event.ask, + bid_size=event.bid_size, + ask_size=event.ask_size, + ts=time.monotonic(), + ts_exchange=event.ts_exchange, + ts_recv=event.ts_recv, + ) + ) + self._event_count += 1 + + def get_latest_trade(self, symbol: str) -> LatestTrade | None: + return self._latest_trade.get(symbol.upper()) + + def get_latest_quote(self, symbol: str) -> LatestQuote | None: + return self._latest_quote.get(symbol.upper()) + + def get_window(self, symbol: str) -> RollingWindow | None: + return self._windows.get(symbol.upper()) + + @property + def symbols(self) -> list[str]: + return sorted( + set(list(self._latest_trade.keys()) + list(self._latest_quote.keys())) + ) + + @property + def event_count(self) -> int: + return self._event_count + + def to_dict(self, symbol: str) -> dict: + """Serialise latest state for API.""" + sym = symbol.upper() + result: dict = {"symbol": sym} + + trade = self._latest_trade.get(sym) + if trade: + result["latest_trade"] = { + "price": trade.price, + "size": trade.size, + "side": trade.side, + "provider": trade.provider, + "ts_recv": trade.ts_recv.isoformat() if trade.ts_recv else None, + } + + quote = self._latest_quote.get(sym) + if quote: + result["latest_quote"] = { + "bid": quote.bid, + "ask": quote.ask, + "bid_size": quote.bid_size, + "ask_size": quote.ask_size, + "provider": quote.provider, + "ts_recv": quote.ts_recv.isoformat() if quote.ts_recv else None, + } + + window = self._windows.get(sym) + if window: + result["window"] = { + "trades_count": len(window.trades), + "quotes_count": len(window.quotes), + } + + return result diff --git a/services/senpai-md-consumer/tests/__init__.py b/services/senpai-md-consumer/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/services/senpai-md-consumer/tests/test_features_math.py b/services/senpai-md-consumer/tests/test_features_math.py new file mode 100644 index 00000000..57eb9340 --- /dev/null +++ b/services/senpai-md-consumer/tests/test_features_math.py @@ -0,0 +1,212 @@ +""" +Test feature computations — deterministic scenarios. +""" + +import pytest + +from senpai.md_consumer.features import ( + _percentile, + _realized_vol, + _vwap, + check_signal, + compute_features, +) +from senpai.md_consumer.models import QuoteEvent, TradeEvent +from senpai.md_consumer.state import LatestState, TradeRecord + + +# ── VWAP ─────────────────────────────────────────────────────────────── + + +def test_vwap_basic(): + trades = [ + TradeRecord(price=100.0, size=10.0, ts=0), + TradeRecord(price=200.0, size=10.0, ts=0), + ] + # VWAP = (100*10 + 200*10) / (10+10) = 150 + assert _vwap(trades) == 150.0 + + +def test_vwap_weighted(): + trades = [ + TradeRecord(price=100.0, size=90.0, ts=0), + TradeRecord(price=200.0, size=10.0, ts=0), + ] + # VWAP = (100*90 + 200*10) / 100 = 110 + assert _vwap(trades) == 110.0 + + +def test_vwap_empty(): + assert _vwap([]) is None + + +def test_vwap_zero_volume(): + trades = [TradeRecord(price=100.0, size=0.0, ts=0)] + assert _vwap(trades) is None + + +# ── Realized volatility ─────────────────────────────────────────────── + + +def test_realized_vol_constant_price(): + """Constant price → 0 volatility.""" + trades = [TradeRecord(price=100.0, size=1.0, ts=0) for _ in range(10)] + vol = _realized_vol(trades) + assert vol is not None + assert vol == 0.0 + + +def test_realized_vol_two_prices(): + """Not enough data points → None.""" + trades = [ + TradeRecord(price=100.0, size=1.0, ts=0), + TradeRecord(price=101.0, size=1.0, ts=0), + ] + assert _realized_vol(trades) is None # needs at least 3 + + +def test_realized_vol_positive(): + """Variable prices should give positive volatility.""" + trades = [ + TradeRecord(price=100.0, size=1.0, ts=0), + TradeRecord(price=102.0, size=1.0, ts=0), + TradeRecord(price=99.0, size=1.0, ts=0), + TradeRecord(price=103.0, size=1.0, ts=0), + ] + vol = _realized_vol(trades) + assert vol is not None + assert vol > 0 + + +# ── Percentile ───────────────────────────────────────────────────────── + + +def test_percentile_basic(): + data = [1.0, 2.0, 3.0, 4.0, 5.0] + assert _percentile(data, 50) == 3.0 + assert _percentile(data, 0) == 1.0 + assert _percentile(data, 100) == 5.0 + + +def test_percentile_p95(): + data = list(range(1, 101)) # 1..100 + data_float = [float(x) for x in data] + p95 = _percentile(data_float, 95) + assert 95 <= p95 <= 96 + + +# ── Full feature computation ────────────────────────────────────────── + + +def test_compute_features_with_state(): + state = LatestState(window_seconds=60.0) + + # Add quote + state.update_quote(QuoteEvent( + provider="binance", + symbol="BTCUSDT", + bid=70000.0, + ask=70002.0, + bid_size=5.0, + ask_size=3.0, + )) + + # Add some trades + for i in range(5): + state.update_trade(TradeEvent( + provider="binance", + symbol="BTCUSDT", + price=70000.0 + i * 10, + size=1.0, + )) + + features = compute_features(state, "BTCUSDT") + + # Mid + assert features["mid"] == pytest.approx(70001.0) + + # Spread + assert features["spread_abs"] == pytest.approx(2.0) + assert features["spread_bps"] is not None + assert features["spread_bps"] > 0 + + # Trade count + assert features["trade_count_10s"] == 5.0 + + # Volume + assert features["trade_volume_10s"] == 5.0 + + # VWAP should be defined + assert features["trade_vwap_10s"] is not None + assert features["trade_vwap_60s"] is not None + + +def test_compute_features_no_data(): + state = LatestState(window_seconds=60.0) + features = compute_features(state, "BTCUSDT") + + # All should be None + assert features["mid"] is None + assert features["spread_abs"] is None + assert features["trade_vwap_10s"] is None + + +# ── Signal detection ────────────────────────────────────────────────── + + +def test_check_signal_long(): + """Strong positive return + volume + tight spread → long signal.""" + features = { + "return_10s": 0.005, # 0.5% (> 0.3% threshold) + "trade_volume_10s": 5.0, # > 1.0 threshold + "spread_bps": 3.0, # < 20 bps threshold + } + signal = check_signal(features, "BTCUSDT") + assert signal is not None + assert signal.direction == "long" + assert signal.confidence > 0 + + +def test_check_signal_short(): + """Strong negative return → short signal.""" + features = { + "return_10s": -0.005, + "trade_volume_10s": 5.0, + "spread_bps": 3.0, + } + signal = check_signal(features, "BTCUSDT") + assert signal is not None + assert signal.direction == "short" + + +def test_check_signal_no_trigger(): + """Small return → no signal.""" + features = { + "return_10s": 0.0001, + "trade_volume_10s": 5.0, + "spread_bps": 3.0, + } + signal = check_signal(features, "BTCUSDT") + assert signal is None + + +def test_check_signal_wide_spread(): + """Wide spread → no signal (even with strong return).""" + features = { + "return_10s": 0.01, + "trade_volume_10s": 5.0, + "spread_bps": 50.0, # > 20 bps + } + signal = check_signal(features, "BTCUSDT") + assert signal is None + + +def test_check_signal_low_volume(): + """Low volume → no signal.""" + features = { + "return_10s": 0.01, + "trade_volume_10s": 0.1, # < 1.0 + "spread_bps": 3.0, + } + signal = check_signal(features, "BTCUSDT") + assert signal is None diff --git a/services/senpai-md-consumer/tests/test_models_parse.py b/services/senpai-md-consumer/tests/test_models_parse.py new file mode 100644 index 00000000..49c91605 --- /dev/null +++ b/services/senpai-md-consumer/tests/test_models_parse.py @@ -0,0 +1,154 @@ +""" +Test event parsing from JSON payloads (mirrors market-data-service contracts). +""" +import json + + +from senpai.md_consumer.models import ( + EventType, + TradeEvent, + QuoteEvent, + HeartbeatEvent, + parse_event, +) + + +# ── Trade events ─────────────────────────────────────────────────────── + + +def test_parse_trade_basic(): + data = { + "event_type": "trade", + "provider": "binance", + "symbol": "BTCUSDT", + "price": 70500.0, + "size": 1.5, + "ts_recv": "2026-02-09T12:00:00+00:00", + } + event = parse_event(data) + assert event is not None + assert isinstance(event, TradeEvent) + assert event.event_type == EventType.TRADE + assert event.symbol == "BTCUSDT" + assert event.price == 70500.0 + assert event.size == 1.5 + assert event.provider == "binance" + + +def test_parse_trade_with_extra_fields(): + """Unknown fields should be silently ignored (tolerant parsing).""" + data = { + "event_type": "trade", + "provider": "bybit", + "symbol": "ETHUSDT", + "price": 2100.0, + "size": 10.0, + "ts_recv": "2026-02-09T12:00:00+00:00", + "unknown_field": "should_be_ignored", + "another_extra": 42, + } + event = parse_event(data) + assert event is not None + assert event.symbol == "ETHUSDT" + + +def test_parse_trade_with_side_and_exchange_ts(): + data = { + "event_type": "trade", + "provider": "binance", + "symbol": "BTCUSDT", + "price": 70000.0, + "size": 0.5, + "side": "buy", + "ts_exchange": "2026-02-09T12:00:00+00:00", + "ts_recv": "2026-02-09T12:00:00.100+00:00", + "trade_id": "t12345", + } + event = parse_event(data) + assert event.side == "buy" + assert event.trade_id == "t12345" + assert event.ts_exchange is not None + + +# ── Quote events ─────────────────────────────────────────────────────── + + +def test_parse_quote_basic(): + data = { + "event_type": "quote", + "provider": "binance", + "symbol": "BTCUSDT", + "bid": 70000.0, + "ask": 70001.0, + "bid_size": 5.0, + "ask_size": 3.0, + "ts_recv": "2026-02-09T12:00:00+00:00", + } + event = parse_event(data) + assert isinstance(event, QuoteEvent) + assert event.bid == 70000.0 + assert event.ask == 70001.0 + + +def test_parse_quote_zero_values(): + data = { + "event_type": "quote", + "provider": "binance", + "symbol": "BTCUSDT", + "bid": 0.0, + "ask": 0.0, + "bid_size": 0.0, + "ask_size": 0.0, + } + event = parse_event(data) + assert event is not None + assert event.bid == 0.0 + + +# ── Heartbeat events ────────────────────────────────────────────────── + + +def test_parse_heartbeat(): + data = { + "event_type": "heartbeat", + "provider": "alpaca", + "ts_recv": "2026-02-09T12:00:00+00:00", + } + event = parse_event(data) + assert isinstance(event, HeartbeatEvent) + assert event.provider == "alpaca" + + +# ── Edge cases ───────────────────────────────────────────────────────── + + +def test_parse_unknown_type(): + data = {"event_type": "unknown_type", "provider": "test"} + event = parse_event(data) + assert event is None + + +def test_parse_missing_type(): + data = {"provider": "test", "symbol": "BTC"} + event = parse_event(data) + assert event is None + + +def test_parse_invalid_data(): + data = {"event_type": "trade"} # missing required fields + event = parse_event(data) + assert event is None + + +def test_parse_empty_dict(): + event = parse_event({}) + assert event is None + + +def test_parse_from_json_bytes(): + """Simulate actual NATS message deserialization.""" + raw = b'{"event_type":"trade","provider":"binance","symbol":"BTCUSDT","price":70500.0,"size":1.5}' + data = json.loads(raw) + event = parse_event(data) + assert event is not None + assert event.price == 70500.0 diff --git a/services/senpai-md-consumer/tests/test_rate_limit.py b/services/senpai-md-consumer/tests/test_rate_limit.py new file mode 100644 index 00000000..bf0ca321 --- /dev/null +++ b/services/senpai-md-consumer/tests/test_rate_limit.py @@ -0,0 +1,111 @@ +""" +Test publisher rate limiting. +""" +from unittest.mock import AsyncMock + +import pytest + +from senpai.md_consumer.publisher import Publisher +from senpai.md_consumer.models import FeatureSnapshot, TradeSignal + + +@pytest.fixture +def mock_nc(): + """Mock NATS client.""" + nc = AsyncMock() + nc.publish = AsyncMock() + return nc + + +@pytest.fixture +def publisher(mock_nc): + return Publisher(mock_nc) + + +@pytest.mark.asyncio +async def test_publish_features_respects_rate_limit(mock_nc, publisher): + """Second publish for same symbol within rate window should be skipped.""" + snapshot = FeatureSnapshot( + symbol="BTCUSDT", + features={"mid": 70000.0}, + ) + + # First publish should succeed + result1 = await publisher.publish_features(snapshot) + assert result1 is True + + # Immediate second publish should be rate-limited + result2 = await publisher.publish_features(snapshot) + assert result2 is False # rate-limited + + # Only one actual NATS publish + assert mock_nc.publish.call_count == 1 + + +@pytest.mark.asyncio +async def test_publish_features_different_symbols(mock_nc, publisher): + """Different symbols have independent rate limiters.""" + snap1 = FeatureSnapshot(symbol="BTCUSDT", features={"mid": 70000.0}) + snap2 = FeatureSnapshot(symbol="ETHUSDT", features={"mid": 2000.0}) + + r1 = await publisher.publish_features(snap1) + r2 = await publisher.publish_features(snap2) + + assert r1 is True + assert r2 is True + assert mock_nc.publish.call_count == 2 + + +@pytest.mark.asyncio +async def test_publish_signal_no_rate_limit(mock_nc, publisher): + """Signals are NOT rate limited.""" + signal = TradeSignal( + symbol="BTCUSDT", + direction="long", + confidence=0.8, + reason="test", + ) + + r1 = await publisher.publish_signal(signal) + r2 = await publisher.publish_signal(signal) + + assert r1 is True + assert r2 is True + assert mock_nc.publish.call_count == 2 + + +@pytest.mark.asyncio +async def test_publish_features_after_rate_window(mock_nc, publisher): + """After rate window passes, publish should succeed again.""" + # Override min interval to something very small for testing + publisher._min_interval = 0.01 # 10ms + + snapshot = FeatureSnapshot( + symbol="BTCUSDT", + features={"mid": 70000.0}, + ) + + r1 = await publisher.publish_features(snapshot) + assert r1 is True + + # Wait for rate window to pass + import asyncio + await asyncio.sleep(0.02) + + r2 = await publisher.publish_features(snapshot) + assert r2 is True + assert mock_nc.publish.call_count == 2 + + +@pytest.mark.asyncio +async def test_publish_handles_nats_error(mock_nc, publisher): + """NATS publish error should not raise, just return False.""" + mock_nc.publish.side_effect = Exception("NATS down") + + snapshot = FeatureSnapshot( + symbol="BTCUSDT", + features={"mid": 70000.0}, + ) + + result = await publisher.publish_features(snapshot) + assert result is False diff --git a/services/senpai-md-consumer/tests/test_state_rolling.py b/services/senpai-md-consumer/tests/test_state_rolling.py new file mode 100644 index 00000000..49766e9b --- /dev/null +++ b/services/senpai-md-consumer/tests/test_state_rolling.py @@ -0,0 +1,138 @@ +""" +Test state management — LatestState and RollingWindow. +""" +import time + + +from senpai.md_consumer.state import ( + LatestState, + RollingWindow, + TradeRecord, +) +from senpai.md_consumer.models import TradeEvent, QuoteEvent + + +# ── RollingWindow ────────────────────────────────────────────────────── + + +def test_rolling_window_add_trade(): + w = RollingWindow(window_seconds=60.0) + t = TradeRecord(price=100.0, size=1.0, ts=time.monotonic()) + w.add_trade(t) + assert len(w.trades) == 1 + assert w.trades[0].price == 100.0 + + +def test_rolling_window_eviction(): + """Old records should be evicted.""" + w = RollingWindow(window_seconds=1.0) # 1 second window + + old_ts = time.monotonic() - 2.0 # 2 seconds ago + w.add_trade(TradeRecord(price=100.0, size=1.0, ts=old_ts)) + w.add_trade(TradeRecord(price=200.0, size=2.0, ts=time.monotonic())) + + # Old record should be evicted + trades = list(w.trades) + assert len(trades) == 1 + assert trades[0].price == 200.0 + + +def test_rolling_window_trades_since(): + w = RollingWindow(window_seconds=60.0) + now = time.monotonic() + + # Add trades at different times + w.add_trade(TradeRecord(price=100.0, size=1.0, ts=now - 30)) # 30s ago + w.add_trade(TradeRecord(price=200.0, size=2.0, ts=now - 5)) # 5s ago + w.add_trade(TradeRecord(price=300.0, size=3.0, ts=now)) # now + + last_10s = w.trades_since(10.0) + assert len(last_10s) == 2 # 5s ago + now + assert last_10s[0].price == 200.0 + + +def test_rolling_window_empty(): + w = RollingWindow(window_seconds=60.0) + assert len(w.trades) == 0 + assert len(w.quotes) == 0 + assert w.trades_since(10.0) == [] + + +# ── LatestState ──────────────────────────────────────────────────────── + + +def test_latest_state_update_trade(): + state = LatestState(window_seconds=60.0) + + event = TradeEvent( + provider="binance", + symbol="BTCUSDT", + price=70500.0, + size=1.5, + side="buy", + ) + state.update_trade(event) + + latest = state.get_latest_trade("BTCUSDT") + assert latest is not None + assert latest.price == 70500.0 + assert latest.side == "buy" + assert state.event_count == 1 + + +def test_latest_state_update_quote(): + state = LatestState(window_seconds=60.0) + + event = QuoteEvent( + provider="binance", + symbol="BTCUSDT", + bid=70000.0, + ask=70001.0, + bid_size=5.0, + ask_size=3.0, + ) + state.update_quote(event) + + latest = state.get_latest_quote("BTCUSDT") + assert latest is not None + assert latest.bid == 70000.0 + assert latest.ask == 70001.0 + + +def test_latest_state_symbols(): + state = LatestState(window_seconds=60.0) + + state.update_trade(TradeEvent( + provider="binance", symbol="BTCUSDT", price=100.0, size=1.0 + )) + state.update_quote(QuoteEvent( + provider="binance", symbol="ETHUSDT", + bid=2000.0, ask=2001.0, bid_size=1.0, ask_size=1.0, + )) + + assert "BTCUSDT" in state.symbols + assert "ETHUSDT" in state.symbols + + +def test_latest_state_to_dict(): + state = LatestState(window_seconds=60.0) + + state.update_trade(TradeEvent( + provider="binance", symbol="BTCUSDT", price=70500.0, size=1.0 + )) + state.update_quote(QuoteEvent( + provider="binance", symbol="BTCUSDT", + bid=70000.0, ask=70001.0, bid_size=1.0, ask_size=1.0, + )) + + d = state.to_dict("BTCUSDT") + assert d["symbol"] == "BTCUSDT" + assert "latest_trade" in d + assert "latest_quote" in d + assert d["latest_trade"]["price"] == 70500.0 + + +def test_latest_state_missing_symbol(): + state = LatestState(window_seconds=60.0) + assert state.get_latest_trade("NOPE") is None + assert state.get_latest_quote("NOPE") is None