Files
microdao-daarion/services/chandra-inference/main.py
Apple 5290287058 feat: implement TTS, Document processing, and Memory Service /facts API
- TTS: xtts-v2 integration with voice cloning support
- Document: docling integration for PDF/DOCX/PPTX processing
- Memory Service: added /facts/upsert, /facts/{key}, /facts endpoints
- Added required dependencies (TTS, docling)
2026-01-17 08:16:37 -08:00

266 lines
8.2 KiB
Python

"""
Chandra Inference Service
Direct inference using HuggingFace model
"""
import logging
import os
from typing import Optional, Dict, Any
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from PIL import Image
from io import BytesIO
import base64
import torch
logger = logging.getLogger(__name__)
app = FastAPI(title="Chandra Inference Service")
# Configuration
# Using GOT-OCR2.0 - best open-source OCR for documents and tables
# Alternative: microsoft/trocr-base-printed for simple text
OCR_MODEL = os.getenv("OCR_MODEL", "stepfun-ai/GOT-OCR2_0")
DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
# Load model (lazy loading)
model = None
processor = None
def load_model():
"""Load OCR model from HuggingFace"""
global model, processor
if model is not None:
return
try:
logger.info(f"Loading OCR model: {OCR_MODEL} on {DEVICE}")
# Try GOT-OCR2.0 first (best for documents and tables)
if "GOT-OCR" in OCR_MODEL or "got-ocr" in OCR_MODEL.lower():
from transformers import AutoModel, AutoTokenizer
# GOT-OCR uses different loading
tokenizer = AutoTokenizer.from_pretrained(OCR_MODEL, trust_remote_code=True)
model = AutoModel.from_pretrained(
OCR_MODEL,
trust_remote_code=True,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto" if DEVICE == "cuda" else None
)
processor = tokenizer
logger.info(f"GOT-OCR2.0 model loaded on {DEVICE}")
else:
# Fallback to TrOCR for simple text OCR
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
processor = TrOCRProcessor.from_pretrained(OCR_MODEL)
model = VisionEncoderDecoderModel.from_pretrained(
OCR_MODEL,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
)
if DEVICE == "cpu":
model = model.to(DEVICE)
else:
model = model.cuda()
if model:
model.eval()
logger.info(f"Model loaded successfully on {DEVICE}")
except Exception as e:
logger.error(f"Failed to load model: {e}", exc_info=True)
logger.warning("Service will run in degraded mode without model")
@app.on_event("startup")
async def startup():
"""Load model on startup"""
try:
load_model()
except Exception as e:
logger.error(f"Startup failed: {e}", exc_info=True)
@app.get("/health")
async def health():
"""Health check endpoint"""
try:
if model is None:
return {
"status": "loading",
"service": "ocr-inference",
"model": OCR_MODEL,
"device": DEVICE
}
return {
"status": "healthy",
"service": "ocr-inference",
"model": OCR_MODEL,
"device": DEVICE,
"cuda_available": torch.cuda.is_available()
}
except Exception as e:
return {
"status": "unhealthy",
"service": "ocr-inference",
"error": str(e)
}
@app.post("/process")
async def process_document(
file: Optional[UploadFile] = File(None),
doc_url: Optional[str] = Form(None),
doc_base64: Optional[str] = Form(None),
output_format: str = Form("markdown"),
accurate_mode: bool = Form(False)
):
"""
Process a document using Chandra OCR.
Args:
file: Uploaded file (PDF, image)
doc_url: URL to document
doc_base64: Base64 encoded document
output_format: markdown, html, or json
accurate_mode: Use accurate mode (slower but more precise)
"""
try:
# Ensure model is loaded
if model is None:
load_model()
# Get image data
image_data = None
if file:
image_data = await file.read()
elif doc_base64:
image_data = base64.b64decode(doc_base64)
elif doc_url:
import httpx
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(doc_url)
if response.status_code == 200:
image_data = response.content
else:
raise HTTPException(
status_code=400,
detail=f"Failed to download document: {response.status_code}"
)
else:
raise HTTPException(
status_code=400,
detail="No document provided"
)
# Load image
image = Image.open(BytesIO(image_data)).convert("RGB")
# Process with OCR model
if model is None or processor is None:
raise HTTPException(
status_code=503,
detail="OCR model not loaded. Check logs for details."
)
# Different processing for GOT-OCR vs TrOCR
if "GOT-OCR" in OCR_MODEL or "got-ocr" in OCR_MODEL.lower():
# GOT-OCR2.0 processing
with torch.no_grad():
result = model.chat(processor, image, ocr_type='ocr')
generated_text = result if isinstance(result, str) else str(result)
else:
# TrOCR processing
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
with torch.no_grad():
generated_ids = model.generate(
inputs.pixel_values,
max_length=512,
num_beams=5 if accurate_mode else 3
)
generated_text = processor.batch_decode(
generated_ids,
skip_special_tokens=True
)[0]
# Format output based on requested format
if output_format == "markdown":
result = {
"markdown": generated_text,
"format": "markdown"
}
elif output_format == "html":
# Convert markdown to HTML (simplified)
result = {
"html": generated_text.replace("\n", "<br>"),
"format": "html"
}
else: # json
result = {
"text": generated_text,
"format": "json",
"metadata": {
"model": CHANDRA_MODEL,
"device": DEVICE,
"accurate_mode": accurate_mode
}
}
return {
"success": True,
"output_format": output_format,
"result": result
}
except Exception as e:
logger.error(f"Document processing failed: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail=f"Processing failed: {str(e)}"
)
@app.get("/models")
async def list_models():
"""List available models"""
return {
"current_model": OCR_MODEL,
"available_models": [
{
"name": "stepfun-ai/GOT-OCR2_0",
"type": "document_ocr",
"description": "Best for documents, tables, formulas, handwriting",
"vram": "~8GB"
},
{
"name": "microsoft/trocr-base-printed",
"type": "text_ocr",
"description": "Fast OCR for printed text",
"vram": "~2GB"
},
{
"name": "microsoft/trocr-base-handwritten",
"type": "handwriting_ocr",
"description": "OCR for handwritten text",
"vram": "~2GB"
}
],
"note": "GOT-OCR2.0 recommended for documents and tables. TrOCR for simple text.",
"device": DEVICE,
"cuda_available": torch.cuda.is_available()
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)