""" RAG Service - FastAPI application Retrieval-Augmented Generation for MicroDAO """ import logging import os from typing import Any, Dict from contextlib import asynccontextmanager import psycopg2 from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from app.core.config import settings from app.document_store import get_document_store, _make_document from app.embedding import get_text_embedder from app.models import ( IngestRequest, IngestResponse, QueryRequest, QueryResponse, UpsertRequest, UpsertResponse, DeleteByFingerprintRequest, DeleteResponse, ) from app.ingest_pipeline import ingest_parsed_document from app.query_pipeline import answer_query from app.event_worker import event_worker logger = logging.getLogger(__name__) @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan events: startup and shutdown""" import threading # Startup logger.info("Starting RAG Service...") try: dsn = settings.PG_DSN.replace("postgresql+psycopg2", "postgresql") with psycopg2.connect(dsn) as conn: with conn.cursor() as cur: cur.execute("create extension if not exists vector;") conn.commit() logger.info("pgvector extension ensured") except Exception as e: logger.error(f"Failed to ensure pgvector extension: {e}", exc_info=True) raise # Start event worker in a background thread def run_event_worker(): import asyncio asyncio.run(event_worker()) event_worker_thread = threading.Thread(target=run_event_worker, daemon=True) event_worker_thread.start() logger.info("RAG Event Worker started in background thread") app.state.event_worker_thread = event_worker_thread yield # Shutdown logger.info("Shutting down RAG Service...") import asyncio from app.event_worker import close_subscriptions await close_subscriptions() if event_worker_thread.is_alive(): logger.info("Event Worker is still running, will shut down automatically") # FastAPI app app = FastAPI( title="RAG Service", description="Retrieval-Augmented Generation service for MicroDAO", version="1.0.0", lifespan=lifespan ) NODE_ENV = os.getenv("NODE_ENV", "production").lower() DEBUG_ENDPOINTS = os.getenv("DEBUG_ENDPOINTS", "false").lower() == "true" # CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/health") async def health(): """Health check endpoint""" return { "status": "healthy", "service": "rag-service", "version": "1.0.0" } @app.post("/ingest", response_model=IngestResponse) async def ingest_endpoint(request: IngestRequest): """ Ingest parsed document from PARSER service into RAG Body: - dao_id: DAO identifier - doc_id: Document identifier - parsed_json: ParsedDocument JSON from PARSER service - user_id: Optional user identifier """ try: result = await ingest_parsed_document( dao_id=request.dao_id, doc_id=request.doc_id, parsed_json=request.parsed_json, user_id=request.user_id ) return IngestResponse(**result) except Exception as e: logger.error(f"Ingest endpoint error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.post("/query", response_model=QueryResponse) async def query_endpoint(request: QueryRequest): """ Answer query using RAG pipeline Body: - dao_id: DAO identifier - question: User question - top_k: Optional number of documents to retrieve - user_id: Optional user identifier """ try: result = await answer_query( dao_id=request.dao_id, question=request.question, top_k=request.top_k, user_id=request.user_id ) return QueryResponse(**result) except Exception as e: logger.error(f"Query endpoint error: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.post("/index/upsert", response_model=UpsertResponse) async def index_upsert(request: UpsertRequest): try: if not request.chunks: return UpsertResponse(status="error", indexed_count=0, message="No chunks provided") embedder = get_text_embedder() document_store = get_document_store() texts = [chunk.content for chunk in request.chunks] embeddings_result = embedder.run(texts=texts) embeddings = embeddings_result.get("embeddings") or [] documents = [] for idx, chunk in enumerate(request.chunks): embedding = embeddings[idx] if idx < len(embeddings) else None documents.append(_make_document(content=chunk.content, meta=chunk.meta, embedding=embedding)) document_store.write_documents(documents) return UpsertResponse(status="success", indexed_count=len(documents)) except Exception as e: logger.error(f"Index upsert failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.post("/index/delete_by_fingerprint", response_model=DeleteResponse) async def delete_by_fingerprint(request: DeleteByFingerprintRequest): try: document_store = get_document_store() document_store.delete_documents(filters={"index_fingerprint": {"$eq": request.fingerprint}}) return DeleteResponse(status="success", deleted_count=0, message="Delete requested") except Exception as e: logger.error(f"Delete by fingerprint failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) @app.get("/debug/chunks") async def debug_chunks(artifact_id: str, limit: int = 3) -> Dict[str, Any]: if NODE_ENV == "production" and not DEBUG_ENDPOINTS: raise HTTPException(status_code=404, detail="Not Found") try: document_store = get_document_store() docs = document_store.filter_documents( filters={"artifact_id": artifact_id}, top_k=limit, return_embedding=False, ) items = [] for doc in docs: items.append({"content_preview": doc.content[:200], "meta": doc.meta}) return {"items": items, "count": len(items)} except Exception as e: logger.error(f"Debug chunks failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn from app.core.config import settings uvicorn.run( "app.main:app", host=settings.API_HOST, port=settings.API_PORT, reload=True )