""" DAARION Memory Service - FastAPI Application Трирівнева пам'ять агентів: - Short-term: conversation events (робочий буфер) - Mid-term: thread summaries (сесійна/тематична) - Long-term: memory items (персональна/проектна) """ from contextlib import asynccontextmanager from typing import List, Optional from fastapi import Depends from uuid import UUID import structlog from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from .config import get_settings from .models import ( CreateThreadRequest, AddEventRequest, CreateMemoryRequest, MemoryFeedbackRequest, RetrievalRequest, SummaryRequest, ThreadResponse, EventResponse, MemoryResponse, SummaryResponse, RetrievalResponse, RetrievalResult, ContextResponse, MemoryCategory, FeedbackAction ) from .vector_store import vector_store from .database import db from .auth import get_current_service, get_current_service_optional logger = structlog.get_logger() settings = get_settings() @asynccontextmanager async def lifespan(app: FastAPI): """Startup and shutdown events""" # Startup logger.info("starting_memory_service") await db.connect() await vector_store.initialize() yield # Shutdown await db.disconnect() logger.info("memory_service_stopped") app = FastAPI( title="DAARION Memory Service", description="Agent memory management with PostgreSQL + Qdrant + Cohere", version="1.0.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============================================================================ # HEALTH # ============================================================================ @app.get("/health") async def health(): """Health check""" return { "status": "healthy", "service": settings.service_name, "vector_store": await vector_store.get_collection_stats() } # ============================================================================ # THREADS (Conversations) # ============================================================================ @app.post("/threads", response_model=ThreadResponse) async def create_thread( request: CreateThreadRequest, service: Optional[dict] = Depends(get_current_service_optional) ): """Create new conversation thread""" # Auth опціональний: якщо JWT надано, перевіряємо; якщо ні - дозволяємо (dev режим) thread = await db.create_thread( org_id=request.org_id, user_id=request.user_id, workspace_id=request.workspace_id, agent_id=request.agent_id, title=request.title, tags=request.tags, metadata=request.metadata ) return thread @app.get("/threads/{thread_id}", response_model=ThreadResponse) async def get_thread(thread_id: UUID): """Get thread by ID""" thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") return thread @app.get("/threads", response_model=List[ThreadResponse]) async def list_threads( user_id: UUID = Query(...), org_id: UUID = Query(...), workspace_id: Optional[UUID] = None, agent_id: Optional[UUID] = None, limit: int = Query(default=20, le=100) ): """List threads for user""" threads = await db.list_threads( org_id=org_id, user_id=user_id, workspace_id=workspace_id, agent_id=agent_id, limit=limit ) return threads # ============================================================================ # EVENTS (Short-term Memory) # ============================================================================ @app.post("/events", response_model=EventResponse) async def add_event( request: AddEventRequest, service: Optional[dict] = Depends(get_current_service_optional) ): """Add event to conversation (message, tool call, etc.)""" event = await db.add_event( thread_id=request.thread_id, event_type=request.event_type, role=request.role, content=request.content, tool_name=request.tool_name, tool_input=request.tool_input, tool_output=request.tool_output, payload=request.payload, token_count=request.token_count, model_used=request.model_used, latency_ms=request.latency_ms, metadata=request.metadata ) return event @app.get("/threads/{thread_id}/events", response_model=List[EventResponse]) async def get_events( thread_id: UUID, limit: int = Query(default=50, le=200), offset: int = Query(default=0) ): """Get events for thread (most recent first)""" events = await db.get_events(thread_id, limit=limit, offset=offset) return events # ============================================================================ # MEMORIES (Long-term Memory) # ============================================================================ @app.post("/memories", response_model=MemoryResponse) async def create_memory( request: CreateMemoryRequest, service: Optional[dict] = Depends(get_current_service_optional) ): """Create long-term memory item""" # Create in PostgreSQL memory = await db.create_memory( org_id=request.org_id, user_id=request.user_id, workspace_id=request.workspace_id, agent_id=request.agent_id, category=request.category, fact_text=request.fact_text, confidence=request.confidence, source_event_id=request.source_event_id, source_thread_id=request.source_thread_id, extraction_method=request.extraction_method, is_sensitive=request.is_sensitive, retention=request.retention, ttl_days=request.ttl_days, tags=request.tags, metadata=request.metadata ) # Index in Qdrant point_id = await vector_store.index_memory( memory_id=memory["memory_id"], text=request.fact_text, org_id=request.org_id, user_id=request.user_id, category=request.category, agent_id=request.agent_id, workspace_id=request.workspace_id, thread_id=request.source_thread_id ) # Update memory with embedding ID await db.update_memory_embedding_id(memory["memory_id"], point_id) return memory @app.get("/memories/{memory_id}", response_model=MemoryResponse) async def get_memory(memory_id: UUID): """Get memory by ID""" memory = await db.get_memory(memory_id) if not memory: raise HTTPException(status_code=404, detail="Memory not found") return memory @app.get("/memories", response_model=List[MemoryResponse]) async def list_memories( user_id: UUID = Query(...), org_id: UUID = Query(...), agent_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None, category: Optional[MemoryCategory] = None, include_global: bool = True, limit: int = Query(default=50, le=200) ): """List memories for user""" memories = await db.list_memories( org_id=org_id, user_id=user_id, agent_id=agent_id, workspace_id=workspace_id, category=category, include_global=include_global, limit=limit ) return memories @app.post("/memories/{memory_id}/feedback") async def memory_feedback(memory_id: UUID, request: MemoryFeedbackRequest): """User feedback on memory (confirm/reject/edit/delete)""" memory = await db.get_memory(memory_id) if not memory: raise HTTPException(status_code=404, detail="Memory not found") # Record feedback await db.add_memory_feedback( memory_id=memory_id, user_id=request.user_id, action=request.action, old_value=memory["fact_text"], new_value=request.new_value, reason=request.reason ) # Apply action if request.action == FeedbackAction.CONFIRM: new_confidence = min(1.0, memory["confidence"] + settings.memory_confirm_boost) await db.update_memory_confidence(memory_id, new_confidence, verified=True) elif request.action == FeedbackAction.REJECT: new_confidence = max(0.0, memory["confidence"] - settings.memory_reject_penalty) if new_confidence < settings.memory_min_confidence: # Mark as invalid await db.invalidate_memory(memory_id) await vector_store.delete_memory(memory_id) else: await db.update_memory_confidence(memory_id, new_confidence) elif request.action == FeedbackAction.EDIT: if request.new_value: await db.update_memory_text(memory_id, request.new_value) # Re-index with new text await vector_store.delete_memory(memory_id) await vector_store.index_memory( memory_id=memory_id, text=request.new_value, org_id=memory["org_id"], user_id=memory["user_id"], category=memory["category"], agent_id=memory.get("agent_id"), workspace_id=memory.get("workspace_id") ) elif request.action == FeedbackAction.DELETE: await db.invalidate_memory(memory_id) await vector_store.delete_memory(memory_id) return {"status": "ok", "action": request.action.value} # ============================================================================ # RETRIEVAL (Semantic Search) # ============================================================================ @app.post("/retrieve", response_model=RetrievalResponse) async def retrieve_memories(request: RetrievalRequest): """ Semantic retrieval of relevant memories. Performs multiple queries and deduplicates results. """ all_results = [] seen_ids = set() for query in request.queries: results = await vector_store.search_memories( query=query, org_id=request.org_id, user_id=request.user_id, agent_id=request.agent_id, workspace_id=request.workspace_id, categories=request.categories, include_global=request.include_global, top_k=request.top_k ) for r in results: memory_id = r.get("memory_id") if memory_id and memory_id not in seen_ids: seen_ids.add(memory_id) # Get full memory from DB for confidence check memory = await db.get_memory(UUID(memory_id)) if memory and memory["confidence"] >= request.min_confidence: all_results.append(RetrievalResult( memory_id=UUID(memory_id), fact_text=r["text"], category=MemoryCategory(r["category"]), confidence=memory["confidence"], relevance_score=r["score"], agent_id=UUID(r["agent_id"]) if r.get("agent_id") else None, is_global=r.get("agent_id") is None )) # Update usage stats await db.increment_memory_usage(UUID(memory_id)) # Sort by relevance all_results.sort(key=lambda x: x.relevance_score, reverse=True) return RetrievalResponse( results=all_results[:request.top_k], query_count=len(request.queries), total_results=len(all_results) ) # ============================================================================ # SUMMARIES (Mid-term Memory) # ============================================================================ @app.post("/threads/{thread_id}/summarize", response_model=SummaryResponse) async def create_summary(thread_id: UUID, request: SummaryRequest): """ Generate rolling summary for thread. Compresses old events into a structured summary. """ thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") # Check if summary is needed if not request.force and thread["total_tokens"] < settings.summary_trigger_tokens: raise HTTPException( status_code=400, detail=f"Token count ({thread['total_tokens']}) below threshold ({settings.summary_trigger_tokens})" ) # Get events to summarize events = await db.get_events_for_summary(thread_id) # TODO: Call LLM to generate summary # For now, create a placeholder summary_text = f"Summary of {len(events)} events. [Implement LLM summarization]" state = { "goals": [], "decisions": [], "open_questions": [], "next_steps": [], "key_facts": [] } # Create summary summary = await db.create_summary( thread_id=thread_id, summary_text=summary_text, state=state, events_from_seq=events[0]["sequence_num"] if events else 0, events_to_seq=events[-1]["sequence_num"] if events else 0, events_count=len(events) ) # Index summary in Qdrant await vector_store.index_summary( summary_id=summary["summary_id"], text=summary_text, thread_id=thread_id, org_id=thread["org_id"], user_id=thread["user_id"], agent_id=thread.get("agent_id"), workspace_id=thread.get("workspace_id") ) return summary @app.get("/threads/{thread_id}/summary", response_model=Optional[SummaryResponse]) async def get_latest_summary(thread_id: UUID): """Get latest summary for thread""" summary = await db.get_latest_summary(thread_id) return summary # ============================================================================ # CONTEXT (Full context for agent) # ============================================================================ @app.get("/threads/{thread_id}/context", response_model=ContextResponse) async def get_context( thread_id: UUID, queries: List[str] = Query(default=[]), top_k: int = Query(default=10) ): """ Get full context for agent prompt. Combines: - Latest summary (mid-term) - Recent messages (short-term) - Retrieved memories (long-term) """ thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") # Get summary summary = await db.get_latest_summary(thread_id) # Get recent messages recent = await db.get_events( thread_id, limit=settings.short_term_window_messages ) # Retrieve memories if queries provided retrieved = [] if queries: retrieval_response = await retrieve_memories(RetrievalRequest( org_id=thread["org_id"], user_id=thread["user_id"], agent_id=thread.get("agent_id"), workspace_id=thread.get("workspace_id"), queries=queries, top_k=top_k, include_global=True )) retrieved = retrieval_response.results # Estimate tokens token_estimate = sum(e.get("token_count", 0) or 0 for e in recent) if summary: token_estimate += summary.get("summary_tokens", 0) or 0 return ContextResponse( thread_id=thread_id, summary=summary, recent_messages=recent, retrieved_memories=retrieved, token_estimate=token_estimate ) # ============================================================================ # ADMIN # ============================================================================ @app.get("/stats") async def get_stats(): """Get service statistics""" return { "vector_store": await vector_store.get_collection_stats(), "database": await db.get_stats() } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)