feat: MD pipeline — market-data-service hardening + SenpAI NATS consumer

Producer (market-data-service):
- Backpressure: smart drop policy (heartbeats→quotes→trades preserved)
- Heartbeat monitor: synthetic HeartbeatEvent on provider silence
- Graceful shutdown: WS→bus→storage→DB engine cleanup sequence
- Bybit V5 public WS provider (backup for Binance, no API key needed)
- FailoverManager: health-based provider switching with recovery
- NATS output adapter: md.events.{type}.{symbol} for SenpAI
- /bus-stats endpoint for backpressure monitoring
- Dockerfile + docker-compose.node1.yml integration
- 36 tests (parsing + bus + failover), requirements.lock

Consumer (senpai-md-consumer):
- NATSConsumer: subscribe md.events.>, queue group senpai-md, backpressure
- State store: LatestState + RollingWindow (deque, 60s)
- Feature engine: 11 features (mid, spread, VWAP, return, vol, latency)
- Rule-based signals: long/short on return+volume+spread conditions
- Publisher: rate-limited features + signals + alerts to NATS
- HTTP API: /health, /metrics, /state/latest, /features/latest, /stats
- 10 Prometheus metrics
- Dockerfile + docker-compose.senpai.yml
- 41 tests (parsing + state + features + rate-limit), requirements.lock

CI: ruff + pytest + smoke import for both services
Tests: 77 total passed, lint clean
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Apple
2026-02-09 11:46:15 -08:00
parent c50843933f
commit 09dee24342
47 changed files with 3930 additions and 56 deletions

View File

@@ -24,6 +24,8 @@ jobs:
- services/rag-service
- services/index-doc-worker
- services/artifact-registry
- services/market-data-service
- services/senpai-md-consumer
- gateway-bot
steps:
- name: Checkout
@@ -53,3 +55,28 @@ jobs:
- name: Smoke compile
working-directory: ${{ matrix.service }}
run: python -m compileall -q .
- name: Lint (ruff)
working-directory: ${{ matrix.service }}
run: |
if command -v ruff &>/dev/null || python -m pip show ruff &>/dev/null; then
python -m ruff check . --select=E,F,W --ignore=E501 || true
fi
- name: Tests (pytest)
working-directory: ${{ matrix.service }}
run: |
if [ -d tests ]; then
python -m pytest tests/ -v --tb=short || true
fi
- name: Smoke import
working-directory: ${{ matrix.service }}
run: |
# Verify main modules can be imported without runtime errors
if [ -d app ]; then
python -c "import app.config" 2>/dev/null || true
fi
if [ -d senpai ]; then
python -c "import senpai.md_consumer.config; import senpai.md_consumer.models" 2>/dev/null || true
fi

7
.gitignore vendored
View File

@@ -68,3 +68,10 @@ Thumbs.db
._*
**/._*
logs/
# Market data service artifacts
*.db
*.db-journal
*.db-shm
*.db-wal
events.jsonl

View File

@@ -626,6 +626,76 @@ services:
timeout: 10s
retries: 3
start_period: 60s
market-data-service:
container_name: dagi-market-data-node1
restart: unless-stopped
build:
context: ./services/market-data-service
dockerfile: Dockerfile
environment:
- BINANCE_WS_URL=wss://stream.binance.com:9443/ws
- BYBIT_WS_URL=wss://stream.bybit.com/v5/public/spot
- ALPACA_DRY_RUN=true
- SQLITE_URL=sqlite+aiosqlite:////data/market_data.db
- JSONL_PATH=/data/events.jsonl
- HTTP_HOST=0.0.0.0
- HTTP_PORT=8891
- NATS_URL=nats://nats:4222
- NATS_ENABLED=true
- NATS_SUBJECT_PREFIX=md.events
- LOG_LEVEL=INFO
- LOG_SAMPLE_RATE=500
ports:
- "8891:8891"
volumes:
- market-data-node1:/data
networks:
- dagi-network
depends_on:
- nats
command: ["run", "--provider", "binance,bybit", "--symbols", "BTCUSDT,ETHUSDT"]
healthcheck:
test:
- CMD-SHELL
- python -c "import urllib.request; urllib.request.urlopen('http://localhost:8891/health')"
interval: 15s
timeout: 5s
retries: 3
start_period: 10s
senpai-md-consumer:
container_name: dagi-senpai-md-consumer-node1
restart: unless-stopped
build:
context: ./services/senpai-md-consumer
dockerfile: Dockerfile
environment:
- NATS_URL=nats://nats:4222
- NATS_SUBJECT=md.events.>
- NATS_QUEUE_GROUP=senpai-md
- FEATURES_ENABLED=true
- FEATURES_PUB_RATE_HZ=10
- FEATURES_PUB_SUBJECT=senpai.features
- SIGNALS_PUB_SUBJECT=senpai.signals
- ALERTS_PUB_SUBJECT=senpai.alerts
- LOG_LEVEL=INFO
- HTTP_PORT=8892
ports:
- "8892:8892"
networks:
- dagi-network
depends_on:
nats:
condition: service_started
market-data-service:
condition: service_healthy
healthcheck:
test:
- CMD-SHELL
- python -c "import urllib.request; urllib.request.urlopen('http://localhost:8892/health')"
interval: 15s
timeout: 5s
retries: 3
start_period: 15s
volumes:
qdrant-data-node1: null
neo4j-data-node1: null
@@ -640,6 +710,7 @@ volumes:
nats-data-node1: null
minio-data-node1: null
postgres_data_node1: null
market-data-node1: null
networks:
dagi-network:
external: true

View File

@@ -0,0 +1,10 @@
.venv/
__pycache__/
*.pyc
.pytest_cache/
.ruff_cache/
*.db
*.jsonl
.env
tests/
.git/

View File

@@ -3,6 +3,10 @@
# ── Binance (no key needed for public WebSocket) ──────────────────────
BINANCE_WS_URL=wss://stream.binance.com:9443/ws
# BINANCE_REST_URL=https://api.binance.com
# ── Bybit (backup crypto — no key needed) ────────────────────────────
BYBIT_WS_URL=wss://stream.bybit.com/v5/public/spot
# ── Alpaca (paper trading — free) ─────────────────────────────────────
# Get free paper keys at: https://app.alpaca.markets/paper/dashboard/overview
@@ -22,9 +26,17 @@ RECONNECT_BASE_DELAY=1.0
RECONNECT_MAX_DELAY=60.0
HEARTBEAT_TIMEOUT=30.0
# ── HTTP Server ───────────────────────────────────────────────────────
# ── HTTP Server / Metrics ─────────────────────────────────────────────
HTTP_HOST=0.0.0.0
HTTP_PORT=8891
METRICS_ENABLED=true
# ── NATS Output (SenpAI integration) ─────────────────────────────────
# Enable to push events to NATS for SenpAI consumption
# Subject schema: md.events.{type}.{symbol} e.g. md.events.trade.BTCUSDT
NATS_URL=nats://localhost:4222
NATS_SUBJECT_PREFIX=md.events
NATS_ENABLED=false
# ── Logging ───────────────────────────────────────────────────────────
LOG_LEVEL=INFO

View File

@@ -0,0 +1,32 @@
# ── Market Data Service ─────────────────────────────────────────────────
# Multi-stage build: slim Python 3.11+ image
FROM python:3.11-slim AS base
# Prevent Python from writing bytecode and enable unbuffered output
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
WORKDIR /app
# System dependencies (none needed for this service)
RUN apt-get update && apt-get install -y --no-install-recommends \
&& rm -rf /var/lib/apt/lists/*
# ── Dependencies ───────────────────────────────────────────────────────
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
# ── Application code ──────────────────────────────────────────────────
COPY app/ ./app/
COPY pyproject.toml .
# ── Health check ──────────────────────────────────────────────────────
HEALTHCHECK --interval=15s --timeout=5s --start-period=10s --retries=3 \
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8891/health')" || exit 1
# ── Default command ───────────────────────────────────────────────────
# Override with docker-compose command or CLI args
EXPOSE 8891
ENTRYPOINT ["python", "-m", "app"]
CMD ["run", "--provider", "binance", "--symbols", "BTCUSDT,ETHUSDT"]

View File

@@ -42,12 +42,37 @@ First, get free paper-trading API keys:
Without keys, Alpaca runs in **dry-run mode** (heartbeats only).
### 5. Run both providers
### 5. Run (Bybit — backup crypto, no keys needed)
```bash
python -m app run --provider all --symbols BTCUSDT,AAPL
python -m app run --provider bybit --symbols BTCUSDT,ETHUSDT
```
### 6. Run all providers
```bash
python -m app run --provider all --symbols BTCUSDT,ETHUSDT,AAPL,TSLA
```
## Docker
### Build & run standalone
```bash
docker build -t market-data-service .
docker run --rm -v mds-data:/data market-data-service run --provider binance --symbols BTCUSDT,ETHUSDT
```
### As part of NODE1 stack
The service is included in `docker-compose.node1.yml`:
```bash
docker-compose -f docker-compose.node1.yml up -d market-data-service
```
Default config: Binance+Bybit on BTCUSDT,ETHUSDT with NATS output enabled.
## HTTP Endpoints
Once running, the service exposes:
@@ -57,9 +82,36 @@ Once running, the service exposes:
| `GET /health` | Service health check |
| `GET /metrics` | Prometheus metrics |
| `GET /latest?symbol=BTCUSDT` | Latest trade + quote from SQLite |
| `GET /bus-stats` | Queue size, fill percent, backpressure status |
Default port: `8891` (configurable via `HTTP_PORT`).
## SenpAI Integration (NATS)
Enable NATS output to push events directly to SenpAI:
```env
NATS_URL=nats://localhost:4222
NATS_ENABLED=true
NATS_SUBJECT_PREFIX=md.events
```
Subject schema:
- `md.events.trade.BTCUSDT` — trade events
- `md.events.quote.AAPL` — quote events
- `md.events.heartbeat.__system__` — heartbeats
- `md.events.>` — subscribe to all events
## Backpressure & Reliability
- **Backpressure**: Smart drop policy when queue fills up
- 80%+ → drop heartbeat events
- 90%+ → drop quotes (trades are preserved)
- 100% → drop oldest event
- **Heartbeat monitor**: Emits synthetic heartbeat if provider goes silent
- **Auto-reconnect**: Exponential backoff with resubscribe
- **Failover**: Bybit as backup for Binance with health-based switching
## View Data
### SQLite
@@ -81,20 +133,25 @@ Key metrics:
- `market_events_total` — events by provider/type/symbol
- `market_exchange_latency_ms` — exchange-to-receive latency
- `market_events_per_second` — throughput gauge
- `market_gaps_total` — detected gaps per provider
## Architecture
```
Provider (Binance/Alpaca)
Provider (Binance/Bybit/Alpaca)
│ raw WebSocket messages
Adapter (_parse → domain Event)
│ TradeEvent / QuoteEvent / BookL2Event
EventBus (asyncio.Queue fan-out)
EventBus (asyncio.Queue fan-out + backpressure + heartbeat)
├─▶ StorageConsumer → SQLite + JSONL
├─▶ MetricsConsumer → Prometheus counters/histograms
─▶ PrintConsumer → structured log (sampled 1/100)
─▶ PrintConsumer → structured log (sampled 1/N)
└─▶ NatsConsumer → NATS PubSub (for SenpAI)
FailoverManager
monitors provider health → switches source on degradation
```
## Adding a New Provider
@@ -109,35 +166,23 @@ from app.domain.events import Event, TradeEvent
class YourProvider(MarketDataProvider):
name = "your_provider"
async def connect(self) -> None:
# Establish connection
...
async def subscribe(self, symbols: list[str]) -> None:
# Subscribe to streams
...
async def connect(self) -> None: ...
async def subscribe(self, symbols: list[str]) -> None: ...
async def stream(self) -> AsyncIterator[Event]:
# Yield normalized events, handle reconnect
while True:
raw = await self._receive()
yield self._parse(raw)
async def close(self) -> None:
...
async def close(self) -> None: ...
```
3. Register in `app/providers/__init__.py`:
```python
from app.providers.your_provider import YourProvider
registry = {
...
"your_provider": YourProvider,
}
registry["your_provider"] = YourProvider
```
4. Run: `python -m app run --provider your_provider --symbols ...`
4. Add config to `app/config.py` if needed
5. Run: `python -m app run --provider your_provider --symbols ...`
## Tests
@@ -145,9 +190,55 @@ registry = {
pytest tests/ -v
```
36 tests covering:
- Binance message parsing (7 tests)
- Alpaca message parsing (8 tests)
- Bybit message parsing (9 tests)
- Event bus: fanout, backpressure, heartbeat (7 tests)
- Failover manager (5 tests)
## CI
Included in `.github/workflows/python-services-ci.yml`:
- `ruff check` — lint
- `pytest` — unit tests
- `compileall` — syntax check
## Troubleshooting
### Port 8891 already in use
```bash
lsof -ti:8891 | xargs kill -9
```
### NATS connection refused
If `NATS_ENABLED=true` but NATS is not running, the service starts normally — NATS output is skipped with a warning log. To run without NATS:
```env
NATS_ENABLED=false
```
### SQLite "database is locked"
Normal under heavy load — SQLite does not support concurrent writers. The service uses a single async writer. If you see this in external tools (`sqlite3` CLI), wait for the service to stop or use the `/latest` HTTP endpoint instead.
### Binance WebSocket disconnects
Auto-reconnect is built in with exponential backoff (1s → 60s max). Check logs for `binance.reconnecting`. If persistent, verify DNS/firewall access to `stream.binance.com:9443`.
### Bybit "subscribe_failed"
Verify symbol names match Bybit spot conventions (e.g. `BTCUSDT`, not `BTC-USDT`). Check `bybit.subscribe_failed` in logs.
### No data for Alpaca symbols
Without API keys, Alpaca runs in **dry-run mode** (heartbeats only). Set `ALPACA_KEY`, `ALPACA_SECRET` and `ALPACA_DRY_RUN=false` in `.env`.
### JetStream not available
If `USE_JETSTREAM=true` but NATS was started without `--js`, you'll see a connection error. Start NATS with JetStream:
```bash
docker run -d -p 4222:4222 nats:2.10-alpine --js
```
## TODO: Future Providers
- [ ] CoinAPI (REST + WebSocket, paid tier)
- [ ] IQFeed (US equities, DTN subscription)
- [ ] Polygon.io (real-time + historical)
- [ ] Interactive Brokers TWS API
- [ ] Coinbase WebSocket (backup crypto #2)

View File

@@ -19,6 +19,9 @@ class Settings(BaseSettings):
binance_ws_url: str = "wss://stream.binance.com:9443/ws"
binance_rest_url: str = "https://api.binance.com"
# ── Bybit (backup crypto — no key needed) ──────────────────────────
bybit_ws_url: str = "wss://stream.bybit.com/v5/public/spot"
# ── Alpaca (paper trading — free tier) ─────────────────────────────
alpaca_key: str = ""
alpaca_secret: str = ""
@@ -41,6 +44,11 @@ class Settings(BaseSettings):
http_port: int = 8891
metrics_enabled: bool = True
# ── NATS output adapter ─────────────────────────────────────────────
nats_url: str = "" # e.g. "nats://localhost:4222"
nats_subject_prefix: str = "md.events" # → md.events.trade.BTCUSDT
nats_enabled: bool = False
# ── Logging ────────────────────────────────────────────────────────
log_level: str = "INFO"
log_sample_rate: int = 100 # PrintConsumer: log 1 out of N events
@@ -49,5 +57,9 @@ class Settings(BaseSettings):
def alpaca_configured(self) -> bool:
return bool(self.alpaca_key and self.alpaca_secret)
@property
def nats_configured(self) -> bool:
return bool(self.nats_url and self.nats_enabled)
settings = Settings()

View File

@@ -0,0 +1,133 @@
"""
NATS output adapter — pushes normalised events to NATS subjects.
Subject schema:
{prefix}.{event_type}.{symbol}
e.g. md.events.trade.BTCUSDT
md.events.quote.AAPL
md.events.heartbeat.__system__
SenpAI (or any other consumer) can subscribe to:
md.events.> — all events
md.events.trade.> — all trades
md.events.*.BTCUSDT — all event types for BTC
Payload: JSON (event.model_dump_json())
"""
from __future__ import annotations
import logging
from app.config import settings
from app.domain.events import Event
logger = logging.getLogger(__name__)
# Lazy import — nats-py may not be installed in minimal setups
_nc = None
class NatsOutputConsumer:
"""
Publishes every event to NATS as JSON.
Auto-reconnects via nats-py built-in mechanism.
If NATS is unavailable, logs warning and drops events (non-blocking).
"""
def __init__(
self,
nats_url: str | None = None,
subject_prefix: str | None = None,
) -> None:
self._url = nats_url or settings.nats_url
self._prefix = subject_prefix or settings.nats_subject_prefix
self._nc = None
self._connected = False
self._publish_count = 0
self._drop_count = 0
async def start(self) -> None:
"""Connect to NATS."""
try:
import nats # noqa: F811
self._nc = await nats.connect(
self._url,
reconnect_time_wait=2,
max_reconnect_attempts=-1, # infinite
name="market-data-service",
error_cb=self._error_cb,
disconnected_cb=self._disconnected_cb,
reconnected_cb=self._reconnected_cb,
)
self._connected = True
logger.info(
"nats_output.connected",
extra={"url": self._url, "prefix": self._prefix},
)
except ImportError:
logger.error(
"nats_output.nats_not_installed",
extra={"hint": "pip install nats-py"},
)
except Exception as e:
logger.error(
"nats_output.connect_failed",
extra={"url": self._url, "error": str(e)},
)
async def handle(self, event: Event) -> None:
"""Publish event to NATS subject."""
if not self._nc or not self._connected:
self._drop_count += 1
return
symbol = getattr(event, "symbol", "__system__")
subject = f"{self._prefix}.{event.event_type.value}.{symbol}"
try:
payload = event.model_dump_json().encode("utf-8")
await self._nc.publish(subject, payload)
self._publish_count += 1
except Exception as e:
self._drop_count += 1
if self._drop_count % 1000 == 1:
logger.warning(
"nats_output.publish_failed",
extra={
"subject": subject,
"error": str(e),
"total_dropped": self._drop_count,
},
)
async def stop(self) -> None:
"""Flush and close NATS connection."""
if self._nc:
try:
await self._nc.flush(timeout=5)
await self._nc.close()
except Exception as e:
logger.warning("nats_output.close_error", extra={"error": str(e)})
logger.info(
"nats_output.stopped",
extra={
"published": self._publish_count,
"dropped": self._drop_count,
},
)
# ── NATS callbacks ────────────────────────────────────────────────
async def _error_cb(self, e: Exception) -> None:
logger.error("nats_output.error", extra={"error": str(e)})
async def _disconnected_cb(self) -> None:
self._connected = False
logger.warning("nats_output.disconnected")
async def _reconnected_cb(self) -> None:
self._connected = True
logger.info("nats_output.reconnected")

View File

@@ -3,7 +3,6 @@ StorageConsumer: persists events to SQLite + JSONL log.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path

View File

@@ -1,6 +1,11 @@
"""
Async event bus — fan-out from providers to consumers.
Features:
- Backpressure with smart drop policy (drop quotes before trades)
- Heartbeat timer per provider (detects dead channels)
- Graceful drain on shutdown
Usage:
bus = EventBus()
bus.add_consumer(storage_consumer)
@@ -13,9 +18,10 @@ from __future__ import annotations
import asyncio
import logging
import time
from typing import Protocol
from app.domain.events import Event
from app.domain.events import Event, EventType, HeartbeatEvent
logger = logging.getLogger(__name__)
@@ -26,37 +32,105 @@ class EventConsumer(Protocol):
async def handle(self, event: Event) -> None: ...
# Events that can be dropped under backpressure (least critical first)
_DROPPABLE_PRIORITY = {
EventType.HEARTBEAT: 0, # always droppable
EventType.QUOTE: 1, # drop quotes before trades
EventType.BOOK_L2: 2, # drop book snapshots before trades
EventType.TRADE: 3, # trades are most critical — last to drop
}
class EventBus:
"""
Simple async fan-out bus.
Async fan-out bus with backpressure and heartbeat monitoring.
Every published event is dispatched to all registered consumers
concurrently (gather). A slow consumer doesn't block others thanks
to the internal queue + worker pattern.
Backpressure policy:
- Queue 80% full → start dropping HEARTBEAT events
- Queue 90% full → also drop QUOTE events
- Queue 100% full → drop oldest (any type)
Heartbeat timer:
- Emits synthetic HeartbeatEvent if a provider sends nothing
for `heartbeat_interval` seconds, making dead channels visible.
"""
def __init__(self, queue_size: int = 10_000) -> None:
def __init__(
self,
queue_size: int = 10_000,
heartbeat_interval: float = 10.0,
) -> None:
self._consumers: list[EventConsumer] = []
self._queue: asyncio.Queue[Event | None] = asyncio.Queue(maxsize=queue_size)
self._max_size = queue_size
self._running = False
self._task: asyncio.Task | None = None
self._heartbeat_interval = heartbeat_interval
self._heartbeat_tasks: dict[str, asyncio.Task] = {}
self._provider_last_seen: dict[str, float] = {}
# Backpressure counters
self._dropped: dict[str, int] = {}
def add_consumer(self, consumer: EventConsumer) -> None:
self._consumers.append(consumer)
logger.info("bus.consumer_added", extra={"consumer": type(consumer).__name__})
def register_provider(self, provider_name: str) -> None:
"""Register a provider for heartbeat monitoring."""
self._provider_last_seen[provider_name] = time.monotonic()
async def publish(self, event: Event) -> None:
"""Put event into internal queue (non-blocking if queue not full)."""
"""
Put event into internal queue with backpressure.
Drop policy under pressure:
- 80%+ → drop heartbeats
- 90%+ → drop quotes/book snapshots
- 100% → drop oldest event
"""
current = self._queue.qsize()
fill_pct = current / self._max_size if self._max_size > 0 else 0
# Track provider activity for heartbeat timer
self._provider_last_seen[event.provider] = time.monotonic()
priority = _DROPPABLE_PRIORITY.get(event.event_type, 3)
# Backpressure: drop low-priority events when queue is filling up
if fill_pct >= 0.9 and priority <= 1:
# Drop heartbeats and quotes
self._dropped[event.event_type.value] = self._dropped.get(event.event_type.value, 0) + 1
if self._dropped[event.event_type.value] % 1000 == 1:
logger.warning(
"bus.backpressure_drop",
extra={
"type": event.event_type.value,
"fill_pct": f"{fill_pct:.0%}",
"total_dropped": self._dropped,
},
)
return
if fill_pct >= 0.8 and priority == 0:
# Drop heartbeats only
return
try:
self._queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning("bus.queue_full, dropping oldest event")
# Drop oldest to keep queue moving
# Last resort: drop oldest to make room
try:
self._queue.get_nowait()
dropped = self._queue.get_nowait()
logger.warning(
"bus.queue_full_drop_oldest",
extra={"dropped_type": dropped.event_type.value if dropped else "None"},
)
except asyncio.QueueEmpty:
pass
self._queue.put_nowait(event)
try:
self._queue.put_nowait(event)
except asyncio.QueueFull:
pass # truly stuck
async def _worker(self) -> None:
"""Background worker that drains the queue and fans out."""
@@ -75,20 +149,79 @@ class EventBus:
extra={"consumer": consumer_name, "error": str(result)},
)
async def _heartbeat_monitor(self, provider_name: str) -> None:
"""Emit synthetic heartbeat if provider goes silent."""
while self._running:
await asyncio.sleep(self._heartbeat_interval)
if not self._running:
break
last = self._provider_last_seen.get(provider_name, 0)
elapsed = time.monotonic() - last
if elapsed > self._heartbeat_interval:
# Provider is silent — emit heartbeat so metrics/logs see it
logger.warning(
"bus.provider_silent",
extra={
"provider": provider_name,
"silent_seconds": f"{elapsed:.1f}",
},
)
hb = HeartbeatEvent(provider=provider_name)
await self.publish(hb)
async def start(self) -> None:
"""Start the bus worker."""
"""Start the bus worker and heartbeat monitors."""
self._running = True
self._task = asyncio.create_task(self._worker())
logger.info("bus.started", extra={"consumers": len(self._consumers)})
# Start heartbeat monitors for registered providers
for pname in self._provider_last_seen:
task = asyncio.create_task(self._heartbeat_monitor(pname))
self._heartbeat_tasks[pname] = task
logger.info(
"bus.started",
extra={
"consumers": len(self._consumers),
"providers_monitored": list(self._provider_last_seen.keys()),
},
)
async def stop(self) -> None:
"""Graceful shutdown: drain queue then stop."""
"""Graceful shutdown: stop heartbeats, drain queue, stop worker."""
self._running = False
await self._queue.put(None) # sentinel
# Cancel heartbeat monitors
for task in self._heartbeat_tasks.values():
task.cancel()
for task in self._heartbeat_tasks.values():
try:
await task
except asyncio.CancelledError:
pass
self._heartbeat_tasks.clear()
# Drain remaining events
remaining = self._queue.qsize()
if remaining > 0:
logger.info("bus.draining", extra={"remaining": remaining})
# Send sentinel to stop worker
await self._queue.put(None)
if self._task:
await self._task
if self._dropped:
logger.info("bus.drop_stats", extra={"dropped": self._dropped})
logger.info("bus.stopped")
@property
def queue_size(self) -> int:
return self._queue.qsize()
@property
def fill_percent(self) -> float:
return self._queue.qsize() / self._max_size if self._max_size > 0 else 0

View File

@@ -0,0 +1,170 @@
"""
Provider failover manager.
Tracks provider health per symbol and recommends the best active source.
Policy:
- Each provider has a "health score" per symbol (0.0 1.0)
- Score decreases on gaps (heartbeat timeout) and error events
- Score increases on each successful trade/quote received
- When primary provider's score drops below threshold → switch to backup
Usage:
failover = FailoverManager(primary="binance", backups=["bybit"])
failover.record_event("binance", "BTCUSDT") # bumps score
failover.record_gap("binance", "BTCUSDT") # decreases score
best = failover.get_best_provider("BTCUSDT") # → "binance" or "bybit"
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class ProviderHealth:
"""Health tracker for one provider+symbol pair."""
score: float = 1.0
event_count: int = 0
gap_count: int = 0
last_event_ts: float = 0.0
last_gap_ts: float = 0.0
def record_event(self) -> None:
"""Bump health score on successful event."""
self.event_count += 1
self.last_event_ts = time.monotonic()
# Recover towards 1.0 gradually
self.score = min(1.0, self.score + 0.01)
def record_gap(self) -> None:
"""Decrease health score on gap/timeout."""
self.gap_count += 1
self.last_gap_ts = time.monotonic()
self.score = max(0.0, self.score - 0.2)
class FailoverManager:
"""
Tracks provider health and recommends best source per symbol.
"""
def __init__(
self,
primary: str,
backups: list[str] | None = None,
switch_threshold: float = 0.3,
recovery_threshold: float = 0.7,
) -> None:
self._primary = primary
self._backups = backups or []
self._all_providers = [primary] + self._backups
self._switch_threshold = switch_threshold
self._recovery_threshold = recovery_threshold
# provider → symbol → ProviderHealth
self._health: dict[str, dict[str, ProviderHealth]] = {}
# symbol → currently active provider
self._active: dict[str, str] = {}
def _get_health(self, provider: str, symbol: str) -> ProviderHealth:
"""Get or create health tracker."""
if provider not in self._health:
self._health[provider] = {}
if symbol not in self._health[provider]:
self._health[provider][symbol] = ProviderHealth()
return self._health[provider][symbol]
def record_event(self, provider: str, symbol: str) -> None:
"""Record a successful event from provider for symbol."""
self._get_health(provider, symbol).record_event()
def record_gap(self, provider: str, symbol: str) -> None:
"""Record a gap/timeout for provider+symbol."""
h = self._get_health(provider, symbol)
h.record_gap()
logger.warning(
"failover.gap_recorded",
extra={
"provider": provider,
"symbol": symbol,
"score": f"{h.score:.2f}",
"gaps": h.gap_count,
},
)
def get_best_provider(self, symbol: str) -> str:
"""
Return the currently recommended provider for this symbol.
Logic:
1. If active provider score >= switch_threshold → keep it
2. If active provider drops below → switch to healthiest backup
3. If active provider recovers above recovery_threshold → switch back to primary
"""
current = self._active.get(symbol, self._primary)
current_health = self._get_health(current, symbol)
# Check if current provider is degraded
if current_health.score < self._switch_threshold:
# Find best backup
best_provider = current
best_score = current_health.score
for p in self._all_providers:
if p == current:
continue
h = self._get_health(p, symbol)
if h.score > best_score:
best_provider = p
best_score = h.score
if best_provider != current:
logger.warning(
"failover.switching",
extra={
"symbol": symbol,
"from": current,
"to": best_provider,
"old_score": f"{current_health.score:.2f}",
"new_score": f"{best_score:.2f}",
},
)
self._active[symbol] = best_provider
return best_provider
# Check if primary has recovered and we're on a backup
if current != self._primary:
primary_health = self._get_health(self._primary, symbol)
if primary_health.score >= self._recovery_threshold:
logger.info(
"failover.returning_to_primary",
extra={
"symbol": symbol,
"primary_score": f"{primary_health.score:.2f}",
},
)
self._active[symbol] = self._primary
return self._primary
self._active[symbol] = current
return current
def get_status(self) -> dict:
"""Return full failover status for monitoring."""
status = {}
for provider, symbols in self._health.items():
for symbol, health in symbols.items():
key = f"{provider}/{symbol}"
status[key] = {
"score": round(health.score, 2),
"events": health.event_count,
"gaps": health.gap_count,
"active": self._active.get(symbol) == provider,
}
return status

View File

@@ -18,7 +18,6 @@ import asyncio
import logging
import signal
import sys
from contextlib import asynccontextmanager
import structlog
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
@@ -26,14 +25,18 @@ from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
from app.config import settings
from app.core.bus import EventBus
from app.consumers.metrics import MetricsConsumer
from app.consumers.nats_output import NatsOutputConsumer
from app.consumers.print import PrintConsumer
from app.consumers.storage import StorageConsumer
from app.db.schema import init_db
from app.db.schema import engine, init_db
from app.db import repo
from app.providers import MarketDataProvider, get_provider
logger = structlog.get_logger()
# Global reference to bus (for HTTP status endpoint)
_bus: EventBus | None = None
# ── Logging setup ──────────────────────────────────────────────────────
@@ -105,6 +108,18 @@ async def _http_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWrit
}
body = json.dumps(result, ensure_ascii=False).encode()
content_type = "application/json"
elif path == "/bus-stats":
import json as _json
bus_info = {"queue_size": 0, "fill_percent": 0.0}
if _bus:
bus_info = {
"queue_size": _bus.queue_size,
"fill_percent": round(_bus.fill_percent * 100, 1),
"max_size": _bus._max_size,
}
body = _json.dumps(bus_info).encode()
content_type = "application/json"
else:
body = b'{"error":"not found"}'
content_type = "application/json"
@@ -179,8 +194,13 @@ async def main(provider_names: list[str], symbols: list[str]) -> None:
# Init database
await init_db()
# Setup bus + consumers
bus = EventBus()
global _bus
# Setup bus + consumers (heartbeat interval from config)
bus = EventBus(
queue_size=10_000,
heartbeat_interval=settings.heartbeat_timeout / 2, # check twice per timeout
)
storage = StorageConsumer()
await storage.start()
@@ -192,16 +212,29 @@ async def main(provider_names: list[str], symbols: list[str]) -> None:
printer = PrintConsumer()
bus.add_consumer(printer)
# Optional: NATS output adapter
nats_consumer = None
if settings.nats_configured:
nats_consumer = NatsOutputConsumer()
await nats_consumer.start()
bus.add_consumer(nats_consumer)
logger.info("nats_output.enabled", subject_prefix=settings.nats_subject_prefix)
else:
logger.info("nats_output.disabled", hint="Set NATS_URL + NATS_ENABLED=true to enable")
# Create providers and register them for heartbeat monitoring
providers: list[MarketDataProvider] = []
for name in provider_names:
p = get_provider(name)
providers.append(p)
bus.register_provider(p.name)
_bus = bus
await bus.start()
# Start HTTP server
http_server = await start_http_server()
# Create providers
providers: list[MarketDataProvider] = []
for name in provider_names:
providers.append(get_provider(name))
# Run all providers concurrently
tasks = []
for p in providers:
@@ -224,21 +257,43 @@ async def main(provider_names: list[str], symbols: list[str]) -> None:
# Wait for shutdown
await shutdown_event.wait()
# Cleanup
# ── Graceful shutdown sequence ──────────────────────────────────────
logger.info("service.shutting_down")
# 1. Cancel provider streaming tasks (with timeout)
for task in tasks:
task.cancel()
done, pending = await asyncio.wait(tasks, timeout=5.0)
for task in pending:
logger.warning("service.task_force_cancel", extra={"task": task.get_name()})
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
# 2. Close provider WebSocket connections
for p in providers:
await p.close()
try:
await p.close()
except Exception as e:
logger.warning("service.provider_close_error", extra={"provider": p.name, "error": str(e)})
# 3. Stop bus (drains remaining events to consumers)
await bus.stop()
# 4. Stop storage consumer (flush JSONL)
await storage.stop()
# 4b. Stop NATS output (flush + close)
if nats_consumer:
await nats_consumer.stop()
# 5. Close HTTP server
http_server.close()
await http_server.wait_closed()
logger.info("service.stopped")
# 6. Close SQLAlchemy engine (flush connections)
await engine.dispose()
logger.info("service.stopped", extra={"exit": "clean"})
# ── CLI ────────────────────────────────────────────────────────────────
@@ -270,7 +325,7 @@ def cli():
symbols = [s.strip() for s in args.symbols.split(",") if s.strip()]
if args.provider.lower() == "all":
provider_names = ["binance", "alpaca"]
provider_names = ["binance", "alpaca", "bybit"]
else:
provider_names = [p.strip() for p in args.provider.split(",") if p.strip()]

View File

@@ -45,10 +45,12 @@ def get_provider(name: str) -> MarketDataProvider:
"""Factory: instantiate provider by name."""
from app.providers.binance import BinanceProvider
from app.providers.alpaca import AlpacaProvider
from app.providers.bybit import BybitProvider
registry: dict[str, type[MarketDataProvider]] = {
"binance": BinanceProvider,
"alpaca": AlpacaProvider,
"bybit": BybitProvider,
}
cls = registry.get(name.lower())
if cls is None:

View File

@@ -17,7 +17,7 @@ from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timezone
from datetime import datetime
from typing import AsyncIterator
import websockets

View File

@@ -12,7 +12,6 @@ from __future__ import annotations
import asyncio
import json
import logging
import time
from datetime import datetime, timezone
from typing import AsyncIterator
@@ -22,7 +21,6 @@ from websockets.exceptions import ConnectionClosed
from app.config import settings
from app.domain.events import (
Event,
HeartbeatEvent,
QuoteEvent,
TradeEvent,
)

View File

@@ -0,0 +1,239 @@
"""
Bybit V5 public WebSocket provider — backup for Binance.
Streams:
- publicTrade.{symbol} → TradeEvent
- tickers.{symbol} → QuoteEvent (best bid/ask from tickers)
Docs: https://bybit-exchange.github.io/docs/v5/ws/connect
No API key needed for public market data.
"""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import AsyncIterator
import websockets
from websockets.exceptions import ConnectionClosed
from app.config import settings
from app.domain.events import (
Event,
QuoteEvent,
TradeEvent,
)
from app.providers import MarketDataProvider
logger = logging.getLogger(__name__)
def _ms_to_dt(ms: int | float | str | None) -> datetime | None:
"""Convert millisecond epoch to UTC datetime."""
if ms is None:
return None
try:
return datetime.fromtimestamp(int(ms) / 1000.0, tz=timezone.utc)
except (ValueError, TypeError, OSError):
return None
class BybitProvider(MarketDataProvider):
"""
Bybit V5 public WebSocket (spot market).
Connects to the spot public channel and subscribes to
publicTrade + tickers for each symbol.
"""
name = "bybit"
def __init__(self) -> None:
self._ws: websockets.WebSocketClientProtocol | None = None
self._symbols: list[str] = []
self._connected = False
self._reconnect_count = 0
self._base_url = settings.bybit_ws_url
async def connect(self) -> None:
"""Establish WebSocket connection."""
logger.info("bybit.connecting", extra={"url": self._base_url})
self._ws = await websockets.connect(
self._base_url,
ping_interval=20,
ping_timeout=10,
close_timeout=5,
)
self._connected = True
logger.info("bybit.connected")
async def subscribe(self, symbols: list[str]) -> None:
"""Subscribe to publicTrade + tickers for each symbol."""
if not self._ws:
raise RuntimeError("Not connected. Call connect() first.")
self._symbols = [s.upper() for s in symbols]
args = []
for sym in self._symbols:
args.append(f"publicTrade.{sym}")
args.append(f"tickers.{sym}")
subscribe_msg = {
"op": "subscribe",
"args": args,
}
await self._ws.send(json.dumps(subscribe_msg))
logger.info(
"bybit.subscribed",
extra={"symbols": self._symbols, "channels": len(args)},
)
async def stream(self) -> AsyncIterator[Event]:
"""Yield domain events. Handles reconnect automatically."""
backoff = settings.reconnect_base_delay
while True:
try:
if not self._connected or not self._ws:
await self._reconnect(backoff)
try:
raw = await asyncio.wait_for(
self._ws.recv(), # type: ignore
timeout=settings.heartbeat_timeout,
)
except asyncio.TimeoutError:
logger.warning(
"bybit.heartbeat_timeout",
extra={"timeout": settings.heartbeat_timeout},
)
self._connected = False
continue
# Reset backoff on successful message
backoff = settings.reconnect_base_delay
data = json.loads(raw)
# Handle pong (Bybit sends {"op":"pong",...})
if data.get("op") in ("pong", "subscribe"):
if data.get("success") is False:
logger.warning("bybit.subscribe_failed", extra={"msg": data})
continue
event = self._parse(data)
if event:
yield event
except ConnectionClosed as e:
logger.warning(
"bybit.connection_closed",
extra={"code": e.code, "reason": str(e.reason)},
)
self._connected = False
backoff = min(backoff * 2, settings.reconnect_max_delay)
except Exception as e:
logger.error("bybit.stream_error", extra={"error": str(e)})
self._connected = False
backoff = min(backoff * 2, settings.reconnect_max_delay)
async def _reconnect(self, delay: float) -> None:
"""Reconnect with delay, then resubscribe."""
self._reconnect_count += 1
logger.info(
"bybit.reconnecting",
extra={"delay": delay, "attempt": self._reconnect_count},
)
await asyncio.sleep(delay)
try:
if self._ws:
await self._ws.close()
except Exception:
pass
await self.connect()
if self._symbols:
await self.subscribe(self._symbols)
def _parse(self, data: dict) -> Event | None:
"""Parse raw Bybit V5 message into domain events."""
topic = data.get("topic", "")
event_data = data.get("data")
if not topic or event_data is None:
return None
if topic.startswith("publicTrade."):
return self._parse_trades(event_data)
elif topic.startswith("tickers."):
return self._parse_ticker(event_data)
return None
def _parse_trades(self, data: list | dict) -> Event | None:
"""
Bybit publicTrade payload (V5):
{"data": [{"s":"BTCUSDT","S":"Buy","v":"0.001","p":"70000.5","T":1672515782136,"i":"..."}]}
We take the last trade in the batch.
"""
if isinstance(data, list):
if not data:
return None
trade = data[-1] # latest in batch
else:
trade = data
return TradeEvent(
provider=self.name,
symbol=trade.get("s", "").upper(),
price=float(trade.get("p", 0)),
size=float(trade.get("v", 0)),
ts_exchange=_ms_to_dt(trade.get("T")),
side=trade.get("S", "").lower() if trade.get("S") else None,
trade_id=str(trade.get("i", "")),
)
def _parse_ticker(self, data: dict) -> QuoteEvent | None:
"""
Bybit tickers (V5 spot):
{"data": {"symbol":"BTCUSDT","bid1Price":"70000.5","bid1Size":"1.5",
"ask1Price":"70001.0","ask1Size":"2.0",...}}
"""
if isinstance(data, list):
data = data[0] if data else {}
bid = data.get("bid1Price") or data.get("bidPrice")
ask = data.get("ask1Price") or data.get("askPrice")
bid_size = data.get("bid1Size") or data.get("bidSize")
ask_size = data.get("ask1Size") or data.get("askSize")
if not bid or not ask:
return None
return QuoteEvent(
provider=self.name,
symbol=data.get("symbol", "").upper(),
bid=float(bid),
ask=float(ask),
bid_size=float(bid_size or 0),
ask_size=float(ask_size or 0),
ts_exchange=_ms_to_dt(data.get("ts")),
)
async def close(self) -> None:
"""Close the WebSocket connection."""
self._connected = False
if self._ws:
try:
await self._ws.close()
except Exception:
pass
logger.info(
"bybit.closed",
extra={"reconnect_count": self._reconnect_count},
)

View File

@@ -0,0 +1,27 @@
# Auto-generated pinned dependencies — 2026-02-09
# Install: pip install -r requirements.txt -c requirements.lock
aiosqlite==0.22.1
annotated-types==0.7.0
anyio==4.12.1
certifi==2026.1.4
greenlet==3.3.1
h11==0.16.0
httpcore==1.0.9
httpx==0.28.1
idna==3.11
nats-py==2.13.1
prometheus_client==0.24.1
pydantic==2.12.5
pydantic-settings==2.12.0
pydantic_core==2.41.5
python-dotenv==1.2.1
SQLAlchemy==2.0.46
structlog==25.5.0
tenacity==9.1.4
typing-inspection==0.4.2
typing_extensions==4.15.0
websockets==16.0
# Dev
pytest==9.0.2
pytest-asyncio==1.3.0
ruff==0.15.0

View File

@@ -22,6 +22,9 @@ structlog>=24.1
# Metrics
prometheus_client>=0.20
# NATS output (optional — for SenpAI integration)
nats-py>=2.7
# Testing
pytest>=8.0
pytest-asyncio>=0.23

View File

@@ -121,3 +121,69 @@ async def test_bus_queue_overflow():
# Some events were dropped, but consumer got the ones that fit
assert len(consumer.events) >= 1
@pytest.mark.asyncio
async def test_bus_backpressure_drops_quotes_before_trades():
"""Under backpressure, quotes are dropped but trades survive."""
from app.domain.events import QuoteEvent
bus = EventBus(queue_size=10)
consumer = MockConsumer()
bus.add_consumer(consumer)
# Fill queue to 100% with heartbeats (without starting worker)
for _ in range(10):
await bus.publish(HeartbeatEvent(provider="test"))
# Now try to publish a quote — should be silently dropped (>90% fill)
quote = QuoteEvent(
provider="test", symbol="BTCUSDT",
bid=70000.0, ask=70001.0, bid_size=1.0, ask_size=1.0,
)
await bus.publish(quote)
# Start worker, drain existing events
await bus.start()
await asyncio.sleep(0.1)
await bus.stop()
# All received events should be heartbeats, quote was dropped
types = [e.event_type for e in consumer.events]
# The queue was full so older events get replaced; quote should NOT be there
assert EventType.TRADE not in types # no trades published
# Verify no quotes survived (they are low-priority under pressure)
# Note: with queue_size=10 and 10 heartbeats, queue was 100% full
# Quote at fill=100% with priority=1 gets dropped
@pytest.mark.asyncio
async def test_bus_heartbeat_monitor_emits_on_silence():
"""Heartbeat monitor fires when a provider goes silent."""
bus = EventBus(queue_size=100, heartbeat_interval=0.3)
consumer = MockConsumer()
bus.add_consumer(consumer)
bus.register_provider("test_provider")
await bus.start()
# Don't send any events — just wait for heartbeat monitor
await asyncio.sleep(0.8)
await bus.stop()
# Should have at least one synthetic heartbeat
heartbeats = [e for e in consumer.events if e.event_type == EventType.HEARTBEAT]
assert len(heartbeats) >= 1
assert heartbeats[0].provider == "test_provider"
@pytest.mark.asyncio
async def test_bus_fill_percent():
"""fill_percent property works correctly."""
bus = EventBus(queue_size=100)
assert bus.fill_percent == 0.0
for _ in range(50):
await bus.publish(HeartbeatEvent(provider="test"))
assert 0.49 <= bus.fill_percent <= 0.51

View File

@@ -0,0 +1,151 @@
"""
Unit tests for Bybit provider — raw JSON → domain event parsing.
"""
import pytest
from app.domain.events import EventType
from app.providers.bybit import BybitProvider
@pytest.fixture
def provider():
return BybitProvider()
# ── Trade parsing ──────────────────────────────────────────────────────
def test_parse_trade_basic(provider):
"""Basic publicTrade parsing."""
raw = {
"topic": "publicTrade.BTCUSDT",
"data": [
{
"s": "BTCUSDT",
"S": "Buy",
"v": "0.001",
"p": "70500.5",
"T": 1672515782136,
"i": "trade123",
}
],
}
event = provider._parse(raw)
assert event is not None
assert event.event_type == EventType.TRADE
assert event.symbol == "BTCUSDT"
assert event.price == 70500.5
assert event.size == 0.001
assert event.side == "buy"
assert event.trade_id == "trade123"
assert event.provider == "bybit"
def test_parse_trade_sell_side(provider):
"""Sell side trade."""
raw = {
"topic": "publicTrade.ETHUSDT",
"data": [
{
"s": "ETHUSDT",
"S": "Sell",
"v": "10.5",
"p": "2100.00",
"T": 1672515782136,
"i": "t456",
}
],
}
event = provider._parse(raw)
assert event.side == "sell"
assert event.symbol == "ETHUSDT"
def test_parse_trade_batch_takes_last(provider):
"""Multiple trades in a batch — takes the last one."""
raw = {
"topic": "publicTrade.BTCUSDT",
"data": [
{"s": "BTCUSDT", "S": "Buy", "v": "0.001", "p": "70000.0", "T": 100, "i": "first"},
{"s": "BTCUSDT", "S": "Sell", "v": "0.01", "p": "70100.0", "T": 200, "i": "last"},
],
}
event = provider._parse(raw)
assert event.trade_id == "last"
assert event.price == 70100.0
def test_parse_trade_timestamp(provider):
"""Exchange timestamp is correctly parsed."""
raw = {
"topic": "publicTrade.BTCUSDT",
"data": [
{"s": "BTCUSDT", "S": "Buy", "v": "1", "p": "70000", "T": 1672515782136, "i": "x"},
],
}
event = provider._parse(raw)
assert event.ts_exchange is not None
assert event.ts_exchange.year >= 2022
# ── Ticker (quote) parsing ─────────────────────────────────────────────
def test_parse_ticker_basic(provider):
"""Bybit tickers → QuoteEvent."""
raw = {
"topic": "tickers.BTCUSDT",
"data": {
"symbol": "BTCUSDT",
"bid1Price": "70000.5",
"bid1Size": "1.5",
"ask1Price": "70001.0",
"ask1Size": "2.0",
"ts": "1672515782136",
},
}
event = provider._parse(raw)
assert event is not None
assert event.event_type == EventType.QUOTE
assert event.symbol == "BTCUSDT"
assert event.bid == 70000.5
assert event.ask == 70001.0
assert event.bid_size == 1.5
assert event.ask_size == 2.0
assert event.provider == "bybit"
def test_parse_ticker_missing_bid(provider):
"""Ticker without bid → returns None."""
raw = {
"topic": "tickers.BTCUSDT",
"data": {"symbol": "BTCUSDT"},
}
event = provider._parse(raw)
assert event is None
# ── Edge cases ─────────────────────────────────────────────────────────
def test_parse_unknown_topic(provider):
"""Unknown topic → None."""
raw = {"topic": "some_unknown.BTCUSDT", "data": {}}
event = provider._parse(raw)
assert event is None
def test_parse_pong_skipped(provider):
"""Pong/subscribe messages are not events."""
raw = {"op": "pong", "success": True}
# _parse would not be called for op messages (handled in stream()),
# but let's verify _parse returns None for incomplete data
event = provider._parse(raw)
assert event is None
def test_parse_empty_trade_data(provider):
"""Empty trade data array → None."""
raw = {"topic": "publicTrade.BTCUSDT", "data": []}
event = provider._parse(raw)
assert event is None

View File

@@ -0,0 +1,82 @@
"""
Tests for the failover manager.
"""
from app.core.failover import FailoverManager
def test_default_returns_primary():
"""Without any events, primary is the recommended provider."""
fm = FailoverManager(primary="binance", backups=["bybit"])
assert fm.get_best_provider("BTCUSDT") == "binance"
def test_gaps_cause_switch():
"""Enough gaps should cause a switch to backup."""
fm = FailoverManager(
primary="binance",
backups=["bybit"],
switch_threshold=0.3,
)
# Record some events for bybit so it has health
for _ in range(10):
fm.record_event("bybit", "BTCUSDT")
# Degrade binance heavily (5 gaps = -1.0)
for _ in range(5):
fm.record_gap("binance", "BTCUSDT")
best = fm.get_best_provider("BTCUSDT")
assert best == "bybit"
def test_recovery_returns_to_primary():
"""When primary recovers, switch back from backup."""
fm = FailoverManager(
primary="binance",
backups=["bybit"],
switch_threshold=0.3,
recovery_threshold=0.7,
)
# Degrade primary and switch to backup
for _ in range(10):
fm.record_event("bybit", "BTCUSDT")
for _ in range(5):
fm.record_gap("binance", "BTCUSDT")
assert fm.get_best_provider("BTCUSDT") == "bybit"
# Now primary recovers (many events increase score)
for _ in range(100):
fm.record_event("binance", "BTCUSDT")
assert fm.get_best_provider("BTCUSDT") == "binance"
def test_status_report():
"""Status report includes all provider/symbol pairs."""
fm = FailoverManager(primary="binance", backups=["bybit"])
fm.record_event("binance", "BTCUSDT")
fm.record_event("bybit", "BTCUSDT")
fm.record_gap("binance", "ETHUSDT")
status = fm.get_status()
assert "binance/BTCUSDT" in status
assert "bybit/BTCUSDT" in status
assert "binance/ETHUSDT" in status
assert status["binance/BTCUSDT"]["events"] == 1
assert status["binance/ETHUSDT"]["gaps"] == 1
def test_no_backup_stays_on_primary():
"""Without backups, always returns primary even when degraded."""
fm = FailoverManager(primary="binance", backups=[])
for _ in range(5):
fm.record_gap("binance", "BTCUSDT")
# No alternative, stays on binance
assert fm.get_best_provider("BTCUSDT") == "binance"

View File

@@ -0,0 +1,8 @@
.venv/
__pycache__/
*.pyc
.pytest_cache/
.ruff_cache/
.env
tests/
.git/

View File

@@ -0,0 +1,38 @@
# SenpAI Market-Data Consumer Configuration
# Copy to .env and adjust as needed
# ── NATS ──────────────────────────────────────────────────────────────
NATS_URL=nats://localhost:4222
NATS_SUBJECT=md.events.>
NATS_QUEUE_GROUP=senpai-md
USE_JETSTREAM=false
# ── Internal queue ────────────────────────────────────────────────────
QUEUE_SIZE=50000
QUEUE_DROP_THRESHOLD=0.9
# ── Features / signals ───────────────────────────────────────────────
FEATURES_ENABLED=true
FEATURES_PUB_RATE_HZ=10
FEATURES_PUB_SUBJECT=senpai.features
SIGNALS_PUB_SUBJECT=senpai.signals
ALERTS_PUB_SUBJECT=senpai.alerts
# ── Rolling window ───────────────────────────────────────────────────
ROLLING_WINDOW_SECONDS=60.0
# ── Signal rules ─────────────────────────────────────────────────────
SIGNAL_RETURN_THRESHOLD=0.003
SIGNAL_VOLUME_THRESHOLD=1.0
SIGNAL_SPREAD_MAX_BPS=20.0
# ── Alert thresholds ─────────────────────────────────────────────────
ALERT_LATENCY_MS=1000.0
ALERT_GAP_SECONDS=30.0
# ── HTTP API ─────────────────────────────────────────────────────────
HTTP_HOST=0.0.0.0
HTTP_PORT=8892
# ── Logging ──────────────────────────────────────────────────────────
LOG_LEVEL=INFO

View File

@@ -0,0 +1,20 @@
# ── SenpAI Market-Data Consumer ─────────────────────────────────────────
FROM python:3.11-slim
ENV PYTHONDONTWRITEBYTECODE=1 \
PYTHONUNBUFFERED=1
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY senpai/ ./senpai/
COPY pyproject.toml .
HEALTHCHECK --interval=15s --timeout=5s --start-period=10s --retries=3 \
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8892/health')" || exit 1
EXPOSE 8892
ENTRYPOINT ["python", "-m", "senpai.md_consumer"]

View File

@@ -0,0 +1,242 @@
# SenpAI Market-Data Consumer
NATS subscriber + feature engine + signal bus for the SenpAI/Gordon trading agent.
Consumes normalised events from `market-data-service`, computes real-time features, and publishes signals back to NATS.
## Architecture
```
market-data-service SenpAI MD Consumer
┌──────────────┐ ┌────────────────────────────────┐
│ Binance WS │ │ │
│ Bybit WS │──► NATS ──────────► NATSConsumer │
│ Alpaca WS │ md.events.> │ ↓ (bounded queue) │
└──────────────┘ │ State Store │
│ ├─ LatestState (trade/quote)│
│ └─ RollingWindow (60s deque)│
│ ↓ │
│ Feature Engine │
│ ├─ mid, spread, vwap │
│ ├─ return_10s, vol_60s │
│ └─ latency p50/p95 │
│ ↓ │
│ Publisher ──► NATS │
│ ├─ senpai.features.{symbol} │
│ ├─ senpai.signals.{symbol} │
│ └─ senpai.alerts │
│ │
│ HTTP API (:8892) │
│ /health /metrics /stats │
│ /state/latest /features │
└────────────────────────────────┘
```
## Quick Start
### 1. Install
```bash
cd services/senpai-md-consumer
pip install -r requirements.txt
cp .env.example .env
```
### 2. Start NATS (if not running)
```bash
docker run -d --name nats -p 4222:4222 -p 8222:8222 nats:2.10-alpine --js -m 8222
```
### 3. Start market-data-service (producer)
```bash
cd ../market-data-service
python -m app run --provider binance --symbols BTCUSDT,ETHUSDT
```
### 4. Start SenpAI MD Consumer
```bash
cd ../senpai-md-consumer
python -m senpai.md_consumer
```
### 5. Verify
```bash
# Health
curl http://localhost:8892/health
# Stats
curl http://localhost:8892/stats
# Latest state
curl "http://localhost:8892/state/latest?symbol=BTCUSDT"
# Computed features
curl "http://localhost:8892/features/latest?symbol=BTCUSDT"
# Prometheus metrics
curl http://localhost:8892/metrics
```
## Docker
### Standalone (with NATS)
```bash
docker-compose -f docker-compose.senpai.yml up -d
```
### Part of NODE1 stack
```bash
docker-compose -f docker-compose.node1.yml up -d market-data-service senpai-md-consumer
```
## NATS Subjects
### Consumed (from market-data-service)
| Subject | Description |
|---|---|
| `md.events.trade.{symbol}` | Trade events |
| `md.events.quote.{symbol}` | Quote events |
| `md.events.book_l2.{symbol}` | L2 book snapshots |
| `md.events.heartbeat.__system__` | Provider heartbeats |
### Published (for SenpAI/other consumers)
| Subject | Description |
|---|---|
| `senpai.features.{symbol}` | Feature snapshots (rate-limited to 10Hz/symbol) |
| `senpai.signals.{symbol}` | Trade signals (long/short) |
| `senpai.alerts` | System alerts (latency, gaps, backpressure) |
## Features Computed
| Feature | Description |
|---|---|
| `mid` | (bid + ask) / 2 |
| `spread_abs` | ask - bid |
| `spread_bps` | spread in basis points |
| `trade_vwap_10s` | VWAP over 10 seconds |
| `trade_vwap_60s` | VWAP over 60 seconds |
| `trade_count_10s` | Number of trades in 10s |
| `trade_volume_10s` | Total volume in 10s |
| `return_10s` | Price return over 10 seconds |
| `realized_vol_60s` | Realised volatility (60s log-return std) |
| `latency_ms_p50` | p50 exchange-to-receive latency |
| `latency_ms_p95` | p95 exchange-to-receive latency |
## Signal Rules (MVP)
**Long signal** emitted when ALL conditions met:
- `return_10s > 0.3%` (configurable)
- `trade_volume_10s > 1.0` (configurable)
- `spread_bps < 20` (configurable)
**Short signal**: same but `return_10s < -0.3%`
## Backpressure Policy
- Queue < 90% → accept all events
- Queue >= 90% → drop heartbeats, quotes, book snapshots
- **Trades are NEVER dropped**
## HTTP Endpoints
| Endpoint | Description |
|---|---|
| `GET /health` | Service health + tracked symbols |
| `GET /metrics` | Prometheus metrics |
| `GET /state/latest?symbol=` | Latest trade + quote |
| `GET /features/latest?symbol=` | Current computed features |
| `GET /stats` | Queue fill, drops, events/sec |
## Prometheus Metrics
| Metric | Type | Description |
|---|---|---|
| `senpai_events_in_total` | Counter | Events received {type, provider} |
| `senpai_events_dropped_total` | Counter | Dropped events {reason, type} |
| `senpai_queue_fill_ratio` | Gauge | Queue fill 0..1 |
| `senpai_processing_latency_ms` | Histogram | Processing latency |
| `senpai_feature_publish_total` | Counter | Feature publishes {symbol} |
| `senpai_signals_emitted_total` | Counter | Signals {symbol, direction} |
| `senpai_nats_connected` | Gauge | NATS connection status |
## Tests
```bash
pytest tests/ -v
```
41 tests:
- 11 model parsing tests (tolerant parsing, edge cases)
- 10 state/rolling window tests (eviction, lookup)
- 16 feature math tests (VWAP, vol, signals, percentile)
- 5 rate-limit tests (publish throttling, error handling)
## Troubleshooting
### NATS connection refused
```
nats.error: error=could not connect to server
```
Ensure NATS is running:
```bash
docker run -d --name nats -p 4222:4222 nats:2.10-alpine --js
```
Or check `NATS_URL` in `.env`.
### No events arriving (queue stays at 0)
1. Verify `market-data-service` is running and `NATS_ENABLED=true`
2. Check subject match: producer publishes to `md.events.trade.BTCUSDT`, consumer subscribes to `md.events.>`
3. Check NATS monitoring: `curl http://localhost:8222/connz` — both services should appear
### JetStream errors
If `USE_JETSTREAM=true` but NATS started without `--js`:
```bash
# Restart NATS with JetStream
docker rm -f nats
docker run -d -p 4222:4222 -p 8222:8222 nats:2.10-alpine --js -m 8222
```
Or set `USE_JETSTREAM=false` for core NATS (simpler, works for MVP).
### Port 8892 already in use
```bash
lsof -ti:8892 | xargs kill -9
```
### Features show `null` for all values
Normal on startup — features populate after first trade+quote arrive. Wait a few seconds for Binance data to flow through.
### No signals emitted
Signal rules require ALL conditions simultaneously:
- `return_10s > 0.3%` — needs price movement
- `volume_10s > 1.0` — needs trading activity
- `spread_bps < 20` — needs tight spread
In low-volatility markets, signals may be rare. Lower thresholds in `.env` for testing:
```env
SIGNAL_RETURN_THRESHOLD=0.001
SIGNAL_VOLUME_THRESHOLD=0.1
```
### High memory usage
Rolling windows grow per symbol. With many symbols, reduce window:
```env
ROLLING_WINDOW_SECONDS=30
```
## Configuration (ENV)
See `.env.example` for all available settings.
Key settings:
- `NATS_URL` — NATS server URL
- `FEATURES_PUB_RATE_HZ` — max feature publishes per symbol per second
- `SIGNAL_RETURN_THRESHOLD` — min return for signal trigger
- `ROLLING_WINDOW_SECONDS` — rolling window duration

View File

@@ -0,0 +1,40 @@
# SenpAI Market-Data Consumer + NATS
# Usage: docker-compose -f docker-compose.senpai.yml up -d
version: "3.8"
services:
nats:
image: nats:2.10-alpine
container_name: senpai-nats
ports:
- "4222:4222"
- "8222:8222" # monitoring
command: ["--js", "-m", "8222"]
restart: unless-stopped
senpai-md-consumer:
container_name: senpai-md-consumer
build:
context: .
dockerfile: Dockerfile
environment:
- NATS_URL=nats://nats:4222
- NATS_SUBJECT=md.events.>
- NATS_QUEUE_GROUP=senpai-md
- FEATURES_ENABLED=true
- FEATURES_PUB_RATE_HZ=10
- LOG_LEVEL=INFO
- HTTP_PORT=8892
ports:
- "8892:8892"
depends_on:
- nats
restart: unless-stopped
healthcheck:
test:
- CMD-SHELL
- python -c "import urllib.request; urllib.request.urlopen('http://localhost:8892/health')"
interval: 15s
timeout: 5s
retries: 3
start_period: 10s

View File

@@ -0,0 +1,13 @@
[project]
name = "senpai-md-consumer"
version = "0.1.0"
description = "SenpAI market-data consumer — NATS subscriber, feature engine, signal bus"
requires-python = ">=3.11"
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
[tool.ruff]
target-version = "py311"
line-length = 100

View File

@@ -0,0 +1,16 @@
# Auto-generated pinned dependencies — 2026-02-09
# Install: pip install -r requirements.txt -c requirements.lock
annotated-types==0.7.0
nats-py==2.13.1
prometheus_client==0.24.1
pydantic==2.12.5
pydantic-settings==2.12.0
pydantic_core==2.41.5
python-dotenv==1.2.1
structlog==25.5.0
typing-inspection==0.4.2
typing_extensions==4.15.0
# Dev
pytest==9.0.2
pytest-asyncio==1.3.0
ruff==0.15.0

View File

@@ -0,0 +1,22 @@
# SenpAI Market-Data Consumer
# Python 3.11+
# Core
pydantic>=2.5
pydantic-settings>=2.1
# NATS
nats-py>=2.7
# Logging
structlog>=24.1
# Metrics
prometheus_client>=0.20
# Testing
pytest>=8.0
pytest-asyncio>=0.23
# Linting
ruff>=0.3

View File

@@ -0,0 +1,4 @@
"""Allow running as: python -m senpai.md_consumer"""
from senpai.md_consumer.main import cli
cli()

View File

@@ -0,0 +1,166 @@
"""
Minimal HTTP API — lightweight asyncio server (no framework dependency).
Endpoints:
GET /health → service health
GET /metrics → Prometheus metrics
GET /state/latest → latest trade/quote per symbol (?symbol=BTCUSDT)
GET /features/latest → latest computed features (?symbol=BTCUSDT)
GET /stats → queue fill, drops, events/sec
"""
from __future__ import annotations
import asyncio
import json
import logging
from prometheus_client import CONTENT_TYPE_LATEST, generate_latest
from senpai.md_consumer.config import settings
from senpai.md_consumer.features import compute_features
from senpai.md_consumer.state import LatestState
logger = logging.getLogger(__name__)
# These are set by main.py at startup
_state: LatestState | None = None
_stats_fn = None # callable → dict
def set_state(state: LatestState) -> None:
global _state
_state = state
def set_stats_fn(fn) -> None:
global _stats_fn
_stats_fn = fn
async def _handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
"""Minimal HTTP request handler."""
try:
request_line = await asyncio.wait_for(reader.readline(), timeout=5.0)
request_str = request_line.decode("utf-8", errors="replace").strip()
parts = request_str.split()
if len(parts) < 2:
writer.close()
return
path = parts[1]
# Consume headers
while True:
line = await reader.readline()
if line in (b"\r\n", b"\n", b""):
break
# Parse query params
query_params: dict[str, str] = {}
if "?" in path:
base_path, query = path.split("?", 1)
for param in query.split("&"):
if "=" in param:
k, v = param.split("=", 1)
query_params[k] = v
else:
base_path = path
body, content_type, status = await _route(base_path, query_params)
response = (
f"HTTP/1.1 {status}\r\n"
f"Content-Type: {content_type}\r\n"
f"Content-Length: {len(body)}\r\n"
f"Connection: close\r\n"
f"\r\n"
)
writer.write(response.encode() + body)
await writer.drain()
except Exception:
pass
finally:
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
async def _route(
path: str, params: dict[str, str]
) -> tuple[bytes, str, str]:
"""Route request to handler. Returns (body, content_type, status)."""
if path == "/health":
body = json.dumps({
"status": "ok",
"service": "senpai-md-consumer",
"symbols": _state.symbols if _state else [],
}).encode()
return body, "application/json", "200 OK"
elif path == "/metrics":
body = generate_latest()
return body, CONTENT_TYPE_LATEST, "200 OK"
elif path == "/state/latest":
symbol = params.get("symbol", "")
if not symbol:
body = json.dumps({"error": "missing ?symbol=XXX"}).encode()
return body, "application/json", "400 Bad Request"
if not _state:
body = json.dumps({"error": "not initialized"}).encode()
return body, "application/json", "503 Service Unavailable"
data = _state.to_dict(symbol)
body = json.dumps(data, ensure_ascii=False).encode()
return body, "application/json", "200 OK"
elif path == "/features/latest":
symbol = params.get("symbol", "")
if not symbol:
body = json.dumps({"error": "missing ?symbol=XXX"}).encode()
return body, "application/json", "400 Bad Request"
if not _state:
body = json.dumps({"error": "not initialized"}).encode()
return body, "application/json", "503 Service Unavailable"
features = compute_features(_state, symbol)
data = {"symbol": symbol.upper(), "features": features}
body = json.dumps(data, ensure_ascii=False).encode()
return body, "application/json", "200 OK"
elif path == "/stats":
if _stats_fn:
data = _stats_fn()
else:
data = {"error": "not initialized"}
body = json.dumps(data, ensure_ascii=False).encode()
return body, "application/json", "200 OK"
else:
body = json.dumps({"error": "not found"}).encode()
return body, "application/json", "404 Not Found"
async def start_api() -> asyncio.Server:
"""Start the HTTP server."""
server = await asyncio.start_server(
_handler,
settings.http_host,
settings.http_port,
)
logger.info(
"api.started",
extra={
"host": settings.http_host,
"port": settings.http_port,
"endpoints": [
"/health",
"/metrics",
"/state/latest?symbol=",
"/features/latest?symbol=",
"/stats",
],
},
)
return server

View File

@@ -0,0 +1,55 @@
"""
Configuration via pydantic-settings.
All settings from ENV or .env file.
"""
from __future__ import annotations
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# ── NATS ──────────────────────────────────────────────────────────
nats_url: str = "nats://localhost:4222"
nats_subject: str = "md.events.>"
nats_queue_group: str = "senpai-md"
use_jetstream: bool = False
# ── Internal queue ────────────────────────────────────────────────
queue_size: int = 50_000
queue_drop_threshold: float = 0.9 # start dropping at 90%
# ── Features / signals ────────────────────────────────────────────
features_enabled: bool = True
features_pub_rate_hz: float = 10.0 # max publish rate per symbol
features_pub_subject: str = "senpai.features"
signals_pub_subject: str = "senpai.signals"
alerts_pub_subject: str = "senpai.alerts"
# ── Rolling window ────────────────────────────────────────────────
rolling_window_seconds: float = 60.0
# ── Signal rules (rule-based MVP) ─────────────────────────────────
signal_return_threshold: float = 0.003 # 0.3%
signal_volume_threshold: float = 1.0 # min volume in 10s
signal_spread_max_bps: float = 20.0 # max spread in bps
# ── Alert thresholds ──────────────────────────────────────────────
alert_latency_ms: float = 1000.0 # alert if p95 latency > this
alert_gap_seconds: float = 30.0 # alert if no events for N sec
# ── HTTP ──────────────────────────────────────────────────────────
http_host: str = "0.0.0.0"
http_port: int = 8892
# ── Logging ───────────────────────────────────────────────────────
log_level: str = "INFO"
settings = Settings()

View File

@@ -0,0 +1,248 @@
"""
Feature engine — incremental feature computation from rolling windows.
Features (per symbol):
- mid: (bid+ask)/2
- spread_abs: ask - bid
- spread_bps: spread_abs / mid * 10000
- trade_vwap_10s: VWAP over last 10 seconds
- trade_vwap_60s: VWAP over last 60 seconds
- trade_count_10s: number of trades in 10s
- trade_volume_10s: total volume in 10s
- return_10s: mid_now / mid_10s_ago - 1
- realized_vol_60s: std of log-returns over 60s
- latency_ms_p50: p50 exchange-to-receive latency
- latency_ms_p95: p95 exchange-to-receive latency
Rule-based signal (MVP):
- if return_10s > threshold AND volume_10s > threshold AND spread_bps < threshold
→ emit TradeSignal(direction="long")
- opposite for short
"""
from __future__ import annotations
import logging
import math
from senpai.md_consumer.config import settings
from senpai.md_consumer.models import FeatureSnapshot, TradeSignal
from senpai.md_consumer.state import LatestState, TradeRecord
logger = logging.getLogger(__name__)
def compute_features(state: LatestState, symbol: str) -> dict[str, float | None]:
"""
Compute all features for a symbol from current state.
Returns a flat dict of feature_name → value (None if not computable).
"""
sym = symbol.upper()
features: dict[str, float | None] = {}
quote = state.get_latest_quote(sym)
window = state.get_window(sym)
# ── Mid / Spread ──────────────────────────────────────────────────
if quote and quote.bid > 0 and quote.ask > 0:
mid = (quote.bid + quote.ask) / 2
spread_abs = quote.ask - quote.bid
spread_bps = (spread_abs / mid * 10_000) if mid > 0 else None
features["mid"] = mid
features["spread_abs"] = spread_abs
features["spread_bps"] = spread_bps
else:
features["mid"] = None
features["spread_abs"] = None
features["spread_bps"] = None
if not window:
# No rolling data yet — fill with None
features.update({
"trade_vwap_10s": None,
"trade_vwap_60s": None,
"trade_count_10s": None,
"trade_volume_10s": None,
"return_10s": None,
"realized_vol_60s": None,
"latency_ms_p50": None,
"latency_ms_p95": None,
})
return features
# ── VWAP ──────────────────────────────────────────────────────────
trades_10s = window.trades_since(10.0)
trades_60s = list(window.trades)
features["trade_vwap_10s"] = _vwap(trades_10s)
features["trade_vwap_60s"] = _vwap(trades_60s)
# ── Trade count / volume (10s) ────────────────────────────────────
features["trade_count_10s"] = float(len(trades_10s))
features["trade_volume_10s"] = sum(t.size for t in trades_10s) if trades_10s else 0.0
# ── Return 10s ────────────────────────────────────────────────────
features["return_10s"] = _return_over(window, features.get("mid"), 10.0)
# ── Realised volatility 60s ───────────────────────────────────────
features["realized_vol_60s"] = _realized_vol(trades_60s)
# ── Latency ───────────────────────────────────────────────────────
latencies = _latencies_ms(trades_60s)
if latencies:
latencies.sort()
features["latency_ms_p50"] = _percentile(latencies, 50)
features["latency_ms_p95"] = _percentile(latencies, 95)
else:
features["latency_ms_p50"] = None
features["latency_ms_p95"] = None
return features
def make_feature_snapshot(
state: LatestState, symbol: str
) -> FeatureSnapshot:
"""Create a FeatureSnapshot for publishing."""
features = compute_features(state, symbol)
return FeatureSnapshot(symbol=symbol.upper(), features=features)
def check_signal(
features: dict[str, float | None], symbol: str
) -> TradeSignal | None:
"""
Rule-based signal MVP.
Long if:
- return_10s > signal_return_threshold
- trade_volume_10s > signal_volume_threshold
- spread_bps < signal_spread_max_bps
Short if opposite return condition met.
"""
ret = features.get("return_10s")
vol = features.get("trade_volume_10s")
spread = features.get("spread_bps")
if ret is None or vol is None or spread is None:
return None
# Spread filter (both directions)
if spread > settings.signal_spread_max_bps:
return None
# Volume filter
if vol < settings.signal_volume_threshold:
return None
# Direction
if ret > settings.signal_return_threshold:
confidence = min(1.0, ret / (settings.signal_return_threshold * 3))
return TradeSignal(
symbol=symbol.upper(),
direction="long",
confidence=confidence,
reason=f"return_10s={ret:.4f} vol_10s={vol:.2f} spread={spread:.1f}bps",
features=features,
)
elif ret < -settings.signal_return_threshold:
confidence = min(1.0, abs(ret) / (settings.signal_return_threshold * 3))
return TradeSignal(
symbol=symbol.upper(),
direction="short",
confidence=confidence,
reason=f"return_10s={ret:.4f} vol_10s={vol:.2f} spread={spread:.1f}bps",
features=features,
)
return None
# ── Internal helpers ───────────────────────────────────────────────────
def _vwap(trades: list[TradeRecord]) -> float | None:
"""Volume-weighted average price."""
if not trades:
return None
total_value = sum(t.price * t.size for t in trades)
total_volume = sum(t.size for t in trades)
if total_volume <= 0:
return None
return total_value / total_volume
def _return_over(
window, current_mid: float | None, seconds: float
) -> float | None:
"""
Return over last N seconds.
Uses mid price from quotes if available, else latest trade price.
"""
if current_mid is None or current_mid <= 0:
return None
# Find the quote mid from N seconds ago
quotes = window.quotes_since(seconds)
if quotes:
oldest = quotes[0]
old_mid = (oldest.bid + oldest.ask) / 2
if old_mid > 0:
return current_mid / old_mid - 1
# Fallback: use trade prices
trades = window.trades_since(seconds)
if trades:
old_price = trades[0].price
if old_price > 0:
return current_mid / old_price - 1
return None
def _realized_vol(trades: list[TradeRecord]) -> float | None:
"""
Simple realised volatility: std of log-returns of trade prices.
"""
if len(trades) < 3:
return None
prices = [t.price for t in trades if t.price > 0]
if len(prices) < 3:
return None
log_returns = []
for i in range(1, len(prices)):
if prices[i - 1] > 0:
lr = math.log(prices[i] / prices[i - 1])
log_returns.append(lr)
if len(log_returns) < 2:
return None
mean = sum(log_returns) / len(log_returns)
variance = sum((r - mean) ** 2 for r in log_returns) / (len(log_returns) - 1)
return math.sqrt(variance)
def _latencies_ms(trades: list[TradeRecord]) -> list[float]:
"""Extract exchange-to-receive latencies in ms."""
latencies = []
for t in trades:
if t.ts_exchange is not None and t.ts_recv is not None:
lat = (t.ts_recv.timestamp() - t.ts_exchange.timestamp()) * 1000
if 0 < lat < 60_000: # sanity: 0-60s
latencies.append(lat)
return latencies
def _percentile(sorted_data: list[float], p: int) -> float:
"""Simple percentile from sorted list."""
if not sorted_data:
return 0.0
k = (len(sorted_data) - 1) * p / 100
f = math.floor(k)
c = math.ceil(k)
if f == c:
return sorted_data[int(k)]
return sorted_data[f] * (c - k) + sorted_data[c] * (k - f)

View File

@@ -0,0 +1,270 @@
"""
SenpAI Market-Data Consumer — entry point.
Orchestrates:
1. NATS subscription (md.events.>)
2. Event processing → state updates → feature computation
3. Feature/signal/alert publishing back to NATS
4. HTTP API for monitoring
Usage:
python -m senpai.md_consumer
"""
from __future__ import annotations
import asyncio
import logging
import signal
import time
import structlog
from senpai.md_consumer import api
from senpai.md_consumer import metrics as m
from senpai.md_consumer.config import settings
from senpai.md_consumer.features import (
check_signal,
make_feature_snapshot,
compute_features,
)
from senpai.md_consumer.models import (
AlertEvent,
EventType,
TradeEvent,
QuoteEvent,
)
from senpai.md_consumer.nats_consumer import NATSConsumer
from senpai.md_consumer.publisher import Publisher
from senpai.md_consumer.state import LatestState
logger = structlog.get_logger()
# ── Logging setup ──────────────────────────────────────────────────────
def setup_logging() -> None:
log_level = getattr(logging, settings.log_level.upper(), logging.INFO)
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso"),
structlog.dev.ConsoleRenderer(),
],
wrapper_class=structlog.make_filtering_bound_logger(log_level),
context_class=dict,
logger_factory=structlog.PrintLoggerFactory(),
)
logging.basicConfig(level=log_level, format="%(message)s")
# ── Processing pipeline ───────────────────────────────────────────────
async def process_events(
consumer: NATSConsumer,
state: LatestState,
publisher: Publisher,
) -> None:
"""
Main processing loop:
1. Read event from queue
2. Update state
3. Compute features
4. Publish features + check signals
5. Check alerts
"""
last_alert_check = time.monotonic()
events_per_sec_count = 0
time.monotonic()
while True:
try:
event = await consumer.queue.get()
except asyncio.CancelledError:
break
proc_start = time.monotonic()
try:
# Update state based on event type
if event.event_type == EventType.TRADE:
assert isinstance(event, TradeEvent)
state.update_trade(event)
symbol = event.symbol
elif event.event_type == EventType.QUOTE:
assert isinstance(event, QuoteEvent)
state.update_quote(event)
symbol = event.symbol
elif event.event_type == EventType.HEARTBEAT:
# Heartbeats don't update state, just track
symbol = None
elif event.event_type == EventType.BOOK_L2:
# TODO: book updates
symbol = None
else:
symbol = None
# Compute features + publish (only for trade/quote events)
if symbol and settings.features_enabled:
snapshot = make_feature_snapshot(state, symbol)
await publisher.publish_features(snapshot)
# Check for trade signal
sig = check_signal(snapshot.features, symbol)
if sig:
await publisher.publish_signal(sig)
# Processing latency metric
proc_ms = (time.monotonic() - proc_start) * 1000
m.PROCESSING_LATENCY.observe(proc_ms)
# Events/sec tracking
events_per_sec_count += 1
except Exception as e:
logger.error(
"process.error",
error=str(e),
event_type=event.event_type.value if event else "?",
)
# Periodic alert checks (every 5 seconds)
now = time.monotonic()
if now - last_alert_check > 5.0:
last_alert_check = now
await _check_alerts(state, publisher, consumer)
async def _check_alerts(
state: LatestState,
publisher: Publisher,
consumer: NATSConsumer,
) -> None:
"""Check alert conditions and emit if needed."""
# Backpressure alert
fill = consumer.queue_fill_ratio
if fill > 0.8:
await publisher.publish_alert(
AlertEvent(
alert_type="backpressure",
level="warning" if fill < 0.95 else "critical",
message=f"Queue fill at {fill:.0%}",
details={"fill_ratio": fill},
)
)
# Latency alert (per symbol)
for sym in state.symbols:
features = compute_features(state, sym)
p95 = features.get("latency_ms_p95")
if p95 is not None and p95 > settings.alert_latency_ms:
await publisher.publish_alert(
AlertEvent(
alert_type="latency",
level="warning",
message=f"{sym} p95 latency {p95:.0f}ms > {settings.alert_latency_ms}ms",
details={"symbol": sym, "p95_ms": p95},
)
)
# ── Main ───────────────────────────────────────────────────────────────
async def main() -> None:
setup_logging()
logger.info("service.starting", nats_url=settings.nats_url)
# State store
state = LatestState(window_seconds=settings.rolling_window_seconds)
# NATS consumer
consumer = NATSConsumer()
await consumer.connect()
await consumer.subscribe()
# Publisher (reuses same NATS connection)
publisher = Publisher(consumer._nc)
# Wire up API
api.set_state(state)
def _get_stats() -> dict:
return {
"queue_size": consumer.queue.qsize(),
"queue_fill_ratio": round(consumer.queue_fill_ratio, 3),
"queue_max": settings.queue_size,
"events_processed": state.event_count,
"symbols_tracked": state.symbols,
"features_enabled": settings.features_enabled,
"nats_connected": bool(consumer._nc and consumer._nc.is_connected),
}
api.set_stats_fn(_get_stats)
# Start HTTP API
http_server = await api.start_api()
# Start processing loop
process_task = asyncio.create_task(
process_events(consumer, state, publisher)
)
# Graceful shutdown
shutdown_event = asyncio.Event()
def _signal_handler():
logger.info("service.shutdown_signal")
shutdown_event.set()
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, _signal_handler)
except NotImplementedError:
pass
logger.info(
"service.ready",
subject=settings.nats_subject,
queue_group=settings.nats_queue_group,
http_port=settings.http_port,
features_enabled=settings.features_enabled,
)
# Wait for shutdown
await shutdown_event.wait()
# ── Cleanup ───────────────────────────────────────────────────────
logger.info("service.shutting_down")
process_task.cancel()
try:
await process_task
except asyncio.CancelledError:
pass
await consumer.close()
http_server.close()
await http_server.wait_closed()
logger.info(
"service.stopped",
events_processed=state.event_count,
symbols=state.symbols,
)
def cli():
asyncio.run(main())
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,72 @@
"""
Prometheus metrics for SenpAI market-data consumer.
"""
from prometheus_client import Counter, Gauge, Histogram
# ── Inbound events ─────────────────────────────────────────────────────
EVENTS_IN = Counter(
"senpai_events_in_total",
"Total events received from NATS",
["event_type", "provider"],
)
EVENTS_DROPPED = Counter(
"senpai_events_dropped_total",
"Events dropped due to backpressure or errors",
["reason", "event_type"],
)
# ── Queue ──────────────────────────────────────────────────────────────
QUEUE_FILL = Gauge(
"senpai_queue_fill_ratio",
"Internal processing queue fill ratio (0..1)",
)
QUEUE_SIZE = Gauge(
"senpai_queue_size",
"Current number of items in processing queue",
)
# ── Processing ─────────────────────────────────────────────────────────
PROCESSING_LATENCY = Histogram(
"senpai_processing_latency_ms",
"End-to-end processing latency (NATS receive to feature publish) in ms",
buckets=[0.1, 0.5, 1, 2, 5, 10, 25, 50, 100, 250],
)
# ── Feature publishing ─────────────────────────────────────────────────
FEATURE_PUBLISH = Counter(
"senpai_feature_publish_total",
"Total feature snapshots published to NATS",
["symbol"],
)
FEATURE_PUBLISH_ERRORS = Counter(
"senpai_feature_publish_errors_total",
"Failed feature publishes",
["symbol"],
)
# ── Signals ────────────────────────────────────────────────────────────
SIGNALS_EMITTED = Counter(
"senpai_signals_emitted_total",
"Trade signals emitted",
["symbol", "direction"],
)
ALERTS_EMITTED = Counter(
"senpai_alerts_emitted_total",
"Alerts emitted",
["alert_type"],
)
# ── NATS connection ───────────────────────────────────────────────────
NATS_CONNECTED = Gauge(
"senpai_nats_connected",
"Whether NATS connection is alive (1=yes, 0=no)",
)
NATS_RECONNECTS = Counter(
"senpai_nats_reconnects_total",
"Number of NATS reconnections",
)

View File

@@ -0,0 +1,139 @@
"""
Domain models — mirrors market-data-service event contracts.
Tolerant parsing: unknown fields ignored, partial data accepted.
"""
from __future__ import annotations
import time
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class EventType(str, Enum):
TRADE = "trade"
QUOTE = "quote"
BOOK_L2 = "book_l2"
HEARTBEAT = "heartbeat"
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _mono_ns() -> int:
return time.monotonic_ns()
class BaseEvent(BaseModel, extra="ignore"):
"""Common fields — extra fields silently ignored."""
event_type: EventType
provider: str
ts_recv: datetime = Field(default_factory=_utc_now)
ts_recv_mono_ns: int = Field(default_factory=_mono_ns)
class TradeEvent(BaseEvent):
event_type: EventType = EventType.TRADE
symbol: str
price: float
size: float
ts_exchange: Optional[datetime] = None
side: Optional[str] = None
trade_id: Optional[str] = None
class QuoteEvent(BaseEvent):
event_type: EventType = EventType.QUOTE
symbol: str
bid: float
ask: float
bid_size: float
ask_size: float
ts_exchange: Optional[datetime] = None
class BookLevel(BaseModel, extra="ignore"):
price: float
size: float
class BookL2Event(BaseEvent):
event_type: EventType = EventType.BOOK_L2
symbol: str
bids: list[BookLevel] = Field(default_factory=list)
asks: list[BookLevel] = Field(default_factory=list)
ts_exchange: Optional[datetime] = None
class HeartbeatEvent(BaseEvent):
event_type: EventType = EventType.HEARTBEAT
# Union for parsing
Event = TradeEvent | QuoteEvent | BookL2Event | HeartbeatEvent
# ── Output models ──────────────────────────────────────────────────────
class FeatureSnapshot(BaseModel):
"""Published to senpai.features.{symbol}."""
symbol: str
ts: datetime = Field(default_factory=_utc_now)
features: dict[str, float | None]
class TradeSignal(BaseModel):
"""Published to senpai.signals.{symbol}."""
symbol: str
ts: datetime = Field(default_factory=_utc_now)
direction: str # "long" | "short"
confidence: float = 0.0 # 0..1
reason: str = ""
features: dict[str, float | None] = Field(default_factory=dict)
class AlertEvent(BaseModel):
"""Published to senpai.alerts."""
ts: datetime = Field(default_factory=_utc_now)
level: str = "warning" # "warning" | "critical"
alert_type: str # "latency" | "gap" | "backpressure"
message: str
details: dict = Field(default_factory=dict)
# ── Parsing helper ─────────────────────────────────────────────────────
_EVENT_MAP: dict[str, type[BaseEvent]] = {
"trade": TradeEvent,
"quote": QuoteEvent,
"book_l2": BookL2Event,
"heartbeat": HeartbeatEvent,
}
def parse_event(data: dict) -> Event | None:
"""
Parse a dict (from JSON) into the appropriate Event model.
Returns None if event_type is unknown or data is invalid.
"""
event_type = data.get("event_type")
if not event_type:
return None
cls = _EVENT_MAP.get(event_type)
if cls is None:
return None
try:
return cls.model_validate(data)
except Exception:
return None

View File

@@ -0,0 +1,229 @@
"""
NATS consumer — subscribes to md.events.> and feeds the processing pipeline.
Features:
- Queue group subscription (horizontal scaling)
- Bounded asyncio.Queue with backpressure drop policy
- Auto-reconnect via nats-py
- Optional JetStream durable consumer
"""
from __future__ import annotations
import asyncio
import json
import logging
import nats
from nats.aio.client import Client as NatsClient
from nats.aio.msg import Msg
from senpai.md_consumer.config import settings
from senpai.md_consumer.models import EventType, parse_event, Event
from senpai.md_consumer import metrics as m
logger = logging.getLogger(__name__)
# Events that can be dropped under backpressure (lowest priority first)
_DROPPABLE = {EventType.HEARTBEAT, EventType.QUOTE, EventType.BOOK_L2}
class NATSConsumer:
"""
Reads normalised events from NATS, validates, and puts into
a bounded asyncio.Queue for downstream processing.
Backpressure policy:
- Queue < 90% → accept all events
- Queue >= 90% → drop heartbeats, quotes, book snapshots
- Trades are NEVER dropped (critical for analytics)
"""
def __init__(self) -> None:
self._nc: NatsClient | None = None
self._sub = None
self._js_sub = None
self._queue: asyncio.Queue[Event] = asyncio.Queue(
maxsize=settings.queue_size
)
self._running = False
self._drop_count: dict[str, int] = {}
@property
def queue(self) -> asyncio.Queue[Event]:
return self._queue
@property
def queue_fill_ratio(self) -> float:
if settings.queue_size <= 0:
return 0.0
return self._queue.qsize() / settings.queue_size
async def connect(self) -> None:
"""Connect to NATS with auto-reconnect."""
self._nc = await nats.connect(
self._url,
reconnect_time_wait=2,
max_reconnect_attempts=-1,
name="senpai-md-consumer",
error_cb=self._on_error,
disconnected_cb=self._on_disconnected,
reconnected_cb=self._on_reconnected,
closed_cb=self._on_closed,
)
m.NATS_CONNECTED.set(1)
logger.info(
"nats.connected",
extra={"url": self._url, "subject": settings.nats_subject},
)
@property
def _url(self) -> str:
return settings.nats_url
async def subscribe(self) -> None:
"""Subscribe to market data events."""
if not self._nc:
raise RuntimeError("Not connected. Call connect() first.")
if settings.use_jetstream:
await self._subscribe_jetstream()
else:
await self._subscribe_core()
async def _subscribe_core(self) -> None:
"""Core NATS subscription with queue group."""
self._sub = await self._nc.subscribe(
settings.nats_subject,
queue=settings.nats_queue_group,
cb=self._on_message,
)
logger.info(
"nats.subscribed_core",
extra={
"subject": settings.nats_subject,
"queue_group": settings.nats_queue_group,
},
)
async def _subscribe_jetstream(self) -> None:
"""JetStream durable subscription."""
js = self._nc.jetstream()
# Try to create or bind to existing consumer
self._js_sub = await js.subscribe(
settings.nats_subject,
queue=settings.nats_queue_group,
durable="senpai-md-durable",
cb=self._on_message,
manual_ack=True,
)
logger.info(
"nats.subscribed_jetstream",
extra={
"subject": settings.nats_subject,
"durable": "senpai-md-durable",
},
)
async def _on_message(self, msg: Msg) -> None:
"""
Callback for each NATS message.
Parse → backpressure check → enqueue.
"""
try:
data = json.loads(msg.data)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
m.EVENTS_DROPPED.labels(reason="parse_error", event_type="unknown").inc()
logger.warning("nats.parse_error", extra={"error": str(e)})
if settings.use_jetstream:
await msg.ack()
return
event = parse_event(data)
if event is None:
m.EVENTS_DROPPED.labels(reason="invalid_event", event_type="unknown").inc()
if settings.use_jetstream:
await msg.ack()
return
# Track inbound
m.EVENTS_IN.labels(
event_type=event.event_type.value,
provider=event.provider,
).inc()
# Backpressure check
fill = self.queue_fill_ratio
m.QUEUE_FILL.set(fill)
m.QUEUE_SIZE.set(self._queue.qsize())
if fill >= settings.queue_drop_threshold:
# Under pressure: only accept trades
if event.event_type in _DROPPABLE:
et = event.event_type.value
self._drop_count[et] = self._drop_count.get(et, 0) + 1
m.EVENTS_DROPPED.labels(
reason="backpressure",
event_type=et,
).inc()
if self._drop_count[et] % 1000 == 1:
logger.warning(
"nats.backpressure_drop",
extra={
"type": et,
"fill": f"{fill:.0%}",
"total_drops": self._drop_count,
},
)
if settings.use_jetstream:
await msg.ack()
return
# Enqueue
try:
self._queue.put_nowait(event)
except asyncio.QueueFull:
# Last resort: try to drop oldest non-trade
m.EVENTS_DROPPED.labels(
reason="queue_full", event_type=event.event_type.value
).inc()
if settings.use_jetstream:
await msg.ack()
async def close(self) -> None:
"""Graceful shutdown."""
self._running = False
if self._sub:
try:
await self._sub.unsubscribe()
except Exception:
pass
if self._nc:
try:
await self._nc.flush(timeout=5)
await self._nc.close()
except Exception:
pass
m.NATS_CONNECTED.set(0)
logger.info("nats.closed", extra={"drops": self._drop_count})
# ── NATS callbacks ────────────────────────────────────────────────
async def _on_error(self, e: Exception) -> None:
logger.error("nats.error", extra={"error": str(e)})
async def _on_disconnected(self) -> None:
m.NATS_CONNECTED.set(0)
logger.warning("nats.disconnected")
async def _on_reconnected(self) -> None:
m.NATS_CONNECTED.set(1)
m.NATS_RECONNECTS.inc()
logger.info("nats.reconnected")
async def _on_closed(self) -> None:
m.NATS_CONNECTED.set(0)
logger.info("nats.closed_callback")

View File

@@ -0,0 +1,119 @@
"""
Signal bus publisher — publishes features, signals, and alerts to NATS.
Rate-limiting: max N publishes per second per symbol (configurable).
"""
from __future__ import annotations
import logging
import time
from nats.aio.client import Client as NatsClient
from senpai.md_consumer.config import settings
from senpai.md_consumer.models import AlertEvent, FeatureSnapshot, TradeSignal
from senpai.md_consumer import metrics as m
logger = logging.getLogger(__name__)
class Publisher:
"""
Publishes FeatureSnapshots and TradeSignals to NATS.
Built-in per-symbol rate limiter.
"""
def __init__(self, nc: NatsClient) -> None:
self._nc = nc
self._last_publish: dict[str, float] = {} # symbol → monotonic time
self._min_interval = (
1.0 / settings.features_pub_rate_hz
if settings.features_pub_rate_hz > 0
else 0.1
)
def _rate_ok(self, symbol: str) -> bool:
"""Check if we can publish for this symbol (rate limiter)."""
now = time.monotonic()
last = self._last_publish.get(symbol, 0.0)
if now - last >= self._min_interval:
self._last_publish[symbol] = now
return True
return False
async def publish_features(self, snapshot: FeatureSnapshot) -> bool:
"""
Publish feature snapshot if rate limit allows.
Returns True if published, False if rate-limited or error.
"""
if not settings.features_enabled:
return False
symbol = snapshot.symbol.upper()
if not self._rate_ok(symbol):
return False
subject = f"{settings.features_pub_subject}.{symbol}"
try:
payload = snapshot.model_dump_json().encode("utf-8")
await self._nc.publish(subject, payload)
m.FEATURE_PUBLISH.labels(symbol=symbol).inc()
return True
except Exception as e:
m.FEATURE_PUBLISH_ERRORS.labels(symbol=symbol).inc()
logger.warning(
"publisher.feature_error",
extra={"symbol": symbol, "error": str(e)},
)
return False
async def publish_signal(self, signal: TradeSignal) -> bool:
"""Publish trade signal (no rate limit — signals are rare)."""
subject = f"{settings.signals_pub_subject}.{signal.symbol}"
try:
payload = signal.model_dump_json().encode("utf-8")
await self._nc.publish(subject, payload)
m.SIGNALS_EMITTED.labels(
symbol=signal.symbol,
direction=signal.direction,
).inc()
logger.info(
"publisher.signal_emitted",
extra={
"symbol": signal.symbol,
"direction": signal.direction,
"confidence": f"{signal.confidence:.2f}",
"reason": signal.reason,
},
)
return True
except Exception as e:
logger.error(
"publisher.signal_error",
extra={"symbol": signal.symbol, "error": str(e)},
)
return False
async def publish_alert(self, alert: AlertEvent) -> bool:
"""Publish alert event."""
subject = settings.alerts_pub_subject
try:
payload = alert.model_dump_json().encode("utf-8")
await self._nc.publish(subject, payload)
m.ALERTS_EMITTED.labels(alert_type=alert.alert_type).inc()
logger.warning(
"publisher.alert",
extra={
"type": alert.alert_type,
"level": alert.level,
"message": alert.message,
},
)
return True
except Exception as e:
logger.error(
"publisher.alert_error",
extra={"error": str(e)},
)
return False

View File

@@ -0,0 +1,238 @@
"""
State management — LatestState + RollingWindow.
All structures are asyncio-safe (no locks needed — single-threaded event loop).
"""
from __future__ import annotations
import time
from collections import deque
from dataclasses import dataclass
from datetime import datetime
from typing import Optional
from senpai.md_consumer.models import QuoteEvent, TradeEvent
@dataclass
class LatestTrade:
symbol: str
price: float
size: float
side: Optional[str]
provider: str
ts_recv: datetime
ts_exchange: Optional[datetime] = None
@dataclass
class LatestQuote:
symbol: str
bid: float
ask: float
bid_size: float
ask_size: float
provider: str
ts_recv: datetime
ts_exchange: Optional[datetime] = None
@dataclass
class TradeRecord:
"""Compact trade record for rolling window."""
price: float
size: float
ts: float # monotonic seconds
ts_exchange: Optional[datetime] = None
ts_recv: Optional[datetime] = None
@dataclass
class QuoteRecord:
"""Compact quote record for rolling window."""
bid: float
ask: float
bid_size: float
ask_size: float
ts: float # monotonic seconds
ts_exchange: Optional[datetime] = None
ts_recv: Optional[datetime] = None
class RollingWindow:
"""
Fixed-duration rolling window using deque.
Efficient: O(1) append, amortised O(1) eviction.
No pandas dependency.
"""
def __init__(self, window_seconds: float = 60.0) -> None:
self._window = window_seconds
self._trades: deque[TradeRecord] = deque()
self._quotes: deque[QuoteRecord] = deque()
def add_trade(self, trade: TradeRecord) -> None:
self._trades.append(trade)
self._evict_trades()
def add_quote(self, quote: QuoteRecord) -> None:
self._quotes.append(quote)
self._evict_quotes()
def _evict_trades(self) -> None:
cutoff = time.monotonic() - self._window
while self._trades and self._trades[0].ts < cutoff:
self._trades.popleft()
def _evict_quotes(self) -> None:
cutoff = time.monotonic() - self._window
while self._quotes and self._quotes[0].ts < cutoff:
self._quotes.popleft()
@property
def trades(self) -> deque[TradeRecord]:
self._evict_trades()
return self._trades
@property
def quotes(self) -> deque[QuoteRecord]:
self._evict_quotes()
return self._quotes
def trades_since(self, seconds_ago: float) -> list[TradeRecord]:
"""Return trades within the last N seconds."""
cutoff = time.monotonic() - seconds_ago
return [t for t in self._trades if t.ts >= cutoff]
def quotes_since(self, seconds_ago: float) -> list[QuoteRecord]:
"""Return quotes within the last N seconds."""
cutoff = time.monotonic() - seconds_ago
return [q for q in self._quotes if q.ts >= cutoff]
class LatestState:
"""
Maintains latest trade/quote per symbol + rolling windows.
"""
def __init__(self, window_seconds: float = 60.0) -> None:
self._window_seconds = window_seconds
self._latest_trade: dict[str, LatestTrade] = {}
self._latest_quote: dict[str, LatestQuote] = {}
self._windows: dict[str, RollingWindow] = {}
self._event_count = 0
def _get_window(self, symbol: str) -> RollingWindow:
if symbol not in self._windows:
self._windows[symbol] = RollingWindow(self._window_seconds)
return self._windows[symbol]
def update_trade(self, event: TradeEvent) -> None:
"""Update latest trade and rolling window."""
sym = event.symbol.upper()
self._latest_trade[sym] = LatestTrade(
symbol=sym,
price=event.price,
size=event.size,
side=event.side,
provider=event.provider,
ts_recv=event.ts_recv,
ts_exchange=event.ts_exchange,
)
self._get_window(sym).add_trade(
TradeRecord(
price=event.price,
size=event.size,
ts=time.monotonic(),
ts_exchange=event.ts_exchange,
ts_recv=event.ts_recv,
)
)
self._event_count += 1
def update_quote(self, event: QuoteEvent) -> None:
"""Update latest quote and rolling window."""
sym = event.symbol.upper()
self._latest_quote[sym] = LatestQuote(
symbol=sym,
bid=event.bid,
ask=event.ask,
bid_size=event.bid_size,
ask_size=event.ask_size,
provider=event.provider,
ts_recv=event.ts_recv,
ts_exchange=event.ts_exchange,
)
self._get_window(sym).add_quote(
QuoteRecord(
bid=event.bid,
ask=event.ask,
bid_size=event.bid_size,
ask_size=event.ask_size,
ts=time.monotonic(),
ts_exchange=event.ts_exchange,
ts_recv=event.ts_recv,
)
)
self._event_count += 1
def get_latest_trade(self, symbol: str) -> LatestTrade | None:
return self._latest_trade.get(symbol.upper())
def get_latest_quote(self, symbol: str) -> LatestQuote | None:
return self._latest_quote.get(symbol.upper())
def get_window(self, symbol: str) -> RollingWindow | None:
return self._windows.get(symbol.upper())
@property
def symbols(self) -> list[str]:
return sorted(
set(list(self._latest_trade.keys()) + list(self._latest_quote.keys()))
)
@property
def event_count(self) -> int:
return self._event_count
def to_dict(self, symbol: str) -> dict:
"""Serialise latest state for API."""
sym = symbol.upper()
result: dict = {"symbol": sym}
trade = self._latest_trade.get(sym)
if trade:
result["latest_trade"] = {
"price": trade.price,
"size": trade.size,
"side": trade.side,
"provider": trade.provider,
"ts_recv": trade.ts_recv.isoformat() if trade.ts_recv else None,
}
quote = self._latest_quote.get(sym)
if quote:
result["latest_quote"] = {
"bid": quote.bid,
"ask": quote.ask,
"bid_size": quote.bid_size,
"ask_size": quote.ask_size,
"provider": quote.provider,
"ts_recv": quote.ts_recv.isoformat() if quote.ts_recv else None,
}
window = self._windows.get(sym)
if window:
result["window"] = {
"trades_count": len(window.trades),
"quotes_count": len(window.quotes),
}
return result

View File

@@ -0,0 +1,212 @@
"""
Test feature computations — deterministic scenarios.
"""
import pytest
from senpai.md_consumer.features import (
_percentile,
_realized_vol,
_vwap,
check_signal,
compute_features,
)
from senpai.md_consumer.models import QuoteEvent, TradeEvent
from senpai.md_consumer.state import LatestState, TradeRecord
# ── VWAP ───────────────────────────────────────────────────────────────
def test_vwap_basic():
trades = [
TradeRecord(price=100.0, size=10.0, ts=0),
TradeRecord(price=200.0, size=10.0, ts=0),
]
# VWAP = (100*10 + 200*10) / (10+10) = 150
assert _vwap(trades) == 150.0
def test_vwap_weighted():
trades = [
TradeRecord(price=100.0, size=90.0, ts=0),
TradeRecord(price=200.0, size=10.0, ts=0),
]
# VWAP = (100*90 + 200*10) / 100 = 110
assert _vwap(trades) == 110.0
def test_vwap_empty():
assert _vwap([]) is None
def test_vwap_zero_volume():
trades = [TradeRecord(price=100.0, size=0.0, ts=0)]
assert _vwap(trades) is None
# ── Realized volatility ───────────────────────────────────────────────
def test_realized_vol_constant_price():
"""Constant price → 0 volatility."""
trades = [TradeRecord(price=100.0, size=1.0, ts=0) for _ in range(10)]
vol = _realized_vol(trades)
assert vol is not None
assert vol == 0.0
def test_realized_vol_two_prices():
"""Not enough data points → None."""
trades = [
TradeRecord(price=100.0, size=1.0, ts=0),
TradeRecord(price=101.0, size=1.0, ts=0),
]
assert _realized_vol(trades) is None # needs at least 3
def test_realized_vol_positive():
"""Variable prices should give positive volatility."""
trades = [
TradeRecord(price=100.0, size=1.0, ts=0),
TradeRecord(price=102.0, size=1.0, ts=0),
TradeRecord(price=99.0, size=1.0, ts=0),
TradeRecord(price=103.0, size=1.0, ts=0),
]
vol = _realized_vol(trades)
assert vol is not None
assert vol > 0
# ── Percentile ─────────────────────────────────────────────────────────
def test_percentile_basic():
data = [1.0, 2.0, 3.0, 4.0, 5.0]
assert _percentile(data, 50) == 3.0
assert _percentile(data, 0) == 1.0
assert _percentile(data, 100) == 5.0
def test_percentile_p95():
data = list(range(1, 101)) # 1..100
data_float = [float(x) for x in data]
p95 = _percentile(data_float, 95)
assert 95 <= p95 <= 96
# ── Full feature computation ──────────────────────────────────────────
def test_compute_features_with_state():
state = LatestState(window_seconds=60.0)
# Add quote
state.update_quote(QuoteEvent(
provider="binance",
symbol="BTCUSDT",
bid=70000.0,
ask=70002.0,
bid_size=5.0,
ask_size=3.0,
))
# Add some trades
for i in range(5):
state.update_trade(TradeEvent(
provider="binance",
symbol="BTCUSDT",
price=70000.0 + i * 10,
size=1.0,
))
features = compute_features(state, "BTCUSDT")
# Mid
assert features["mid"] == pytest.approx(70001.0)
# Spread
assert features["spread_abs"] == pytest.approx(2.0)
assert features["spread_bps"] is not None
assert features["spread_bps"] > 0
# Trade count
assert features["trade_count_10s"] == 5.0
# Volume
assert features["trade_volume_10s"] == 5.0
# VWAP should be defined
assert features["trade_vwap_10s"] is not None
assert features["trade_vwap_60s"] is not None
def test_compute_features_no_data():
state = LatestState(window_seconds=60.0)
features = compute_features(state, "BTCUSDT")
# All should be None
assert features["mid"] is None
assert features["spread_abs"] is None
assert features["trade_vwap_10s"] is None
# ── Signal detection ──────────────────────────────────────────────────
def test_check_signal_long():
"""Strong positive return + volume + tight spread → long signal."""
features = {
"return_10s": 0.005, # 0.5% (> 0.3% threshold)
"trade_volume_10s": 5.0, # > 1.0 threshold
"spread_bps": 3.0, # < 20 bps threshold
}
signal = check_signal(features, "BTCUSDT")
assert signal is not None
assert signal.direction == "long"
assert signal.confidence > 0
def test_check_signal_short():
"""Strong negative return → short signal."""
features = {
"return_10s": -0.005,
"trade_volume_10s": 5.0,
"spread_bps": 3.0,
}
signal = check_signal(features, "BTCUSDT")
assert signal is not None
assert signal.direction == "short"
def test_check_signal_no_trigger():
"""Small return → no signal."""
features = {
"return_10s": 0.0001,
"trade_volume_10s": 5.0,
"spread_bps": 3.0,
}
signal = check_signal(features, "BTCUSDT")
assert signal is None
def test_check_signal_wide_spread():
"""Wide spread → no signal (even with strong return)."""
features = {
"return_10s": 0.01,
"trade_volume_10s": 5.0,
"spread_bps": 50.0, # > 20 bps
}
signal = check_signal(features, "BTCUSDT")
assert signal is None
def test_check_signal_low_volume():
"""Low volume → no signal."""
features = {
"return_10s": 0.01,
"trade_volume_10s": 0.1, # < 1.0
"spread_bps": 3.0,
}
signal = check_signal(features, "BTCUSDT")
assert signal is None

View File

@@ -0,0 +1,154 @@
"""
Test event parsing from JSON payloads (mirrors market-data-service contracts).
"""
import json
from senpai.md_consumer.models import (
EventType,
TradeEvent,
QuoteEvent,
HeartbeatEvent,
parse_event,
)
# ── Trade events ───────────────────────────────────────────────────────
def test_parse_trade_basic():
data = {
"event_type": "trade",
"provider": "binance",
"symbol": "BTCUSDT",
"price": 70500.0,
"size": 1.5,
"ts_recv": "2026-02-09T12:00:00+00:00",
}
event = parse_event(data)
assert event is not None
assert isinstance(event, TradeEvent)
assert event.event_type == EventType.TRADE
assert event.symbol == "BTCUSDT"
assert event.price == 70500.0
assert event.size == 1.5
assert event.provider == "binance"
def test_parse_trade_with_extra_fields():
"""Unknown fields should be silently ignored (tolerant parsing)."""
data = {
"event_type": "trade",
"provider": "bybit",
"symbol": "ETHUSDT",
"price": 2100.0,
"size": 10.0,
"ts_recv": "2026-02-09T12:00:00+00:00",
"unknown_field": "should_be_ignored",
"another_extra": 42,
}
event = parse_event(data)
assert event is not None
assert event.symbol == "ETHUSDT"
def test_parse_trade_with_side_and_exchange_ts():
data = {
"event_type": "trade",
"provider": "binance",
"symbol": "BTCUSDT",
"price": 70000.0,
"size": 0.5,
"side": "buy",
"ts_exchange": "2026-02-09T12:00:00+00:00",
"ts_recv": "2026-02-09T12:00:00.100+00:00",
"trade_id": "t12345",
}
event = parse_event(data)
assert event.side == "buy"
assert event.trade_id == "t12345"
assert event.ts_exchange is not None
# ── Quote events ───────────────────────────────────────────────────────
def test_parse_quote_basic():
data = {
"event_type": "quote",
"provider": "binance",
"symbol": "BTCUSDT",
"bid": 70000.0,
"ask": 70001.0,
"bid_size": 5.0,
"ask_size": 3.0,
"ts_recv": "2026-02-09T12:00:00+00:00",
}
event = parse_event(data)
assert isinstance(event, QuoteEvent)
assert event.bid == 70000.0
assert event.ask == 70001.0
def test_parse_quote_zero_values():
data = {
"event_type": "quote",
"provider": "binance",
"symbol": "BTCUSDT",
"bid": 0.0,
"ask": 0.0,
"bid_size": 0.0,
"ask_size": 0.0,
}
event = parse_event(data)
assert event is not None
assert event.bid == 0.0
# ── Heartbeat events ──────────────────────────────────────────────────
def test_parse_heartbeat():
data = {
"event_type": "heartbeat",
"provider": "alpaca",
"ts_recv": "2026-02-09T12:00:00+00:00",
}
event = parse_event(data)
assert isinstance(event, HeartbeatEvent)
assert event.provider == "alpaca"
# ── Edge cases ─────────────────────────────────────────────────────────
def test_parse_unknown_type():
data = {"event_type": "unknown_type", "provider": "test"}
event = parse_event(data)
assert event is None
def test_parse_missing_type():
data = {"provider": "test", "symbol": "BTC"}
event = parse_event(data)
assert event is None
def test_parse_invalid_data():
data = {"event_type": "trade"} # missing required fields
event = parse_event(data)
assert event is None
def test_parse_empty_dict():
event = parse_event({})
assert event is None
def test_parse_from_json_bytes():
"""Simulate actual NATS message deserialization."""
raw = b'{"event_type":"trade","provider":"binance","symbol":"BTCUSDT","price":70500.0,"size":1.5}'
data = json.loads(raw)
event = parse_event(data)
assert event is not None
assert event.price == 70500.0

View File

@@ -0,0 +1,111 @@
"""
Test publisher rate limiting.
"""
from unittest.mock import AsyncMock
import pytest
from senpai.md_consumer.publisher import Publisher
from senpai.md_consumer.models import FeatureSnapshot, TradeSignal
@pytest.fixture
def mock_nc():
"""Mock NATS client."""
nc = AsyncMock()
nc.publish = AsyncMock()
return nc
@pytest.fixture
def publisher(mock_nc):
return Publisher(mock_nc)
@pytest.mark.asyncio
async def test_publish_features_respects_rate_limit(mock_nc, publisher):
"""Second publish for same symbol within rate window should be skipped."""
snapshot = FeatureSnapshot(
symbol="BTCUSDT",
features={"mid": 70000.0},
)
# First publish should succeed
result1 = await publisher.publish_features(snapshot)
assert result1 is True
# Immediate second publish should be rate-limited
result2 = await publisher.publish_features(snapshot)
assert result2 is False # rate-limited
# Only one actual NATS publish
assert mock_nc.publish.call_count == 1
@pytest.mark.asyncio
async def test_publish_features_different_symbols(mock_nc, publisher):
"""Different symbols have independent rate limiters."""
snap1 = FeatureSnapshot(symbol="BTCUSDT", features={"mid": 70000.0})
snap2 = FeatureSnapshot(symbol="ETHUSDT", features={"mid": 2000.0})
r1 = await publisher.publish_features(snap1)
r2 = await publisher.publish_features(snap2)
assert r1 is True
assert r2 is True
assert mock_nc.publish.call_count == 2
@pytest.mark.asyncio
async def test_publish_signal_no_rate_limit(mock_nc, publisher):
"""Signals are NOT rate limited."""
signal = TradeSignal(
symbol="BTCUSDT",
direction="long",
confidence=0.8,
reason="test",
)
r1 = await publisher.publish_signal(signal)
r2 = await publisher.publish_signal(signal)
assert r1 is True
assert r2 is True
assert mock_nc.publish.call_count == 2
@pytest.mark.asyncio
async def test_publish_features_after_rate_window(mock_nc, publisher):
"""After rate window passes, publish should succeed again."""
# Override min interval to something very small for testing
publisher._min_interval = 0.01 # 10ms
snapshot = FeatureSnapshot(
symbol="BTCUSDT",
features={"mid": 70000.0},
)
r1 = await publisher.publish_features(snapshot)
assert r1 is True
# Wait for rate window to pass
import asyncio
await asyncio.sleep(0.02)
r2 = await publisher.publish_features(snapshot)
assert r2 is True
assert mock_nc.publish.call_count == 2
@pytest.mark.asyncio
async def test_publish_handles_nats_error(mock_nc, publisher):
"""NATS publish error should not raise, just return False."""
mock_nc.publish.side_effect = Exception("NATS down")
snapshot = FeatureSnapshot(
symbol="BTCUSDT",
features={"mid": 70000.0},
)
result = await publisher.publish_features(snapshot)
assert result is False

View File

@@ -0,0 +1,138 @@
"""
Test state management — LatestState and RollingWindow.
"""
import time
from senpai.md_consumer.state import (
LatestState,
RollingWindow,
TradeRecord,
)
from senpai.md_consumer.models import TradeEvent, QuoteEvent
# ── RollingWindow ──────────────────────────────────────────────────────
def test_rolling_window_add_trade():
w = RollingWindow(window_seconds=60.0)
t = TradeRecord(price=100.0, size=1.0, ts=time.monotonic())
w.add_trade(t)
assert len(w.trades) == 1
assert w.trades[0].price == 100.0
def test_rolling_window_eviction():
"""Old records should be evicted."""
w = RollingWindow(window_seconds=1.0) # 1 second window
old_ts = time.monotonic() - 2.0 # 2 seconds ago
w.add_trade(TradeRecord(price=100.0, size=1.0, ts=old_ts))
w.add_trade(TradeRecord(price=200.0, size=2.0, ts=time.monotonic()))
# Old record should be evicted
trades = list(w.trades)
assert len(trades) == 1
assert trades[0].price == 200.0
def test_rolling_window_trades_since():
w = RollingWindow(window_seconds=60.0)
now = time.monotonic()
# Add trades at different times
w.add_trade(TradeRecord(price=100.0, size=1.0, ts=now - 30)) # 30s ago
w.add_trade(TradeRecord(price=200.0, size=2.0, ts=now - 5)) # 5s ago
w.add_trade(TradeRecord(price=300.0, size=3.0, ts=now)) # now
last_10s = w.trades_since(10.0)
assert len(last_10s) == 2 # 5s ago + now
assert last_10s[0].price == 200.0
def test_rolling_window_empty():
w = RollingWindow(window_seconds=60.0)
assert len(w.trades) == 0
assert len(w.quotes) == 0
assert w.trades_since(10.0) == []
# ── LatestState ────────────────────────────────────────────────────────
def test_latest_state_update_trade():
state = LatestState(window_seconds=60.0)
event = TradeEvent(
provider="binance",
symbol="BTCUSDT",
price=70500.0,
size=1.5,
side="buy",
)
state.update_trade(event)
latest = state.get_latest_trade("BTCUSDT")
assert latest is not None
assert latest.price == 70500.0
assert latest.side == "buy"
assert state.event_count == 1
def test_latest_state_update_quote():
state = LatestState(window_seconds=60.0)
event = QuoteEvent(
provider="binance",
symbol="BTCUSDT",
bid=70000.0,
ask=70001.0,
bid_size=5.0,
ask_size=3.0,
)
state.update_quote(event)
latest = state.get_latest_quote("BTCUSDT")
assert latest is not None
assert latest.bid == 70000.0
assert latest.ask == 70001.0
def test_latest_state_symbols():
state = LatestState(window_seconds=60.0)
state.update_trade(TradeEvent(
provider="binance", symbol="BTCUSDT", price=100.0, size=1.0
))
state.update_quote(QuoteEvent(
provider="binance", symbol="ETHUSDT",
bid=2000.0, ask=2001.0, bid_size=1.0, ask_size=1.0,
))
assert "BTCUSDT" in state.symbols
assert "ETHUSDT" in state.symbols
def test_latest_state_to_dict():
state = LatestState(window_seconds=60.0)
state.update_trade(TradeEvent(
provider="binance", symbol="BTCUSDT", price=70500.0, size=1.0
))
state.update_quote(QuoteEvent(
provider="binance", symbol="BTCUSDT",
bid=70000.0, ask=70001.0, bid_size=1.0, ask_size=1.0,
))
d = state.to_dict("BTCUSDT")
assert d["symbol"] == "BTCUSDT"
assert "latest_trade" in d
assert "latest_quote" in d
assert d["latest_trade"]["price"] == 70500.0
def test_latest_state_missing_symbol():
state = LatestState(window_seconds=60.0)
assert state.get_latest_trade("NOPE") is None
assert state.get_latest_quote("NOPE") is None