""" Idempotency Middleware for NATS Workers ======================================== Redis-based deduplication for async jobs. Usage: from idempotency_redis import check_idempotency, mark_completed, mark_failed async def process_job(job_id: str, payload: dict): # Check if already processed status, result = await check_idempotency(job_id) if status == "completed": return result # Return cached result if status == "in_progress": raise AlreadyProcessingError("Job already in progress") # Mark as in progress await mark_in_progress(job_id) try: # Process job result = await do_work(payload) await mark_completed(job_id, result) return result except Exception as e: await mark_failed(job_id, str(e)) raise """ import os import json import logging from typing import Optional, Tuple, Dict, Any from datetime import datetime, timedelta import redis.asyncio as redis logger = logging.getLogger(__name__) # Redis connection REDIS_URL = os.getenv("REDIS_URL", "redis://dagi-redis:6379") REDIS_CLIENT = None # TTLs IDEMPOTENCY_TTL_HOURS = int(os.getenv("IDEMPOTENCY_TTL_HOURS", "24")) IN_PROGRESS_TTL_MINUTES = int(os.getenv("IN_PROGRESS_TTL_MINUTES", "30")) async def get_redis() -> redis.Redis: """Get or create Redis client""" global REDIS_CLIENT if REDIS_CLIENT is None: REDIS_CLIENT = await redis.from_url(REDIS_URL, decode_responses=True) return REDIS_CLIENT async def check_idempotency(job_id: str) -> Tuple[str, Optional[Dict[str, Any]]]: """ Check if job_id was already processed. Returns: (status, result_ref) status: "new" | "in_progress" | "completed" | "failed" result_ref: None or dict with result data """ r = await get_redis() key = f"idemp:{job_id}" value = await r.get(key) if value is None: return ("new", None) try: data = json.loads(value) status = data.get("status") result_ref = data.get("result_ref") if status == "completed": # Try to fetch result if stored result_key = f"idemp:result:{job_id}" result_data = await r.get(result_key) if result_data: result_ref = json.loads(result_data) return (status, result_ref) except json.JSONDecodeError: # Legacy format: just status string return (value, None) async def mark_in_progress(job_id: str, metadata: Dict[str, Any] = None): """Mark job as in progress""" r = await get_redis() key = f"idemp:{job_id}" data = { "status": "in_progress", "started_at": datetime.utcnow().isoformat(), "metadata": metadata or {} } await r.setex( key, timedelta(minutes=IN_PROGRESS_TTL_MINUTES), json.dumps(data) ) logger.info(f"Marked job {job_id} as in_progress") async def mark_completed(job_id: str, result: Dict[str, Any] = None, result_ref: str = None): """ Mark job as completed. Args: job_id: Job identifier result: Full result data (stored separately if large) result_ref: Reference to result (e.g., NATS subject, file path) """ r = await get_redis() key = f"idemp:{job_id}" data = { "status": "completed", "completed_at": datetime.utcnow().isoformat(), "result_ref": result_ref or "stored" } # Store result separately if provided (for retrieval) if result: result_key = f"idemp:result:{job_id}" await r.setex( result_key, timedelta(hours=IDEMPOTENCY_TTL_HOURS), json.dumps(result) ) # Mark as completed with TTL await r.setex( key, timedelta(hours=IDEMPOTENCY_TTL_HOURS), json.dumps(data) ) logger.info(f"Marked job {job_id} as completed") async def mark_failed(job_id: str, error: str, allow_retry: bool = True): """ Mark job as failed. Args: job_id: Job identifier error: Error message allow_retry: If True, delete key to allow retry. If False, mark as failed with short TTL. """ r = await get_redis() key = f"idemp:{job_id}" if allow_retry: # Delete key to allow retry await r.delete(key) logger.info(f"Marked job {job_id} as failed (retry allowed), deleted key") else: # Mark as failed with short TTL (to prevent immediate retry spam) data = { "status": "failed", "failed_at": datetime.utcnow().isoformat(), "error": error[:500] # Truncate long errors } await r.setex( key, timedelta(minutes=5), # Short TTL for failed json.dumps(data) ) logger.warning(f"Marked job {job_id} as failed (no retry): {error[:100]}") async def get_job_status(job_id: str) -> Dict[str, Any]: """Get full job status for debugging""" r = await get_redis() key = f"idemp:{job_id}" value = await r.get(key) if value is None: return {"status": "not_found"} try: data = json.loads(value) return data except json.JSONDecodeError: return {"status": value, "raw": value} # ==================== Decorator for Workers ==================== def idempotent_job(job_id_field: str = "job_id"): """ Decorator to make a worker function idempotent. Usage: @idempotent_job("job_id") async def process_workflow(payload: dict): # payload must contain job_id ... """ def decorator(func): async def wrapper(payload: dict, *args, **kwargs): job_id = payload.get(job_id_field) if not job_id: raise ValueError(f"Payload must contain '{job_id_field}'") # Check idempotency status, result = await check_idempotency(job_id) if status == "completed": logger.info(f"Job {job_id} already completed, returning cached result") return result or {"status": "already_completed", "job_id": job_id} if status == "in_progress": logger.warning(f"Job {job_id} already in progress, skipping") raise RuntimeError(f"Job {job_id} already in progress") # Mark as in progress await mark_in_progress(job_id, {"function": func.__name__}) try: # Execute function result = await func(payload, *args, **kwargs) # Mark as completed await mark_completed(job_id, result=result) return result except Exception as e: # Mark as failed (allow retry) await mark_failed(job_id, str(e), allow_retry=True) raise return wrapper return decorator