feat: add tests and integrate dots.ocr model
G.2.5 - Tests: - Add pytest test suite with fixtures - test_preprocessing.py - PDF/image loading, normalization, validation - test_postprocessing.py - chunks, QA pairs, markdown generation - test_inference.py - dummy parser and inference functions - test_api.py - API endpoint tests - Add pytest.ini configuration G.1.3 - dots.ocr Integration: - Update model_loader.py with real model loading code - Support for AutoModelForVision2Seq and AutoProcessor - Device handling (CUDA/CPU/MPS) with fallback - Error handling with dummy fallback option - Update inference.py with real model inference - Process images through model - Generate and decode outputs - Parse model output to blocks - Add model_output_parser.py - Parse JSON or plain text model output - Convert to structured blocks - Layout detection support (placeholder) Dependencies: - Add pytest, pytest-asyncio, httpx for testing
This commit is contained in:
@@ -6,6 +6,7 @@ import logging
|
||||
from typing import Literal, Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from app.schemas import ParsedDocument, ParsedPage, ParsedBlock, BBox
|
||||
@@ -14,6 +15,7 @@ from app.runtime.preprocessing import (
|
||||
convert_pdf_to_images, load_image, prepare_images_for_model
|
||||
)
|
||||
from app.runtime.postprocessing import build_parsed_document
|
||||
from app.runtime.model_output_parser import parse_model_output_to_blocks
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -63,36 +65,46 @@ def parse_document_from_images(
|
||||
|
||||
for idx, image in enumerate(prepared_images, start=1):
|
||||
try:
|
||||
# TODO: Implement actual inference with dots.ocr
|
||||
# Example:
|
||||
# inputs = model["processor"](images=image, return_tensors="pt")
|
||||
# outputs = model["model"].generate(**inputs)
|
||||
# text = model["processor"].decode(outputs[0], skip_special_tokens=True)
|
||||
#
|
||||
# # Parse model output into blocks
|
||||
# blocks = parse_model_output_to_blocks(text, image.size)
|
||||
#
|
||||
# pages_data.append({
|
||||
# "blocks": blocks,
|
||||
# "width": image.width,
|
||||
# "height": image.height
|
||||
# })
|
||||
# Prepare inputs for model
|
||||
inputs = model["processor"](images=image, return_tensors="pt")
|
||||
|
||||
# Move inputs to device
|
||||
device = model["device"]
|
||||
if device != "cpu":
|
||||
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in inputs.items()}
|
||||
|
||||
# Generate output
|
||||
with torch.no_grad():
|
||||
outputs = model["model"].generate(
|
||||
**inputs,
|
||||
max_new_tokens=2048, # Adjust based on model capabilities
|
||||
do_sample=False # Deterministic output
|
||||
)
|
||||
|
||||
# Decode output
|
||||
generated_text = model["processor"].decode(
|
||||
outputs[0],
|
||||
skip_special_tokens=True
|
||||
)
|
||||
|
||||
logger.debug(f"Model output for page {idx}: {generated_text[:100]}...")
|
||||
|
||||
# Parse model output into blocks
|
||||
blocks = parse_model_output_to_blocks(
|
||||
generated_text,
|
||||
image.size,
|
||||
page_num=idx
|
||||
)
|
||||
|
||||
# For now, use dummy for each page
|
||||
logger.debug(f"Processing page {idx} with model (placeholder)")
|
||||
pages_data.append({
|
||||
"blocks": [
|
||||
{
|
||||
"type": "paragraph",
|
||||
"text": f"Page {idx} content (model output placeholder)",
|
||||
"bbox": {"x": 0, "y": 0, "width": image.width, "height": image.height},
|
||||
"reading_order": 1
|
||||
}
|
||||
],
|
||||
"blocks": blocks,
|
||||
"width": image.width,
|
||||
"height": image.height
|
||||
})
|
||||
|
||||
logger.info(f"Processed page {idx}/{len(prepared_images)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing page {idx}: {e}", exc_info=True)
|
||||
# Continue with other pages
|
||||
|
||||
@@ -36,30 +36,62 @@ def load_model() -> Optional[object]:
|
||||
logger.info(f"Device: {settings.PARSER_DEVICE}")
|
||||
|
||||
try:
|
||||
# TODO: Implement actual model loading
|
||||
# Example for dots.ocr (adjust based on actual model structure):
|
||||
# from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
#
|
||||
# processor = AutoProcessor.from_pretrained(settings.PARSER_MODEL_NAME)
|
||||
# model = AutoModelForVision2Seq.from_pretrained(
|
||||
# settings.PARSER_MODEL_NAME,
|
||||
# device_map=settings.PARSER_DEVICE if settings.PARSER_DEVICE != "cpu" else None,
|
||||
# torch_dtype=torch.float16 if settings.PARSER_DEVICE != "cpu" else torch.float32
|
||||
# )
|
||||
#
|
||||
# if settings.PARSER_DEVICE == "cpu":
|
||||
# model = model.to("cpu")
|
||||
#
|
||||
# _model = {
|
||||
# "model": model,
|
||||
# "processor": processor,
|
||||
# "device": settings.PARSER_DEVICE
|
||||
# }
|
||||
#
|
||||
# logger.info("Model loaded successfully")
|
||||
# Load dots.ocr model
|
||||
# Note: Adjust imports and model class based on actual dots.ocr implementation
|
||||
# This is a template that should work with most Vision-Language models
|
||||
|
||||
# For now, return None (will use dummy parser)
|
||||
logger.warning("Model loading not yet implemented, will use dummy parser")
|
||||
try:
|
||||
from transformers import AutoModelForVision2Seq, AutoProcessor
|
||||
import torch
|
||||
except ImportError:
|
||||
logger.error("transformers or torch not installed. Install with: pip install transformers torch")
|
||||
if not settings.ALLOW_DUMMY_FALLBACK:
|
||||
raise
|
||||
return None
|
||||
|
||||
logger.info(f"Loading model from: {settings.PARSER_MODEL_NAME}")
|
||||
|
||||
# Load processor
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
settings.PARSER_MODEL_NAME,
|
||||
trust_remote_code=True # If model has custom code
|
||||
)
|
||||
|
||||
# Determine device and dtype
|
||||
device = settings.PARSER_DEVICE
|
||||
if device == "cuda" and not torch.cuda.is_available():
|
||||
logger.warning("CUDA not available, falling back to CPU")
|
||||
device = "cpu"
|
||||
elif device == "mps" and not hasattr(torch.backends, "mps") or not torch.backends.mps.is_available():
|
||||
logger.warning("MPS not available, falling back to CPU")
|
||||
device = "cpu"
|
||||
|
||||
dtype = torch.float16 if device != "cpu" else torch.float32
|
||||
|
||||
# Load model
|
||||
model = AutoModelForVision2Seq.from_pretrained(
|
||||
settings.PARSER_MODEL_NAME,
|
||||
device_map=device if device != "cpu" else None,
|
||||
torch_dtype=dtype,
|
||||
trust_remote_code=True
|
||||
)
|
||||
|
||||
if device == "cpu":
|
||||
model = model.to("cpu")
|
||||
|
||||
# Store model and processor
|
||||
_model = {
|
||||
"model": model,
|
||||
"processor": processor,
|
||||
"device": device
|
||||
}
|
||||
|
||||
logger.info(f"Model loaded successfully on device: {device}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}", exc_info=True)
|
||||
if not settings.ALLOW_DUMMY_FALLBACK:
|
||||
raise
|
||||
_model = None
|
||||
|
||||
except ImportError as e:
|
||||
|
||||
150
services/parser-service/app/runtime/model_output_parser.py
Normal file
150
services/parser-service/app/runtime/model_output_parser.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Parser for dots.ocr model output
|
||||
Converts model output to structured blocks
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Dict, Any, Optional
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_model_output_to_blocks(
|
||||
model_output: str,
|
||||
image_size: tuple[int, int],
|
||||
page_num: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parse dots.ocr model output into structured blocks
|
||||
|
||||
Args:
|
||||
model_output: Raw text output from model (may be JSON or plain text)
|
||||
image_size: (width, height) of the image
|
||||
page_num: Page number
|
||||
|
||||
Returns:
|
||||
List of block dictionaries
|
||||
"""
|
||||
blocks = []
|
||||
|
||||
try:
|
||||
# Try to parse as JSON first (if model outputs structured JSON)
|
||||
try:
|
||||
output_data = json.loads(model_output)
|
||||
if isinstance(output_data, dict) and "blocks" in output_data:
|
||||
# Model outputs structured format
|
||||
return output_data["blocks"]
|
||||
elif isinstance(output_data, list):
|
||||
# Model outputs list of blocks
|
||||
return output_data
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# Not JSON, treat as plain text
|
||||
pass
|
||||
|
||||
# Parse plain text output
|
||||
# This is a simple heuristic - adjust based on actual dots.ocr output format
|
||||
lines = model_output.strip().split('\n')
|
||||
|
||||
current_block = None
|
||||
reading_order = 1
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Heuristic: lines starting with # are headings
|
||||
if line.startswith('#'):
|
||||
# Save previous block
|
||||
if current_block:
|
||||
blocks.append(current_block)
|
||||
|
||||
# New heading block
|
||||
current_block = {
|
||||
"type": "heading",
|
||||
"text": line.lstrip('#').strip(),
|
||||
"bbox": {
|
||||
"x": 0,
|
||||
"y": reading_order * 30,
|
||||
"width": image_size[0],
|
||||
"height": 30
|
||||
},
|
||||
"reading_order": reading_order
|
||||
}
|
||||
reading_order += 1
|
||||
else:
|
||||
# Regular paragraph
|
||||
if current_block and current_block["type"] == "paragraph":
|
||||
# Append to existing paragraph
|
||||
current_block["text"] += " " + line
|
||||
else:
|
||||
# Save previous block
|
||||
if current_block:
|
||||
blocks.append(current_block)
|
||||
|
||||
# New paragraph block
|
||||
current_block = {
|
||||
"type": "paragraph",
|
||||
"text": line,
|
||||
"bbox": {
|
||||
"x": 0,
|
||||
"y": reading_order * 30,
|
||||
"width": image_size[0],
|
||||
"height": 30
|
||||
},
|
||||
"reading_order": reading_order
|
||||
}
|
||||
reading_order += 1
|
||||
|
||||
# Save last block
|
||||
if current_block:
|
||||
blocks.append(current_block)
|
||||
|
||||
# If no blocks were created, create a single paragraph with all text
|
||||
if not blocks:
|
||||
blocks.append({
|
||||
"type": "paragraph",
|
||||
"text": model_output.strip(),
|
||||
"bbox": {
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": image_size[0],
|
||||
"height": image_size[1]
|
||||
},
|
||||
"reading_order": 1
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing model output: {e}", exc_info=True)
|
||||
# Fallback: create single block with raw output
|
||||
blocks = [{
|
||||
"type": "paragraph",
|
||||
"text": model_output.strip() if model_output else "",
|
||||
"bbox": {
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": image_size[0],
|
||||
"height": image_size[1]
|
||||
},
|
||||
"reading_order": 1
|
||||
}]
|
||||
|
||||
return blocks
|
||||
|
||||
|
||||
def extract_layout_info(model_output: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Extract layout information from model output (if available)
|
||||
|
||||
Args:
|
||||
model_output: Model output dictionary
|
||||
|
||||
Returns:
|
||||
Layout info dictionary or None
|
||||
"""
|
||||
# This function should be customized based on actual dots.ocr output format
|
||||
# For now, return None (no layout info)
|
||||
return None
|
||||
|
||||
14
services/parser-service/pytest.ini
Normal file
14
services/parser-service/pytest.ini
Normal file
@@ -0,0 +1,14 @@
|
||||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
addopts =
|
||||
-v
|
||||
--tb=short
|
||||
--strict-markers
|
||||
markers =
|
||||
unit: Unit tests
|
||||
integration: Integration tests
|
||||
slow: Slow running tests
|
||||
|
||||
@@ -20,3 +20,8 @@ opencv-python>=4.8.0 # Optional, for advanced image processing
|
||||
# Utilities
|
||||
python-dotenv>=1.0.1
|
||||
|
||||
# Testing
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.21.0
|
||||
httpx>=0.25.0 # For TestClient
|
||||
|
||||
|
||||
4
services/parser-service/tests/__init__.py
Normal file
4
services/parser-service/tests/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""
|
||||
Tests for PARSER Service
|
||||
"""
|
||||
|
||||
106
services/parser-service/tests/conftest.py
Normal file
106
services/parser-service/tests/conftest.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Pytest configuration and fixtures
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import os
|
||||
from pathlib import Path
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
# Test fixtures directory
|
||||
FIXTURES_DIR = Path(__file__).parent / "fixtures"
|
||||
DOCS_DIR = FIXTURES_DIR / "docs"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fixtures_dir():
|
||||
"""Return fixtures directory path"""
|
||||
return FIXTURES_DIR
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def docs_dir():
|
||||
"""Return test documents directory path"""
|
||||
return DOCS_DIR
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_image_bytes():
|
||||
"""Create a sample image in memory"""
|
||||
img = Image.new('RGB', (800, 600), color='white')
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
return buffer.getvalue()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_pdf_bytes():
|
||||
"""Create a minimal PDF in memory (for testing)"""
|
||||
# Minimal valid PDF structure
|
||||
pdf_content = b"""%PDF-1.4
|
||||
1 0 obj
|
||||
<<
|
||||
/Type /Catalog
|
||||
/Pages 2 0 R
|
||||
>>
|
||||
endobj
|
||||
2 0 obj
|
||||
<<
|
||||
/Type /Pages
|
||||
/Kids [3 0 R]
|
||||
/Count 1
|
||||
>>
|
||||
endobj
|
||||
3 0 obj
|
||||
<<
|
||||
/Type /Page
|
||||
/Parent 2 0 R
|
||||
/MediaBox [0 0 612 792]
|
||||
/Contents 4 0 R
|
||||
/Resources <<
|
||||
/Font <<
|
||||
/F1 <<
|
||||
/Type /Font
|
||||
/Subtype /Type1
|
||||
/BaseFont /Helvetica
|
||||
>>
|
||||
>>
|
||||
>>
|
||||
>>
|
||||
endobj
|
||||
4 0 obj
|
||||
<<
|
||||
/Length 44
|
||||
>>
|
||||
stream
|
||||
BT
|
||||
/F1 12 Tf
|
||||
100 700 Td
|
||||
(Test PDF) Tj
|
||||
ET
|
||||
endstream
|
||||
endobj
|
||||
xref
|
||||
0 5
|
||||
0000000000 65535 f
|
||||
0000000009 00000 n
|
||||
0000000058 00000 n
|
||||
0000000115 00000 n
|
||||
0000000306 00000 n
|
||||
trailer
|
||||
<<
|
||||
/Size 5
|
||||
/Root 1 0 R
|
||||
>>
|
||||
startxref
|
||||
400
|
||||
%%EOF"""
|
||||
return pdf_content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir(tmp_path):
|
||||
"""Temporary directory for test files"""
|
||||
return tmp_path
|
||||
|
||||
109
services/parser-service/tests/test_api.py
Normal file
109
services/parser-service/tests/test_api.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Tests for API endpoints
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from app.main import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
|
||||
class TestHealthEndpoint:
|
||||
"""Tests for health check endpoint"""
|
||||
|
||||
def test_health(self):
|
||||
"""Test health endpoint"""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
assert data["service"] == "parser-service"
|
||||
|
||||
|
||||
class TestParseEndpoint:
|
||||
"""Tests for parse endpoint"""
|
||||
|
||||
def test_parse_no_file(self):
|
||||
"""Test parse without file"""
|
||||
response = client.post("/ocr/parse")
|
||||
assert response.status_code == 400
|
||||
|
||||
def test_parse_image(self):
|
||||
"""Test parsing image"""
|
||||
# Create test image
|
||||
img = Image.new('RGB', (800, 600), color='white')
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
buffer.seek(0)
|
||||
|
||||
response = client.post(
|
||||
"/ocr/parse",
|
||||
files={"file": ("test.png", buffer, "image/png")},
|
||||
data={"output_mode": "raw_json"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "document" in data or "chunks" in data or "markdown" in data
|
||||
|
||||
def test_parse_chunks_mode(self):
|
||||
"""Test parsing in chunks mode"""
|
||||
img = Image.new('RGB', (800, 600), color='white')
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
buffer.seek(0)
|
||||
|
||||
response = client.post(
|
||||
"/ocr/parse",
|
||||
files={"file": ("test.png", buffer, "image/png")},
|
||||
data={"output_mode": "chunks", "dao_id": "test-dao"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "chunks" in data
|
||||
|
||||
def test_parse_markdown_mode(self):
|
||||
"""Test parsing in markdown mode"""
|
||||
img = Image.new('RGB', (800, 600), color='white')
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
buffer.seek(0)
|
||||
|
||||
response = client.post(
|
||||
"/ocr/parse",
|
||||
files={"file": ("test.png", buffer, "image/png")},
|
||||
data={"output_mode": "markdown"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "markdown" in data
|
||||
|
||||
|
||||
class TestParseChunksEndpoint:
|
||||
"""Tests for parse_chunks endpoint"""
|
||||
|
||||
def test_parse_chunks(self):
|
||||
"""Test parse_chunks endpoint"""
|
||||
img = Image.new('RGB', (800, 600), color='white')
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format='PNG')
|
||||
buffer.seek(0)
|
||||
|
||||
response = client.post(
|
||||
"/ocr/parse_chunks",
|
||||
files={"file": ("test.png", buffer, "image/png")},
|
||||
data={"dao_id": "test-dao"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "chunks" in data
|
||||
assert "total_chunks" in data
|
||||
assert data["dao_id"] == "test-dao"
|
||||
|
||||
53
services/parser-service/tests/test_inference.py
Normal file
53
services/parser-service/tests/test_inference.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
Tests for inference functions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from app.runtime.inference import (
|
||||
parse_document_from_images,
|
||||
dummy_parse_document_from_images
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestDummyParser:
|
||||
"""Tests for dummy parser"""
|
||||
|
||||
def test_dummy_parse_document_from_images(self):
|
||||
"""Test dummy parser with images"""
|
||||
images = [
|
||||
Image.new('RGB', (800, 600), color='white'),
|
||||
Image.new('RGB', (800, 600), color='white')
|
||||
]
|
||||
|
||||
doc = dummy_parse_document_from_images(images, doc_id="test-doc")
|
||||
|
||||
assert doc.doc_id == "test-doc"
|
||||
assert len(doc.pages) == 2
|
||||
assert all(len(page.blocks) > 0 for page in doc.pages)
|
||||
assert all(page.width == 800 for page in doc.pages)
|
||||
assert all(page.height == 600 for page in doc.pages)
|
||||
|
||||
|
||||
class TestParseDocumentFromImages:
|
||||
"""Tests for parse_document_from_images"""
|
||||
|
||||
def test_parse_document_from_images_dummy_mode(self, monkeypatch):
|
||||
"""Test parsing with dummy mode enabled"""
|
||||
monkeypatch.setenv("USE_DUMMY_PARSER", "true")
|
||||
from app.core.config import Settings
|
||||
settings = Settings()
|
||||
|
||||
images = [Image.new('RGB', (800, 600), color='white')]
|
||||
doc = parse_document_from_images(images, doc_id="test-doc")
|
||||
|
||||
assert doc.doc_id == "test-doc"
|
||||
assert len(doc.pages) == 1
|
||||
|
||||
def test_parse_document_from_images_empty(self):
|
||||
"""Test parsing with empty images list"""
|
||||
with pytest.raises(ValueError, match="No valid images"):
|
||||
parse_document_from_images([], doc_id="test-doc")
|
||||
|
||||
193
services/parser-service/tests/test_postprocessing.py
Normal file
193
services/parser-service/tests/test_postprocessing.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
Tests for postprocessing functions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from app.runtime.postprocessing import (
|
||||
normalize_text,
|
||||
build_parsed_document,
|
||||
build_chunks,
|
||||
build_qa_pairs,
|
||||
build_markdown
|
||||
)
|
||||
from app.schemas import ParsedDocument, ParsedPage, ParsedBlock, BBox
|
||||
|
||||
|
||||
class TestTextNormalization:
|
||||
"""Tests for text normalization"""
|
||||
|
||||
def test_normalize_text_whitespace(self):
|
||||
"""Test removing extra whitespace"""
|
||||
text = " hello world "
|
||||
assert normalize_text(text) == "hello world"
|
||||
|
||||
def test_normalize_text_newlines(self):
|
||||
"""Test removing newlines"""
|
||||
text = "hello\n\nworld"
|
||||
assert normalize_text(text) == "hello world"
|
||||
|
||||
def test_normalize_text_empty(self):
|
||||
"""Test empty text"""
|
||||
assert normalize_text("") == ""
|
||||
assert normalize_text(" ") == ""
|
||||
|
||||
|
||||
class TestBuildParsedDocument:
|
||||
"""Tests for building ParsedDocument"""
|
||||
|
||||
def test_build_parsed_document(self):
|
||||
"""Test building ParsedDocument from model output"""
|
||||
pages_data = [
|
||||
{
|
||||
"blocks": [
|
||||
{
|
||||
"type": "heading",
|
||||
"text": " Title ",
|
||||
"bbox": {"x": 0, "y": 0, "width": 100, "height": 20},
|
||||
"reading_order": 1
|
||||
},
|
||||
{
|
||||
"type": "paragraph",
|
||||
"text": " Content ",
|
||||
"bbox": {"x": 0, "y": 30, "width": 100, "height": 50},
|
||||
"reading_order": 2
|
||||
}
|
||||
],
|
||||
"width": 800,
|
||||
"height": 1200
|
||||
}
|
||||
]
|
||||
|
||||
doc = build_parsed_document(pages_data, "test-doc", "pdf")
|
||||
|
||||
assert doc.doc_id == "test-doc"
|
||||
assert doc.doc_type == "pdf"
|
||||
assert len(doc.pages) == 1
|
||||
assert len(doc.pages[0].blocks) == 2
|
||||
assert doc.pages[0].blocks[0].text == "Title" # Normalized
|
||||
assert doc.pages[0].blocks[0].type == "heading"
|
||||
|
||||
|
||||
class TestBuildChunks:
|
||||
"""Tests for building chunks"""
|
||||
|
||||
def test_build_chunks(self):
|
||||
"""Test building chunks from ParsedDocument"""
|
||||
doc = ParsedDocument(
|
||||
doc_id="test-doc",
|
||||
doc_type="pdf",
|
||||
pages=[
|
||||
ParsedPage(
|
||||
page_num=1,
|
||||
blocks=[
|
||||
ParsedBlock(
|
||||
type="heading",
|
||||
text="Section 1",
|
||||
bbox=BBox(x=0, y=0, width=100, height=20),
|
||||
reading_order=1,
|
||||
page_num=1
|
||||
),
|
||||
ParsedBlock(
|
||||
type="paragraph",
|
||||
text="Content of section 1",
|
||||
bbox=BBox(x=0, y=30, width=100, height=50),
|
||||
reading_order=2,
|
||||
page_num=1
|
||||
)
|
||||
],
|
||||
width=800,
|
||||
height=1200
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
chunks = build_chunks(doc, dao_id="test-dao")
|
||||
|
||||
assert len(chunks) > 0
|
||||
assert all(chunk.page == 1 for chunk in chunks)
|
||||
assert all(chunk.metadata.get("dao_id") == "test-dao" for chunk in chunks)
|
||||
assert all(chunk.metadata.get("doc_id") == "test-doc" for chunk in chunks)
|
||||
|
||||
|
||||
class TestBuildQAPairs:
|
||||
"""Tests for building Q&A pairs"""
|
||||
|
||||
def test_build_qa_pairs(self):
|
||||
"""Test building Q&A pairs from ParsedDocument"""
|
||||
doc = ParsedDocument(
|
||||
doc_id="test-doc",
|
||||
doc_type="pdf",
|
||||
pages=[
|
||||
ParsedPage(
|
||||
page_num=1,
|
||||
blocks=[
|
||||
ParsedBlock(
|
||||
type="heading",
|
||||
text="What is X?",
|
||||
bbox=BBox(x=0, y=0, width=100, height=20),
|
||||
reading_order=1,
|
||||
page_num=1
|
||||
),
|
||||
ParsedBlock(
|
||||
type="paragraph",
|
||||
text="X is a test",
|
||||
bbox=BBox(x=0, y=30, width=100, height=50),
|
||||
reading_order=2,
|
||||
page_num=1
|
||||
)
|
||||
],
|
||||
width=800,
|
||||
height=1200
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
qa_pairs = build_qa_pairs(doc, max_pairs=5)
|
||||
|
||||
assert len(qa_pairs) > 0
|
||||
assert all(isinstance(qa.question, str) for qa in qa_pairs)
|
||||
assert all(isinstance(qa.answer, str) for qa in qa_pairs)
|
||||
assert all(qa.source_page == 1 for qa in qa_pairs)
|
||||
|
||||
|
||||
class TestBuildMarkdown:
|
||||
"""Tests for building Markdown"""
|
||||
|
||||
def test_build_markdown(self):
|
||||
"""Test building Markdown from ParsedDocument"""
|
||||
doc = ParsedDocument(
|
||||
doc_id="test-doc",
|
||||
doc_type="pdf",
|
||||
pages=[
|
||||
ParsedPage(
|
||||
page_num=1,
|
||||
blocks=[
|
||||
ParsedBlock(
|
||||
type="heading",
|
||||
text="Title",
|
||||
bbox=BBox(x=0, y=0, width=100, height=20),
|
||||
reading_order=1,
|
||||
page_num=1
|
||||
),
|
||||
ParsedBlock(
|
||||
type="paragraph",
|
||||
text="Content",
|
||||
bbox=BBox(x=0, y=30, width=100, height=50),
|
||||
reading_order=2,
|
||||
page_num=1
|
||||
)
|
||||
],
|
||||
width=800,
|
||||
height=1200
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
markdown = build_markdown(doc)
|
||||
|
||||
assert isinstance(markdown, str)
|
||||
assert "Title" in markdown
|
||||
assert "Content" in markdown
|
||||
assert "###" in markdown or "####" in markdown # Heading markers
|
||||
|
||||
123
services/parser-service/tests/test_preprocessing.py
Normal file
123
services/parser-service/tests/test_preprocessing.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Tests for preprocessing functions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
import io
|
||||
|
||||
from app.runtime.preprocessing import (
|
||||
convert_pdf_to_images,
|
||||
load_image,
|
||||
normalize_image,
|
||||
prepare_images_for_model,
|
||||
detect_file_type,
|
||||
validate_file_size
|
||||
)
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class TestImageLoading:
|
||||
"""Tests for image loading functions"""
|
||||
|
||||
def test_load_image_png(self, sample_image_bytes):
|
||||
"""Test loading PNG image"""
|
||||
image = load_image(sample_image_bytes)
|
||||
assert isinstance(image, Image.Image)
|
||||
assert image.size == (800, 600)
|
||||
|
||||
def test_load_image_invalid(self):
|
||||
"""Test loading invalid image"""
|
||||
invalid_bytes = b"not an image"
|
||||
with pytest.raises(ValueError, match="Image loading failed"):
|
||||
load_image(invalid_bytes)
|
||||
|
||||
|
||||
class TestPDFConversion:
|
||||
"""Tests for PDF conversion"""
|
||||
|
||||
def test_convert_pdf_to_images(self, sample_pdf_bytes):
|
||||
"""Test converting PDF to images"""
|
||||
images = convert_pdf_to_images(sample_pdf_bytes, dpi=150, max_pages=1)
|
||||
assert len(images) > 0
|
||||
assert all(isinstance(img, Image.Image) for img in images)
|
||||
|
||||
def test_convert_pdf_max_pages(self, sample_pdf_bytes):
|
||||
"""Test PDF conversion respects max_pages"""
|
||||
images = convert_pdf_to_images(sample_pdf_bytes, max_pages=1)
|
||||
assert len(images) <= 1
|
||||
|
||||
|
||||
class TestImageNormalization:
|
||||
"""Tests for image normalization"""
|
||||
|
||||
def test_normalize_image_rgb(self, sample_image_bytes):
|
||||
"""Test image is converted to RGB"""
|
||||
image = load_image(sample_image_bytes)
|
||||
normalized = normalize_image(image)
|
||||
assert normalized.mode == 'RGB'
|
||||
|
||||
def test_normalize_image_resize(self):
|
||||
"""Test image is resized if too large"""
|
||||
# Create large image
|
||||
large_img = Image.new('RGB', (3000, 2000), color='white')
|
||||
normalized = normalize_image(large_img, max_size=2048)
|
||||
assert normalized.width <= 2048 or normalized.height <= 2048
|
||||
|
||||
def test_normalize_image_small(self):
|
||||
"""Test small image is not resized"""
|
||||
small_img = Image.new('RGB', (500, 400), color='white')
|
||||
normalized = normalize_image(small_img, max_size=2048)
|
||||
assert normalized.size == small_img.size
|
||||
|
||||
|
||||
class TestFileTypeDetection:
|
||||
"""Tests for file type detection"""
|
||||
|
||||
def test_detect_pdf(self, sample_pdf_bytes):
|
||||
"""Test PDF detection"""
|
||||
assert detect_file_type(sample_pdf_bytes) == "pdf"
|
||||
assert detect_file_type(sample_pdf_bytes, "test.pdf") == "pdf"
|
||||
|
||||
def test_detect_image(self, sample_image_bytes):
|
||||
"""Test image detection"""
|
||||
assert detect_file_type(sample_image_bytes) == "image"
|
||||
assert detect_file_type(sample_image_bytes, "test.png") == "image"
|
||||
|
||||
def test_detect_unsupported(self):
|
||||
"""Test unsupported file type"""
|
||||
with pytest.raises(ValueError, match="Unsupported file type"):
|
||||
detect_file_type(b"random bytes", "test.xyz")
|
||||
|
||||
|
||||
class TestFileSizeValidation:
|
||||
"""Tests for file size validation"""
|
||||
|
||||
def test_validate_file_size_ok(self):
|
||||
"""Test valid file size"""
|
||||
small_file = b"x" * (10 * 1024 * 1024) # 10 MB
|
||||
validate_file_size(small_file) # Should not raise
|
||||
|
||||
def test_validate_file_size_too_large(self):
|
||||
"""Test file size exceeds limit"""
|
||||
large_file = b"x" * (100 * 1024 * 1024) # 100 MB
|
||||
with pytest.raises(ValueError, match="exceeds maximum"):
|
||||
validate_file_size(large_file)
|
||||
|
||||
|
||||
class TestPrepareImages:
|
||||
"""Tests for preparing images for model"""
|
||||
|
||||
def test_prepare_images_for_model(self, sample_image_bytes):
|
||||
"""Test preparing images for model"""
|
||||
image = load_image(sample_image_bytes)
|
||||
prepared = prepare_images_for_model([image])
|
||||
assert len(prepared) == 1
|
||||
assert isinstance(prepared[0], Image.Image)
|
||||
assert prepared[0].mode == 'RGB'
|
||||
|
||||
def test_prepare_images_empty(self):
|
||||
"""Test preparing empty list"""
|
||||
prepared = prepare_images_for_model([])
|
||||
assert len(prepared) == 0
|
||||
|
||||
Reference in New Issue
Block a user