Files
microdao-daarion/services/parser-service/app/runtime/model_loader.py

120 lines
3.4 KiB
Python

"""
Model loader for dots.ocr
Handles lazy loading and GPU/CPU fallback
"""
import logging
from typing import Optional, Literal
from pathlib import Path
from app.core.config import settings
logger = logging.getLogger(__name__)
# Global model instance
_model: Optional[object] = None
def load_model() -> Optional[object]:
"""
Load dots.ocr model
Returns:
Loaded model instance or None if loading fails
"""
global _model
if _model is not None:
return _model
# Check if dummy mode is enabled
if settings.USE_DUMMY_PARSER:
logger.info("Dummy parser mode enabled, skipping model loading")
return None
logger.info(f"Loading model: {settings.PARSER_MODEL_NAME}")
logger.info(f"Device: {settings.PARSER_DEVICE}")
try:
# Load dots.ocr model
# Note: Adjust imports and model class based on actual dots.ocr implementation
# This is a template that should work with most Vision-Language models
try:
from transformers import AutoModelForVision2Seq, AutoProcessor
import torch
except ImportError:
logger.error("transformers or torch not installed. Install with: pip install transformers torch")
if not settings.ALLOW_DUMMY_FALLBACK:
raise
return None
logger.info(f"Loading model from: {settings.PARSER_MODEL_NAME}")
# Load processor
processor = AutoProcessor.from_pretrained(
settings.PARSER_MODEL_NAME,
trust_remote_code=True # If model has custom code
)
# Determine device and dtype
device = settings.PARSER_DEVICE
if device == "cuda" and not torch.cuda.is_available():
logger.warning("CUDA not available, falling back to CPU")
device = "cpu"
elif device == "mps" and not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available():
logger.warning("MPS not available, falling back to CPU")
device = "cpu"
dtype = torch.float16 if device != "cpu" else torch.float32
# Load model
model = AutoModelForVision2Seq.from_pretrained(
settings.PARSER_MODEL_NAME,
device_map=device if device != "cpu" else None,
torch_dtype=dtype,
trust_remote_code=True
)
if device == "cpu":
model = model.to("cpu")
# Store model and processor
_model = {
"model": model,
"processor": processor,
"device": device
}
logger.info(f"Model loaded successfully on device: {device}")
except ImportError as e:
logger.error(f"Required packages not installed: {e}")
if not settings.ALLOW_DUMMY_FALLBACK:
raise
_model = None
except Exception as e:
logger.error(f"Failed to load model: {e}", exc_info=True)
if not settings.ALLOW_DUMMY_FALLBACK:
raise
_model = None
return _model
def get_model() -> Optional[object]:
"""Get loaded model instance"""
if _model is None:
return load_model()
return _model
def unload_model():
"""Unload model from memory"""
global _model
if _model is not None:
# TODO: Proper cleanup
_model = None
logger.info("Model unloaded")