Files
microdao-daarion/services/image-gen-service/app/main.py
Apple 5290287058 feat: implement TTS, Document processing, and Memory Service /facts API
- TTS: xtts-v2 integration with voice cloning support
- Document: docling integration for PDF/DOCX/PPTX processing
- Memory Service: added /facts/upsert, /facts/{key}, /facts endpoints
- Added required dependencies (TTS, docling)
2026-01-17 08:16:37 -08:00

127 lines
3.6 KiB
Python

import base64
import io
import os
from typing import Optional
import torch
from diffusers import Flux2KleinPipeline
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
app = FastAPI(title="Image Generation Service", version="1.0.0")
class GenerateRequest(BaseModel):
prompt: str = Field(..., min_length=1)
negative_prompt: Optional[str] = None
width: int = Field(1024, ge=256, le=2048)
height: int = Field(1024, ge=256, le=2048)
num_inference_steps: int = Field(50, ge=1, le=100)
guidance_scale: float = Field(4.0, ge=0.0, le=20.0)
seed: Optional[int] = Field(None, ge=0)
MODEL_ID = os.getenv("IMAGE_GEN_MODEL", "black-forest-labs/FLUX.2-klein-base-4B")
DEVICE = os.getenv("IMAGE_GEN_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
DTYPE_ENV = os.getenv("IMAGE_GEN_DTYPE", "float16")
def _resolve_dtype() -> torch.dtype:
if DEVICE.startswith("cuda"):
return torch.float16 if DTYPE_ENV == "float16" else torch.bfloat16
return torch.float32
PIPELINE: Optional[Flux2KleinPipeline] = None
LOAD_ERROR: Optional[str] = None
def _load_pipeline() -> None:
global PIPELINE, LOAD_ERROR
try:
dtype = _resolve_dtype()
# Use bfloat16 for FLUX.2 Klein as recommended
if dtype == torch.float16 and DEVICE.startswith("cuda"):
dtype = torch.bfloat16
pipe = Flux2KleinPipeline.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
)
# Enable CPU offload to reduce VRAM usage
if DEVICE.startswith("cuda"):
pipe.enable_model_cpu_offload()
else:
pipe.to(DEVICE)
PIPELINE = pipe
LOAD_ERROR = None
except Exception as exc: # pragma: no cover - surface error via health/info
PIPELINE = None
LOAD_ERROR = str(exc)
@app.on_event("startup")
def startup_event() -> None:
_load_pipeline()
@app.get("/health")
def health() -> dict:
if LOAD_ERROR:
raise HTTPException(status_code=503, detail=LOAD_ERROR)
return {
"status": "ok",
"model_loaded": PIPELINE is not None,
"model_id": MODEL_ID,
"device": DEVICE,
"dtype": str(_resolve_dtype()).replace("torch.", ""),
}
@app.get("/info")
def info() -> dict:
return {
"model_id": MODEL_ID,
"device": DEVICE,
"dtype": str(_resolve_dtype()).replace("torch.", ""),
"pipeline_loaded": PIPELINE is not None,
"load_error": LOAD_ERROR,
}
@app.post("/generate")
def generate(payload: GenerateRequest) -> dict:
if LOAD_ERROR:
raise HTTPException(status_code=503, detail=LOAD_ERROR)
if PIPELINE is None:
raise HTTPException(status_code=503, detail="Model is not loaded yet")
generator = None
if payload.seed is not None:
generator = torch.Generator(device="cuda" if DEVICE.startswith("cuda") else "cpu")
generator.manual_seed(payload.seed)
with torch.inference_mode():
result = PIPELINE(
prompt=payload.prompt,
negative_prompt=payload.negative_prompt if payload.negative_prompt else None,
height=payload.height,
width=payload.width,
num_inference_steps=payload.num_inference_steps,
guidance_scale=payload.guidance_scale,
generator=generator,
)
image = result.images[0]
buffer = io.BytesIO()
image.save(buffer, format="PNG")
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
return {
"image_base64": encoded,
"seed": payload.seed,
"model_id": MODEL_ID,
}