feat: add qa_build mode, tests, and region mode support
Router Configuration: - Add mode='qa_build' routing rule in router-config.yml - Priority 8, uses local_qwen3_8b for Q&A generation 2-Stage Q&A Pipeline Tests: - Create test_qa_pipeline.py with comprehensive tests - Test prompt building, JSON parsing, router integration - Mock DAGI Router responses for testing Region Mode (Grounding OCR): - Add region_bbox and region_page parameters to ParseRequest - Support region mode in local_runtime with bbox in prompt - Update endpoints to accept region parameters (x, y, width, height, page) - Validate region parameters and filter pages for region mode - Pass region_bbox through inference pipeline Updates: - Update local_runtime to support region_bbox in prompts - Update inference.py to pass region_bbox to local_runtime - Update endpoints.py to handle region mode parameters
This commit is contained in:
@@ -6,7 +6,7 @@ Maps OutputMode to dots.ocr prompt modes using dict_promptmode_to_prompt
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, Dict, Any
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
@@ -124,18 +124,29 @@ def get_model():
|
||||
return _model, _processor
|
||||
|
||||
|
||||
def _build_prompt(output_mode: str) -> str:
|
||||
def _build_prompt(output_mode: str, region_bbox: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Build prompt for dots.ocr based on OutputMode
|
||||
|
||||
Args:
|
||||
output_mode: One of "raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"
|
||||
region_bbox: Optional bounding box for region mode {"x": float, "y": float, "width": float, "height": float}
|
||||
|
||||
Returns:
|
||||
Prompt string for dots.ocr
|
||||
"""
|
||||
prompt_key = DOTS_PROMPT_MAP.get(output_mode, "prompt_layout_all_en")
|
||||
|
||||
# For region mode, add bbox information to prompt
|
||||
if output_mode == "region" and region_bbox:
|
||||
base_prompt = FALLBACK_PROMPTS.get("prompt_grounding_ocr", "")
|
||||
region_info = (
|
||||
f"\n\nExtract content from the specified region:\n"
|
||||
f"Bounding box: x={region_bbox.get('x', 0)}, y={region_bbox.get('y', 0)}, "
|
||||
f"width={region_bbox.get('width', 0)}, height={region_bbox.get('height', 0)}"
|
||||
)
|
||||
return base_prompt + region_info
|
||||
|
||||
# Try to use native dots.ocr prompts
|
||||
if DOTS_PROMPTS_AVAILABLE and prompt_key in dict_promptmode_to_prompt:
|
||||
prompt = dict_promptmode_to_prompt[prompt_key]
|
||||
@@ -167,7 +178,8 @@ def _build_messages(image_path: str, prompt: str) -> list:
|
||||
|
||||
def _generate_from_path(
|
||||
image_path: str,
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"]
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"],
|
||||
region_bbox: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate output from image path using dots.ocr model
|
||||
@@ -180,7 +192,7 @@ def _generate_from_path(
|
||||
Generated text from model
|
||||
"""
|
||||
model, processor = get_model()
|
||||
prompt = _build_prompt(output_mode)
|
||||
prompt = _build_prompt(output_mode, region_bbox)
|
||||
messages = _build_messages(image_path, prompt)
|
||||
|
||||
# Apply chat template
|
||||
@@ -236,7 +248,8 @@ def _generate_from_path(
|
||||
|
||||
def parse_document_with_local(
|
||||
image_bytes: bytes,
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"] = "raw_json"
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"] = "raw_json",
|
||||
region_bbox: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Parse document from image bytes using local dots.ocr model
|
||||
@@ -264,7 +277,7 @@ def parse_document_with_local(
|
||||
f.write(image_bytes)
|
||||
|
||||
try:
|
||||
return _generate_from_path(tmp_path, output_mode)
|
||||
return _generate_from_path(tmp_path, output_mode, region_bbox)
|
||||
finally:
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
|
||||
Reference in New Issue
Block a user