Files
microdao-daarion/services/swapper-service/app/main.py

2516 lines
96 KiB
Python

"""
Swapper Service - Dynamic Model Loading Service
Manages loading/unloading LLM and OCR models on-demand to optimize memory usage.
Supports:
- Ollama models (LLM, Vision, Math)
- HuggingFace models (OCR, Document Understanding)
- Lazy loading for OCR models
"""
import os
import asyncio
import logging
import base64
import json
import re
from typing import Optional, Dict, List, Any, Union
from datetime import datetime, timedelta
from enum import Enum
from io import BytesIO
import xml.etree.ElementTree as ET
from fastapi import FastAPI, HTTPException, BackgroundTasks, File, UploadFile, Form
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import httpx
import csv
import zipfile
from io import BytesIO
import mimetypes
import yaml
# Optional imports for HuggingFace models
try:
import torch
from PIL import Image
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
torch = None
Image = None
logger = logging.getLogger(__name__)
# ========== Document Helpers ==========
def _decode_text_bytes(content: bytes) -> str:
"""Decode text with best-effort fallback."""
try:
import chardet
detected = chardet.detect(content)
encoding = detected.get("encoding") or "utf-8"
return content.decode(encoding, errors="replace")
except Exception:
try:
return content.decode("utf-8", errors="replace")
except Exception:
return content.decode("latin-1", errors="replace")
def _csv_to_markdown(content: bytes) -> str:
text = _decode_text_bytes(content)
reader = csv.reader(text.splitlines())
rows = list(reader)
return _rows_to_markdown(rows)
def _tsv_to_markdown(content: bytes) -> str:
text = _decode_text_bytes(content)
reader = csv.reader(text.splitlines(), delimiter="\t")
rows = list(reader)
return _rows_to_markdown(rows)
def _rows_to_markdown(rows: List[List[Any]]) -> str:
if not rows:
return ""
width = max(len(r) for r in rows)
norm_rows = []
for r in rows:
rr = [str(c) if c is not None else "" for c in r]
if len(rr) < width:
rr.extend([""] * (width - len(rr)))
norm_rows.append(rr)
header = norm_rows[0]
body = norm_rows[1:]
lines = [
"| " + " | ".join(header) + " |",
"| " + " | ".join(["---"] * len(header)) + " |",
]
for row in body:
lines.append("| " + " | ".join([str(c) if c is not None else "" for c in row]) + " |")
return "\n".join(lines)
def _xlsx_to_markdown(content: bytes) -> str:
try:
import openpyxl
except Exception as e:
raise HTTPException(status_code=500, detail=f"openpyxl not available: {e}")
wb = openpyxl.load_workbook(filename=BytesIO(content), data_only=True)
parts = []
for sheet in wb.worksheets:
parts.append(f"## Sheet: {sheet.title}")
rows = list(sheet.iter_rows(values_only=True))
if not rows:
parts.append("_Empty sheet_")
continue
header = [str(c) if c is not None else "" for c in rows[0]]
body = rows[1:]
parts.append("| " + " | ".join(header) + " |")
parts.append("| " + " | ".join(["---"] * len(header)) + " |")
for row in body:
parts.append("| " + " | ".join([str(c) if c is not None else "" for c in row]) + " |")
return "\n".join(parts)
def _xls_to_markdown(content: bytes) -> str:
try:
import xlrd
except Exception as e:
raise HTTPException(status_code=500, detail=f"xlrd not available: {e}")
wb = xlrd.open_workbook(file_contents=content)
parts = []
for s in wb.sheets():
parts.append(f"## Sheet: {s.name}")
rows = []
for r in range(s.nrows):
rows.append([s.cell_value(r, c) for c in range(s.ncols)])
if not rows:
parts.append("_Empty sheet_")
continue
parts.append(_rows_to_markdown(rows))
return "\n\n".join(parts)
def _ods_to_markdown(content: bytes) -> str:
try:
from odf.opendocument import load
from odf.table import Table, TableRow, TableCell
from odf.text import P
except Exception as e:
raise HTTPException(status_code=500, detail=f"odfpy not available: {e}")
try:
doc = load(BytesIO(content))
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid ODS file: {e}")
parts = []
for table in doc.spreadsheet.getElementsByType(Table):
table_name = str(table.getAttribute("name") or "Sheet")
parts.append(f"## Sheet: {table_name}")
rows: List[List[str]] = []
for row in table.getElementsByType(TableRow):
cells_out: List[str] = []
for cell in row.getElementsByType(TableCell):
txt_parts = []
for p in cell.getElementsByType(P):
txt_parts.extend(
[str(getattr(node, "data", "")).strip() for node in p.childNodes if getattr(node, "data", None)]
)
cell_text = " ".join([t for t in txt_parts if t]).strip()
repeat_raw = cell.getAttribute("numbercolumnsrepeated")
try:
repeat = int(repeat_raw) if repeat_raw else 1
except Exception:
repeat = 1
repeat = max(1, min(repeat, 100))
for _ in range(repeat):
cells_out.append(cell_text)
if cells_out:
rows.append(cells_out)
if not rows:
parts.append("_Empty sheet_")
continue
parts.append(_rows_to_markdown(rows))
return "\n\n".join(parts)
def _docx_to_text(content: bytes) -> str:
try:
from docx import Document
except Exception as e:
raise HTTPException(status_code=500, detail=f"python-docx not available: {e}")
doc = Document(BytesIO(content))
lines = [p.text for p in doc.paragraphs if p.text]
return "\n".join(lines)
def _pdf_to_text(content: bytes) -> str:
try:
import pdfplumber
except Exception as e:
raise HTTPException(status_code=500, detail=f"pdfplumber not available: {e}")
text_content = []
with pdfplumber.open(BytesIO(content)) as pdf:
for page in pdf.pages:
page_text = page.extract_text() or ""
if page_text:
text_content.append(page_text)
return "\n\n".join(text_content)
def _pptx_to_text(content: bytes) -> str:
try:
from pptx import Presentation
except Exception as e:
raise HTTPException(status_code=500, detail=f"python-pptx not available: {e}")
prs = Presentation(BytesIO(content))
parts = []
for idx, slide in enumerate(prs.slides, start=1):
parts.append(f"## Slide {idx}")
slide_lines = []
for shape in slide.shapes:
text = getattr(shape, "text", None)
if text and str(text).strip():
slide_lines.append(str(text).strip())
parts.extend(slide_lines if slide_lines else ["_No text on this slide_"])
return "\n\n".join(parts)
def _json_to_text(content: bytes) -> str:
raw = _decode_text_bytes(content)
try:
parsed = json.loads(raw)
return json.dumps(parsed, ensure_ascii=False, indent=2)
except Exception:
return raw
def _yaml_to_text(content: bytes) -> str:
raw = _decode_text_bytes(content)
try:
parsed = yaml.safe_load(raw)
return yaml.safe_dump(parsed, allow_unicode=True, sort_keys=False)
except Exception:
return raw
def _xml_to_text(content: bytes) -> str:
raw = _decode_text_bytes(content)
try:
root = ET.fromstring(raw)
text = " ".join([t.strip() for t in root.itertext() if t and t.strip()])
return text or raw
except Exception:
return raw
def _html_to_text(content: bytes) -> str:
raw = _decode_text_bytes(content)
try:
from bs4 import BeautifulSoup
soup = BeautifulSoup(raw, "html.parser")
text = soup.get_text(separator="\n")
text = re.sub(r"\n{3,}", "\n\n", text)
return text.strip() or raw
except Exception:
# Minimal fallback if bs4 is unavailable
text = re.sub(r"<[^>]+>", " ", raw)
text = re.sub(r"\s+", " ", text)
return text.strip()
def _rtf_to_text(content: bytes) -> str:
raw = _decode_text_bytes(content)
try:
from striprtf.striprtf import rtf_to_text
return rtf_to_text(raw)
except Exception:
# Basic fallback: strip common RTF control tokens
text = re.sub(r"\\'[0-9a-fA-F]{2}", " ", raw)
text = re.sub(r"\\[a-zA-Z]+-?\d* ?", " ", text)
text = text.replace("{", " ").replace("}", " ")
return re.sub(r"\s+", " ", text).strip()
def _extract_text_by_ext(filename: str, content: bytes) -> str:
ext = filename.split(".")[-1].lower() if "." in filename else ""
if ext in ["txt", "md", "markdown"]:
return _decode_text_bytes(content)
if ext == "csv":
return _csv_to_markdown(content)
if ext == "tsv":
return _tsv_to_markdown(content)
if ext in {"xlsx", "xlsm"}:
return _xlsx_to_markdown(content)
if ext == "xls":
return _xls_to_markdown(content)
if ext == "ods":
return _ods_to_markdown(content)
if ext == "docx":
return _docx_to_text(content)
if ext == "pdf":
return _pdf_to_text(content)
if ext == "pptx":
return _pptx_to_text(content)
if ext == "json":
return _json_to_text(content)
if ext in {"yaml", "yml"}:
return _yaml_to_text(content)
if ext == "xml":
return _xml_to_text(content)
if ext in {"html", "htm"}:
return _html_to_text(content)
if ext == "rtf":
return _rtf_to_text(content)
raise HTTPException(status_code=400, detail=f"Unsupported file type: .{ext}")
def _zip_to_markdown(content: bytes, max_files: int = 50, max_total_mb: int = 100) -> str:
zf = zipfile.ZipFile(BytesIO(content))
members = [m for m in zf.infolist() if not m.is_dir()]
if len(members) > max_files:
raise HTTPException(status_code=400, detail=f"ZIP has слишком много файлов: {len(members)}")
total_size = sum(m.file_size for m in members)
if total_size > max_total_mb * 1024 * 1024:
raise HTTPException(status_code=400, detail=f"ZIP слишком большой: {total_size / 1024 / 1024:.1f} MB")
parts = []
allowed_exts = {
"txt", "md", "markdown", "csv", "tsv",
"xls", "xlsx", "xlsm", "ods",
"docx", "pdf", "pptx",
"json", "yaml", "yml", "xml", "html", "htm", "rtf",
}
processed = []
skipped = []
for member in members:
name = member.filename
ext = name.split(".")[-1].lower() if "." in name else ""
if ext not in allowed_exts:
skipped.append(name)
parts.append(f"## {name}\n_Skipped unsupported file type_")
continue
file_bytes = zf.read(member)
extracted = _extract_text_by_ext(name, file_bytes)
processed.append(name)
parts.append(f"## {name}\n{extracted}")
header_lines = ["# ZIP summary", "Processed files:"]
header_lines.extend([f"- {name}" for name in processed] or ["- (none)"])
if skipped:
header_lines.append("Skipped files:")
header_lines.extend([f"- {name}" for name in skipped])
return "\n\n".join(["\n".join(header_lines), *parts])
# ========== Configuration ==========
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
SWAPPER_CONFIG_PATH = os.getenv("SWAPPER_CONFIG_PATH", "./config/swapper_config.yaml")
SWAPPER_MODE = os.getenv("SWAPPER_MODE", "single-active") # single-active or multi-active
MAX_CONCURRENT_MODELS = int(os.getenv("MAX_CONCURRENT_MODELS", "1"))
MODEL_SWAP_TIMEOUT = int(os.getenv("MODEL_SWAP_TIMEOUT", "30"))
# ========== Models ==========
class ModelStatus(str, Enum):
"""Model status"""
LOADED = "loaded"
LOADING = "loading"
UNLOADED = "unloaded"
UNLOADING = "unloading"
ERROR = "error"
class ModelBackend(str, Enum):
"""Model backend type"""
OLLAMA = "ollama"
HUGGINGFACE = "huggingface"
class ModelInfo(BaseModel):
"""Model information"""
name: str
ollama_name: str # For Ollama models
hf_name: Optional[str] = None # For HuggingFace models
backend: ModelBackend = ModelBackend.OLLAMA
type: str # llm, code, vision, math, ocr
size_gb: float
priority: str # high, medium, low
status: ModelStatus
capabilities: List[str] = [] # For OCR models
loaded_at: Optional[datetime] = None
unloaded_at: Optional[datetime] = None
total_uptime_seconds: float = 0.0
request_count: int = 0
class SwapperStatus(BaseModel):
"""Swapper service status"""
status: str
active_model: Optional[str] = None
available_models: List[str]
loaded_models: List[str]
mode: str
total_models: int
class ModelMetrics(BaseModel):
"""Model usage metrics"""
model_name: str
status: str
loaded_at: Optional[datetime] = None
uptime_hours: float
request_count: int
total_uptime_seconds: float
# ========== Swapper Service ==========
class SwapperService:
"""Swapper Service - manages model loading/unloading for Ollama and HuggingFace"""
def __init__(self):
self.models: Dict[str, ModelInfo] = {}
self.active_model: Optional[str] = None # Active LLM model
self.active_ocr_model: Optional[str] = None # Active OCR model (separate from LLM)
self.active_image_model: Optional[str] = None # Active Image Generation model
self.loading_lock = asyncio.Lock()
self.http_client = httpx.AsyncClient(timeout=300.0)
self.model_uptime: Dict[str, float] = {} # Track uptime per model
self.model_load_times: Dict[str, datetime] = {} # Track when model was loaded
# HuggingFace model instances (lazy loaded)
self.hf_models: Dict[str, Any] = {} # model_name -> model instance
self.hf_processors: Dict[str, Any] = {} # model_name -> processor/tokenizer
# Device configuration
self.device = "cuda" if TORCH_AVAILABLE and torch.cuda.is_available() else "cpu"
logger.info(f"🔧 Swapper initialized with device: {self.device}")
async def initialize(self):
"""Initialize Swapper Service - load configuration"""
config = None
try:
logger.info(f"🔧 Initializing Swapper Service...")
logger.info(f"🔧 Config path: {SWAPPER_CONFIG_PATH}")
logger.info(f"🔧 Config exists: {os.path.exists(SWAPPER_CONFIG_PATH)}")
if os.path.exists(SWAPPER_CONFIG_PATH):
with open(SWAPPER_CONFIG_PATH, 'r') as f:
config = yaml.safe_load(f)
models_config = config.get('models', {})
logger.info(f"🔧 Found {len(models_config)} models in config")
for model_key, model_config in models_config.items():
path = model_config.get('path', '')
model_type = model_config.get('type', 'llm')
capabilities = model_config.get('capabilities', [])
# Determine backend from path prefix
if path.startswith('huggingface:'):
hf_name = path.replace('huggingface:', '')
logger.info(f"🔧 Adding HuggingFace model: {model_key} -> {hf_name}")
self.models[model_key] = ModelInfo(
name=model_key,
ollama_name="",
hf_name=hf_name,
backend=ModelBackend.HUGGINGFACE,
type=model_type,
size_gb=model_config.get('size_gb', 0),
priority=model_config.get('priority', 'medium'),
capabilities=capabilities,
status=ModelStatus.UNLOADED
)
else:
ollama_name = path.replace('ollama:', '')
logger.info(f"🔧 Adding Ollama model: {model_key} -> {ollama_name}")
self.models[model_key] = ModelInfo(
name=model_key,
ollama_name=ollama_name,
backend=ModelBackend.OLLAMA,
type=model_type,
size_gb=model_config.get('size_gb', 0),
priority=model_config.get('priority', 'medium'),
capabilities=capabilities,
status=ModelStatus.UNLOADED
)
self.model_uptime[model_key] = 0.0
logger.info(f"✅ Loaded {len(self.models)} models into Swapper")
# Count by backend
ollama_count = sum(1 for m in self.models.values() if m.backend == ModelBackend.OLLAMA)
hf_count = sum(1 for m in self.models.values() if m.backend == ModelBackend.HUGGINGFACE)
logger.info(f"✅ Models: {ollama_count} Ollama, {hf_count} HuggingFace")
else:
logger.warning(f"⚠️ Config file not found: {SWAPPER_CONFIG_PATH}, using defaults")
await self._load_models_from_ollama()
logger.info(f"✅ Swapper Service initialized with {len(self.models)} models")
logger.info(f"✅ Model names: {list(self.models.keys())}")
# Load default LLM model (not OCR - those are lazy loaded)
if config:
swapper_config = config.get('swapper', {})
default_model = swapper_config.get('default_model')
lazy_load_ocr = swapper_config.get('lazy_load_ocr', True)
if default_model and default_model in self.models:
model_info = self.models[default_model]
# Only auto-load non-OCR models
if model_info.type != 'ocr' or not lazy_load_ocr:
logger.info(f"🔄 Loading default model: {default_model}")
success = await self.load_model(default_model)
if success:
logger.info(f"✅ Default model loaded: {default_model}")
else:
logger.warning(f"⚠️ Failed to load default model: {default_model}")
else:
logger.info(f"⏳ OCR model '{default_model}' will be lazy loaded on first request")
elif default_model:
logger.warning(f"⚠️ Default model '{default_model}' not found in models list")
except Exception as e:
logger.error(f"❌ Error initializing Swapper Service: {e}", exc_info=True)
import traceback
logger.error(f"❌ Traceback: {traceback.format_exc()}")
async def _load_models_from_ollama(self):
"""Load available models from Ollama"""
try:
response = await self.http_client.get(f"{OLLAMA_BASE_URL}/api/tags")
if response.status_code == 200:
data = response.json()
for model in data.get('models', []):
model_name = model.get('name', '')
# Extract base name (remove :latest, :7b, etc.)
base_name = model_name.split(':')[0]
if base_name not in self.models:
size_gb = model.get('size', 0) / (1024**3) # Convert bytes to GB
self.models[base_name] = ModelInfo(
name=base_name,
ollama_name=model_name,
type='llm', # Default type
size_gb=size_gb,
priority='medium',
status=ModelStatus.UNLOADED
)
self.model_uptime[base_name] = 0.0
logger.info(f"✅ Loaded {len(self.models)} models from Ollama")
except Exception as e:
logger.error(f"❌ Error loading models from Ollama: {e}")
async def load_model(self, model_name: str) -> bool:
"""Load a model (unload current if in single-active mode)"""
async with self.loading_lock:
try:
# Check if model exists
if model_name not in self.models:
logger.error(f"❌ Model not found: {model_name}")
return False
model_info = self.models[model_name]
# If single-active mode and another model is loaded, unload it first
if SWAPPER_MODE == "single-active" and self.active_model and self.active_model != model_name:
await self._unload_model_internal(self.active_model)
# Load the model
logger.info(f"🔄 Loading model: {model_name}")
model_info.status = ModelStatus.LOADING
# Check if model is already loaded in Ollama
response = await self.http_client.post(
f"{OLLAMA_BASE_URL}/api/generate",
json={
"model": model_info.ollama_name,
"prompt": "test",
"stream": False
},
timeout=MODEL_SWAP_TIMEOUT
)
if response.status_code == 200:
model_info.status = ModelStatus.LOADED
model_info.loaded_at = datetime.now()
model_info.unloaded_at = None
self.active_model = model_name
self.model_load_times[model_name] = datetime.now()
logger.info(f"✅ Model loaded: {model_name}")
return True
else:
model_info.status = ModelStatus.ERROR
logger.error(f"❌ Failed to load model: {model_name}")
return False
except Exception as e:
logger.error(f"❌ Error loading model {model_name}: {e}", exc_info=True)
if model_name in self.models:
self.models[model_name].status = ModelStatus.ERROR
return False
async def _unload_model_internal(self, model_name: str) -> bool:
"""Internal method to unload a model"""
try:
if model_name not in self.models:
return False
model_info = self.models[model_name]
if model_info.status == ModelStatus.LOADED:
logger.info(f"🔄 Unloading model: {model_name}")
model_info.status = ModelStatus.UNLOADING
# Calculate uptime
if model_name in self.model_load_times:
load_time = self.model_load_times[model_name]
uptime_seconds = (datetime.now() - load_time).total_seconds()
self.model_uptime[model_name] = self.model_uptime.get(model_name, 0.0) + uptime_seconds
model_info.total_uptime_seconds = self.model_uptime[model_name]
del self.model_load_times[model_name]
model_info.status = ModelStatus.UNLOADED
model_info.unloaded_at = datetime.now()
if self.active_model == model_name:
self.active_model = None
logger.info(f"✅ Model unloaded: {model_name}")
return True
except Exception as e:
logger.error(f"❌ Error unloading model {model_name}: {e}")
return False
async def unload_model(self, model_name: str) -> bool:
"""Unload a model"""
async with self.loading_lock:
return await self._unload_model_internal(model_name)
async def get_status(self) -> SwapperStatus:
"""Get Swapper service status"""
# Update uptime for currently loaded model
if self.active_model and self.active_model in self.model_load_times:
load_time = self.model_load_times[self.active_model]
current_uptime = (datetime.now() - load_time).total_seconds()
self.model_uptime[self.active_model] = self.model_uptime.get(self.active_model, 0.0) + current_uptime
self.model_load_times[self.active_model] = datetime.now() # Reset timer
loaded_models = [
name for name, model in self.models.items()
if model.status == ModelStatus.LOADED
]
return SwapperStatus(
status="healthy",
active_model=self.active_model,
available_models=list(self.models.keys()),
loaded_models=loaded_models,
mode=SWAPPER_MODE,
total_models=len(self.models)
)
async def get_model_metrics(self, model_name: Optional[str] = None) -> List[ModelMetrics]:
"""Get metrics for model(s)"""
metrics = []
models_to_check = [model_name] if model_name else list(self.models.keys())
for name in models_to_check:
if name not in self.models:
continue
model_info = self.models[name]
# Calculate current uptime
uptime_seconds = self.model_uptime.get(name, 0.0)
if name in self.model_load_times:
load_time = self.model_load_times[name]
current_uptime = (datetime.now() - load_time).total_seconds()
uptime_seconds += current_uptime
uptime_hours = uptime_seconds / 3600.0
metrics.append(ModelMetrics(
model_name=name,
status=model_info.status.value,
loaded_at=model_info.loaded_at,
uptime_hours=uptime_hours,
request_count=model_info.request_count,
total_uptime_seconds=uptime_seconds
))
return metrics
async def generate(self, model_name: str, prompt: str, system_prompt: Optional[str] = None,
max_tokens: int = 2048, temperature: float = 0.7, stream: bool = False) -> Dict[str, Any]:
"""Generate text using a model"""
try:
# Ensure model is loaded
if model_name not in self.models:
raise ValueError(f"Model not found: {model_name}")
model_info = self.models[model_name]
# Load model if not loaded
if model_info.status != ModelStatus.LOADED:
logger.info(f"🔄 Model {model_name} not loaded, loading now...")
success = await self.load_model(model_name)
if not success:
raise ValueError(f"Failed to load model: {model_name}")
# Increment request count
model_info.request_count += 1
# Prepare request to Ollama
request_data = {
"model": model_info.ollama_name,
"prompt": prompt,
"stream": stream,
"options": {
"num_predict": max_tokens,
"temperature": temperature
}
}
if system_prompt:
request_data["system"] = system_prompt
# Call Ollama
response = await self.http_client.post(
f"{OLLAMA_BASE_URL}/api/generate",
json=request_data,
timeout=300.0
)
if response.status_code == 200:
data = response.json()
return {
"response": data.get("response", ""),
"model": model_name,
"done": data.get("done", True),
"eval_count": data.get("eval_count", 0),
"prompt_eval_count": data.get("prompt_eval_count", 0)
}
else:
raise HTTPException(
status_code=response.status_code,
detail=f"Ollama error: {response.text}"
)
except Exception as e:
logger.error(f"❌ Error generating with model {model_name}: {e}", exc_info=True)
raise
async def close(self):
"""Close HTTP client and unload HuggingFace models"""
await self.http_client.aclose()
# Unload HuggingFace models to free GPU memory
for model_name in list(self.hf_models.keys()):
await self._unload_hf_model(model_name)
async def _load_hf_model(self, model_name: str) -> bool:
"""Load a HuggingFace model (OCR/Document Understanding)"""
if not TORCH_AVAILABLE:
logger.error("❌ PyTorch not available, cannot load HuggingFace models")
return False
try:
model_info = self.models[model_name]
if model_info.backend != ModelBackend.HUGGINGFACE:
logger.error(f"❌ Model {model_name} is not a HuggingFace model")
return False
hf_name = model_info.hf_name
logger.info(f"🔄 Loading HuggingFace model: {hf_name} on {self.device}")
from transformers import AutoModel, AutoTokenizer, AutoProcessor
# Different loading strategies based on model type
if "GOT-OCR" in hf_name or "got-ocr" in hf_name.lower():
# GOT-OCR2.0 specific loading
tokenizer = AutoTokenizer.from_pretrained(hf_name, trust_remote_code=True)
model = AutoModel.from_pretrained(
hf_name,
trust_remote_code=True,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
device_map="auto" if self.device == "cuda" else None,
low_cpu_mem_usage=True
)
self.hf_processors[model_name] = tokenizer
self.hf_models[model_name] = model
elif "donut" in hf_name.lower():
# Donut model loading
from transformers import DonutProcessor, VisionEncoderDecoderModel
processor = DonutProcessor.from_pretrained(hf_name)
model = VisionEncoderDecoderModel.from_pretrained(
hf_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
if self.device == "cuda":
model = model.cuda()
model.eval()
self.hf_processors[model_name] = processor
self.hf_models[model_name] = model
elif "trocr" in hf_name.lower():
# TrOCR model loading
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
processor = TrOCRProcessor.from_pretrained(hf_name)
model = VisionEncoderDecoderModel.from_pretrained(
hf_name,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
if self.device == "cuda":
model = model.cuda()
model.eval()
self.hf_processors[model_name] = processor
self.hf_models[model_name] = model
elif "flux" in hf_name.lower() or model_info.type == "image_generation":
# FLUX / Diffusion model loading
logger.info(f"🎨 Loading diffusion model: {hf_name}")
from diffusers import AutoPipelineForText2Image
pipeline = AutoPipelineForText2Image.from_pretrained(
hf_name,
torch_dtype=torch.bfloat16,
use_safetensors=True
)
pipeline.to(self.device)
pipeline.enable_model_cpu_offload() # Optimize VRAM usage
self.hf_models[model_name] = pipeline
self.hf_processors[model_name] = None # No separate processor for diffusion
logger.info(f"✅ Diffusion model loaded: {model_name} with CPU offload enabled")
else:
# Generic loading
processor = AutoProcessor.from_pretrained(hf_name, trust_remote_code=True)
model = AutoModel.from_pretrained(
hf_name,
trust_remote_code=True,
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
)
if self.device == "cuda":
model = model.cuda()
self.hf_processors[model_name] = processor
self.hf_models[model_name] = model
logger.info(f"✅ HuggingFace model loaded: {model_name}")
return True
except Exception as e:
logger.error(f"❌ Failed to load HuggingFace model {model_name}: {e}", exc_info=True)
return False
async def _unload_hf_model(self, model_name: str) -> bool:
"""Unload a HuggingFace model to free memory"""
try:
if model_name in self.hf_models:
del self.hf_models[model_name]
if model_name in self.hf_processors:
del self.hf_processors[model_name]
# Force garbage collection
if TORCH_AVAILABLE:
import gc
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info(f"✅ HuggingFace model unloaded: {model_name}")
return True
except Exception as e:
logger.error(f"❌ Failed to unload HuggingFace model {model_name}: {e}")
return False
async def ocr_process(self, model_name: str, image_data: bytes, ocr_type: str = "ocr") -> Dict[str, Any]:
"""Process image with OCR model"""
if not TORCH_AVAILABLE:
raise HTTPException(status_code=503, detail="PyTorch not available")
try:
# Ensure model is loaded
if model_name not in self.models:
raise ValueError(f"OCR model not found: {model_name}")
model_info = self.models[model_name]
if model_info.backend != ModelBackend.HUGGINGFACE:
raise ValueError(f"Model {model_name} is not an OCR model")
# Lazy load model if not loaded
if model_name not in self.hf_models:
# Unload current OCR model if different (to save VRAM)
if self.active_ocr_model and self.active_ocr_model != model_name:
await self._unload_hf_model(self.active_ocr_model)
if self.active_ocr_model in self.models:
self.models[self.active_ocr_model].status = ModelStatus.UNLOADED
logger.info(f"🔄 Lazy loading OCR model: {model_name}")
model_info.status = ModelStatus.LOADING
success = await self._load_hf_model(model_name)
if not success:
model_info.status = ModelStatus.ERROR
raise ValueError(f"Failed to load OCR model: {model_name}")
model_info.status = ModelStatus.LOADED
model_info.loaded_at = datetime.now()
self.active_ocr_model = model_name
self.model_load_times[model_name] = datetime.now()
# Process image
image = Image.open(BytesIO(image_data)).convert("RGB")
model = self.hf_models[model_name]
processor = self.hf_processors[model_name]
model_info.request_count += 1
hf_name = model_info.hf_name
# Different processing based on model type
if "GOT-OCR" in hf_name or "got-ocr" in hf_name.lower():
# GOT-OCR2.0 processing
with torch.no_grad():
result = model.chat(processor, image, ocr_type=ocr_type)
text = result if isinstance(result, str) else str(result)
elif "donut" in hf_name.lower():
# Donut processing
task_prompt = "<s_cord-v2>" # or "<s_docvqa>" depending on task
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
pixel_values = processor(image, return_tensors="pt").pixel_values
if self.device == "cuda":
pixel_values = pixel_values.cuda()
decoder_input_ids = decoder_input_ids.cuda()
with torch.no_grad():
outputs = model.generate(
pixel_values,
decoder_input_ids=decoder_input_ids,
max_length=model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True
)
sequence = processor.batch_decode(outputs.sequences)[0]
text = processor.token2json(sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, ""))
elif "trocr" in hf_name.lower():
# TrOCR processing
pixel_values = processor(images=image, return_tensors="pt").pixel_values
if self.device == "cuda":
pixel_values = pixel_values.cuda()
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=512)
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
# Generic processing (try chat method)
with torch.no_grad():
if hasattr(model, 'chat'):
result = model.chat(processor, image)
text = result if isinstance(result, str) else str(result)
else:
text = "Model does not support direct inference"
return {
"success": True,
"model": model_name,
"text": text,
"device": self.device
}
except Exception as e:
logger.error(f"❌ OCR processing failed: {e}", exc_info=True)
raise
async def image_generate(
self,
model_name: str,
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 50,
guidance_scale: float = 4.0,
width: int = 1024,
height: int = 1024
) -> Dict[str, Any]:
"""Generate image with diffusion model (lazy loaded)"""
if not TORCH_AVAILABLE:
raise HTTPException(status_code=503, detail="PyTorch not available")
import time
start_time = time.time()
try:
# Ensure model exists
if model_name not in self.models:
raise ValueError(f"Image model not found: {model_name}")
model_info = self.models[model_name]
if model_info.type != "image_generation":
raise ValueError(f"Model {model_name} is not an image generation model")
# Lazy load model if not loaded
if model_name not in self.hf_models:
# Unload current image model if different (to save VRAM)
if self.active_image_model and self.active_image_model != model_name:
logger.info(f"🔄 Unloading current image model: {self.active_image_model}")
await self._unload_hf_model(self.active_image_model)
if self.active_image_model in self.models:
self.models[self.active_image_model].status = ModelStatus.UNLOADED
logger.info(f"🎨 Lazy loading image model: {model_name}")
model_info.status = ModelStatus.LOADING
success = await self._load_hf_model(model_name)
if not success:
model_info.status = ModelStatus.ERROR
raise ValueError(f"Failed to load image model: {model_name}")
model_info.status = ModelStatus.LOADED
model_info.loaded_at = datetime.now()
self.active_image_model = model_name
self.model_load_times[model_name] = datetime.now()
# Generate image
pipeline = self.hf_models[model_name]
model_info.request_count += 1
logger.info(f"🎨 Generating image with {model_name}: {prompt[:50]}...")
with torch.no_grad():
# FLUX Klein doesn't support negative_prompt, check pipeline type
pipeline_kwargs = {
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"width": width,
"height": height,
}
# Only add negative_prompt for models that support it (not FLUX)
is_flux = "flux" in model_name.lower()
if negative_prompt and not is_flux:
pipeline_kwargs["negative_prompt"] = negative_prompt
result = pipeline(**pipeline_kwargs)
image = result.images[0]
# Convert to base64
buffered = BytesIO()
image.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
generation_time_ms = (time.time() - start_time) * 1000
logger.info(f"✅ Image generated in {generation_time_ms:.0f}ms")
return {
"success": True,
"model": model_name,
"image_base64": img_base64,
"width": width,
"height": height,
"generation_time_ms": generation_time_ms,
"device": self.device
}
except Exception as e:
logger.error(f"❌ Image generation failed: {e}", exc_info=True)
raise
# ========== FastAPI App ==========
app = FastAPI(
title="Swapper Service",
description="Dynamic model loading service for Node #2",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Include cabinet API router (import after swapper is created)
try:
from app.cabinet_api import router as cabinet_router
app.include_router(cabinet_router)
logger.info("✅ Cabinet API router included")
except ImportError:
logger.warning("⚠️ cabinet_api module not found, skipping cabinet router")
# Global Swapper instance
swapper = SwapperService()
@app.on_event("startup")
async def startup():
"""Initialize Swapper on startup"""
await swapper.initialize()
@app.on_event("shutdown")
async def shutdown():
"""Close Swapper on shutdown"""
await swapper.close()
# ========== API Endpoints ==========
@app.get("/health")
async def health():
"""Health check endpoint"""
status = await swapper.get_status()
return {
"status": "healthy",
"service": "swapper-service",
"active_model": status.active_model,
"mode": status.mode
}
@app.get("/status", response_model=SwapperStatus)
async def get_status():
"""Get Swapper service status"""
return await swapper.get_status()
@app.get("/models")
async def list_models():
"""List all available models"""
return {
"models": [
{
"name": model.name,
"ollama_name": model.ollama_name,
"type": model.type,
"size_gb": model.size_gb,
"priority": model.priority,
"status": model.status.value
}
for model in swapper.models.values()
]
}
@app.get("/models/{model_name}")
async def get_model_info(model_name: str):
"""Get information about a specific model"""
if model_name not in swapper.models:
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
model_info = swapper.models[model_name]
return {
"name": model_info.name,
"ollama_name": model_info.ollama_name,
"type": model_info.type,
"size_gb": model_info.size_gb,
"priority": model_info.priority,
"status": model_info.status.value,
"loaded_at": model_info.loaded_at.isoformat() if model_info.loaded_at else None,
"unloaded_at": model_info.unloaded_at.isoformat() if model_info.unloaded_at else None,
"total_uptime_seconds": swapper.model_uptime.get(model_name, 0.0)
}
@app.post("/models/{model_name}/load")
async def load_model_endpoint(model_name: str):
"""Load a model"""
success = await swapper.load_model(model_name)
if success:
return {"status": "success", "model": model_name, "message": f"Model {model_name} loaded"}
raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
@app.post("/models/{model_name}/unload")
async def unload_model_endpoint(model_name: str):
"""Unload a model"""
success = await swapper.unload_model(model_name)
if success:
return {"status": "success", "model": model_name, "message": f"Model {model_name} unloaded"}
raise HTTPException(status_code=500, detail=f"Failed to unload model: {model_name}")
@app.get("/metrics")
async def get_metrics(model_name: Optional[str] = None):
"""Get metrics for model(s)"""
metrics = await swapper.get_model_metrics(model_name)
return {
"metrics": [metric.dict() for metric in metrics]
}
@app.get("/metrics/{model_name}")
async def get_model_metrics(model_name: str):
"""Get metrics for a specific model"""
metrics = await swapper.get_model_metrics(model_name)
if not metrics:
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
return metrics[0].dict()
# ========== Chat Completions API (OpenAI-compatible) ==========
class ChatMessage(BaseModel):
"""Chat message"""
role: str # system, user, assistant
content: str
class ChatCompletionRequest(BaseModel):
"""Chat completion request (OpenAI-compatible)"""
model: str
messages: List[ChatMessage]
max_tokens: Optional[int] = 2048
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
class ChatCompletionResponse(BaseModel):
"""Chat completion response (OpenAI-compatible)"""
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[Dict[str, Any]]
usage: Dict[str, int]
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def chat_completions(request: ChatCompletionRequest):
"""OpenAI-compatible chat completions endpoint"""
import time
try:
# Extract system prompt and user messages
system_prompt = None
user_messages = []
for msg in request.messages:
if msg.role == "system":
system_prompt = msg.content
elif msg.role == "user":
user_messages.append(msg.content)
# Combine user messages into prompt
prompt = "\n".join(user_messages)
# Generate response
result = await swapper.generate(
model_name=request.model,
prompt=prompt,
system_prompt=system_prompt,
max_tokens=request.max_tokens or 2048,
temperature=request.temperature or 0.7,
stream=request.stream or False
)
# Format response in OpenAI style
return ChatCompletionResponse(
id=f"chatcmpl-{int(time.time())}",
created=int(time.time()),
model=request.model,
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": result["response"]
},
"finish_reason": "stop" if result.get("done", True) else None
}],
usage={
"prompt_tokens": result.get("prompt_eval_count", 0),
"completion_tokens": result.get("eval_count", 0),
"total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0)
}
)
except Exception as e:
logger.error(f"❌ Error in chat completions: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ========== Generate API (Ollama-compatible) ==========
class GenerateRequest(BaseModel):
"""Generate request (Ollama-compatible)"""
model: str
prompt: str
system: Optional[str] = None
max_tokens: Optional[int] = 2048
temperature: Optional[float] = 0.7
stream: Optional[bool] = False
@app.post("/generate")
async def generate(request: GenerateRequest):
"""Ollama-compatible generate endpoint"""
try:
result = await swapper.generate(
model_name=request.model,
prompt=request.prompt,
system_prompt=request.system,
max_tokens=request.max_tokens or 2048,
temperature=request.temperature or 0.7,
stream=request.stream or False
)
return result
except Exception as e:
logger.error(f"❌ Error in generate: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ========== VISION API Endpoints ==========
class VisionRequest(BaseModel):
"""Vision (image description) request"""
model: str = "qwen3-vl-8b"
prompt: str = "Опиши це зображення коротко (2-3 речення)."
images: List[str] # List of base64 encoded images (can include data: prefix)
system: Optional[str] = None
max_tokens: int = 1024
temperature: float = 0.7
@app.post("/vision")
async def vision_endpoint(request: VisionRequest):
"""
Vision endpoint - analyze images with Vision-Language models.
Models:
- qwen3-vl-8b: Qwen3 Vision-Language model (8GB VRAM)
Images should be base64 encoded. Can include data:image/... prefix or raw base64.
"""
try:
import time
start_time = time.time()
model_name = request.model
# Convert data URLs to raw base64 (Ollama expects base64 without prefix)
processed_images = []
for img in request.images:
if img.startswith("data:"):
# Extract base64 part from data URL
base64_part = img.split(",", 1)[1] if "," in img else img
processed_images.append(base64_part)
else:
processed_images.append(img)
logger.info(f"🖼️ Vision request: model={model_name}, images={len(processed_images)}, prompt={request.prompt[:50]}...")
# Map model name to Ollama model
ollama_model = model_name.replace("-", ":") # qwen3-vl-8b -> qwen3:vl-8b
if model_name == "qwen3-vl-8b":
ollama_model = "qwen3-vl:8b"
# Build Ollama request
ollama_payload = {
"model": ollama_model,
"prompt": request.prompt,
"images": processed_images,
"stream": False,
"options": {
"num_predict": request.max_tokens,
"temperature": request.temperature
}
}
if request.system:
ollama_payload["system"] = request.system
# Send to Ollama
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
f"{OLLAMA_BASE_URL}/api/generate",
json=ollama_payload
)
if response.status_code != 200:
logger.error(f"❌ Ollama vision error: {response.status_code} - {response.text[:200]}")
raise HTTPException(status_code=500, detail=f"Ollama error: {response.status_code}")
result = response.json()
vision_text = result.get("response", "")
# Debug logging
if not vision_text:
logger.warning(f"⚠️ Empty response from Ollama! Result keys: {list(result.keys())}, error: {result.get('error', 'none')}")
processing_time_ms = (time.time() - start_time) * 1000
logger.info(f"✅ Vision response: {len(vision_text)} chars in {processing_time_ms:.0f}ms")
return {
"success": True,
"model": model_name,
"text": vision_text,
"processing_time_ms": processing_time_ms,
"images_count": len(processed_images)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Vision endpoint error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/vision/models")
async def vision_models():
"""List available vision models"""
vision_models = [m for m in swapper.models.values() if m.type == "vision"]
return {
"models": [
{
"name": m.name,
"type": m.type,
"status": m.status.value,
"size_gb": m.size_gb
}
for m in vision_models
]
}
# ========== OCR API Endpoints ==========
class OCRRequest(BaseModel):
"""OCR request"""
model: str = "got-ocr2" # Default to GOT-OCR2.0
image_base64: Optional[str] = None
image_url: Optional[str] = None
ocr_type: str = "ocr" # ocr, format, table
@app.post("/ocr")
async def ocr_endpoint(
request: OCRRequest = None,
file: UploadFile = File(None),
model: str = Form("got-ocr2"),
ocr_type: str = Form("ocr")
):
"""
OCR endpoint - process images with OCR models.
Models:
- got-ocr2: Best for documents, tables, formulas (7GB VRAM)
- donut-base: Document parsing without OCR (3GB VRAM)
- donut-cord: Receipt/invoice parsing (3GB VRAM)
- trocr-base: Fast printed text OCR (2GB VRAM)
OCR Types (for GOT-OCR2.0):
- ocr: Standard OCR
- format: Preserve formatting
- table: Extract tables
"""
try:
image_data = None
model_name = model
ocr_type_param = ocr_type
# Get image from request
if file:
image_data = await file.read()
elif request:
model_name = request.model
ocr_type_param = request.ocr_type
if request.image_base64:
image_data = base64.b64decode(request.image_base64)
elif request.image_url:
async with httpx.AsyncClient() as client:
response = await client.get(request.image_url)
if response.status_code == 200:
image_data = response.content
else:
raise HTTPException(status_code=400, detail="Failed to download image")
if not image_data:
raise HTTPException(status_code=400, detail="No image provided")
# Process with OCR
result = await swapper.ocr_process(model_name, image_data, ocr_type_param)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"❌ OCR endpoint error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/ocr/models")
async def list_ocr_models():
"""List available OCR models"""
ocr_models = [
{
"name": model.name,
"hf_name": model.hf_name,
"type": model.type,
"size_gb": model.size_gb,
"priority": model.priority,
"status": model.status.value,
"capabilities": model.capabilities,
"loaded": model.name in swapper.hf_models
}
for model in swapper.models.values()
if model.backend == ModelBackend.HUGGINGFACE and model.type == "ocr"
]
return {
"ocr_models": ocr_models,
"active_ocr_model": swapper.active_ocr_model,
"device": swapper.device
}
@app.post("/ocr/models/{model_name}/load")
async def load_ocr_model(model_name: str):
"""Pre-load an OCR model (optional, models are lazy loaded by default)"""
if model_name not in swapper.models:
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
model_info = swapper.models[model_name]
if model_info.backend != ModelBackend.HUGGINGFACE:
raise HTTPException(status_code=400, detail="Not an OCR model")
async with swapper.loading_lock:
success = await swapper._load_hf_model(model_name)
if success:
model_info.status = ModelStatus.LOADED
model_info.loaded_at = datetime.now()
swapper.active_ocr_model = model_name
swapper.model_load_times[model_name] = datetime.now()
return {"status": "success", "model": model_name, "message": f"OCR model {model_name} loaded"}
raise HTTPException(status_code=500, detail=f"Failed to load OCR model: {model_name}")
@app.post("/ocr/models/{model_name}/unload")
async def unload_ocr_model(model_name: str):
"""Unload an OCR model to free GPU memory"""
if model_name not in swapper.hf_models:
raise HTTPException(status_code=400, detail=f"Model not loaded: {model_name}")
async with swapper.loading_lock:
success = await swapper._unload_hf_model(model_name)
if success:
if model_name in swapper.models:
swapper.models[model_name].status = ModelStatus.UNLOADED
if swapper.active_ocr_model == model_name:
swapper.active_ocr_model = None
return {"status": "success", "model": model_name, "message": f"OCR model {model_name} unloaded"}
raise HTTPException(status_code=500, detail=f"Failed to unload OCR model: {model_name}")
# ========== STT (Speech-to-Text) API Endpoints ==========
class STTRequest(BaseModel):
"""STT request"""
model: str = "faster-whisper-large"
audio_base64: Optional[str] = None
audio_url: Optional[str] = None
language: Optional[str] = None # auto-detect if not specified
task: str = "transcribe" # transcribe or translate
@app.post("/stt")
async def stt_endpoint(
file: UploadFile = File(None),
model: str = Form("faster-whisper-large"),
language: Optional[str] = Form(None),
task: str = Form("transcribe")
):
"""
Speech-to-Text endpoint using Faster Whisper.
Models:
- faster-whisper-large: Best quality, 99 languages (3GB VRAM)
- whisper-small: Fast transcription (0.5GB VRAM)
"""
import tempfile
import os
try:
audio_data = None
if file:
audio_data = await file.read()
if not audio_data:
raise HTTPException(status_code=400, detail="No audio provided")
# Save audio to temp file (faster-whisper requires file path)
with tempfile.NamedTemporaryFile(delete=False, suffix=".ogg") as tmp_file:
tmp_file.write(audio_data)
tmp_path = tmp_file.name
try:
# Lazy load faster-whisper model
stt_model = await _get_or_load_stt_model(model)
# Transcribe
logger.info(f"🎤 STT: Transcribing audio with {model}...")
segments, info = stt_model.transcribe(
tmp_path,
language=language,
task=task,
beam_size=5,
vad_filter=True, # Remove silence
vad_parameters=dict(min_silence_duration_ms=500)
)
# Collect all segments
text_parts = []
for segment in segments:
text_parts.append(segment.text.strip())
full_text = " ".join(text_parts)
detected_language = info.language if hasattr(info, 'language') else language or "unknown"
logger.info(f"✅ STT: Transcribed successfully. Language: {detected_language}, Text: {full_text[:100]}...")
return {
"success": True,
"model": model,
"text": full_text,
"language": detected_language,
"device": swapper.device
}
finally:
# Clean up temp file
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except Exception as e:
logger.error(f"❌ STT endpoint error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# STT model cache
_stt_models = {}
async def _get_or_load_stt_model(model_name: str):
"""Get or load STT model (lazy loading)"""
global _stt_models
if model_name in _stt_models:
return _stt_models[model_name]
from faster_whisper import WhisperModel
# Map model names to faster-whisper sizes
model_map = {
"faster-whisper-large": "large-v3",
"faster-whisper-medium": "medium",
"whisper-small": "small",
"whisper-base": "base",
"whisper-tiny": "tiny"
}
whisper_size = model_map.get(model_name, "small")
logger.info(f"🔄 Loading STT model: {model_name} (size: {whisper_size})...")
# Use GPU if available
device = "cuda" if swapper.device == "cuda" else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
stt_model = WhisperModel(
whisper_size,
device=device,
compute_type=compute_type
)
_stt_models[model_name] = stt_model
logger.info(f"✅ STT model {model_name} loaded on {device}")
return stt_model
@app.get("/stt/models")
async def list_stt_models():
"""List available STT models"""
stt_models = [
{
"name": model.name,
"hf_name": model.hf_name,
"type": model.type,
"size_gb": model.size_gb,
"priority": model.priority,
"status": model.status.value,
"capabilities": model.capabilities,
"loaded": model.name in swapper.hf_models
}
for model in swapper.models.values()
if model.backend == ModelBackend.HUGGINGFACE and model.type == "stt"
]
return {
"stt_models": stt_models,
"device": swapper.device
}
# ========== TTS (Text-to-Speech) API Endpoints ==========
class TTSRequest(BaseModel):
"""TTS request"""
model: str = "xtts-v2"
text: str
language: str = "uk" # Ukrainian by default
speaker_wav_base64: Optional[str] = None # For voice cloning
speed: float = 1.0
@app.post("/tts")
async def tts_endpoint(request: TTSRequest):
"""
Text-to-Speech endpoint.
Models:
- xtts-v2: Best multilingual TTS with voice cloning (2GB VRAM)
Languages: uk, en, es, fr, de, it, pt, pl, tr, ru, nl, cs, ar, zh-cn, ja, hu, ko
"""
try:
model_name = request.model
# Get model config
tts_model_config = swapper.get_model_config(model_name, "tts")
if not tts_model_config:
raise HTTPException(status_code=400, detail=f"TTS model '{model_name}' not found")
# Load XTTS model if not loaded
if model_name not in swapper.hf_models:
logger.info(f"🔊 Loading TTS model: {model_name}...")
try:
from TTS.api import TTS
# XTTS-v2 model
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
if swapper.device == "cuda":
tts = tts.to(swapper.device)
swapper.hf_models[model_name] = tts
logger.info(f"✅ TTS model {model_name} loaded on {swapper.device}")
except ImportError:
logger.error("❌ TTS library not installed. Run: pip install TTS")
raise HTTPException(status_code=503, detail="TTS library not installed")
except Exception as e:
logger.error(f"❌ Failed to load TTS model: {e}")
raise HTTPException(status_code=500, detail=f"Failed to load TTS model: {e}")
tts_model = swapper.hf_models[model_name]
# Generate speech
import tempfile
import time
start_time = time.time()
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
tmp_path = tmp_file.name
try:
# Check if voice cloning is requested
if request.speaker_wav_base64:
# Decode speaker reference audio
speaker_audio = base64.b64decode(request.speaker_wav_base64)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as speaker_file:
speaker_file.write(speaker_audio)
speaker_path = speaker_file.name
# Generate with voice cloning
tts_model.tts_to_file(
text=request.text,
file_path=tmp_path,
speaker_wav=speaker_path,
language=request.language,
speed=request.speed
)
os.unlink(speaker_path)
else:
# Generate with default speaker
tts_model.tts_to_file(
text=request.text,
file_path=tmp_path,
language=request.language,
speed=request.speed
)
# Read and encode audio
with open(tmp_path, "rb") as f:
audio_data = f.read()
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
generation_time_ms = (time.time() - start_time) * 1000
logger.info(f"✅ TTS generated in {generation_time_ms:.0f}ms, {len(audio_data)} bytes")
return {
"success": True,
"model": model_name,
"audio_base64": audio_base64,
"language": request.language,
"generation_time_ms": generation_time_ms,
"device": swapper.device
}
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ TTS endpoint error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/tts/models")
async def list_tts_models():
"""List available TTS models"""
tts_models = [
{
"name": model.name,
"hf_name": model.hf_name,
"type": model.type,
"size_gb": model.size_gb,
"priority": model.priority,
"status": model.status.value,
"capabilities": model.capabilities,
"loaded": model.name in swapper.hf_models
}
for model in swapper.models.values()
if model.backend == ModelBackend.HUGGINGFACE and model.type == "tts"
]
return {
"tts_models": tts_models,
"device": swapper.device
}
# ========== Document Processing API Endpoints ==========
class DocumentRequest(BaseModel):
"""Document processing request"""
model: str = "granite-docling"
doc_base64: Optional[str] = None
doc_url: Optional[str] = None
output_format: str = "markdown" # markdown, json, doctags
@app.post("/document")
async def document_endpoint(
file: UploadFile = File(None),
model: str = Form("granite-docling"),
output_format: str = Form("markdown")
):
"""
Document processing endpoint using Docling.
Models:
- granite-docling: IBM Granite for document structure (2.5GB VRAM)
Output formats:
- markdown: Clean markdown text
- json: Structured JSON with document elements
- text: Plain text extraction
Supported files:
PDF, DOCX, XLS/XLSX/XLSM/ODS, PPTX, TXT/MD/CSV/TSV, JSON/YAML/XML/HTML, RTF, ZIP, images.
"""
try:
import time
start_time = time.time()
doc_data = None
if file:
doc_data = await file.read()
if not doc_data:
raise HTTPException(status_code=400, detail="No document provided")
# Determine file type
filename = file.filename if file else "document"
file_ext = filename.split(".")[-1].lower() if "." in filename else "pdf"
# Handle deterministic extraction for standard office/text formats
if file_ext in [
"txt", "md", "markdown", "csv", "tsv",
"xlsx", "xls", "xlsm", "ods",
"json", "yaml", "yml", "xml", "html", "htm", "rtf",
"pptx", "zip",
]:
try:
if file_ext == "zip":
content = _zip_to_markdown(doc_data)
output_format = "markdown"
else:
content = _extract_text_by_ext(filename, doc_data)
output_format = (
"markdown"
if file_ext in {
"md", "markdown", "csv", "tsv",
"xlsx", "xls", "xlsm", "ods",
"json", "yaml", "yml", "xml", "html", "htm", "pptx",
}
else "text"
)
processing_time_ms = (time.time() - start_time) * 1000
return {
"success": True,
"model": "text-extract",
"output_format": output_format,
"result": content,
"filename": filename,
"processing_time_ms": processing_time_ms,
"device": swapper.device
}
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Text extraction failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Text extraction failed: {e}")
# Save to temp file
import tempfile
with tempfile.NamedTemporaryFile(suffix=f".{file_ext}", delete=False) as tmp_file:
tmp_file.write(doc_data)
tmp_path = tmp_file.name
try:
# Try to use docling library
try:
from docling.document_converter import DocumentConverter
from docling.datamodel.base_models import InputFormat
logger.info(f"📄 Processing document: {filename} ({len(doc_data)} bytes)")
# Initialize converter (lazy load)
if "docling_converter" not in swapper.hf_models:
logger.info("🔄 Initializing Docling converter...")
converter = DocumentConverter()
swapper.hf_models["docling_converter"] = converter
logger.info("✅ Docling converter initialized")
converter = swapper.hf_models["docling_converter"]
# Convert document
result = converter.convert(tmp_path)
doc = result.document
# Format output
if output_format == "markdown":
content = doc.export_to_markdown()
elif output_format == "json":
content = doc.export_to_dict()
elif output_format == "text":
content = doc.export_to_text()
else:
content = doc.export_to_markdown()
processing_time_ms = (time.time() - start_time) * 1000
logger.info(f"✅ Document processed in {processing_time_ms:.0f}ms")
return {
"success": True,
"model": model,
"output_format": output_format,
"result": content,
"filename": filename,
"processing_time_ms": processing_time_ms,
"device": swapper.device
}
except ImportError:
# Fallback to pdfplumber/OCR for simpler extraction
logger.warning("⚠️ Docling not installed, using fallback extraction")
# For images, use OCR
if file_ext in ["png", "jpg", "jpeg", "gif", "webp"]:
ocr_result = await swapper.ocr_process("got-ocr2", doc_data, "ocr")
return {
"success": True,
"model": "got-ocr2 (fallback)",
"output_format": "text",
"result": ocr_result.get("text", ""),
"filename": filename,
"processing_time_ms": (time.time() - start_time) * 1000,
"device": swapper.device
}
# For common office/text formats, try deterministic extractors.
if file_ext in {
"docx", "txt", "md", "markdown", "csv", "tsv",
"xlsx", "xls", "xlsm", "ods",
"pptx", "json", "yaml", "yml", "xml", "html", "htm", "rtf",
}:
try:
content = _extract_text_by_ext(filename, doc_data)
out_fmt = "markdown" if file_ext not in {"txt", "rtf"} else "text"
return {
"success": True,
"model": "text-extract (fallback)",
"output_format": out_fmt,
"result": content,
"filename": filename,
"processing_time_ms": (time.time() - start_time) * 1000,
"device": swapper.device
}
except Exception as e:
logger.error(f"Text fallback failed for .{file_ext}: {e}")
raise HTTPException(status_code=500, detail=f"Extraction failed for .{file_ext}")
# For PDFs, try pdfplumber
if file_ext == "pdf":
try:
import pdfplumber
text_content = []
with pdfplumber.open(tmp_path) as pdf:
for page in pdf.pages:
text = page.extract_text()
if text:
text_content.append(text)
content = "\n\n".join(text_content)
return {
"success": True,
"model": "pdfplumber (fallback)",
"output_format": "text",
"result": content,
"filename": filename,
"processing_time_ms": (time.time() - start_time) * 1000,
"device": "cpu"
}
except ImportError:
pass
# For other documents, return error
raise HTTPException(
status_code=503,
detail="Document processing unavailable for this type. Supported: office/text/image/zip standard formats."
)
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
except HTTPException:
raise
except Exception as e:
logger.error(f"❌ Document endpoint error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/document/models")
async def list_document_models():
"""List available document processing models"""
doc_models = [
{
"name": model.name,
"hf_name": model.hf_name,
"type": model.type,
"size_gb": model.size_gb,
"priority": model.priority,
"status": model.status.value,
"capabilities": model.capabilities,
"loaded": model.name in swapper.hf_models
}
for model in swapper.models.values()
if model.backend == ModelBackend.HUGGINGFACE and model.type == "document"
]
return {
"document_models": doc_models,
"device": swapper.device
}
# ========== Image Generation API Endpoints (FLUX) ==========
class ImageGenerateRequest(BaseModel):
"""Image generation request"""
model: str = "flux-klein-4b"
prompt: str
negative_prompt: str = ""
num_inference_steps: int = 50
guidance_scale: float = 4.0
width: int = 1024
height: int = 1024
@app.post("/image/generate")
async def image_generate_endpoint(request: ImageGenerateRequest):
"""
Generate image using diffusion model (lazy loaded).
Models:
- flux-klein-4b: FLUX.2 Klein 4B (15.4GB VRAM, lazy loaded on demand)
The model is loaded on first request and unloaded when VRAM is needed for other models.
"""
try:
result = await swapper.image_generate(
model_name=request.model,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
num_inference_steps=request.num_inference_steps,
guidance_scale=request.guidance_scale,
width=request.width,
height=request.height
)
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"❌ Image generate endpoint error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/image/models")
async def list_image_models():
"""List available image generation models"""
image_models = [
{
"name": model.name,
"hf_name": model.hf_name,
"type": model.type,
"size_gb": model.size_gb,
"priority": model.priority,
"status": model.status.value,
"capabilities": model.capabilities,
"loaded": model.name in swapper.hf_models
}
for model in swapper.models.values()
if model.backend == ModelBackend.HUGGINGFACE and model.type == "image_generation"
]
return {
"image_models": image_models,
"active_image_model": swapper.active_image_model,
"device": swapper.device
}
@app.post("/image/models/{model_name}/load")
async def load_image_model(model_name: str):
"""Pre-load an image generation model (optional, models are lazy loaded by default)"""
if model_name not in swapper.models:
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
model_info = swapper.models[model_name]
if model_info.type != "image_generation":
raise HTTPException(status_code=400, detail="Not an image generation model")
async with swapper.loading_lock:
success = await swapper._load_hf_model(model_name)
if success:
model_info.status = ModelStatus.LOADED
model_info.loaded_at = datetime.now()
swapper.active_image_model = model_name
swapper.model_load_times[model_name] = datetime.now()
return {"status": "success", "model": model_name, "message": f"Image model {model_name} loaded"}
raise HTTPException(status_code=500, detail=f"Failed to load image model: {model_name}")
@app.post("/image/models/{model_name}/unload")
async def unload_image_model(model_name: str):
"""Unload an image generation model to free GPU memory"""
if model_name not in swapper.hf_models:
raise HTTPException(status_code=400, detail=f"Model not loaded: {model_name}")
async with swapper.loading_lock:
success = await swapper._unload_hf_model(model_name)
if success:
if model_name in swapper.models:
swapper.models[model_name].status = ModelStatus.UNLOADED
if swapper.active_image_model == model_name:
swapper.active_image_model = None
return {"status": "success", "model": model_name, "message": f"Image model {model_name} unloaded"}
raise HTTPException(status_code=500, detail=f"Failed to unload image model: {model_name}")
# ========== Web Scraping API Endpoints ==========
class WebExtractRequest(BaseModel):
"""Web content extraction request"""
url: str
method: str = "auto" # auto, jina, trafilatura, crawl4ai
include_links: bool = False
include_images: bool = False
class WebSearchRequest(BaseModel):
"""Web search request"""
query: str
max_results: int = 10
engine: str = "duckduckgo" # duckduckgo, google
@app.post("/web/extract")
async def web_extract(request: WebExtractRequest):
"""
Extract content from URL using multiple methods.
Methods:
- jina: Jina Reader API (free, JS support, cloud)
- trafilatura: Local extraction (fast, no JS)
- crawl4ai: Full crawling (JS support, local) - requires separate service
- auto: Try jina first, fallback to trafilatura
"""
url = request.url
method = request.method
async def extract_with_jina(url: str) -> dict:
"""Extract using Jina Reader API (free)"""
try:
jina_url = f"https://r.jina.ai/{url}"
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(jina_url)
if response.status_code == 200:
return {
"success": True,
"method": "jina",
"content": response.text,
"url": url
}
else:
return {"success": False, "error": f"Jina returned {response.status_code}"}
except Exception as e:
return {"success": False, "error": str(e)}
async def extract_with_trafilatura(url: str) -> dict:
"""Extract using Trafilatura (local)"""
try:
import trafilatura
downloaded = trafilatura.fetch_url(url)
if downloaded:
text = trafilatura.extract(
downloaded,
include_links=request.include_links,
include_images=request.include_images
)
return {
"success": True,
"method": "trafilatura",
"content": text,
"url": url
}
return {"success": False, "error": "Failed to download page"}
except ImportError:
return {"success": False, "error": "Trafilatura not installed"}
except Exception as e:
return {"success": False, "error": str(e)}
async def extract_with_crawl4ai(url: str) -> dict:
"""Extract using Crawl4AI service"""
try:
crawl4ai_url = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235")
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
f"{crawl4ai_url}/crawl",
json={"urls": [url], "word_count_threshold": 10}
)
if response.status_code == 200:
data = response.json()
result = data.get("results", [{}])[0]
# Get markdown - can be string or dict with raw_markdown
markdown_data = result.get("markdown", "")
if isinstance(markdown_data, dict):
content = markdown_data.get("raw_markdown", "") or markdown_data.get("fit_markdown", "")
else:
content = markdown_data
# Fallback to cleaned_html
if not content:
content = result.get("cleaned_html", "") or result.get("extracted_content", "")
# Last resort: strip HTML tags
if not content and result.get("html"):
import re
content = re.sub(r'<[^>]+>', ' ', result.get("html", ""))
content = re.sub(r'\s+', ' ', content).strip()
# Limit size for LLM context
if len(content) > 50000:
content = content[:50000] + "\n\n[... truncated ...]"
return {
"success": bool(content),
"method": "crawl4ai",
"content": content,
"url": url
}
return {"success": False, "error": f"Crawl4AI returned {response.status_code}"}
except Exception as e:
return {"success": False, "error": str(e)}
# Execute based on method
if method == "jina":
result = await extract_with_jina(url)
elif method == "trafilatura":
result = await extract_with_trafilatura(url)
elif method == "crawl4ai":
result = await extract_with_crawl4ai(url)
elif method == "auto":
# Try jina first (JS support), fallback to trafilatura
result = await extract_with_jina(url)
if not result.get("success"):
logger.info(f"Jina failed, trying trafilatura for {url}")
result = await extract_with_trafilatura(url)
else:
raise HTTPException(status_code=400, detail=f"Unknown method: {method}")
if not result.get("success"):
raise HTTPException(status_code=500, detail=result.get("error", "Extraction failed"))
return result
@app.post("/web/search")
async def web_search(request: WebSearchRequest):
"""
Search the web using multiple engines with fallback.
Priority: 1) DDGS (DuckDuckGo) 2) Google Search
"""
formatted_results = []
engine_used = "none"
# Method 1: Try DDGS (new package name)
try:
from ddgs import DDGS
ddgs = DDGS()
results = list(ddgs.text(
request.query,
max_results=request.max_results,
region="wt-wt" # Worldwide
))
if results:
for idx, result in enumerate(results):
formatted_results.append({
"position": idx + 1,
"title": result.get("title", ""),
"url": result.get("href", result.get("link", "")),
"snippet": result.get("body", result.get("snippet", ""))
})
engine_used = "ddgs"
logger.info(f"✅ DDGS search found {len(formatted_results)} results for: {request.query[:50]}")
except ImportError:
logger.warning("DDGS not installed, trying Google search")
except Exception as e:
logger.warning(f"DDGS search failed: {e}, trying Google search")
# Method 2: Fallback to Google search
if not formatted_results:
try:
from googlesearch import search as google_search
results = list(google_search(request.query, num_results=request.max_results, lang="uk"))
if results:
for idx, url in enumerate(results):
formatted_results.append({
"position": idx + 1,
"title": url.split("/")[-1].replace("-", " ").replace("_", " ")[:60] or "Result",
"url": url,
"snippet": ""
})
engine_used = "google"
logger.info(f"✅ Google search found {len(formatted_results)} results for: {request.query[:50]}")
except ImportError:
logger.warning("Google search not installed")
except Exception as e:
logger.warning(f"Google search failed: {e}")
# Return results or empty
return {
"success": len(formatted_results) > 0,
"query": request.query,
"results": formatted_results,
"total": len(formatted_results),
"engine": engine_used
}
@app.get("/web/read/{url:path}")
async def web_read_simple(url: str):
"""
Simple GET endpoint to read a URL (uses Jina by default).
Example: GET /web/read/https://example.com
"""
request = WebExtractRequest(url=url, method="auto")
return await web_extract(request)
@app.get("/web/status")
async def web_status():
"""Check availability of web scraping methods"""
# Check Trafilatura
try:
import trafilatura
trafilatura_available = True
except ImportError:
trafilatura_available = False
# Check DuckDuckGo
try:
from duckduckgo_search import DDGS
ddgs_available = True
except ImportError:
ddgs_available = False
# Check Crawl4AI
crawl4ai_url = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235")
crawl4ai_available = False
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{crawl4ai_url}/health")
crawl4ai_available = response.status_code == 200
except:
pass
return {
"methods": {
"jina": {"available": True, "type": "cloud", "js_support": True},
"trafilatura": {"available": trafilatura_available, "type": "local", "js_support": False},
"crawl4ai": {"available": crawl4ai_available, "type": "local", "js_support": True}
},
"search": {
"duckduckgo": {"available": ddgs_available}
}
}
# ========== Video Generation API (Grok xAI) ==========
GROK_API_KEY = os.getenv("GROK_API_KEY", "")
GROK_API_URL = "https://api.x.ai/v1"
class VideoGenerateRequest(BaseModel):
"""Video generation request via Grok"""
prompt: str
duration: int = 6 # seconds (max 6 for Grok)
style: str = "cinematic" # cinematic, anime, realistic, abstract
aspect_ratio: str = "16:9" # 16:9, 9:16, 1:1
@app.post("/video/generate")
async def video_generate(request: VideoGenerateRequest):
"""
Generate image using Grok (xAI) API.
Note: Grok API currently supports image generation only (not video).
For video-like content, generate multiple frames and combine externally.
"""
if not GROK_API_KEY:
raise HTTPException(status_code=503, detail="GROK_API_KEY not configured")
try:
async with httpx.AsyncClient(timeout=120.0) as client:
# Grok image generation endpoint
response = await client.post(
f"{GROK_API_URL}/images/generations",
headers={
"Authorization": f"Bearer {GROK_API_KEY}",
"Content-Type": "application/json"
},
json={
"model": "grok-2-image-1212", # Correct model name
"prompt": f"{request.prompt}, {request.style} style",
"n": 1,
"response_format": "url"
}
)
if response.status_code == 200:
data = response.json()
return {
"success": True,
"prompt": request.prompt,
"style": request.style,
"type": "image", # Note: video not available via API
"result": data,
"provider": "grok-xai",
"note": "Grok API supports image generation. Video generation is available only in xAI app."
}
else:
logger.error(f"Grok API error: {response.status_code} - {response.text}")
raise HTTPException(
status_code=response.status_code,
detail=f"Grok API error: {response.text}"
)
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Image generation timeout (>120s)")
except Exception as e:
logger.error(f"Image generation error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/video/status")
async def video_status():
"""Check Grok image/video generation service status"""
return {
"service": "grok-xai",
"api_key_configured": bool(GROK_API_KEY),
"capabilities": {
"image_generation": True, # grok-2-image-1212
"video_generation": False, # Not available via API (only in xAI app)
"vision_analysis": True # grok-2-vision-1212
},
"models": {
"image": "grok-2-image-1212",
"vision": "grok-2-vision-1212",
"chat": ["grok-3", "grok-3-mini", "grok-4-0709"]
},
"supported_styles": ["cinematic", "anime", "realistic", "abstract", "photorealistic"]
}
# ========== Multimodal Stack Summary ==========
@app.get("/multimodal")
async def get_multimodal_stack():
"""Get full multimodal stack status"""
def get_models_by_type(model_type: str):
return [
{
"name": m.name,
"size_gb": m.size_gb,
"status": m.status.value,
"loaded": m.name in swapper.hf_models
}
for m in swapper.models.values()
if m.type == model_type
]
return {
"device": swapper.device,
"cuda_available": TORCH_AVAILABLE and torch.cuda.is_available(),
"stack": {
"llm": get_models_by_type("llm"),
"vision": get_models_by_type("vision"),
"math": get_models_by_type("math"),
"ocr": get_models_by_type("ocr"),
"document": get_models_by_type("document"),
"stt": get_models_by_type("stt"),
"tts": get_models_by_type("tts"),
"embedding": get_models_by_type("embedding"),
"image_generation": get_models_by_type("image_generation"),
"video_generation": {"provider": "grok-xai", "available": bool(GROK_API_KEY)}
},
"active_models": {
"llm": swapper.active_model,
"ocr": swapper.active_ocr_model
}
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8890)