292 lines
8.8 KiB
Python
292 lines
8.8 KiB
Python
"""
|
|
SenpAI Market-Data Consumer — entry point.
|
|
|
|
Orchestrates:
|
|
1. NATS subscription (md.events.>)
|
|
2. Event processing → state updates → feature computation
|
|
3. Feature/signal/alert publishing back to NATS
|
|
4. HTTP API for monitoring
|
|
|
|
Usage:
|
|
python -m senpai.md_consumer
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import signal
|
|
import time
|
|
|
|
import structlog
|
|
|
|
from senpai.md_consumer import api
|
|
from senpai.md_consumer import metrics as m
|
|
from senpai.md_consumer.config import settings
|
|
from senpai.md_consumer.features import (
|
|
check_signal,
|
|
make_feature_snapshot,
|
|
compute_features,
|
|
)
|
|
from senpai.md_consumer.models import (
|
|
AlertEvent,
|
|
EventType,
|
|
TradeEvent,
|
|
QuoteEvent,
|
|
)
|
|
from senpai.md_consumer.nats_consumer import NATSConsumer
|
|
from senpai.md_consumer.publisher import Publisher
|
|
from senpai.md_consumer.state import LatestState
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
# ── Logging setup ──────────────────────────────────────────────────────
|
|
|
|
|
|
def setup_logging() -> None:
|
|
log_level = getattr(logging, settings.log_level.upper(), logging.INFO)
|
|
structlog.configure(
|
|
processors=[
|
|
structlog.contextvars.merge_contextvars,
|
|
structlog.processors.add_log_level,
|
|
structlog.processors.TimeStamper(fmt="iso"),
|
|
structlog.dev.ConsoleRenderer(),
|
|
],
|
|
wrapper_class=structlog.make_filtering_bound_logger(log_level),
|
|
context_class=dict,
|
|
logger_factory=structlog.PrintLoggerFactory(),
|
|
)
|
|
logging.basicConfig(level=log_level, format="%(message)s")
|
|
|
|
|
|
# ── Processing pipeline ───────────────────────────────────────────────
|
|
|
|
|
|
async def process_events(
|
|
consumer: NATSConsumer,
|
|
state: LatestState,
|
|
publisher: Publisher,
|
|
) -> None:
|
|
"""
|
|
Main processing loop:
|
|
1. Read event from queue
|
|
2. Update state
|
|
3. Compute features
|
|
4. Publish features + check signals
|
|
5. Check alerts
|
|
"""
|
|
last_alert_check = time.monotonic()
|
|
events_per_sec_count = 0
|
|
time.monotonic()
|
|
|
|
feature_compute_interval = 1.0 / max(settings.features_pub_rate_hz, 1.0)
|
|
next_feature_compute: dict[str, float] = {}
|
|
next_signal_emit: dict[str, float] = {}
|
|
signal_cooldown_sec = 1.0
|
|
|
|
batch_counter = 0
|
|
while True:
|
|
try:
|
|
event = await consumer.queue.get()
|
|
except asyncio.CancelledError:
|
|
break
|
|
|
|
# Yield to event loop every N events so HTTP API stays responsive
|
|
batch_counter += 1
|
|
if batch_counter >= 5:
|
|
batch_counter = 0
|
|
await asyncio.sleep(0)
|
|
|
|
proc_start = time.monotonic()
|
|
|
|
try:
|
|
# Update state based on event type
|
|
if event.event_type == EventType.TRADE:
|
|
assert isinstance(event, TradeEvent)
|
|
state.update_trade(event)
|
|
symbol = event.symbol
|
|
|
|
elif event.event_type == EventType.QUOTE:
|
|
assert isinstance(event, QuoteEvent)
|
|
state.update_quote(event)
|
|
symbol = event.symbol
|
|
|
|
elif event.event_type == EventType.HEARTBEAT:
|
|
# Heartbeats don't update state, just track
|
|
symbol = None
|
|
|
|
elif event.event_type == EventType.BOOK_L2:
|
|
# TODO: book updates
|
|
symbol = None
|
|
|
|
else:
|
|
symbol = None
|
|
|
|
# Compute features + publish with per-symbol throttling
|
|
if symbol and settings.features_enabled:
|
|
now_mono = time.monotonic()
|
|
due = next_feature_compute.get(symbol, 0.0)
|
|
if now_mono >= due:
|
|
snapshot = make_feature_snapshot(state, symbol)
|
|
# Cache for fast HTTP API responses
|
|
api.cache_features(symbol, snapshot.features)
|
|
await publisher.publish_features(snapshot)
|
|
|
|
# Check for trade signal with cooldown to avoid flood
|
|
sig = check_signal(snapshot.features, symbol)
|
|
sig_due = next_signal_emit.get(symbol, 0.0)
|
|
if sig and now_mono >= sig_due:
|
|
await publisher.publish_signal(sig)
|
|
next_signal_emit[symbol] = now_mono + signal_cooldown_sec
|
|
|
|
next_feature_compute[symbol] = now_mono + feature_compute_interval
|
|
|
|
# Processing latency metric
|
|
proc_ms = (time.monotonic() - proc_start) * 1000
|
|
m.PROCESSING_LATENCY.observe(proc_ms)
|
|
|
|
# Events/sec tracking
|
|
events_per_sec_count += 1
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"process.error",
|
|
error=str(e),
|
|
event_type=event.event_type.value if event else "?",
|
|
)
|
|
|
|
# Periodic alert checks (every 5 seconds)
|
|
now = time.monotonic()
|
|
if now - last_alert_check > 5.0:
|
|
last_alert_check = now
|
|
await _check_alerts(state, publisher, consumer)
|
|
|
|
|
|
async def _check_alerts(
|
|
state: LatestState,
|
|
publisher: Publisher,
|
|
consumer: NATSConsumer,
|
|
) -> None:
|
|
"""Check alert conditions and emit if needed."""
|
|
# Backpressure alert
|
|
fill = consumer.queue_fill_ratio
|
|
if fill > 0.8:
|
|
await publisher.publish_alert(
|
|
AlertEvent(
|
|
alert_type="backpressure",
|
|
level="warning" if fill < 0.95 else "critical",
|
|
message=f"Queue fill at {fill:.0%}",
|
|
details={"fill_ratio": fill},
|
|
)
|
|
)
|
|
|
|
# Latency alert (per symbol)
|
|
for sym in state.symbols:
|
|
features = compute_features(state, sym)
|
|
p95 = features.get("latency_ms_p95")
|
|
if p95 is not None and p95 > settings.alert_latency_ms:
|
|
await publisher.publish_alert(
|
|
AlertEvent(
|
|
alert_type="latency",
|
|
level="warning",
|
|
message=f"{sym} p95 latency {p95:.0f}ms > {settings.alert_latency_ms}ms",
|
|
details={"symbol": sym, "p95_ms": p95},
|
|
)
|
|
)
|
|
|
|
|
|
# ── Main ───────────────────────────────────────────────────────────────
|
|
|
|
|
|
async def main() -> None:
|
|
setup_logging()
|
|
logger.info("service.starting", nats_url=settings.nats_url)
|
|
|
|
# State store
|
|
state = LatestState(window_seconds=settings.rolling_window_seconds)
|
|
|
|
# NATS consumer
|
|
consumer = NATSConsumer()
|
|
await consumer.connect()
|
|
await consumer.subscribe()
|
|
|
|
# Publisher (reuses same NATS connection)
|
|
publisher = Publisher(consumer._nc)
|
|
|
|
# Wire up API
|
|
api.set_state(state)
|
|
|
|
def _get_stats() -> dict:
|
|
return {
|
|
"queue_size": consumer.queue.qsize(),
|
|
"queue_fill_ratio": round(consumer.queue_fill_ratio, 3),
|
|
"queue_max": settings.queue_size,
|
|
"events_processed": state.event_count,
|
|
"symbols_tracked": state.symbols,
|
|
"features_enabled": settings.features_enabled,
|
|
"nats_connected": bool(consumer._nc and consumer._nc.is_connected),
|
|
}
|
|
|
|
api.set_stats_fn(_get_stats)
|
|
|
|
# Start HTTP API
|
|
http_server = await api.start_api()
|
|
|
|
# Start processing loop
|
|
process_task = asyncio.create_task(
|
|
process_events(consumer, state, publisher)
|
|
)
|
|
|
|
# Graceful shutdown
|
|
shutdown_event = asyncio.Event()
|
|
|
|
def _signal_handler():
|
|
logger.info("service.shutdown_signal")
|
|
shutdown_event.set()
|
|
|
|
loop = asyncio.get_event_loop()
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
try:
|
|
loop.add_signal_handler(sig, _signal_handler)
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
logger.info(
|
|
"service.ready",
|
|
subject=settings.nats_subject,
|
|
queue_group=settings.nats_queue_group,
|
|
http_port=settings.http_port,
|
|
features_enabled=settings.features_enabled,
|
|
)
|
|
|
|
# Wait for shutdown
|
|
await shutdown_event.wait()
|
|
|
|
# ── Cleanup ───────────────────────────────────────────────────────
|
|
logger.info("service.shutting_down")
|
|
|
|
process_task.cancel()
|
|
try:
|
|
await process_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
await consumer.close()
|
|
|
|
http_server.close()
|
|
await http_server.wait_closed()
|
|
|
|
logger.info(
|
|
"service.stopped",
|
|
events_processed=state.event_count,
|
|
symbols=state.symbols,
|
|
)
|
|
|
|
|
|
def cli():
|
|
asyncio.run(main())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli()
|