Router Configuration: - Add mode='qa_build' routing rule in router-config.yml - Priority 8, uses local_qwen3_8b for Q&A generation 2-Stage Q&A Pipeline Tests: - Create test_qa_pipeline.py with comprehensive tests - Test prompt building, JSON parsing, router integration - Mock DAGI Router responses for testing Region Mode (Grounding OCR): - Add region_bbox and region_page parameters to ParseRequest - Support region mode in local_runtime with bbox in prompt - Update endpoints to accept region parameters (x, y, width, height, page) - Validate region parameters and filter pages for region mode - Pass region_bbox through inference pipeline Updates: - Update local_runtime to support region_bbox in prompts - Update inference.py to pass region_bbox to local_runtime - Update endpoints.py to handle region mode parameters
287 lines
9.4 KiB
Python
287 lines
9.4 KiB
Python
"""
|
|
Local runtime for dots.ocr model with native prompt modes
|
|
Maps OutputMode to dots.ocr prompt modes using dict_promptmode_to_prompt
|
|
"""
|
|
|
|
import os
|
|
import tempfile
|
|
import logging
|
|
from typing import Literal, Optional, Dict, Any
|
|
|
|
import torch
|
|
from transformers import AutoModelForVision2Seq, AutoProcessor
|
|
from qwen_vl_utils import process_vision_info
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Try to import dots.ocr prompt dictionary
|
|
try:
|
|
from dots_ocr.utils.prompts import dict_promptmode_to_prompt
|
|
DOTS_PROMPTS_AVAILABLE = True
|
|
except ImportError:
|
|
logger.warning(
|
|
"dots_ocr.utils.prompts not available. "
|
|
"Using fallback prompts. Install dots.ocr package for native prompt modes."
|
|
)
|
|
DOTS_PROMPTS_AVAILABLE = False
|
|
dict_promptmode_to_prompt = {}
|
|
|
|
# Map OutputMode to dots.ocr native prompt modes
|
|
DOTS_PROMPT_MAP = {
|
|
"raw_json": "prompt_layout_all_en", # Full JSON (layout + content)
|
|
"markdown": "prompt_ocr", # Content-oriented OCR (Markdown)
|
|
"qa_pairs": "prompt_layout_all_en", # Full JSON, then 2nd step LLM
|
|
"chunks": "prompt_layout_all_en", # Full JSON for chunking
|
|
"layout_only": "prompt_layout_only_en", # Layout only (bbox + categories, no text)
|
|
"region": "prompt_grounding_ocr", # Targeted region parsing (grounding)
|
|
}
|
|
|
|
# Fallback prompts if dict_promptmode_to_prompt is not available
|
|
FALLBACK_PROMPTS = {
|
|
"prompt_layout_all_en": (
|
|
"You are a document OCR and layout parser. "
|
|
"Extract all text, tables, formulas, and layout into a clean JSON structure with fields like "
|
|
"`blocks`, `tables`, `reading_order`, including bounding boxes and page numbers. "
|
|
"Respond with JSON only, no explanations."
|
|
),
|
|
"prompt_ocr": (
|
|
"You are a document OCR and layout parser. "
|
|
"Extract the document as Markdown, preserving headings, paragraphs, and tables. "
|
|
"Tables should be proper GitHub-flavored Markdown tables. "
|
|
"Respond with Markdown as plain text."
|
|
),
|
|
"prompt_layout_only_en": (
|
|
"You are a document layout parser. "
|
|
"Extract only the layout structure (bounding boxes, block types, reading order) "
|
|
"without the text content. "
|
|
"Respond with JSON containing only layout information (bbox, type, reading_order)."
|
|
),
|
|
"prompt_grounding_ocr": (
|
|
"You are a document OCR assistant for targeted region parsing. "
|
|
"Extract text and layout for the specified region of the document. "
|
|
"Respond with JSON containing the parsed content for the region."
|
|
),
|
|
}
|
|
|
|
# Global model instance
|
|
_model: Optional[dict] = None
|
|
_processor: Optional[object] = None
|
|
|
|
# Model configuration
|
|
MODEL_PATH = settings.PARSER_MODEL_NAME
|
|
DEVICE = settings.PARSER_DEVICE
|
|
DTYPE = torch.bfloat16 if DEVICE != "cpu" else torch.float32
|
|
MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "24000"))
|
|
|
|
|
|
def load_model():
|
|
"""Load dots.ocr model with lazy initialization"""
|
|
global _model, _processor
|
|
|
|
if _model is not None and _processor is not None:
|
|
return _model, _processor
|
|
|
|
logger.info(f"Loading dots.ocr model: {MODEL_PATH}")
|
|
logger.info(f"Device: {DEVICE}")
|
|
|
|
try:
|
|
model = AutoModelForVision2Seq.from_pretrained(
|
|
MODEL_PATH,
|
|
attn_implementation="flash_attention_2",
|
|
torch_dtype=DTYPE,
|
|
device_map="auto",
|
|
trust_remote_code=True,
|
|
low_cpu_mem_usage=True,
|
|
)
|
|
|
|
processor = AutoProcessor.from_pretrained(
|
|
MODEL_PATH,
|
|
trust_remote_code=True
|
|
)
|
|
|
|
if DEVICE == "cuda" and torch.cuda.is_available():
|
|
model.to("cuda")
|
|
elif DEVICE == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
model.to("mps")
|
|
|
|
_model = model
|
|
_processor = processor
|
|
|
|
logger.info(f"Model loaded successfully on {DEVICE}")
|
|
return _model, _processor
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {e}", exc_info=True)
|
|
raise
|
|
|
|
|
|
def get_model():
|
|
"""Get loaded model instance"""
|
|
if _model is None or _processor is None:
|
|
return load_model()
|
|
return _model, _processor
|
|
|
|
|
|
def _build_prompt(output_mode: str, region_bbox: Optional[Dict[str, Any]] = None) -> str:
|
|
"""
|
|
Build prompt for dots.ocr based on OutputMode
|
|
|
|
Args:
|
|
output_mode: One of "raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"
|
|
region_bbox: Optional bounding box for region mode {"x": float, "y": float, "width": float, "height": float}
|
|
|
|
Returns:
|
|
Prompt string for dots.ocr
|
|
"""
|
|
prompt_key = DOTS_PROMPT_MAP.get(output_mode, "prompt_layout_all_en")
|
|
|
|
# For region mode, add bbox information to prompt
|
|
if output_mode == "region" and region_bbox:
|
|
base_prompt = FALLBACK_PROMPTS.get("prompt_grounding_ocr", "")
|
|
region_info = (
|
|
f"\n\nExtract content from the specified region:\n"
|
|
f"Bounding box: x={region_bbox.get('x', 0)}, y={region_bbox.get('y', 0)}, "
|
|
f"width={region_bbox.get('width', 0)}, height={region_bbox.get('height', 0)}"
|
|
)
|
|
return base_prompt + region_info
|
|
|
|
# Try to use native dots.ocr prompts
|
|
if DOTS_PROMPTS_AVAILABLE and prompt_key in dict_promptmode_to_prompt:
|
|
prompt = dict_promptmode_to_prompt[prompt_key]
|
|
logger.debug(f"Using native dots.ocr prompt: {prompt_key}")
|
|
return prompt
|
|
|
|
# Fallback to our prompts
|
|
if prompt_key in FALLBACK_PROMPTS:
|
|
logger.debug(f"Using fallback prompt: {prompt_key}")
|
|
return FALLBACK_PROMPTS[prompt_key]
|
|
|
|
# Ultimate fallback
|
|
logger.warning(f"Unknown prompt key: {prompt_key}, using default")
|
|
return FALLBACK_PROMPTS["prompt_layout_all_en"]
|
|
|
|
|
|
def _build_messages(image_path: str, prompt: str) -> list:
|
|
"""Build messages for dots.ocr model"""
|
|
return [
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image", "image": image_path},
|
|
{"type": "text", "text": prompt},
|
|
],
|
|
}
|
|
]
|
|
|
|
|
|
def _generate_from_path(
|
|
image_path: str,
|
|
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"],
|
|
region_bbox: Optional[Dict[str, Any]] = None
|
|
) -> str:
|
|
"""
|
|
Generate output from image path using dots.ocr model
|
|
|
|
Args:
|
|
image_path: Path to image file
|
|
output_mode: Output mode (maps to dots.ocr prompt mode)
|
|
|
|
Returns:
|
|
Generated text from model
|
|
"""
|
|
model, processor = get_model()
|
|
prompt = _build_prompt(output_mode, region_bbox)
|
|
messages = _build_messages(image_path, prompt)
|
|
|
|
# Apply chat template
|
|
text = processor.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=True,
|
|
)
|
|
|
|
# Process vision info
|
|
image_inputs, video_inputs = process_vision_info(messages)
|
|
|
|
# Prepare inputs
|
|
inputs = processor(
|
|
text=[text],
|
|
images=image_inputs,
|
|
videos=video_inputs,
|
|
padding=True,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
# Move to device
|
|
device = DEVICE
|
|
if device == "cuda" and torch.cuda.is_available():
|
|
inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v
|
|
for k, v in inputs.items()}
|
|
elif device == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
|
inputs = {k: v.to("mps") if isinstance(v, torch.Tensor) else v
|
|
for k, v in inputs.items()}
|
|
|
|
# Generate
|
|
with torch.inference_mode():
|
|
generated_ids = model.generate(
|
|
**inputs,
|
|
max_new_tokens=MAX_NEW_TOKENS,
|
|
)
|
|
|
|
# Trim input tokens
|
|
generated_ids_trimmed = [
|
|
out_ids[len(in_ids):]
|
|
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
|
|
]
|
|
|
|
# Decode
|
|
output_text = processor.batch_decode(
|
|
generated_ids_trimmed,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=False,
|
|
)
|
|
|
|
return output_text[0]
|
|
|
|
|
|
def parse_document_with_local(
|
|
image_bytes: bytes,
|
|
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"] = "raw_json",
|
|
region_bbox: Optional[Dict[str, Any]] = None
|
|
) -> str:
|
|
"""
|
|
Parse document from image bytes using local dots.ocr model
|
|
|
|
Args:
|
|
image_bytes: Image bytes (PNG/JPEG)
|
|
output_mode: Output mode (maps to dots.ocr prompt mode)
|
|
- raw_json: Full JSON (layout + content) via prompt_layout_all_en
|
|
- markdown: Markdown text via prompt_ocr
|
|
- qa_pairs: Full JSON (same as raw_json), then 2nd step LLM
|
|
- chunks: Full JSON for chunking
|
|
- layout_only: Layout only (bbox + categories) via prompt_layout_only_en
|
|
- region: Targeted region parsing via prompt_grounding_ocr
|
|
|
|
Returns:
|
|
Generated text from model (JSON or Markdown depending on mode)
|
|
|
|
Note:
|
|
For "qa_pairs" mode, this returns full JSON.
|
|
The 2nd step (LLM Q&A generation) should be done separately.
|
|
"""
|
|
# Save to temporary file
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
|
tmp_path = f.name
|
|
f.write(image_bytes)
|
|
|
|
try:
|
|
return _generate_from_path(tmp_path, output_mode, region_bbox)
|
|
finally:
|
|
try:
|
|
os.remove(tmp_path)
|
|
except OSError:
|
|
pass
|
|
|