""" Chandra Inference Service Direct inference using HuggingFace model """ import logging import os from typing import Optional, Dict, Any from fastapi import FastAPI, HTTPException, File, UploadFile, Form from fastapi.responses import JSONResponse from pydantic import BaseModel from PIL import Image from io import BytesIO import base64 import torch logger = logging.getLogger(__name__) app = FastAPI(title="Chandra Inference Service") # Configuration # Using GOT-OCR2.0 - best open-source OCR for documents and tables # Alternative: microsoft/trocr-base-printed for simple text OCR_MODEL = os.getenv("OCR_MODEL", "stepfun-ai/GOT-OCR2_0") DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu") # Load model (lazy loading) model = None processor = None def load_model(): """Load OCR model from HuggingFace""" global model, processor if model is not None: return try: logger.info(f"Loading OCR model: {OCR_MODEL} on {DEVICE}") # Try GOT-OCR2.0 first (best for documents and tables) if "GOT-OCR" in OCR_MODEL or "got-ocr" in OCR_MODEL.lower(): from transformers import AutoModel, AutoTokenizer # GOT-OCR uses different loading tokenizer = AutoTokenizer.from_pretrained(OCR_MODEL, trust_remote_code=True) model = AutoModel.from_pretrained( OCR_MODEL, trust_remote_code=True, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, device_map="auto" if DEVICE == "cuda" else None ) processor = tokenizer logger.info(f"GOT-OCR2.0 model loaded on {DEVICE}") else: # Fallback to TrOCR for simple text OCR from transformers import TrOCRProcessor, VisionEncoderDecoderModel processor = TrOCRProcessor.from_pretrained(OCR_MODEL) model = VisionEncoderDecoderModel.from_pretrained( OCR_MODEL, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32 ) if DEVICE == "cpu": model = model.to(DEVICE) else: model = model.cuda() if model: model.eval() logger.info(f"Model loaded successfully on {DEVICE}") except Exception as e: logger.error(f"Failed to load model: {e}", exc_info=True) logger.warning("Service will run in degraded mode without model") @app.on_event("startup") async def startup(): """Load model on startup""" try: load_model() except Exception as e: logger.error(f"Startup failed: {e}", exc_info=True) @app.get("/health") async def health(): """Health check endpoint""" try: if model is None: return { "status": "loading", "service": "ocr-inference", "model": OCR_MODEL, "device": DEVICE } return { "status": "healthy", "service": "ocr-inference", "model": OCR_MODEL, "device": DEVICE, "cuda_available": torch.cuda.is_available() } except Exception as e: return { "status": "unhealthy", "service": "ocr-inference", "error": str(e) } @app.post("/process") async def process_document( file: Optional[UploadFile] = File(None), doc_url: Optional[str] = Form(None), doc_base64: Optional[str] = Form(None), output_format: str = Form("markdown"), accurate_mode: bool = Form(False) ): """ Process a document using Chandra OCR. Args: file: Uploaded file (PDF, image) doc_url: URL to document doc_base64: Base64 encoded document output_format: markdown, html, or json accurate_mode: Use accurate mode (slower but more precise) """ try: # Ensure model is loaded if model is None: load_model() # Get image data image_data = None if file: image_data = await file.read() elif doc_base64: image_data = base64.b64decode(doc_base64) elif doc_url: import httpx async with httpx.AsyncClient(timeout=30.0) as client: response = await client.get(doc_url) if response.status_code == 200: image_data = response.content else: raise HTTPException( status_code=400, detail=f"Failed to download document: {response.status_code}" ) else: raise HTTPException( status_code=400, detail="No document provided" ) # Load image image = Image.open(BytesIO(image_data)).convert("RGB") # Process with OCR model if model is None or processor is None: raise HTTPException( status_code=503, detail="OCR model not loaded. Check logs for details." ) # Different processing for GOT-OCR vs TrOCR if "GOT-OCR" in OCR_MODEL or "got-ocr" in OCR_MODEL.lower(): # GOT-OCR2.0 processing with torch.no_grad(): result = model.chat(processor, image, ocr_type='ocr') generated_text = result if isinstance(result, str) else str(result) else: # TrOCR processing inputs = processor(images=image, return_tensors="pt").to(DEVICE) with torch.no_grad(): generated_ids = model.generate( inputs.pixel_values, max_length=512, num_beams=5 if accurate_mode else 3 ) generated_text = processor.batch_decode( generated_ids, skip_special_tokens=True )[0] # Format output based on requested format if output_format == "markdown": result = { "markdown": generated_text, "format": "markdown" } elif output_format == "html": # Convert markdown to HTML (simplified) result = { "html": generated_text.replace("\n", "
"), "format": "html" } else: # json result = { "text": generated_text, "format": "json", "metadata": { "model": CHANDRA_MODEL, "device": DEVICE, "accurate_mode": accurate_mode } } return { "success": True, "output_format": output_format, "result": result } except Exception as e: logger.error(f"Document processing failed: {e}", exc_info=True) raise HTTPException( status_code=500, detail=f"Processing failed: {str(e)}" ) @app.get("/models") async def list_models(): """List available models""" return { "current_model": OCR_MODEL, "available_models": [ { "name": "stepfun-ai/GOT-OCR2_0", "type": "document_ocr", "description": "Best for documents, tables, formulas, handwriting", "vram": "~8GB" }, { "name": "microsoft/trocr-base-printed", "type": "text_ocr", "description": "Fast OCR for printed text", "vram": "~2GB" }, { "name": "microsoft/trocr-base-handwritten", "type": "handwriting_ocr", "description": "OCR for handwritten text", "vram": "~2GB" } ], "note": "GOT-OCR2.0 recommended for documents and tables. TrOCR for simple text.", "device": DEVICE, "cuda_available": torch.cuda.is_available() } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)