feat: add tests and integrate dots.ocr model

G.2.5 - Tests:
- Add pytest test suite with fixtures
- test_preprocessing.py - PDF/image loading, normalization, validation
- test_postprocessing.py - chunks, QA pairs, markdown generation
- test_inference.py - dummy parser and inference functions
- test_api.py - API endpoint tests
- Add pytest.ini configuration

G.1.3 - dots.ocr Integration:
- Update model_loader.py with real model loading code
  - Support for AutoModelForVision2Seq and AutoProcessor
  - Device handling (CUDA/CPU/MPS) with fallback
  - Error handling with dummy fallback option
- Update inference.py with real model inference
  - Process images through model
  - Generate and decode outputs
  - Parse model output to blocks
- Add model_output_parser.py
  - Parse JSON or plain text model output
  - Convert to structured blocks
  - Layout detection support (placeholder)

Dependencies:
- Add pytest, pytest-asyncio, httpx for testing
This commit is contained in:
Apple
2025-11-15 13:25:01 -08:00
parent 62cb1d2108
commit 2a353040f6
11 changed files with 848 additions and 47 deletions

View File

@@ -6,6 +6,7 @@ import logging
from typing import Literal, Optional, List
from pathlib import Path
import torch
from PIL import Image
from app.schemas import ParsedDocument, ParsedPage, ParsedBlock, BBox
@@ -14,6 +15,7 @@ from app.runtime.preprocessing import (
convert_pdf_to_images, load_image, prepare_images_for_model
)
from app.runtime.postprocessing import build_parsed_document
from app.runtime.model_output_parser import parse_model_output_to_blocks
from app.core.config import settings
logger = logging.getLogger(__name__)
@@ -63,36 +65,46 @@ def parse_document_from_images(
for idx, image in enumerate(prepared_images, start=1):
try:
# TODO: Implement actual inference with dots.ocr
# Example:
# inputs = model["processor"](images=image, return_tensors="pt")
# outputs = model["model"].generate(**inputs)
# text = model["processor"].decode(outputs[0], skip_special_tokens=True)
#
# # Parse model output into blocks
# blocks = parse_model_output_to_blocks(text, image.size)
#
# pages_data.append({
# "blocks": blocks,
# "width": image.width,
# "height": image.height
# })
# Prepare inputs for model
inputs = model["processor"](images=image, return_tensors="pt")
# Move inputs to device
device = model["device"]
if device != "cpu":
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v
for k, v in inputs.items()}
# Generate output
with torch.no_grad():
outputs = model["model"].generate(
**inputs,
max_new_tokens=2048, # Adjust based on model capabilities
do_sample=False # Deterministic output
)
# Decode output
generated_text = model["processor"].decode(
outputs[0],
skip_special_tokens=True
)
logger.debug(f"Model output for page {idx}: {generated_text[:100]}...")
# Parse model output into blocks
blocks = parse_model_output_to_blocks(
generated_text,
image.size,
page_num=idx
)
# For now, use dummy for each page
logger.debug(f"Processing page {idx} with model (placeholder)")
pages_data.append({
"blocks": [
{
"type": "paragraph",
"text": f"Page {idx} content (model output placeholder)",
"bbox": {"x": 0, "y": 0, "width": image.width, "height": image.height},
"reading_order": 1
}
],
"blocks": blocks,
"width": image.width,
"height": image.height
})
logger.info(f"Processed page {idx}/{len(prepared_images)}")
except Exception as e:
logger.error(f"Error processing page {idx}: {e}", exc_info=True)
# Continue with other pages

View File

@@ -36,30 +36,62 @@ def load_model() -> Optional[object]:
logger.info(f"Device: {settings.PARSER_DEVICE}")
try:
# TODO: Implement actual model loading
# Example for dots.ocr (adjust based on actual model structure):
# from transformers import AutoModelForVision2Seq, AutoProcessor
#
# processor = AutoProcessor.from_pretrained(settings.PARSER_MODEL_NAME)
# model = AutoModelForVision2Seq.from_pretrained(
# settings.PARSER_MODEL_NAME,
# device_map=settings.PARSER_DEVICE if settings.PARSER_DEVICE != "cpu" else None,
# torch_dtype=torch.float16 if settings.PARSER_DEVICE != "cpu" else torch.float32
# )
#
# if settings.PARSER_DEVICE == "cpu":
# model = model.to("cpu")
#
# _model = {
# "model": model,
# "processor": processor,
# "device": settings.PARSER_DEVICE
# }
#
# logger.info("Model loaded successfully")
# Load dots.ocr model
# Note: Adjust imports and model class based on actual dots.ocr implementation
# This is a template that should work with most Vision-Language models
# For now, return None (will use dummy parser)
logger.warning("Model loading not yet implemented, will use dummy parser")
try:
from transformers import AutoModelForVision2Seq, AutoProcessor
import torch
except ImportError:
logger.error("transformers or torch not installed. Install with: pip install transformers torch")
if not settings.ALLOW_DUMMY_FALLBACK:
raise
return None
logger.info(f"Loading model from: {settings.PARSER_MODEL_NAME}")
# Load processor
processor = AutoProcessor.from_pretrained(
settings.PARSER_MODEL_NAME,
trust_remote_code=True # If model has custom code
)
# Determine device and dtype
device = settings.PARSER_DEVICE
if device == "cuda" and not torch.cuda.is_available():
logger.warning("CUDA not available, falling back to CPU")
device = "cpu"
elif device == "mps" and not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available():
logger.warning("MPS not available, falling back to CPU")
device = "cpu"
dtype = torch.float16 if device != "cpu" else torch.float32
# Load model
model = AutoModelForVision2Seq.from_pretrained(
settings.PARSER_MODEL_NAME,
device_map=device if device != "cpu" else None,
torch_dtype=dtype,
trust_remote_code=True
)
if device == "cpu":
model = model.to("cpu")
# Store model and processor
_model = {
"model": model,
"processor": processor,
"device": device
}
logger.info(f"Model loaded successfully on device: {device}")
except Exception as e:
logger.error(f"Failed to load model: {e}", exc_info=True)
if not settings.ALLOW_DUMMY_FALLBACK:
raise
_model = None
except ImportError as e:

View File

@@ -0,0 +1,150 @@
"""
Parser for dots.ocr model output
Converts model output to structured blocks
"""
import logging
import json
from typing import List, Dict, Any, Optional
from PIL import Image
logger = logging.getLogger(__name__)
def parse_model_output_to_blocks(
model_output: str,
image_size: tuple[int, int],
page_num: int
) -> List[Dict[str, Any]]:
"""
Parse dots.ocr model output into structured blocks
Args:
model_output: Raw text output from model (may be JSON or plain text)
image_size: (width, height) of the image
page_num: Page number
Returns:
List of block dictionaries
"""
blocks = []
try:
# Try to parse as JSON first (if model outputs structured JSON)
try:
output_data = json.loads(model_output)
if isinstance(output_data, dict) and "blocks" in output_data:
# Model outputs structured format
return output_data["blocks"]
elif isinstance(output_data, list):
# Model outputs list of blocks
return output_data
except (json.JSONDecodeError, KeyError):
# Not JSON, treat as plain text
pass
# Parse plain text output
# This is a simple heuristic - adjust based on actual dots.ocr output format
lines = model_output.strip().split('\n')
current_block = None
reading_order = 1
for line in lines:
line = line.strip()
if not line:
continue
# Heuristic: lines starting with # are headings
if line.startswith('#'):
# Save previous block
if current_block:
blocks.append(current_block)
# New heading block
current_block = {
"type": "heading",
"text": line.lstrip('#').strip(),
"bbox": {
"x": 0,
"y": reading_order * 30,
"width": image_size[0],
"height": 30
},
"reading_order": reading_order
}
reading_order += 1
else:
# Regular paragraph
if current_block and current_block["type"] == "paragraph":
# Append to existing paragraph
current_block["text"] += " " + line
else:
# Save previous block
if current_block:
blocks.append(current_block)
# New paragraph block
current_block = {
"type": "paragraph",
"text": line,
"bbox": {
"x": 0,
"y": reading_order * 30,
"width": image_size[0],
"height": 30
},
"reading_order": reading_order
}
reading_order += 1
# Save last block
if current_block:
blocks.append(current_block)
# If no blocks were created, create a single paragraph with all text
if not blocks:
blocks.append({
"type": "paragraph",
"text": model_output.strip(),
"bbox": {
"x": 0,
"y": 0,
"width": image_size[0],
"height": image_size[1]
},
"reading_order": 1
})
except Exception as e:
logger.error(f"Error parsing model output: {e}", exc_info=True)
# Fallback: create single block with raw output
blocks = [{
"type": "paragraph",
"text": model_output.strip() if model_output else "",
"bbox": {
"x": 0,
"y": 0,
"width": image_size[0],
"height": image_size[1]
},
"reading_order": 1
}]
return blocks
def extract_layout_info(model_output: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Extract layout information from model output (if available)
Args:
model_output: Model output dictionary
Returns:
Layout info dictionary or None
"""
# This function should be customized based on actual dots.ocr output format
# For now, return None (no layout info)
return None