""" NodeHealthTracker — M8.0: per-node health state tracking for soft-failover. Tracks invoke outcomes per node and maintains: - EWMA latency estimate - consecutive failure counter - last ok / last error timestamps - derived health state: "healthy" | "degraded" | "down" State transitions ----------------- Any state → "down" : consecutive_failures >= fail_consecutive Any state → "degraded" : ewma_latency_s >= lat_ewma_threshold (and not yet "down") "down"/"degraded" → "healthy" : record_ok() resets consecutive_failures to 0 and ewma is updated towards the actual latency Thread safety ------------- All mutations are protected by a threading.Lock so this can be called from asyncio callbacks (e.g. in `_invoke_and_send` on the event loop thread). Use `record_ok` / `record_error` from within coroutines; they are synchronous (no blocking I/O) so they are safe to call directly without to_thread. """ from __future__ import annotations import logging import threading import time from dataclasses import dataclass, field from typing import Dict, FrozenSet, Optional, Tuple logger = logging.getLogger(__name__) # ── State constants ──────────────────────────────────────────────────────────── NODE_STATE_HEALTHY = "healthy" NODE_STATE_DEGRADED = "degraded" NODE_STATE_DOWN = "down" # Failover-triggering error classes FAILOVER_REASON_TIMEOUT = "timeout" FAILOVER_REASON_HTTP_5XX = "http_5xx" FAILOVER_REASON_NETWORK = "network" # ── Config ──────────────────────────────────────────────────────────────────── @dataclass(frozen=True) class NodeHealthConfig: """ Thresholds controlling when a node is considered degraded or down. fail_consecutive : int number of consecutive failures → "down" lat_ewma_s : float EWMA latency estimate (seconds) threshold → "degraded" ewma_alpha : float EWMA smoothing factor (0..1); higher = more reactive """ fail_consecutive: int = 3 lat_ewma_s: float = 12.0 ewma_alpha: float = 0.3 def __post_init__(self) -> None: if not (0 < self.ewma_alpha <= 1): raise ValueError(f"ewma_alpha must be in (0, 1], got {self.ewma_alpha}") if self.fail_consecutive < 1: raise ValueError(f"fail_consecutive must be ≥ 1, got {self.fail_consecutive}") if self.lat_ewma_s <= 0: raise ValueError(f"lat_ewma_s must be > 0, got {self.lat_ewma_s}") # ── Per-node state ──────────────────────────────────────────────────────────── @dataclass class _NodeState: invoke_ok_total: int = 0 invoke_err_total: int = 0 consecutive_failures: int = 0 last_ok_ts: Optional[float] = None last_err_ts: Optional[float] = None ewma_latency_s: Optional[float] = None # None until first ok record # ── Tracker ─────────────────────────────────────────────────────────────────── class NodeHealthTracker: """ Thread-safe per-node health tracker. Usage: tracker = NodeHealthTracker(NodeHealthConfig()) # On successful invoke tracker.record_ok("NODA1", latency_s=1.4) # On failed invoke tracker.record_error("NODA1", reason=FAILOVER_REASON_TIMEOUT) # Read health state state = tracker.state("NODA1") # "healthy" | "degraded" | "down" fallback = tracker.pick_fallback("NODA1", allowed_nodes=frozenset({"NODA1","NODA2"})) """ def __init__(self, config: Optional[NodeHealthConfig] = None) -> None: self._cfg = config or NodeHealthConfig() self._nodes: Dict[str, _NodeState] = {} self._lock = threading.RLock() # RLock: re-entrant (needed for all_info → as_info_dict) # ── Public mutation API ──────────────────────────────────────────────────── def record_ok(self, node_id: str, latency_s: float) -> None: """Record a successful invoke for node_id with given latency.""" with self._lock: ns = self._get_or_create(node_id) ns.invoke_ok_total += 1 ns.consecutive_failures = 0 ns.last_ok_ts = time.monotonic() if ns.ewma_latency_s is None: ns.ewma_latency_s = latency_s else: alpha = self._cfg.ewma_alpha ns.ewma_latency_s = alpha * latency_s + (1 - alpha) * ns.ewma_latency_s def record_error(self, node_id: str, reason: str = "unknown") -> None: """Record a failed invoke for node_id.""" with self._lock: ns = self._get_or_create(node_id) ns.invoke_err_total += 1 ns.consecutive_failures += 1 ns.last_err_ts = time.monotonic() logger.debug( "NodeHealth: node=%s consecutive_failures=%d reason=%s", node_id, ns.consecutive_failures, reason, ) # ── Public read API ─────────────────────────────────────────────────────── def state(self, node_id: str) -> str: """Return current health state for node_id.""" with self._lock: return self._state_unlocked(node_id) def pick_fallback( self, primary: str, allowed_nodes: FrozenSet[str], ) -> Optional[str]: """ Return the best alternative node for failover. Priority: healthy > degraded > (never down) Returns None if no acceptable fallback exists. """ with self._lock: candidates = sorted(n for n in allowed_nodes if n != primary) # Prefer healthy first for n in candidates: if self._state_unlocked(n) == NODE_STATE_HEALTHY: return n # Accept degraded if no healthy available for n in candidates: if self._state_unlocked(n) == NODE_STATE_DEGRADED: return n # Do not failover to "down" nodes return None def as_info_dict(self, node_id: str) -> dict: """Return a JSON-safe status dict for one node.""" with self._lock: ns = self._nodes.get(node_id) if ns is None: return { "node_id": node_id, "state": NODE_STATE_HEALTHY, "invoke_ok": 0, "invoke_err": 0, "consecutive_failures": 0, "ewma_latency_s": None, "last_ok_ts": None, "last_err_ts": None, } return { "node_id": node_id, "state": self._state_unlocked(node_id), "invoke_ok": ns.invoke_ok_total, "invoke_err": ns.invoke_err_total, "consecutive_failures": ns.consecutive_failures, "ewma_latency_s": round(ns.ewma_latency_s, 3) if ns.ewma_latency_s else None, "last_ok_ts": ns.last_ok_ts, "last_err_ts": ns.last_err_ts, } def all_info(self, allowed_nodes: Optional[FrozenSet[str]] = None) -> Dict[str, dict]: """ Return status dicts for all tracked (or specified) nodes. If allowed_nodes provided, also include entries for unseen nodes (state=healthy). """ with self._lock: keys = set(self._nodes.keys()) if allowed_nodes: keys |= set(allowed_nodes) return {n: self.as_info_dict(n) for n in sorted(keys)} def reset(self, node_id: str) -> None: """Reset health state for a node (e.g. after manual recovery).""" with self._lock: self._nodes.pop(node_id, None) def restore_node( self, node_id: str, ewma_latency_s: Optional[float], consecutive_failures: int, ) -> None: """ Restore persisted node state after a restart (M8.2). Only restores ewma_latency_s and consecutive_failures; counters (invoke_ok_total, invoke_err_total) start from 0 since they are runtime metrics for the current session. """ with self._lock: ns = self._get_or_create(node_id) ns.ewma_latency_s = ewma_latency_s ns.consecutive_failures = max(0, consecutive_failures) # ── Internal ────────────────────────────────────────────────────────────── def _get_or_create(self, node_id: str) -> _NodeState: if node_id not in self._nodes: self._nodes[node_id] = _NodeState() return self._nodes[node_id] def _state_unlocked(self, node_id: str) -> str: ns = self._nodes.get(node_id) if ns is None: return NODE_STATE_HEALTHY # unseen nodes are assumed healthy if ns.consecutive_failures >= self._cfg.fail_consecutive: return NODE_STATE_DOWN if ( ns.ewma_latency_s is not None and ns.ewma_latency_s >= self._cfg.lat_ewma_s ): return NODE_STATE_DEGRADED return NODE_STATE_HEALTHY # ── Parser (env vars → NodeHealthConfig) ────────────────────────────────────── def parse_node_health_config( fail_consecutive: int = 3, lat_ewma_s: float = 12.0, ewma_alpha: float = 0.3, ) -> NodeHealthConfig: """Construct NodeHealthConfig from parsed env values.""" return NodeHealthConfig( fail_consecutive=fail_consecutive, lat_ewma_s=lat_ewma_s, ewma_alpha=ewma_alpha, )