2516 lines
96 KiB
Python
2516 lines
96 KiB
Python
"""
|
|
Swapper Service - Dynamic Model Loading Service
|
|
Manages loading/unloading LLM and OCR models on-demand to optimize memory usage.
|
|
Supports:
|
|
- Ollama models (LLM, Vision, Math)
|
|
- HuggingFace models (OCR, Document Understanding)
|
|
- Lazy loading for OCR models
|
|
"""
|
|
|
|
import os
|
|
import asyncio
|
|
import logging
|
|
import base64
|
|
import json
|
|
import re
|
|
from typing import Optional, Dict, List, Any, Union
|
|
from datetime import datetime, timedelta
|
|
from enum import Enum
|
|
from io import BytesIO
|
|
import xml.etree.ElementTree as ET
|
|
|
|
from fastapi import FastAPI, HTTPException, BackgroundTasks, File, UploadFile, Form
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
import httpx
|
|
import csv
|
|
import zipfile
|
|
from io import BytesIO
|
|
import mimetypes
|
|
import yaml
|
|
|
|
# Optional imports for HuggingFace models
|
|
try:
|
|
import torch
|
|
from PIL import Image
|
|
TORCH_AVAILABLE = True
|
|
except ImportError:
|
|
TORCH_AVAILABLE = False
|
|
torch = None
|
|
Image = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
# ========== Document Helpers ==========
|
|
def _decode_text_bytes(content: bytes) -> str:
|
|
"""Decode text with best-effort fallback."""
|
|
try:
|
|
import chardet
|
|
detected = chardet.detect(content)
|
|
encoding = detected.get("encoding") or "utf-8"
|
|
return content.decode(encoding, errors="replace")
|
|
except Exception:
|
|
try:
|
|
return content.decode("utf-8", errors="replace")
|
|
except Exception:
|
|
return content.decode("latin-1", errors="replace")
|
|
|
|
|
|
def _csv_to_markdown(content: bytes) -> str:
|
|
text = _decode_text_bytes(content)
|
|
reader = csv.reader(text.splitlines())
|
|
rows = list(reader)
|
|
return _rows_to_markdown(rows)
|
|
|
|
|
|
def _tsv_to_markdown(content: bytes) -> str:
|
|
text = _decode_text_bytes(content)
|
|
reader = csv.reader(text.splitlines(), delimiter="\t")
|
|
rows = list(reader)
|
|
return _rows_to_markdown(rows)
|
|
|
|
|
|
def _rows_to_markdown(rows: List[List[Any]]) -> str:
|
|
if not rows:
|
|
return ""
|
|
width = max(len(r) for r in rows)
|
|
norm_rows = []
|
|
for r in rows:
|
|
rr = [str(c) if c is not None else "" for c in r]
|
|
if len(rr) < width:
|
|
rr.extend([""] * (width - len(rr)))
|
|
norm_rows.append(rr)
|
|
header = norm_rows[0]
|
|
body = norm_rows[1:]
|
|
lines = [
|
|
"| " + " | ".join(header) + " |",
|
|
"| " + " | ".join(["---"] * len(header)) + " |",
|
|
]
|
|
for row in body:
|
|
lines.append("| " + " | ".join([str(c) if c is not None else "" for c in row]) + " |")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _xlsx_to_markdown(content: bytes) -> str:
|
|
try:
|
|
import openpyxl
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"openpyxl not available: {e}")
|
|
wb = openpyxl.load_workbook(filename=BytesIO(content), data_only=True)
|
|
parts = []
|
|
for sheet in wb.worksheets:
|
|
parts.append(f"## Sheet: {sheet.title}")
|
|
rows = list(sheet.iter_rows(values_only=True))
|
|
if not rows:
|
|
parts.append("_Empty sheet_")
|
|
continue
|
|
header = [str(c) if c is not None else "" for c in rows[0]]
|
|
body = rows[1:]
|
|
parts.append("| " + " | ".join(header) + " |")
|
|
parts.append("| " + " | ".join(["---"] * len(header)) + " |")
|
|
for row in body:
|
|
parts.append("| " + " | ".join([str(c) if c is not None else "" for c in row]) + " |")
|
|
return "\n".join(parts)
|
|
|
|
|
|
def _xls_to_markdown(content: bytes) -> str:
|
|
try:
|
|
import xlrd
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"xlrd not available: {e}")
|
|
wb = xlrd.open_workbook(file_contents=content)
|
|
parts = []
|
|
for s in wb.sheets():
|
|
parts.append(f"## Sheet: {s.name}")
|
|
rows = []
|
|
for r in range(s.nrows):
|
|
rows.append([s.cell_value(r, c) for c in range(s.ncols)])
|
|
if not rows:
|
|
parts.append("_Empty sheet_")
|
|
continue
|
|
parts.append(_rows_to_markdown(rows))
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
def _ods_to_markdown(content: bytes) -> str:
|
|
try:
|
|
from odf.opendocument import load
|
|
from odf.table import Table, TableRow, TableCell
|
|
from odf.text import P
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"odfpy not available: {e}")
|
|
|
|
try:
|
|
doc = load(BytesIO(content))
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=f"Invalid ODS file: {e}")
|
|
|
|
parts = []
|
|
for table in doc.spreadsheet.getElementsByType(Table):
|
|
table_name = str(table.getAttribute("name") or "Sheet")
|
|
parts.append(f"## Sheet: {table_name}")
|
|
rows: List[List[str]] = []
|
|
for row in table.getElementsByType(TableRow):
|
|
cells_out: List[str] = []
|
|
for cell in row.getElementsByType(TableCell):
|
|
txt_parts = []
|
|
for p in cell.getElementsByType(P):
|
|
txt_parts.extend(
|
|
[str(getattr(node, "data", "")).strip() for node in p.childNodes if getattr(node, "data", None)]
|
|
)
|
|
cell_text = " ".join([t for t in txt_parts if t]).strip()
|
|
repeat_raw = cell.getAttribute("numbercolumnsrepeated")
|
|
try:
|
|
repeat = int(repeat_raw) if repeat_raw else 1
|
|
except Exception:
|
|
repeat = 1
|
|
repeat = max(1, min(repeat, 100))
|
|
for _ in range(repeat):
|
|
cells_out.append(cell_text)
|
|
if cells_out:
|
|
rows.append(cells_out)
|
|
if not rows:
|
|
parts.append("_Empty sheet_")
|
|
continue
|
|
parts.append(_rows_to_markdown(rows))
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
def _docx_to_text(content: bytes) -> str:
|
|
try:
|
|
from docx import Document
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"python-docx not available: {e}")
|
|
doc = Document(BytesIO(content))
|
|
lines = [p.text for p in doc.paragraphs if p.text]
|
|
return "\n".join(lines)
|
|
|
|
|
|
def _pdf_to_text(content: bytes) -> str:
|
|
try:
|
|
import pdfplumber
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"pdfplumber not available: {e}")
|
|
text_content = []
|
|
with pdfplumber.open(BytesIO(content)) as pdf:
|
|
for page in pdf.pages:
|
|
page_text = page.extract_text() or ""
|
|
if page_text:
|
|
text_content.append(page_text)
|
|
return "\n\n".join(text_content)
|
|
|
|
|
|
def _pptx_to_text(content: bytes) -> str:
|
|
try:
|
|
from pptx import Presentation
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"python-pptx not available: {e}")
|
|
prs = Presentation(BytesIO(content))
|
|
parts = []
|
|
for idx, slide in enumerate(prs.slides, start=1):
|
|
parts.append(f"## Slide {idx}")
|
|
slide_lines = []
|
|
for shape in slide.shapes:
|
|
text = getattr(shape, "text", None)
|
|
if text and str(text).strip():
|
|
slide_lines.append(str(text).strip())
|
|
parts.extend(slide_lines if slide_lines else ["_No text on this slide_"])
|
|
return "\n\n".join(parts)
|
|
|
|
|
|
def _json_to_text(content: bytes) -> str:
|
|
raw = _decode_text_bytes(content)
|
|
try:
|
|
parsed = json.loads(raw)
|
|
return json.dumps(parsed, ensure_ascii=False, indent=2)
|
|
except Exception:
|
|
return raw
|
|
|
|
|
|
def _yaml_to_text(content: bytes) -> str:
|
|
raw = _decode_text_bytes(content)
|
|
try:
|
|
parsed = yaml.safe_load(raw)
|
|
return yaml.safe_dump(parsed, allow_unicode=True, sort_keys=False)
|
|
except Exception:
|
|
return raw
|
|
|
|
|
|
def _xml_to_text(content: bytes) -> str:
|
|
raw = _decode_text_bytes(content)
|
|
try:
|
|
root = ET.fromstring(raw)
|
|
text = " ".join([t.strip() for t in root.itertext() if t and t.strip()])
|
|
return text or raw
|
|
except Exception:
|
|
return raw
|
|
|
|
|
|
def _html_to_text(content: bytes) -> str:
|
|
raw = _decode_text_bytes(content)
|
|
try:
|
|
from bs4 import BeautifulSoup
|
|
|
|
soup = BeautifulSoup(raw, "html.parser")
|
|
text = soup.get_text(separator="\n")
|
|
text = re.sub(r"\n{3,}", "\n\n", text)
|
|
return text.strip() or raw
|
|
except Exception:
|
|
# Minimal fallback if bs4 is unavailable
|
|
text = re.sub(r"<[^>]+>", " ", raw)
|
|
text = re.sub(r"\s+", " ", text)
|
|
return text.strip()
|
|
|
|
|
|
def _rtf_to_text(content: bytes) -> str:
|
|
raw = _decode_text_bytes(content)
|
|
try:
|
|
from striprtf.striprtf import rtf_to_text
|
|
return rtf_to_text(raw)
|
|
except Exception:
|
|
# Basic fallback: strip common RTF control tokens
|
|
text = re.sub(r"\\'[0-9a-fA-F]{2}", " ", raw)
|
|
text = re.sub(r"\\[a-zA-Z]+-?\d* ?", " ", text)
|
|
text = text.replace("{", " ").replace("}", " ")
|
|
return re.sub(r"\s+", " ", text).strip()
|
|
|
|
|
|
def _extract_text_by_ext(filename: str, content: bytes) -> str:
|
|
ext = filename.split(".")[-1].lower() if "." in filename else ""
|
|
if ext in ["txt", "md", "markdown"]:
|
|
return _decode_text_bytes(content)
|
|
if ext == "csv":
|
|
return _csv_to_markdown(content)
|
|
if ext == "tsv":
|
|
return _tsv_to_markdown(content)
|
|
if ext in {"xlsx", "xlsm"}:
|
|
return _xlsx_to_markdown(content)
|
|
if ext == "xls":
|
|
return _xls_to_markdown(content)
|
|
if ext == "ods":
|
|
return _ods_to_markdown(content)
|
|
if ext == "docx":
|
|
return _docx_to_text(content)
|
|
if ext == "pdf":
|
|
return _pdf_to_text(content)
|
|
if ext == "pptx":
|
|
return _pptx_to_text(content)
|
|
if ext == "json":
|
|
return _json_to_text(content)
|
|
if ext in {"yaml", "yml"}:
|
|
return _yaml_to_text(content)
|
|
if ext == "xml":
|
|
return _xml_to_text(content)
|
|
if ext in {"html", "htm"}:
|
|
return _html_to_text(content)
|
|
if ext == "rtf":
|
|
return _rtf_to_text(content)
|
|
raise HTTPException(status_code=400, detail=f"Unsupported file type: .{ext}")
|
|
|
|
|
|
def _zip_to_markdown(content: bytes, max_files: int = 50, max_total_mb: int = 100) -> str:
|
|
zf = zipfile.ZipFile(BytesIO(content))
|
|
members = [m for m in zf.infolist() if not m.is_dir()]
|
|
if len(members) > max_files:
|
|
raise HTTPException(status_code=400, detail=f"ZIP has слишком много файлов: {len(members)}")
|
|
total_size = sum(m.file_size for m in members)
|
|
if total_size > max_total_mb * 1024 * 1024:
|
|
raise HTTPException(status_code=400, detail=f"ZIP слишком большой: {total_size / 1024 / 1024:.1f} MB")
|
|
parts = []
|
|
allowed_exts = {
|
|
"txt", "md", "markdown", "csv", "tsv",
|
|
"xls", "xlsx", "xlsm", "ods",
|
|
"docx", "pdf", "pptx",
|
|
"json", "yaml", "yml", "xml", "html", "htm", "rtf",
|
|
}
|
|
processed = []
|
|
skipped = []
|
|
for member in members:
|
|
name = member.filename
|
|
ext = name.split(".")[-1].lower() if "." in name else ""
|
|
if ext not in allowed_exts:
|
|
skipped.append(name)
|
|
parts.append(f"## {name}\n_Skipped unsupported file type_")
|
|
continue
|
|
file_bytes = zf.read(member)
|
|
extracted = _extract_text_by_ext(name, file_bytes)
|
|
processed.append(name)
|
|
parts.append(f"## {name}\n{extracted}")
|
|
header_lines = ["# ZIP summary", "Processed files:"]
|
|
header_lines.extend([f"- {name}" for name in processed] or ["- (none)"])
|
|
if skipped:
|
|
header_lines.append("Skipped files:")
|
|
header_lines.extend([f"- {name}" for name in skipped])
|
|
return "\n\n".join(["\n".join(header_lines), *parts])
|
|
|
|
|
|
# ========== Configuration ==========
|
|
|
|
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
|
|
SWAPPER_CONFIG_PATH = os.getenv("SWAPPER_CONFIG_PATH", "./config/swapper_config.yaml")
|
|
SWAPPER_MODE = os.getenv("SWAPPER_MODE", "single-active") # single-active or multi-active
|
|
MAX_CONCURRENT_MODELS = int(os.getenv("MAX_CONCURRENT_MODELS", "1"))
|
|
MODEL_SWAP_TIMEOUT = int(os.getenv("MODEL_SWAP_TIMEOUT", "30"))
|
|
|
|
# ========== Models ==========
|
|
|
|
class ModelStatus(str, Enum):
|
|
"""Model status"""
|
|
LOADED = "loaded"
|
|
LOADING = "loading"
|
|
UNLOADED = "unloaded"
|
|
UNLOADING = "unloading"
|
|
ERROR = "error"
|
|
|
|
class ModelBackend(str, Enum):
|
|
"""Model backend type"""
|
|
OLLAMA = "ollama"
|
|
HUGGINGFACE = "huggingface"
|
|
|
|
class ModelInfo(BaseModel):
|
|
"""Model information"""
|
|
name: str
|
|
ollama_name: str # For Ollama models
|
|
hf_name: Optional[str] = None # For HuggingFace models
|
|
backend: ModelBackend = ModelBackend.OLLAMA
|
|
type: str # llm, code, vision, math, ocr
|
|
size_gb: float
|
|
priority: str # high, medium, low
|
|
status: ModelStatus
|
|
capabilities: List[str] = [] # For OCR models
|
|
loaded_at: Optional[datetime] = None
|
|
unloaded_at: Optional[datetime] = None
|
|
total_uptime_seconds: float = 0.0
|
|
request_count: int = 0
|
|
|
|
class SwapperStatus(BaseModel):
|
|
"""Swapper service status"""
|
|
status: str
|
|
active_model: Optional[str] = None
|
|
available_models: List[str]
|
|
loaded_models: List[str]
|
|
mode: str
|
|
total_models: int
|
|
|
|
class ModelMetrics(BaseModel):
|
|
"""Model usage metrics"""
|
|
model_name: str
|
|
status: str
|
|
loaded_at: Optional[datetime] = None
|
|
uptime_hours: float
|
|
request_count: int
|
|
total_uptime_seconds: float
|
|
|
|
# ========== Swapper Service ==========
|
|
|
|
class SwapperService:
|
|
"""Swapper Service - manages model loading/unloading for Ollama and HuggingFace"""
|
|
|
|
def __init__(self):
|
|
self.models: Dict[str, ModelInfo] = {}
|
|
self.active_model: Optional[str] = None # Active LLM model
|
|
self.active_ocr_model: Optional[str] = None # Active OCR model (separate from LLM)
|
|
self.active_image_model: Optional[str] = None # Active Image Generation model
|
|
self.loading_lock = asyncio.Lock()
|
|
self.http_client = httpx.AsyncClient(timeout=300.0)
|
|
self.model_uptime: Dict[str, float] = {} # Track uptime per model
|
|
self.model_load_times: Dict[str, datetime] = {} # Track when model was loaded
|
|
|
|
# HuggingFace model instances (lazy loaded)
|
|
self.hf_models: Dict[str, Any] = {} # model_name -> model instance
|
|
self.hf_processors: Dict[str, Any] = {} # model_name -> processor/tokenizer
|
|
|
|
# Device configuration
|
|
self.device = "cuda" if TORCH_AVAILABLE and torch.cuda.is_available() else "cpu"
|
|
logger.info(f"🔧 Swapper initialized with device: {self.device}")
|
|
|
|
async def initialize(self):
|
|
"""Initialize Swapper Service - load configuration"""
|
|
config = None
|
|
try:
|
|
logger.info(f"🔧 Initializing Swapper Service...")
|
|
logger.info(f"🔧 Config path: {SWAPPER_CONFIG_PATH}")
|
|
logger.info(f"🔧 Config exists: {os.path.exists(SWAPPER_CONFIG_PATH)}")
|
|
|
|
if os.path.exists(SWAPPER_CONFIG_PATH):
|
|
with open(SWAPPER_CONFIG_PATH, 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
models_config = config.get('models', {})
|
|
logger.info(f"🔧 Found {len(models_config)} models in config")
|
|
|
|
for model_key, model_config in models_config.items():
|
|
path = model_config.get('path', '')
|
|
model_type = model_config.get('type', 'llm')
|
|
capabilities = model_config.get('capabilities', [])
|
|
|
|
# Determine backend from path prefix
|
|
if path.startswith('huggingface:'):
|
|
hf_name = path.replace('huggingface:', '')
|
|
logger.info(f"🔧 Adding HuggingFace model: {model_key} -> {hf_name}")
|
|
self.models[model_key] = ModelInfo(
|
|
name=model_key,
|
|
ollama_name="",
|
|
hf_name=hf_name,
|
|
backend=ModelBackend.HUGGINGFACE,
|
|
type=model_type,
|
|
size_gb=model_config.get('size_gb', 0),
|
|
priority=model_config.get('priority', 'medium'),
|
|
capabilities=capabilities,
|
|
status=ModelStatus.UNLOADED
|
|
)
|
|
else:
|
|
ollama_name = path.replace('ollama:', '')
|
|
logger.info(f"🔧 Adding Ollama model: {model_key} -> {ollama_name}")
|
|
self.models[model_key] = ModelInfo(
|
|
name=model_key,
|
|
ollama_name=ollama_name,
|
|
backend=ModelBackend.OLLAMA,
|
|
type=model_type,
|
|
size_gb=model_config.get('size_gb', 0),
|
|
priority=model_config.get('priority', 'medium'),
|
|
capabilities=capabilities,
|
|
status=ModelStatus.UNLOADED
|
|
)
|
|
self.model_uptime[model_key] = 0.0
|
|
|
|
logger.info(f"✅ Loaded {len(self.models)} models into Swapper")
|
|
|
|
# Count by backend
|
|
ollama_count = sum(1 for m in self.models.values() if m.backend == ModelBackend.OLLAMA)
|
|
hf_count = sum(1 for m in self.models.values() if m.backend == ModelBackend.HUGGINGFACE)
|
|
logger.info(f"✅ Models: {ollama_count} Ollama, {hf_count} HuggingFace")
|
|
else:
|
|
logger.warning(f"⚠️ Config file not found: {SWAPPER_CONFIG_PATH}, using defaults")
|
|
await self._load_models_from_ollama()
|
|
|
|
logger.info(f"✅ Swapper Service initialized with {len(self.models)} models")
|
|
logger.info(f"✅ Model names: {list(self.models.keys())}")
|
|
|
|
# Load default LLM model (not OCR - those are lazy loaded)
|
|
if config:
|
|
swapper_config = config.get('swapper', {})
|
|
default_model = swapper_config.get('default_model')
|
|
lazy_load_ocr = swapper_config.get('lazy_load_ocr', True)
|
|
|
|
if default_model and default_model in self.models:
|
|
model_info = self.models[default_model]
|
|
# Only auto-load non-OCR models
|
|
if model_info.type != 'ocr' or not lazy_load_ocr:
|
|
logger.info(f"🔄 Loading default model: {default_model}")
|
|
success = await self.load_model(default_model)
|
|
if success:
|
|
logger.info(f"✅ Default model loaded: {default_model}")
|
|
else:
|
|
logger.warning(f"⚠️ Failed to load default model: {default_model}")
|
|
else:
|
|
logger.info(f"⏳ OCR model '{default_model}' will be lazy loaded on first request")
|
|
elif default_model:
|
|
logger.warning(f"⚠️ Default model '{default_model}' not found in models list")
|
|
except Exception as e:
|
|
logger.error(f"❌ Error initializing Swapper Service: {e}", exc_info=True)
|
|
import traceback
|
|
logger.error(f"❌ Traceback: {traceback.format_exc()}")
|
|
|
|
async def _load_models_from_ollama(self):
|
|
"""Load available models from Ollama"""
|
|
try:
|
|
response = await self.http_client.get(f"{OLLAMA_BASE_URL}/api/tags")
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
for model in data.get('models', []):
|
|
model_name = model.get('name', '')
|
|
# Extract base name (remove :latest, :7b, etc.)
|
|
base_name = model_name.split(':')[0]
|
|
|
|
if base_name not in self.models:
|
|
size_gb = model.get('size', 0) / (1024**3) # Convert bytes to GB
|
|
self.models[base_name] = ModelInfo(
|
|
name=base_name,
|
|
ollama_name=model_name,
|
|
type='llm', # Default type
|
|
size_gb=size_gb,
|
|
priority='medium',
|
|
status=ModelStatus.UNLOADED
|
|
)
|
|
self.model_uptime[base_name] = 0.0
|
|
|
|
logger.info(f"✅ Loaded {len(self.models)} models from Ollama")
|
|
except Exception as e:
|
|
logger.error(f"❌ Error loading models from Ollama: {e}")
|
|
|
|
async def load_model(self, model_name: str) -> bool:
|
|
"""Load a model (unload current if in single-active mode)"""
|
|
async with self.loading_lock:
|
|
try:
|
|
# Check if model exists
|
|
if model_name not in self.models:
|
|
logger.error(f"❌ Model not found: {model_name}")
|
|
return False
|
|
|
|
model_info = self.models[model_name]
|
|
|
|
# If single-active mode and another model is loaded, unload it first
|
|
if SWAPPER_MODE == "single-active" and self.active_model and self.active_model != model_name:
|
|
await self._unload_model_internal(self.active_model)
|
|
|
|
# Load the model
|
|
logger.info(f"🔄 Loading model: {model_name}")
|
|
model_info.status = ModelStatus.LOADING
|
|
|
|
# Check if model is already loaded in Ollama
|
|
response = await self.http_client.post(
|
|
f"{OLLAMA_BASE_URL}/api/generate",
|
|
json={
|
|
"model": model_info.ollama_name,
|
|
"prompt": "test",
|
|
"stream": False
|
|
},
|
|
timeout=MODEL_SWAP_TIMEOUT
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
model_info.status = ModelStatus.LOADED
|
|
model_info.loaded_at = datetime.now()
|
|
model_info.unloaded_at = None
|
|
self.active_model = model_name
|
|
self.model_load_times[model_name] = datetime.now()
|
|
logger.info(f"✅ Model loaded: {model_name}")
|
|
return True
|
|
else:
|
|
model_info.status = ModelStatus.ERROR
|
|
logger.error(f"❌ Failed to load model: {model_name}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error loading model {model_name}: {e}", exc_info=True)
|
|
if model_name in self.models:
|
|
self.models[model_name].status = ModelStatus.ERROR
|
|
return False
|
|
|
|
async def _unload_model_internal(self, model_name: str) -> bool:
|
|
"""Internal method to unload a model"""
|
|
try:
|
|
if model_name not in self.models:
|
|
return False
|
|
|
|
model_info = self.models[model_name]
|
|
|
|
if model_info.status == ModelStatus.LOADED:
|
|
logger.info(f"🔄 Unloading model: {model_name}")
|
|
model_info.status = ModelStatus.UNLOADING
|
|
|
|
# Calculate uptime
|
|
if model_name in self.model_load_times:
|
|
load_time = self.model_load_times[model_name]
|
|
uptime_seconds = (datetime.now() - load_time).total_seconds()
|
|
self.model_uptime[model_name] = self.model_uptime.get(model_name, 0.0) + uptime_seconds
|
|
model_info.total_uptime_seconds = self.model_uptime[model_name]
|
|
del self.model_load_times[model_name]
|
|
|
|
model_info.status = ModelStatus.UNLOADED
|
|
model_info.unloaded_at = datetime.now()
|
|
|
|
if self.active_model == model_name:
|
|
self.active_model = None
|
|
|
|
logger.info(f"✅ Model unloaded: {model_name}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error unloading model {model_name}: {e}")
|
|
return False
|
|
|
|
async def unload_model(self, model_name: str) -> bool:
|
|
"""Unload a model"""
|
|
async with self.loading_lock:
|
|
return await self._unload_model_internal(model_name)
|
|
|
|
async def get_status(self) -> SwapperStatus:
|
|
"""Get Swapper service status"""
|
|
# Update uptime for currently loaded model
|
|
if self.active_model and self.active_model in self.model_load_times:
|
|
load_time = self.model_load_times[self.active_model]
|
|
current_uptime = (datetime.now() - load_time).total_seconds()
|
|
self.model_uptime[self.active_model] = self.model_uptime.get(self.active_model, 0.0) + current_uptime
|
|
self.model_load_times[self.active_model] = datetime.now() # Reset timer
|
|
|
|
loaded_models = [
|
|
name for name, model in self.models.items()
|
|
if model.status == ModelStatus.LOADED
|
|
]
|
|
|
|
return SwapperStatus(
|
|
status="healthy",
|
|
active_model=self.active_model,
|
|
available_models=list(self.models.keys()),
|
|
loaded_models=loaded_models,
|
|
mode=SWAPPER_MODE,
|
|
total_models=len(self.models)
|
|
)
|
|
|
|
async def get_model_metrics(self, model_name: Optional[str] = None) -> List[ModelMetrics]:
|
|
"""Get metrics for model(s)"""
|
|
metrics = []
|
|
|
|
models_to_check = [model_name] if model_name else list(self.models.keys())
|
|
|
|
for name in models_to_check:
|
|
if name not in self.models:
|
|
continue
|
|
|
|
model_info = self.models[name]
|
|
|
|
# Calculate current uptime
|
|
uptime_seconds = self.model_uptime.get(name, 0.0)
|
|
if name in self.model_load_times:
|
|
load_time = self.model_load_times[name]
|
|
current_uptime = (datetime.now() - load_time).total_seconds()
|
|
uptime_seconds += current_uptime
|
|
|
|
uptime_hours = uptime_seconds / 3600.0
|
|
|
|
metrics.append(ModelMetrics(
|
|
model_name=name,
|
|
status=model_info.status.value,
|
|
loaded_at=model_info.loaded_at,
|
|
uptime_hours=uptime_hours,
|
|
request_count=model_info.request_count,
|
|
total_uptime_seconds=uptime_seconds
|
|
))
|
|
|
|
return metrics
|
|
|
|
async def generate(self, model_name: str, prompt: str, system_prompt: Optional[str] = None,
|
|
max_tokens: int = 2048, temperature: float = 0.7, stream: bool = False) -> Dict[str, Any]:
|
|
"""Generate text using a model"""
|
|
try:
|
|
# Ensure model is loaded
|
|
if model_name not in self.models:
|
|
raise ValueError(f"Model not found: {model_name}")
|
|
|
|
model_info = self.models[model_name]
|
|
|
|
# Load model if not loaded
|
|
if model_info.status != ModelStatus.LOADED:
|
|
logger.info(f"🔄 Model {model_name} not loaded, loading now...")
|
|
success = await self.load_model(model_name)
|
|
if not success:
|
|
raise ValueError(f"Failed to load model: {model_name}")
|
|
|
|
# Increment request count
|
|
model_info.request_count += 1
|
|
|
|
# Prepare request to Ollama
|
|
request_data = {
|
|
"model": model_info.ollama_name,
|
|
"prompt": prompt,
|
|
"stream": stream,
|
|
"options": {
|
|
"num_predict": max_tokens,
|
|
"temperature": temperature
|
|
}
|
|
}
|
|
|
|
if system_prompt:
|
|
request_data["system"] = system_prompt
|
|
|
|
# Call Ollama
|
|
response = await self.http_client.post(
|
|
f"{OLLAMA_BASE_URL}/api/generate",
|
|
json=request_data,
|
|
timeout=300.0
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
return {
|
|
"response": data.get("response", ""),
|
|
"model": model_name,
|
|
"done": data.get("done", True),
|
|
"eval_count": data.get("eval_count", 0),
|
|
"prompt_eval_count": data.get("prompt_eval_count", 0)
|
|
}
|
|
else:
|
|
raise HTTPException(
|
|
status_code=response.status_code,
|
|
detail=f"Ollama error: {response.text}"
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Error generating with model {model_name}: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def close(self):
|
|
"""Close HTTP client and unload HuggingFace models"""
|
|
await self.http_client.aclose()
|
|
|
|
# Unload HuggingFace models to free GPU memory
|
|
for model_name in list(self.hf_models.keys()):
|
|
await self._unload_hf_model(model_name)
|
|
|
|
async def _load_hf_model(self, model_name: str) -> bool:
|
|
"""Load a HuggingFace model (OCR/Document Understanding)"""
|
|
if not TORCH_AVAILABLE:
|
|
logger.error("❌ PyTorch not available, cannot load HuggingFace models")
|
|
return False
|
|
|
|
try:
|
|
model_info = self.models[model_name]
|
|
if model_info.backend != ModelBackend.HUGGINGFACE:
|
|
logger.error(f"❌ Model {model_name} is not a HuggingFace model")
|
|
return False
|
|
|
|
hf_name = model_info.hf_name
|
|
logger.info(f"🔄 Loading HuggingFace model: {hf_name} on {self.device}")
|
|
|
|
from transformers import AutoModel, AutoTokenizer, AutoProcessor
|
|
|
|
# Different loading strategies based on model type
|
|
if "GOT-OCR" in hf_name or "got-ocr" in hf_name.lower():
|
|
# GOT-OCR2.0 specific loading
|
|
tokenizer = AutoTokenizer.from_pretrained(hf_name, trust_remote_code=True)
|
|
model = AutoModel.from_pretrained(
|
|
hf_name,
|
|
trust_remote_code=True,
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
|
|
device_map="auto" if self.device == "cuda" else None,
|
|
low_cpu_mem_usage=True
|
|
)
|
|
self.hf_processors[model_name] = tokenizer
|
|
self.hf_models[model_name] = model
|
|
|
|
elif "donut" in hf_name.lower():
|
|
# Donut model loading
|
|
from transformers import DonutProcessor, VisionEncoderDecoderModel
|
|
processor = DonutProcessor.from_pretrained(hf_name)
|
|
model = VisionEncoderDecoderModel.from_pretrained(
|
|
hf_name,
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
|
)
|
|
if self.device == "cuda":
|
|
model = model.cuda()
|
|
model.eval()
|
|
self.hf_processors[model_name] = processor
|
|
self.hf_models[model_name] = model
|
|
|
|
elif "trocr" in hf_name.lower():
|
|
# TrOCR model loading
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
processor = TrOCRProcessor.from_pretrained(hf_name)
|
|
model = VisionEncoderDecoderModel.from_pretrained(
|
|
hf_name,
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
|
)
|
|
if self.device == "cuda":
|
|
model = model.cuda()
|
|
model.eval()
|
|
self.hf_processors[model_name] = processor
|
|
self.hf_models[model_name] = model
|
|
|
|
elif "flux" in hf_name.lower() or model_info.type == "image_generation":
|
|
# FLUX / Diffusion model loading
|
|
logger.info(f"🎨 Loading diffusion model: {hf_name}")
|
|
from diffusers import AutoPipelineForText2Image
|
|
|
|
pipeline = AutoPipelineForText2Image.from_pretrained(
|
|
hf_name,
|
|
torch_dtype=torch.bfloat16,
|
|
use_safetensors=True
|
|
)
|
|
pipeline.to(self.device)
|
|
pipeline.enable_model_cpu_offload() # Optimize VRAM usage
|
|
|
|
self.hf_models[model_name] = pipeline
|
|
self.hf_processors[model_name] = None # No separate processor for diffusion
|
|
logger.info(f"✅ Diffusion model loaded: {model_name} with CPU offload enabled")
|
|
|
|
else:
|
|
# Generic loading
|
|
processor = AutoProcessor.from_pretrained(hf_name, trust_remote_code=True)
|
|
model = AutoModel.from_pretrained(
|
|
hf_name,
|
|
trust_remote_code=True,
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
|
)
|
|
if self.device == "cuda":
|
|
model = model.cuda()
|
|
self.hf_processors[model_name] = processor
|
|
self.hf_models[model_name] = model
|
|
|
|
logger.info(f"✅ HuggingFace model loaded: {model_name}")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Failed to load HuggingFace model {model_name}: {e}", exc_info=True)
|
|
return False
|
|
|
|
async def _unload_hf_model(self, model_name: str) -> bool:
|
|
"""Unload a HuggingFace model to free memory"""
|
|
try:
|
|
if model_name in self.hf_models:
|
|
del self.hf_models[model_name]
|
|
if model_name in self.hf_processors:
|
|
del self.hf_processors[model_name]
|
|
|
|
# Force garbage collection
|
|
if TORCH_AVAILABLE:
|
|
import gc
|
|
gc.collect()
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
logger.info(f"✅ HuggingFace model unloaded: {model_name}")
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"❌ Failed to unload HuggingFace model {model_name}: {e}")
|
|
return False
|
|
|
|
async def ocr_process(self, model_name: str, image_data: bytes, ocr_type: str = "ocr") -> Dict[str, Any]:
|
|
"""Process image with OCR model"""
|
|
if not TORCH_AVAILABLE:
|
|
raise HTTPException(status_code=503, detail="PyTorch not available")
|
|
|
|
try:
|
|
# Ensure model is loaded
|
|
if model_name not in self.models:
|
|
raise ValueError(f"OCR model not found: {model_name}")
|
|
|
|
model_info = self.models[model_name]
|
|
|
|
if model_info.backend != ModelBackend.HUGGINGFACE:
|
|
raise ValueError(f"Model {model_name} is not an OCR model")
|
|
|
|
# Lazy load model if not loaded
|
|
if model_name not in self.hf_models:
|
|
# Unload current OCR model if different (to save VRAM)
|
|
if self.active_ocr_model and self.active_ocr_model != model_name:
|
|
await self._unload_hf_model(self.active_ocr_model)
|
|
if self.active_ocr_model in self.models:
|
|
self.models[self.active_ocr_model].status = ModelStatus.UNLOADED
|
|
|
|
logger.info(f"🔄 Lazy loading OCR model: {model_name}")
|
|
model_info.status = ModelStatus.LOADING
|
|
|
|
success = await self._load_hf_model(model_name)
|
|
if not success:
|
|
model_info.status = ModelStatus.ERROR
|
|
raise ValueError(f"Failed to load OCR model: {model_name}")
|
|
|
|
model_info.status = ModelStatus.LOADED
|
|
model_info.loaded_at = datetime.now()
|
|
self.active_ocr_model = model_name
|
|
self.model_load_times[model_name] = datetime.now()
|
|
|
|
# Process image
|
|
image = Image.open(BytesIO(image_data)).convert("RGB")
|
|
model = self.hf_models[model_name]
|
|
processor = self.hf_processors[model_name]
|
|
|
|
model_info.request_count += 1
|
|
hf_name = model_info.hf_name
|
|
|
|
# Different processing based on model type
|
|
if "GOT-OCR" in hf_name or "got-ocr" in hf_name.lower():
|
|
# GOT-OCR2.0 processing
|
|
with torch.no_grad():
|
|
result = model.chat(processor, image, ocr_type=ocr_type)
|
|
text = result if isinstance(result, str) else str(result)
|
|
|
|
elif "donut" in hf_name.lower():
|
|
# Donut processing
|
|
task_prompt = "<s_cord-v2>" # or "<s_docvqa>" depending on task
|
|
decoder_input_ids = processor.tokenizer(
|
|
task_prompt, add_special_tokens=False, return_tensors="pt"
|
|
).input_ids
|
|
|
|
pixel_values = processor(image, return_tensors="pt").pixel_values
|
|
if self.device == "cuda":
|
|
pixel_values = pixel_values.cuda()
|
|
decoder_input_ids = decoder_input_ids.cuda()
|
|
|
|
with torch.no_grad():
|
|
outputs = model.generate(
|
|
pixel_values,
|
|
decoder_input_ids=decoder_input_ids,
|
|
max_length=model.decoder.config.max_position_embeddings,
|
|
pad_token_id=processor.tokenizer.pad_token_id,
|
|
eos_token_id=processor.tokenizer.eos_token_id,
|
|
use_cache=True,
|
|
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
|
return_dict_in_generate=True
|
|
)
|
|
|
|
sequence = processor.batch_decode(outputs.sequences)[0]
|
|
text = processor.token2json(sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, ""))
|
|
|
|
elif "trocr" in hf_name.lower():
|
|
# TrOCR processing
|
|
pixel_values = processor(images=image, return_tensors="pt").pixel_values
|
|
if self.device == "cuda":
|
|
pixel_values = pixel_values.cuda()
|
|
|
|
with torch.no_grad():
|
|
generated_ids = model.generate(pixel_values, max_length=512)
|
|
|
|
text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
|
else:
|
|
# Generic processing (try chat method)
|
|
with torch.no_grad():
|
|
if hasattr(model, 'chat'):
|
|
result = model.chat(processor, image)
|
|
text = result if isinstance(result, str) else str(result)
|
|
else:
|
|
text = "Model does not support direct inference"
|
|
|
|
return {
|
|
"success": True,
|
|
"model": model_name,
|
|
"text": text,
|
|
"device": self.device
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ OCR processing failed: {e}", exc_info=True)
|
|
raise
|
|
|
|
async def image_generate(
|
|
self,
|
|
model_name: str,
|
|
prompt: str,
|
|
negative_prompt: str = "",
|
|
num_inference_steps: int = 50,
|
|
guidance_scale: float = 4.0,
|
|
width: int = 1024,
|
|
height: int = 1024
|
|
) -> Dict[str, Any]:
|
|
"""Generate image with diffusion model (lazy loaded)"""
|
|
if not TORCH_AVAILABLE:
|
|
raise HTTPException(status_code=503, detail="PyTorch not available")
|
|
|
|
import time
|
|
start_time = time.time()
|
|
|
|
try:
|
|
# Ensure model exists
|
|
if model_name not in self.models:
|
|
raise ValueError(f"Image model not found: {model_name}")
|
|
|
|
model_info = self.models[model_name]
|
|
|
|
if model_info.type != "image_generation":
|
|
raise ValueError(f"Model {model_name} is not an image generation model")
|
|
|
|
# Lazy load model if not loaded
|
|
if model_name not in self.hf_models:
|
|
# Unload current image model if different (to save VRAM)
|
|
if self.active_image_model and self.active_image_model != model_name:
|
|
logger.info(f"🔄 Unloading current image model: {self.active_image_model}")
|
|
await self._unload_hf_model(self.active_image_model)
|
|
if self.active_image_model in self.models:
|
|
self.models[self.active_image_model].status = ModelStatus.UNLOADED
|
|
|
|
logger.info(f"🎨 Lazy loading image model: {model_name}")
|
|
model_info.status = ModelStatus.LOADING
|
|
|
|
success = await self._load_hf_model(model_name)
|
|
if not success:
|
|
model_info.status = ModelStatus.ERROR
|
|
raise ValueError(f"Failed to load image model: {model_name}")
|
|
|
|
model_info.status = ModelStatus.LOADED
|
|
model_info.loaded_at = datetime.now()
|
|
self.active_image_model = model_name
|
|
self.model_load_times[model_name] = datetime.now()
|
|
|
|
# Generate image
|
|
pipeline = self.hf_models[model_name]
|
|
model_info.request_count += 1
|
|
|
|
logger.info(f"🎨 Generating image with {model_name}: {prompt[:50]}...")
|
|
|
|
with torch.no_grad():
|
|
# FLUX Klein doesn't support negative_prompt, check pipeline type
|
|
pipeline_kwargs = {
|
|
"prompt": prompt,
|
|
"num_inference_steps": num_inference_steps,
|
|
"guidance_scale": guidance_scale,
|
|
"width": width,
|
|
"height": height,
|
|
}
|
|
|
|
# Only add negative_prompt for models that support it (not FLUX)
|
|
is_flux = "flux" in model_name.lower()
|
|
if negative_prompt and not is_flux:
|
|
pipeline_kwargs["negative_prompt"] = negative_prompt
|
|
|
|
result = pipeline(**pipeline_kwargs)
|
|
image = result.images[0]
|
|
|
|
# Convert to base64
|
|
buffered = BytesIO()
|
|
image.save(buffered, format="PNG")
|
|
img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
|
|
generation_time_ms = (time.time() - start_time) * 1000
|
|
logger.info(f"✅ Image generated in {generation_time_ms:.0f}ms")
|
|
|
|
return {
|
|
"success": True,
|
|
"model": model_name,
|
|
"image_base64": img_base64,
|
|
"width": width,
|
|
"height": height,
|
|
"generation_time_ms": generation_time_ms,
|
|
"device": self.device
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Image generation failed: {e}", exc_info=True)
|
|
raise
|
|
|
|
# ========== FastAPI App ==========
|
|
|
|
app = FastAPI(
|
|
title="Swapper Service",
|
|
description="Dynamic model loading service for Node #2",
|
|
version="1.0.0"
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Include cabinet API router (import after swapper is created)
|
|
try:
|
|
from app.cabinet_api import router as cabinet_router
|
|
app.include_router(cabinet_router)
|
|
logger.info("✅ Cabinet API router included")
|
|
except ImportError:
|
|
logger.warning("⚠️ cabinet_api module not found, skipping cabinet router")
|
|
|
|
# Global Swapper instance
|
|
swapper = SwapperService()
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
"""Initialize Swapper on startup"""
|
|
await swapper.initialize()
|
|
|
|
@app.on_event("shutdown")
|
|
async def shutdown():
|
|
"""Close Swapper on shutdown"""
|
|
await swapper.close()
|
|
|
|
# ========== API Endpoints ==========
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
"""Health check endpoint"""
|
|
status = await swapper.get_status()
|
|
return {
|
|
"status": "healthy",
|
|
"service": "swapper-service",
|
|
"active_model": status.active_model,
|
|
"mode": status.mode
|
|
}
|
|
|
|
@app.get("/status", response_model=SwapperStatus)
|
|
async def get_status():
|
|
"""Get Swapper service status"""
|
|
return await swapper.get_status()
|
|
|
|
@app.get("/models")
|
|
async def list_models():
|
|
"""List all available models"""
|
|
return {
|
|
"models": [
|
|
{
|
|
"name": model.name,
|
|
"ollama_name": model.ollama_name,
|
|
"type": model.type,
|
|
"size_gb": model.size_gb,
|
|
"priority": model.priority,
|
|
"status": model.status.value
|
|
}
|
|
for model in swapper.models.values()
|
|
]
|
|
}
|
|
|
|
@app.get("/models/{model_name}")
|
|
async def get_model_info(model_name: str):
|
|
"""Get information about a specific model"""
|
|
if model_name not in swapper.models:
|
|
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
|
|
|
|
model_info = swapper.models[model_name]
|
|
return {
|
|
"name": model_info.name,
|
|
"ollama_name": model_info.ollama_name,
|
|
"type": model_info.type,
|
|
"size_gb": model_info.size_gb,
|
|
"priority": model_info.priority,
|
|
"status": model_info.status.value,
|
|
"loaded_at": model_info.loaded_at.isoformat() if model_info.loaded_at else None,
|
|
"unloaded_at": model_info.unloaded_at.isoformat() if model_info.unloaded_at else None,
|
|
"total_uptime_seconds": swapper.model_uptime.get(model_name, 0.0)
|
|
}
|
|
|
|
@app.post("/models/{model_name}/load")
|
|
async def load_model_endpoint(model_name: str):
|
|
"""Load a model"""
|
|
success = await swapper.load_model(model_name)
|
|
if success:
|
|
return {"status": "success", "model": model_name, "message": f"Model {model_name} loaded"}
|
|
raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}")
|
|
|
|
@app.post("/models/{model_name}/unload")
|
|
async def unload_model_endpoint(model_name: str):
|
|
"""Unload a model"""
|
|
success = await swapper.unload_model(model_name)
|
|
if success:
|
|
return {"status": "success", "model": model_name, "message": f"Model {model_name} unloaded"}
|
|
raise HTTPException(status_code=500, detail=f"Failed to unload model: {model_name}")
|
|
|
|
@app.get("/metrics")
|
|
async def get_metrics(model_name: Optional[str] = None):
|
|
"""Get metrics for model(s)"""
|
|
metrics = await swapper.get_model_metrics(model_name)
|
|
return {
|
|
"metrics": [metric.dict() for metric in metrics]
|
|
}
|
|
|
|
@app.get("/metrics/{model_name}")
|
|
async def get_model_metrics(model_name: str):
|
|
"""Get metrics for a specific model"""
|
|
metrics = await swapper.get_model_metrics(model_name)
|
|
if not metrics:
|
|
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
|
|
return metrics[0].dict()
|
|
|
|
# ========== Chat Completions API (OpenAI-compatible) ==========
|
|
|
|
class ChatMessage(BaseModel):
|
|
"""Chat message"""
|
|
role: str # system, user, assistant
|
|
content: str
|
|
|
|
class ChatCompletionRequest(BaseModel):
|
|
"""Chat completion request (OpenAI-compatible)"""
|
|
model: str
|
|
messages: List[ChatMessage]
|
|
max_tokens: Optional[int] = 2048
|
|
temperature: Optional[float] = 0.7
|
|
stream: Optional[bool] = False
|
|
|
|
class ChatCompletionResponse(BaseModel):
|
|
"""Chat completion response (OpenAI-compatible)"""
|
|
id: str
|
|
object: str = "chat.completion"
|
|
created: int
|
|
model: str
|
|
choices: List[Dict[str, Any]]
|
|
usage: Dict[str, int]
|
|
|
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
|
async def chat_completions(request: ChatCompletionRequest):
|
|
"""OpenAI-compatible chat completions endpoint"""
|
|
import time
|
|
|
|
try:
|
|
# Extract system prompt and user messages
|
|
system_prompt = None
|
|
user_messages = []
|
|
|
|
for msg in request.messages:
|
|
if msg.role == "system":
|
|
system_prompt = msg.content
|
|
elif msg.role == "user":
|
|
user_messages.append(msg.content)
|
|
|
|
# Combine user messages into prompt
|
|
prompt = "\n".join(user_messages)
|
|
|
|
# Generate response
|
|
result = await swapper.generate(
|
|
model_name=request.model,
|
|
prompt=prompt,
|
|
system_prompt=system_prompt,
|
|
max_tokens=request.max_tokens or 2048,
|
|
temperature=request.temperature or 0.7,
|
|
stream=request.stream or False
|
|
)
|
|
|
|
# Format response in OpenAI style
|
|
return ChatCompletionResponse(
|
|
id=f"chatcmpl-{int(time.time())}",
|
|
created=int(time.time()),
|
|
model=request.model,
|
|
choices=[{
|
|
"index": 0,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": result["response"]
|
|
},
|
|
"finish_reason": "stop" if result.get("done", True) else None
|
|
}],
|
|
usage={
|
|
"prompt_tokens": result.get("prompt_eval_count", 0),
|
|
"completion_tokens": result.get("eval_count", 0),
|
|
"total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0)
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"❌ Error in chat completions: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# ========== Generate API (Ollama-compatible) ==========
|
|
|
|
class GenerateRequest(BaseModel):
|
|
"""Generate request (Ollama-compatible)"""
|
|
model: str
|
|
prompt: str
|
|
system: Optional[str] = None
|
|
max_tokens: Optional[int] = 2048
|
|
temperature: Optional[float] = 0.7
|
|
stream: Optional[bool] = False
|
|
|
|
@app.post("/generate")
|
|
async def generate(request: GenerateRequest):
|
|
"""Ollama-compatible generate endpoint"""
|
|
try:
|
|
result = await swapper.generate(
|
|
model_name=request.model,
|
|
prompt=request.prompt,
|
|
system_prompt=request.system,
|
|
max_tokens=request.max_tokens or 2048,
|
|
temperature=request.temperature or 0.7,
|
|
stream=request.stream or False
|
|
)
|
|
return result
|
|
except Exception as e:
|
|
logger.error(f"❌ Error in generate: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# ========== VISION API Endpoints ==========
|
|
|
|
class VisionRequest(BaseModel):
|
|
"""Vision (image description) request"""
|
|
model: str = "qwen3-vl-8b"
|
|
prompt: str = "Опиши це зображення коротко (2-3 речення)."
|
|
images: List[str] # List of base64 encoded images (can include data: prefix)
|
|
system: Optional[str] = None
|
|
max_tokens: int = 1024
|
|
temperature: float = 0.7
|
|
|
|
@app.post("/vision")
|
|
async def vision_endpoint(request: VisionRequest):
|
|
"""
|
|
Vision endpoint - analyze images with Vision-Language models.
|
|
|
|
Models:
|
|
- qwen3-vl-8b: Qwen3 Vision-Language model (8GB VRAM)
|
|
|
|
Images should be base64 encoded. Can include data:image/... prefix or raw base64.
|
|
"""
|
|
try:
|
|
import time
|
|
start_time = time.time()
|
|
|
|
model_name = request.model
|
|
|
|
# Convert data URLs to raw base64 (Ollama expects base64 without prefix)
|
|
processed_images = []
|
|
for img in request.images:
|
|
if img.startswith("data:"):
|
|
# Extract base64 part from data URL
|
|
base64_part = img.split(",", 1)[1] if "," in img else img
|
|
processed_images.append(base64_part)
|
|
else:
|
|
processed_images.append(img)
|
|
|
|
logger.info(f"🖼️ Vision request: model={model_name}, images={len(processed_images)}, prompt={request.prompt[:50]}...")
|
|
|
|
# Map model name to Ollama model
|
|
ollama_model = model_name.replace("-", ":") # qwen3-vl-8b -> qwen3:vl-8b
|
|
if model_name == "qwen3-vl-8b":
|
|
ollama_model = "qwen3-vl:8b"
|
|
|
|
# Build Ollama request
|
|
ollama_payload = {
|
|
"model": ollama_model,
|
|
"prompt": request.prompt,
|
|
"images": processed_images,
|
|
"stream": False,
|
|
"options": {
|
|
"num_predict": request.max_tokens,
|
|
"temperature": request.temperature
|
|
}
|
|
}
|
|
|
|
if request.system:
|
|
ollama_payload["system"] = request.system
|
|
|
|
# Send to Ollama
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
response = await client.post(
|
|
f"{OLLAMA_BASE_URL}/api/generate",
|
|
json=ollama_payload
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
logger.error(f"❌ Ollama vision error: {response.status_code} - {response.text[:200]}")
|
|
raise HTTPException(status_code=500, detail=f"Ollama error: {response.status_code}")
|
|
|
|
result = response.json()
|
|
vision_text = result.get("response", "")
|
|
|
|
# Debug logging
|
|
if not vision_text:
|
|
logger.warning(f"⚠️ Empty response from Ollama! Result keys: {list(result.keys())}, error: {result.get('error', 'none')}")
|
|
|
|
processing_time_ms = (time.time() - start_time) * 1000
|
|
logger.info(f"✅ Vision response: {len(vision_text)} chars in {processing_time_ms:.0f}ms")
|
|
|
|
return {
|
|
"success": True,
|
|
"model": model_name,
|
|
"text": vision_text,
|
|
"processing_time_ms": processing_time_ms,
|
|
"images_count": len(processed_images)
|
|
}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"❌ Vision endpoint error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/vision/models")
|
|
async def vision_models():
|
|
"""List available vision models"""
|
|
vision_models = [m for m in swapper.models.values() if m.type == "vision"]
|
|
return {
|
|
"models": [
|
|
{
|
|
"name": m.name,
|
|
"type": m.type,
|
|
"status": m.status.value,
|
|
"size_gb": m.size_gb
|
|
}
|
|
for m in vision_models
|
|
]
|
|
}
|
|
|
|
|
|
# ========== OCR API Endpoints ==========
|
|
|
|
class OCRRequest(BaseModel):
|
|
"""OCR request"""
|
|
model: str = "got-ocr2" # Default to GOT-OCR2.0
|
|
image_base64: Optional[str] = None
|
|
image_url: Optional[str] = None
|
|
ocr_type: str = "ocr" # ocr, format, table
|
|
|
|
@app.post("/ocr")
|
|
async def ocr_endpoint(
|
|
request: OCRRequest = None,
|
|
file: UploadFile = File(None),
|
|
model: str = Form("got-ocr2"),
|
|
ocr_type: str = Form("ocr")
|
|
):
|
|
"""
|
|
OCR endpoint - process images with OCR models.
|
|
|
|
Models:
|
|
- got-ocr2: Best for documents, tables, formulas (7GB VRAM)
|
|
- donut-base: Document parsing without OCR (3GB VRAM)
|
|
- donut-cord: Receipt/invoice parsing (3GB VRAM)
|
|
- trocr-base: Fast printed text OCR (2GB VRAM)
|
|
|
|
OCR Types (for GOT-OCR2.0):
|
|
- ocr: Standard OCR
|
|
- format: Preserve formatting
|
|
- table: Extract tables
|
|
"""
|
|
try:
|
|
image_data = None
|
|
model_name = model
|
|
ocr_type_param = ocr_type
|
|
|
|
# Get image from request
|
|
if file:
|
|
image_data = await file.read()
|
|
elif request:
|
|
model_name = request.model
|
|
ocr_type_param = request.ocr_type
|
|
|
|
if request.image_base64:
|
|
image_data = base64.b64decode(request.image_base64)
|
|
elif request.image_url:
|
|
async with httpx.AsyncClient() as client:
|
|
response = await client.get(request.image_url)
|
|
if response.status_code == 200:
|
|
image_data = response.content
|
|
else:
|
|
raise HTTPException(status_code=400, detail="Failed to download image")
|
|
|
|
if not image_data:
|
|
raise HTTPException(status_code=400, detail="No image provided")
|
|
|
|
# Process with OCR
|
|
result = await swapper.ocr_process(model_name, image_data, ocr_type_param)
|
|
return result
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
logger.error(f"❌ OCR endpoint error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/ocr/models")
|
|
async def list_ocr_models():
|
|
"""List available OCR models"""
|
|
ocr_models = [
|
|
{
|
|
"name": model.name,
|
|
"hf_name": model.hf_name,
|
|
"type": model.type,
|
|
"size_gb": model.size_gb,
|
|
"priority": model.priority,
|
|
"status": model.status.value,
|
|
"capabilities": model.capabilities,
|
|
"loaded": model.name in swapper.hf_models
|
|
}
|
|
for model in swapper.models.values()
|
|
if model.backend == ModelBackend.HUGGINGFACE and model.type == "ocr"
|
|
]
|
|
return {
|
|
"ocr_models": ocr_models,
|
|
"active_ocr_model": swapper.active_ocr_model,
|
|
"device": swapper.device
|
|
}
|
|
|
|
@app.post("/ocr/models/{model_name}/load")
|
|
async def load_ocr_model(model_name: str):
|
|
"""Pre-load an OCR model (optional, models are lazy loaded by default)"""
|
|
if model_name not in swapper.models:
|
|
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
|
|
|
|
model_info = swapper.models[model_name]
|
|
if model_info.backend != ModelBackend.HUGGINGFACE:
|
|
raise HTTPException(status_code=400, detail="Not an OCR model")
|
|
|
|
async with swapper.loading_lock:
|
|
success = await swapper._load_hf_model(model_name)
|
|
if success:
|
|
model_info.status = ModelStatus.LOADED
|
|
model_info.loaded_at = datetime.now()
|
|
swapper.active_ocr_model = model_name
|
|
swapper.model_load_times[model_name] = datetime.now()
|
|
return {"status": "success", "model": model_name, "message": f"OCR model {model_name} loaded"}
|
|
raise HTTPException(status_code=500, detail=f"Failed to load OCR model: {model_name}")
|
|
|
|
@app.post("/ocr/models/{model_name}/unload")
|
|
async def unload_ocr_model(model_name: str):
|
|
"""Unload an OCR model to free GPU memory"""
|
|
if model_name not in swapper.hf_models:
|
|
raise HTTPException(status_code=400, detail=f"Model not loaded: {model_name}")
|
|
|
|
async with swapper.loading_lock:
|
|
success = await swapper._unload_hf_model(model_name)
|
|
if success:
|
|
if model_name in swapper.models:
|
|
swapper.models[model_name].status = ModelStatus.UNLOADED
|
|
if swapper.active_ocr_model == model_name:
|
|
swapper.active_ocr_model = None
|
|
return {"status": "success", "model": model_name, "message": f"OCR model {model_name} unloaded"}
|
|
raise HTTPException(status_code=500, detail=f"Failed to unload OCR model: {model_name}")
|
|
|
|
# ========== STT (Speech-to-Text) API Endpoints ==========
|
|
|
|
class STTRequest(BaseModel):
|
|
"""STT request"""
|
|
model: str = "faster-whisper-large"
|
|
audio_base64: Optional[str] = None
|
|
audio_url: Optional[str] = None
|
|
language: Optional[str] = None # auto-detect if not specified
|
|
task: str = "transcribe" # transcribe or translate
|
|
|
|
@app.post("/stt")
|
|
async def stt_endpoint(
|
|
file: UploadFile = File(None),
|
|
model: str = Form("faster-whisper-large"),
|
|
language: Optional[str] = Form(None),
|
|
task: str = Form("transcribe")
|
|
):
|
|
"""
|
|
Speech-to-Text endpoint using Faster Whisper.
|
|
|
|
Models:
|
|
- faster-whisper-large: Best quality, 99 languages (3GB VRAM)
|
|
- whisper-small: Fast transcription (0.5GB VRAM)
|
|
"""
|
|
import tempfile
|
|
import os
|
|
|
|
try:
|
|
audio_data = None
|
|
if file:
|
|
audio_data = await file.read()
|
|
|
|
if not audio_data:
|
|
raise HTTPException(status_code=400, detail="No audio provided")
|
|
|
|
# Save audio to temp file (faster-whisper requires file path)
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".ogg") as tmp_file:
|
|
tmp_file.write(audio_data)
|
|
tmp_path = tmp_file.name
|
|
|
|
try:
|
|
# Lazy load faster-whisper model
|
|
stt_model = await _get_or_load_stt_model(model)
|
|
|
|
# Transcribe
|
|
logger.info(f"🎤 STT: Transcribing audio with {model}...")
|
|
segments, info = stt_model.transcribe(
|
|
tmp_path,
|
|
language=language,
|
|
task=task,
|
|
beam_size=5,
|
|
vad_filter=True, # Remove silence
|
|
vad_parameters=dict(min_silence_duration_ms=500)
|
|
)
|
|
|
|
# Collect all segments
|
|
text_parts = []
|
|
for segment in segments:
|
|
text_parts.append(segment.text.strip())
|
|
|
|
full_text = " ".join(text_parts)
|
|
detected_language = info.language if hasattr(info, 'language') else language or "unknown"
|
|
|
|
logger.info(f"✅ STT: Transcribed successfully. Language: {detected_language}, Text: {full_text[:100]}...")
|
|
|
|
return {
|
|
"success": True,
|
|
"model": model,
|
|
"text": full_text,
|
|
"language": detected_language,
|
|
"device": swapper.device
|
|
}
|
|
|
|
finally:
|
|
# Clean up temp file
|
|
if os.path.exists(tmp_path):
|
|
os.unlink(tmp_path)
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ STT endpoint error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# STT model cache
|
|
_stt_models = {}
|
|
|
|
async def _get_or_load_stt_model(model_name: str):
|
|
"""Get or load STT model (lazy loading)"""
|
|
global _stt_models
|
|
|
|
if model_name in _stt_models:
|
|
return _stt_models[model_name]
|
|
|
|
from faster_whisper import WhisperModel
|
|
|
|
# Map model names to faster-whisper sizes
|
|
model_map = {
|
|
"faster-whisper-large": "large-v3",
|
|
"faster-whisper-medium": "medium",
|
|
"whisper-small": "small",
|
|
"whisper-base": "base",
|
|
"whisper-tiny": "tiny"
|
|
}
|
|
|
|
whisper_size = model_map.get(model_name, "small")
|
|
|
|
logger.info(f"🔄 Loading STT model: {model_name} (size: {whisper_size})...")
|
|
|
|
# Use GPU if available
|
|
device = "cuda" if swapper.device == "cuda" else "cpu"
|
|
compute_type = "float16" if device == "cuda" else "int8"
|
|
|
|
stt_model = WhisperModel(
|
|
whisper_size,
|
|
device=device,
|
|
compute_type=compute_type
|
|
)
|
|
|
|
_stt_models[model_name] = stt_model
|
|
logger.info(f"✅ STT model {model_name} loaded on {device}")
|
|
|
|
return stt_model
|
|
|
|
@app.get("/stt/models")
|
|
async def list_stt_models():
|
|
"""List available STT models"""
|
|
stt_models = [
|
|
{
|
|
"name": model.name,
|
|
"hf_name": model.hf_name,
|
|
"type": model.type,
|
|
"size_gb": model.size_gb,
|
|
"priority": model.priority,
|
|
"status": model.status.value,
|
|
"capabilities": model.capabilities,
|
|
"loaded": model.name in swapper.hf_models
|
|
}
|
|
for model in swapper.models.values()
|
|
if model.backend == ModelBackend.HUGGINGFACE and model.type == "stt"
|
|
]
|
|
return {
|
|
"stt_models": stt_models,
|
|
"device": swapper.device
|
|
}
|
|
|
|
# ========== TTS (Text-to-Speech) API Endpoints ==========
|
|
|
|
class TTSRequest(BaseModel):
|
|
"""TTS request"""
|
|
model: str = "xtts-v2"
|
|
text: str
|
|
language: str = "uk" # Ukrainian by default
|
|
speaker_wav_base64: Optional[str] = None # For voice cloning
|
|
speed: float = 1.0
|
|
|
|
@app.post("/tts")
|
|
async def tts_endpoint(request: TTSRequest):
|
|
"""
|
|
Text-to-Speech endpoint.
|
|
|
|
Models:
|
|
- xtts-v2: Best multilingual TTS with voice cloning (2GB VRAM)
|
|
|
|
Languages: uk, en, es, fr, de, it, pt, pl, tr, ru, nl, cs, ar, zh-cn, ja, hu, ko
|
|
"""
|
|
try:
|
|
model_name = request.model
|
|
|
|
# Get model config
|
|
tts_model_config = swapper.get_model_config(model_name, "tts")
|
|
if not tts_model_config:
|
|
raise HTTPException(status_code=400, detail=f"TTS model '{model_name}' not found")
|
|
|
|
# Load XTTS model if not loaded
|
|
if model_name not in swapper.hf_models:
|
|
logger.info(f"🔊 Loading TTS model: {model_name}...")
|
|
|
|
try:
|
|
from TTS.api import TTS
|
|
|
|
# XTTS-v2 model
|
|
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2")
|
|
if swapper.device == "cuda":
|
|
tts = tts.to(swapper.device)
|
|
|
|
swapper.hf_models[model_name] = tts
|
|
logger.info(f"✅ TTS model {model_name} loaded on {swapper.device}")
|
|
|
|
except ImportError:
|
|
logger.error("❌ TTS library not installed. Run: pip install TTS")
|
|
raise HTTPException(status_code=503, detail="TTS library not installed")
|
|
except Exception as e:
|
|
logger.error(f"❌ Failed to load TTS model: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Failed to load TTS model: {e}")
|
|
|
|
tts_model = swapper.hf_models[model_name]
|
|
|
|
# Generate speech
|
|
import tempfile
|
|
import time
|
|
|
|
start_time = time.time()
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
|
tmp_path = tmp_file.name
|
|
|
|
try:
|
|
# Check if voice cloning is requested
|
|
if request.speaker_wav_base64:
|
|
# Decode speaker reference audio
|
|
speaker_audio = base64.b64decode(request.speaker_wav_base64)
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as speaker_file:
|
|
speaker_file.write(speaker_audio)
|
|
speaker_path = speaker_file.name
|
|
|
|
# Generate with voice cloning
|
|
tts_model.tts_to_file(
|
|
text=request.text,
|
|
file_path=tmp_path,
|
|
speaker_wav=speaker_path,
|
|
language=request.language,
|
|
speed=request.speed
|
|
)
|
|
os.unlink(speaker_path)
|
|
else:
|
|
# Generate with default speaker
|
|
tts_model.tts_to_file(
|
|
text=request.text,
|
|
file_path=tmp_path,
|
|
language=request.language,
|
|
speed=request.speed
|
|
)
|
|
|
|
# Read and encode audio
|
|
with open(tmp_path, "rb") as f:
|
|
audio_data = f.read()
|
|
audio_base64 = base64.b64encode(audio_data).decode("utf-8")
|
|
|
|
generation_time_ms = (time.time() - start_time) * 1000
|
|
logger.info(f"✅ TTS generated in {generation_time_ms:.0f}ms, {len(audio_data)} bytes")
|
|
|
|
return {
|
|
"success": True,
|
|
"model": model_name,
|
|
"audio_base64": audio_base64,
|
|
"language": request.language,
|
|
"generation_time_ms": generation_time_ms,
|
|
"device": swapper.device
|
|
}
|
|
|
|
finally:
|
|
if os.path.exists(tmp_path):
|
|
os.unlink(tmp_path)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"❌ TTS endpoint error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/tts/models")
|
|
async def list_tts_models():
|
|
"""List available TTS models"""
|
|
tts_models = [
|
|
{
|
|
"name": model.name,
|
|
"hf_name": model.hf_name,
|
|
"type": model.type,
|
|
"size_gb": model.size_gb,
|
|
"priority": model.priority,
|
|
"status": model.status.value,
|
|
"capabilities": model.capabilities,
|
|
"loaded": model.name in swapper.hf_models
|
|
}
|
|
for model in swapper.models.values()
|
|
if model.backend == ModelBackend.HUGGINGFACE and model.type == "tts"
|
|
]
|
|
return {
|
|
"tts_models": tts_models,
|
|
"device": swapper.device
|
|
}
|
|
|
|
# ========== Document Processing API Endpoints ==========
|
|
|
|
class DocumentRequest(BaseModel):
|
|
"""Document processing request"""
|
|
model: str = "granite-docling"
|
|
doc_base64: Optional[str] = None
|
|
doc_url: Optional[str] = None
|
|
output_format: str = "markdown" # markdown, json, doctags
|
|
|
|
@app.post("/document")
|
|
async def document_endpoint(
|
|
file: UploadFile = File(None),
|
|
model: str = Form("granite-docling"),
|
|
output_format: str = Form("markdown")
|
|
):
|
|
"""
|
|
Document processing endpoint using Docling.
|
|
|
|
Models:
|
|
- granite-docling: IBM Granite for document structure (2.5GB VRAM)
|
|
|
|
Output formats:
|
|
- markdown: Clean markdown text
|
|
- json: Structured JSON with document elements
|
|
- text: Plain text extraction
|
|
|
|
Supported files:
|
|
PDF, DOCX, XLS/XLSX/XLSM/ODS, PPTX, TXT/MD/CSV/TSV, JSON/YAML/XML/HTML, RTF, ZIP, images.
|
|
"""
|
|
try:
|
|
import time
|
|
start_time = time.time()
|
|
|
|
doc_data = None
|
|
if file:
|
|
doc_data = await file.read()
|
|
|
|
if not doc_data:
|
|
raise HTTPException(status_code=400, detail="No document provided")
|
|
|
|
# Determine file type
|
|
filename = file.filename if file else "document"
|
|
file_ext = filename.split(".")[-1].lower() if "." in filename else "pdf"
|
|
|
|
# Handle deterministic extraction for standard office/text formats
|
|
if file_ext in [
|
|
"txt", "md", "markdown", "csv", "tsv",
|
|
"xlsx", "xls", "xlsm", "ods",
|
|
"json", "yaml", "yml", "xml", "html", "htm", "rtf",
|
|
"pptx", "zip",
|
|
]:
|
|
try:
|
|
if file_ext == "zip":
|
|
content = _zip_to_markdown(doc_data)
|
|
output_format = "markdown"
|
|
else:
|
|
content = _extract_text_by_ext(filename, doc_data)
|
|
output_format = (
|
|
"markdown"
|
|
if file_ext in {
|
|
"md", "markdown", "csv", "tsv",
|
|
"xlsx", "xls", "xlsm", "ods",
|
|
"json", "yaml", "yml", "xml", "html", "htm", "pptx",
|
|
}
|
|
else "text"
|
|
)
|
|
processing_time_ms = (time.time() - start_time) * 1000
|
|
return {
|
|
"success": True,
|
|
"model": "text-extract",
|
|
"output_format": output_format,
|
|
"result": content,
|
|
"filename": filename,
|
|
"processing_time_ms": processing_time_ms,
|
|
"device": swapper.device
|
|
}
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"❌ Text extraction failed: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=f"Text extraction failed: {e}")
|
|
|
|
# Save to temp file
|
|
import tempfile
|
|
with tempfile.NamedTemporaryFile(suffix=f".{file_ext}", delete=False) as tmp_file:
|
|
tmp_file.write(doc_data)
|
|
tmp_path = tmp_file.name
|
|
|
|
try:
|
|
# Try to use docling library
|
|
try:
|
|
from docling.document_converter import DocumentConverter
|
|
from docling.datamodel.base_models import InputFormat
|
|
|
|
logger.info(f"📄 Processing document: {filename} ({len(doc_data)} bytes)")
|
|
|
|
# Initialize converter (lazy load)
|
|
if "docling_converter" not in swapper.hf_models:
|
|
logger.info("🔄 Initializing Docling converter...")
|
|
converter = DocumentConverter()
|
|
swapper.hf_models["docling_converter"] = converter
|
|
logger.info("✅ Docling converter initialized")
|
|
|
|
converter = swapper.hf_models["docling_converter"]
|
|
|
|
# Convert document
|
|
result = converter.convert(tmp_path)
|
|
doc = result.document
|
|
|
|
# Format output
|
|
if output_format == "markdown":
|
|
content = doc.export_to_markdown()
|
|
elif output_format == "json":
|
|
content = doc.export_to_dict()
|
|
elif output_format == "text":
|
|
content = doc.export_to_text()
|
|
else:
|
|
content = doc.export_to_markdown()
|
|
|
|
processing_time_ms = (time.time() - start_time) * 1000
|
|
logger.info(f"✅ Document processed in {processing_time_ms:.0f}ms")
|
|
|
|
return {
|
|
"success": True,
|
|
"model": model,
|
|
"output_format": output_format,
|
|
"result": content,
|
|
"filename": filename,
|
|
"processing_time_ms": processing_time_ms,
|
|
"device": swapper.device
|
|
}
|
|
|
|
except ImportError:
|
|
# Fallback to pdfplumber/OCR for simpler extraction
|
|
logger.warning("⚠️ Docling not installed, using fallback extraction")
|
|
|
|
# For images, use OCR
|
|
if file_ext in ["png", "jpg", "jpeg", "gif", "webp"]:
|
|
ocr_result = await swapper.ocr_process("got-ocr2", doc_data, "ocr")
|
|
return {
|
|
"success": True,
|
|
"model": "got-ocr2 (fallback)",
|
|
"output_format": "text",
|
|
"result": ocr_result.get("text", ""),
|
|
"filename": filename,
|
|
"processing_time_ms": (time.time() - start_time) * 1000,
|
|
"device": swapper.device
|
|
}
|
|
|
|
# For common office/text formats, try deterministic extractors.
|
|
if file_ext in {
|
|
"docx", "txt", "md", "markdown", "csv", "tsv",
|
|
"xlsx", "xls", "xlsm", "ods",
|
|
"pptx", "json", "yaml", "yml", "xml", "html", "htm", "rtf",
|
|
}:
|
|
try:
|
|
content = _extract_text_by_ext(filename, doc_data)
|
|
out_fmt = "markdown" if file_ext not in {"txt", "rtf"} else "text"
|
|
return {
|
|
"success": True,
|
|
"model": "text-extract (fallback)",
|
|
"output_format": out_fmt,
|
|
"result": content,
|
|
"filename": filename,
|
|
"processing_time_ms": (time.time() - start_time) * 1000,
|
|
"device": swapper.device
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Text fallback failed for .{file_ext}: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Extraction failed for .{file_ext}")
|
|
|
|
# For PDFs, try pdfplumber
|
|
if file_ext == "pdf":
|
|
try:
|
|
import pdfplumber
|
|
text_content = []
|
|
with pdfplumber.open(tmp_path) as pdf:
|
|
for page in pdf.pages:
|
|
text = page.extract_text()
|
|
if text:
|
|
text_content.append(text)
|
|
content = "\n\n".join(text_content)
|
|
return {
|
|
"success": True,
|
|
"model": "pdfplumber (fallback)",
|
|
"output_format": "text",
|
|
"result": content,
|
|
"filename": filename,
|
|
"processing_time_ms": (time.time() - start_time) * 1000,
|
|
"device": "cpu"
|
|
}
|
|
except ImportError:
|
|
pass
|
|
|
|
# For other documents, return error
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="Document processing unavailable for this type. Supported: office/text/image/zip standard formats."
|
|
)
|
|
|
|
finally:
|
|
if os.path.exists(tmp_path):
|
|
os.unlink(tmp_path)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"❌ Document endpoint error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/document/models")
|
|
async def list_document_models():
|
|
"""List available document processing models"""
|
|
doc_models = [
|
|
{
|
|
"name": model.name,
|
|
"hf_name": model.hf_name,
|
|
"type": model.type,
|
|
"size_gb": model.size_gb,
|
|
"priority": model.priority,
|
|
"status": model.status.value,
|
|
"capabilities": model.capabilities,
|
|
"loaded": model.name in swapper.hf_models
|
|
}
|
|
for model in swapper.models.values()
|
|
if model.backend == ModelBackend.HUGGINGFACE and model.type == "document"
|
|
]
|
|
return {
|
|
"document_models": doc_models,
|
|
"device": swapper.device
|
|
}
|
|
|
|
# ========== Image Generation API Endpoints (FLUX) ==========
|
|
|
|
class ImageGenerateRequest(BaseModel):
|
|
"""Image generation request"""
|
|
model: str = "flux-klein-4b"
|
|
prompt: str
|
|
negative_prompt: str = ""
|
|
num_inference_steps: int = 50
|
|
guidance_scale: float = 4.0
|
|
width: int = 1024
|
|
height: int = 1024
|
|
|
|
@app.post("/image/generate")
|
|
async def image_generate_endpoint(request: ImageGenerateRequest):
|
|
"""
|
|
Generate image using diffusion model (lazy loaded).
|
|
|
|
Models:
|
|
- flux-klein-4b: FLUX.2 Klein 4B (15.4GB VRAM, lazy loaded on demand)
|
|
|
|
The model is loaded on first request and unloaded when VRAM is needed for other models.
|
|
"""
|
|
try:
|
|
result = await swapper.image_generate(
|
|
model_name=request.model,
|
|
prompt=request.prompt,
|
|
negative_prompt=request.negative_prompt,
|
|
num_inference_steps=request.num_inference_steps,
|
|
guidance_scale=request.guidance_scale,
|
|
width=request.width,
|
|
height=request.height
|
|
)
|
|
return result
|
|
|
|
except ValueError as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
except Exception as e:
|
|
logger.error(f"❌ Image generate endpoint error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/image/models")
|
|
async def list_image_models():
|
|
"""List available image generation models"""
|
|
image_models = [
|
|
{
|
|
"name": model.name,
|
|
"hf_name": model.hf_name,
|
|
"type": model.type,
|
|
"size_gb": model.size_gb,
|
|
"priority": model.priority,
|
|
"status": model.status.value,
|
|
"capabilities": model.capabilities,
|
|
"loaded": model.name in swapper.hf_models
|
|
}
|
|
for model in swapper.models.values()
|
|
if model.backend == ModelBackend.HUGGINGFACE and model.type == "image_generation"
|
|
]
|
|
return {
|
|
"image_models": image_models,
|
|
"active_image_model": swapper.active_image_model,
|
|
"device": swapper.device
|
|
}
|
|
|
|
@app.post("/image/models/{model_name}/load")
|
|
async def load_image_model(model_name: str):
|
|
"""Pre-load an image generation model (optional, models are lazy loaded by default)"""
|
|
if model_name not in swapper.models:
|
|
raise HTTPException(status_code=404, detail=f"Model not found: {model_name}")
|
|
|
|
model_info = swapper.models[model_name]
|
|
if model_info.type != "image_generation":
|
|
raise HTTPException(status_code=400, detail="Not an image generation model")
|
|
|
|
async with swapper.loading_lock:
|
|
success = await swapper._load_hf_model(model_name)
|
|
if success:
|
|
model_info.status = ModelStatus.LOADED
|
|
model_info.loaded_at = datetime.now()
|
|
swapper.active_image_model = model_name
|
|
swapper.model_load_times[model_name] = datetime.now()
|
|
return {"status": "success", "model": model_name, "message": f"Image model {model_name} loaded"}
|
|
raise HTTPException(status_code=500, detail=f"Failed to load image model: {model_name}")
|
|
|
|
@app.post("/image/models/{model_name}/unload")
|
|
async def unload_image_model(model_name: str):
|
|
"""Unload an image generation model to free GPU memory"""
|
|
if model_name not in swapper.hf_models:
|
|
raise HTTPException(status_code=400, detail=f"Model not loaded: {model_name}")
|
|
|
|
async with swapper.loading_lock:
|
|
success = await swapper._unload_hf_model(model_name)
|
|
if success:
|
|
if model_name in swapper.models:
|
|
swapper.models[model_name].status = ModelStatus.UNLOADED
|
|
if swapper.active_image_model == model_name:
|
|
swapper.active_image_model = None
|
|
return {"status": "success", "model": model_name, "message": f"Image model {model_name} unloaded"}
|
|
raise HTTPException(status_code=500, detail=f"Failed to unload image model: {model_name}")
|
|
|
|
# ========== Web Scraping API Endpoints ==========
|
|
|
|
class WebExtractRequest(BaseModel):
|
|
"""Web content extraction request"""
|
|
url: str
|
|
method: str = "auto" # auto, jina, trafilatura, crawl4ai
|
|
include_links: bool = False
|
|
include_images: bool = False
|
|
|
|
class WebSearchRequest(BaseModel):
|
|
"""Web search request"""
|
|
query: str
|
|
max_results: int = 10
|
|
engine: str = "duckduckgo" # duckduckgo, google
|
|
|
|
@app.post("/web/extract")
|
|
async def web_extract(request: WebExtractRequest):
|
|
"""
|
|
Extract content from URL using multiple methods.
|
|
|
|
Methods:
|
|
- jina: Jina Reader API (free, JS support, cloud)
|
|
- trafilatura: Local extraction (fast, no JS)
|
|
- crawl4ai: Full crawling (JS support, local) - requires separate service
|
|
- auto: Try jina first, fallback to trafilatura
|
|
"""
|
|
url = request.url
|
|
method = request.method
|
|
|
|
async def extract_with_jina(url: str) -> dict:
|
|
"""Extract using Jina Reader API (free)"""
|
|
try:
|
|
jina_url = f"https://r.jina.ai/{url}"
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.get(jina_url)
|
|
if response.status_code == 200:
|
|
return {
|
|
"success": True,
|
|
"method": "jina",
|
|
"content": response.text,
|
|
"url": url
|
|
}
|
|
else:
|
|
return {"success": False, "error": f"Jina returned {response.status_code}"}
|
|
except Exception as e:
|
|
return {"success": False, "error": str(e)}
|
|
|
|
async def extract_with_trafilatura(url: str) -> dict:
|
|
"""Extract using Trafilatura (local)"""
|
|
try:
|
|
import trafilatura
|
|
downloaded = trafilatura.fetch_url(url)
|
|
if downloaded:
|
|
text = trafilatura.extract(
|
|
downloaded,
|
|
include_links=request.include_links,
|
|
include_images=request.include_images
|
|
)
|
|
return {
|
|
"success": True,
|
|
"method": "trafilatura",
|
|
"content": text,
|
|
"url": url
|
|
}
|
|
return {"success": False, "error": "Failed to download page"}
|
|
except ImportError:
|
|
return {"success": False, "error": "Trafilatura not installed"}
|
|
except Exception as e:
|
|
return {"success": False, "error": str(e)}
|
|
|
|
async def extract_with_crawl4ai(url: str) -> dict:
|
|
"""Extract using Crawl4AI service"""
|
|
try:
|
|
crawl4ai_url = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235")
|
|
async with httpx.AsyncClient(timeout=60.0) as client:
|
|
response = await client.post(
|
|
f"{crawl4ai_url}/crawl",
|
|
json={"urls": [url], "word_count_threshold": 10}
|
|
)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
result = data.get("results", [{}])[0]
|
|
|
|
# Get markdown - can be string or dict with raw_markdown
|
|
markdown_data = result.get("markdown", "")
|
|
if isinstance(markdown_data, dict):
|
|
content = markdown_data.get("raw_markdown", "") or markdown_data.get("fit_markdown", "")
|
|
else:
|
|
content = markdown_data
|
|
|
|
# Fallback to cleaned_html
|
|
if not content:
|
|
content = result.get("cleaned_html", "") or result.get("extracted_content", "")
|
|
|
|
# Last resort: strip HTML tags
|
|
if not content and result.get("html"):
|
|
import re
|
|
content = re.sub(r'<[^>]+>', ' ', result.get("html", ""))
|
|
content = re.sub(r'\s+', ' ', content).strip()
|
|
|
|
# Limit size for LLM context
|
|
if len(content) > 50000:
|
|
content = content[:50000] + "\n\n[... truncated ...]"
|
|
|
|
return {
|
|
"success": bool(content),
|
|
"method": "crawl4ai",
|
|
"content": content,
|
|
"url": url
|
|
}
|
|
return {"success": False, "error": f"Crawl4AI returned {response.status_code}"}
|
|
except Exception as e:
|
|
return {"success": False, "error": str(e)}
|
|
|
|
# Execute based on method
|
|
if method == "jina":
|
|
result = await extract_with_jina(url)
|
|
elif method == "trafilatura":
|
|
result = await extract_with_trafilatura(url)
|
|
elif method == "crawl4ai":
|
|
result = await extract_with_crawl4ai(url)
|
|
elif method == "auto":
|
|
# Try jina first (JS support), fallback to trafilatura
|
|
result = await extract_with_jina(url)
|
|
if not result.get("success"):
|
|
logger.info(f"Jina failed, trying trafilatura for {url}")
|
|
result = await extract_with_trafilatura(url)
|
|
else:
|
|
raise HTTPException(status_code=400, detail=f"Unknown method: {method}")
|
|
|
|
if not result.get("success"):
|
|
raise HTTPException(status_code=500, detail=result.get("error", "Extraction failed"))
|
|
|
|
return result
|
|
|
|
@app.post("/web/search")
|
|
async def web_search(request: WebSearchRequest):
|
|
"""
|
|
Search the web using multiple engines with fallback.
|
|
Priority: 1) DDGS (DuckDuckGo) 2) Google Search
|
|
"""
|
|
formatted_results = []
|
|
engine_used = "none"
|
|
|
|
# Method 1: Try DDGS (new package name)
|
|
try:
|
|
from ddgs import DDGS
|
|
ddgs = DDGS()
|
|
results = list(ddgs.text(
|
|
request.query,
|
|
max_results=request.max_results,
|
|
region="wt-wt" # Worldwide
|
|
))
|
|
|
|
if results:
|
|
for idx, result in enumerate(results):
|
|
formatted_results.append({
|
|
"position": idx + 1,
|
|
"title": result.get("title", ""),
|
|
"url": result.get("href", result.get("link", "")),
|
|
"snippet": result.get("body", result.get("snippet", ""))
|
|
})
|
|
engine_used = "ddgs"
|
|
logger.info(f"✅ DDGS search found {len(formatted_results)} results for: {request.query[:50]}")
|
|
except ImportError:
|
|
logger.warning("DDGS not installed, trying Google search")
|
|
except Exception as e:
|
|
logger.warning(f"DDGS search failed: {e}, trying Google search")
|
|
|
|
# Method 2: Fallback to Google search
|
|
if not formatted_results:
|
|
try:
|
|
from googlesearch import search as google_search
|
|
results = list(google_search(request.query, num_results=request.max_results, lang="uk"))
|
|
|
|
if results:
|
|
for idx, url in enumerate(results):
|
|
formatted_results.append({
|
|
"position": idx + 1,
|
|
"title": url.split("/")[-1].replace("-", " ").replace("_", " ")[:60] or "Result",
|
|
"url": url,
|
|
"snippet": ""
|
|
})
|
|
engine_used = "google"
|
|
logger.info(f"✅ Google search found {len(formatted_results)} results for: {request.query[:50]}")
|
|
except ImportError:
|
|
logger.warning("Google search not installed")
|
|
except Exception as e:
|
|
logger.warning(f"Google search failed: {e}")
|
|
|
|
# Return results or empty
|
|
return {
|
|
"success": len(formatted_results) > 0,
|
|
"query": request.query,
|
|
"results": formatted_results,
|
|
"total": len(formatted_results),
|
|
"engine": engine_used
|
|
}
|
|
|
|
@app.get("/web/read/{url:path}")
|
|
async def web_read_simple(url: str):
|
|
"""
|
|
Simple GET endpoint to read a URL (uses Jina by default).
|
|
Example: GET /web/read/https://example.com
|
|
"""
|
|
request = WebExtractRequest(url=url, method="auto")
|
|
return await web_extract(request)
|
|
|
|
@app.get("/web/status")
|
|
async def web_status():
|
|
"""Check availability of web scraping methods"""
|
|
# Check Trafilatura
|
|
try:
|
|
import trafilatura
|
|
trafilatura_available = True
|
|
except ImportError:
|
|
trafilatura_available = False
|
|
|
|
# Check DuckDuckGo
|
|
try:
|
|
from duckduckgo_search import DDGS
|
|
ddgs_available = True
|
|
except ImportError:
|
|
ddgs_available = False
|
|
|
|
# Check Crawl4AI
|
|
crawl4ai_url = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235")
|
|
crawl4ai_available = False
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
response = await client.get(f"{crawl4ai_url}/health")
|
|
crawl4ai_available = response.status_code == 200
|
|
except:
|
|
pass
|
|
|
|
return {
|
|
"methods": {
|
|
"jina": {"available": True, "type": "cloud", "js_support": True},
|
|
"trafilatura": {"available": trafilatura_available, "type": "local", "js_support": False},
|
|
"crawl4ai": {"available": crawl4ai_available, "type": "local", "js_support": True}
|
|
},
|
|
"search": {
|
|
"duckduckgo": {"available": ddgs_available}
|
|
}
|
|
}
|
|
|
|
# ========== Video Generation API (Grok xAI) ==========
|
|
|
|
GROK_API_KEY = os.getenv("GROK_API_KEY", "")
|
|
GROK_API_URL = "https://api.x.ai/v1"
|
|
|
|
class VideoGenerateRequest(BaseModel):
|
|
"""Video generation request via Grok"""
|
|
prompt: str
|
|
duration: int = 6 # seconds (max 6 for Grok)
|
|
style: str = "cinematic" # cinematic, anime, realistic, abstract
|
|
aspect_ratio: str = "16:9" # 16:9, 9:16, 1:1
|
|
|
|
@app.post("/video/generate")
|
|
async def video_generate(request: VideoGenerateRequest):
|
|
"""
|
|
Generate image using Grok (xAI) API.
|
|
|
|
Note: Grok API currently supports image generation only (not video).
|
|
For video-like content, generate multiple frames and combine externally.
|
|
"""
|
|
if not GROK_API_KEY:
|
|
raise HTTPException(status_code=503, detail="GROK_API_KEY not configured")
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
# Grok image generation endpoint
|
|
response = await client.post(
|
|
f"{GROK_API_URL}/images/generations",
|
|
headers={
|
|
"Authorization": f"Bearer {GROK_API_KEY}",
|
|
"Content-Type": "application/json"
|
|
},
|
|
json={
|
|
"model": "grok-2-image-1212", # Correct model name
|
|
"prompt": f"{request.prompt}, {request.style} style",
|
|
"n": 1,
|
|
"response_format": "url"
|
|
}
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
return {
|
|
"success": True,
|
|
"prompt": request.prompt,
|
|
"style": request.style,
|
|
"type": "image", # Note: video not available via API
|
|
"result": data,
|
|
"provider": "grok-xai",
|
|
"note": "Grok API supports image generation. Video generation is available only in xAI app."
|
|
}
|
|
else:
|
|
logger.error(f"Grok API error: {response.status_code} - {response.text}")
|
|
raise HTTPException(
|
|
status_code=response.status_code,
|
|
detail=f"Grok API error: {response.text}"
|
|
)
|
|
|
|
except httpx.TimeoutException:
|
|
raise HTTPException(status_code=504, detail="Image generation timeout (>120s)")
|
|
except Exception as e:
|
|
logger.error(f"Image generation error: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/video/status")
|
|
async def video_status():
|
|
"""Check Grok image/video generation service status"""
|
|
return {
|
|
"service": "grok-xai",
|
|
"api_key_configured": bool(GROK_API_KEY),
|
|
"capabilities": {
|
|
"image_generation": True, # grok-2-image-1212
|
|
"video_generation": False, # Not available via API (only in xAI app)
|
|
"vision_analysis": True # grok-2-vision-1212
|
|
},
|
|
"models": {
|
|
"image": "grok-2-image-1212",
|
|
"vision": "grok-2-vision-1212",
|
|
"chat": ["grok-3", "grok-3-mini", "grok-4-0709"]
|
|
},
|
|
"supported_styles": ["cinematic", "anime", "realistic", "abstract", "photorealistic"]
|
|
}
|
|
|
|
|
|
# ========== Multimodal Stack Summary ==========
|
|
|
|
@app.get("/multimodal")
|
|
async def get_multimodal_stack():
|
|
"""Get full multimodal stack status"""
|
|
def get_models_by_type(model_type: str):
|
|
return [
|
|
{
|
|
"name": m.name,
|
|
"size_gb": m.size_gb,
|
|
"status": m.status.value,
|
|
"loaded": m.name in swapper.hf_models
|
|
}
|
|
for m in swapper.models.values()
|
|
if m.type == model_type
|
|
]
|
|
|
|
return {
|
|
"device": swapper.device,
|
|
"cuda_available": TORCH_AVAILABLE and torch.cuda.is_available(),
|
|
"stack": {
|
|
"llm": get_models_by_type("llm"),
|
|
"vision": get_models_by_type("vision"),
|
|
"math": get_models_by_type("math"),
|
|
"ocr": get_models_by_type("ocr"),
|
|
"document": get_models_by_type("document"),
|
|
"stt": get_models_by_type("stt"),
|
|
"tts": get_models_by_type("tts"),
|
|
"embedding": get_models_by_type("embedding"),
|
|
"image_generation": get_models_by_type("image_generation"),
|
|
"video_generation": {"provider": "grok-xai", "available": bool(GROK_API_KEY)}
|
|
},
|
|
"active_models": {
|
|
"llm": swapper.active_model,
|
|
"ocr": swapper.active_ocr_model
|
|
}
|
|
}
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8890)
|