""" NATS consumer — subscribes to md.events.> and feeds the processing pipeline. Features: - Queue group subscription (horizontal scaling) - Bounded asyncio.Queue with backpressure drop policy - Auto-reconnect via nats-py - Optional JetStream durable consumer """ from __future__ import annotations import asyncio import json import logging import nats from nats.aio.client import Client as NatsClient from nats.aio.msg import Msg from senpai.md_consumer.config import settings from senpai.md_consumer.models import EventType, parse_event, Event from senpai.md_consumer import metrics as m logger = logging.getLogger(__name__) # Events that can be dropped under backpressure (lowest priority first) _DROPPABLE = {EventType.HEARTBEAT, EventType.QUOTE, EventType.BOOK_L2} class NATSConsumer: """ Reads normalised events from NATS, validates, and puts into a bounded asyncio.Queue for downstream processing. Backpressure policy: - Queue < 90% → accept all events - Queue >= 90% → drop heartbeats, quotes, book snapshots - Trades are NEVER dropped (critical for analytics) """ def __init__(self) -> None: self._nc: NatsClient | None = None self._sub = None self._js_sub = None self._queue: asyncio.Queue[Event] = asyncio.Queue( maxsize=settings.queue_size ) self._running = False self._drop_count: dict[str, int] = {} @property def queue(self) -> asyncio.Queue[Event]: return self._queue @property def queue_fill_ratio(self) -> float: if settings.queue_size <= 0: return 0.0 return self._queue.qsize() / settings.queue_size async def connect(self) -> None: """Connect to NATS with auto-reconnect.""" self._nc = await nats.connect( self._url, reconnect_time_wait=2, max_reconnect_attempts=-1, name="senpai-md-consumer", error_cb=self._on_error, disconnected_cb=self._on_disconnected, reconnected_cb=self._on_reconnected, closed_cb=self._on_closed, ) m.NATS_CONNECTED.set(1) logger.info( "nats.connected", extra={"url": self._url, "subject": settings.nats_subject}, ) @property def _url(self) -> str: return settings.nats_url async def subscribe(self) -> None: """Subscribe to market data events.""" if not self._nc: raise RuntimeError("Not connected. Call connect() first.") if settings.use_jetstream: await self._subscribe_jetstream() else: await self._subscribe_core() async def _subscribe_core(self) -> None: """Core NATS subscription with queue group.""" self._sub = await self._nc.subscribe( settings.nats_subject, queue=settings.nats_queue_group, cb=self._on_message, ) logger.info( "nats.subscribed_core", extra={ "subject": settings.nats_subject, "queue_group": settings.nats_queue_group, }, ) async def _subscribe_jetstream(self) -> None: """JetStream durable subscription.""" js = self._nc.jetstream() # Try to create or bind to existing consumer self._js_sub = await js.subscribe( settings.nats_subject, queue=settings.nats_queue_group, durable="senpai-md-durable", cb=self._on_message, manual_ack=True, ) logger.info( "nats.subscribed_jetstream", extra={ "subject": settings.nats_subject, "durable": "senpai-md-durable", }, ) async def _on_message(self, msg: Msg) -> None: """ Callback for each NATS message. Parse → backpressure check → enqueue. """ try: data = json.loads(msg.data) except (json.JSONDecodeError, UnicodeDecodeError) as e: m.EVENTS_DROPPED.labels(reason="parse_error", event_type="unknown").inc() logger.warning("nats.parse_error", extra={"error": str(e)}) if settings.use_jetstream: await msg.ack() return event = parse_event(data) if event is None: m.EVENTS_DROPPED.labels(reason="invalid_event", event_type="unknown").inc() if settings.use_jetstream: await msg.ack() return # Track inbound m.EVENTS_IN.labels( event_type=event.event_type.value, provider=event.provider, ).inc() # Backpressure check fill = self.queue_fill_ratio m.QUEUE_FILL.set(fill) m.QUEUE_SIZE.set(self._queue.qsize()) if fill >= settings.queue_drop_threshold: # Under pressure: only accept trades if event.event_type in _DROPPABLE: et = event.event_type.value self._drop_count[et] = self._drop_count.get(et, 0) + 1 m.EVENTS_DROPPED.labels( reason="backpressure", event_type=et, ).inc() if self._drop_count[et] % 1000 == 1: logger.warning( "nats.backpressure_drop", extra={ "type": et, "fill": f"{fill:.0%}", "total_drops": self._drop_count, }, ) if settings.use_jetstream: await msg.ack() return # Enqueue try: self._queue.put_nowait(event) except asyncio.QueueFull: # Last resort: try to drop oldest non-trade m.EVENTS_DROPPED.labels( reason="queue_full", event_type=event.event_type.value ).inc() if settings.use_jetstream: await msg.ack() async def close(self) -> None: """Graceful shutdown.""" self._running = False if self._sub: try: await self._sub.unsubscribe() except Exception: pass if self._nc: try: await self._nc.flush(timeout=5) await self._nc.close() except Exception: pass m.NATS_CONNECTED.set(0) logger.info("nats.closed", extra={"drops": self._drop_count}) # ── NATS callbacks ──────────────────────────────────────────────── async def _on_error(self, e: Exception) -> None: logger.error("nats.error", extra={"error": str(e)}) async def _on_disconnected(self) -> None: m.NATS_CONNECTED.set(0) logger.warning("nats.disconnected") async def _on_reconnected(self) -> None: m.NATS_CONNECTED.set(1) m.NATS_RECONNECTS.inc() logger.info("nats.reconnected") async def _on_closed(self) -> None: m.NATS_CONNECTED.set(0) logger.info("nats.closed_callback")