""" DAARION Memory Service - FastAPI Application Трирівнева пам'ять агентів: - Short-term: conversation events (робочий буфер) - Mid-term: thread summaries (сесійна/тематична) - Long-term: memory items (персональна/проектна) """ from contextlib import asynccontextmanager from typing import List, Optional from fastapi import Depends from uuid import UUID import structlog from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from .config import get_settings from .models import ( CreateThreadRequest, AddEventRequest, CreateMemoryRequest, MemoryFeedbackRequest, RetrievalRequest, SummaryRequest, ThreadResponse, EventResponse, MemoryResponse, SummaryResponse, RetrievalResponse, RetrievalResult, ContextResponse, MemoryCategory, FeedbackAction ) from .vector_store import vector_store from .database import db from .auth import get_current_service, get_current_service_optional logger = structlog.get_logger() settings = get_settings() @asynccontextmanager async def lifespan(app: FastAPI): """Startup and shutdown events""" # Startup logger.info("starting_memory_service") await db.connect() await vector_store.initialize() yield # Shutdown await db.disconnect() logger.info("memory_service_stopped") app = FastAPI( title="DAARION Memory Service", description="Agent memory management with PostgreSQL + Qdrant + Cohere", version="1.0.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============================================================================ # HEALTH # ============================================================================ @app.get("/health") async def health(): """Health check""" return { "status": "healthy", "service": settings.service_name, "vector_store": await vector_store.get_collection_stats() } # ============================================================================ # THREADS (Conversations) # ============================================================================ @app.post("/threads", response_model=ThreadResponse) async def create_thread( request: CreateThreadRequest, service: Optional[dict] = Depends(get_current_service_optional) ): """Create new conversation thread""" # Auth опціональний: якщо JWT надано, перевіряємо; якщо ні - дозволяємо (dev режим) thread = await db.create_thread( org_id=request.org_id, user_id=request.user_id, workspace_id=request.workspace_id, agent_id=request.agent_id, title=request.title, tags=request.tags, metadata=request.metadata ) return thread @app.get("/threads/{thread_id}", response_model=ThreadResponse) async def get_thread(thread_id: UUID): """Get thread by ID""" thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") return thread @app.get("/threads", response_model=List[ThreadResponse]) async def list_threads( user_id: UUID = Query(...), org_id: UUID = Query(...), workspace_id: Optional[UUID] = None, agent_id: Optional[UUID] = None, limit: int = Query(default=20, le=100) ): """List threads for user""" threads = await db.list_threads( org_id=org_id, user_id=user_id, workspace_id=workspace_id, agent_id=agent_id, limit=limit ) return threads # ============================================================================ # EVENTS (Short-term Memory) # ============================================================================ @app.post("/events", response_model=EventResponse) async def add_event( request: AddEventRequest, service: Optional[dict] = Depends(get_current_service_optional) ): """Add event to conversation (message, tool call, etc.)""" event = await db.add_event( thread_id=request.thread_id, event_type=request.event_type, role=request.role, content=request.content, tool_name=request.tool_name, tool_input=request.tool_input, tool_output=request.tool_output, payload=request.payload, token_count=request.token_count, model_used=request.model_used, latency_ms=request.latency_ms, metadata=request.metadata ) return event @app.get("/threads/{thread_id}/events", response_model=List[EventResponse]) async def get_events( thread_id: UUID, limit: int = Query(default=50, le=200), offset: int = Query(default=0) ): """Get events for thread (most recent first)""" events = await db.get_events(thread_id, limit=limit, offset=offset) return events # ============================================================================ # MEMORIES (Long-term Memory) # ============================================================================ @app.post("/memories", response_model=MemoryResponse) async def create_memory( request: CreateMemoryRequest, service: Optional[dict] = Depends(get_current_service_optional) ): """Create long-term memory item""" # Create in PostgreSQL memory = await db.create_memory( org_id=request.org_id, user_id=request.user_id, workspace_id=request.workspace_id, agent_id=request.agent_id, category=request.category, fact_text=request.fact_text, confidence=request.confidence, source_event_id=request.source_event_id, source_thread_id=request.source_thread_id, extraction_method=request.extraction_method, is_sensitive=request.is_sensitive, retention=request.retention, ttl_days=request.ttl_days, tags=request.tags, metadata=request.metadata ) # Index in Qdrant point_id = await vector_store.index_memory( memory_id=memory["memory_id"], text=request.fact_text, org_id=request.org_id, user_id=request.user_id, category=request.category, agent_id=request.agent_id, workspace_id=request.workspace_id, thread_id=request.source_thread_id ) # Update memory with embedding ID await db.update_memory_embedding_id(memory["memory_id"], point_id) return memory @app.get("/memories/{memory_id}", response_model=MemoryResponse) async def get_memory(memory_id: UUID): """Get memory by ID""" memory = await db.get_memory(memory_id) if not memory: raise HTTPException(status_code=404, detail="Memory not found") return memory @app.get("/memories", response_model=List[MemoryResponse]) async def list_memories( user_id: UUID = Query(...), org_id: UUID = Query(...), agent_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None, category: Optional[MemoryCategory] = None, include_global: bool = True, limit: int = Query(default=50, le=200) ): """List memories for user""" memories = await db.list_memories( org_id=org_id, user_id=user_id, agent_id=agent_id, workspace_id=workspace_id, category=category, include_global=include_global, limit=limit ) return memories @app.post("/memories/{memory_id}/feedback") async def memory_feedback(memory_id: UUID, request: MemoryFeedbackRequest): """User feedback on memory (confirm/reject/edit/delete)""" memory = await db.get_memory(memory_id) if not memory: raise HTTPException(status_code=404, detail="Memory not found") # Record feedback await db.add_memory_feedback( memory_id=memory_id, user_id=request.user_id, action=request.action, old_value=memory["fact_text"], new_value=request.new_value, reason=request.reason ) # Apply action if request.action == FeedbackAction.CONFIRM: new_confidence = min(1.0, memory["confidence"] + settings.memory_confirm_boost) await db.update_memory_confidence(memory_id, new_confidence, verified=True) elif request.action == FeedbackAction.REJECT: new_confidence = max(0.0, memory["confidence"] - settings.memory_reject_penalty) if new_confidence < settings.memory_min_confidence: # Mark as invalid await db.invalidate_memory(memory_id) await vector_store.delete_memory(memory_id) else: await db.update_memory_confidence(memory_id, new_confidence) elif request.action == FeedbackAction.EDIT: if request.new_value: await db.update_memory_text(memory_id, request.new_value) # Re-index with new text await vector_store.delete_memory(memory_id) await vector_store.index_memory( memory_id=memory_id, text=request.new_value, org_id=memory["org_id"], user_id=memory["user_id"], category=memory["category"], agent_id=memory.get("agent_id"), workspace_id=memory.get("workspace_id") ) elif request.action == FeedbackAction.DELETE: await db.invalidate_memory(memory_id) await vector_store.delete_memory(memory_id) return {"status": "ok", "action": request.action.value} # ============================================================================ # RETRIEVAL (Semantic Search) # ============================================================================ @app.post("/retrieve", response_model=RetrievalResponse) async def retrieve_memories(request: RetrievalRequest): """ Semantic retrieval of relevant memories. Performs multiple queries and deduplicates results. """ all_results = [] seen_ids = set() for query in request.queries: results = await vector_store.search_memories( query=query, org_id=request.org_id, user_id=request.user_id, agent_id=request.agent_id, workspace_id=request.workspace_id, categories=request.categories, include_global=request.include_global, top_k=request.top_k ) for r in results: memory_id = r.get("memory_id") if memory_id and memory_id not in seen_ids: seen_ids.add(memory_id) # Get full memory from DB for confidence check memory = await db.get_memory(UUID(memory_id)) if memory and memory["confidence"] >= request.min_confidence: all_results.append(RetrievalResult( memory_id=UUID(memory_id), fact_text=r["text"], category=MemoryCategory(r["category"]), confidence=memory["confidence"], relevance_score=r["score"], agent_id=UUID(r["agent_id"]) if r.get("agent_id") else None, is_global=r.get("agent_id") is None )) # Update usage stats await db.increment_memory_usage(UUID(memory_id)) # Sort by relevance all_results.sort(key=lambda x: x.relevance_score, reverse=True) return RetrievalResponse( results=all_results[:request.top_k], query_count=len(request.queries), total_results=len(all_results) ) # ============================================================================ # SUMMARIES (Mid-term Memory) # ============================================================================ @app.post("/threads/{thread_id}/summarize", response_model=SummaryResponse) async def create_summary(thread_id: UUID, request: SummaryRequest): """ Generate rolling summary for thread. Compresses old events into a structured summary. """ thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") # Check if summary is needed if not request.force and thread["total_tokens"] < settings.summary_trigger_tokens: raise HTTPException( status_code=400, detail=f"Token count ({thread['total_tokens']}) below threshold ({settings.summary_trigger_tokens})" ) # Get events to summarize events = await db.get_events_for_summary(thread_id) # TODO: Call LLM to generate summary # For now, create a placeholder summary_text = f"Summary of {len(events)} events. [Implement LLM summarization]" state = { "goals": [], "decisions": [], "open_questions": [], "next_steps": [], "key_facts": [] } # Create summary summary = await db.create_summary( thread_id=thread_id, summary_text=summary_text, state=state, events_from_seq=events[0]["sequence_num"] if events else 0, events_to_seq=events[-1]["sequence_num"] if events else 0, events_count=len(events) ) # Index summary in Qdrant await vector_store.index_summary( summary_id=summary["summary_id"], text=summary_text, thread_id=thread_id, org_id=thread["org_id"], user_id=thread["user_id"], agent_id=thread.get("agent_id"), workspace_id=thread.get("workspace_id") ) return summary @app.get("/threads/{thread_id}/summary", response_model=Optional[SummaryResponse]) async def get_latest_summary(thread_id: UUID): """Get latest summary for thread""" summary = await db.get_latest_summary(thread_id) return summary # ============================================================================ # CONTEXT (Full context for agent) # ============================================================================ @app.get("/threads/{thread_id}/context", response_model=ContextResponse) async def get_context( thread_id: UUID, queries: List[str] = Query(default=[]), top_k: int = Query(default=10) ): """ Get full context for agent prompt. Combines: - Latest summary (mid-term) - Recent messages (short-term) - Retrieved memories (long-term) """ thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") # Get summary summary = await db.get_latest_summary(thread_id) # Get recent messages recent = await db.get_events( thread_id, limit=settings.short_term_window_messages ) # Retrieve memories if queries provided retrieved = [] if queries: retrieval_response = await retrieve_memories(RetrievalRequest( org_id=thread["org_id"], user_id=thread["user_id"], agent_id=thread.get("agent_id"), workspace_id=thread.get("workspace_id"), queries=queries, top_k=top_k, include_global=True )) retrieved = retrieval_response.results # Estimate tokens token_estimate = sum(e.get("token_count", 0) or 0 for e in recent) if summary: token_estimate += summary.get("summary_tokens", 0) or 0 return ContextResponse( thread_id=thread_id, summary=summary, recent_messages=recent, retrieved_memories=retrieved, token_estimate=token_estimate ) # ============================================================================ # FACTS (Simple Key-Value storage for Gateway compatibility) # ============================================================================ from pydantic import BaseModel from typing import Any class FactUpsertRequest(BaseModel): """Request to upsert a user fact""" user_id: str fact_key: str fact_value: Optional[str] = None fact_value_json: Optional[dict] = None team_id: Optional[str] = None @app.post("/facts/upsert") async def upsert_fact(request: FactUpsertRequest): """ Create or update a user fact. This is a simple key-value store for Gateway compatibility. Facts are stored in PostgreSQL without vector indexing. """ try: # Ensure facts table exists (will be created on first call) await db.ensure_facts_table() # Upsert the fact result = await db.upsert_fact( user_id=request.user_id, fact_key=request.fact_key, fact_value=request.fact_value, fact_value_json=request.fact_value_json, team_id=request.team_id ) logger.info(f"fact_upserted", user_id=request.user_id, fact_key=request.fact_key) return {"status": "ok", "fact_id": result.get("fact_id") if result else None} except Exception as e: logger.error(f"fact_upsert_failed", error=str(e), user_id=request.user_id) raise HTTPException(status_code=500, detail=str(e)) @app.get("/facts/{fact_key}") async def get_fact( fact_key: str, user_id: str = Query(...), team_id: Optional[str] = None ): """Get a specific fact for a user""" try: fact = await db.get_fact(user_id=user_id, fact_key=fact_key, team_id=team_id) if not fact: raise HTTPException(status_code=404, detail="Fact not found") return fact except HTTPException: raise except Exception as e: logger.error(f"fact_get_failed", error=str(e)) raise HTTPException(status_code=500, detail=str(e)) @app.get("/facts") async def list_facts( user_id: str = Query(...), team_id: Optional[str] = None ): """List all facts for a user""" try: facts = await db.list_facts(user_id=user_id, team_id=team_id) return {"facts": facts} except Exception as e: logger.error(f"facts_list_failed", error=str(e)) raise HTTPException(status_code=500, detail=str(e)) @app.delete("/facts/{fact_key}") async def delete_fact( fact_key: str, user_id: str = Query(...), team_id: Optional[str] = None ): """Delete a fact""" try: deleted = await db.delete_fact(user_id=user_id, fact_key=fact_key, team_id=team_id) if not deleted: raise HTTPException(status_code=404, detail="Fact not found") return {"status": "ok", "deleted": True} except HTTPException: raise except Exception as e: logger.error(f"fact_delete_failed", error=str(e)) raise HTTPException(status_code=500, detail=str(e)) # ============================================================================ # ADMIN # ============================================================================ @app.get("/stats") async def get_stats(): """Get service statistics""" return { "vector_store": await vector_store.get_collection_stats(), "database": await db.get_stats() } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)