Producer (market-data-service):
- Backpressure: smart drop policy (heartbeats→quotes→trades preserved)
- Heartbeat monitor: synthetic HeartbeatEvent on provider silence
- Graceful shutdown: WS→bus→storage→DB engine cleanup sequence
- Bybit V5 public WS provider (backup for Binance, no API key needed)
- FailoverManager: health-based provider switching with recovery
- NATS output adapter: md.events.{type}.{symbol} for SenpAI
- /bus-stats endpoint for backpressure monitoring
- Dockerfile + docker-compose.node1.yml integration
- 36 tests (parsing + bus + failover), requirements.lock
Consumer (senpai-md-consumer):
- NATSConsumer: subscribe md.events.>, queue group senpai-md, backpressure
- State store: LatestState + RollingWindow (deque, 60s)
- Feature engine: 11 features (mid, spread, VWAP, return, vol, latency)
- Rule-based signals: long/short on return+volume+spread conditions
- Publisher: rate-limited features + signals + alerts to NATS
- HTTP API: /health, /metrics, /state/latest, /features/latest, /stats
- 10 Prometheus metrics
- Dockerfile + docker-compose.senpai.yml
- 41 tests (parsing + state + features + rate-limit), requirements.lock
CI: ruff + pytest + smoke import for both services
Tests: 77 total passed, lint clean
Co-authored-by: Cursor <cursoragent@cursor.com>
340 lines
10 KiB
Python
340 lines
10 KiB
Python
"""
|
|
Market Data Service — entry point.
|
|
|
|
CLI:
|
|
python -m app run --provider binance --symbols BTCUSDT,ETHUSDT
|
|
python -m app run --provider alpaca --symbols AAPL,TSLA
|
|
python -m app run --provider all --symbols BTCUSDT,AAPL
|
|
|
|
HTTP (optional, runs alongside):
|
|
GET /health
|
|
GET /metrics (Prometheus)
|
|
GET /latest?symbol=BTCUSDT
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import logging
|
|
import signal
|
|
import sys
|
|
|
|
import structlog
|
|
from prometheus_client import generate_latest, CONTENT_TYPE_LATEST
|
|
|
|
from app.config import settings
|
|
from app.core.bus import EventBus
|
|
from app.consumers.metrics import MetricsConsumer
|
|
from app.consumers.nats_output import NatsOutputConsumer
|
|
from app.consumers.print import PrintConsumer
|
|
from app.consumers.storage import StorageConsumer
|
|
from app.db.schema import engine, init_db
|
|
from app.db import repo
|
|
from app.providers import MarketDataProvider, get_provider
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
# Global reference to bus (for HTTP status endpoint)
|
|
_bus: EventBus | None = None
|
|
|
|
# ── Logging setup ──────────────────────────────────────────────────────
|
|
|
|
|
|
def setup_logging() -> None:
|
|
log_level = getattr(logging, settings.log_level.upper(), logging.INFO)
|
|
structlog.configure(
|
|
processors=[
|
|
structlog.contextvars.merge_contextvars,
|
|
structlog.processors.add_log_level,
|
|
structlog.processors.TimeStamper(fmt="iso"),
|
|
structlog.dev.ConsoleRenderer(),
|
|
],
|
|
wrapper_class=structlog.make_filtering_bound_logger(log_level),
|
|
context_class=dict,
|
|
logger_factory=structlog.PrintLoggerFactory(),
|
|
)
|
|
logging.basicConfig(level=log_level, format="%(message)s")
|
|
|
|
|
|
# ── HTTP Server (lightweight, no FastAPI dependency) ───────────────────
|
|
|
|
|
|
async def _http_handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
|
"""Minimal HTTP handler for health/metrics/latest."""
|
|
try:
|
|
request_line = await asyncio.wait_for(reader.readline(), timeout=5.0)
|
|
request_str = request_line.decode("utf-8", errors="replace").strip()
|
|
|
|
# Parse method + path
|
|
parts = request_str.split()
|
|
if len(parts) < 2:
|
|
writer.close()
|
|
return
|
|
path = parts[1]
|
|
|
|
# Consume headers
|
|
while True:
|
|
line = await reader.readline()
|
|
if line in (b"\r\n", b"\n", b""):
|
|
break
|
|
|
|
if path == "/health":
|
|
body = b'{"status":"ok","service":"market-data-service"}'
|
|
content_type = "application/json"
|
|
elif path == "/metrics":
|
|
body = generate_latest()
|
|
content_type = CONTENT_TYPE_LATEST
|
|
elif path.startswith("/latest"):
|
|
# Parse ?symbol=XXX
|
|
symbol = ""
|
|
if "?" in path:
|
|
query = path.split("?", 1)[1]
|
|
for param in query.split("&"):
|
|
if param.startswith("symbol="):
|
|
symbol = param.split("=", 1)[1]
|
|
|
|
if not symbol:
|
|
body = b'{"error":"missing ?symbol=XXX"}'
|
|
content_type = "application/json"
|
|
else:
|
|
import json
|
|
|
|
trade = await repo.get_latest_trade(symbol)
|
|
quote = await repo.get_latest_quote(symbol)
|
|
result = {
|
|
"symbol": symbol.upper(),
|
|
"latest_trade": trade,
|
|
"latest_quote": quote,
|
|
}
|
|
body = json.dumps(result, ensure_ascii=False).encode()
|
|
content_type = "application/json"
|
|
elif path == "/bus-stats":
|
|
import json as _json
|
|
|
|
bus_info = {"queue_size": 0, "fill_percent": 0.0}
|
|
if _bus:
|
|
bus_info = {
|
|
"queue_size": _bus.queue_size,
|
|
"fill_percent": round(_bus.fill_percent * 100, 1),
|
|
"max_size": _bus._max_size,
|
|
}
|
|
body = _json.dumps(bus_info).encode()
|
|
content_type = "application/json"
|
|
else:
|
|
body = b'{"error":"not found"}'
|
|
content_type = "application/json"
|
|
|
|
response = (
|
|
f"HTTP/1.1 200 OK\r\n"
|
|
f"Content-Type: {content_type}\r\n"
|
|
f"Content-Length: {len(body)}\r\n"
|
|
f"Connection: close\r\n"
|
|
f"\r\n"
|
|
)
|
|
writer.write(response.encode() + body)
|
|
await writer.drain()
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
try:
|
|
writer.close()
|
|
await writer.wait_closed()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def start_http_server() -> asyncio.Server:
|
|
server = await asyncio.start_server(
|
|
_http_handler,
|
|
settings.http_host,
|
|
settings.http_port,
|
|
)
|
|
logger.info(
|
|
"http.started",
|
|
host=settings.http_host,
|
|
port=settings.http_port,
|
|
endpoints=["/health", "/metrics", "/latest?symbol=XXX"],
|
|
)
|
|
return server
|
|
|
|
|
|
# ── Provider runner ────────────────────────────────────────────────────
|
|
|
|
|
|
async def run_provider(
|
|
provider: MarketDataProvider,
|
|
symbols: list[str],
|
|
bus: EventBus,
|
|
) -> None:
|
|
"""Connect, subscribe, and stream events into the bus."""
|
|
await provider.connect()
|
|
await provider.subscribe(symbols)
|
|
|
|
logger.info(
|
|
"provider.streaming",
|
|
provider=provider.name,
|
|
symbols=symbols,
|
|
)
|
|
|
|
async for event in provider.stream():
|
|
await bus.publish(event)
|
|
|
|
|
|
# ── Main orchestrator ──────────────────────────────────────────────────
|
|
|
|
|
|
async def main(provider_names: list[str], symbols: list[str]) -> None:
|
|
setup_logging()
|
|
logger.info(
|
|
"service.starting",
|
|
providers=provider_names,
|
|
symbols=symbols,
|
|
)
|
|
|
|
# Init database
|
|
await init_db()
|
|
|
|
global _bus
|
|
|
|
# Setup bus + consumers (heartbeat interval from config)
|
|
bus = EventBus(
|
|
queue_size=10_000,
|
|
heartbeat_interval=settings.heartbeat_timeout / 2, # check twice per timeout
|
|
)
|
|
|
|
storage = StorageConsumer()
|
|
await storage.start()
|
|
bus.add_consumer(storage)
|
|
|
|
metrics = MetricsConsumer()
|
|
bus.add_consumer(metrics)
|
|
|
|
printer = PrintConsumer()
|
|
bus.add_consumer(printer)
|
|
|
|
# Optional: NATS output adapter
|
|
nats_consumer = None
|
|
if settings.nats_configured:
|
|
nats_consumer = NatsOutputConsumer()
|
|
await nats_consumer.start()
|
|
bus.add_consumer(nats_consumer)
|
|
logger.info("nats_output.enabled", subject_prefix=settings.nats_subject_prefix)
|
|
else:
|
|
logger.info("nats_output.disabled", hint="Set NATS_URL + NATS_ENABLED=true to enable")
|
|
|
|
# Create providers and register them for heartbeat monitoring
|
|
providers: list[MarketDataProvider] = []
|
|
for name in provider_names:
|
|
p = get_provider(name)
|
|
providers.append(p)
|
|
bus.register_provider(p.name)
|
|
|
|
_bus = bus
|
|
await bus.start()
|
|
|
|
# Start HTTP server
|
|
http_server = await start_http_server()
|
|
|
|
# Run all providers concurrently
|
|
tasks = []
|
|
for p in providers:
|
|
tasks.append(asyncio.create_task(run_provider(p, symbols, bus)))
|
|
|
|
# Graceful shutdown on SIGINT/SIGTERM
|
|
shutdown_event = asyncio.Event()
|
|
|
|
def _signal_handler():
|
|
logger.info("service.shutdown_signal")
|
|
shutdown_event.set()
|
|
|
|
loop = asyncio.get_event_loop()
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
try:
|
|
loop.add_signal_handler(sig, _signal_handler)
|
|
except NotImplementedError:
|
|
pass # Windows
|
|
|
|
# Wait for shutdown
|
|
await shutdown_event.wait()
|
|
|
|
# ── Graceful shutdown sequence ──────────────────────────────────────
|
|
logger.info("service.shutting_down")
|
|
|
|
# 1. Cancel provider streaming tasks (with timeout)
|
|
for task in tasks:
|
|
task.cancel()
|
|
done, pending = await asyncio.wait(tasks, timeout=5.0)
|
|
for task in pending:
|
|
logger.warning("service.task_force_cancel", extra={"task": task.get_name()})
|
|
task.cancel()
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# 2. Close provider WebSocket connections
|
|
for p in providers:
|
|
try:
|
|
await p.close()
|
|
except Exception as e:
|
|
logger.warning("service.provider_close_error", extra={"provider": p.name, "error": str(e)})
|
|
|
|
# 3. Stop bus (drains remaining events to consumers)
|
|
await bus.stop()
|
|
|
|
# 4. Stop storage consumer (flush JSONL)
|
|
await storage.stop()
|
|
|
|
# 4b. Stop NATS output (flush + close)
|
|
if nats_consumer:
|
|
await nats_consumer.stop()
|
|
|
|
# 5. Close HTTP server
|
|
http_server.close()
|
|
await http_server.wait_closed()
|
|
|
|
# 6. Close SQLAlchemy engine (flush connections)
|
|
await engine.dispose()
|
|
|
|
logger.info("service.stopped", extra={"exit": "clean"})
|
|
|
|
|
|
# ── CLI ────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def cli():
|
|
parser = argparse.ArgumentParser(
|
|
description="Market Data Service for SenpAI trading agent",
|
|
)
|
|
sub = parser.add_subparsers(dest="command")
|
|
|
|
run_parser = sub.add_parser("run", help="Start streaming market data")
|
|
run_parser.add_argument(
|
|
"--provider",
|
|
type=str,
|
|
default="binance",
|
|
help="Provider name: binance, alpaca, all (comma-separated)",
|
|
)
|
|
run_parser.add_argument(
|
|
"--symbols",
|
|
type=str,
|
|
required=True,
|
|
help="Comma-separated symbols (e.g. BTCUSDT,ETHUSDT)",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.command == "run":
|
|
symbols = [s.strip() for s in args.symbols.split(",") if s.strip()]
|
|
|
|
if args.provider.lower() == "all":
|
|
provider_names = ["binance", "alpaca", "bybit"]
|
|
else:
|
|
provider_names = [p.strip() for p in args.provider.split(",") if p.strip()]
|
|
|
|
asyncio.run(main(provider_names, symbols))
|
|
else:
|
|
parser.print_help()
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli()
|