- TTS: xtts-v2 integration with voice cloning support
- Document: docling integration for PDF/DOCX/PPTX processing
- Memory Service: added /facts/upsert, /facts/{key}, /facts endpoints
- Added required dependencies (TTS, docling)
266 lines
8.2 KiB
Python
266 lines
8.2 KiB
Python
"""
|
|
Chandra Inference Service
|
|
Direct inference using HuggingFace model
|
|
"""
|
|
import logging
|
|
import os
|
|
from typing import Optional, Dict, Any
|
|
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel
|
|
from PIL import Image
|
|
from io import BytesIO
|
|
import base64
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
app = FastAPI(title="Chandra Inference Service")
|
|
|
|
# Configuration
|
|
# Using GOT-OCR2.0 - best open-source OCR for documents and tables
|
|
# Alternative: microsoft/trocr-base-printed for simple text
|
|
OCR_MODEL = os.getenv("OCR_MODEL", "stepfun-ai/GOT-OCR2_0")
|
|
DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Load model (lazy loading)
|
|
model = None
|
|
processor = None
|
|
|
|
|
|
def load_model():
|
|
"""Load OCR model from HuggingFace"""
|
|
global model, processor
|
|
|
|
if model is not None:
|
|
return
|
|
|
|
try:
|
|
logger.info(f"Loading OCR model: {OCR_MODEL} on {DEVICE}")
|
|
|
|
# Try GOT-OCR2.0 first (best for documents and tables)
|
|
if "GOT-OCR" in OCR_MODEL or "got-ocr" in OCR_MODEL.lower():
|
|
from transformers import AutoModel, AutoTokenizer
|
|
|
|
# GOT-OCR uses different loading
|
|
tokenizer = AutoTokenizer.from_pretrained(OCR_MODEL, trust_remote_code=True)
|
|
model = AutoModel.from_pretrained(
|
|
OCR_MODEL,
|
|
trust_remote_code=True,
|
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
|
device_map="auto" if DEVICE == "cuda" else None
|
|
)
|
|
processor = tokenizer
|
|
logger.info(f"GOT-OCR2.0 model loaded on {DEVICE}")
|
|
|
|
else:
|
|
# Fallback to TrOCR for simple text OCR
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
|
|
|
processor = TrOCRProcessor.from_pretrained(OCR_MODEL)
|
|
model = VisionEncoderDecoderModel.from_pretrained(
|
|
OCR_MODEL,
|
|
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
|
|
)
|
|
|
|
if DEVICE == "cpu":
|
|
model = model.to(DEVICE)
|
|
else:
|
|
model = model.cuda()
|
|
|
|
if model:
|
|
model.eval()
|
|
logger.info(f"Model loaded successfully on {DEVICE}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to load model: {e}", exc_info=True)
|
|
logger.warning("Service will run in degraded mode without model")
|
|
|
|
|
|
@app.on_event("startup")
|
|
async def startup():
|
|
"""Load model on startup"""
|
|
try:
|
|
load_model()
|
|
except Exception as e:
|
|
logger.error(f"Startup failed: {e}", exc_info=True)
|
|
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
"""Health check endpoint"""
|
|
try:
|
|
if model is None:
|
|
return {
|
|
"status": "loading",
|
|
"service": "ocr-inference",
|
|
"model": OCR_MODEL,
|
|
"device": DEVICE
|
|
}
|
|
|
|
return {
|
|
"status": "healthy",
|
|
"service": "ocr-inference",
|
|
"model": OCR_MODEL,
|
|
"device": DEVICE,
|
|
"cuda_available": torch.cuda.is_available()
|
|
}
|
|
except Exception as e:
|
|
return {
|
|
"status": "unhealthy",
|
|
"service": "ocr-inference",
|
|
"error": str(e)
|
|
}
|
|
|
|
|
|
@app.post("/process")
|
|
async def process_document(
|
|
file: Optional[UploadFile] = File(None),
|
|
doc_url: Optional[str] = Form(None),
|
|
doc_base64: Optional[str] = Form(None),
|
|
output_format: str = Form("markdown"),
|
|
accurate_mode: bool = Form(False)
|
|
):
|
|
"""
|
|
Process a document using Chandra OCR.
|
|
|
|
Args:
|
|
file: Uploaded file (PDF, image)
|
|
doc_url: URL to document
|
|
doc_base64: Base64 encoded document
|
|
output_format: markdown, html, or json
|
|
accurate_mode: Use accurate mode (slower but more precise)
|
|
"""
|
|
try:
|
|
# Ensure model is loaded
|
|
if model is None:
|
|
load_model()
|
|
|
|
# Get image data
|
|
image_data = None
|
|
|
|
if file:
|
|
image_data = await file.read()
|
|
elif doc_base64:
|
|
image_data = base64.b64decode(doc_base64)
|
|
elif doc_url:
|
|
import httpx
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.get(doc_url)
|
|
if response.status_code == 200:
|
|
image_data = response.content
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Failed to download document: {response.status_code}"
|
|
)
|
|
else:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail="No document provided"
|
|
)
|
|
|
|
# Load image
|
|
image = Image.open(BytesIO(image_data)).convert("RGB")
|
|
|
|
# Process with OCR model
|
|
if model is None or processor is None:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="OCR model not loaded. Check logs for details."
|
|
)
|
|
|
|
# Different processing for GOT-OCR vs TrOCR
|
|
if "GOT-OCR" in OCR_MODEL or "got-ocr" in OCR_MODEL.lower():
|
|
# GOT-OCR2.0 processing
|
|
with torch.no_grad():
|
|
result = model.chat(processor, image, ocr_type='ocr')
|
|
generated_text = result if isinstance(result, str) else str(result)
|
|
else:
|
|
# TrOCR processing
|
|
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
|
|
|
|
with torch.no_grad():
|
|
generated_ids = model.generate(
|
|
inputs.pixel_values,
|
|
max_length=512,
|
|
num_beams=5 if accurate_mode else 3
|
|
)
|
|
|
|
generated_text = processor.batch_decode(
|
|
generated_ids,
|
|
skip_special_tokens=True
|
|
)[0]
|
|
|
|
# Format output based on requested format
|
|
if output_format == "markdown":
|
|
result = {
|
|
"markdown": generated_text,
|
|
"format": "markdown"
|
|
}
|
|
elif output_format == "html":
|
|
# Convert markdown to HTML (simplified)
|
|
result = {
|
|
"html": generated_text.replace("\n", "<br>"),
|
|
"format": "html"
|
|
}
|
|
else: # json
|
|
result = {
|
|
"text": generated_text,
|
|
"format": "json",
|
|
"metadata": {
|
|
"model": CHANDRA_MODEL,
|
|
"device": DEVICE,
|
|
"accurate_mode": accurate_mode
|
|
}
|
|
}
|
|
|
|
return {
|
|
"success": True,
|
|
"output_format": output_format,
|
|
"result": result
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Document processing failed: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Processing failed: {str(e)}"
|
|
)
|
|
|
|
|
|
@app.get("/models")
|
|
async def list_models():
|
|
"""List available models"""
|
|
return {
|
|
"current_model": OCR_MODEL,
|
|
"available_models": [
|
|
{
|
|
"name": "stepfun-ai/GOT-OCR2_0",
|
|
"type": "document_ocr",
|
|
"description": "Best for documents, tables, formulas, handwriting",
|
|
"vram": "~8GB"
|
|
},
|
|
{
|
|
"name": "microsoft/trocr-base-printed",
|
|
"type": "text_ocr",
|
|
"description": "Fast OCR for printed text",
|
|
"vram": "~2GB"
|
|
},
|
|
{
|
|
"name": "microsoft/trocr-base-handwritten",
|
|
"type": "handwriting_ocr",
|
|
"description": "OCR for handwritten text",
|
|
"vram": "~2GB"
|
|
}
|
|
],
|
|
"note": "GOT-OCR2.0 recommended for documents and tables. TrOCR for simple text.",
|
|
"device": DEVICE,
|
|
"cuda_available": torch.cuda.is_available()
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|