feat: market-data-service for SenpAI trading agent

New service: real-time market data collection with unified event model.

Architecture:
- Domain events: TradeEvent, QuoteEvent, BookL2Event, HeartbeatEvent
- Provider interface: MarketDataProvider ABC with connect/subscribe/stream/close
- Async EventBus with fan-out to multiple consumers

Providers:
- BinanceProvider: public WebSocket (trades + bookTicker), no API key needed,
  auto-reconnect with exponential backoff, heartbeat timeout detection
- AlpacaProvider: IEX real-time data + paper trading auth,
  dry-run mode when no keys configured (heartbeats only)

Consumers:
- StorageConsumer: SQLite (via SQLAlchemy async) + JSONL append-only log
- MetricsConsumer: Prometheus counters, latency histograms, events/sec gauge
- PrintConsumer: sampled structured logging (1/100 events)

CLI: python -m app run --provider binance --symbols BTCUSDT,ETHUSDT
HTTP: /health, /metrics (Prometheus), /latest?symbol=XXX

Tests: 19/19 passed (Binance parse, Alpaca parse, bus smoke tests)

Config: pydantic-settings + .env, all secrets via environment variables.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Apple
2026-02-09 11:19:00 -08:00
parent ad6b6d2662
commit c50843933f
26 changed files with 2036 additions and 0 deletions

View File

@@ -0,0 +1,31 @@
# Market Data Service Configuration
# Copy to .env and fill in your values
# ── Binance (no key needed for public WebSocket) ──────────────────────
BINANCE_WS_URL=wss://stream.binance.com:9443/ws
# ── Alpaca (paper trading — free) ─────────────────────────────────────
# Get free paper keys at: https://app.alpaca.markets/paper/dashboard/overview
ALPACA_KEY=
ALPACA_SECRET=
ALPACA_BASE_URL=https://paper-api.alpaca.markets
ALPACA_DATA_WS_URL=wss://stream.data.alpaca.markets/v2/iex
ALPACA_DRY_RUN=true # Set to false when keys are configured
# ── Storage ───────────────────────────────────────────────────────────
SQLITE_URL=sqlite+aiosqlite:///market_data.db
JSONL_PATH=events.jsonl
# ── Reliability ───────────────────────────────────────────────────────
RECONNECT_MAX_RETRIES=20
RECONNECT_BASE_DELAY=1.0
RECONNECT_MAX_DELAY=60.0
HEARTBEAT_TIMEOUT=30.0
# ── HTTP Server ───────────────────────────────────────────────────────
HTTP_HOST=0.0.0.0
HTTP_PORT=8891
# ── Logging ───────────────────────────────────────────────────────────
LOG_LEVEL=INFO
LOG_SAMPLE_RATE=100

View File

@@ -0,0 +1,153 @@
# Market Data Service (SenpAI)
Real-time market data collection and normalization for the SenpAI/Gordon trading agent.
## Quick Start
### 1. Install
```bash
cd services/market-data-service
pip install -r requirements.txt
```
### 2. Copy config
```bash
cp .env.example .env
```
### 3. Run (Binance — no keys needed)
```bash
python -m app run --provider binance --symbols BTCUSDT,ETHUSDT
```
### 4. Run (Alpaca — paper trading)
First, get free paper-trading API keys:
1. Sign up at https://app.alpaca.markets
2. Switch to **Paper Trading** in the dashboard
3. Go to API Keys → Generate New Key
4. Add to `.env`:
```
ALPACA_KEY=your_key_here
ALPACA_SECRET=your_secret_here
ALPACA_DRY_RUN=false
```
5. Run:
```bash
python -m app run --provider alpaca --symbols AAPL,TSLA
```
Without keys, Alpaca runs in **dry-run mode** (heartbeats only).
### 5. Run both providers
```bash
python -m app run --provider all --symbols BTCUSDT,AAPL
```
## HTTP Endpoints
Once running, the service exposes:
| Endpoint | Description |
|---|---|
| `GET /health` | Service health check |
| `GET /metrics` | Prometheus metrics |
| `GET /latest?symbol=BTCUSDT` | Latest trade + quote from SQLite |
Default port: `8891` (configurable via `HTTP_PORT`).
## View Data
### SQLite
```bash
sqlite3 market_data.db "SELECT * FROM trades ORDER BY ts_recv DESC LIMIT 5;"
```
### JSONL Event Log
```bash
tail -5 events.jsonl | python -m json.tool
```
### Prometheus Metrics
```bash
curl http://localhost:8891/metrics
```
Key metrics:
- `market_events_total` — events by provider/type/symbol
- `market_exchange_latency_ms` — exchange-to-receive latency
- `market_events_per_second` — throughput gauge
## Architecture
```
Provider (Binance/Alpaca)
│ raw WebSocket messages
Adapter (_parse → domain Event)
│ TradeEvent / QuoteEvent / BookL2Event
EventBus (asyncio.Queue fan-out)
├─▶ StorageConsumer → SQLite + JSONL
├─▶ MetricsConsumer → Prometheus counters/histograms
└─▶ PrintConsumer → structured log (sampled 1/100)
```
## Adding a New Provider
1. Create `app/providers/your_provider.py`
2. Subclass `MarketDataProvider`:
```python
from app.providers import MarketDataProvider
from app.domain.events import Event, TradeEvent
class YourProvider(MarketDataProvider):
name = "your_provider"
async def connect(self) -> None:
# Establish connection
...
async def subscribe(self, symbols: list[str]) -> None:
# Subscribe to streams
...
async def stream(self) -> AsyncIterator[Event]:
# Yield normalized events, handle reconnect
while True:
raw = await self._receive()
yield self._parse(raw)
async def close(self) -> None:
...
```
3. Register in `app/providers/__init__.py`:
```python
from app.providers.your_provider import YourProvider
registry = {
...
"your_provider": YourProvider,
}
```
4. Run: `python -m app run --provider your_provider --symbols ...`
## Tests
```bash
pytest tests/ -v
```
## TODO: Future Providers
- [ ] CoinAPI (REST + WebSocket, paid tier)
- [ ] IQFeed (US equities, DTN subscription)
- [ ] Polygon.io (real-time + historical)
- [ ] Interactive Brokers TWS API

View File

@@ -0,0 +1,4 @@
"""Allow running as: python -m app run --provider binance --symbols BTCUSDT"""
from app.main import cli
cli()

View File

@@ -0,0 +1,53 @@
"""
Configuration via pydantic-settings.
All secrets come from .env; no defaults for sensitive keys.
"""
from __future__ import annotations
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
)
# ── Binance (no key needed for public WS) ──────────────────────────
binance_ws_url: str = "wss://stream.binance.com:9443/ws"
binance_rest_url: str = "https://api.binance.com"
# ── Alpaca (paper trading — free tier) ─────────────────────────────
alpaca_key: str = ""
alpaca_secret: str = ""
alpaca_base_url: str = "https://paper-api.alpaca.markets"
alpaca_data_ws_url: str = "wss://stream.data.alpaca.markets/v2/iex"
alpaca_dry_run: bool = True # True = skip real API calls if no keys
# ── Storage ────────────────────────────────────────────────────────
sqlite_url: str = "sqlite+aiosqlite:///market_data.db"
jsonl_path: str = "events.jsonl"
# ── Reliability ────────────────────────────────────────────────────
reconnect_max_retries: int = 20
reconnect_base_delay: float = 1.0 # seconds, exponential backoff
reconnect_max_delay: float = 60.0
heartbeat_timeout: float = 30.0 # no-message timeout → reconnect
# ── Metrics / HTTP ─────────────────────────────────────────────────
http_host: str = "0.0.0.0"
http_port: int = 8891
metrics_enabled: bool = True
# ── Logging ────────────────────────────────────────────────────────
log_level: str = "INFO"
log_sample_rate: int = 100 # PrintConsumer: log 1 out of N events
@property
def alpaca_configured(self) -> bool:
return bool(self.alpaca_key and self.alpaca_secret)
settings = Settings()

View File

@@ -0,0 +1,98 @@
"""
MetricsConsumer: Prometheus counters + latency histograms.
"""
from __future__ import annotations
import logging
import time
from prometheus_client import Counter, Gauge, Histogram, Summary
from app.domain.events import Event, EventType
logger = logging.getLogger(__name__)
# ── Prometheus metrics ─────────────────────────────────────────────────
EVENTS_TOTAL = Counter(
"market_events_total",
"Total market data events received",
["provider", "event_type", "symbol"],
)
EVENTS_PER_SECOND = Gauge(
"market_events_per_second",
"Approximate events per second",
["provider"],
)
EXCHANGE_LATENCY = Histogram(
"market_exchange_latency_ms",
"Latency from exchange timestamp to receive (ms)",
["provider"],
buckets=[1, 5, 10, 25, 50, 100, 250, 500, 1000, 5000],
)
RECV_LATENCY = Summary(
"market_recv_latency_ns",
"Internal receive latency (nanoseconds, monotonic)",
["provider"],
)
GAPS = Counter(
"market_gaps_total",
"Number of detected message gaps (heartbeat timeouts)",
["provider"],
)
class MetricsConsumer:
"""
Computes and exposes Prometheus metrics from the event stream.
"""
def __init__(self) -> None:
self._last_ts: dict[str, float] = {} # provider → last time.time()
self._window_counts: dict[str, int] = {}
self._window_start: dict[str, float] = {}
async def handle(self, event: Event) -> None:
provider = event.provider
event_type = event.event_type.value
symbol = getattr(event, "symbol", "__heartbeat__")
# Count
EVENTS_TOTAL.labels(
provider=provider,
event_type=event_type,
symbol=symbol,
).inc()
# Exchange latency (if ts_exchange available)
ts_exchange = getattr(event, "ts_exchange", None)
if ts_exchange is not None:
latency_ms = (event.ts_recv.timestamp() - ts_exchange.timestamp()) * 1000
if 0 < latency_ms < 60_000: # sanity: 060s
EXCHANGE_LATENCY.labels(provider=provider).observe(latency_ms)
# Internal receive latency
RECV_LATENCY.labels(provider=provider).observe(event.ts_recv_mono_ns)
# Events/sec approximation (1-second window)
now = time.time()
if provider not in self._window_start:
self._window_start[provider] = now
self._window_counts[provider] = 0
self._window_counts[provider] += 1
elapsed = now - self._window_start[provider]
if elapsed >= 1.0:
eps = self._window_counts[provider] / elapsed
EVENTS_PER_SECOND.labels(provider=provider).set(eps)
self._window_start[provider] = now
self._window_counts[provider] = 0
# Gap detection: heartbeat events signal a potential gap
if event.event_type == EventType.HEARTBEAT:
GAPS.labels(provider=provider).inc()

View File

@@ -0,0 +1,59 @@
"""
PrintConsumer: structured debug logging (sampled).
"""
from __future__ import annotations
import logging
from app.config import settings
from app.domain.events import Event, EventType
logger = logging.getLogger(__name__)
class PrintConsumer:
"""
Logs 1 out of every N events for debugging.
Always logs heartbeats and first event per symbol.
"""
def __init__(self, sample_rate: int | None = None) -> None:
self._sample_rate = sample_rate or settings.log_sample_rate
self._count = 0
self._seen_symbols: set[str] = set()
async def handle(self, event: Event) -> None:
self._count += 1
symbol = getattr(event, "symbol", None)
# Always log heartbeats
if event.event_type == EventType.HEARTBEAT:
logger.info(
"event.heartbeat",
extra={"provider": event.provider},
)
return
# Always log first event for a new symbol
force_log = False
if symbol and symbol not in self._seen_symbols:
self._seen_symbols.add(symbol)
force_log = True
# Sample
if force_log or self._count % self._sample_rate == 0:
extra = {
"provider": event.provider,
"type": event.event_type.value,
"symbol": symbol or "?",
"n": self._count,
}
if event.event_type == EventType.TRADE:
extra["price"] = getattr(event, "price", None)
extra["size"] = getattr(event, "size", None)
elif event.event_type == EventType.QUOTE:
extra["bid"] = getattr(event, "bid", None)
extra["ask"] = getattr(event, "ask", None)
logger.info("event.sample", extra=extra)

View File

@@ -0,0 +1,64 @@
"""
StorageConsumer: persists events to SQLite + JSONL log.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from app.config import settings
from app.domain.events import (
BookL2Event,
Event,
EventType,
QuoteEvent,
TradeEvent,
)
from app.db import repo
logger = logging.getLogger(__name__)
class StorageConsumer:
"""
Writes every event to:
1. SQLite (via async repo) — structured, queryable.
2. JSONL file — append-only event log for replay/audit.
"""
def __init__(self, jsonl_path: str | None = None) -> None:
self._jsonl_path = Path(jsonl_path or settings.jsonl_path)
self._jsonl_file = None
self._count = 0
async def start(self) -> None:
"""Open JSONL file for appending."""
self._jsonl_file = open(self._jsonl_path, "a", buffering=1) # line-buffered
logger.info("storage.started", extra={"jsonl": str(self._jsonl_path)})
async def handle(self, event: Event) -> None:
"""Persist one event."""
# 1. JSONL log (always)
if self._jsonl_file:
line = event.model_dump_json()
self._jsonl_file.write(line + "\n")
# 2. SQLite (by type)
if event.event_type == EventType.TRADE:
assert isinstance(event, TradeEvent)
await repo.save_trade(event)
elif event.event_type == EventType.QUOTE:
assert isinstance(event, QuoteEvent)
await repo.save_quote(event)
elif event.event_type == EventType.BOOK_L2:
assert isinstance(event, BookL2Event)
await repo.save_book_snapshot(event)
# Heartbeats → only JSONL, not SQLite
self._count += 1
async def stop(self) -> None:
if self._jsonl_file:
self._jsonl_file.close()
logger.info("storage.stopped", extra={"events_written": self._count})

View File

@@ -0,0 +1,94 @@
"""
Async event bus — fan-out from providers to consumers.
Usage:
bus = EventBus()
bus.add_consumer(storage_consumer)
bus.add_consumer(metrics_consumer)
async for event in provider.stream():
await bus.publish(event)
"""
from __future__ import annotations
import asyncio
import logging
from typing import Protocol
from app.domain.events import Event
logger = logging.getLogger(__name__)
class EventConsumer(Protocol):
"""Any callable that accepts an Event."""
async def handle(self, event: Event) -> None: ...
class EventBus:
"""
Simple async fan-out bus.
Every published event is dispatched to all registered consumers
concurrently (gather). A slow consumer doesn't block others thanks
to the internal queue + worker pattern.
"""
def __init__(self, queue_size: int = 10_000) -> None:
self._consumers: list[EventConsumer] = []
self._queue: asyncio.Queue[Event | None] = asyncio.Queue(maxsize=queue_size)
self._running = False
self._task: asyncio.Task | None = None
def add_consumer(self, consumer: EventConsumer) -> None:
self._consumers.append(consumer)
logger.info("bus.consumer_added", extra={"consumer": type(consumer).__name__})
async def publish(self, event: Event) -> None:
"""Put event into internal queue (non-blocking if queue not full)."""
try:
self._queue.put_nowait(event)
except asyncio.QueueFull:
logger.warning("bus.queue_full, dropping oldest event")
# Drop oldest to keep queue moving
try:
self._queue.get_nowait()
except asyncio.QueueEmpty:
pass
self._queue.put_nowait(event)
async def _worker(self) -> None:
"""Background worker that drains the queue and fans out."""
while self._running:
event = await self._queue.get()
if event is None:
break
tasks = [c.handle(event) for c in self._consumers]
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, Exception):
consumer_name = type(self._consumers[i]).__name__
logger.error(
"bus.consumer_error",
extra={"consumer": consumer_name, "error": str(result)},
)
async def start(self) -> None:
"""Start the bus worker."""
self._running = True
self._task = asyncio.create_task(self._worker())
logger.info("bus.started", extra={"consumers": len(self._consumers)})
async def stop(self) -> None:
"""Graceful shutdown: drain queue then stop."""
self._running = False
await self._queue.put(None) # sentinel
if self._task:
await self._task
logger.info("bus.stopped")
@property
def queue_size(self) -> int:
return self._queue.qsize()

View File

@@ -0,0 +1,110 @@
"""
Repository: thin async wrapper over SQLAlchemy for read/write.
"""
from __future__ import annotations
import json
from typing import Optional
from sqlalchemy import desc, select
from app.db.schema import (
BookSnapshotRecord,
QuoteRecord,
TradeRecord,
async_session,
)
from app.domain.events import BookL2Event, QuoteEvent, TradeEvent
async def save_trade(event: TradeEvent) -> None:
async with async_session() as session:
record = TradeRecord(
provider=event.provider,
symbol=event.symbol,
price=event.price,
size=event.size,
side=event.side,
trade_id=event.trade_id,
ts_exchange=event.ts_exchange,
ts_recv=event.ts_recv,
)
session.add(record)
await session.commit()
async def save_quote(event: QuoteEvent) -> None:
async with async_session() as session:
record = QuoteRecord(
provider=event.provider,
symbol=event.symbol,
bid=event.bid,
ask=event.ask,
bid_size=event.bid_size,
ask_size=event.ask_size,
ts_exchange=event.ts_exchange,
ts_recv=event.ts_recv,
)
session.add(record)
await session.commit()
async def save_book_snapshot(event: BookL2Event) -> None:
async with async_session() as session:
record = BookSnapshotRecord(
provider=event.provider,
symbol=event.symbol,
bids_json=json.dumps([{"price": b.price, "size": b.size} for b in event.bids]),
asks_json=json.dumps([{"price": a.price, "size": a.size} for a in event.asks]),
depth=max(len(event.bids), len(event.asks)),
ts_exchange=event.ts_exchange,
ts_recv=event.ts_recv,
)
session.add(record)
await session.commit()
async def get_latest_trade(symbol: str) -> Optional[dict]:
async with async_session() as session:
stmt = (
select(TradeRecord)
.where(TradeRecord.symbol == symbol.upper())
.order_by(desc(TradeRecord.ts_recv))
.limit(1)
)
result = await session.execute(stmt)
row = result.scalar_one_or_none()
if row is None:
return None
return {
"symbol": row.symbol,
"price": row.price,
"size": row.size,
"side": row.side,
"provider": row.provider,
"ts_recv": row.ts_recv.isoformat() if row.ts_recv else None,
"ts_exchange": row.ts_exchange.isoformat() if row.ts_exchange else None,
}
async def get_latest_quote(symbol: str) -> Optional[dict]:
async with async_session() as session:
stmt = (
select(QuoteRecord)
.where(QuoteRecord.symbol == symbol.upper())
.order_by(desc(QuoteRecord.ts_recv))
.limit(1)
)
result = await session.execute(stmt)
row = result.scalar_one_or_none()
if row is None:
return None
return {
"symbol": row.symbol,
"bid": row.bid,
"ask": row.ask,
"bid_size": row.bid_size,
"ask_size": row.ask_size,
"provider": row.provider,
"ts_recv": row.ts_recv.isoformat() if row.ts_recv else None,
}

View File

@@ -0,0 +1,77 @@
"""
SQLAlchemy async models for market data storage.
"""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import DateTime, Float, Index, Integer, String, Text
from sqlalchemy.ext.asyncio import AsyncAttrs, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from app.config import settings
class Base(AsyncAttrs, DeclarativeBase):
pass
class TradeRecord(Base):
__tablename__ = "trades"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
provider: Mapped[str] = mapped_column(String(32), nullable=False)
symbol: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
price: Mapped[float] = mapped_column(Float, nullable=False)
size: Mapped[float] = mapped_column(Float, nullable=False)
side: Mapped[str | None] = mapped_column(String(8), nullable=True)
trade_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
ts_exchange: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
ts_recv: Mapped[datetime] = mapped_column(DateTime, nullable=False)
__table_args__ = (
Index("ix_trades_symbol_ts", "symbol", "ts_recv"),
)
class QuoteRecord(Base):
__tablename__ = "quotes"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
provider: Mapped[str] = mapped_column(String(32), nullable=False)
symbol: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
bid: Mapped[float] = mapped_column(Float, nullable=False)
ask: Mapped[float] = mapped_column(Float, nullable=False)
bid_size: Mapped[float] = mapped_column(Float, nullable=False)
ask_size: Mapped[float] = mapped_column(Float, nullable=False)
ts_exchange: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
ts_recv: Mapped[datetime] = mapped_column(DateTime, nullable=False)
__table_args__ = (
Index("ix_quotes_symbol_ts", "symbol", "ts_recv"),
)
class BookSnapshotRecord(Base):
__tablename__ = "book_snapshots"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
provider: Mapped[str] = mapped_column(String(32), nullable=False)
symbol: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
bids_json: Mapped[str] = mapped_column(Text, nullable=False) # JSON
asks_json: Mapped[str] = mapped_column(Text, nullable=False) # JSON
depth: Mapped[int] = mapped_column(Integer, nullable=False)
ts_exchange: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
ts_recv: Mapped[datetime] = mapped_column(DateTime, nullable=False)
# ── Engine & Session factory ──────────────────────────────────────────
engine = create_async_engine(settings.sqlite_url, echo=False)
async_session = async_sessionmaker(engine, expire_on_commit=False)
async def init_db() -> None:
"""Create all tables."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)

View File

@@ -0,0 +1,89 @@
"""
Unified domain events for market data.
All providers normalize raw messages into these canonical types.
Timestamps are always UTC. ts_exchange may be None if the source
doesn't provide exchange timestamps.
"""
from __future__ import annotations
import time
from datetime import datetime, timezone
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class EventType(str, Enum):
TRADE = "trade"
QUOTE = "quote"
BOOK_L2 = "book_l2"
HEARTBEAT = "heartbeat"
def _utc_now() -> datetime:
return datetime.now(timezone.utc)
def _mono_ns() -> int:
"""Monotonic nanoseconds for internal latency measurement."""
return time.monotonic_ns()
class BaseEvent(BaseModel):
"""Common fields for every market-data event."""
event_type: EventType
provider: str
ts_recv: datetime = Field(default_factory=_utc_now)
ts_recv_mono_ns: int = Field(default_factory=_mono_ns)
class TradeEvent(BaseEvent):
"""A single matched trade (fill)."""
event_type: EventType = EventType.TRADE
symbol: str
price: float
size: float
ts_exchange: Optional[datetime] = None
side: Optional[str] = None # "buy" | "sell" | None
trade_id: Optional[str] = None
class QuoteEvent(BaseEvent):
"""Best bid/ask (top-of-book)."""
event_type: EventType = EventType.QUOTE
symbol: str
bid: float
ask: float
bid_size: float
ask_size: float
ts_exchange: Optional[datetime] = None
class BookLevel(BaseModel):
price: float
size: float
class BookL2Event(BaseEvent):
"""L2 order-book snapshot (partial depth)."""
event_type: EventType = EventType.BOOK_L2
symbol: str
bids: list[BookLevel] = Field(default_factory=list)
asks: list[BookLevel] = Field(default_factory=list)
ts_exchange: Optional[datetime] = None
class HeartbeatEvent(BaseEvent):
"""Provider heartbeat / keep-alive signal."""
event_type: EventType = EventType.HEARTBEAT
# Union type for type-safe consumers
Event = TradeEvent | QuoteEvent | BookL2Event | HeartbeatEvent

View File

@@ -0,0 +1,284 @@
"""
Market Data Service — entry point.
CLI:
python -m app run --provider binance --symbols BTCUSDT,ETHUSDT
python -m app run --provider alpaca --symbols AAPL,TSLA
python -m app run --provider all --symbols BTCUSDT,AAPL
HTTP (optional, runs alongside):
GET /health
GET /metrics (Prometheus)
GET /latest?symbol=BTCUSDT
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import signal
import sys
from contextlib import asynccontextmanager
import structlog
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
from app.config import settings
from app.core.bus import EventBus
from app.consumers.metrics import MetricsConsumer
from app.consumers.print import PrintConsumer
from app.consumers.storage import StorageConsumer
from app.db.schema import init_db
from app.db import repo
from app.providers import MarketDataProvider, get_provider
logger = structlog.get_logger()
# ── Logging setup ──────────────────────────────────────────────────────
def setup_logging() -> None:
log_level = getattr(logging, settings.log_level.upper(), logging.INFO)
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.TimeStamper(fmt="iso"),
structlog.dev.ConsoleRenderer(),
],
wrapper_class=structlog.make_filtering_bound_logger(log_level),
context_class=dict,
logger_factory=structlog.PrintLoggerFactory(),
)
logging.basicConfig(level=log_level, format="%(message)s")
# ── HTTP Server (lightweight, no FastAPI dependency) ───────────────────
async def _http_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
"""Minimal HTTP handler for health/metrics/latest."""
try:
request_line = await asyncio.wait_for(reader.readline(), timeout=5.0)
request_str = request_line.decode("utf-8", errors="replace").strip()
# Parse method + path
parts = request_str.split()
if len(parts) < 2:
writer.close()
return
path = parts[1]
# Consume headers
while True:
line = await reader.readline()
if line in (b"\r\n", b"\n", b""):
break
if path == "/health":
body = b'{"status":"ok","service":"market-data-service"}'
content_type = "application/json"
elif path == "/metrics":
body = generate_latest()
content_type = CONTENT_TYPE_LATEST
elif path.startswith("/latest"):
# Parse ?symbol=XXX
symbol = ""
if "?" in path:
query = path.split("?", 1)[1]
for param in query.split("&"):
if param.startswith("symbol="):
symbol = param.split("=", 1)[1]
if not symbol:
body = b'{"error":"missing ?symbol=XXX"}'
content_type = "application/json"
else:
import json
trade = await repo.get_latest_trade(symbol)
quote = await repo.get_latest_quote(symbol)
result = {
"symbol": symbol.upper(),
"latest_trade": trade,
"latest_quote": quote,
}
body = json.dumps(result, ensure_ascii=False).encode()
content_type = "application/json"
else:
body = b'{"error":"not found"}'
content_type = "application/json"
response = (
f"HTTP/1.1 200 OK\r\n"
f"Content-Type: {content_type}\r\n"
f"Content-Length: {len(body)}\r\n"
f"Connection: close\r\n"
f"\r\n"
)
writer.write(response.encode() + body)
await writer.drain()
except Exception:
pass
finally:
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
async def start_http_server() -> asyncio.Server:
server = await asyncio.start_server(
_http_handler,
settings.http_host,
settings.http_port,
)
logger.info(
"http.started",
host=settings.http_host,
port=settings.http_port,
endpoints=["/health", "/metrics", "/latest?symbol=XXX"],
)
return server
# ── Provider runner ────────────────────────────────────────────────────
async def run_provider(
provider: MarketDataProvider,
symbols: list[str],
bus: EventBus,
) -> None:
"""Connect, subscribe, and stream events into the bus."""
await provider.connect()
await provider.subscribe(symbols)
logger.info(
"provider.streaming",
provider=provider.name,
symbols=symbols,
)
async for event in provider.stream():
await bus.publish(event)
# ── Main orchestrator ──────────────────────────────────────────────────
async def main(provider_names: list[str], symbols: list[str]) -> None:
setup_logging()
logger.info(
"service.starting",
providers=provider_names,
symbols=symbols,
)
# Init database
await init_db()
# Setup bus + consumers
bus = EventBus()
storage = StorageConsumer()
await storage.start()
bus.add_consumer(storage)
metrics = MetricsConsumer()
bus.add_consumer(metrics)
printer = PrintConsumer()
bus.add_consumer(printer)
await bus.start()
# Start HTTP server
http_server = await start_http_server()
# Create providers
providers: list[MarketDataProvider] = []
for name in provider_names:
providers.append(get_provider(name))
# Run all providers concurrently
tasks = []
for p in providers:
tasks.append(asyncio.create_task(run_provider(p, symbols, bus)))
# Graceful shutdown on SIGINT/SIGTERM
shutdown_event = asyncio.Event()
def _signal_handler():
logger.info("service.shutdown_signal")
shutdown_event.set()
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, _signal_handler)
except NotImplementedError:
pass # Windows
# Wait for shutdown
await shutdown_event.wait()
# Cleanup
logger.info("service.shutting_down")
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
for p in providers:
await p.close()
await bus.stop()
await storage.stop()
http_server.close()
await http_server.wait_closed()
logger.info("service.stopped")
# ── CLI ────────────────────────────────────────────────────────────────
def cli():
parser = argparse.ArgumentParser(
description="Market Data Service for SenpAI trading agent",
)
sub = parser.add_subparsers(dest="command")
run_parser = sub.add_parser("run", help="Start streaming market data")
run_parser.add_argument(
"--provider",
type=str,
default="binance",
help="Provider name: binance, alpaca, all (comma-separated)",
)
run_parser.add_argument(
"--symbols",
type=str,
required=True,
help="Comma-separated symbols (e.g. BTCUSDT,ETHUSDT)",
)
args = parser.parse_args()
if args.command == "run":
symbols = [s.strip() for s in args.symbols.split(",") if s.strip()]
if args.provider.lower() == "all":
provider_names = ["binance", "alpaca"]
else:
provider_names = [p.strip() for p in args.provider.split(",") if p.strip()]
asyncio.run(main(provider_names, symbols))
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
cli()

View File

@@ -0,0 +1,57 @@
"""
Market data provider interface and registry.
To add a new provider:
1. Create providers/your_provider.py
2. Subclass MarketDataProvider
3. Register in PROVIDER_REGISTRY below
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import AsyncIterator
from app.domain.events import Event
class MarketDataProvider(ABC):
"""
Base class for all market-data feed adapters.
Lifecycle: connect() → subscribe() → stream() → close()
"""
name: str = "unknown"
@abstractmethod
async def connect(self) -> None:
"""Establish connection to the data source."""
@abstractmethod
async def subscribe(self, symbols: list[str]) -> None:
"""Subscribe to symbols. May be called after reconnect."""
@abstractmethod
async def stream(self) -> AsyncIterator[Event]:
"""Yield normalized domain events. Must handle reconnect internally."""
yield # type: ignore
@abstractmethod
async def close(self) -> None:
"""Graceful shutdown."""
def get_provider(name: str) -> MarketDataProvider:
"""Factory: instantiate provider by name."""
from app.providers.binance import BinanceProvider
from app.providers.alpaca import AlpacaProvider
registry: dict[str, type[MarketDataProvider]] = {
"binance": BinanceProvider,
"alpaca": AlpacaProvider,
}
cls = registry.get(name.lower())
if cls is None:
available = ", ".join(registry.keys())
raise ValueError(f"Unknown provider '{name}'. Available: {available}")
return cls()

View File

@@ -0,0 +1,270 @@
"""
Alpaca Markets provider — paper trading + IEX real-time data.
Requires ALPACA_KEY + ALPACA_SECRET in .env for live mode.
Falls back to dry-run mode if keys are not configured.
Subscribes to:
- trades → TradeEvent
- quotes → QuoteEvent
Alpaca WebSocket protocol:
wss://stream.data.alpaca.markets/v2/iex
Auth → subscribe → stream messages
"""
from __future__ import annotations
import asyncio
import json
import logging
from datetime import datetime, timezone
from typing import AsyncIterator
import websockets
from websockets.exceptions import ConnectionClosed
from app.config import settings
from app.domain.events import (
Event,
HeartbeatEvent,
QuoteEvent,
TradeEvent,
)
from app.providers import MarketDataProvider
logger = logging.getLogger(__name__)
def _iso_to_dt(ts_str: str | None) -> datetime | None:
"""Parse Alpaca ISO-8601 timestamp to UTC datetime."""
if not ts_str:
return None
try:
# Alpaca uses RFC3339 with Z or +00:00
ts_str = ts_str.replace("Z", "+00:00")
return datetime.fromisoformat(ts_str)
except (ValueError, TypeError):
return None
class AlpacaProvider(MarketDataProvider):
"""
Alpaca IEX real-time data + paper trading integration.
In dry-run mode (no keys), generates synthetic heartbeats
and logs a warning — useful for testing the pipeline without keys.
"""
name = "alpaca"
def __init__(self) -> None:
self._ws: websockets.WebSocketClientProtocol | None = None
self._symbols: list[str] = []
self._connected = False
self._authenticated = False
self._reconnect_count = 0
self._dry_run = not settings.alpaca_configured or settings.alpaca_dry_run
async def connect(self) -> None:
"""Establish WebSocket connection and authenticate."""
if self._dry_run:
logger.warning(
"alpaca.dry_run_mode",
extra={"reason": "No ALPACA_KEY/ALPACA_SECRET or dry_run=True"},
)
self._connected = True
return
url = settings.alpaca_data_ws_url
logger.info("alpaca.connecting", extra={"url": url})
self._ws = await websockets.connect(
url,
ping_interval=20,
ping_timeout=10,
close_timeout=5,
)
# Read welcome message
welcome = await self._ws.recv()
welcome_data = json.loads(welcome)
logger.info("alpaca.welcome", extra={"msg": welcome_data})
# Authenticate
auth_msg = {
"action": "auth",
"key": settings.alpaca_key,
"secret": settings.alpaca_secret,
}
await self._ws.send(json.dumps(auth_msg))
auth_resp = await self._ws.recv()
auth_data = json.loads(auth_resp)
logger.info("alpaca.auth_response", extra={"msg": auth_data})
# Check auth result
if isinstance(auth_data, list):
for msg in auth_data:
if msg.get("T") == "error":
raise ConnectionError(f"Alpaca auth failed: {msg}")
if msg.get("T") == "success" and msg.get("msg") == "authenticated":
self._authenticated = True
self._connected = True
logger.info("alpaca.connected", extra={"authenticated": self._authenticated})
async def subscribe(self, symbols: list[str]) -> None:
"""Subscribe to trades + quotes for symbols."""
self._symbols = [s.upper() for s in symbols]
if self._dry_run:
logger.info(
"alpaca.dry_run_subscribe",
extra={"symbols": self._symbols},
)
return
if not self._ws:
raise RuntimeError("Not connected.")
sub_msg = {
"action": "subscribe",
"trades": self._symbols,
"quotes": self._symbols,
}
await self._ws.send(json.dumps(sub_msg))
# Read subscription confirmation
sub_resp = await self._ws.recv()
logger.info("alpaca.subscribed", extra={"response": json.loads(sub_resp)})
async def stream(self) -> AsyncIterator[Event]:
"""Yield domain events. Dry-run mode emits periodic heartbeats."""
if self._dry_run:
async for event in self._dry_run_stream():
yield event
return
backoff = settings.reconnect_base_delay
while True:
try:
if not self._connected or not self._ws:
await self._reconnect(backoff)
try:
raw = await asyncio.wait_for(
self._ws.recv(), # type: ignore
timeout=settings.heartbeat_timeout,
)
except asyncio.TimeoutError:
logger.warning("alpaca.heartbeat_timeout")
self._connected = False
continue
backoff = settings.reconnect_base_delay
messages = json.loads(raw)
# Alpaca sends arrays of messages
if not isinstance(messages, list):
messages = [messages]
for msg in messages:
event = self._parse(msg)
if event:
yield event
except ConnectionClosed as e:
logger.warning(
"alpaca.connection_closed",
extra={"code": e.code, "reason": str(e.reason)},
)
self._connected = False
backoff = min(backoff * 2, settings.reconnect_max_delay)
except Exception as e:
logger.error("alpaca.stream_error", extra={"error": str(e)})
self._connected = False
backoff = min(backoff * 2, settings.reconnect_max_delay)
async def _dry_run_stream(self) -> AsyncIterator[Event]:
"""Emit heartbeats in dry-run mode (no real data)."""
logger.info("alpaca.dry_run_stream_started")
while True:
yield HeartbeatEvent(provider=self.name)
await asyncio.sleep(5.0)
async def _reconnect(self, delay: float) -> None:
self._reconnect_count += 1
logger.info(
"alpaca.reconnecting",
extra={"delay": delay, "attempt": self._reconnect_count},
)
await asyncio.sleep(delay)
try:
if self._ws:
await self._ws.close()
except Exception:
pass
self._authenticated = False
await self.connect()
if self._symbols:
await self.subscribe(self._symbols)
def _parse(self, msg: dict) -> Event | None:
"""Parse single Alpaca message into domain event."""
msg_type = msg.get("T")
if msg_type == "t":
return self._parse_trade(msg)
elif msg_type == "q":
return self._parse_quote(msg)
elif msg_type in ("success", "subscription", "error"):
# Control messages — skip
return None
return None
def _parse_trade(self, data: dict) -> TradeEvent:
"""
Alpaca trade:
{"T":"t", "S":"AAPL", "p":150.25, "s":100, "t":"2024-01-15T...", "i":12345, ...}
"""
return TradeEvent(
provider=self.name,
symbol=data.get("S", "").upper(),
price=float(data.get("p", 0)),
size=float(data.get("s", 0)),
ts_exchange=_iso_to_dt(data.get("t")),
trade_id=str(data.get("i", "")),
)
def _parse_quote(self, data: dict) -> QuoteEvent:
"""
Alpaca quote:
{"T":"q", "S":"AAPL", "bp":150.24, "bs":200, "ap":150.26, "as":100,
"t":"2024-01-15T...", ...}
"""
return QuoteEvent(
provider=self.name,
symbol=data.get("S", "").upper(),
bid=float(data.get("bp", 0)),
ask=float(data.get("ap", 0)),
bid_size=float(data.get("bs", 0)),
ask_size=float(data.get("as", 0)),
ts_exchange=_iso_to_dt(data.get("t")),
)
async def close(self) -> None:
self._connected = False
if self._ws:
try:
await self._ws.close()
except Exception:
pass
logger.info(
"alpaca.closed",
extra={"reconnect_count": self._reconnect_count},
)

View File

@@ -0,0 +1,223 @@
"""
Binance public WebSocket provider.
No API key required. Subscribes to:
- <symbol>@trade → TradeEvent
- <symbol>@bookTicker → QuoteEvent
Auto-reconnect with exponential backoff via tenacity.
"""
from __future__ import annotations
import asyncio
import json
import logging
import time
from datetime import datetime, timezone
from typing import AsyncIterator
import websockets
from websockets.exceptions import ConnectionClosed
from app.config import settings
from app.domain.events import (
Event,
HeartbeatEvent,
QuoteEvent,
TradeEvent,
)
from app.providers import MarketDataProvider
logger = logging.getLogger(__name__)
def _ms_to_dt(ms: int | float | None) -> datetime | None:
"""Convert millisecond epoch to UTC datetime."""
if ms is None:
return None
return datetime.fromtimestamp(ms / 1000.0, tz=timezone.utc)
class BinanceProvider(MarketDataProvider):
"""
Binance public WebSocket streams.
Connects to the combined stream endpoint and subscribes to
trade + bookTicker channels for each symbol.
"""
name = "binance"
def __init__(self) -> None:
self._ws: websockets.WebSocketClientProtocol | None = None
self._symbols: list[str] = []
self._connected = False
self._reconnect_count = 0
self._base_url = settings.binance_ws_url
async def connect(self) -> None:
"""Establish WebSocket connection."""
logger.info("binance.connecting", extra={"url": self._base_url})
self._ws = await websockets.connect(
self._base_url,
ping_interval=20,
ping_timeout=10,
close_timeout=5,
)
self._connected = True
logger.info("binance.connected")
async def subscribe(self, symbols: list[str]) -> None:
"""Subscribe to trade + bookTicker for each symbol."""
if not self._ws:
raise RuntimeError("Not connected. Call connect() first.")
self._symbols = [s.lower() for s in symbols]
streams = []
for sym in self._symbols:
streams.append(f"{sym}@trade")
streams.append(f"{sym}@bookTicker")
subscribe_msg = {
"method": "SUBSCRIBE",
"params": streams,
"id": 1,
}
await self._ws.send(json.dumps(subscribe_msg))
logger.info(
"binance.subscribed",
extra={"symbols": self._symbols, "streams": len(streams)},
)
async def stream(self) -> AsyncIterator[Event]:
"""
Yield domain events. Handles reconnect automatically.
"""
backoff = settings.reconnect_base_delay
while True:
try:
if not self._connected or not self._ws:
await self._reconnect(backoff)
# Set timeout for heartbeat detection
try:
raw = await asyncio.wait_for(
self._ws.recv(), # type: ignore
timeout=settings.heartbeat_timeout,
)
except asyncio.TimeoutError:
logger.warning(
"binance.heartbeat_timeout",
extra={"timeout": settings.heartbeat_timeout},
)
self._connected = False
continue
# Reset backoff on successful message
backoff = settings.reconnect_base_delay
data = json.loads(raw)
# Skip subscription confirmations
if "result" in data and "id" in data:
continue
event = self._parse(data)
if event:
yield event
except ConnectionClosed as e:
logger.warning(
"binance.connection_closed",
extra={"code": e.code, "reason": str(e.reason)},
)
self._connected = False
backoff = min(backoff * 2, settings.reconnect_max_delay)
except Exception as e:
logger.error("binance.stream_error", extra={"error": str(e)})
self._connected = False
backoff = min(backoff * 2, settings.reconnect_max_delay)
async def _reconnect(self, delay: float) -> None:
"""Reconnect with delay, then resubscribe."""
self._reconnect_count += 1
logger.info(
"binance.reconnecting",
extra={"delay": delay, "attempt": self._reconnect_count},
)
await asyncio.sleep(delay)
try:
if self._ws:
await self._ws.close()
except Exception:
pass
await self.connect()
if self._symbols:
await self.subscribe(self._symbols)
def _parse(self, data: dict) -> Event | None:
"""Parse raw Binance JSON into domain events."""
event_type = data.get("e")
if event_type == "trade":
return self._parse_trade(data)
elif event_type == "bookTicker" or ("b" in data and "a" in data and "s" in data and "e" not in data):
# bookTicker doesn't always have "e" field in combined stream
return self._parse_book_ticker(data)
return None
def _parse_trade(self, data: dict) -> TradeEvent:
"""
Binance trade payload:
{
"e": "trade", "E": 1672515782136, "s": "BNBBTC",
"t": 12345, "p": "0.001", "q": "100",
"T": 1672515782136, "m": true
}
"""
return TradeEvent(
provider=self.name,
symbol=data.get("s", "").upper(),
price=float(data.get("p", 0)),
size=float(data.get("q", 0)),
ts_exchange=_ms_to_dt(data.get("T")),
side="sell" if data.get("m") else "buy", # m=True → buyer is maker → trade is a sell
trade_id=str(data.get("t", "")),
)
def _parse_book_ticker(self, data: dict) -> QuoteEvent:
"""
Binance bookTicker payload:
{
"u": 400900217, "s": "BNBUSDT",
"b": "25.35190000", "B": "31.21000000",
"a": "25.36520000", "A": "40.66000000"
}
"""
return QuoteEvent(
provider=self.name,
symbol=data.get("s", "").upper(),
bid=float(data.get("b", 0)),
ask=float(data.get("a", 0)),
bid_size=float(data.get("B", 0)),
ask_size=float(data.get("A", 0)),
ts_exchange=_ms_to_dt(data.get("E")),
)
async def close(self) -> None:
"""Close the WebSocket connection."""
self._connected = False
if self._ws:
try:
await self._ws.close()
except Exception:
pass
logger.info(
"binance.closed",
extra={"reconnect_count": self._reconnect_count},
)

View File

@@ -0,0 +1,13 @@
[project]
name = "market-data-service"
version = "0.1.0"
description = "Real-time market data collection for SenpAI trading agent"
requires-python = ">=3.11"
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
[tool.ruff]
target-version = "py311"
line-length = 100

View File

@@ -0,0 +1,30 @@
# Market Data Service for SenpAI
# Python 3.11+
# Core
pydantic>=2.5
pydantic-settings>=2.1
# Async
websockets>=12.0
httpx>=0.27
# Database
sqlalchemy[asyncio]>=2.0
aiosqlite>=0.20
# Reliability
tenacity>=8.2
# Logging
structlog>=24.1
# Metrics
prometheus_client>=0.20
# Testing
pytest>=8.0
pytest-asyncio>=0.23
# Linting (optional)
ruff>=0.3

View File

@@ -0,0 +1,107 @@
"""
Unit tests for Alpaca raw → domain event parsing.
"""
import pytest
from app.providers.alpaca import AlpacaProvider
from app.domain.events import EventType
@pytest.fixture
def provider():
return AlpacaProvider()
# ── Trade parsing ──────────────────────────────────────────────────────
ALPACA_TRADE_RAW = {
"T": "t",
"S": "AAPL",
"p": 185.50,
"s": 100,
"t": "2024-01-15T14:30:00.123456Z",
"i": 12345,
"x": "V",
"z": "C",
}
def test_parse_trade_basic(provider):
event = provider._parse(ALPACA_TRADE_RAW)
assert event is not None
assert event.event_type == EventType.TRADE
assert event.symbol == "AAPL"
assert event.price == 185.50
assert event.size == 100
assert event.trade_id == "12345"
assert event.provider == "alpaca"
def test_parse_trade_timestamp(provider):
event = provider._parse(ALPACA_TRADE_RAW)
assert event.ts_exchange is not None
assert event.ts_exchange.year == 2024
assert event.ts_exchange.month == 1
# ── Quote parsing ──────────────────────────────────────────────────────
ALPACA_QUOTE_RAW = {
"T": "q",
"S": "TSLA",
"bp": 250.10,
"bs": 200,
"ap": 250.25,
"as": 150,
"t": "2024-01-15T14:30:01.456789Z",
"x": "V",
"z": "C",
}
def test_parse_quote_basic(provider):
event = provider._parse(ALPACA_QUOTE_RAW)
assert event is not None
assert event.event_type == EventType.QUOTE
assert event.symbol == "TSLA"
assert event.bid == 250.10
assert event.ask == 250.25
assert event.bid_size == 200
assert event.ask_size == 150
assert event.provider == "alpaca"
# ── Control messages ───────────────────────────────────────────────────
def test_parse_success_message(provider):
raw = {"T": "success", "msg": "connected"}
event = provider._parse(raw)
assert event is None
def test_parse_subscription_message(provider):
raw = {"T": "subscription", "trades": ["AAPL"], "quotes": ["AAPL"]}
event = provider._parse(raw)
assert event is None
def test_parse_error_message(provider):
raw = {"T": "error", "code": 402, "msg": "auth failed"}
event = provider._parse(raw)
assert event is None
# ── Edge cases ─────────────────────────────────────────────────────────
def test_parse_trade_missing_timestamp(provider):
raw = {"T": "t", "S": "AAPL", "p": 100, "s": 10, "i": 1}
event = provider._parse(raw)
assert event is not None
assert event.ts_exchange is None
def test_parse_quote_zero_values(provider):
raw = {"T": "q", "S": "SPY", "bp": 0, "bs": 0, "ap": 0, "as": 0}
event = provider._parse(raw)
assert event is not None
assert event.bid == 0
assert event.ask == 0

View File

@@ -0,0 +1,97 @@
"""
Unit tests for Binance raw → domain event parsing.
"""
import pytest
from app.providers.binance import BinanceProvider
from app.domain.events import EventType
@pytest.fixture
def provider():
return BinanceProvider()
# ── Trade parsing ──────────────────────────────────────────────────────
BINANCE_TRADE_RAW = {
"e": "trade",
"E": 1672515782136,
"s": "BTCUSDT",
"t": 123456789,
"p": "42500.50",
"q": "0.015",
"T": 1672515782135,
"m": True,
}
def test_parse_trade_basic(provider):
event = provider._parse(BINANCE_TRADE_RAW)
assert event is not None
assert event.event_type == EventType.TRADE
assert event.symbol == "BTCUSDT"
assert event.price == 42500.50
assert event.size == 0.015
assert event.side == "sell" # m=True → seller is maker → trade is sell
assert event.trade_id == "123456789"
assert event.provider == "binance"
def test_parse_trade_ts_exchange(provider):
event = provider._parse(BINANCE_TRADE_RAW)
assert event.ts_exchange is not None
# 1672515782135 ms → 2022-12-31 or 2023-01-01 (depending on TZ)
assert event.ts_exchange.year >= 2022
def test_parse_trade_buy_side(provider):
raw = {**BINANCE_TRADE_RAW, "m": False}
event = provider._parse(raw)
assert event.side == "buy"
# ── BookTicker (Quote) parsing ─────────────────────────────────────────
BINANCE_BOOKTICKER_RAW = {
"u": 400900217,
"s": "ETHUSDT",
"b": "2150.25000000",
"B": "31.21000000",
"a": "2150.50000000",
"A": "40.66000000",
}
def test_parse_bookticker(provider):
event = provider._parse(BINANCE_BOOKTICKER_RAW)
assert event is not None
assert event.event_type == EventType.QUOTE
assert event.symbol == "ETHUSDT"
assert event.bid == 2150.25
assert event.ask == 2150.50
assert event.bid_size == 31.21
assert event.ask_size == 40.66
assert event.provider == "binance"
# ── Edge cases ─────────────────────────────────────────────────────────
def test_parse_unknown_event(provider):
raw = {"e": "aggTrade", "s": "BTCUSDT", "p": "100"}
event = provider._parse(raw)
assert event is None
def test_parse_subscription_confirmation(provider):
raw = {"result": None, "id": 1}
# This is handled in stream(), not _parse(), so _parse should return None
event = provider._parse(raw)
assert event is None
def test_parse_empty_values(provider):
raw = {"e": "trade", "s": "", "p": "0", "q": "0", "T": None, "m": True, "t": ""}
event = provider._parse(raw)
assert event is not None
assert event.price == 0.0
assert event.symbol == ""

View File

@@ -0,0 +1,123 @@
"""
Smoke test for the async event bus.
"""
import asyncio
import pytest
from app.core.bus import EventBus
from app.domain.events import TradeEvent, HeartbeatEvent, EventType
class MockConsumer:
"""Test consumer that collects events."""
def __init__(self):
self.events: list = []
async def handle(self, event):
self.events.append(event)
class FailingConsumer:
"""Consumer that always raises."""
async def handle(self, event):
raise ValueError("I always fail")
@pytest.mark.asyncio
async def test_bus_fanout():
"""Events are delivered to all consumers."""
bus = EventBus()
c1 = MockConsumer()
c2 = MockConsumer()
bus.add_consumer(c1)
bus.add_consumer(c2)
await bus.start()
event = TradeEvent(
provider="test",
symbol="BTCUSDT",
price=42000.0,
size=1.5,
)
await bus.publish(event)
# Give worker time to process
await asyncio.sleep(0.1)
await bus.stop()
assert len(c1.events) == 1
assert len(c2.events) == 1
assert c1.events[0].symbol == "BTCUSDT"
assert c2.events[0].price == 42000.0
@pytest.mark.asyncio
async def test_bus_failing_consumer_doesnt_block():
"""A failing consumer doesn't prevent others from receiving events."""
bus = EventBus()
good = MockConsumer()
bad = FailingConsumer()
bus.add_consumer(bad)
bus.add_consumer(good)
await bus.start()
await bus.publish(HeartbeatEvent(provider="test"))
await asyncio.sleep(0.1)
await bus.stop()
assert len(good.events) == 1
assert good.events[0].event_type == EventType.HEARTBEAT
@pytest.mark.asyncio
async def test_bus_multiple_events():
"""Multiple events are delivered in order."""
bus = EventBus()
consumer = MockConsumer()
bus.add_consumer(consumer)
await bus.start()
for i in range(10):
await bus.publish(
TradeEvent(
provider="test",
symbol=f"SYM{i}",
price=float(i),
size=1.0,
)
)
await asyncio.sleep(0.2)
await bus.stop()
assert len(consumer.events) == 10
symbols = [e.symbol for e in consumer.events]
assert symbols == [f"SYM{i}" for i in range(10)]
@pytest.mark.asyncio
async def test_bus_queue_overflow():
"""Bus handles queue overflow without crashing."""
bus = EventBus(queue_size=3)
consumer = MockConsumer()
bus.add_consumer(consumer)
# Don't start worker — queue will fill up
for i in range(10):
await bus.publish(
HeartbeatEvent(provider="test")
)
# Should not raise
await bus.start()
await asyncio.sleep(0.1)
await bus.stop()
# Some events were dropped, but consumer got the ones that fit
assert len(consumer.events) >= 1