Files

282 lines
8.7 KiB
Python

# services/comfy-agent/app/api.py
import hashlib
import json
from fastapi import APIRouter, Header, HTTPException
from .models import GenerateImageRequest, GenerateVideoRequest, JobStatus
from .jobs import JOB_STORE
from .worker import enqueue
from . import idempotency
from .config import settings
router = APIRouter()
def _req_hash(gen_type: str, payload: dict) -> str:
normalized = json.dumps({"type": gen_type, "payload": payload}, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(normalized.encode("utf-8")).hexdigest()
def _resolve_idempotency_key(
*,
header_key: str | None,
body_key: str | None,
) -> str | None:
key = (header_key or body_key or "").strip()
return key or None
def _create_job_with_idempotency(
*,
gen_type: str,
idem_key: str | None,
req_hash: str,
) -> tuple[JobStatus, bool]:
"""
Returns:
(job_status, should_enqueue)
"""
if not idem_key or idempotency.IDEMPOTENCY_STORE is None:
return JOB_STORE.create(gen_type), True
candidate_job_id = JOB_STORE.new_job_id()
result = idempotency.IDEMPOTENCY_STORE.reserve(
idem_key=idem_key,
gen_type=gen_type,
req_hash=req_hash,
job_id=candidate_job_id,
)
if result.decision == "conflict":
raise HTTPException(status_code=409, detail="idempotency_key_reused_with_different_payload")
if result.decision == "exists":
existing = JOB_STORE.get(result.job_id)
if existing:
return existing, False
# If process was restarted and in-memory JOB_STORE was lost, return queued placeholder.
return JOB_STORE.create(gen_type, job_id=result.job_id), False
return JOB_STORE.create(gen_type, job_id=result.job_id), True
def _build_workflow_t2i(req: GenerateImageRequest) -> dict:
# Basic SD 1.5 workflow
# Node structure: CheckpointLoader -> CLIP Encode -> KSampler -> VAE Decode -> SaveImage
return {
"3": {
"inputs": {
"seed": req.seed if req.seed else 42,
"steps": req.steps,
"cfg": 7.0,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": ["4", 0],
"positive": ["6", 0],
"negative": ["7", 0],
"latent_image": ["5", 0]
},
"class_type": "KSampler"
},
"4": {
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
},
"class_type": "CheckpointLoaderSimple"
},
"5": {
"inputs": {
"width": req.width,
"height": req.height,
"batch_size": 1
},
"class_type": "EmptyLatentImage"
},
"6": {
"inputs": {
"text": req.prompt,
"clip": ["4", 1]
},
"class_type": "CLIPTextEncode"
},
"7": {
"inputs": {
"text": req.negative_prompt if req.negative_prompt else "text, watermark, blurry",
"clip": ["4", 1]
},
"class_type": "CLIPTextEncode"
},
"8": {
"inputs": {
"samples": ["3", 0],
"vae": ["4", 2]
},
"class_type": "VAEDecode"
},
"9": {
"inputs": {
"filename_prefix": "comfy-agent",
"images": ["8", 0]
},
"class_type": "SaveImage"
}
}
def _build_workflow_t2v(req: GenerateVideoRequest) -> dict:
if not settings.LTX_TEXT_ENCODER:
raise HTTPException(status_code=503, detail="ltx_text_encoder_not_configured")
frame_rate = float(max(1, req.fps))
length = req.frames if req.frames and req.frames > 0 else (max(1, req.seconds) * max(1, req.fps) + 1)
neg = req.negative_prompt if req.negative_prompt else "low quality, worst quality, deformed, distorted, disfigured, motion artifacts"
# LTX-2 text-to-video pipeline with SaveVideo output node.
return {
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": settings.LTX_CKPT_NAME,
},
},
"5": {
"class_type": "LTXAVTextEncoderLoader",
"inputs": {
"text_encoder": settings.LTX_TEXT_ENCODER,
"ckpt_name": settings.LTX_CKPT_NAME,
"device": settings.LTX_DEVICE,
},
},
"6": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": req.prompt,
"clip": ["5", 0],
},
},
"7": {
"class_type": "CLIPTextEncode",
"inputs": {
"text": neg,
"clip": ["5", 0],
},
},
"69": {
"class_type": "LTXVConditioning",
"inputs": {
"positive": ["6", 0],
"negative": ["7", 0],
"frame_rate": frame_rate,
},
},
"70": {
"class_type": "EmptyLTXVLatentVideo",
"inputs": {
"width": req.width,
"height": req.height,
"length": length,
"batch_size": 1,
},
},
"73": {
"class_type": "KSamplerSelect",
"inputs": {
"sampler_name": settings.LTX_SAMPLER,
},
},
"71": {
"class_type": "LTXVScheduler",
"inputs": {
"steps": req.steps,
"max_shift": settings.LTX_MAX_SHIFT,
"base_shift": settings.LTX_BASE_SHIFT,
"stretch": settings.LTX_STRETCH,
"terminal": settings.LTX_TERMINAL,
},
},
"72": {
"class_type": "SamplerCustom",
"inputs": {
"model": ["4", 0],
"add_noise": True,
"noise_seed": req.seed if req.seed is not None else 42,
"cfg": req.cfg,
"positive": ["69", 0],
"negative": ["69", 1],
"sampler": ["73", 0],
"sigmas": ["71", 0],
"latent_image": ["70", 0],
},
},
"8": {
"class_type": "VAEDecode",
"inputs": {
"samples": ["72", 0],
"vae": ["4", 2],
},
},
"78": {
"class_type": "CreateVideo",
"inputs": {
"images": ["8", 0],
"fps": frame_rate,
},
},
"79": {
"class_type": "SaveVideo",
"inputs": {
"video": ["78", 0],
"filename_prefix": "comfy-agent/video",
"format": req.format,
"codec": req.codec,
},
},
}
@router.post("/generate/image", response_model=JobStatus)
async def generate_image(
req: GenerateImageRequest,
idempotency_key: str | None = Header(default=None, alias="Idempotency-Key"),
):
idem_key = _resolve_idempotency_key(header_key=idempotency_key, body_key=req.idempotency_key)
req_hash = _req_hash("text-to-image", req.model_dump(mode="json", exclude={"idempotency_key"}))
job, should_enqueue = _create_job_with_idempotency(
gen_type="text-to-image",
idem_key=idem_key,
req_hash=req_hash,
)
graph = _build_workflow_t2i(req)
if should_enqueue:
enqueue(job.job_id, "text-to-image", graph)
return JOB_STORE.get(job.job_id)
@router.post("/generate/video", response_model=JobStatus)
async def generate_video(
req: GenerateVideoRequest,
idempotency_key: str | None = Header(default=None, alias="Idempotency-Key"),
):
idem_key = _resolve_idempotency_key(header_key=idempotency_key, body_key=req.idempotency_key)
req_hash = _req_hash("text-to-video", req.model_dump(mode="json", exclude={"idempotency_key"}))
job, should_enqueue = _create_job_with_idempotency(
gen_type="text-to-video",
idem_key=idem_key,
req_hash=req_hash,
)
graph = _build_workflow_t2v(req)
if should_enqueue:
enqueue(job.job_id, "text-to-video", graph)
return JOB_STORE.get(job.job_id)
@router.get("/status/{job_id}", response_model=JobStatus)
async def status(job_id: str):
job = JOB_STORE.get(job_id)
if not job:
raise HTTPException(status_code=404, detail="job_not_found")
return job
@router.get("/result/{job_id}", response_model=JobStatus)
async def result(job_id: str):
job = JOB_STORE.get(job_id)
if not job:
raise HTTPException(status_code=404, detail="job_not_found")
return job