feat: MD pipeline — market-data-service hardening + SenpAI NATS consumer

Producer (market-data-service):
- Backpressure: smart drop policy (heartbeats→quotes→trades preserved)
- Heartbeat monitor: synthetic HeartbeatEvent on provider silence
- Graceful shutdown: WS→bus→storage→DB engine cleanup sequence
- Bybit V5 public WS provider (backup for Binance, no API key needed)
- FailoverManager: health-based provider switching with recovery
- NATS output adapter: md.events.{type}.{symbol} for SenpAI
- /bus-stats endpoint for backpressure monitoring
- Dockerfile + docker-compose.node1.yml integration
- 36 tests (parsing + bus + failover), requirements.lock

Consumer (senpai-md-consumer):
- NATSConsumer: subscribe md.events.>, queue group senpai-md, backpressure
- State store: LatestState + RollingWindow (deque, 60s)
- Feature engine: 11 features (mid, spread, VWAP, return, vol, latency)
- Rule-based signals: long/short on return+volume+spread conditions
- Publisher: rate-limited features + signals + alerts to NATS
- HTTP API: /health, /metrics, /state/latest, /features/latest, /stats
- 10 Prometheus metrics
- Dockerfile + docker-compose.senpai.yml
- 41 tests (parsing + state + features + rate-limit), requirements.lock

CI: ruff + pytest + smoke import for both services
Tests: 77 total passed, lint clean
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Apple
2026-02-09 11:46:15 -08:00
parent c50843933f
commit 09dee24342
47 changed files with 3930 additions and 56 deletions

View File

@@ -19,6 +19,9 @@ class Settings(BaseSettings):
binance_ws_url: str = "wss://stream.binance.com:9443/ws"
binance_rest_url: str = "https://api.binance.com"
# ── Bybit (backup crypto — no key needed) ──────────────────────────
bybit_ws_url: str = "wss://stream.bybit.com/v5/public/spot"
# ── Alpaca (paper trading — free tier) ─────────────────────────────
alpaca_key: str = ""
alpaca_secret: str = ""
@@ -41,6 +44,11 @@ class Settings(BaseSettings):
http_port: int = 8891
metrics_enabled: bool = True
# ── NATS output adapter ─────────────────────────────────────────────
nats_url: str = "" # e.g. "nats://localhost:4222"
nats_subject_prefix: str = "md.events" # → md.events.trade.BTCUSDT
nats_enabled: bool = False
# ── Logging ────────────────────────────────────────────────────────
log_level: str = "INFO"
log_sample_rate: int = 100 # PrintConsumer: log 1 out of N events
@@ -49,5 +57,9 @@ class Settings(BaseSettings):
def alpaca_configured(self) -> bool:
return bool(self.alpaca_key and self.alpaca_secret)
@property
def nats_configured(self) -> bool:
return bool(self.nats_url and self.nats_enabled)
settings = Settings()

View File

@@ -0,0 +1,133 @@
"""
NATS output adapter — pushes normalised events to NATS subjects.
Subject schema:
{prefix}.{event_type}.{symbol}
e.g. md.events.trade.BTCUSDT
md.events.quote.AAPL
md.events.heartbeat.__system__
SenpAI (or any other consumer) can subscribe to:
md.events.> — all events
md.events.trade.> — all trades
md.events.*.BTCUSDT — all event types for BTC
Payload: JSON (event.model_dump_json())
"""
from __future__ import annotations
import logging
from app.config import settings
from app.domain.events import Event
logger = logging.getLogger(__name__)
# Lazy import — nats-py may not be installed in minimal setups
_nc = None
class NatsOutputConsumer:
"""
Publishes every event to NATS as JSON.
Auto-reconnects via nats-py built-in mechanism.
If NATS is unavailable, logs warning and drops events (non-blocking).
"""
def __init__(
self,
nats_url: str | None = None,
subject_prefix: str | None = None,
) -> None:
self._url = nats_url or settings.nats_url
self._prefix = subject_prefix or settings.nats_subject_prefix
self._nc = None
self._connected = False
self._publish_count = 0
self._drop_count = 0
async def start(self) -> None:
"""Connect to NATS."""
try:
import nats # noqa: F811
self._nc = await nats.connect(
self._url,
reconnect_time_wait=2,
max_reconnect_attempts=-1, # infinite
name="market-data-service",
error_cb=self._error_cb,
disconnected_cb=self._disconnected_cb,
reconnected_cb=self._reconnected_cb,
)
self._connected = True
logger.info(
"nats_output.connected",
extra={"url": self._url, "prefix": self._prefix},
)
except ImportError:
logger.error(
"nats_output.nats_not_installed",
extra={"hint": "pip install nats-py"},
)
except Exception as e:
logger.error(
"nats_output.connect_failed",
extra={"url": self._url, "error": str(e)},
)
async def handle(self, event: Event) -> None:
"""Publish event to NATS subject."""
if not self._nc or not self._connected:
self._drop_count += 1
return
symbol = getattr(event, "symbol", "__system__")
subject = f"{self._prefix}.{event.event_type.value}.{symbol}"
try:
payload = event.model_dump_json().encode("utf-8")
await self._nc.publish(subject, payload)
self._publish_count += 1
except Exception as e:
self._drop_count += 1
if self._drop_count % 1000 == 1:
logger.warning(
"nats_output.publish_failed",
extra={
"subject": subject,
"error": str(e),
"total_dropped": self._drop_count,
},
)
async def stop(self) -> None:
"""Flush and close NATS connection."""
if self._nc:
try:
await self._nc.flush(timeout=5)
await self._nc.close()
except Exception as e:
logger.warning("nats_output.close_error", extra={"error": str(e)})
logger.info(
"nats_output.stopped",
extra={
"published": self._publish_count,
"dropped": self._drop_count,
},
)
# ── NATS callbacks ────────────────────────────────────────────────
async def _error_cb(self, e: Exception) -> None:
logger.error("nats_output.error", extra={"error": str(e)})
async def _disconnected_cb(self) -> None:
self._connected = False
logger.warning("nats_output.disconnected")
async def _reconnected_cb(self) -> None:
self._connected = True
logger.info("nats_output.reconnected")

View File

@@ -3,7 +3,6 @@ StorageConsumer: persists events to SQLite + JSONL log.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path

View File

@@ -1,6 +1,11 @@
"""
Async event bus — fan-out from providers to consumers.
Features:
- Backpressure with smart drop policy (drop quotes before trades)
- Heartbeat timer per provider (detects dead channels)
- Graceful drain on shutdown
Usage:
bus = EventBus()
bus.add_consumer(storage_consumer)
@@ -13,9 +18,10 @@ from __future__ import annotations
import asyncio
import logging
import time
from typing import Protocol
from app.domain.events import Event
from app.domain.events import Event, EventType, HeartbeatEvent
logger = logging.getLogger(__name__)
@@ -26,37 +32,105 @@ class EventConsumer(Protocol):
async def handle(self, event: Event) -> None: ...
# Events that can be dropped under backpressure (least critical first)
_DROPPABLE_PRIORITY = {
EventType.HEARTBEAT: 0, # always droppable
EventType.QUOTE: 1, # drop quotes before trades
EventType.BOOK_L2: 2, # drop book snapshots before trades
EventType.TRADE: 3, # trades are most critical — last to drop
}
class EventBus:
"""
Simple async fan-out bus.
Async fan-out bus with backpressure and heartbeat monitoring.
Every published event is dispatched to all registered consumers
concurrently (gather). A slow consumer doesn't block others thanks
to the internal queue + worker pattern.
Backpressure policy:
- Queue 80% full → start dropping HEARTBEAT events
- Queue 90% full → also drop QUOTE events
- Queue 100% full → drop oldest (any type)
Heartbeat timer:
- Emits synthetic HeartbeatEvent if a provider sends nothing
for `heartbeat_interval` seconds, making dead channels visible.
"""
def __init__(self, queue_size: int = 10_000) -> None:
def __init__(
self,
queue_size: int = 10_000,
heartbeat_interval: float = 10.0,
) -> None:
self._consumers: list[EventConsumer] = []
self._queue: asyncio.Queue[Event | None] = asyncio.Queue(maxsize=queue_size)
self._max_size = queue_size
self._running = False
self._task: asyncio.Task | None = None
self._heartbeat_interval = heartbeat_interval
self._heartbeat_tasks: dict[str, asyncio.Task] = {}
self._provider_last_seen: dict[str, float] = {}
# Backpressure counters
self._dropped: dict[str, int] = {}
def add_consumer(self, consumer: EventConsumer) -> None:
self._consumers.append(consumer)
logger.info("bus.consumer_added", extra={"consumer": type(consumer).__name__})
def register_provider(self, provider_name: str) -> None:
"""Register a provider for heartbeat monitoring."""
self._provider_last_seen[provider_name] = time.monotonic()
async def publish(self, event: Event) -> None:
"""Put event into internal queue (non-blocking if queue not full)."""
"""
Put event into internal queue with backpressure.
Drop policy under pressure:
- 80%+ → drop heartbeats
- 90%+ → drop quotes/book snapshots
- 100% → drop oldest event
"""
current = self._queue.qsize()
fill_pct = current / self._max_size if self._max_size > 0 else 0
# Track provider activity for heartbeat timer
self._provider_last_seen[event.provider] = time.monotonic()
priority = _DROPPABLE_PRIORITY.get(event.event_type, 3)
# Backpressure: drop low-priority events when queue is filling up
if fill_pct >= 0.9 and priority <= 1:
# Drop heartbeats and quotes
self._dropped[event.event_type.value] = self._dropped.get(event.event_type.value, 0) + 1
if self._dropped[event.event_type.value] % 1000 == 1:
logger.warning(
"bus.backpressure_drop",
extra={
"type": event.event_type.value,
"fill_pct": f"{fill_pct:.0%}",
"total_dropped": self._dropped,
},
)
return
if fill_pct >= 0.8 and priority == 0:
# Drop heartbeats only
return
try:
self._queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning("bus.queue_full, dropping oldest event")
# Drop oldest to keep queue moving
# Last resort: drop oldest to make room
try:
self._queue.get_nowait()
dropped = self._queue.get_nowait()
logger.warning(
"bus.queue_full_drop_oldest",
extra={"dropped_type": dropped.event_type.value if dropped else "None"},
)
except asyncio.QueueEmpty:
pass
self._queue.put_nowait(event)
try:
self._queue.put_nowait(event)
except asyncio.QueueFull:
pass # truly stuck
async def _worker(self) -> None:
"""Background worker that drains the queue and fans out."""
@@ -75,20 +149,79 @@ class EventBus:
extra={"consumer": consumer_name, "error": str(result)},
)
async def _heartbeat_monitor(self, provider_name: str) -> None:
"""Emit synthetic heartbeat if provider goes silent."""
while self._running:
await asyncio.sleep(self._heartbeat_interval)
if not self._running:
break
last = self._provider_last_seen.get(provider_name, 0)
elapsed = time.monotonic() - last
if elapsed > self._heartbeat_interval:
# Provider is silent — emit heartbeat so metrics/logs see it
logger.warning(
"bus.provider_silent",
extra={
"provider": provider_name,
"silent_seconds": f"{elapsed:.1f}",
},
)
hb = HeartbeatEvent(provider=provider_name)
await self.publish(hb)
async def start(self) -> None:
"""Start the bus worker."""
"""Start the bus worker and heartbeat monitors."""
self._running = True
self._task = asyncio.create_task(self._worker())
logger.info("bus.started", extra={"consumers": len(self._consumers)})
# Start heartbeat monitors for registered providers
for pname in self._provider_last_seen:
task = asyncio.create_task(self._heartbeat_monitor(pname))
self._heartbeat_tasks[pname] = task
logger.info(
"bus.started",
extra={
"consumers": len(self._consumers),
"providers_monitored": list(self._provider_last_seen.keys()),
},
)
async def stop(self) -> None:
"""Graceful shutdown: drain queue then stop."""
"""Graceful shutdown: stop heartbeats, drain queue, stop worker."""
self._running = False
await self._queue.put(None) # sentinel
# Cancel heartbeat monitors
for task in self._heartbeat_tasks.values():
task.cancel()
for task in self._heartbeat_tasks.values():
try:
await task
except asyncio.CancelledError:
pass
self._heartbeat_tasks.clear()
# Drain remaining events
remaining = self._queue.qsize()
if remaining > 0:
logger.info("bus.draining", extra={"remaining": remaining})
# Send sentinel to stop worker
await self._queue.put(None)
if self._task:
await self._task
if self._dropped:
logger.info("bus.drop_stats", extra={"dropped": self._dropped})
logger.info("bus.stopped")
@property
def queue_size(self) -> int:
return self._queue.qsize()
@property
def fill_percent(self) -> float:
return self._queue.qsize() / self._max_size if self._max_size > 0 else 0

View File

@@ -0,0 +1,170 @@
"""
Provider failover manager.
Tracks provider health per symbol and recommends the best active source.
Policy:
- Each provider has a "health score" per symbol (0.0 1.0)
- Score decreases on gaps (heartbeat timeout) and error events
- Score increases on each successful trade/quote received
- When primary provider's score drops below threshold → switch to backup
Usage:
failover = FailoverManager(primary="binance", backups=["bybit"])
failover.record_event("binance", "BTCUSDT") # bumps score
failover.record_gap("binance", "BTCUSDT") # decreases score
best = failover.get_best_provider("BTCUSDT") # → "binance" or "bybit"
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
logger = logging.getLogger(__name__)
@dataclass
class ProviderHealth:
"""Health tracker for one provider+symbol pair."""
score: float = 1.0
event_count: int = 0
gap_count: int = 0
last_event_ts: float = 0.0
last_gap_ts: float = 0.0
def record_event(self) -> None:
"""Bump health score on successful event."""
self.event_count += 1
self.last_event_ts = time.monotonic()
# Recover towards 1.0 gradually
self.score = min(1.0, self.score + 0.01)
def record_gap(self) -> None:
"""Decrease health score on gap/timeout."""
self.gap_count += 1
self.last_gap_ts = time.monotonic()
self.score = max(0.0, self.score - 0.2)
class FailoverManager:
"""
Tracks provider health and recommends best source per symbol.
"""
def __init__(
self,
primary: str,
backups: list[str] | None = None,
switch_threshold: float = 0.3,
recovery_threshold: float = 0.7,
) -> None:
self._primary = primary
self._backups = backups or []
self._all_providers = [primary] + self._backups
self._switch_threshold = switch_threshold
self._recovery_threshold = recovery_threshold
# provider → symbol → ProviderHealth
self._health: dict[str, dict[str, ProviderHealth]] = {}
# symbol → currently active provider
self._active: dict[str, str] = {}
def _get_health(self, provider: str, symbol: str) -> ProviderHealth:
"""Get or create health tracker."""
if provider not in self._health:
self._health[provider] = {}
if symbol not in self._health[provider]:
self._health[provider][symbol] = ProviderHealth()
return self._health[provider][symbol]
def record_event(self, provider: str, symbol: str) -> None:
"""Record a successful event from provider for symbol."""
self._get_health(provider, symbol).record_event()
def record_gap(self, provider: str, symbol: str) -> None:
"""Record a gap/timeout for provider+symbol."""
h = self._get_health(provider, symbol)
h.record_gap()
logger.warning(
"failover.gap_recorded",
extra={
"provider": provider,
"symbol": symbol,
"score": f"{h.score:.2f}",
"gaps": h.gap_count,
},
)
def get_best_provider(self, symbol: str) -> str:
"""
Return the currently recommended provider for this symbol.
Logic:
1. If active provider score >= switch_threshold → keep it
2. If active provider drops below → switch to healthiest backup
3. If active provider recovers above recovery_threshold → switch back to primary
"""
current = self._active.get(symbol, self._primary)
current_health = self._get_health(current, symbol)
# Check if current provider is degraded
if current_health.score < self._switch_threshold:
# Find best backup
best_provider = current
best_score = current_health.score
for p in self._all_providers:
if p == current:
continue
h = self._get_health(p, symbol)
if h.score > best_score:
best_provider = p
best_score = h.score
if best_provider != current:
logger.warning(
"failover.switching",
extra={
"symbol": symbol,
"from": current,
"to": best_provider,
"old_score": f"{current_health.score:.2f}",
"new_score": f"{best_score:.2f}",
},
)
self._active[symbol] = best_provider
return best_provider
# Check if primary has recovered and we're on a backup
if current != self._primary:
primary_health = self._get_health(self._primary, symbol)
if primary_health.score >= self._recovery_threshold:
logger.info(
"failover.returning_to_primary",
extra={
"symbol": symbol,
"primary_score": f"{primary_health.score:.2f}",
},
)
self._active[symbol] = self._primary
return self._primary
self._active[symbol] = current
return current
def get_status(self) -> dict:
"""Return full failover status for monitoring."""
status = {}
for provider, symbols in self._health.items():
for symbol, health in symbols.items():
key = f"{provider}/{symbol}"
status[key] = {
"score": round(health.score, 2),
"events": health.event_count,
"gaps": health.gap_count,
"active": self._active.get(symbol) == provider,
}
return status

View File

@@ -18,7 +18,6 @@ import asyncio
import logging
import signal
import sys
from contextlib import asynccontextmanager
import structlog
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
@@ -26,14 +25,18 @@ from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
from app.config import settings
from app.core.bus import EventBus
from app.consumers.metrics import MetricsConsumer
from app.consumers.nats_output import NatsOutputConsumer
from app.consumers.print import PrintConsumer
from app.consumers.storage import StorageConsumer
from app.db.schema import init_db
from app.db.schema import engine, init_db
from app.db import repo
from app.providers import MarketDataProvider, get_provider
logger = structlog.get_logger()
# Global reference to bus (for HTTP status endpoint)
_bus: EventBus | None = None
# ── Logging setup ──────────────────────────────────────────────────────
@@ -105,6 +108,18 @@ async def _http_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWrit
}
body = json.dumps(result, ensure_ascii=False).encode()
content_type = "application/json"
elif path == "/bus-stats":
import json as _json
bus_info = {"queue_size": 0, "fill_percent": 0.0}
if _bus:
bus_info = {
"queue_size": _bus.queue_size,
"fill_percent": round(_bus.fill_percent * 100, 1),
"max_size": _bus._max_size,
}
body = _json.dumps(bus_info).encode()
content_type = "application/json"
else:
body = b'{"error":"not found"}'
content_type = "application/json"
@@ -179,8 +194,13 @@ async def main(provider_names: list[str], symbols: list[str]) -> None:
# Init database
await init_db()
# Setup bus + consumers
bus = EventBus()
global _bus
# Setup bus + consumers (heartbeat interval from config)
bus = EventBus(
queue_size=10_000,
heartbeat_interval=settings.heartbeat_timeout / 2, # check twice per timeout
)
storage = StorageConsumer()
await storage.start()
@@ -192,16 +212,29 @@ async def main(provider_names: list[str], symbols: list[str]) -> None:
printer = PrintConsumer()
bus.add_consumer(printer)
# Optional: NATS output adapter
nats_consumer = None
if settings.nats_configured:
nats_consumer = NatsOutputConsumer()
await nats_consumer.start()
bus.add_consumer(nats_consumer)
logger.info("nats_output.enabled", subject_prefix=settings.nats_subject_prefix)
else:
logger.info("nats_output.disabled", hint="Set NATS_URL + NATS_ENABLED=true to enable")
# Create providers and register them for heartbeat monitoring
providers: list[MarketDataProvider] = []
for name in provider_names:
p = get_provider(name)
providers.append(p)
bus.register_provider(p.name)
_bus = bus
await bus.start()
# Start HTTP server
http_server = await start_http_server()
# Create providers
providers: list[MarketDataProvider] = []
for name in provider_names:
providers.append(get_provider(name))
# Run all providers concurrently
tasks = []
for p in providers:
@@ -224,21 +257,43 @@ async def main(provider_names: list[str], symbols: list[str]) -> None:
# Wait for shutdown
await shutdown_event.wait()
# Cleanup
# ── Graceful shutdown sequence ──────────────────────────────────────
logger.info("service.shutting_down")
# 1. Cancel provider streaming tasks (with timeout)
for task in tasks:
task.cancel()
done, pending = await asyncio.wait(tasks, timeout=5.0)
for task in pending:
logger.warning("service.task_force_cancel", extra={"task": task.get_name()})
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
# 2. Close provider WebSocket connections
for p in providers:
await p.close()
try:
await p.close()
except Exception as e:
logger.warning("service.provider_close_error", extra={"provider": p.name, "error": str(e)})
# 3. Stop bus (drains remaining events to consumers)
await bus.stop()
# 4. Stop storage consumer (flush JSONL)
await storage.stop()
# 4b. Stop NATS output (flush + close)
if nats_consumer:
await nats_consumer.stop()
# 5. Close HTTP server
http_server.close()
await http_server.wait_closed()
logger.info("service.stopped")
# 6. Close SQLAlchemy engine (flush connections)
await engine.dispose()
logger.info("service.stopped", extra={"exit": "clean"})
# ── CLI ────────────────────────────────────────────────────────────────
@@ -270,7 +325,7 @@ def cli():
symbols = [s.strip() for s in args.symbols.split(",") if s.strip()]
if args.provider.lower() == "all":
provider_names = ["binance", "alpaca"]
provider_names = ["binance", "alpaca", "bybit"]
else:
provider_names = [p.strip() for p in args.provider.split(",") if p.strip()]

View File

@@ -45,10 +45,12 @@ def get_provider(name: str) -> MarketDataProvider:
"""Factory: instantiate provider by name."""
from app.providers.binance import BinanceProvider
from app.providers.alpaca import AlpacaProvider
from app.providers.bybit import BybitProvider
registry: dict[str, type[MarketDataProvider]] = {
"binance": BinanceProvider,
"alpaca": AlpacaProvider,
"bybit": BybitProvider,
}
cls = registry.get(name.lower())
if cls is None:

View File

@@ -17,7 +17,7 @@ from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timezone
from datetime import datetime
from typing import AsyncIterator
import websockets

View File

@@ -12,7 +12,6 @@ from __future__ import annotations
import asyncio
import json
import logging
import time
from datetime import datetime, timezone
from typing import AsyncIterator
@@ -22,7 +21,6 @@ from websockets.exceptions import ConnectionClosed
from app.config import settings
from app.domain.events import (
Event,
HeartbeatEvent,
QuoteEvent,
TradeEvent,
)

View File

@@ -0,0 +1,239 @@
"""
Bybit V5 public WebSocket provider — backup for Binance.
Streams:
- publicTrade.{symbol} → TradeEvent
- tickers.{symbol} → QuoteEvent (best bid/ask from tickers)
Docs: https://bybit-exchange.github.io/docs/v5/ws/connect
No API key needed for public market data.
"""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import AsyncIterator
import websockets
from websockets.exceptions import ConnectionClosed
from app.config import settings
from app.domain.events import (
Event,
QuoteEvent,
TradeEvent,
)
from app.providers import MarketDataProvider
logger = logging.getLogger(__name__)
def _ms_to_dt(ms: int | float | str | None) -> datetime | None:
"""Convert millisecond epoch to UTC datetime."""
if ms is None:
return None
try:
return datetime.fromtimestamp(int(ms) / 1000.0, tz=timezone.utc)
except (ValueError, TypeError, OSError):
return None
class BybitProvider(MarketDataProvider):
"""
Bybit V5 public WebSocket (spot market).
Connects to the spot public channel and subscribes to
publicTrade + tickers for each symbol.
"""
name = "bybit"
def __init__(self) -> None:
self._ws: websockets.WebSocketClientProtocol | None = None
self._symbols: list[str] = []
self._connected = False
self._reconnect_count = 0
self._base_url = settings.bybit_ws_url
async def connect(self) -> None:
"""Establish WebSocket connection."""
logger.info("bybit.connecting", extra={"url": self._base_url})
self._ws = await websockets.connect(
self._base_url,
ping_interval=20,
ping_timeout=10,
close_timeout=5,
)
self._connected = True
logger.info("bybit.connected")
async def subscribe(self, symbols: list[str]) -> None:
"""Subscribe to publicTrade + tickers for each symbol."""
if not self._ws:
raise RuntimeError("Not connected. Call connect() first.")
self._symbols = [s.upper() for s in symbols]
args = []
for sym in self._symbols:
args.append(f"publicTrade.{sym}")
args.append(f"tickers.{sym}")
subscribe_msg = {
"op": "subscribe",
"args": args,
}
await self._ws.send(json.dumps(subscribe_msg))
logger.info(
"bybit.subscribed",
extra={"symbols": self._symbols, "channels": len(args)},
)
async def stream(self) -> AsyncIterator[Event]:
"""Yield domain events. Handles reconnect automatically."""
backoff = settings.reconnect_base_delay
while True:
try:
if not self._connected or not self._ws:
await self._reconnect(backoff)
try:
raw = await asyncio.wait_for(
self._ws.recv(), # type: ignore
timeout=settings.heartbeat_timeout,
)
except asyncio.TimeoutError:
logger.warning(
"bybit.heartbeat_timeout",
extra={"timeout": settings.heartbeat_timeout},
)
self._connected = False
continue
# Reset backoff on successful message
backoff = settings.reconnect_base_delay
data = json.loads(raw)
# Handle pong (Bybit sends {"op":"pong",...})
if data.get("op") in ("pong", "subscribe"):
if data.get("success") is False:
logger.warning("bybit.subscribe_failed", extra={"msg": data})
continue
event = self._parse(data)
if event:
yield event
except ConnectionClosed as e:
logger.warning(
"bybit.connection_closed",
extra={"code": e.code, "reason": str(e.reason)},
)
self._connected = False
backoff = min(backoff * 2, settings.reconnect_max_delay)
except Exception as e:
logger.error("bybit.stream_error", extra={"error": str(e)})
self._connected = False
backoff = min(backoff * 2, settings.reconnect_max_delay)
async def _reconnect(self, delay: float) -> None:
"""Reconnect with delay, then resubscribe."""
self._reconnect_count += 1
logger.info(
"bybit.reconnecting",
extra={"delay": delay, "attempt": self._reconnect_count},
)
await asyncio.sleep(delay)
try:
if self._ws:
await self._ws.close()
except Exception:
pass
await self.connect()
if self._symbols:
await self.subscribe(self._symbols)
def _parse(self, data: dict) -> Event | None:
"""Parse raw Bybit V5 message into domain events."""
topic = data.get("topic", "")
event_data = data.get("data")
if not topic or event_data is None:
return None
if topic.startswith("publicTrade."):
return self._parse_trades(event_data)
elif topic.startswith("tickers."):
return self._parse_ticker(event_data)
return None
def _parse_trades(self, data: list | dict) -> Event | None:
"""
Bybit publicTrade payload (V5):
{"data": [{"s":"BTCUSDT","S":"Buy","v":"0.001","p":"70000.5","T":1672515782136,"i":"..."}]}
We take the last trade in the batch.
"""
if isinstance(data, list):
if not data:
return None
trade = data[-1] # latest in batch
else:
trade = data
return TradeEvent(
provider=self.name,
symbol=trade.get("s", "").upper(),
price=float(trade.get("p", 0)),
size=float(trade.get("v", 0)),
ts_exchange=_ms_to_dt(trade.get("T")),
side=trade.get("S", "").lower() if trade.get("S") else None,
trade_id=str(trade.get("i", "")),
)
def _parse_ticker(self, data: dict) -> QuoteEvent | None:
"""
Bybit tickers (V5 spot):
{"data": {"symbol":"BTCUSDT","bid1Price":"70000.5","bid1Size":"1.5",
"ask1Price":"70001.0","ask1Size":"2.0",...}}
"""
if isinstance(data, list):
data = data[0] if data else {}
bid = data.get("bid1Price") or data.get("bidPrice")
ask = data.get("ask1Price") or data.get("askPrice")
bid_size = data.get("bid1Size") or data.get("bidSize")
ask_size = data.get("ask1Size") or data.get("askSize")
if not bid or not ask:
return None
return QuoteEvent(
provider=self.name,
symbol=data.get("symbol", "").upper(),
bid=float(bid),
ask=float(ask),
bid_size=float(bid_size or 0),
ask_size=float(ask_size or 0),
ts_exchange=_ms_to_dt(data.get("ts")),
)
async def close(self) -> None:
"""Close the WebSocket connection."""
self._connected = False
if self._ws:
try:
await self._ws.close()
except Exception:
pass
logger.info(
"bybit.closed",
extra={"reconnect_count": self._reconnect_count},
)