""" Trace Middleware ================ Стандартизована кореляція запитів через всі сервіси. Headers: - X-Trace-ID: uuid (весь шлях) - X-Request-ID: uuid (HTTP request) - X-Job-ID: uuid (async NATS job) - X-User-ID: user identifier - X-Agent-ID: target agent - X-Mode: public|team|private|confidential - X-Policy-Version: version hash - X-Prompt-Version: version hash Використання: 1. Gateway генерує trace_id 2. Всі сервіси передають у headers 3. NATS messages містять у metadata 4. Logs структуровані з trace_id """ import uuid import logging from typing import Optional, Dict, Any from datetime import datetime from contextvars import ContextVar from functools import wraps from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware # Context variables for trace propagation trace_context: ContextVar[Dict[str, str]] = ContextVar('trace_context', default={}) logger = logging.getLogger(__name__) class TraceContext: """Immutable trace context for request correlation""" def __init__( self, trace_id: str = None, request_id: str = None, job_id: str = None, user_id: str = None, agent_id: str = None, mode: str = "public", policy_version: str = None, prompt_version: str = None, source_service: str = None ): self.trace_id = trace_id or str(uuid.uuid4()) self.request_id = request_id or str(uuid.uuid4()) self.job_id = job_id self.user_id = user_id self.agent_id = agent_id self.mode = mode self.policy_version = policy_version self.prompt_version = prompt_version self.source_service = source_service self.timestamp = datetime.utcnow().isoformat() def to_headers(self) -> Dict[str, str]: """Convert to HTTP headers""" headers = { "X-Trace-ID": self.trace_id, "X-Request-ID": self.request_id, } if self.job_id: headers["X-Job-ID"] = self.job_id if self.user_id: headers["X-User-ID"] = self.user_id if self.agent_id: headers["X-Agent-ID"] = self.agent_id if self.mode: headers["X-Mode"] = self.mode if self.policy_version: headers["X-Policy-Version"] = self.policy_version if self.prompt_version: headers["X-Prompt-Version"] = self.prompt_version return headers def to_nats_headers(self) -> Dict[str, str]: """Convert to NATS message headers""" return { "Nats-Trace-ID": self.trace_id, "Nats-Job-ID": self.job_id or self.request_id, "Nats-User-ID": self.user_id or "", "Nats-Agent-ID": self.agent_id or "", "Nats-Mode": self.mode, "Nats-Timestamp": self.timestamp } def to_log_context(self) -> Dict[str, Any]: """Convert to structured log context""" return { "trace_id": self.trace_id, "request_id": self.request_id, "job_id": self.job_id, "user_id": self.user_id, "agent_id": self.agent_id, "mode": self.mode, "policy_version": self.policy_version, "prompt_version": self.prompt_version, "timestamp": self.timestamp } @classmethod def from_headers(cls, headers: Dict[str, str]) -> "TraceContext": """Create from HTTP headers""" return cls( trace_id=headers.get("X-Trace-ID") or headers.get("x-trace-id"), request_id=headers.get("X-Request-ID") or headers.get("x-request-id"), job_id=headers.get("X-Job-ID") or headers.get("x-job-id"), user_id=headers.get("X-User-ID") or headers.get("x-user-id"), agent_id=headers.get("X-Agent-ID") or headers.get("x-agent-id"), mode=headers.get("X-Mode") or headers.get("x-mode") or "public", policy_version=headers.get("X-Policy-Version"), prompt_version=headers.get("X-Prompt-Version") ) @classmethod def from_nats(cls, headers: Dict[str, str]) -> "TraceContext": """Create from NATS headers""" return cls( trace_id=headers.get("Nats-Trace-ID"), job_id=headers.get("Nats-Job-ID"), user_id=headers.get("Nats-User-ID"), agent_id=headers.get("Nats-Agent-ID"), mode=headers.get("Nats-Mode", "public") ) class TraceMiddleware(BaseHTTPMiddleware): """FastAPI middleware for trace propagation""" def __init__(self, app, service_name: str): super().__init__(app) self.service_name = service_name async def dispatch(self, request: Request, call_next): # Extract or create trace context ctx = TraceContext.from_headers(dict(request.headers)) ctx.source_service = self.service_name # Store in context var trace_context.set(ctx.to_log_context()) # Log request start logger.info( f"Request started", extra={ "trace_id": ctx.trace_id, "request_id": ctx.request_id, "method": request.method, "path": request.url.path, "service": self.service_name } ) # Process request response = await call_next(request) # Add trace headers to response response.headers["X-Trace-ID"] = ctx.trace_id response.headers["X-Request-ID"] = ctx.request_id # Log request end logger.info( f"Request completed", extra={ "trace_id": ctx.trace_id, "request_id": ctx.request_id, "status_code": response.status_code, "service": self.service_name } ) return response def get_current_trace() -> Dict[str, str]: """Get current trace context from context var""" return trace_context.get() def with_trace(func): """Decorator to propagate trace context""" @wraps(func) async def wrapper(*args, **kwargs): ctx = get_current_trace() return await func(*args, trace_context=ctx, **kwargs) return wrapper # ==================== Structured Logging ==================== class TraceLogFormatter(logging.Formatter): """JSON formatter with trace context""" def format(self, record): # Get trace context ctx = trace_context.get() log_entry = { "timestamp": datetime.utcnow().isoformat(), "level": record.levelname, "message": record.getMessage(), "service": getattr(record, 'service', 'unknown'), "trace_id": ctx.get('trace_id', ''), "request_id": ctx.get('request_id', ''), "user_id": ctx.get('user_id', ''), "agent_id": ctx.get('agent_id', ''), } # Add extra fields if hasattr(record, 'extra'): log_entry.update(record.extra) import json return json.dumps(log_entry) def setup_trace_logging(service_name: str): """Setup structured logging with trace context""" handler = logging.StreamHandler() handler.setFormatter(TraceLogFormatter()) root_logger = logging.getLogger() root_logger.handlers = [handler] root_logger.setLevel(logging.INFO) # Add service name to all logs old_factory = logging.getLogRecordFactory() def record_factory(*args, **kwargs): record = old_factory(*args, **kwargs) record.service = service_name return record logging.setLogRecordFactory(record_factory) # ==================== NATS Integration ==================== async def publish_with_trace(js, subject: str, payload: bytes, ctx: TraceContext): """Publish NATS message with trace headers""" headers = ctx.to_nats_headers() await js.publish(subject, payload, headers=headers) def extract_trace_from_msg(msg) -> TraceContext: """Extract trace context from NATS message""" headers = dict(msg.headers) if msg.headers else {} return TraceContext.from_nats(headers) # ==================== Audit Event ==================== def create_audit_event( action: str, ctx: TraceContext, details: Dict[str, Any] = None ) -> Dict[str, Any]: """Create standardized audit event""" return { "event_id": str(uuid.uuid4()), "event_type": f"audit.action.{action}", "timestamp": datetime.utcnow().isoformat(), "trace_id": ctx.trace_id, "request_id": ctx.request_id, "job_id": ctx.job_id, "user_id": ctx.user_id, "agent_id": ctx.agent_id, "mode": ctx.mode, "policy_version": ctx.policy_version, "prompt_version": ctx.prompt_version, "action": action, "details": details or {} }