feat: integrate dots.ocr native prompt modes and 2-stage qa_pairs pipeline
Prompt Modes Integration: - Create local_runtime.py with DOTS_PROMPT_MAP - Map OutputMode to native dots.ocr prompt modes (prompt_layout_all_en, prompt_ocr, etc.) - Support dict_promptmode_to_prompt from dots.ocr with fallback prompts - Add layout_only and region modes to OutputMode enum 2-Stage Q&A Pipeline: - Create qa_builder.py for 2-stage qa_pairs generation - Stage 1: PARSER (dots.ocr) → raw JSON via prompt_layout_all_en - Stage 2: LLM (DAGI Router) → Q&A pairs via mode=qa_build - Update endpoints.py to use 2-stage pipeline for qa_pairs mode - Add ROUTER_BASE_URL and ROUTER_TIMEOUT to config Updates: - Update inference.py to use local_runtime with native prompts - Update ollama_client.py to use same prompt map - Add PROMPT_MODES.md documentation
This commit is contained in:
273
services/parser-service/app/runtime/local_runtime.py
Normal file
273
services/parser-service/app/runtime/local_runtime.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Local runtime for dots.ocr model with native prompt modes
|
||||
Maps OutputMode to dots.ocr prompt modes using dict_promptmode_to_prompt
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import dots.ocr prompt dictionary
|
||||
try:
|
||||
from dots_ocr.utils.prompts import dict_promptmode_to_prompt
|
||||
DOTS_PROMPTS_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"dots_ocr.utils.prompts not available. "
|
||||
"Using fallback prompts. Install dots.ocr package for native prompt modes."
|
||||
)
|
||||
DOTS_PROMPTS_AVAILABLE = False
|
||||
dict_promptmode_to_prompt = {}
|
||||
|
||||
# Map OutputMode to dots.ocr native prompt modes
|
||||
DOTS_PROMPT_MAP = {
|
||||
"raw_json": "prompt_layout_all_en", # Full JSON (layout + content)
|
||||
"markdown": "prompt_ocr", # Content-oriented OCR (Markdown)
|
||||
"qa_pairs": "prompt_layout_all_en", # Full JSON, then 2nd step LLM
|
||||
"chunks": "prompt_layout_all_en", # Full JSON for chunking
|
||||
"layout_only": "prompt_layout_only_en", # Layout only (bbox + categories, no text)
|
||||
"region": "prompt_grounding_ocr", # Targeted region parsing (grounding)
|
||||
}
|
||||
|
||||
# Fallback prompts if dict_promptmode_to_prompt is not available
|
||||
FALLBACK_PROMPTS = {
|
||||
"prompt_layout_all_en": (
|
||||
"You are a document OCR and layout parser. "
|
||||
"Extract all text, tables, formulas, and layout into a clean JSON structure with fields like "
|
||||
"`blocks`, `tables`, `reading_order`, including bounding boxes and page numbers. "
|
||||
"Respond with JSON only, no explanations."
|
||||
),
|
||||
"prompt_ocr": (
|
||||
"You are a document OCR and layout parser. "
|
||||
"Extract the document as Markdown, preserving headings, paragraphs, and tables. "
|
||||
"Tables should be proper GitHub-flavored Markdown tables. "
|
||||
"Respond with Markdown as plain text."
|
||||
),
|
||||
"prompt_layout_only_en": (
|
||||
"You are a document layout parser. "
|
||||
"Extract only the layout structure (bounding boxes, block types, reading order) "
|
||||
"without the text content. "
|
||||
"Respond with JSON containing only layout information (bbox, type, reading_order)."
|
||||
),
|
||||
"prompt_grounding_ocr": (
|
||||
"You are a document OCR assistant for targeted region parsing. "
|
||||
"Extract text and layout for the specified region of the document. "
|
||||
"Respond with JSON containing the parsed content for the region."
|
||||
),
|
||||
}
|
||||
|
||||
# Global model instance
|
||||
_model: Optional[dict] = None
|
||||
_processor: Optional[object] = None
|
||||
|
||||
# Model configuration
|
||||
MODEL_PATH = settings.PARSER_MODEL_NAME
|
||||
DEVICE = settings.PARSER_DEVICE
|
||||
DTYPE = torch.bfloat16 if DEVICE != "cpu" else torch.float32
|
||||
MAX_NEW_TOKENS = int(os.getenv("DOTS_OCR_MAX_NEW_TOKENS", "24000"))
|
||||
|
||||
|
||||
def load_model():
|
||||
"""Load dots.ocr model with lazy initialization"""
|
||||
global _model, _processor
|
||||
|
||||
if _model is not None and _processor is not None:
|
||||
return _model, _processor
|
||||
|
||||
logger.info(f"Loading dots.ocr model: {MODEL_PATH}")
|
||||
logger.info(f"Device: {DEVICE}")
|
||||
|
||||
try:
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
MODEL_PATH,
|
||||
attn_implementation="flash_attention_2",
|
||||
torch_dtype=DTYPE,
|
||||
device_map="auto",
|
||||
trust_remote_code=True,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
MODEL_PATH,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
if DEVICE == "cuda" and torch.cuda.is_available():
|
||||
model.to("cuda")
|
||||
elif DEVICE == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
model.to("mps")
|
||||
|
||||
_model = model
|
||||
_processor = processor
|
||||
|
||||
logger.info(f"Model loaded successfully on {DEVICE}")
|
||||
return _model, _processor
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def get_model():
|
||||
"""Get loaded model instance"""
|
||||
if _model is None or _processor is None:
|
||||
return load_model()
|
||||
return _model, _processor
|
||||
|
||||
|
||||
def _build_prompt(output_mode: str) -> str:
|
||||
"""
|
||||
Build prompt for dots.ocr based on OutputMode
|
||||
|
||||
Args:
|
||||
output_mode: One of "raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"
|
||||
|
||||
Returns:
|
||||
Prompt string for dots.ocr
|
||||
"""
|
||||
prompt_key = DOTS_PROMPT_MAP.get(output_mode, "prompt_layout_all_en")
|
||||
|
||||
# Try to use native dots.ocr prompts
|
||||
if DOTS_PROMPTS_AVAILABLE and prompt_key in dict_promptmode_to_prompt:
|
||||
prompt = dict_promptmode_to_prompt[prompt_key]
|
||||
logger.debug(f"Using native dots.ocr prompt: {prompt_key}")
|
||||
return prompt
|
||||
|
||||
# Fallback to our prompts
|
||||
if prompt_key in FALLBACK_PROMPTS:
|
||||
logger.debug(f"Using fallback prompt: {prompt_key}")
|
||||
return FALLBACK_PROMPTS[prompt_key]
|
||||
|
||||
# Ultimate fallback
|
||||
logger.warning(f"Unknown prompt key: {prompt_key}, using default")
|
||||
return FALLBACK_PROMPTS["prompt_layout_all_en"]
|
||||
|
||||
|
||||
def _build_messages(image_path: str, prompt: str) -> list:
|
||||
"""Build messages for dots.ocr model"""
|
||||
return [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": image_path},
|
||||
{"type": "text", "text": prompt},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _generate_from_path(
|
||||
image_path: str,
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"]
|
||||
) -> str:
|
||||
"""
|
||||
Generate output from image path using dots.ocr model
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
output_mode: Output mode (maps to dots.ocr prompt mode)
|
||||
|
||||
Returns:
|
||||
Generated text from model
|
||||
"""
|
||||
model, processor = get_model()
|
||||
prompt = _build_prompt(output_mode)
|
||||
messages = _build_messages(image_path, prompt)
|
||||
|
||||
# Apply chat template
|
||||
text = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
# Process vision info
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
|
||||
# Prepare inputs
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Move to device
|
||||
device = DEVICE
|
||||
if device == "cuda" and torch.cuda.is_available():
|
||||
inputs = {k: v.to("cuda") if isinstance(v, torch.Tensor) else v
|
||||
for k, v in inputs.items()}
|
||||
elif device == "mps" and hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
inputs = {k: v.to("mps") if isinstance(v, torch.Tensor) else v
|
||||
for k, v in inputs.items()}
|
||||
|
||||
# Generate
|
||||
with torch.inference_mode():
|
||||
generated_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=MAX_NEW_TOKENS,
|
||||
)
|
||||
|
||||
# Trim input tokens
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):]
|
||||
for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
|
||||
]
|
||||
|
||||
# Decode
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
|
||||
return output_text[0]
|
||||
|
||||
|
||||
def parse_document_with_local(
|
||||
image_bytes: bytes,
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"] = "raw_json"
|
||||
) -> str:
|
||||
"""
|
||||
Parse document from image bytes using local dots.ocr model
|
||||
|
||||
Args:
|
||||
image_bytes: Image bytes (PNG/JPEG)
|
||||
output_mode: Output mode (maps to dots.ocr prompt mode)
|
||||
- raw_json: Full JSON (layout + content) via prompt_layout_all_en
|
||||
- markdown: Markdown text via prompt_ocr
|
||||
- qa_pairs: Full JSON (same as raw_json), then 2nd step LLM
|
||||
- chunks: Full JSON for chunking
|
||||
- layout_only: Layout only (bbox + categories) via prompt_layout_only_en
|
||||
- region: Targeted region parsing via prompt_grounding_ocr
|
||||
|
||||
Returns:
|
||||
Generated text from model (JSON or Markdown depending on mode)
|
||||
|
||||
Note:
|
||||
For "qa_pairs" mode, this returns full JSON.
|
||||
The 2nd step (LLM Q&A generation) should be done separately.
|
||||
"""
|
||||
# Save to temporary file
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as f:
|
||||
tmp_path = f.name
|
||||
f.write(image_bytes)
|
||||
|
||||
try:
|
||||
return _generate_from_path(tmp_path, output_mode)
|
||||
finally:
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user