""" Swapper Service - Dynamic Model Loading Service Manages loading/unloading LLM and OCR models on-demand to optimize memory usage. Supports: - Ollama models (LLM, Vision, Math) - HuggingFace models (OCR, Document Understanding) - Lazy loading for OCR models """ import os import asyncio import logging import base64 import json import re from typing import Optional, Dict, List, Any, Union from datetime import datetime, timedelta from enum import Enum from io import BytesIO import xml.etree.ElementTree as ET from fastapi import FastAPI, HTTPException, BackgroundTasks, File, UploadFile, Form from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import httpx import csv import zipfile from io import BytesIO import mimetypes import yaml # Optional imports for HuggingFace models try: import torch from PIL import Image TORCH_AVAILABLE = True except ImportError: TORCH_AVAILABLE = False torch = None Image = None logger = logging.getLogger(__name__) # ========== Document Helpers ========== def _decode_text_bytes(content: bytes) -> str: """Decode text with best-effort fallback.""" try: import chardet detected = chardet.detect(content) encoding = detected.get("encoding") or "utf-8" return content.decode(encoding, errors="replace") except Exception: try: return content.decode("utf-8", errors="replace") except Exception: return content.decode("latin-1", errors="replace") def _csv_to_markdown(content: bytes) -> str: text = _decode_text_bytes(content) reader = csv.reader(text.splitlines()) rows = list(reader) return _rows_to_markdown(rows) def _tsv_to_markdown(content: bytes) -> str: text = _decode_text_bytes(content) reader = csv.reader(text.splitlines(), delimiter="\t") rows = list(reader) return _rows_to_markdown(rows) def _rows_to_markdown(rows: List[List[Any]]) -> str: if not rows: return "" width = max(len(r) for r in rows) norm_rows = [] for r in rows: rr = [str(c) if c is not None else "" for c in r] if len(rr) < width: rr.extend([""] * (width - len(rr))) norm_rows.append(rr) header = norm_rows[0] body = norm_rows[1:] lines = [ "| " + " | ".join(header) + " |", "| " + " | ".join(["---"] * len(header)) + " |", ] for row in body: lines.append("| " + " | ".join([str(c) if c is not None else "" for c in row]) + " |") return "\n".join(lines) def _xlsx_to_markdown(content: bytes) -> str: try: import openpyxl except Exception as e: raise HTTPException(status_code=500, detail=f"openpyxl not available: {e}") wb = openpyxl.load_workbook(filename=BytesIO(content), data_only=True) parts = [] for sheet in wb.worksheets: parts.append(f"## Sheet: {sheet.title}") rows = list(sheet.iter_rows(values_only=True)) if not rows: parts.append("_Empty sheet_") continue header = [str(c) if c is not None else "" for c in rows[0]] body = rows[1:] parts.append("| " + " | ".join(header) + " |") parts.append("| " + " | ".join(["---"] * len(header)) + " |") for row in body: parts.append("| " + " | ".join([str(c) if c is not None else "" for c in row]) + " |") return "\n".join(parts) def _xls_to_markdown(content: bytes) -> str: try: import xlrd except Exception as e: raise HTTPException(status_code=500, detail=f"xlrd not available: {e}") wb = xlrd.open_workbook(file_contents=content) parts = [] for s in wb.sheets(): parts.append(f"## Sheet: {s.name}") rows = [] for r in range(s.nrows): rows.append([s.cell_value(r, c) for c in range(s.ncols)]) if not rows: parts.append("_Empty sheet_") continue parts.append(_rows_to_markdown(rows)) return "\n\n".join(parts) def _ods_to_markdown(content: bytes) -> str: try: from odf.opendocument import load from odf.table import Table, TableRow, TableCell from odf.text import P except Exception as e: raise HTTPException(status_code=500, detail=f"odfpy not available: {e}") try: doc = load(BytesIO(content)) except Exception as e: raise HTTPException(status_code=400, detail=f"Invalid ODS file: {e}") parts = [] for table in doc.spreadsheet.getElementsByType(Table): table_name = str(table.getAttribute("name") or "Sheet") parts.append(f"## Sheet: {table_name}") rows: List[List[str]] = [] for row in table.getElementsByType(TableRow): cells_out: List[str] = [] for cell in row.getElementsByType(TableCell): txt_parts = [] for p in cell.getElementsByType(P): txt_parts.extend( [str(getattr(node, "data", "")).strip() for node in p.childNodes if getattr(node, "data", None)] ) cell_text = " ".join([t for t in txt_parts if t]).strip() repeat_raw = cell.getAttribute("numbercolumnsrepeated") try: repeat = int(repeat_raw) if repeat_raw else 1 except Exception: repeat = 1 repeat = max(1, min(repeat, 100)) for _ in range(repeat): cells_out.append(cell_text) if cells_out: rows.append(cells_out) if not rows: parts.append("_Empty sheet_") continue parts.append(_rows_to_markdown(rows)) return "\n\n".join(parts) def _docx_to_text(content: bytes) -> str: try: from docx import Document except Exception as e: raise HTTPException(status_code=500, detail=f"python-docx not available: {e}") doc = Document(BytesIO(content)) lines = [p.text for p in doc.paragraphs if p.text] return "\n".join(lines) def _pdf_to_text(content: bytes) -> str: try: import pdfplumber except Exception as e: raise HTTPException(status_code=500, detail=f"pdfplumber not available: {e}") text_content = [] with pdfplumber.open(BytesIO(content)) as pdf: for page in pdf.pages: page_text = page.extract_text() or "" if page_text: text_content.append(page_text) return "\n\n".join(text_content) def _pptx_to_text(content: bytes) -> str: try: from pptx import Presentation except Exception as e: raise HTTPException(status_code=500, detail=f"python-pptx not available: {e}") prs = Presentation(BytesIO(content)) parts = [] for idx, slide in enumerate(prs.slides, start=1): parts.append(f"## Slide {idx}") slide_lines = [] for shape in slide.shapes: text = getattr(shape, "text", None) if text and str(text).strip(): slide_lines.append(str(text).strip()) parts.extend(slide_lines if slide_lines else ["_No text on this slide_"]) return "\n\n".join(parts) def _json_to_text(content: bytes) -> str: raw = _decode_text_bytes(content) try: parsed = json.loads(raw) return json.dumps(parsed, ensure_ascii=False, indent=2) except Exception: return raw def _yaml_to_text(content: bytes) -> str: raw = _decode_text_bytes(content) try: parsed = yaml.safe_load(raw) return yaml.safe_dump(parsed, allow_unicode=True, sort_keys=False) except Exception: return raw def _xml_to_text(content: bytes) -> str: raw = _decode_text_bytes(content) try: root = ET.fromstring(raw) text = " ".join([t.strip() for t in root.itertext() if t and t.strip()]) return text or raw except Exception: return raw def _html_to_text(content: bytes) -> str: raw = _decode_text_bytes(content) try: from bs4 import BeautifulSoup soup = BeautifulSoup(raw, "html.parser") text = soup.get_text(separator="\n") text = re.sub(r"\n{3,}", "\n\n", text) return text.strip() or raw except Exception: # Minimal fallback if bs4 is unavailable text = re.sub(r"<[^>]+>", " ", raw) text = re.sub(r"\s+", " ", text) return text.strip() def _rtf_to_text(content: bytes) -> str: raw = _decode_text_bytes(content) try: from striprtf.striprtf import rtf_to_text return rtf_to_text(raw) except Exception: # Basic fallback: strip common RTF control tokens text = re.sub(r"\\'[0-9a-fA-F]{2}", " ", raw) text = re.sub(r"\\[a-zA-Z]+-?\d* ?", " ", text) text = text.replace("{", " ").replace("}", " ") return re.sub(r"\s+", " ", text).strip() def _extract_text_by_ext(filename: str, content: bytes) -> str: ext = filename.split(".")[-1].lower() if "." in filename else "" if ext in ["txt", "md", "markdown"]: return _decode_text_bytes(content) if ext == "csv": return _csv_to_markdown(content) if ext == "tsv": return _tsv_to_markdown(content) if ext in {"xlsx", "xlsm"}: return _xlsx_to_markdown(content) if ext == "xls": return _xls_to_markdown(content) if ext == "ods": return _ods_to_markdown(content) if ext == "docx": return _docx_to_text(content) if ext == "pdf": return _pdf_to_text(content) if ext == "pptx": return _pptx_to_text(content) if ext == "json": return _json_to_text(content) if ext in {"yaml", "yml"}: return _yaml_to_text(content) if ext == "xml": return _xml_to_text(content) if ext in {"html", "htm"}: return _html_to_text(content) if ext == "rtf": return _rtf_to_text(content) raise HTTPException(status_code=400, detail=f"Unsupported file type: .{ext}") def _zip_to_markdown(content: bytes, max_files: int = 50, max_total_mb: int = 100) -> str: zf = zipfile.ZipFile(BytesIO(content)) members = [m for m in zf.infolist() if not m.is_dir()] if len(members) > max_files: raise HTTPException(status_code=400, detail=f"ZIP has слишком много файлов: {len(members)}") total_size = sum(m.file_size for m in members) if total_size > max_total_mb * 1024 * 1024: raise HTTPException(status_code=400, detail=f"ZIP слишком большой: {total_size / 1024 / 1024:.1f} MB") parts = [] allowed_exts = { "txt", "md", "markdown", "csv", "tsv", "xls", "xlsx", "xlsm", "ods", "docx", "pdf", "pptx", "json", "yaml", "yml", "xml", "html", "htm", "rtf", } processed = [] skipped = [] for member in members: name = member.filename ext = name.split(".")[-1].lower() if "." in name else "" if ext not in allowed_exts: skipped.append(name) parts.append(f"## {name}\n_Skipped unsupported file type_") continue file_bytes = zf.read(member) extracted = _extract_text_by_ext(name, file_bytes) processed.append(name) parts.append(f"## {name}\n{extracted}") header_lines = ["# ZIP summary", "Processed files:"] header_lines.extend([f"- {name}" for name in processed] or ["- (none)"]) if skipped: header_lines.append("Skipped files:") header_lines.extend([f"- {name}" for name in skipped]) return "\n\n".join(["\n".join(header_lines), *parts]) # ========== Configuration ========== OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") SWAPPER_CONFIG_PATH = os.getenv("SWAPPER_CONFIG_PATH", "./config/swapper_config.yaml") SWAPPER_MODE = os.getenv("SWAPPER_MODE", "single-active") # single-active or multi-active MAX_CONCURRENT_MODELS = int(os.getenv("MAX_CONCURRENT_MODELS", "1")) MODEL_SWAP_TIMEOUT = int(os.getenv("MODEL_SWAP_TIMEOUT", "30")) # ========== Models ========== class ModelStatus(str, Enum): """Model status""" LOADED = "loaded" LOADING = "loading" UNLOADED = "unloaded" UNLOADING = "unloading" ERROR = "error" class ModelBackend(str, Enum): """Model backend type""" OLLAMA = "ollama" HUGGINGFACE = "huggingface" class ModelInfo(BaseModel): """Model information""" name: str ollama_name: str # For Ollama models hf_name: Optional[str] = None # For HuggingFace models backend: ModelBackend = ModelBackend.OLLAMA type: str # llm, code, vision, math, ocr size_gb: float priority: str # high, medium, low status: ModelStatus capabilities: List[str] = [] # For OCR models loaded_at: Optional[datetime] = None unloaded_at: Optional[datetime] = None total_uptime_seconds: float = 0.0 request_count: int = 0 class SwapperStatus(BaseModel): """Swapper service status""" status: str active_model: Optional[str] = None available_models: List[str] loaded_models: List[str] mode: str total_models: int class ModelMetrics(BaseModel): """Model usage metrics""" model_name: str status: str loaded_at: Optional[datetime] = None uptime_hours: float request_count: int total_uptime_seconds: float # ========== Swapper Service ========== class SwapperService: """Swapper Service - manages model loading/unloading for Ollama and HuggingFace""" def __init__(self): self.models: Dict[str, ModelInfo] = {} self.active_model: Optional[str] = None # Active LLM model self.active_ocr_model: Optional[str] = None # Active OCR model (separate from LLM) self.active_image_model: Optional[str] = None # Active Image Generation model self.loading_lock = asyncio.Lock() self.http_client = httpx.AsyncClient(timeout=300.0) self.model_uptime: Dict[str, float] = {} # Track uptime per model self.model_load_times: Dict[str, datetime] = {} # Track when model was loaded # HuggingFace model instances (lazy loaded) self.hf_models: Dict[str, Any] = {} # model_name -> model instance self.hf_processors: Dict[str, Any] = {} # model_name -> processor/tokenizer # Device configuration self.device = "cuda" if TORCH_AVAILABLE and torch.cuda.is_available() else "cpu" logger.info(f"🔧 Swapper initialized with device: {self.device}") async def initialize(self): """Initialize Swapper Service - load configuration""" config = None try: logger.info(f"🔧 Initializing Swapper Service...") logger.info(f"🔧 Config path: {SWAPPER_CONFIG_PATH}") logger.info(f"🔧 Config exists: {os.path.exists(SWAPPER_CONFIG_PATH)}") if os.path.exists(SWAPPER_CONFIG_PATH): with open(SWAPPER_CONFIG_PATH, 'r') as f: config = yaml.safe_load(f) models_config = config.get('models', {}) logger.info(f"🔧 Found {len(models_config)} models in config") for model_key, model_config in models_config.items(): path = model_config.get('path', '') model_type = model_config.get('type', 'llm') capabilities = model_config.get('capabilities', []) # Determine backend from path prefix if path.startswith('huggingface:'): hf_name = path.replace('huggingface:', '') logger.info(f"🔧 Adding HuggingFace model: {model_key} -> {hf_name}") self.models[model_key] = ModelInfo( name=model_key, ollama_name="", hf_name=hf_name, backend=ModelBackend.HUGGINGFACE, type=model_type, size_gb=model_config.get('size_gb', 0), priority=model_config.get('priority', 'medium'), capabilities=capabilities, status=ModelStatus.UNLOADED ) else: ollama_name = path.replace('ollama:', '') logger.info(f"🔧 Adding Ollama model: {model_key} -> {ollama_name}") self.models[model_key] = ModelInfo( name=model_key, ollama_name=ollama_name, backend=ModelBackend.OLLAMA, type=model_type, size_gb=model_config.get('size_gb', 0), priority=model_config.get('priority', 'medium'), capabilities=capabilities, status=ModelStatus.UNLOADED ) self.model_uptime[model_key] = 0.0 logger.info(f"✅ Loaded {len(self.models)} models into Swapper") # Count by backend ollama_count = sum(1 for m in self.models.values() if m.backend == ModelBackend.OLLAMA) hf_count = sum(1 for m in self.models.values() if m.backend == ModelBackend.HUGGINGFACE) logger.info(f"✅ Models: {ollama_count} Ollama, {hf_count} HuggingFace") else: logger.warning(f"⚠️ Config file not found: {SWAPPER_CONFIG_PATH}, using defaults") await self._load_models_from_ollama() logger.info(f"✅ Swapper Service initialized with {len(self.models)} models") logger.info(f"✅ Model names: {list(self.models.keys())}") # Load default LLM model (not OCR - those are lazy loaded) if config: swapper_config = config.get('swapper', {}) default_model = swapper_config.get('default_model') lazy_load_ocr = swapper_config.get('lazy_load_ocr', True) if default_model and default_model in self.models: model_info = self.models[default_model] # Only auto-load non-OCR models if model_info.type != 'ocr' or not lazy_load_ocr: logger.info(f"🔄 Loading default model: {default_model}") success = await self.load_model(default_model) if success: logger.info(f"✅ Default model loaded: {default_model}") else: logger.warning(f"⚠️ Failed to load default model: {default_model}") else: logger.info(f"⏳ OCR model '{default_model}' will be lazy loaded on first request") elif default_model: logger.warning(f"⚠️ Default model '{default_model}' not found in models list") except Exception as e: logger.error(f"❌ Error initializing Swapper Service: {e}", exc_info=True) import traceback logger.error(f"❌ Traceback: {traceback.format_exc()}") async def _load_models_from_ollama(self): """Load available models from Ollama""" try: response = await self.http_client.get(f"{OLLAMA_BASE_URL}/api/tags") if response.status_code == 200: data = response.json() for model in data.get('models', []): model_name = model.get('name', '') # Extract base name (remove :latest, :7b, etc.) base_name = model_name.split(':')[0] if base_name not in self.models: size_gb = model.get('size', 0) / (1024**3) # Convert bytes to GB self.models[base_name] = ModelInfo( name=base_name, ollama_name=model_name, type='llm', # Default type size_gb=size_gb, priority='medium', status=ModelStatus.UNLOADED ) self.model_uptime[base_name] = 0.0 logger.info(f"✅ Loaded {len(self.models)} models from Ollama") except Exception as e: logger.error(f"❌ Error loading models from Ollama: {e}") async def load_model(self, model_name: str) -> bool: """Load a model (unload current if in single-active mode)""" async with self.loading_lock: try: # Check if model exists if model_name not in self.models: logger.error(f"❌ Model not found: {model_name}") return False model_info = self.models[model_name] # If single-active mode and another model is loaded, unload it first if SWAPPER_MODE == "single-active" and self.active_model and self.active_model != model_name: await self._unload_model_internal(self.active_model) # Load the model logger.info(f"🔄 Loading model: {model_name}") model_info.status = ModelStatus.LOADING # Check if model is already loaded in Ollama response = await self.http_client.post( f"{OLLAMA_BASE_URL}/api/generate", json={ "model": model_info.ollama_name, "prompt": "test", "stream": False }, timeout=MODEL_SWAP_TIMEOUT ) if response.status_code == 200: model_info.status = ModelStatus.LOADED model_info.loaded_at = datetime.now() model_info.unloaded_at = None self.active_model = model_name self.model_load_times[model_name] = datetime.now() logger.info(f"✅ Model loaded: {model_name}") return True else: model_info.status = ModelStatus.ERROR logger.error(f"❌ Failed to load model: {model_name}") return False except Exception as e: logger.error(f"❌ Error loading model {model_name}: {e}", exc_info=True) if model_name in self.models: self.models[model_name].status = ModelStatus.ERROR return False async def _unload_model_internal(self, model_name: str) -> bool: """Internal method to unload a model""" try: if model_name not in self.models: return False model_info = self.models[model_name] if model_info.status == ModelStatus.LOADED: logger.info(f"🔄 Unloading model: {model_name}") model_info.status = ModelStatus.UNLOADING # Calculate uptime if model_name in self.model_load_times: load_time = self.model_load_times[model_name] uptime_seconds = (datetime.now() - load_time).total_seconds() self.model_uptime[model_name] = self.model_uptime.get(model_name, 0.0) + uptime_seconds model_info.total_uptime_seconds = self.model_uptime[model_name] del self.model_load_times[model_name] model_info.status = ModelStatus.UNLOADED model_info.unloaded_at = datetime.now() if self.active_model == model_name: self.active_model = None logger.info(f"✅ Model unloaded: {model_name}") return True except Exception as e: logger.error(f"❌ Error unloading model {model_name}: {e}") return False async def unload_model(self, model_name: str) -> bool: """Unload a model""" async with self.loading_lock: return await self._unload_model_internal(model_name) async def get_status(self) -> SwapperStatus: """Get Swapper service status""" # Update uptime for currently loaded model if self.active_model and self.active_model in self.model_load_times: load_time = self.model_load_times[self.active_model] current_uptime = (datetime.now() - load_time).total_seconds() self.model_uptime[self.active_model] = self.model_uptime.get(self.active_model, 0.0) + current_uptime self.model_load_times[self.active_model] = datetime.now() # Reset timer loaded_models = [ name for name, model in self.models.items() if model.status == ModelStatus.LOADED ] return SwapperStatus( status="healthy", active_model=self.active_model, available_models=list(self.models.keys()), loaded_models=loaded_models, mode=SWAPPER_MODE, total_models=len(self.models) ) async def get_model_metrics(self, model_name: Optional[str] = None) -> List[ModelMetrics]: """Get metrics for model(s)""" metrics = [] models_to_check = [model_name] if model_name else list(self.models.keys()) for name in models_to_check: if name not in self.models: continue model_info = self.models[name] # Calculate current uptime uptime_seconds = self.model_uptime.get(name, 0.0) if name in self.model_load_times: load_time = self.model_load_times[name] current_uptime = (datetime.now() - load_time).total_seconds() uptime_seconds += current_uptime uptime_hours = uptime_seconds / 3600.0 metrics.append(ModelMetrics( model_name=name, status=model_info.status.value, loaded_at=model_info.loaded_at, uptime_hours=uptime_hours, request_count=model_info.request_count, total_uptime_seconds=uptime_seconds )) return metrics async def generate(self, model_name: str, prompt: str, system_prompt: Optional[str] = None, max_tokens: int = 2048, temperature: float = 0.7, stream: bool = False) -> Dict[str, Any]: """Generate text using a model""" try: # Ensure model is loaded if model_name not in self.models: raise ValueError(f"Model not found: {model_name}") model_info = self.models[model_name] # Load model if not loaded if model_info.status != ModelStatus.LOADED: logger.info(f"🔄 Model {model_name} not loaded, loading now...") success = await self.load_model(model_name) if not success: raise ValueError(f"Failed to load model: {model_name}") # Increment request count model_info.request_count += 1 # Prepare request to Ollama request_data = { "model": model_info.ollama_name, "prompt": prompt, "stream": stream, "options": { "num_predict": max_tokens, "temperature": temperature } } if system_prompt: request_data["system"] = system_prompt # Call Ollama response = await self.http_client.post( f"{OLLAMA_BASE_URL}/api/generate", json=request_data, timeout=300.0 ) if response.status_code == 200: data = response.json() return { "response": data.get("response", ""), "model": model_name, "done": data.get("done", True), "eval_count": data.get("eval_count", 0), "prompt_eval_count": data.get("prompt_eval_count", 0) } else: raise HTTPException( status_code=response.status_code, detail=f"Ollama error: {response.text}" ) except Exception as e: logger.error(f"❌ Error generating with model {model_name}: {e}", exc_info=True) raise async def close(self): """Close HTTP client and unload HuggingFace models""" await self.http_client.aclose() # Unload HuggingFace models to free GPU memory for model_name in list(self.hf_models.keys()): await self._unload_hf_model(model_name) async def _load_hf_model(self, model_name: str) -> bool: """Load a HuggingFace model (OCR/Document Understanding)""" if not TORCH_AVAILABLE: logger.error("❌ PyTorch not available, cannot load HuggingFace models") return False try: model_info = self.models[model_name] if model_info.backend != ModelBackend.HUGGINGFACE: logger.error(f"❌ Model {model_name} is not a HuggingFace model") return False hf_name = model_info.hf_name logger.info(f"🔄 Loading HuggingFace model: {hf_name} on {self.device}") from transformers import AutoModel, AutoTokenizer, AutoProcessor # Different loading strategies based on model type if "GOT-OCR" in hf_name or "got-ocr" in hf_name.lower(): # GOT-OCR2.0 specific loading tokenizer = AutoTokenizer.from_pretrained(hf_name, trust_remote_code=True) model = AutoModel.from_pretrained( hf_name, trust_remote_code=True, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, device_map="auto" if self.device == "cuda" else None, low_cpu_mem_usage=True ) self.hf_processors[model_name] = tokenizer self.hf_models[model_name] = model elif "donut" in hf_name.lower(): # Donut model loading from transformers import DonutProcessor, VisionEncoderDecoderModel processor = DonutProcessor.from_pretrained(hf_name) model = VisionEncoderDecoderModel.from_pretrained( hf_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) if self.device == "cuda": model = model.cuda() model.eval() self.hf_processors[model_name] = processor self.hf_models[model_name] = model elif "trocr" in hf_name.lower(): # TrOCR model loading from transformers import TrOCRProcessor, VisionEncoderDecoderModel processor = TrOCRProcessor.from_pretrained(hf_name) model = VisionEncoderDecoderModel.from_pretrained( hf_name, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) if self.device == "cuda": model = model.cuda() model.eval() self.hf_processors[model_name] = processor self.hf_models[model_name] = model elif "flux" in hf_name.lower() or model_info.type == "image_generation": # FLUX / Diffusion model loading logger.info(f"🎨 Loading diffusion model: {hf_name}") from diffusers import AutoPipelineForText2Image diffusion_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 pipeline = AutoPipelineForText2Image.from_pretrained( hf_name, torch_dtype=diffusion_dtype ) pipeline.to(self.device) if self.device == "cuda": pipeline.enable_model_cpu_offload() # Optimize VRAM usage on CUDA self.hf_models[model_name] = pipeline self.hf_processors[model_name] = None # No separate processor for diffusion logger.info(f"✅ Diffusion model loaded: {model_name} (device={self.device})") else: # Generic loading processor = AutoProcessor.from_pretrained(hf_name, trust_remote_code=True) model = AutoModel.from_pretrained( hf_name, trust_remote_code=True, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 ) if self.device == "cuda": model = model.cuda() self.hf_processors[model_name] = processor self.hf_models[model_name] = model logger.info(f"✅ HuggingFace model loaded: {model_name}") return True except Exception as e: logger.error(f"❌ Failed to load HuggingFace model {model_name}: {e}", exc_info=True) return False async def _unload_hf_model(self, model_name: str) -> bool: """Unload a HuggingFace model to free memory""" try: if model_name in self.hf_models: del self.hf_models[model_name] if model_name in self.hf_processors: del self.hf_processors[model_name] # Force garbage collection if TORCH_AVAILABLE: import gc gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info(f"✅ HuggingFace model unloaded: {model_name}") return True except Exception as e: logger.error(f"❌ Failed to unload HuggingFace model {model_name}: {e}") return False async def ocr_process(self, model_name: str, image_data: bytes, ocr_type: str = "ocr") -> Dict[str, Any]: """Process image with OCR model""" if not TORCH_AVAILABLE: raise HTTPException(status_code=503, detail="PyTorch not available") try: # Ensure model is loaded if model_name not in self.models: raise ValueError(f"OCR model not found: {model_name}") model_info = self.models[model_name] if model_info.backend != ModelBackend.HUGGINGFACE: raise ValueError(f"Model {model_name} is not an OCR model") # Lazy load model if not loaded if model_name not in self.hf_models: # Unload current OCR model if different (to save VRAM) if self.active_ocr_model and self.active_ocr_model != model_name: await self._unload_hf_model(self.active_ocr_model) if self.active_ocr_model in self.models: self.models[self.active_ocr_model].status = ModelStatus.UNLOADED logger.info(f"🔄 Lazy loading OCR model: {model_name}") model_info.status = ModelStatus.LOADING success = await self._load_hf_model(model_name) if not success: model_info.status = ModelStatus.ERROR raise ValueError(f"Failed to load OCR model: {model_name}") model_info.status = ModelStatus.LOADED model_info.loaded_at = datetime.now() self.active_ocr_model = model_name self.model_load_times[model_name] = datetime.now() # Process image image = Image.open(BytesIO(image_data)).convert("RGB") model = self.hf_models[model_name] processor = self.hf_processors[model_name] model_info.request_count += 1 hf_name = model_info.hf_name # Different processing based on model type if "GOT-OCR" in hf_name or "got-ocr" in hf_name.lower(): # GOT-OCR2.0 processing with torch.no_grad(): result = model.chat(processor, image, ocr_type=ocr_type) text = result if isinstance(result, str) else str(result) elif "donut" in hf_name.lower(): # Donut processing task_prompt = "" # or "" depending on task decoder_input_ids = processor.tokenizer( task_prompt, add_special_tokens=False, return_tensors="pt" ).input_ids pixel_values = processor(image, return_tensors="pt").pixel_values if self.device == "cuda": pixel_values = pixel_values.cuda() decoder_input_ids = decoder_input_ids.cuda() with torch.no_grad(): outputs = model.generate( pixel_values, decoder_input_ids=decoder_input_ids, max_length=model.decoder.config.max_position_embeddings, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True ) sequence = processor.batch_decode(outputs.sequences)[0] text = processor.token2json(sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")) elif "trocr" in hf_name.lower(): # TrOCR processing pixel_values = processor(images=image, return_tensors="pt").pixel_values if self.device == "cuda": pixel_values = pixel_values.cuda() with torch.no_grad(): generated_ids = model.generate(pixel_values, max_length=512) text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] else: # Generic processing (try chat method) with torch.no_grad(): if hasattr(model, 'chat'): result = model.chat(processor, image) text = result if isinstance(result, str) else str(result) else: text = "Model does not support direct inference" return { "success": True, "model": model_name, "text": text, "device": self.device } except Exception as e: logger.error(f"❌ OCR processing failed: {e}", exc_info=True) raise async def image_generate( self, model_name: str, prompt: str, negative_prompt: str = "", num_inference_steps: int = 50, guidance_scale: float = 4.0, width: int = 1024, height: int = 1024 ) -> Dict[str, Any]: """Generate image with diffusion model (lazy loaded)""" if not TORCH_AVAILABLE: raise HTTPException(status_code=503, detail="PyTorch not available") import time start_time = time.time() try: # Ensure model exists if model_name not in self.models: raise ValueError(f"Image model not found: {model_name}") model_info = self.models[model_name] if model_info.type != "image_generation": raise ValueError(f"Model {model_name} is not an image generation model") # Lazy load model if not loaded if model_name not in self.hf_models: # Unload current image model if different (to save VRAM) if self.active_image_model and self.active_image_model != model_name: logger.info(f"🔄 Unloading current image model: {self.active_image_model}") await self._unload_hf_model(self.active_image_model) if self.active_image_model in self.models: self.models[self.active_image_model].status = ModelStatus.UNLOADED logger.info(f"🎨 Lazy loading image model: {model_name}") model_info.status = ModelStatus.LOADING success = await self._load_hf_model(model_name) if not success: model_info.status = ModelStatus.ERROR raise ValueError(f"Failed to load image model: {model_name}") model_info.status = ModelStatus.LOADED model_info.loaded_at = datetime.now() self.active_image_model = model_name self.model_load_times[model_name] = datetime.now() # Generate image pipeline = self.hf_models[model_name] model_info.request_count += 1 logger.info(f"🎨 Generating image with {model_name}: {prompt[:50]}...") with torch.no_grad(): # FLUX Klein doesn't support negative_prompt, check pipeline type pipeline_kwargs = { "prompt": prompt, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "width": width, "height": height, } # Only add negative_prompt for models that support it (not FLUX) is_flux = "flux" in model_name.lower() if negative_prompt and not is_flux: pipeline_kwargs["negative_prompt"] = negative_prompt result = pipeline(**pipeline_kwargs) image = result.images[0] # Convert to base64 buffered = BytesIO() image.save(buffered, format="PNG") img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") generation_time_ms = (time.time() - start_time) * 1000 logger.info(f"✅ Image generated in {generation_time_ms:.0f}ms") return { "success": True, "model": model_name, "image_base64": img_base64, "width": width, "height": height, "generation_time_ms": generation_time_ms, "device": self.device } except Exception as e: logger.error(f"❌ Image generation failed: {e}", exc_info=True) raise # ========== FastAPI App ========== app = FastAPI( title="Swapper Service", description="Dynamic model loading service for Node #2", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Include cabinet API router (import after swapper is created) try: from app.cabinet_api import router as cabinet_router app.include_router(cabinet_router) logger.info("✅ Cabinet API router included") except ImportError: logger.warning("⚠️ cabinet_api module not found, skipping cabinet router") # Global Swapper instance swapper = SwapperService() @app.on_event("startup") async def startup(): """Initialize Swapper on startup""" await swapper.initialize() @app.on_event("shutdown") async def shutdown(): """Close Swapper on shutdown""" await swapper.close() # ========== API Endpoints ========== @app.get("/health") async def health(): """Health check endpoint""" status = await swapper.get_status() return { "status": "healthy", "service": "swapper-service", "active_model": status.active_model, "mode": status.mode } @app.get("/status", response_model=SwapperStatus) async def get_status(): """Get Swapper service status""" return await swapper.get_status() @app.get("/models") async def list_models(): """List all available models""" return { "models": [ { "name": model.name, "ollama_name": model.ollama_name, "type": model.type, "size_gb": model.size_gb, "priority": model.priority, "status": model.status.value } for model in swapper.models.values() ] } @app.get("/models/{model_name}") async def get_model_info(model_name: str): """Get information about a specific model""" if model_name not in swapper.models: raise HTTPException(status_code=404, detail=f"Model not found: {model_name}") model_info = swapper.models[model_name] return { "name": model_info.name, "ollama_name": model_info.ollama_name, "type": model_info.type, "size_gb": model_info.size_gb, "priority": model_info.priority, "status": model_info.status.value, "loaded_at": model_info.loaded_at.isoformat() if model_info.loaded_at else None, "unloaded_at": model_info.unloaded_at.isoformat() if model_info.unloaded_at else None, "total_uptime_seconds": swapper.model_uptime.get(model_name, 0.0) } @app.post("/models/{model_name}/load") async def load_model_endpoint(model_name: str): """Load a model""" success = await swapper.load_model(model_name) if success: return {"status": "success", "model": model_name, "message": f"Model {model_name} loaded"} raise HTTPException(status_code=500, detail=f"Failed to load model: {model_name}") @app.post("/models/{model_name}/unload") async def unload_model_endpoint(model_name: str): """Unload a model""" success = await swapper.unload_model(model_name) if success: return {"status": "success", "model": model_name, "message": f"Model {model_name} unloaded"} raise HTTPException(status_code=500, detail=f"Failed to unload model: {model_name}") @app.get("/metrics") async def get_metrics(model_name: Optional[str] = None): """Get metrics for model(s)""" metrics = await swapper.get_model_metrics(model_name) return { "metrics": [metric.dict() for metric in metrics] } @app.get("/metrics/{model_name}") async def get_model_metrics(model_name: str): """Get metrics for a specific model""" metrics = await swapper.get_model_metrics(model_name) if not metrics: raise HTTPException(status_code=404, detail=f"Model not found: {model_name}") return metrics[0].dict() # ========== Chat Completions API (OpenAI-compatible) ========== class ChatMessage(BaseModel): """Chat message""" role: str # system, user, assistant content: str class ChatCompletionRequest(BaseModel): """Chat completion request (OpenAI-compatible)""" model: str messages: List[ChatMessage] max_tokens: Optional[int] = 2048 temperature: Optional[float] = 0.7 stream: Optional[bool] = False class ChatCompletionResponse(BaseModel): """Chat completion response (OpenAI-compatible)""" id: str object: str = "chat.completion" created: int model: str choices: List[Dict[str, Any]] usage: Dict[str, int] @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def chat_completions(request: ChatCompletionRequest): """OpenAI-compatible chat completions endpoint""" import time try: # Extract system prompt and user messages system_prompt = None user_messages = [] for msg in request.messages: if msg.role == "system": system_prompt = msg.content elif msg.role == "user": user_messages.append(msg.content) # Combine user messages into prompt prompt = "\n".join(user_messages) # Generate response result = await swapper.generate( model_name=request.model, prompt=prompt, system_prompt=system_prompt, max_tokens=request.max_tokens or 2048, temperature=request.temperature or 0.7, stream=request.stream or False ) # Format response in OpenAI style return ChatCompletionResponse( id=f"chatcmpl-{int(time.time())}", created=int(time.time()), model=request.model, choices=[{ "index": 0, "message": { "role": "assistant", "content": result["response"] }, "finish_reason": "stop" if result.get("done", True) else None }], usage={ "prompt_tokens": result.get("prompt_eval_count", 0), "completion_tokens": result.get("eval_count", 0), "total_tokens": result.get("prompt_eval_count", 0) + result.get("eval_count", 0) } ) except Exception as e: logger.error(f"❌ Error in chat completions: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) # ========== Generate API (Ollama-compatible) ========== class GenerateRequest(BaseModel): """Generate request (Ollama-compatible)""" model: str prompt: str system: Optional[str] = None max_tokens: Optional[int] = 2048 temperature: Optional[float] = 0.7 stream: Optional[bool] = False @app.post("/generate") async def generate(request: GenerateRequest): """Ollama-compatible generate endpoint""" try: result = await swapper.generate( model_name=request.model, prompt=request.prompt, system_prompt=request.system, max_tokens=request.max_tokens or 2048, temperature=request.temperature or 0.7, stream=request.stream or False ) return result except Exception as e: logger.error(f"❌ Error in generate: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) # ========== VISION API Endpoints ========== class VisionRequest(BaseModel): """Vision (image description) request""" model: str = "qwen3-vl-8b" prompt: str = "Опиши це зображення коротко (2-3 речення)." images: List[str] # List of base64 encoded images (can include data: prefix) system: Optional[str] = None max_tokens: int = 1024 temperature: float = 0.7 @app.post("/vision") async def vision_endpoint(request: VisionRequest): """ Vision endpoint - analyze images with Vision-Language models. Models: - qwen3-vl-8b: Qwen3 Vision-Language model (8GB VRAM) Images should be base64 encoded. Can include data:image/... prefix or raw base64. """ try: import time start_time = time.time() model_name = request.model # Convert data URLs to raw base64 (Ollama expects base64 without prefix) processed_images = [] for img in request.images: if img.startswith("data:"): # Extract base64 part from data URL base64_part = img.split(",", 1)[1] if "," in img else img processed_images.append(base64_part) else: processed_images.append(img) logger.info(f"🖼️ Vision request: model={model_name}, images={len(processed_images)}, prompt={request.prompt[:50]}...") # Map model name to Ollama model ollama_model = model_name.replace("-", ":") # qwen3-vl-8b -> qwen3:vl-8b if model_name == "qwen3-vl-8b": ollama_model = "qwen3-vl:8b" # Build Ollama request ollama_payload = { "model": ollama_model, "prompt": request.prompt, "images": processed_images, "stream": False, "options": { "num_predict": request.max_tokens, "temperature": request.temperature } } if request.system: ollama_payload["system"] = request.system # Send to Ollama async with httpx.AsyncClient(timeout=120.0) as client: response = await client.post( f"{OLLAMA_BASE_URL}/api/generate", json=ollama_payload ) if response.status_code != 200: logger.error(f"❌ Ollama vision error: {response.status_code} - {response.text[:200]}") raise HTTPException(status_code=500, detail=f"Ollama error: {response.status_code}") result = response.json() vision_text = result.get("response", "") # Debug logging if not vision_text: logger.warning(f"⚠️ Empty response from Ollama! Result keys: {list(result.keys())}, error: {result.get('error', 'none')}") processing_time_ms = (time.time() - start_time) * 1000 logger.info(f"✅ Vision response: {len(vision_text)} chars in {processing_time_ms:.0f}ms") return { "success": True, "model": model_name, "text": vision_text, "processing_time_ms": processing_time_ms, "images_count": len(processed_images) } except HTTPException: raise except Exception as e: logger.error(f"❌ Vision endpoint error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/vision/models") async def vision_models(): """List available vision models""" vision_models = [m for m in swapper.models.values() if m.type == "vision"] return { "models": [ { "name": m.name, "type": m.type, "status": m.status.value, "size_gb": m.size_gb } for m in vision_models ] } # ========== OCR API Endpoints ========== class OCRRequest(BaseModel): """OCR request""" model: str = "got-ocr2" # Default to GOT-OCR2.0 image_base64: Optional[str] = None image_url: Optional[str] = None ocr_type: str = "ocr" # ocr, format, table @app.post("/ocr") async def ocr_endpoint( request: OCRRequest = None, file: UploadFile = File(None), model: str = Form("got-ocr2"), ocr_type: str = Form("ocr") ): """ OCR endpoint - process images with OCR models. Models: - got-ocr2: Best for documents, tables, formulas (7GB VRAM) - donut-base: Document parsing without OCR (3GB VRAM) - donut-cord: Receipt/invoice parsing (3GB VRAM) - trocr-base: Fast printed text OCR (2GB VRAM) OCR Types (for GOT-OCR2.0): - ocr: Standard OCR - format: Preserve formatting - table: Extract tables """ try: image_data = None model_name = model ocr_type_param = ocr_type # Get image from request if file: image_data = await file.read() elif request: model_name = request.model ocr_type_param = request.ocr_type if request.image_base64: image_data = base64.b64decode(request.image_base64) elif request.image_url: async with httpx.AsyncClient() as client: response = await client.get(request.image_url) if response.status_code == 200: image_data = response.content else: raise HTTPException(status_code=400, detail="Failed to download image") if not image_data: raise HTTPException(status_code=400, detail="No image provided") # Process with OCR result = await swapper.ocr_process(model_name, image_data, ocr_type_param) return result except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"❌ OCR endpoint error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/ocr/models") async def list_ocr_models(): """List available OCR models""" ocr_models = [ { "name": model.name, "hf_name": model.hf_name, "type": model.type, "size_gb": model.size_gb, "priority": model.priority, "status": model.status.value, "capabilities": model.capabilities, "loaded": model.name in swapper.hf_models } for model in swapper.models.values() if model.backend == ModelBackend.HUGGINGFACE and model.type == "ocr" ] return { "ocr_models": ocr_models, "active_ocr_model": swapper.active_ocr_model, "device": swapper.device } @app.post("/ocr/models/{model_name}/load") async def load_ocr_model(model_name: str): """Pre-load an OCR model (optional, models are lazy loaded by default)""" if model_name not in swapper.models: raise HTTPException(status_code=404, detail=f"Model not found: {model_name}") model_info = swapper.models[model_name] if model_info.backend != ModelBackend.HUGGINGFACE: raise HTTPException(status_code=400, detail="Not an OCR model") async with swapper.loading_lock: success = await swapper._load_hf_model(model_name) if success: model_info.status = ModelStatus.LOADED model_info.loaded_at = datetime.now() swapper.active_ocr_model = model_name swapper.model_load_times[model_name] = datetime.now() return {"status": "success", "model": model_name, "message": f"OCR model {model_name} loaded"} raise HTTPException(status_code=500, detail=f"Failed to load OCR model: {model_name}") @app.post("/ocr/models/{model_name}/unload") async def unload_ocr_model(model_name: str): """Unload an OCR model to free GPU memory""" if model_name not in swapper.hf_models: raise HTTPException(status_code=400, detail=f"Model not loaded: {model_name}") async with swapper.loading_lock: success = await swapper._unload_hf_model(model_name) if success: if model_name in swapper.models: swapper.models[model_name].status = ModelStatus.UNLOADED if swapper.active_ocr_model == model_name: swapper.active_ocr_model = None return {"status": "success", "model": model_name, "message": f"OCR model {model_name} unloaded"} raise HTTPException(status_code=500, detail=f"Failed to unload OCR model: {model_name}") # ========== STT (Speech-to-Text) API Endpoints ========== class STTRequest(BaseModel): """STT request""" model: str = "faster-whisper-large" audio_base64: Optional[str] = None audio_url: Optional[str] = None language: Optional[str] = None # auto-detect if not specified task: str = "transcribe" # transcribe or translate @app.post("/stt") async def stt_endpoint( file: UploadFile = File(None), model: str = Form("faster-whisper-large"), language: Optional[str] = Form(None), task: str = Form("transcribe") ): """ Speech-to-Text endpoint using Faster Whisper. Models: - faster-whisper-large: Best quality, 99 languages (3GB VRAM) - whisper-small: Fast transcription (0.5GB VRAM) """ import tempfile import os try: audio_data = None if file: audio_data = await file.read() if not audio_data: raise HTTPException(status_code=400, detail="No audio provided") # Save audio to temp file (faster-whisper requires file path) with tempfile.NamedTemporaryFile(delete=False, suffix=".ogg") as tmp_file: tmp_file.write(audio_data) tmp_path = tmp_file.name try: # Lazy load faster-whisper model stt_model = await _get_or_load_stt_model(model) # Transcribe logger.info(f"🎤 STT: Transcribing audio with {model}...") segments, info = stt_model.transcribe( tmp_path, language=language, task=task, beam_size=5, vad_filter=True, # Remove silence vad_parameters=dict(min_silence_duration_ms=500) ) # Collect all segments text_parts = [] for segment in segments: text_parts.append(segment.text.strip()) full_text = " ".join(text_parts) detected_language = info.language if hasattr(info, 'language') else language or "unknown" logger.info(f"✅ STT: Transcribed successfully. Language: {detected_language}, Text: {full_text[:100]}...") return { "success": True, "model": model, "text": full_text, "language": detected_language, "device": swapper.device } finally: # Clean up temp file if os.path.exists(tmp_path): os.unlink(tmp_path) except Exception as e: logger.error(f"❌ STT endpoint error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) # STT model cache _stt_models = {} async def _get_or_load_stt_model(model_name: str): """Get or load STT model (lazy loading)""" global _stt_models if model_name in _stt_models: return _stt_models[model_name] from faster_whisper import WhisperModel # Map model names to faster-whisper sizes model_map = { "faster-whisper-large": "large-v3", "faster-whisper-medium": "medium", "whisper-small": "small", "whisper-base": "base", "whisper-tiny": "tiny" } whisper_size = model_map.get(model_name, "small") logger.info(f"🔄 Loading STT model: {model_name} (size: {whisper_size})...") # Use GPU if available device = "cuda" if swapper.device == "cuda" else "cpu" compute_type = "float16" if device == "cuda" else "int8" stt_model = WhisperModel( whisper_size, device=device, compute_type=compute_type ) _stt_models[model_name] = stt_model logger.info(f"✅ STT model {model_name} loaded on {device}") return stt_model @app.get("/stt/models") async def list_stt_models(): """List available STT models""" stt_models = [ { "name": model.name, "hf_name": model.hf_name, "type": model.type, "size_gb": model.size_gb, "priority": model.priority, "status": model.status.value, "capabilities": model.capabilities, "loaded": model.name in swapper.hf_models } for model in swapper.models.values() if model.backend == ModelBackend.HUGGINGFACE and model.type == "stt" ] return { "stt_models": stt_models, "device": swapper.device } # ========== TTS (Text-to-Speech) API Endpoints ========== class TTSRequest(BaseModel): """TTS request""" model: str = "xtts-v2" text: str language: str = "uk" # Ukrainian by default speaker_wav_base64: Optional[str] = None # For voice cloning speed: float = 1.0 @app.post("/tts") async def tts_endpoint(request: TTSRequest): """ Text-to-Speech endpoint. Models: - xtts-v2: Best multilingual TTS with voice cloning (2GB VRAM) Languages: uk, en, es, fr, de, it, pt, pl, tr, ru, nl, cs, ar, zh-cn, ja, hu, ko """ try: model_name = request.model # Get model config tts_model_config = swapper.get_model_config(model_name, "tts") if not tts_model_config: raise HTTPException(status_code=400, detail=f"TTS model '{model_name}' not found") # Load XTTS model if not loaded if model_name not in swapper.hf_models: logger.info(f"🔊 Loading TTS model: {model_name}...") try: from TTS.api import TTS # XTTS-v2 model tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2") if swapper.device == "cuda": tts = tts.to(swapper.device) swapper.hf_models[model_name] = tts logger.info(f"✅ TTS model {model_name} loaded on {swapper.device}") except ImportError: logger.error("❌ TTS library not installed. Run: pip install TTS") raise HTTPException(status_code=503, detail="TTS library not installed") except Exception as e: logger.error(f"❌ Failed to load TTS model: {e}") raise HTTPException(status_code=500, detail=f"Failed to load TTS model: {e}") tts_model = swapper.hf_models[model_name] # Generate speech import tempfile import time start_time = time.time() with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: tmp_path = tmp_file.name try: # Check if voice cloning is requested if request.speaker_wav_base64: # Decode speaker reference audio speaker_audio = base64.b64decode(request.speaker_wav_base64) with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as speaker_file: speaker_file.write(speaker_audio) speaker_path = speaker_file.name # Generate with voice cloning tts_model.tts_to_file( text=request.text, file_path=tmp_path, speaker_wav=speaker_path, language=request.language, speed=request.speed ) os.unlink(speaker_path) else: # Generate with default speaker tts_model.tts_to_file( text=request.text, file_path=tmp_path, language=request.language, speed=request.speed ) # Read and encode audio with open(tmp_path, "rb") as f: audio_data = f.read() audio_base64 = base64.b64encode(audio_data).decode("utf-8") generation_time_ms = (time.time() - start_time) * 1000 logger.info(f"✅ TTS generated in {generation_time_ms:.0f}ms, {len(audio_data)} bytes") return { "success": True, "model": model_name, "audio_base64": audio_base64, "language": request.language, "generation_time_ms": generation_time_ms, "device": swapper.device } finally: if os.path.exists(tmp_path): os.unlink(tmp_path) except HTTPException: raise except Exception as e: logger.error(f"❌ TTS endpoint error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/tts/models") async def list_tts_models(): """List available TTS models""" tts_models = [ { "name": model.name, "hf_name": model.hf_name, "type": model.type, "size_gb": model.size_gb, "priority": model.priority, "status": model.status.value, "capabilities": model.capabilities, "loaded": model.name in swapper.hf_models } for model in swapper.models.values() if model.backend == ModelBackend.HUGGINGFACE and model.type == "tts" ] return { "tts_models": tts_models, "device": swapper.device } # ========== Document Processing API Endpoints ========== class DocumentRequest(BaseModel): """Document processing request""" model: str = "granite-docling" doc_base64: Optional[str] = None doc_url: Optional[str] = None output_format: str = "markdown" # markdown, json, doctags @app.post("/document") async def document_endpoint( file: UploadFile = File(None), model: str = Form("granite-docling"), output_format: str = Form("markdown") ): """ Document processing endpoint using Docling. Models: - granite-docling: IBM Granite for document structure (2.5GB VRAM) Output formats: - markdown: Clean markdown text - json: Structured JSON with document elements - text: Plain text extraction Supported files: PDF, DOCX, XLS/XLSX/XLSM/ODS, PPTX, TXT/MD/CSV/TSV, JSON/YAML/XML/HTML, RTF, ZIP, images. """ try: import time start_time = time.time() doc_data = None if file: doc_data = await file.read() if not doc_data: raise HTTPException(status_code=400, detail="No document provided") # Determine file type filename = file.filename if file else "document" file_ext = filename.split(".")[-1].lower() if "." in filename else "pdf" # Handle deterministic extraction for standard office/text formats if file_ext in [ "txt", "md", "markdown", "csv", "tsv", "xlsx", "xls", "xlsm", "ods", "json", "yaml", "yml", "xml", "html", "htm", "rtf", "pptx", "zip", ]: try: if file_ext == "zip": content = _zip_to_markdown(doc_data) output_format = "markdown" else: content = _extract_text_by_ext(filename, doc_data) output_format = ( "markdown" if file_ext in { "md", "markdown", "csv", "tsv", "xlsx", "xls", "xlsm", "ods", "json", "yaml", "yml", "xml", "html", "htm", "pptx", } else "text" ) processing_time_ms = (time.time() - start_time) * 1000 return { "success": True, "model": "text-extract", "output_format": output_format, "result": content, "filename": filename, "processing_time_ms": processing_time_ms, "device": swapper.device } except HTTPException: raise except Exception as e: logger.error(f"❌ Text extraction failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Text extraction failed: {e}") # Save to temp file import tempfile with tempfile.NamedTemporaryFile(suffix=f".{file_ext}", delete=False) as tmp_file: tmp_file.write(doc_data) tmp_path = tmp_file.name try: # Try to use docling library try: from docling.document_converter import DocumentConverter from docling.datamodel.base_models import InputFormat logger.info(f"📄 Processing document: {filename} ({len(doc_data)} bytes)") # Initialize converter (lazy load) if "docling_converter" not in swapper.hf_models: logger.info("🔄 Initializing Docling converter...") converter = DocumentConverter() swapper.hf_models["docling_converter"] = converter logger.info("✅ Docling converter initialized") converter = swapper.hf_models["docling_converter"] # Convert document result = converter.convert(tmp_path) doc = result.document # Format output if output_format == "markdown": content = doc.export_to_markdown() elif output_format == "json": content = doc.export_to_dict() elif output_format == "text": content = doc.export_to_text() else: content = doc.export_to_markdown() processing_time_ms = (time.time() - start_time) * 1000 logger.info(f"✅ Document processed in {processing_time_ms:.0f}ms") return { "success": True, "model": model, "output_format": output_format, "result": content, "filename": filename, "processing_time_ms": processing_time_ms, "device": swapper.device } except ImportError: # Fallback to pdfplumber/OCR for simpler extraction logger.warning("⚠️ Docling not installed, using fallback extraction") # For images, use OCR if file_ext in ["png", "jpg", "jpeg", "gif", "webp"]: ocr_result = await swapper.ocr_process("got-ocr2", doc_data, "ocr") return { "success": True, "model": "got-ocr2 (fallback)", "output_format": "text", "result": ocr_result.get("text", ""), "filename": filename, "processing_time_ms": (time.time() - start_time) * 1000, "device": swapper.device } # For common office/text formats, try deterministic extractors. if file_ext in { "docx", "txt", "md", "markdown", "csv", "tsv", "xlsx", "xls", "xlsm", "ods", "pptx", "json", "yaml", "yml", "xml", "html", "htm", "rtf", }: try: content = _extract_text_by_ext(filename, doc_data) out_fmt = "markdown" if file_ext not in {"txt", "rtf"} else "text" return { "success": True, "model": "text-extract (fallback)", "output_format": out_fmt, "result": content, "filename": filename, "processing_time_ms": (time.time() - start_time) * 1000, "device": swapper.device } except Exception as e: logger.error(f"Text fallback failed for .{file_ext}: {e}") raise HTTPException(status_code=500, detail=f"Extraction failed for .{file_ext}") # For PDFs, try pdfplumber if file_ext == "pdf": try: import pdfplumber text_content = [] with pdfplumber.open(tmp_path) as pdf: for page in pdf.pages: text = page.extract_text() if text: text_content.append(text) content = "\n\n".join(text_content) return { "success": True, "model": "pdfplumber (fallback)", "output_format": "text", "result": content, "filename": filename, "processing_time_ms": (time.time() - start_time) * 1000, "device": "cpu" } except ImportError: pass # For other documents, return error raise HTTPException( status_code=503, detail="Document processing unavailable for this type. Supported: office/text/image/zip standard formats." ) finally: if os.path.exists(tmp_path): os.unlink(tmp_path) except HTTPException: raise except Exception as e: logger.error(f"❌ Document endpoint error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/document/models") async def list_document_models(): """List available document processing models""" doc_models = [ { "name": model.name, "hf_name": model.hf_name, "type": model.type, "size_gb": model.size_gb, "priority": model.priority, "status": model.status.value, "capabilities": model.capabilities, "loaded": model.name in swapper.hf_models } for model in swapper.models.values() if model.backend == ModelBackend.HUGGINGFACE and model.type == "document" ] return { "document_models": doc_models, "device": swapper.device } # ========== Image Generation API Endpoints (FLUX) ========== class ImageGenerateRequest(BaseModel): """Image generation request""" model: str = "flux-klein-4b" prompt: str negative_prompt: str = "" num_inference_steps: int = 50 guidance_scale: float = 4.0 width: int = 1024 height: int = 1024 @app.post("/image/generate") async def image_generate_endpoint(request: ImageGenerateRequest): """ Generate image using diffusion model (lazy loaded). Models: - flux-klein-4b: FLUX.2 Klein 4B (15.4GB VRAM, lazy loaded on demand) The model is loaded on first request and unloaded when VRAM is needed for other models. """ try: result = await swapper.image_generate( model_name=request.model, prompt=request.prompt, negative_prompt=request.negative_prompt, num_inference_steps=request.num_inference_steps, guidance_scale=request.guidance_scale, width=request.width, height=request.height ) return result except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(f"❌ Image generate endpoint error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/image/models") async def list_image_models(): """List available image generation models""" image_models = [ { "name": model.name, "hf_name": model.hf_name, "type": model.type, "size_gb": model.size_gb, "priority": model.priority, "status": model.status.value, "capabilities": model.capabilities, "loaded": model.name in swapper.hf_models } for model in swapper.models.values() if model.backend == ModelBackend.HUGGINGFACE and model.type == "image_generation" ] return { "image_models": image_models, "active_image_model": swapper.active_image_model, "device": swapper.device } @app.post("/image/models/{model_name}/load") async def load_image_model(model_name: str): """Pre-load an image generation model (optional, models are lazy loaded by default)""" if model_name not in swapper.models: raise HTTPException(status_code=404, detail=f"Model not found: {model_name}") model_info = swapper.models[model_name] if model_info.type != "image_generation": raise HTTPException(status_code=400, detail="Not an image generation model") async with swapper.loading_lock: success = await swapper._load_hf_model(model_name) if success: model_info.status = ModelStatus.LOADED model_info.loaded_at = datetime.now() swapper.active_image_model = model_name swapper.model_load_times[model_name] = datetime.now() return {"status": "success", "model": model_name, "message": f"Image model {model_name} loaded"} raise HTTPException(status_code=500, detail=f"Failed to load image model: {model_name}") @app.post("/image/models/{model_name}/unload") async def unload_image_model(model_name: str): """Unload an image generation model to free GPU memory""" if model_name not in swapper.hf_models: raise HTTPException(status_code=400, detail=f"Model not loaded: {model_name}") async with swapper.loading_lock: success = await swapper._unload_hf_model(model_name) if success: if model_name in swapper.models: swapper.models[model_name].status = ModelStatus.UNLOADED if swapper.active_image_model == model_name: swapper.active_image_model = None return {"status": "success", "model": model_name, "message": f"Image model {model_name} unloaded"} raise HTTPException(status_code=500, detail=f"Failed to unload image model: {model_name}") # ========== Web Scraping API Endpoints ========== class WebExtractRequest(BaseModel): """Web content extraction request""" url: str method: str = "auto" # auto, jina, trafilatura, crawl4ai include_links: bool = False include_images: bool = False class WebSearchRequest(BaseModel): """Web search request""" query: str max_results: int = 10 engine: str = "duckduckgo" # duckduckgo, google @app.post("/web/extract") async def web_extract(request: WebExtractRequest): """ Extract content from URL using multiple methods. Methods: - jina: Jina Reader API (free, JS support, cloud) - trafilatura: Local extraction (fast, no JS) - crawl4ai: Full crawling (JS support, local) - requires separate service - auto: Try jina first, fallback to trafilatura """ url = request.url method = request.method async def extract_with_jina(url: str) -> dict: """Extract using Jina Reader API (free)""" try: jina_url = f"https://r.jina.ai/{url}" async with httpx.AsyncClient(timeout=30.0) as client: response = await client.get(jina_url) if response.status_code == 200: return { "success": True, "method": "jina", "content": response.text, "url": url } else: return {"success": False, "error": f"Jina returned {response.status_code}"} except Exception as e: return {"success": False, "error": str(e)} async def extract_with_trafilatura(url: str) -> dict: """Extract using Trafilatura (local)""" try: import trafilatura downloaded = trafilatura.fetch_url(url) if downloaded: text = trafilatura.extract( downloaded, include_links=request.include_links, include_images=request.include_images ) return { "success": True, "method": "trafilatura", "content": text, "url": url } return {"success": False, "error": "Failed to download page"} except ImportError: return {"success": False, "error": "Trafilatura not installed"} except Exception as e: return {"success": False, "error": str(e)} async def extract_with_crawl4ai(url: str) -> dict: """Extract using Crawl4AI service""" try: crawl4ai_url = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235") async with httpx.AsyncClient(timeout=60.0) as client: response = await client.post( f"{crawl4ai_url}/crawl", json={"urls": [url], "word_count_threshold": 10} ) if response.status_code == 200: data = response.json() result = data.get("results", [{}])[0] # Get markdown - can be string or dict with raw_markdown markdown_data = result.get("markdown", "") if isinstance(markdown_data, dict): content = markdown_data.get("raw_markdown", "") or markdown_data.get("fit_markdown", "") else: content = markdown_data # Fallback to cleaned_html if not content: content = result.get("cleaned_html", "") or result.get("extracted_content", "") # Last resort: strip HTML tags if not content and result.get("html"): import re content = re.sub(r'<[^>]+>', ' ', result.get("html", "")) content = re.sub(r'\s+', ' ', content).strip() # Limit size for LLM context if len(content) > 50000: content = content[:50000] + "\n\n[... truncated ...]" return { "success": bool(content), "method": "crawl4ai", "content": content, "url": url } return {"success": False, "error": f"Crawl4AI returned {response.status_code}"} except Exception as e: return {"success": False, "error": str(e)} # Execute based on method if method == "jina": result = await extract_with_jina(url) elif method == "trafilatura": result = await extract_with_trafilatura(url) elif method == "crawl4ai": result = await extract_with_crawl4ai(url) elif method == "auto": # Try jina first (JS support), fallback to trafilatura result = await extract_with_jina(url) if not result.get("success"): logger.info(f"Jina failed, trying trafilatura for {url}") result = await extract_with_trafilatura(url) else: raise HTTPException(status_code=400, detail=f"Unknown method: {method}") if not result.get("success"): raise HTTPException(status_code=500, detail=result.get("error", "Extraction failed")) return result @app.post("/web/search") async def web_search(request: WebSearchRequest): """ Search the web using multiple engines with fallback. Priority: 1) DDGS (DuckDuckGo) 2) Google Search """ formatted_results = [] engine_used = "none" # Method 1: Try DDGS (new package name) try: from ddgs import DDGS ddgs = DDGS() results = list(ddgs.text( request.query, max_results=request.max_results, region="wt-wt" # Worldwide )) if results: for idx, result in enumerate(results): formatted_results.append({ "position": idx + 1, "title": result.get("title", ""), "url": result.get("href", result.get("link", "")), "snippet": result.get("body", result.get("snippet", "")) }) engine_used = "ddgs" logger.info(f"✅ DDGS search found {len(formatted_results)} results for: {request.query[:50]}") except ImportError: logger.warning("DDGS not installed, trying Google search") except Exception as e: logger.warning(f"DDGS search failed: {e}, trying Google search") # Method 2: Fallback to Google search if not formatted_results: try: from googlesearch import search as google_search results = list(google_search(request.query, num_results=request.max_results, lang="uk")) if results: for idx, url in enumerate(results): formatted_results.append({ "position": idx + 1, "title": url.split("/")[-1].replace("-", " ").replace("_", " ")[:60] or "Result", "url": url, "snippet": "" }) engine_used = "google" logger.info(f"✅ Google search found {len(formatted_results)} results for: {request.query[:50]}") except ImportError: logger.warning("Google search not installed") except Exception as e: logger.warning(f"Google search failed: {e}") # Return results or empty return { "success": len(formatted_results) > 0, "query": request.query, "results": formatted_results, "total": len(formatted_results), "engine": engine_used } @app.get("/web/read/{url:path}") async def web_read_simple(url: str): """ Simple GET endpoint to read a URL (uses Jina by default). Example: GET /web/read/https://example.com """ request = WebExtractRequest(url=url, method="auto") return await web_extract(request) @app.get("/web/status") async def web_status(): """Check availability of web scraping methods""" # Check Trafilatura try: import trafilatura trafilatura_available = True except ImportError: trafilatura_available = False # Check DuckDuckGo try: from duckduckgo_search import DDGS ddgs_available = True except ImportError: ddgs_available = False # Check Crawl4AI crawl4ai_url = os.getenv("CRAWL4AI_URL", "http://crawl4ai:11235") crawl4ai_available = False try: async with httpx.AsyncClient(timeout=5.0) as client: response = await client.get(f"{crawl4ai_url}/health") crawl4ai_available = response.status_code == 200 except: pass return { "methods": { "jina": {"available": True, "type": "cloud", "js_support": True}, "trafilatura": {"available": trafilatura_available, "type": "local", "js_support": False}, "crawl4ai": {"available": crawl4ai_available, "type": "local", "js_support": True} }, "search": { "duckduckgo": {"available": ddgs_available} } } # ========== Video Generation API (Grok xAI) ========== GROK_API_KEY = os.getenv("GROK_API_KEY", "") GROK_API_URL = "https://api.x.ai/v1" class VideoGenerateRequest(BaseModel): """Video generation request via Grok""" prompt: str duration: int = 6 # seconds (max 6 for Grok) style: str = "cinematic" # cinematic, anime, realistic, abstract aspect_ratio: str = "16:9" # 16:9, 9:16, 1:1 @app.post("/video/generate") async def video_generate(request: VideoGenerateRequest): """ Generate image using Grok (xAI) API. Note: Grok API currently supports image generation only (not video). For video-like content, generate multiple frames and combine externally. """ if not GROK_API_KEY: raise HTTPException(status_code=503, detail="GROK_API_KEY not configured") try: async with httpx.AsyncClient(timeout=120.0) as client: # Grok image generation endpoint response = await client.post( f"{GROK_API_URL}/images/generations", headers={ "Authorization": f"Bearer {GROK_API_KEY}", "Content-Type": "application/json" }, json={ "model": "grok-2-image-1212", # Correct model name "prompt": f"{request.prompt}, {request.style} style", "n": 1, "response_format": "url" } ) if response.status_code == 200: data = response.json() return { "success": True, "prompt": request.prompt, "style": request.style, "type": "image", # Note: video not available via API "result": data, "provider": "grok-xai", "note": "Grok API supports image generation. Video generation is available only in xAI app." } else: logger.error(f"Grok API error: {response.status_code} - {response.text}") raise HTTPException( status_code=response.status_code, detail=f"Grok API error: {response.text}" ) except httpx.TimeoutException: raise HTTPException(status_code=504, detail="Image generation timeout (>120s)") except Exception as e: logger.error(f"Image generation error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/video/status") async def video_status(): """Check Grok image/video generation service status""" return { "service": "grok-xai", "api_key_configured": bool(GROK_API_KEY), "capabilities": { "image_generation": True, # grok-2-image-1212 "video_generation": False, # Not available via API (only in xAI app) "vision_analysis": True # grok-2-vision-1212 }, "models": { "image": "grok-2-image-1212", "vision": "grok-2-vision-1212", "chat": ["grok-3", "grok-3-mini", "grok-4-0709"] }, "supported_styles": ["cinematic", "anime", "realistic", "abstract", "photorealistic"] } # ========== Multimodal Stack Summary ========== @app.get("/multimodal") async def get_multimodal_stack(): """Get full multimodal stack status""" def get_models_by_type(model_type: str): return [ { "name": m.name, "size_gb": m.size_gb, "status": m.status.value, "loaded": m.name in swapper.hf_models } for m in swapper.models.values() if m.type == model_type ] return { "device": swapper.device, "cuda_available": TORCH_AVAILABLE and torch.cuda.is_available(), "stack": { "llm": get_models_by_type("llm"), "vision": get_models_by_type("vision"), "math": get_models_by_type("math"), "ocr": get_models_by_type("ocr"), "document": get_models_by_type("document"), "stt": get_models_by_type("stt"), "tts": get_models_by_type("tts"), "embedding": get_models_by_type("embedding"), "image_generation": get_models_by_type("image_generation"), "video_generation": {"provider": "grok-xai", "available": bool(GROK_API_KEY)} }, "active_models": { "llm": swapper.active_model, "ocr": swapper.active_ocr_model } } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8890)