feat: add tests and integrate dots.ocr model
G.2.5 - Tests: - Add pytest test suite with fixtures - test_preprocessing.py - PDF/image loading, normalization, validation - test_postprocessing.py - chunks, QA pairs, markdown generation - test_inference.py - dummy parser and inference functions - test_api.py - API endpoint tests - Add pytest.ini configuration G.1.3 - dots.ocr Integration: - Update model_loader.py with real model loading code - Support for AutoModelForVision2Seq and AutoProcessor - Device handling (CUDA/CPU/MPS) with fallback - Error handling with dummy fallback option - Update inference.py with real model inference - Process images through model - Generate and decode outputs - Parse model output to blocks - Add model_output_parser.py - Parse JSON or plain text model output - Convert to structured blocks - Layout detection support (placeholder) Dependencies: - Add pytest, pytest-asyncio, httpx for testing
This commit is contained in:
@@ -6,6 +6,7 @@ import logging
|
||||
from typing import Literal, Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from app.schemas import ParsedDocument, ParsedPage, ParsedBlock, BBox
|
||||
@@ -14,6 +15,7 @@ from app.runtime.preprocessing import (
|
||||
convert_pdf_to_images, load_image, prepare_images_for_model
|
||||
)
|
||||
from app.runtime.postprocessing import build_parsed_document
|
||||
from app.runtime.model_output_parser import parse_model_output_to_blocks
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,36 +65,46 @@ def parse_document_from_images(
|
||||
|
||||
for idx, image in enumerate(prepared_images, start=1):
|
||||
try:
|
||||
# TODO: Implement actual inference with dots.ocr
|
||||
# Example:
|
||||
# inputs = model["processor"](images=image, return_tensors="pt")
|
||||
# outputs = model["model"].generate(**inputs)
|
||||
# text = model["processor"].decode(outputs[0], skip_special_tokens=True)
|
||||
#
|
||||
# # Parse model output into blocks
|
||||
# blocks = parse_model_output_to_blocks(text, image.size)
|
||||
#
|
||||
# pages_data.append({
|
||||
# "blocks": blocks,
|
||||
# "width": image.width,
|
||||
# "height": image.height
|
||||
# })
|
||||
# Prepare inputs for model
|
||||
inputs = model["processor"](images=image, return_tensors="pt")
|
||||
|
||||
# Move inputs to device
|
||||
device = model["device"]
|
||||
if device != "cpu":
|
||||
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in inputs.items()}
|
||||
|
||||
# Generate output
|
||||
with torch.no_grad():
|
||||
outputs = model["model"].generate(
|
||||
**inputs,
|
||||
max_new_tokens=2048, # Adjust based on model capabilities
|
||||
do_sample=False # Deterministic output
|
||||
)
|
||||
|
||||
# Decode output
|
||||
generated_text = model["processor"].decode(
|
||||
outputs[0],
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
logger.debug(f"Model output for page {idx}: {generated_text[:100]}...")
|
||||
|
||||
# Parse model output into blocks
|
||||
blocks = parse_model_output_to_blocks(
|
||||
generated_text,
|
||||
image.size,
|
||||
page_num=idx
|
||||
)
|
||||
|
||||
# For now, use dummy for each page
|
||||
logger.debug(f"Processing page {idx} with model (placeholder)")
|
||||
pages_data.append({
|
||||
"blocks": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"text": f"Page {idx} content (model output placeholder)",
|
||||
"bbox": {"x": 0, "y": 0, "width": image.width, "height": image.height},
|
||||
"reading_order": 1
|
||||
}
|
||||
],
|
||||
"blocks": blocks,
|
||||
"width": image.width,
|
||||
"height": image.height
|
||||
})
|
||||
|
||||
logger.info(f"Processed page {idx}/{len(prepared_images)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing page {idx}: {e}", exc_info=True)
|
||||
# Continue with other pages
|
||||
|
||||
@@ -36,30 +36,62 @@ def load_model() -> Optional[object]:
|
||||
logger.info(f"Device: {settings.PARSER_DEVICE}")
|
||||
|
||||
try:
|
||||
# TODO: Implement actual model loading
|
||||
# Example for dots.ocr (adjust based on actual model structure):
|
||||
# from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
#
|
||||
# processor = AutoProcessor.from_pretrained(settings.PARSER_MODEL_NAME)
|
||||
# model = AutoModelForVision2Seq.from_pretrained(
|
||||
# settings.PARSER_MODEL_NAME,
|
||||
# device_map=settings.PARSER_DEVICE if settings.PARSER_DEVICE != "cpu" else None,
|
||||
# torch_dtype=torch.float16 if settings.PARSER_DEVICE != "cpu" else torch.float32
|
||||
# )
|
||||
#
|
||||
# if settings.PARSER_DEVICE == "cpu":
|
||||
# model = model.to("cpu")
|
||||
#
|
||||
# _model = {
|
||||
# "model": model,
|
||||
# "processor": processor,
|
||||
# "device": settings.PARSER_DEVICE
|
||||
# }
|
||||
#
|
||||
# logger.info("Model loaded successfully")
|
||||
# Load dots.ocr model
|
||||
# Note: Adjust imports and model class based on actual dots.ocr implementation
|
||||
# This is a template that should work with most Vision-Language models
|
||||
|
||||
# For now, return None (will use dummy parser)
|
||||
logger.warning("Model loading not yet implemented, will use dummy parser")
|
||||
try:
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
import torch
|
||||
except ImportError:
|
||||
logger.error("transformers or torch not installed. Install with: pip install transformers torch")
|
||||
if not settings.ALLOW_DUMMY_FALLBACK:
|
||||
raise
|
||||
return None
|
||||
|
||||
logger.info(f"Loading model from: {settings.PARSER_MODEL_NAME}")
|
||||
|
||||
# Load processor
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
settings.PARSER_MODEL_NAME,
|
||||
trust_remote_code=True # If model has custom code
|
||||
)
|
||||
|
||||
# Determine device and dtype
|
||||
device = settings.PARSER_DEVICE
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
logger.warning("CUDA not available, falling back to CPU")
|
||||
device = "cpu"
|
||||
elif device == "mps" and not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available():
|
||||
logger.warning("MPS not available, falling back to CPU")
|
||||
device = "cpu"
|
||||
|
||||
dtype = torch.float16 if device != "cpu" else torch.float32
|
||||
|
||||
# Load model
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
settings.PARSER_MODEL_NAME,
|
||||
device_map=device if device != "cpu" else None,
|
||||
torch_dtype=dtype,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
if device == "cpu":
|
||||
model = model.to("cpu")
|
||||
|
||||
# Store model and processor
|
||||
_model = {
|
||||
"model": model,
|
||||
"processor": processor,
|
||||
"device": device
|
||||
}
|
||||
|
||||
logger.info(f"Model loaded successfully on device: {device}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}", exc_info=True)
|
||||
if not settings.ALLOW_DUMMY_FALLBACK:
|
||||
raise
|
||||
_model = None
|
||||
|
||||
except ImportError as e:
|
||||
|
||||
150
services/parser-service/app/runtime/model_output_parser.py
Normal file
150
services/parser-service/app/runtime/model_output_parser.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Parser for dots.ocr model output
|
||||
Converts model output to structured blocks
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_model_output_to_blocks(
|
||||
model_output: str,
|
||||
image_size: tuple[int, int],
|
||||
page_num: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parse dots.ocr model output into structured blocks
|
||||
|
||||
Args:
|
||||
model_output: Raw text output from model (may be JSON or plain text)
|
||||
image_size: (width, height) of the image
|
||||
page_num: Page number
|
||||
|
||||
Returns:
|
||||
List of block dictionaries
|
||||
"""
|
||||
blocks = []
|
||||
|
||||
try:
|
||||
# Try to parse as JSON first (if model outputs structured JSON)
|
||||
try:
|
||||
output_data = json.loads(model_output)
|
||||
if isinstance(output_data, dict) and "blocks" in output_data:
|
||||
# Model outputs structured format
|
||||
return output_data["blocks"]
|
||||
elif isinstance(output_data, list):
|
||||
# Model outputs list of blocks
|
||||
return output_data
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# Not JSON, treat as plain text
|
||||
pass
|
||||
|
||||
# Parse plain text output
|
||||
# This is a simple heuristic - adjust based on actual dots.ocr output format
|
||||
lines = model_output.strip().split('\n')
|
||||
|
||||
current_block = None
|
||||
reading_order = 1
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Heuristic: lines starting with # are headings
|
||||
if line.startswith('#'):
|
||||
# Save previous block
|
||||
if current_block:
|
||||
blocks.append(current_block)
|
||||
|
||||
# New heading block
|
||||
current_block = {
|
||||
"type": "heading",
|
||||
"text": line.lstrip('#').strip(),
|
||||
"bbox": {
|
||||
"x": 0,
|
||||
"y": reading_order * 30,
|
||||
"width": image_size[0],
|
||||
"height": 30
|
||||
},
|
||||
"reading_order": reading_order
|
||||
}
|
||||
reading_order += 1
|
||||
else:
|
||||
# Regular paragraph
|
||||
if current_block and current_block["type"] == "paragraph":
|
||||
# Append to existing paragraph
|
||||
current_block["text"] += " " + line
|
||||
else:
|
||||
# Save previous block
|
||||
if current_block:
|
||||
blocks.append(current_block)
|
||||
|
||||
# New paragraph block
|
||||
current_block = {
|
||||
"type": "paragraph",
|
||||
"text": line,
|
||||
"bbox": {
|
||||
"x": 0,
|
||||
"y": reading_order * 30,
|
||||
"width": image_size[0],
|
||||
"height": 30
|
||||
},
|
||||
"reading_order": reading_order
|
||||
}
|
||||
reading_order += 1
|
||||
|
||||
# Save last block
|
||||
if current_block:
|
||||
blocks.append(current_block)
|
||||
|
||||
# If no blocks were created, create a single paragraph with all text
|
||||
if not blocks:
|
||||
blocks.append({
|
||||
"type": "paragraph",
|
||||
"text": model_output.strip(),
|
||||
"bbox": {
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": image_size[0],
|
||||
"height": image_size[1]
|
||||
},
|
||||
"reading_order": 1
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing model output: {e}", exc_info=True)
|
||||
# Fallback: create single block with raw output
|
||||
blocks = [{
|
||||
"type": "paragraph",
|
||||
"text": model_output.strip() if model_output else "",
|
||||
"bbox": {
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": image_size[0],
|
||||
"height": image_size[1]
|
||||
},
|
||||
"reading_order": 1
|
||||
}]
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def extract_layout_info(model_output: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract layout information from model output (if available)
|
||||
|
||||
Args:
|
||||
model_output: Model output dictionary
|
||||
|
||||
Returns:
|
||||
Layout info dictionary or None
|
||||
"""
|
||||
# This function should be customized based on actual dots.ocr output format
|
||||
# For now, return None (no layout info)
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user