import base64 import io import os from typing import Optional import torch from diffusers import Flux2KleinPipeline from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field app = FastAPI(title="Image Generation Service", version="1.0.0") class GenerateRequest(BaseModel): prompt: str = Field(..., min_length=1) negative_prompt: Optional[str] = None width: int = Field(1024, ge=256, le=2048) height: int = Field(1024, ge=256, le=2048) num_inference_steps: int = Field(50, ge=1, le=100) guidance_scale: float = Field(4.0, ge=0.0, le=20.0) seed: Optional[int] = Field(None, ge=0) MODEL_ID = os.getenv("IMAGE_GEN_MODEL", "black-forest-labs/FLUX.2-klein-base-4B") DEVICE = os.getenv("IMAGE_GEN_DEVICE", "cuda" if torch.cuda.is_available() else "cpu") DTYPE_ENV = os.getenv("IMAGE_GEN_DTYPE", "float16") def _resolve_dtype() -> torch.dtype: if DEVICE.startswith("cuda"): return torch.float16 if DTYPE_ENV == "float16" else torch.bfloat16 return torch.float32 PIPELINE: Optional[Flux2KleinPipeline] = None LOAD_ERROR: Optional[str] = None def _load_pipeline() -> None: global PIPELINE, LOAD_ERROR try: dtype = _resolve_dtype() # Use bfloat16 for FLUX.2 Klein as recommended if dtype == torch.float16 and DEVICE.startswith("cuda"): dtype = torch.bfloat16 pipe = Flux2KleinPipeline.from_pretrained( MODEL_ID, torch_dtype=dtype, ) # Enable CPU offload to reduce VRAM usage if DEVICE.startswith("cuda"): pipe.enable_model_cpu_offload() else: pipe.to(DEVICE) PIPELINE = pipe LOAD_ERROR = None except Exception as exc: # pragma: no cover - surface error via health/info PIPELINE = None LOAD_ERROR = str(exc) @app.on_event("startup") def startup_event() -> None: _load_pipeline() @app.get("/health") def health() -> dict: if LOAD_ERROR: raise HTTPException(status_code=503, detail=LOAD_ERROR) return { "status": "ok", "model_loaded": PIPELINE is not None, "model_id": MODEL_ID, "device": DEVICE, "dtype": str(_resolve_dtype()).replace("torch.", ""), } @app.get("/info") def info() -> dict: return { "model_id": MODEL_ID, "device": DEVICE, "dtype": str(_resolve_dtype()).replace("torch.", ""), "pipeline_loaded": PIPELINE is not None, "load_error": LOAD_ERROR, } @app.post("/generate") def generate(payload: GenerateRequest) -> dict: if LOAD_ERROR: raise HTTPException(status_code=503, detail=LOAD_ERROR) if PIPELINE is None: raise HTTPException(status_code=503, detail="Model is not loaded yet") generator = None if payload.seed is not None: generator = torch.Generator(device="cuda" if DEVICE.startswith("cuda") else "cpu") generator.manual_seed(payload.seed) with torch.inference_mode(): result = PIPELINE( prompt=payload.prompt, negative_prompt=payload.negative_prompt if payload.negative_prompt else None, height=payload.height, width=payload.width, num_inference_steps=payload.num_inference_steps, guidance_scale=payload.guidance_scale, generator=generator, ) image = result.images[0] buffer = io.BytesIO() image.save(buffer, format="PNG") encoded = base64.b64encode(buffer.getvalue()).decode("ascii") return { "image_base64": encoded, "seed": payload.seed, "model_id": MODEL_ID, }