119 lines
4.8 KiB
Python
119 lines
4.8 KiB
Python
# services/comfy-agent/app/worker.py
|
|
import asyncio
|
|
import uuid
|
|
import os
|
|
import json
|
|
from typing import Any, Dict, Optional, Tuple
|
|
from .jobs import JOB_STORE
|
|
from .storage import make_job_dir, publish_result_url
|
|
from .comfyui_client import ComfyUIClient
|
|
from .config import settings
|
|
|
|
_queue: "asyncio.Queue[Tuple[str, str, Dict[str, Any]]]" = asyncio.Queue()
|
|
|
|
def enqueue(job_id: str, gen_type: str, prompt_graph: Dict[str, Any]) -> None:
|
|
_queue.put_nowait((job_id, gen_type, prompt_graph))
|
|
|
|
async def _extract_first_output(history: Dict[str, Any], job_dir: str, client: ComfyUIClient) -> Optional[str]:
|
|
# Keep full history for debugging/reproducibility.
|
|
manifest_path = os.path.join(job_dir, "manifest.json")
|
|
with open(manifest_path, "w", encoding="utf-8") as f:
|
|
json.dump(history, f, ensure_ascii=False, indent=2)
|
|
|
|
def _iter_assets() -> list[Dict[str, Any]]:
|
|
assets: list[Dict[str, Any]] = []
|
|
for prompt_data in history.values():
|
|
outputs = prompt_data.get("outputs", {}) if isinstance(prompt_data, dict) else {}
|
|
for node_out in outputs.values():
|
|
if not isinstance(node_out, dict):
|
|
continue
|
|
for key in ("images", "gifs", "videos"):
|
|
for item in node_out.get(key, []) or []:
|
|
if isinstance(item, dict) and item.get("filename"):
|
|
assets.append(item)
|
|
return assets
|
|
|
|
assets = _iter_assets()
|
|
if not assets:
|
|
return None
|
|
|
|
first = assets[0]
|
|
filename = os.path.basename(first.get("filename", "output.bin"))
|
|
params = {
|
|
"filename": first.get("filename"),
|
|
"subfolder": first.get("subfolder", ""),
|
|
"type": first.get("type", "output"),
|
|
}
|
|
|
|
try:
|
|
resp = await client.http.get("/view", params=params)
|
|
resp.raise_for_status()
|
|
out_path = os.path.join(job_dir, filename)
|
|
with open(out_path, "wb") as f:
|
|
f.write(resp.content)
|
|
return filename
|
|
except Exception:
|
|
# Fallback remains manifest-only if /view download fails.
|
|
return None
|
|
|
|
|
|
def _extract_history_error(history: Dict[str, Any]) -> Optional[str]:
|
|
for prompt_data in history.values():
|
|
if not isinstance(prompt_data, dict):
|
|
continue
|
|
status = prompt_data.get("status", {})
|
|
if isinstance(status, dict) and status.get("status_str") == "error":
|
|
messages = status.get("messages", [])
|
|
for item in messages:
|
|
if not (isinstance(item, list) and len(item) >= 2):
|
|
continue
|
|
if item[0] != "execution_error":
|
|
continue
|
|
payload = item[1] if isinstance(item[1], dict) else {}
|
|
msg = payload.get("exception_message") or payload.get("exception_type")
|
|
if msg:
|
|
return str(msg).strip()
|
|
return "comfy_execution_error"
|
|
return None
|
|
|
|
async def worker_loop() -> None:
|
|
client = ComfyUIClient()
|
|
sem = asyncio.Semaphore(settings.MAX_CONCURRENCY)
|
|
|
|
async def run_one(job_id: str, gen_type: str, prompt_graph: Dict[str, Any]) -> None:
|
|
async with sem:
|
|
JOB_STORE.update(job_id, status="running", progress=0.01)
|
|
|
|
client_id = f"comfy-agent-{uuid.uuid4().hex}"
|
|
def on_p(p: float, msg: str) -> None:
|
|
JOB_STORE.update(job_id, progress=float(p), message=msg)
|
|
|
|
try:
|
|
prompt_id = await client.queue_prompt(prompt_graph, client_id=client_id)
|
|
JOB_STORE.update(job_id, comfy_prompt_id=prompt_id)
|
|
|
|
await client.wait_progress(client_id=client_id, prompt_id=prompt_id, on_progress=on_p)
|
|
|
|
hist = await client.get_history(prompt_id)
|
|
job_dir = make_job_dir(job_id)
|
|
hist_error = _extract_history_error(hist)
|
|
if hist_error:
|
|
await _extract_first_output(hist, job_dir, client)
|
|
JOB_STORE.update(job_id, status="failed", message="failed", error=hist_error)
|
|
return
|
|
fname = await _extract_first_output(hist, job_dir, client)
|
|
if not fname:
|
|
JOB_STORE.update(job_id, status="failed", message="failed", error="No outputs found in ComfyUI history")
|
|
return
|
|
|
|
local_path = os.path.join(job_dir, fname)
|
|
url = publish_result_url(job_id, fname, local_path)
|
|
JOB_STORE.update(job_id, status="succeeded", progress=1.0, result_url=url)
|
|
|
|
except Exception as e:
|
|
JOB_STORE.update(job_id, status="failed", message="failed", error=str(e))
|
|
|
|
while True:
|
|
job_id, gen_type, prompt_graph = await _queue.get()
|
|
asyncio.create_task(run_one(job_id, gen_type, prompt_graph))
|