""" Feature engine — incremental feature computation from rolling windows. Features (per symbol): - mid: (bid+ask)/2 - spread_abs: ask - bid - spread_bps: spread_abs / mid * 10000 - trade_vwap_10s: VWAP over last 10 seconds - trade_vwap_60s: VWAP over last 60 seconds - trade_count_10s: number of trades in 10s - trade_volume_10s: total volume in 10s - return_10s: mid_now / mid_10s_ago - 1 - realized_vol_60s: std of log-returns over 60s - latency_ms_p50: p50 exchange-to-receive latency - latency_ms_p95: p95 exchange-to-receive latency Rule-based signal (MVP): - if return_10s > threshold AND volume_10s > threshold AND spread_bps < threshold → emit TradeSignal(direction="long") - opposite for short """ from __future__ import annotations import logging import math from senpai.md_consumer.config import settings from senpai.md_consumer.models import FeatureSnapshot, TradeSignal from senpai.md_consumer.state import LatestState, TradeRecord logger = logging.getLogger(__name__) def compute_features(state: LatestState, symbol: str) -> dict[str, float | None]: """ Compute all features for a symbol from current state. Returns a flat dict of feature_name → value (None if not computable). """ sym = symbol.upper() features: dict[str, float | None] = {} quote = state.get_latest_quote(sym) window = state.get_window(sym) # ── Mid / Spread ────────────────────────────────────────────────── if quote and quote.bid > 0 and quote.ask > 0: mid = (quote.bid + quote.ask) / 2 spread_abs = quote.ask - quote.bid spread_bps = (spread_abs / mid * 10_000) if mid > 0 else None features["mid"] = mid features["spread_abs"] = spread_abs features["spread_bps"] = spread_bps else: features["mid"] = None features["spread_abs"] = None features["spread_bps"] = None if not window: # No rolling data yet — fill with None features.update({ "trade_vwap_10s": None, "trade_vwap_60s": None, "trade_count_10s": None, "trade_volume_10s": None, "return_10s": None, "realized_vol_60s": None, "latency_ms_p50": None, "latency_ms_p95": None, }) return features # ── VWAP ────────────────────────────────────────────────────────── trades_10s = window.trades_since(10.0) trades_60s = list(window.trades) features["trade_vwap_10s"] = _vwap(trades_10s) features["trade_vwap_60s"] = _vwap(trades_60s) # ── Trade count / volume (10s) ──────────────────────────────────── features["trade_count_10s"] = float(len(trades_10s)) features["trade_volume_10s"] = sum(t.size for t in trades_10s) if trades_10s else 0.0 # ── Return 10s ──────────────────────────────────────────────────── features["return_10s"] = _return_over(window, features.get("mid"), 10.0) # ── Realised volatility 60s ─────────────────────────────────────── features["realized_vol_60s"] = _realized_vol(trades_60s) # ── Latency ─────────────────────────────────────────────────────── latencies = _latencies_ms(trades_60s) if latencies: latencies.sort() features["latency_ms_p50"] = _percentile(latencies, 50) features["latency_ms_p95"] = _percentile(latencies, 95) else: features["latency_ms_p50"] = None features["latency_ms_p95"] = None return features def make_feature_snapshot( state: LatestState, symbol: str ) -> FeatureSnapshot: """Create a FeatureSnapshot for publishing.""" features = compute_features(state, symbol) return FeatureSnapshot(symbol=symbol.upper(), features=features) def check_signal( features: dict[str, float | None], symbol: str ) -> TradeSignal | None: """ Rule-based signal MVP. Long if: - return_10s > signal_return_threshold - trade_volume_10s > signal_volume_threshold - spread_bps < signal_spread_max_bps Short if opposite return condition met. """ ret = features.get("return_10s") vol = features.get("trade_volume_10s") spread = features.get("spread_bps") if ret is None or vol is None or spread is None: return None # Spread filter (both directions) if spread > settings.signal_spread_max_bps: return None # Volume filter if vol < settings.signal_volume_threshold: return None # Direction if ret > settings.signal_return_threshold: confidence = min(1.0, ret / (settings.signal_return_threshold * 3)) return TradeSignal( symbol=symbol.upper(), direction="long", confidence=confidence, reason=f"return_10s={ret:.4f} vol_10s={vol:.2f} spread={spread:.1f}bps", features=features, ) elif ret < -settings.signal_return_threshold: confidence = min(1.0, abs(ret) / (settings.signal_return_threshold * 3)) return TradeSignal( symbol=symbol.upper(), direction="short", confidence=confidence, reason=f"return_10s={ret:.4f} vol_10s={vol:.2f} spread={spread:.1f}bps", features=features, ) return None # ── Internal helpers ─────────────────────────────────────────────────── def _vwap(trades: list[TradeRecord]) -> float | None: """Volume-weighted average price.""" if not trades: return None total_value = sum(t.price * t.size for t in trades) total_volume = sum(t.size for t in trades) if total_volume <= 0: return None return total_value / total_volume def _return_over( window, current_mid: float | None, seconds: float ) -> float | None: """ Return over last N seconds. Uses mid price from quotes if available, else latest trade price. """ if current_mid is None or current_mid <= 0: return None # Find the quote mid from N seconds ago quotes = window.quotes_since(seconds) if quotes: oldest = quotes[0] old_mid = (oldest.bid + oldest.ask) / 2 if old_mid > 0: return current_mid / old_mid - 1 # Fallback: use trade prices trades = window.trades_since(seconds) if trades: old_price = trades[0].price if old_price > 0: return current_mid / old_price - 1 return None def _realized_vol(trades: list[TradeRecord]) -> float | None: """ Simple realised volatility: std of log-returns of trade prices. """ if len(trades) < 3: return None prices = [t.price for t in trades if t.price > 0] if len(prices) < 3: return None log_returns = [] for i in range(1, len(prices)): if prices[i - 1] > 0: lr = math.log(prices[i] / prices[i - 1]) log_returns.append(lr) if len(log_returns) < 2: return None mean = sum(log_returns) / len(log_returns) variance = sum((r - mean) ** 2 for r in log_returns) / (len(log_returns) - 1) return math.sqrt(variance) def _latencies_ms(trades: list[TradeRecord]) -> list[float]: """Extract exchange-to-receive latencies in ms.""" latencies = [] for t in trades: if t.ts_exchange is not None and t.ts_recv is not None: lat = (t.ts_recv.timestamp() - t.ts_exchange.timestamp()) * 1000 if 0 < lat < 60_000: # sanity: 0-60s latencies.append(lat) return latencies def _percentile(sorted_data: list[float], p: int) -> float: """Simple percentile from sorted list.""" if not sorted_data: return 0.0 k = (len(sorted_data) - 1) * p / 100 f = math.floor(k) c = math.ceil(k) if f == c: return sorted_data[int(k)] return sorted_data[f] * (c - k) + sorted_data[c] * (k - f)