from fastapi import FastAPI import asyncio import nats import logging import json from typing import Dict, Any logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) app = FastAPI() nc = None js = None running = False async def process_test_task(msg): """Process test task with mock execution""" try: data = json.loads(msg.data.decode()) job_id = data.get("job_id") or data.get("task_id") trace_id = data.get("trace_id", "") agent_id = data.get("agent_id", "helion") logger.info(f"Processing TEST task: {job_id}") headers = dict(msg.headers) if msg.headers else {} headers_lower = {str(k).lower(): v for k, v in headers.items()} replayed = headers_lower.get("replayed") == "true" # Forced fail path for DLQ replay validation (test_mode only) if data.get("force_fail") and not replayed: fail_payload = { "status": "failed", "job_id": job_id, "trace_id": trace_id, "agent_id": agent_id, "error": "forced_fail", "original_subject": msg.subject, "data": data } await js.publish( "agent.run.failed.dlq", json.dumps(fail_payload).encode(), headers={**headers, "replay_count": "0"} ) logger.info(f"⚠️ Forced fail sent to DLQ: {job_id}") await msg.ack() return # Mock execution await asyncio.sleep(0.1) result = { "status": "completed", "job_id": job_id, "trace_id": trace_id, "agent_id": agent_id, "result": "Mock test execution completed", "test_mode": True } # Publish completion completion_headers = { "Nats-Trace-ID": trace_id, "Nats-Job-ID": job_id, "Nats-Agent-ID": agent_id } if headers_lower.get("replayed") is not None: completion_headers["replayed"] = headers_lower.get("replayed") if headers_lower.get("replay_count") is not None: completion_headers["replay_count"] = headers_lower.get("replay_count") if headers_lower.get("original_subject") is not None: completion_headers["original_subject"] = headers_lower.get("original_subject") if headers_lower.get("original_msg_id") is not None: completion_headers["original_msg_id"] = headers_lower.get("original_msg_id") await js.publish( "agent.run.completed.helion", json.dumps(result).encode(), headers=completion_headers ) logger.info(f"✅ TEST task completed: {job_id}") await msg.ack() except Exception as e: logger.error(f"Test task failed: {e}") await msg.nak() async def worker_loop(): """Worker loop - processes messages""" global nc, js, running try: # Connect (ignore errors, try to continue) try: nc = await nats.connect("nats://nats:4222") js = nc.jetstream() logger.info("✅ Connected to NATS") except Exception as e: logger.warning(f"NATS connection issue (will retry): {e}") await asyncio.sleep(5) # Retry connection asyncio.create_task(worker_loop()) return if not js: return running = True # Subscribe to messages (ephemeral consumer) try: sub = await js.pull_subscribe( "agent.run.requested", None, # Ephemeral consumer stream=None ) logger.info("✅ Subscribed to agent.run.requested") except Exception as e: logger.warning(f"Subscription failed: {e}") return while running: try: msgs = await sub.fetch(5, timeout=5) for msg in msgs: data = json.loads(msg.data.decode()) # Only process test messages if data.get("workflow_type") == "test" or data.get("test_mode"): await process_test_task(msg) else: await msg.ack() # ACK but don't process non-test except asyncio.TimeoutError: pass except Exception as e: logger.error(f"Worker loop error: {e}") await asyncio.sleep(1) except Exception as e: logger.error(f"Worker failed: {e}") @app.on_event("startup") async def startup(): """Start worker on startup""" asyncio.create_task(worker_loop()) @app.on_event("shutdown") async def shutdown(): """Stop worker on shutdown""" global running, nc running = False if nc: await nc.close() @app.get("/health") async def health(): return {"status": "ok", "running": running, "connected": nc is not None}