feat: MD pipeline — market-data-service hardening + SenpAI NATS consumer
Producer (market-data-service):
- Backpressure: smart drop policy (heartbeats→quotes→trades preserved)
- Heartbeat monitor: synthetic HeartbeatEvent on provider silence
- Graceful shutdown: WS→bus→storage→DB engine cleanup sequence
- Bybit V5 public WS provider (backup for Binance, no API key needed)
- FailoverManager: health-based provider switching with recovery
- NATS output adapter: md.events.{type}.{symbol} for SenpAI
- /bus-stats endpoint for backpressure monitoring
- Dockerfile + docker-compose.node1.yml integration
- 36 tests (parsing + bus + failover), requirements.lock
Consumer (senpai-md-consumer):
- NATSConsumer: subscribe md.events.>, queue group senpai-md, backpressure
- State store: LatestState + RollingWindow (deque, 60s)
- Feature engine: 11 features (mid, spread, VWAP, return, vol, latency)
- Rule-based signals: long/short on return+volume+spread conditions
- Publisher: rate-limited features + signals + alerts to NATS
- HTTP API: /health, /metrics, /state/latest, /features/latest, /stats
- 10 Prometheus metrics
- Dockerfile + docker-compose.senpai.yml
- 41 tests (parsing + state + features + rate-limit), requirements.lock
CI: ruff + pytest + smoke import for both services
Tests: 77 total passed, lint clean
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -19,6 +19,9 @@ class Settings(BaseSettings):
|
||||
binance_ws_url: str = "wss://stream.binance.com:9443/ws"
|
||||
binance_rest_url: str = "https://api.binance.com"
|
||||
|
||||
# ── Bybit (backup crypto — no key needed) ──────────────────────────
|
||||
bybit_ws_url: str = "wss://stream.bybit.com/v5/public/spot"
|
||||
|
||||
# ── Alpaca (paper trading — free tier) ─────────────────────────────
|
||||
alpaca_key: str = ""
|
||||
alpaca_secret: str = ""
|
||||
@@ -41,6 +44,11 @@ class Settings(BaseSettings):
|
||||
http_port: int = 8891
|
||||
metrics_enabled: bool = True
|
||||
|
||||
# ── NATS output adapter ─────────────────────────────────────────────
|
||||
nats_url: str = "" # e.g. "nats://localhost:4222"
|
||||
nats_subject_prefix: str = "md.events" # → md.events.trade.BTCUSDT
|
||||
nats_enabled: bool = False
|
||||
|
||||
# ── Logging ────────────────────────────────────────────────────────
|
||||
log_level: str = "INFO"
|
||||
log_sample_rate: int = 100 # PrintConsumer: log 1 out of N events
|
||||
@@ -49,5 +57,9 @@ class Settings(BaseSettings):
|
||||
def alpaca_configured(self) -> bool:
|
||||
return bool(self.alpaca_key and self.alpaca_secret)
|
||||
|
||||
@property
|
||||
def nats_configured(self) -> bool:
|
||||
return bool(self.nats_url and self.nats_enabled)
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
133
services/market-data-service/app/consumers/nats_output.py
Normal file
133
services/market-data-service/app/consumers/nats_output.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
NATS output adapter — pushes normalised events to NATS subjects.
|
||||
|
||||
Subject schema:
|
||||
{prefix}.{event_type}.{symbol}
|
||||
e.g. md.events.trade.BTCUSDT
|
||||
md.events.quote.AAPL
|
||||
md.events.heartbeat.__system__
|
||||
|
||||
SenpAI (or any other consumer) can subscribe to:
|
||||
md.events.> — all events
|
||||
md.events.trade.> — all trades
|
||||
md.events.*.BTCUSDT — all event types for BTC
|
||||
|
||||
Payload: JSON (event.model_dump_json())
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.config import settings
|
||||
from app.domain.events import Event
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy import — nats-py may not be installed in minimal setups
|
||||
_nc = None
|
||||
|
||||
|
||||
class NatsOutputConsumer:
|
||||
"""
|
||||
Publishes every event to NATS as JSON.
|
||||
|
||||
Auto-reconnects via nats-py built-in mechanism.
|
||||
If NATS is unavailable, logs warning and drops events (non-blocking).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nats_url: str | None = None,
|
||||
subject_prefix: str | None = None,
|
||||
) -> None:
|
||||
self._url = nats_url or settings.nats_url
|
||||
self._prefix = subject_prefix or settings.nats_subject_prefix
|
||||
self._nc = None
|
||||
self._connected = False
|
||||
self._publish_count = 0
|
||||
self._drop_count = 0
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect to NATS."""
|
||||
try:
|
||||
import nats # noqa: F811
|
||||
|
||||
self._nc = await nats.connect(
|
||||
self._url,
|
||||
reconnect_time_wait=2,
|
||||
max_reconnect_attempts=-1, # infinite
|
||||
name="market-data-service",
|
||||
error_cb=self._error_cb,
|
||||
disconnected_cb=self._disconnected_cb,
|
||||
reconnected_cb=self._reconnected_cb,
|
||||
)
|
||||
self._connected = True
|
||||
logger.info(
|
||||
"nats_output.connected",
|
||||
extra={"url": self._url, "prefix": self._prefix},
|
||||
)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"nats_output.nats_not_installed",
|
||||
extra={"hint": "pip install nats-py"},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"nats_output.connect_failed",
|
||||
extra={"url": self._url, "error": str(e)},
|
||||
)
|
||||
|
||||
async def handle(self, event: Event) -> None:
|
||||
"""Publish event to NATS subject."""
|
||||
if not self._nc or not self._connected:
|
||||
self._drop_count += 1
|
||||
return
|
||||
|
||||
symbol = getattr(event, "symbol", "__system__")
|
||||
subject = f"{self._prefix}.{event.event_type.value}.{symbol}"
|
||||
|
||||
try:
|
||||
payload = event.model_dump_json().encode("utf-8")
|
||||
await self._nc.publish(subject, payload)
|
||||
self._publish_count += 1
|
||||
except Exception as e:
|
||||
self._drop_count += 1
|
||||
if self._drop_count % 1000 == 1:
|
||||
logger.warning(
|
||||
"nats_output.publish_failed",
|
||||
extra={
|
||||
"subject": subject,
|
||||
"error": str(e),
|
||||
"total_dropped": self._drop_count,
|
||||
},
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Flush and close NATS connection."""
|
||||
if self._nc:
|
||||
try:
|
||||
await self._nc.flush(timeout=5)
|
||||
await self._nc.close()
|
||||
except Exception as e:
|
||||
logger.warning("nats_output.close_error", extra={"error": str(e)})
|
||||
|
||||
logger.info(
|
||||
"nats_output.stopped",
|
||||
extra={
|
||||
"published": self._publish_count,
|
||||
"dropped": self._drop_count,
|
||||
},
|
||||
)
|
||||
|
||||
# ── NATS callbacks ────────────────────────────────────────────────
|
||||
|
||||
async def _error_cb(self, e: Exception) -> None:
|
||||
logger.error("nats_output.error", extra={"error": str(e)})
|
||||
|
||||
async def _disconnected_cb(self) -> None:
|
||||
self._connected = False
|
||||
logger.warning("nats_output.disconnected")
|
||||
|
||||
async def _reconnected_cb(self) -> None:
|
||||
self._connected = True
|
||||
logger.info("nats_output.reconnected")
|
||||
@@ -3,7 +3,6 @@ StorageConsumer: persists events to SQLite + JSONL log.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
"""
|
||||
Async event bus — fan-out from providers to consumers.
|
||||
|
||||
Features:
|
||||
- Backpressure with smart drop policy (drop quotes before trades)
|
||||
- Heartbeat timer per provider (detects dead channels)
|
||||
- Graceful drain on shutdown
|
||||
|
||||
Usage:
|
||||
bus = EventBus()
|
||||
bus.add_consumer(storage_consumer)
|
||||
@@ -13,9 +18,10 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Protocol
|
||||
|
||||
from app.domain.events import Event
|
||||
from app.domain.events import Event, EventType, HeartbeatEvent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,37 +32,105 @@ class EventConsumer(Protocol):
|
||||
async def handle(self, event: Event) -> None: ...
|
||||
|
||||
|
||||
# Events that can be dropped under backpressure (least critical first)
|
||||
_DROPPABLE_PRIORITY = {
|
||||
EventType.HEARTBEAT: 0, # always droppable
|
||||
EventType.QUOTE: 1, # drop quotes before trades
|
||||
EventType.BOOK_L2: 2, # drop book snapshots before trades
|
||||
EventType.TRADE: 3, # trades are most critical — last to drop
|
||||
}
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""
|
||||
Simple async fan-out bus.
|
||||
Async fan-out bus with backpressure and heartbeat monitoring.
|
||||
|
||||
Every published event is dispatched to all registered consumers
|
||||
concurrently (gather). A slow consumer doesn't block others thanks
|
||||
to the internal queue + worker pattern.
|
||||
Backpressure policy:
|
||||
- Queue 80% full → start dropping HEARTBEAT events
|
||||
- Queue 90% full → also drop QUOTE events
|
||||
- Queue 100% full → drop oldest (any type)
|
||||
|
||||
Heartbeat timer:
|
||||
- Emits synthetic HeartbeatEvent if a provider sends nothing
|
||||
for `heartbeat_interval` seconds, making dead channels visible.
|
||||
"""
|
||||
|
||||
def __init__(self, queue_size: int = 10_000) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
queue_size: int = 10_000,
|
||||
heartbeat_interval: float = 10.0,
|
||||
) -> None:
|
||||
self._consumers: list[EventConsumer] = []
|
||||
self._queue: asyncio.Queue[Event | None] = asyncio.Queue(maxsize=queue_size)
|
||||
self._max_size = queue_size
|
||||
self._running = False
|
||||
self._task: asyncio.Task | None = None
|
||||
self._heartbeat_interval = heartbeat_interval
|
||||
self._heartbeat_tasks: dict[str, asyncio.Task] = {}
|
||||
self._provider_last_seen: dict[str, float] = {}
|
||||
# Backpressure counters
|
||||
self._dropped: dict[str, int] = {}
|
||||
|
||||
def add_consumer(self, consumer: EventConsumer) -> None:
|
||||
self._consumers.append(consumer)
|
||||
logger.info("bus.consumer_added", extra={"consumer": type(consumer).__name__})
|
||||
|
||||
def register_provider(self, provider_name: str) -> None:
|
||||
"""Register a provider for heartbeat monitoring."""
|
||||
self._provider_last_seen[provider_name] = time.monotonic()
|
||||
|
||||
async def publish(self, event: Event) -> None:
|
||||
"""Put event into internal queue (non-blocking if queue not full)."""
|
||||
"""
|
||||
Put event into internal queue with backpressure.
|
||||
|
||||
Drop policy under pressure:
|
||||
- 80%+ → drop heartbeats
|
||||
- 90%+ → drop quotes/book snapshots
|
||||
- 100% → drop oldest event
|
||||
"""
|
||||
current = self._queue.qsize()
|
||||
fill_pct = current / self._max_size if self._max_size > 0 else 0
|
||||
|
||||
# Track provider activity for heartbeat timer
|
||||
self._provider_last_seen[event.provider] = time.monotonic()
|
||||
|
||||
priority = _DROPPABLE_PRIORITY.get(event.event_type, 3)
|
||||
|
||||
# Backpressure: drop low-priority events when queue is filling up
|
||||
if fill_pct >= 0.9 and priority <= 1:
|
||||
# Drop heartbeats and quotes
|
||||
self._dropped[event.event_type.value] = self._dropped.get(event.event_type.value, 0) + 1
|
||||
if self._dropped[event.event_type.value] % 1000 == 1:
|
||||
logger.warning(
|
||||
"bus.backpressure_drop",
|
||||
extra={
|
||||
"type": event.event_type.value,
|
||||
"fill_pct": f"{fill_pct:.0%}",
|
||||
"total_dropped": self._dropped,
|
||||
},
|
||||
)
|
||||
return
|
||||
|
||||
if fill_pct >= 0.8 and priority == 0:
|
||||
# Drop heartbeats only
|
||||
return
|
||||
|
||||
try:
|
||||
self._queue.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning("bus.queue_full, dropping oldest event")
|
||||
# Drop oldest to keep queue moving
|
||||
# Last resort: drop oldest to make room
|
||||
try:
|
||||
self._queue.get_nowait()
|
||||
dropped = self._queue.get_nowait()
|
||||
logger.warning(
|
||||
"bus.queue_full_drop_oldest",
|
||||
extra={"dropped_type": dropped.event_type.value if dropped else "None"},
|
||||
)
|
||||
except asyncio.QueueEmpty:
|
||||
pass
|
||||
self._queue.put_nowait(event)
|
||||
try:
|
||||
self._queue.put_nowait(event)
|
||||
except asyncio.QueueFull:
|
||||
pass # truly stuck
|
||||
|
||||
async def _worker(self) -> None:
|
||||
"""Background worker that drains the queue and fans out."""
|
||||
@@ -75,20 +149,79 @@ class EventBus:
|
||||
extra={"consumer": consumer_name, "error": str(result)},
|
||||
)
|
||||
|
||||
async def _heartbeat_monitor(self, provider_name: str) -> None:
|
||||
"""Emit synthetic heartbeat if provider goes silent."""
|
||||
while self._running:
|
||||
await asyncio.sleep(self._heartbeat_interval)
|
||||
if not self._running:
|
||||
break
|
||||
|
||||
last = self._provider_last_seen.get(provider_name, 0)
|
||||
elapsed = time.monotonic() - last
|
||||
|
||||
if elapsed > self._heartbeat_interval:
|
||||
# Provider is silent — emit heartbeat so metrics/logs see it
|
||||
logger.warning(
|
||||
"bus.provider_silent",
|
||||
extra={
|
||||
"provider": provider_name,
|
||||
"silent_seconds": f"{elapsed:.1f}",
|
||||
},
|
||||
)
|
||||
hb = HeartbeatEvent(provider=provider_name)
|
||||
await self.publish(hb)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the bus worker."""
|
||||
"""Start the bus worker and heartbeat monitors."""
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._worker())
|
||||
logger.info("bus.started", extra={"consumers": len(self._consumers)})
|
||||
|
||||
# Start heartbeat monitors for registered providers
|
||||
for pname in self._provider_last_seen:
|
||||
task = asyncio.create_task(self._heartbeat_monitor(pname))
|
||||
self._heartbeat_tasks[pname] = task
|
||||
|
||||
logger.info(
|
||||
"bus.started",
|
||||
extra={
|
||||
"consumers": len(self._consumers),
|
||||
"providers_monitored": list(self._provider_last_seen.keys()),
|
||||
},
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Graceful shutdown: drain queue then stop."""
|
||||
"""Graceful shutdown: stop heartbeats, drain queue, stop worker."""
|
||||
self._running = False
|
||||
await self._queue.put(None) # sentinel
|
||||
|
||||
# Cancel heartbeat monitors
|
||||
for task in self._heartbeat_tasks.values():
|
||||
task.cancel()
|
||||
for task in self._heartbeat_tasks.values():
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._heartbeat_tasks.clear()
|
||||
|
||||
# Drain remaining events
|
||||
remaining = self._queue.qsize()
|
||||
if remaining > 0:
|
||||
logger.info("bus.draining", extra={"remaining": remaining})
|
||||
|
||||
# Send sentinel to stop worker
|
||||
await self._queue.put(None)
|
||||
if self._task:
|
||||
await self._task
|
||||
|
||||
if self._dropped:
|
||||
logger.info("bus.drop_stats", extra={"dropped": self._dropped})
|
||||
|
||||
logger.info("bus.stopped")
|
||||
|
||||
@property
|
||||
def queue_size(self) -> int:
|
||||
return self._queue.qsize()
|
||||
|
||||
@property
|
||||
def fill_percent(self) -> float:
|
||||
return self._queue.qsize() / self._max_size if self._max_size > 0 else 0
|
||||
|
||||
170
services/market-data-service/app/core/failover.py
Normal file
170
services/market-data-service/app/core/failover.py
Normal file
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
Provider failover manager.
|
||||
|
||||
Tracks provider health per symbol and recommends the best active source.
|
||||
|
||||
Policy:
|
||||
- Each provider has a "health score" per symbol (0.0 – 1.0)
|
||||
- Score decreases on gaps (heartbeat timeout) and error events
|
||||
- Score increases on each successful trade/quote received
|
||||
- When primary provider's score drops below threshold → switch to backup
|
||||
|
||||
Usage:
|
||||
failover = FailoverManager(primary="binance", backups=["bybit"])
|
||||
failover.record_event("binance", "BTCUSDT") # bumps score
|
||||
failover.record_gap("binance", "BTCUSDT") # decreases score
|
||||
best = failover.get_best_provider("BTCUSDT") # → "binance" or "bybit"
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderHealth:
|
||||
"""Health tracker for one provider+symbol pair."""
|
||||
|
||||
score: float = 1.0
|
||||
event_count: int = 0
|
||||
gap_count: int = 0
|
||||
last_event_ts: float = 0.0
|
||||
last_gap_ts: float = 0.0
|
||||
|
||||
def record_event(self) -> None:
|
||||
"""Bump health score on successful event."""
|
||||
self.event_count += 1
|
||||
self.last_event_ts = time.monotonic()
|
||||
# Recover towards 1.0 gradually
|
||||
self.score = min(1.0, self.score + 0.01)
|
||||
|
||||
def record_gap(self) -> None:
|
||||
"""Decrease health score on gap/timeout."""
|
||||
self.gap_count += 1
|
||||
self.last_gap_ts = time.monotonic()
|
||||
self.score = max(0.0, self.score - 0.2)
|
||||
|
||||
|
||||
class FailoverManager:
|
||||
"""
|
||||
Tracks provider health and recommends best source per symbol.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
primary: str,
|
||||
backups: list[str] | None = None,
|
||||
switch_threshold: float = 0.3,
|
||||
recovery_threshold: float = 0.7,
|
||||
) -> None:
|
||||
self._primary = primary
|
||||
self._backups = backups or []
|
||||
self._all_providers = [primary] + self._backups
|
||||
self._switch_threshold = switch_threshold
|
||||
self._recovery_threshold = recovery_threshold
|
||||
|
||||
# provider → symbol → ProviderHealth
|
||||
self._health: dict[str, dict[str, ProviderHealth]] = {}
|
||||
|
||||
# symbol → currently active provider
|
||||
self._active: dict[str, str] = {}
|
||||
|
||||
def _get_health(self, provider: str, symbol: str) -> ProviderHealth:
|
||||
"""Get or create health tracker."""
|
||||
if provider not in self._health:
|
||||
self._health[provider] = {}
|
||||
if symbol not in self._health[provider]:
|
||||
self._health[provider][symbol] = ProviderHealth()
|
||||
return self._health[provider][symbol]
|
||||
|
||||
def record_event(self, provider: str, symbol: str) -> None:
|
||||
"""Record a successful event from provider for symbol."""
|
||||
self._get_health(provider, symbol).record_event()
|
||||
|
||||
def record_gap(self, provider: str, symbol: str) -> None:
|
||||
"""Record a gap/timeout for provider+symbol."""
|
||||
h = self._get_health(provider, symbol)
|
||||
h.record_gap()
|
||||
logger.warning(
|
||||
"failover.gap_recorded",
|
||||
extra={
|
||||
"provider": provider,
|
||||
"symbol": symbol,
|
||||
"score": f"{h.score:.2f}",
|
||||
"gaps": h.gap_count,
|
||||
},
|
||||
)
|
||||
|
||||
def get_best_provider(self, symbol: str) -> str:
|
||||
"""
|
||||
Return the currently recommended provider for this symbol.
|
||||
|
||||
Logic:
|
||||
1. If active provider score >= switch_threshold → keep it
|
||||
2. If active provider drops below → switch to healthiest backup
|
||||
3. If active provider recovers above recovery_threshold → switch back to primary
|
||||
"""
|
||||
current = self._active.get(symbol, self._primary)
|
||||
current_health = self._get_health(current, symbol)
|
||||
|
||||
# Check if current provider is degraded
|
||||
if current_health.score < self._switch_threshold:
|
||||
# Find best backup
|
||||
best_provider = current
|
||||
best_score = current_health.score
|
||||
|
||||
for p in self._all_providers:
|
||||
if p == current:
|
||||
continue
|
||||
h = self._get_health(p, symbol)
|
||||
if h.score > best_score:
|
||||
best_provider = p
|
||||
best_score = h.score
|
||||
|
||||
if best_provider != current:
|
||||
logger.warning(
|
||||
"failover.switching",
|
||||
extra={
|
||||
"symbol": symbol,
|
||||
"from": current,
|
||||
"to": best_provider,
|
||||
"old_score": f"{current_health.score:.2f}",
|
||||
"new_score": f"{best_score:.2f}",
|
||||
},
|
||||
)
|
||||
self._active[symbol] = best_provider
|
||||
return best_provider
|
||||
|
||||
# Check if primary has recovered and we're on a backup
|
||||
if current != self._primary:
|
||||
primary_health = self._get_health(self._primary, symbol)
|
||||
if primary_health.score >= self._recovery_threshold:
|
||||
logger.info(
|
||||
"failover.returning_to_primary",
|
||||
extra={
|
||||
"symbol": symbol,
|
||||
"primary_score": f"{primary_health.score:.2f}",
|
||||
},
|
||||
)
|
||||
self._active[symbol] = self._primary
|
||||
return self._primary
|
||||
|
||||
self._active[symbol] = current
|
||||
return current
|
||||
|
||||
def get_status(self) -> dict:
|
||||
"""Return full failover status for monitoring."""
|
||||
status = {}
|
||||
for provider, symbols in self._health.items():
|
||||
for symbol, health in symbols.items():
|
||||
key = f"{provider}/{symbol}"
|
||||
status[key] = {
|
||||
"score": round(health.score, 2),
|
||||
"events": health.event_count,
|
||||
"gaps": health.gap_count,
|
||||
"active": self._active.get(symbol) == provider,
|
||||
}
|
||||
return status
|
||||
@@ -18,7 +18,6 @@ import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import structlog
|
||||
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
|
||||
@@ -26,14 +25,18 @@ from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
|
||||
from app.config import settings
|
||||
from app.core.bus import EventBus
|
||||
from app.consumers.metrics import MetricsConsumer
|
||||
from app.consumers.nats_output import NatsOutputConsumer
|
||||
from app.consumers.print import PrintConsumer
|
||||
from app.consumers.storage import StorageConsumer
|
||||
from app.db.schema import init_db
|
||||
from app.db.schema import engine, init_db
|
||||
from app.db import repo
|
||||
from app.providers import MarketDataProvider, get_provider
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Global reference to bus (for HTTP status endpoint)
|
||||
_bus: EventBus | None = None
|
||||
|
||||
# ── Logging setup ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -105,6 +108,18 @@ async def _http_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWrit
|
||||
}
|
||||
body = json.dumps(result, ensure_ascii=False).encode()
|
||||
content_type = "application/json"
|
||||
elif path == "/bus-stats":
|
||||
import json as _json
|
||||
|
||||
bus_info = {"queue_size": 0, "fill_percent": 0.0}
|
||||
if _bus:
|
||||
bus_info = {
|
||||
"queue_size": _bus.queue_size,
|
||||
"fill_percent": round(_bus.fill_percent * 100, 1),
|
||||
"max_size": _bus._max_size,
|
||||
}
|
||||
body = _json.dumps(bus_info).encode()
|
||||
content_type = "application/json"
|
||||
else:
|
||||
body = b'{"error":"not found"}'
|
||||
content_type = "application/json"
|
||||
@@ -179,8 +194,13 @@ async def main(provider_names: list[str], symbols: list[str]) -> None:
|
||||
# Init database
|
||||
await init_db()
|
||||
|
||||
# Setup bus + consumers
|
||||
bus = EventBus()
|
||||
global _bus
|
||||
|
||||
# Setup bus + consumers (heartbeat interval from config)
|
||||
bus = EventBus(
|
||||
queue_size=10_000,
|
||||
heartbeat_interval=settings.heartbeat_timeout / 2, # check twice per timeout
|
||||
)
|
||||
|
||||
storage = StorageConsumer()
|
||||
await storage.start()
|
||||
@@ -192,16 +212,29 @@ async def main(provider_names: list[str], symbols: list[str]) -> None:
|
||||
printer = PrintConsumer()
|
||||
bus.add_consumer(printer)
|
||||
|
||||
# Optional: NATS output adapter
|
||||
nats_consumer = None
|
||||
if settings.nats_configured:
|
||||
nats_consumer = NatsOutputConsumer()
|
||||
await nats_consumer.start()
|
||||
bus.add_consumer(nats_consumer)
|
||||
logger.info("nats_output.enabled", subject_prefix=settings.nats_subject_prefix)
|
||||
else:
|
||||
logger.info("nats_output.disabled", hint="Set NATS_URL + NATS_ENABLED=true to enable")
|
||||
|
||||
# Create providers and register them for heartbeat monitoring
|
||||
providers: list[MarketDataProvider] = []
|
||||
for name in provider_names:
|
||||
p = get_provider(name)
|
||||
providers.append(p)
|
||||
bus.register_provider(p.name)
|
||||
|
||||
_bus = bus
|
||||
await bus.start()
|
||||
|
||||
# Start HTTP server
|
||||
http_server = await start_http_server()
|
||||
|
||||
# Create providers
|
||||
providers: list[MarketDataProvider] = []
|
||||
for name in provider_names:
|
||||
providers.append(get_provider(name))
|
||||
|
||||
# Run all providers concurrently
|
||||
tasks = []
|
||||
for p in providers:
|
||||
@@ -224,21 +257,43 @@ async def main(provider_names: list[str], symbols: list[str]) -> None:
|
||||
# Wait for shutdown
|
||||
await shutdown_event.wait()
|
||||
|
||||
# Cleanup
|
||||
# ── Graceful shutdown sequence ──────────────────────────────────────
|
||||
logger.info("service.shutting_down")
|
||||
|
||||
# 1. Cancel provider streaming tasks (with timeout)
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
done, pending = await asyncio.wait(tasks, timeout=5.0)
|
||||
for task in pending:
|
||||
logger.warning("service.task_force_cancel", extra={"task": task.get_name()})
|
||||
task.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# 2. Close provider WebSocket connections
|
||||
for p in providers:
|
||||
await p.close()
|
||||
try:
|
||||
await p.close()
|
||||
except Exception as e:
|
||||
logger.warning("service.provider_close_error", extra={"provider": p.name, "error": str(e)})
|
||||
|
||||
# 3. Stop bus (drains remaining events to consumers)
|
||||
await bus.stop()
|
||||
|
||||
# 4. Stop storage consumer (flush JSONL)
|
||||
await storage.stop()
|
||||
|
||||
# 4b. Stop NATS output (flush + close)
|
||||
if nats_consumer:
|
||||
await nats_consumer.stop()
|
||||
|
||||
# 5. Close HTTP server
|
||||
http_server.close()
|
||||
await http_server.wait_closed()
|
||||
|
||||
logger.info("service.stopped")
|
||||
# 6. Close SQLAlchemy engine (flush connections)
|
||||
await engine.dispose()
|
||||
|
||||
logger.info("service.stopped", extra={"exit": "clean"})
|
||||
|
||||
|
||||
# ── CLI ────────────────────────────────────────────────────────────────
|
||||
@@ -270,7 +325,7 @@ def cli():
|
||||
symbols = [s.strip() for s in args.symbols.split(",") if s.strip()]
|
||||
|
||||
if args.provider.lower() == "all":
|
||||
provider_names = ["binance", "alpaca"]
|
||||
provider_names = ["binance", "alpaca", "bybit"]
|
||||
else:
|
||||
provider_names = [p.strip() for p in args.provider.split(",") if p.strip()]
|
||||
|
||||
|
||||
@@ -45,10 +45,12 @@ def get_provider(name: str) -> MarketDataProvider:
|
||||
"""Factory: instantiate provider by name."""
|
||||
from app.providers.binance import BinanceProvider
|
||||
from app.providers.alpaca import AlpacaProvider
|
||||
from app.providers.bybit import BybitProvider
|
||||
|
||||
registry: dict[str, type[MarketDataProvider]] = {
|
||||
"binance": BinanceProvider,
|
||||
"alpaca": AlpacaProvider,
|
||||
"bybit": BybitProvider,
|
||||
}
|
||||
cls = registry.get(name.lower())
|
||||
if cls is None:
|
||||
|
||||
@@ -17,7 +17,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from typing import AsyncIterator
|
||||
|
||||
import websockets
|
||||
|
||||
@@ -12,7 +12,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncIterator
|
||||
|
||||
@@ -22,7 +21,6 @@ from websockets.exceptions import ConnectionClosed
|
||||
from app.config import settings
|
||||
from app.domain.events import (
|
||||
Event,
|
||||
HeartbeatEvent,
|
||||
QuoteEvent,
|
||||
TradeEvent,
|
||||
)
|
||||
|
||||
239
services/market-data-service/app/providers/bybit.py
Normal file
239
services/market-data-service/app/providers/bybit.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Bybit V5 public WebSocket provider — backup for Binance.
|
||||
|
||||
Streams:
|
||||
- publicTrade.{symbol} → TradeEvent
|
||||
- tickers.{symbol} → QuoteEvent (best bid/ask from tickers)
|
||||
|
||||
Docs: https://bybit-exchange.github.io/docs/v5/ws/connect
|
||||
|
||||
No API key needed for public market data.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import AsyncIterator
|
||||
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosed
|
||||
|
||||
from app.config import settings
|
||||
from app.domain.events import (
|
||||
Event,
|
||||
QuoteEvent,
|
||||
TradeEvent,
|
||||
)
|
||||
from app.providers import MarketDataProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _ms_to_dt(ms: int | float | str | None) -> datetime | None:
|
||||
"""Convert millisecond epoch to UTC datetime."""
|
||||
if ms is None:
|
||||
return None
|
||||
try:
|
||||
return datetime.fromtimestamp(int(ms) / 1000.0, tz=timezone.utc)
|
||||
except (ValueError, TypeError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
class BybitProvider(MarketDataProvider):
|
||||
"""
|
||||
Bybit V5 public WebSocket (spot market).
|
||||
|
||||
Connects to the spot public channel and subscribes to
|
||||
publicTrade + tickers for each symbol.
|
||||
"""
|
||||
|
||||
name = "bybit"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._ws: websockets.WebSocketClientProtocol | None = None
|
||||
self._symbols: list[str] = []
|
||||
self._connected = False
|
||||
self._reconnect_count = 0
|
||||
self._base_url = settings.bybit_ws_url
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""Establish WebSocket connection."""
|
||||
logger.info("bybit.connecting", extra={"url": self._base_url})
|
||||
self._ws = await websockets.connect(
|
||||
self._base_url,
|
||||
ping_interval=20,
|
||||
ping_timeout=10,
|
||||
close_timeout=5,
|
||||
)
|
||||
self._connected = True
|
||||
logger.info("bybit.connected")
|
||||
|
||||
async def subscribe(self, symbols: list[str]) -> None:
|
||||
"""Subscribe to publicTrade + tickers for each symbol."""
|
||||
if not self._ws:
|
||||
raise RuntimeError("Not connected. Call connect() first.")
|
||||
|
||||
self._symbols = [s.upper() for s in symbols]
|
||||
args = []
|
||||
for sym in self._symbols:
|
||||
args.append(f"publicTrade.{sym}")
|
||||
args.append(f"tickers.{sym}")
|
||||
|
||||
subscribe_msg = {
|
||||
"op": "subscribe",
|
||||
"args": args,
|
||||
}
|
||||
await self._ws.send(json.dumps(subscribe_msg))
|
||||
logger.info(
|
||||
"bybit.subscribed",
|
||||
extra={"symbols": self._symbols, "channels": len(args)},
|
||||
)
|
||||
|
||||
async def stream(self) -> AsyncIterator[Event]:
|
||||
"""Yield domain events. Handles reconnect automatically."""
|
||||
backoff = settings.reconnect_base_delay
|
||||
|
||||
while True:
|
||||
try:
|
||||
if not self._connected or not self._ws:
|
||||
await self._reconnect(backoff)
|
||||
|
||||
try:
|
||||
raw = await asyncio.wait_for(
|
||||
self._ws.recv(), # type: ignore
|
||||
timeout=settings.heartbeat_timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"bybit.heartbeat_timeout",
|
||||
extra={"timeout": settings.heartbeat_timeout},
|
||||
)
|
||||
self._connected = False
|
||||
continue
|
||||
|
||||
# Reset backoff on successful message
|
||||
backoff = settings.reconnect_base_delay
|
||||
|
||||
data = json.loads(raw)
|
||||
|
||||
# Handle pong (Bybit sends {"op":"pong",...})
|
||||
if data.get("op") in ("pong", "subscribe"):
|
||||
if data.get("success") is False:
|
||||
logger.warning("bybit.subscribe_failed", extra={"msg": data})
|
||||
continue
|
||||
|
||||
event = self._parse(data)
|
||||
if event:
|
||||
yield event
|
||||
|
||||
except ConnectionClosed as e:
|
||||
logger.warning(
|
||||
"bybit.connection_closed",
|
||||
extra={"code": e.code, "reason": str(e.reason)},
|
||||
)
|
||||
self._connected = False
|
||||
backoff = min(backoff * 2, settings.reconnect_max_delay)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("bybit.stream_error", extra={"error": str(e)})
|
||||
self._connected = False
|
||||
backoff = min(backoff * 2, settings.reconnect_max_delay)
|
||||
|
||||
async def _reconnect(self, delay: float) -> None:
|
||||
"""Reconnect with delay, then resubscribe."""
|
||||
self._reconnect_count += 1
|
||||
logger.info(
|
||||
"bybit.reconnecting",
|
||||
extra={"delay": delay, "attempt": self._reconnect_count},
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
try:
|
||||
if self._ws:
|
||||
await self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await self.connect()
|
||||
if self._symbols:
|
||||
await self.subscribe(self._symbols)
|
||||
|
||||
def _parse(self, data: dict) -> Event | None:
|
||||
"""Parse raw Bybit V5 message into domain events."""
|
||||
topic = data.get("topic", "")
|
||||
event_data = data.get("data")
|
||||
|
||||
if not topic or event_data is None:
|
||||
return None
|
||||
|
||||
if topic.startswith("publicTrade."):
|
||||
return self._parse_trades(event_data)
|
||||
elif topic.startswith("tickers."):
|
||||
return self._parse_ticker(event_data)
|
||||
|
||||
return None
|
||||
|
||||
def _parse_trades(self, data: list | dict) -> Event | None:
|
||||
"""
|
||||
Bybit publicTrade payload (V5):
|
||||
{"data": [{"s":"BTCUSDT","S":"Buy","v":"0.001","p":"70000.5","T":1672515782136,"i":"..."}]}
|
||||
We take the last trade in the batch.
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
if not data:
|
||||
return None
|
||||
trade = data[-1] # latest in batch
|
||||
else:
|
||||
trade = data
|
||||
|
||||
return TradeEvent(
|
||||
provider=self.name,
|
||||
symbol=trade.get("s", "").upper(),
|
||||
price=float(trade.get("p", 0)),
|
||||
size=float(trade.get("v", 0)),
|
||||
ts_exchange=_ms_to_dt(trade.get("T")),
|
||||
side=trade.get("S", "").lower() if trade.get("S") else None,
|
||||
trade_id=str(trade.get("i", "")),
|
||||
)
|
||||
|
||||
def _parse_ticker(self, data: dict) -> QuoteEvent | None:
|
||||
"""
|
||||
Bybit tickers (V5 spot):
|
||||
{"data": {"symbol":"BTCUSDT","bid1Price":"70000.5","bid1Size":"1.5",
|
||||
"ask1Price":"70001.0","ask1Size":"2.0",...}}
|
||||
"""
|
||||
if isinstance(data, list):
|
||||
data = data[0] if data else {}
|
||||
|
||||
bid = data.get("bid1Price") or data.get("bidPrice")
|
||||
ask = data.get("ask1Price") or data.get("askPrice")
|
||||
bid_size = data.get("bid1Size") or data.get("bidSize")
|
||||
ask_size = data.get("ask1Size") or data.get("askSize")
|
||||
|
||||
if not bid or not ask:
|
||||
return None
|
||||
|
||||
return QuoteEvent(
|
||||
provider=self.name,
|
||||
symbol=data.get("symbol", "").upper(),
|
||||
bid=float(bid),
|
||||
ask=float(ask),
|
||||
bid_size=float(bid_size or 0),
|
||||
ask_size=float(ask_size or 0),
|
||||
ts_exchange=_ms_to_dt(data.get("ts")),
|
||||
)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the WebSocket connection."""
|
||||
self._connected = False
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(
|
||||
"bybit.closed",
|
||||
extra={"reconnect_count": self._reconnect_count},
|
||||
)
|
||||
Reference in New Issue
Block a user