feat: add qa_build mode, tests, and region mode support
Router Configuration: - Add mode='qa_build' routing rule in router-config.yml - Priority 8, uses local_qwen3_8b for Q&A generation 2-Stage Q&A Pipeline Tests: - Create test_qa_pipeline.py with comprehensive tests - Test prompt building, JSON parsing, router integration - Mock DAGI Router responses for testing Region Mode (Grounding OCR): - Add region_bbox and region_page parameters to ParseRequest - Support region mode in local_runtime with bbox in prompt - Update endpoints to accept region parameters (x, y, width, height, page) - Validate region parameters and filter pages for region mode - Pass region_bbox through inference pipeline Updates: - Update local_runtime to support region_bbox in prompts - Update inference.py to pass region_bbox to local_runtime - Update endpoints.py to handle region mode parameters
This commit is contained in:
@@ -95,6 +95,14 @@ routing:
|
||||
use_llm: local_qwen3_8b
|
||||
description: "microDAO chat → local LLM with RBAC context"
|
||||
|
||||
# Q&A Builder mode (for parser-service 2-stage pipeline)
|
||||
- id: qa_build_mode
|
||||
priority: 8
|
||||
when:
|
||||
mode: qa_build
|
||||
use_llm: local_qwen3_8b
|
||||
description: "Q&A generation from parsed documents → local LLM"
|
||||
|
||||
# NEW: CrewAI workflow orchestration
|
||||
- id: crew_mode
|
||||
priority: 3
|
||||
|
||||
@@ -34,7 +34,12 @@ async def parse_document_endpoint(
|
||||
doc_url: Optional[str] = Form(None),
|
||||
output_mode: str = Form("raw_json"),
|
||||
dao_id: Optional[str] = Form(None),
|
||||
doc_id: Optional[str] = Form(None)
|
||||
doc_id: Optional[str] = Form(None),
|
||||
region_bbox_x: Optional[float] = Form(None),
|
||||
region_bbox_y: Optional[float] = Form(None),
|
||||
region_bbox_width: Optional[float] = Form(None),
|
||||
region_bbox_height: Optional[float] = Form(None),
|
||||
region_page: Optional[int] = Form(None)
|
||||
):
|
||||
"""
|
||||
Parse document (PDF or image) using dots.ocr
|
||||
@@ -81,6 +86,30 @@ async def parse_document_endpoint(
|
||||
image = load_image(content)
|
||||
images = [image]
|
||||
|
||||
# For region mode, validate and prepare region bbox
|
||||
region_bbox = None
|
||||
if output_mode == "region":
|
||||
if not all([region_bbox_x is not None, region_bbox_y is not None,
|
||||
region_bbox_width is not None, region_bbox_height is not None]):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="region mode requires region_bbox_x, region_bbox_y, region_bbox_width, region_bbox_height"
|
||||
)
|
||||
region_bbox = {
|
||||
"x": float(region_bbox_x),
|
||||
"y": float(region_bbox_y),
|
||||
"width": float(region_bbox_width),
|
||||
"height": float(region_bbox_height)
|
||||
}
|
||||
# If region_page specified, only process that page
|
||||
if region_page is not None:
|
||||
if region_page < 1 or region_page > len(images):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"region_page {region_page} out of range (1-{len(images)})"
|
||||
)
|
||||
images = [images[region_page - 1]]
|
||||
|
||||
else:
|
||||
# TODO: Download from doc_url
|
||||
raise HTTPException(
|
||||
@@ -99,14 +128,16 @@ async def parse_document_endpoint(
|
||||
images=images,
|
||||
output_mode=output_mode,
|
||||
doc_id=doc_id or str(uuid.uuid4()),
|
||||
doc_type=doc_type
|
||||
doc_type=doc_type,
|
||||
region_bbox=region_bbox
|
||||
)
|
||||
else:
|
||||
parsed_doc = parse_document_from_images(
|
||||
images=images,
|
||||
output_mode=output_mode,
|
||||
doc_id=doc_id or str(uuid.uuid4()),
|
||||
doc_type=doc_type
|
||||
doc_type=doc_type,
|
||||
region_bbox=region_bbox
|
||||
)
|
||||
|
||||
# Build response based on output_mode
|
||||
|
||||
@@ -29,7 +29,8 @@ async def parse_document_with_ollama(
|
||||
images: List[Image.Image],
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"] = "raw_json",
|
||||
doc_id: Optional[str] = None,
|
||||
doc_type: Literal["pdf", "image"] = "image"
|
||||
doc_type: Literal["pdf", "image"] = "image",
|
||||
region_bbox: Optional[dict] = None
|
||||
) -> ParsedDocument:
|
||||
"""
|
||||
Parse document using Ollama API
|
||||
@@ -109,7 +110,8 @@ def parse_document_from_images(
|
||||
images: List[Image.Image],
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"] = "raw_json",
|
||||
doc_id: Optional[str] = None,
|
||||
doc_type: Literal["pdf", "image"] = "image"
|
||||
doc_type: Literal["pdf", "image"] = "image",
|
||||
region_bbox: Optional[dict] = None
|
||||
) -> ParsedDocument:
|
||||
"""
|
||||
Parse document from list of images using dots.ocr model
|
||||
@@ -159,7 +161,7 @@ def parse_document_from_images(
|
||||
image_bytes = buf.getvalue()
|
||||
|
||||
# Use local_runtime with native prompt modes
|
||||
generated_text = parse_document_with_local(image_bytes, output_mode)
|
||||
generated_text = parse_document_with_local(image_bytes, output_mode, region_bbox)
|
||||
|
||||
logger.debug(f"Model output for page {idx}: {generated_text[:100]}...")
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ Maps OutputMode to dots.ocr prompt modes using dict_promptmode_to_prompt
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
from typing import Literal, Optional
|
||||
from typing import Literal, Optional, Dict, Any
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
@@ -124,18 +124,29 @@ def get_model():
|
||||
return _model, _processor
|
||||
|
||||
|
||||
def _build_prompt(output_mode: str) -> str:
|
||||
def _build_prompt(output_mode: str, region_bbox: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""
|
||||
Build prompt for dots.ocr based on OutputMode
|
||||
|
||||
Args:
|
||||
output_mode: One of "raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"
|
||||
region_bbox: Optional bounding box for region mode {"x": float, "y": float, "width": float, "height": float}
|
||||
|
||||
Returns:
|
||||
Prompt string for dots.ocr
|
||||
"""
|
||||
prompt_key = DOTS_PROMPT_MAP.get(output_mode, "prompt_layout_all_en")
|
||||
|
||||
# For region mode, add bbox information to prompt
|
||||
if output_mode == "region" and region_bbox:
|
||||
base_prompt = FALLBACK_PROMPTS.get("prompt_grounding_ocr", "")
|
||||
region_info = (
|
||||
f"\n\nExtract content from the specified region:\n"
|
||||
f"Bounding box: x={region_bbox.get('x', 0)}, y={region_bbox.get('y', 0)}, "
|
||||
f"width={region_bbox.get('width', 0)}, height={region_bbox.get('height', 0)}"
|
||||
)
|
||||
return base_prompt + region_info
|
||||
|
||||
# Try to use native dots.ocr prompts
|
||||
if DOTS_PROMPTS_AVAILABLE and prompt_key in dict_promptmode_to_prompt:
|
||||
prompt = dict_promptmode_to_prompt[prompt_key]
|
||||
@@ -167,7 +178,8 @@ def _build_messages(image_path: str, prompt: str) -> list:
|
||||
|
||||
def _generate_from_path(
|
||||
image_path: str,
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"]
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"],
|
||||
region_bbox: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate output from image path using dots.ocr model
|
||||
@@ -180,7 +192,7 @@ def _generate_from_path(
|
||||
Generated text from model
|
||||
"""
|
||||
model, processor = get_model()
|
||||
prompt = _build_prompt(output_mode)
|
||||
prompt = _build_prompt(output_mode, region_bbox)
|
||||
messages = _build_messages(image_path, prompt)
|
||||
|
||||
# Apply chat template
|
||||
@@ -236,7 +248,8 @@ def _generate_from_path(
|
||||
|
||||
def parse_document_with_local(
|
||||
image_bytes: bytes,
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"] = "raw_json"
|
||||
output_mode: Literal["raw_json", "markdown", "qa_pairs", "chunks", "layout_only", "region"] = "raw_json",
|
||||
region_bbox: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Parse document from image bytes using local dots.ocr model
|
||||
@@ -264,7 +277,7 @@ def parse_document_with_local(
|
||||
f.write(image_bytes)
|
||||
|
||||
try:
|
||||
return _generate_from_path(tmp_path, output_mode)
|
||||
return _generate_from_path(tmp_path, output_mode, region_bbox)
|
||||
finally:
|
||||
try:
|
||||
os.remove(tmp_path)
|
||||
|
||||
@@ -120,6 +120,9 @@ class ParseRequest(BaseModel):
|
||||
)
|
||||
dao_id: Optional[str] = Field(None, description="DAO ID")
|
||||
doc_id: Optional[str] = Field(None, description="Document ID")
|
||||
# Region mode parameters (for grounding OCR)
|
||||
region_bbox: Optional[BBox] = Field(None, description="Bounding box for region mode (x, y, width, height)")
|
||||
region_page: Optional[int] = Field(None, description="Page number for region mode")
|
||||
|
||||
|
||||
class ParseResponse(BaseModel):
|
||||
|
||||
197
services/parser-service/tests/test_qa_pipeline.py
Normal file
197
services/parser-service/tests/test_qa_pipeline.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Tests for 2-stage Q&A pipeline (PARSER → LLM)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, patch, MagicMock
|
||||
from PIL import Image
|
||||
import io
|
||||
import json
|
||||
|
||||
from app.schemas import ParsedDocument, ParsedPage, ParsedBlock, BBox
|
||||
from app.runtime.qa_builder import build_qa_pairs_via_router, _build_qa_prompt, _parse_qa_response
|
||||
|
||||
|
||||
class TestQABuilder:
|
||||
"""Tests for Q&A builder (2-stage pipeline)"""
|
||||
|
||||
def test_build_qa_prompt(self):
|
||||
"""Test prompt building for Q&A generation"""
|
||||
# Create mock parsed document
|
||||
doc = ParsedDocument(
|
||||
doc_id="test-doc",
|
||||
doc_type="pdf",
|
||||
pages=[
|
||||
ParsedPage(
|
||||
page_num=1,
|
||||
blocks=[
|
||||
ParsedBlock(
|
||||
type="heading",
|
||||
text="Test Document",
|
||||
bbox=BBox(x=0, y=0, width=800, height=50),
|
||||
reading_order=1,
|
||||
page_num=1
|
||||
),
|
||||
ParsedBlock(
|
||||
type="paragraph",
|
||||
text="This is a test document with some content.",
|
||||
bbox=BBox(x=0, y=60, width=800, height=100),
|
||||
reading_order=2,
|
||||
page_num=1
|
||||
)
|
||||
],
|
||||
width=800,
|
||||
height=600
|
||||
)
|
||||
],
|
||||
metadata={}
|
||||
)
|
||||
|
||||
prompt = _build_qa_prompt(doc)
|
||||
|
||||
# Check prompt structure
|
||||
assert "OCR-документу" in prompt
|
||||
assert "JSON-масив" in prompt
|
||||
assert "question" in prompt
|
||||
assert "answer" in prompt
|
||||
assert "source_page" in prompt
|
||||
assert "Test Document" in prompt or "test document" in prompt.lower()
|
||||
|
||||
def test_parse_qa_response_valid_json(self):
|
||||
"""Test parsing valid JSON response from LLM"""
|
||||
doc = ParsedDocument(
|
||||
doc_id="test-doc",
|
||||
doc_type="pdf",
|
||||
pages=[ParsedPage(page_num=1, blocks=[], width=800, height=600)],
|
||||
metadata={}
|
||||
)
|
||||
|
||||
response_text = json.dumps([
|
||||
{
|
||||
"question": "Що це за документ?",
|
||||
"answer": "Це тестовий документ",
|
||||
"source_page": 1,
|
||||
"confidence": 0.9
|
||||
},
|
||||
{
|
||||
"question": "Який контент?",
|
||||
"answer": "Тестовий контент",
|
||||
"source_page": 1
|
||||
}
|
||||
])
|
||||
|
||||
qa_pairs = _parse_qa_response(response_text, doc)
|
||||
|
||||
assert len(qa_pairs) == 2
|
||||
assert qa_pairs[0].question == "Що це за документ?"
|
||||
assert qa_pairs[0].answer == "Це тестовий документ"
|
||||
assert qa_pairs[0].source_page == 1
|
||||
assert qa_pairs[0].confidence == 0.9
|
||||
|
||||
def test_parse_qa_response_markdown_code_block(self):
|
||||
"""Test parsing JSON from markdown code block"""
|
||||
doc = ParsedDocument(
|
||||
doc_id="test-doc",
|
||||
doc_type="pdf",
|
||||
pages=[ParsedPage(page_num=1, blocks=[], width=800, height=600)],
|
||||
metadata={}
|
||||
)
|
||||
|
||||
response_text = "```json\n" + json.dumps([
|
||||
{
|
||||
"question": "Тест?",
|
||||
"answer": "Відповідь"
|
||||
}
|
||||
]) + "\n```"
|
||||
|
||||
qa_pairs = _parse_qa_response(response_text, doc)
|
||||
|
||||
assert len(qa_pairs) == 1
|
||||
assert qa_pairs[0].question == "Тест?"
|
||||
|
||||
def test_parse_qa_response_invalid_json(self):
|
||||
"""Test parsing invalid JSON (should return empty list)"""
|
||||
doc = ParsedDocument(
|
||||
doc_id="test-doc",
|
||||
doc_type="pdf",
|
||||
pages=[ParsedPage(page_num=1, blocks=[], width=800, height=600)],
|
||||
metadata={}
|
||||
)
|
||||
|
||||
response_text = "This is not JSON"
|
||||
|
||||
qa_pairs = _parse_qa_response(response_text, doc)
|
||||
|
||||
assert len(qa_pairs) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_qa_pairs_via_router_success(self):
|
||||
"""Test successful Q&A generation via DAGI Router"""
|
||||
# Create mock parsed document
|
||||
doc = ParsedDocument(
|
||||
doc_id="test-doc",
|
||||
doc_type="pdf",
|
||||
pages=[
|
||||
ParsedPage(
|
||||
page_num=1,
|
||||
blocks=[
|
||||
ParsedBlock(
|
||||
type="paragraph",
|
||||
text="Test content",
|
||||
bbox=BBox(x=0, y=0, width=800, height=100),
|
||||
reading_order=1,
|
||||
page_num=1
|
||||
)
|
||||
],
|
||||
width=800,
|
||||
height=600
|
||||
)
|
||||
],
|
||||
metadata={}
|
||||
)
|
||||
|
||||
# Mock router response
|
||||
mock_response = {
|
||||
"ok": True,
|
||||
"data": {
|
||||
"text": json.dumps([
|
||||
{
|
||||
"question": "Що це?",
|
||||
"answer": "Тест",
|
||||
"source_page": 1
|
||||
}
|
||||
])
|
||||
}
|
||||
}
|
||||
|
||||
with patch("app.runtime.qa_builder.httpx.AsyncClient") as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.post = AsyncMock(
|
||||
return_value=MagicMock(
|
||||
raise_for_status=MagicMock(),
|
||||
json=lambda: mock_response
|
||||
)
|
||||
)
|
||||
|
||||
qa_pairs = await build_qa_pairs_via_router(doc, dao_id="test-dao")
|
||||
|
||||
assert len(qa_pairs) == 1
|
||||
assert qa_pairs[0].question == "Що це?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_qa_pairs_via_router_failure(self):
|
||||
"""Test Q&A generation failure (should raise exception)"""
|
||||
doc = ParsedDocument(
|
||||
doc_id="test-doc",
|
||||
doc_type="pdf",
|
||||
pages=[ParsedPage(page_num=1, blocks=[], width=800, height=600)],
|
||||
metadata={}
|
||||
)
|
||||
|
||||
with patch("app.runtime.qa_builder.httpx.AsyncClient") as mock_client:
|
||||
mock_client.return_value.__aenter__.return_value.post = AsyncMock(
|
||||
side_effect=Exception("Router error")
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await build_qa_pairs_via_router(doc, dao_id="test-dao")
|
||||
|
||||
Reference in New Issue
Block a user