""" DAARION Memory Service - FastAPI Application Трирівнева пам'ять агентів: - Short-term: conversation events (робочий буфер) - Mid-term: thread summaries (сесійна/тематична) - Long-term: memory items (персональна/проектна) """ from contextlib import asynccontextmanager from typing import List, Optional from uuid import UUID import structlog from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from .config import get_settings from .models import ( CreateThreadRequest, AddEventRequest, CreateMemoryRequest, MemoryFeedbackRequest, RetrievalRequest, SummaryRequest, ThreadResponse, EventResponse, MemoryResponse, SummaryResponse, RetrievalResponse, RetrievalResult, ContextResponse, MemoryCategory, FeedbackAction ) from .vector_store import vector_store from .database import db logger = structlog.get_logger() settings = get_settings() @asynccontextmanager async def lifespan(app: FastAPI): """Startup and shutdown events""" # Startup logger.info("starting_memory_service") await db.connect() await vector_store.initialize() yield # Shutdown await db.disconnect() logger.info("memory_service_stopped") app = FastAPI( title="DAARION Memory Service", description="Agent memory management with PostgreSQL + Qdrant + Cohere", version="1.0.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============================================================================ # HEALTH # ============================================================================ @app.get("/health") async def health(): """Health check""" return { "status": "healthy", "service": settings.service_name, "vector_store": await vector_store.get_collection_stats() } # ============================================================================ # THREADS (Conversations) # ============================================================================ @app.post("/threads", response_model=ThreadResponse) async def create_thread(request: CreateThreadRequest): """Create new conversation thread""" thread = await db.create_thread( org_id=request.org_id, user_id=request.user_id, workspace_id=request.workspace_id, agent_id=request.agent_id, title=request.title, tags=request.tags, metadata=request.metadata ) return thread @app.get("/threads/{thread_id}", response_model=ThreadResponse) async def get_thread(thread_id: UUID): """Get thread by ID""" thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") return thread @app.get("/threads", response_model=List[ThreadResponse]) async def list_threads( user_id: UUID = Query(...), org_id: UUID = Query(...), workspace_id: Optional[UUID] = None, agent_id: Optional[UUID] = None, limit: int = Query(default=20, le=100) ): """List threads for user""" threads = await db.list_threads( org_id=org_id, user_id=user_id, workspace_id=workspace_id, agent_id=agent_id, limit=limit ) return threads # ============================================================================ # EVENTS (Short-term Memory) # ============================================================================ @app.post("/events", response_model=EventResponse) async def add_event(request: AddEventRequest): """Add event to conversation (message, tool call, etc.)""" event = await db.add_event( thread_id=request.thread_id, event_type=request.event_type, role=request.role, content=request.content, tool_name=request.tool_name, tool_input=request.tool_input, tool_output=request.tool_output, payload=request.payload, token_count=request.token_count, model_used=request.model_used, latency_ms=request.latency_ms, metadata=request.metadata ) return event @app.get("/threads/{thread_id}/events", response_model=List[EventResponse]) async def get_events( thread_id: UUID, limit: int = Query(default=50, le=200), offset: int = Query(default=0) ): """Get events for thread (most recent first)""" events = await db.get_events(thread_id, limit=limit, offset=offset) return events # ============================================================================ # MEMORIES (Long-term Memory) # ============================================================================ @app.post("/memories", response_model=MemoryResponse) async def create_memory(request: CreateMemoryRequest): """Create long-term memory item""" # Create in PostgreSQL memory = await db.create_memory( org_id=request.org_id, user_id=request.user_id, workspace_id=request.workspace_id, agent_id=request.agent_id, category=request.category, fact_text=request.fact_text, confidence=request.confidence, source_event_id=request.source_event_id, source_thread_id=request.source_thread_id, extraction_method=request.extraction_method, is_sensitive=request.is_sensitive, retention=request.retention, ttl_days=request.ttl_days, tags=request.tags, metadata=request.metadata ) # Index in Qdrant point_id = await vector_store.index_memory( memory_id=memory["memory_id"], text=request.fact_text, org_id=request.org_id, user_id=request.user_id, category=request.category, agent_id=request.agent_id, workspace_id=request.workspace_id, thread_id=request.source_thread_id ) # Update memory with embedding ID await db.update_memory_embedding_id(memory["memory_id"], point_id) return memory @app.get("/memories/{memory_id}", response_model=MemoryResponse) async def get_memory(memory_id: UUID): """Get memory by ID""" memory = await db.get_memory(memory_id) if not memory: raise HTTPException(status_code=404, detail="Memory not found") return memory @app.get("/memories", response_model=List[MemoryResponse]) async def list_memories( user_id: UUID = Query(...), org_id: UUID = Query(...), agent_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None, category: Optional[MemoryCategory] = None, include_global: bool = True, limit: int = Query(default=50, le=200) ): """List memories for user""" memories = await db.list_memories( org_id=org_id, user_id=user_id, agent_id=agent_id, workspace_id=workspace_id, category=category, include_global=include_global, limit=limit ) return memories @app.post("/memories/{memory_id}/feedback") async def memory_feedback(memory_id: UUID, request: MemoryFeedbackRequest): """User feedback on memory (confirm/reject/edit/delete)""" memory = await db.get_memory(memory_id) if not memory: raise HTTPException(status_code=404, detail="Memory not found") # Record feedback await db.add_memory_feedback( memory_id=memory_id, user_id=request.user_id, action=request.action, old_value=memory["fact_text"], new_value=request.new_value, reason=request.reason ) # Apply action if request.action == FeedbackAction.CONFIRM: new_confidence = min(1.0, memory["confidence"] + settings.memory_confirm_boost) await db.update_memory_confidence(memory_id, new_confidence, verified=True) elif request.action == FeedbackAction.REJECT: new_confidence = max(0.0, memory["confidence"] - settings.memory_reject_penalty) if new_confidence < settings.memory_min_confidence: # Mark as invalid await db.invalidate_memory(memory_id) await vector_store.delete_memory(memory_id) else: await db.update_memory_confidence(memory_id, new_confidence) elif request.action == FeedbackAction.EDIT: if request.new_value: await db.update_memory_text(memory_id, request.new_value) # Re-index with new text await vector_store.delete_memory(memory_id) await vector_store.index_memory( memory_id=memory_id, text=request.new_value, org_id=memory["org_id"], user_id=memory["user_id"], category=memory["category"], agent_id=memory.get("agent_id"), workspace_id=memory.get("workspace_id") ) elif request.action == FeedbackAction.DELETE: await db.invalidate_memory(memory_id) await vector_store.delete_memory(memory_id) return {"status": "ok", "action": request.action.value} # ============================================================================ # RETRIEVAL (Semantic Search) # ============================================================================ @app.post("/retrieve", response_model=RetrievalResponse) async def retrieve_memories(request: RetrievalRequest): """ Semantic retrieval of relevant memories. Performs multiple queries and deduplicates results. """ all_results = [] seen_ids = set() for query in request.queries: results = await vector_store.search_memories( query=query, org_id=request.org_id, user_id=request.user_id, agent_id=request.agent_id, workspace_id=request.workspace_id, categories=request.categories, include_global=request.include_global, top_k=request.top_k ) for r in results: memory_id = r.get("memory_id") if memory_id and memory_id not in seen_ids: seen_ids.add(memory_id) # Get full memory from DB for confidence check memory = await db.get_memory(UUID(memory_id)) if memory and memory["confidence"] >= request.min_confidence: all_results.append(RetrievalResult( memory_id=UUID(memory_id), fact_text=r["text"], category=MemoryCategory(r["category"]), confidence=memory["confidence"], relevance_score=r["score"], agent_id=UUID(r["agent_id"]) if r.get("agent_id") else None, is_global=r.get("agent_id") is None )) # Update usage stats await db.increment_memory_usage(UUID(memory_id)) # Sort by relevance all_results.sort(key=lambda x: x.relevance_score, reverse=True) return RetrievalResponse( results=all_results[:request.top_k], query_count=len(request.queries), total_results=len(all_results) ) # ============================================================================ # SUMMARIES (Mid-term Memory) # ============================================================================ @app.post("/threads/{thread_id}/summarize", response_model=SummaryResponse) async def create_summary(thread_id: UUID, request: SummaryRequest): """ Generate rolling summary for thread. Compresses old events into a structured summary. """ thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") # Check if summary is needed if not request.force and thread["total_tokens"] < settings.summary_trigger_tokens: raise HTTPException( status_code=400, detail=f"Token count ({thread['total_tokens']}) below threshold ({settings.summary_trigger_tokens})" ) # Get events to summarize events = await db.get_events_for_summary(thread_id) # TODO: Call LLM to generate summary # For now, create a placeholder summary_text = f"Summary of {len(events)} events. [Implement LLM summarization]" state = { "goals": [], "decisions": [], "open_questions": [], "next_steps": [], "key_facts": [] } # Create summary summary = await db.create_summary( thread_id=thread_id, summary_text=summary_text, state=state, events_from_seq=events[0]["sequence_num"] if events else 0, events_to_seq=events[-1]["sequence_num"] if events else 0, events_count=len(events) ) # Index summary in Qdrant await vector_store.index_summary( summary_id=summary["summary_id"], text=summary_text, thread_id=thread_id, org_id=thread["org_id"], user_id=thread["user_id"], agent_id=thread.get("agent_id"), workspace_id=thread.get("workspace_id") ) return summary @app.get("/threads/{thread_id}/summary", response_model=Optional[SummaryResponse]) async def get_latest_summary(thread_id: UUID): """Get latest summary for thread""" summary = await db.get_latest_summary(thread_id) return summary # ============================================================================ # CONTEXT (Full context for agent) # ============================================================================ @app.get("/threads/{thread_id}/context", response_model=ContextResponse) async def get_context( thread_id: UUID, queries: List[str] = Query(default=[]), top_k: int = Query(default=10) ): """ Get full context for agent prompt. Combines: - Latest summary (mid-term) - Recent messages (short-term) - Retrieved memories (long-term) """ thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") # Get summary summary = await db.get_latest_summary(thread_id) # Get recent messages recent = await db.get_events( thread_id, limit=settings.short_term_window_messages ) # Retrieve memories if queries provided retrieved = [] if queries: retrieval_response = await retrieve_memories(RetrievalRequest( org_id=thread["org_id"], user_id=thread["user_id"], agent_id=thread.get("agent_id"), workspace_id=thread.get("workspace_id"), queries=queries, top_k=top_k, include_global=True )) retrieved = retrieval_response.results # Estimate tokens token_estimate = sum(e.get("token_count", 0) or 0 for e in recent) if summary: token_estimate += summary.get("summary_tokens", 0) or 0 return ContextResponse( thread_id=thread_id, summary=summary, recent_messages=recent, retrieved_memories=retrieved, token_estimate=token_estimate ) # ============================================================================ # ADMIN # ============================================================================ @app.get("/stats") async def get_stats(): """Get service statistics""" return { "vector_store": await vector_store.get_collection_stats(), "database": await db.get_stats() } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)