""" DAARION Memory Service - FastAPI Application Трирівнева пам'ять агентів: - Short-term: conversation events (робочий буфер) - Mid-term: thread summaries (сесійна/тематична) - Long-term: memory items (персональна/проектна) """ from contextlib import asynccontextmanager from typing import List, Optional from fastapi import Depends, BackgroundTasks from uuid import UUID import structlog from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from .config import get_settings from .models import ( CreateThreadRequest, AddEventRequest, CreateMemoryRequest, MemoryFeedbackRequest, RetrievalRequest, SummaryRequest, ThreadResponse, EventResponse, MemoryResponse, SummaryResponse, RetrievalResponse, RetrievalResult, ContextResponse, MemoryCategory, FeedbackAction ) from .vector_store import vector_store from .database import db from .auth import get_current_service, get_current_service_optional logger = structlog.get_logger() settings = get_settings() @asynccontextmanager async def lifespan(app: FastAPI): """Startup and shutdown events""" # Startup logger.info("starting_memory_service") await db.connect() await vector_store.initialize() yield # Shutdown await db.disconnect() logger.info("memory_service_stopped") app = FastAPI( title="DAARION Memory Service", description="Agent memory management with PostgreSQL + Qdrant + Cohere", version="1.0.0", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============================================================================ # HEALTH # ============================================================================ @app.get("/health") async def health(): """Health check""" return { "status": "healthy", "service": settings.service_name, "vector_store": await vector_store.get_collection_stats() } # ============================================================================ # THREADS (Conversations) # ============================================================================ @app.post("/threads", response_model=ThreadResponse) async def create_thread( request: CreateThreadRequest, service: Optional[dict] = Depends(get_current_service_optional) ): """Create new conversation thread""" # Auth опціональний: якщо JWT надано, перевіряємо; якщо ні - дозволяємо (dev режим) thread = await db.create_thread( org_id=request.org_id, user_id=request.user_id, workspace_id=request.workspace_id, agent_id=request.agent_id, title=request.title, tags=request.tags, metadata=request.metadata ) return thread @app.get("/threads/{thread_id}", response_model=ThreadResponse) async def get_thread(thread_id: UUID): """Get thread by ID""" thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") return thread @app.get("/threads", response_model=List[ThreadResponse]) async def list_threads( user_id: UUID = Query(...), org_id: UUID = Query(...), workspace_id: Optional[UUID] = None, agent_id: Optional[UUID] = None, limit: int = Query(default=20, le=100) ): """List threads for user""" threads = await db.list_threads( org_id=org_id, user_id=user_id, workspace_id=workspace_id, agent_id=agent_id, limit=limit ) return threads # ============================================================================ # EVENTS (Short-term Memory) # ============================================================================ @app.post("/events", response_model=EventResponse) async def add_event( request: AddEventRequest, service: Optional[dict] = Depends(get_current_service_optional) ): """Add event to conversation (message, tool call, etc.)""" event = await db.add_event( thread_id=request.thread_id, event_type=request.event_type, role=request.role, content=request.content, tool_name=request.tool_name, tool_input=request.tool_input, tool_output=request.tool_output, payload=request.payload, token_count=request.token_count, model_used=request.model_used, latency_ms=request.latency_ms, metadata=request.metadata ) return event @app.get("/threads/{thread_id}/events", response_model=List[EventResponse]) async def get_events( thread_id: UUID, limit: int = Query(default=50, le=200), offset: int = Query(default=0) ): """Get events for thread (most recent first)""" events = await db.get_events(thread_id, limit=limit, offset=offset) return events # ============================================================================ # MEMORIES (Long-term Memory) # ============================================================================ @app.post("/memories", response_model=MemoryResponse) async def create_memory( request: CreateMemoryRequest, service: Optional[dict] = Depends(get_current_service_optional) ): """Create long-term memory item""" # Create in PostgreSQL memory = await db.create_memory( org_id=request.org_id, user_id=request.user_id, workspace_id=request.workspace_id, agent_id=request.agent_id, category=request.category, fact_text=request.fact_text, confidence=request.confidence, source_event_id=request.source_event_id, source_thread_id=request.source_thread_id, extraction_method=request.extraction_method, is_sensitive=request.is_sensitive, retention=request.retention, ttl_days=request.ttl_days, tags=request.tags, metadata=request.metadata ) # Index in Qdrant point_id = await vector_store.index_memory( memory_id=memory["memory_id"], text=request.fact_text, org_id=request.org_id, user_id=request.user_id, category=request.category, agent_id=request.agent_id, workspace_id=request.workspace_id, thread_id=request.source_thread_id ) # Update memory with embedding ID await db.update_memory_embedding_id(memory["memory_id"], point_id) return memory @app.get("/memories/{memory_id}", response_model=MemoryResponse) async def get_memory(memory_id: UUID): """Get memory by ID""" memory = await db.get_memory(memory_id) if not memory: raise HTTPException(status_code=404, detail="Memory not found") return memory @app.get("/memories", response_model=List[MemoryResponse]) async def list_memories( user_id: UUID = Query(...), org_id: UUID = Query(...), agent_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None, category: Optional[MemoryCategory] = None, include_global: bool = True, limit: int = Query(default=50, le=200) ): """List memories for user""" memories = await db.list_memories( org_id=org_id, user_id=user_id, agent_id=agent_id, workspace_id=workspace_id, category=category, include_global=include_global, limit=limit ) return memories @app.post("/memories/{memory_id}/feedback") async def memory_feedback(memory_id: UUID, request: MemoryFeedbackRequest): """User feedback on memory (confirm/reject/edit/delete)""" memory = await db.get_memory(memory_id) if not memory: raise HTTPException(status_code=404, detail="Memory not found") # Record feedback await db.add_memory_feedback( memory_id=memory_id, user_id=request.user_id, action=request.action, old_value=memory["fact_text"], new_value=request.new_value, reason=request.reason ) # Apply action if request.action == FeedbackAction.CONFIRM: new_confidence = min(1.0, memory["confidence"] + settings.memory_confirm_boost) await db.update_memory_confidence(memory_id, new_confidence, verified=True) elif request.action == FeedbackAction.REJECT: new_confidence = max(0.0, memory["confidence"] - settings.memory_reject_penalty) if new_confidence < settings.memory_min_confidence: # Mark as invalid await db.invalidate_memory(memory_id) await vector_store.delete_memory(memory_id) else: await db.update_memory_confidence(memory_id, new_confidence) elif request.action == FeedbackAction.EDIT: if request.new_value: await db.update_memory_text(memory_id, request.new_value) # Re-index with new text await vector_store.delete_memory(memory_id) await vector_store.index_memory( memory_id=memory_id, text=request.new_value, org_id=memory["org_id"], user_id=memory["user_id"], category=memory["category"], agent_id=memory.get("agent_id"), workspace_id=memory.get("workspace_id") ) elif request.action == FeedbackAction.DELETE: await db.invalidate_memory(memory_id) await vector_store.delete_memory(memory_id) return {"status": "ok", "action": request.action.value} # ============================================================================ # RETRIEVAL (Semantic Search) # ============================================================================ @app.post("/retrieve", response_model=RetrievalResponse) async def retrieve_memories(request: RetrievalRequest): """ Semantic retrieval of relevant memories. Performs multiple queries and deduplicates results. """ all_results = [] seen_ids = set() for query in request.queries: results = await vector_store.search_memories( query=query, org_id=request.org_id, user_id=request.user_id, agent_id=request.agent_id, workspace_id=request.workspace_id, categories=request.categories, include_global=request.include_global, top_k=request.top_k ) for r in results: memory_id = r.get("memory_id") if memory_id and memory_id not in seen_ids: seen_ids.add(memory_id) # Get full memory from DB for confidence check memory = await db.get_memory(UUID(memory_id)) if memory and memory["confidence"] >= request.min_confidence: all_results.append(RetrievalResult( memory_id=UUID(memory_id), fact_text=r["text"], category=MemoryCategory(r["category"]), confidence=memory["confidence"], relevance_score=r["score"], agent_id=UUID(r["agent_id"]) if r.get("agent_id") else None, is_global=r.get("agent_id") is None )) # Update usage stats await db.increment_memory_usage(UUID(memory_id)) # Sort by relevance all_results.sort(key=lambda x: x.relevance_score, reverse=True) return RetrievalResponse( results=all_results[:request.top_k], query_count=len(request.queries), total_results=len(all_results) ) # ============================================================================ # SUMMARIES (Mid-term Memory) # ============================================================================ @app.post("/threads/{thread_id}/summarize", response_model=SummaryResponse) async def create_summary(thread_id: UUID, request: SummaryRequest): """ Generate rolling summary for thread. Compresses old events into a structured summary. """ thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") # Check if summary is needed if not request.force and thread["total_tokens"] < settings.summary_trigger_tokens: raise HTTPException( status_code=400, detail=f"Token count ({thread['total_tokens']}) below threshold ({settings.summary_trigger_tokens})" ) # Get events to summarize events = await db.get_events_for_summary(thread_id) # TODO: Call LLM to generate summary # For now, create a placeholder summary_text = f"Summary of {len(events)} events. [Implement LLM summarization]" state = { "goals": [], "decisions": [], "open_questions": [], "next_steps": [], "key_facts": [] } # Create summary summary = await db.create_summary( thread_id=thread_id, summary_text=summary_text, state=state, events_from_seq=events[0]["sequence_num"] if events else 0, events_to_seq=events[-1]["sequence_num"] if events else 0, events_count=len(events) ) # Index summary in Qdrant await vector_store.index_summary( summary_id=summary["summary_id"], text=summary_text, thread_id=thread_id, org_id=thread["org_id"], user_id=thread["user_id"], agent_id=thread.get("agent_id"), workspace_id=thread.get("workspace_id") ) return summary @app.get("/threads/{thread_id}/summary", response_model=Optional[SummaryResponse]) async def get_latest_summary(thread_id: UUID): """Get latest summary for thread""" summary = await db.get_latest_summary(thread_id) return summary # ============================================================================ # CONTEXT (Full context for agent) # ============================================================================ @app.get("/threads/{thread_id}/context", response_model=ContextResponse) async def get_context( thread_id: UUID, queries: List[str] = Query(default=[]), top_k: int = Query(default=10) ): """ Get full context for agent prompt. Combines: - Latest summary (mid-term) - Recent messages (short-term) - Retrieved memories (long-term) """ thread = await db.get_thread(thread_id) if not thread: raise HTTPException(status_code=404, detail="Thread not found") # Get summary summary = await db.get_latest_summary(thread_id) # Get recent messages recent = await db.get_events( thread_id, limit=settings.short_term_window_messages ) # Retrieve memories if queries provided retrieved = [] if queries: retrieval_response = await retrieve_memories(RetrievalRequest( org_id=thread["org_id"], user_id=thread["user_id"], agent_id=thread.get("agent_id"), workspace_id=thread.get("workspace_id"), queries=queries, top_k=top_k, include_global=True )) retrieved = retrieval_response.results # Estimate tokens token_estimate = sum(e.get("token_count", 0) or 0 for e in recent) if summary: token_estimate += summary.get("summary_tokens", 0) or 0 return ContextResponse( thread_id=thread_id, summary=summary, recent_messages=recent, retrieved_memories=retrieved, token_estimate=token_estimate ) # ============================================================================ # FACTS (Simple Key-Value storage for Gateway compatibility) # ============================================================================ from pydantic import BaseModel from typing import Any class FactUpsertRequest(BaseModel): """Request to upsert a user fact""" user_id: str fact_key: str fact_value: Optional[str] = None fact_value_json: Optional[dict] = None team_id: Optional[str] = None @app.post("/facts/upsert") async def upsert_fact(request: FactUpsertRequest): """ Create or update a user fact. This is a simple key-value store for Gateway compatibility. Facts are stored in PostgreSQL without vector indexing. """ try: # Ensure facts table exists (will be created on first call) await db.ensure_facts_table() # Upsert the fact result = await db.upsert_fact( user_id=request.user_id, fact_key=request.fact_key, fact_value=request.fact_value, fact_value_json=request.fact_value_json, team_id=request.team_id ) logger.info(f"fact_upserted", user_id=request.user_id, fact_key=request.fact_key) return {"status": "ok", "fact_id": result.get("fact_id") if result else None} except Exception as e: logger.error(f"fact_upsert_failed", error=str(e), user_id=request.user_id) raise HTTPException(status_code=500, detail=str(e)) @app.get("/facts/{fact_key}") async def get_fact( fact_key: str, user_id: str = Query(...), team_id: Optional[str] = None ): """Get a specific fact for a user""" try: fact = await db.get_fact(user_id=user_id, fact_key=fact_key, team_id=team_id) if not fact: raise HTTPException(status_code=404, detail="Fact not found") return fact except HTTPException: raise except Exception as e: logger.error(f"fact_get_failed", error=str(e)) raise HTTPException(status_code=500, detail=str(e)) @app.get("/facts") async def list_facts( user_id: str = Query(...), team_id: Optional[str] = None ): """List all facts for a user""" try: facts = await db.list_facts(user_id=user_id, team_id=team_id) return {"facts": facts} except Exception as e: logger.error(f"facts_list_failed", error=str(e)) raise HTTPException(status_code=500, detail=str(e)) @app.delete("/facts/{fact_key}") async def delete_fact( fact_key: str, user_id: str = Query(...), team_id: Optional[str] = None ): """Delete a fact""" try: deleted = await db.delete_fact(user_id=user_id, fact_key=fact_key, team_id=team_id) if not deleted: raise HTTPException(status_code=404, detail="Fact not found") return {"status": "ok", "deleted": True} except HTTPException: raise except Exception as e: logger.error(f"fact_delete_failed", error=str(e)) raise HTTPException(status_code=500, detail=str(e)) # ============================================================================ # AGENT MEMORY (Gateway compatibility endpoint) # ============================================================================ class AgentMemoryRequest(BaseModel): """Request format from Gateway for saving chat history""" agent_id: str team_id: Optional[str] = None channel_id: Optional[str] = None user_id: str # Support both formats: new (content) and gateway (body_text) content: Optional[str] = None body_text: Optional[str] = None role: str = "user" # user, assistant, system # Support both formats: metadata and body_json metadata: Optional[dict] = None body_json: Optional[dict] = None context: Optional[str] = None scope: Optional[str] = None kind: Optional[str] = None # "message", "event", etc. def get_content(self) -> str: """Get content from either field""" return self.content or self.body_text or "" def get_metadata(self) -> dict: """Get metadata from either field""" return self.metadata or self.body_json or {} @app.post("/agents/{agent_id}/memory") async def save_agent_memory(agent_id: str, request: AgentMemoryRequest, background_tasks: BackgroundTasks): """ Save chat turn to memory with full ingestion pipeline: 1. Save to PostgreSQL (facts table) 2. Create embedding via Cohere and save to Qdrant 3. Update Knowledge Graph in Neo4j """ try: from datetime import datetime from uuid import uuid4 # Create a unique key for this conversation event timestamp = datetime.utcnow().isoformat() message_id = str(uuid4()) fact_key = f"chat_event:{request.channel_id}:{timestamp}" # Store as a fact with JSON payload content = request.get_content() metadata = request.get_metadata() # Skip empty messages if not content or content.startswith("[Photo:"): logger.debug("skipping_empty_or_photo_message", content=content[:50] if content else "") return {"status": "ok", "event_id": None, "indexed": False} # Determine role from kind/body_json if not explicitly set role = request.role if request.body_json and request.body_json.get("type") == "agent_response": role = "assistant" event_data = { "message_id": message_id, "agent_id": agent_id, "team_id": request.team_id, "channel_id": request.channel_id, "user_id": request.user_id, "role": role, "content": content, "metadata": metadata, "scope": request.scope, "kind": request.kind, "timestamp": timestamp } # 1. Save to PostgreSQL (isolated by agent_id) await db.ensure_facts_table() result = await db.upsert_fact( user_id=request.user_id, fact_key=fact_key, fact_value_json=event_data, team_id=request.team_id, agent_id=agent_id # Agent isolation ) logger.info("agent_memory_saved", agent_id=agent_id, user_id=request.user_id, role=role, channel_id=request.channel_id, content_len=len(content)) # 2. Index in Qdrant (async background task) background_tasks.add_task( index_message_in_qdrant, message_id=message_id, content=content, agent_id=agent_id, user_id=request.user_id, channel_id=request.channel_id, role=role, timestamp=timestamp ) # 3. Update Neo4j graph (async background task) background_tasks.add_task( update_neo4j_graph, message_id=message_id, content=content, agent_id=agent_id, user_id=request.user_id, channel_id=request.channel_id, role=role ) return { "status": "ok", "event_id": result.get("fact_id") if result else None, "message_id": message_id, "indexed": True } except Exception as e: logger.error("agent_memory_save_failed", error=str(e), agent_id=agent_id) raise HTTPException(status_code=500, detail=str(e)) async def index_message_in_qdrant( message_id: str, content: str, agent_id: str, user_id: str, channel_id: str, role: str, timestamp: str ): """Index message in Qdrant for semantic search (isolated by agent_id)""" try: from .embedding import get_document_embeddings from qdrant_client.http import models as qmodels # Skip very short messages if len(content) < 10: return # Generate embedding embeddings = await get_document_embeddings([content]) if not embeddings or not embeddings[0]: logger.warning("embedding_failed", message_id=message_id) return vector = embeddings[0] # Use agent-specific collection (isolation!) collection_name = f"{agent_id}_messages" # Ensure collection exists try: vector_store.client.get_collection(collection_name) except Exception: # Create collection if not exists vector_store.client.create_collection( collection_name=collection_name, vectors_config=qmodels.VectorParams( size=len(vector), distance=qmodels.Distance.COSINE ) ) logger.info("created_collection", collection=collection_name) # Save to agent-specific Qdrant collection vector_store.client.upsert( collection_name=collection_name, points=[ qmodels.PointStruct( id=message_id, vector=vector, payload={ "message_id": message_id, "agent_id": agent_id, "user_id": user_id, "channel_id": channel_id, "role": role, "content": content, "timestamp": timestamp, "type": "chat_message" } ) ] ) logger.info("message_indexed_qdrant", message_id=message_id, collection=collection_name, content_len=len(content), vector_dim=len(vector)) except Exception as e: logger.error("qdrant_indexing_failed", error=str(e), message_id=message_id) async def update_neo4j_graph( message_id: str, content: str, agent_id: str, user_id: str, channel_id: str, role: str ): """Update Knowledge Graph in Neo4j (with agent isolation)""" try: import httpx import os neo4j_url = os.getenv("NEO4J_HTTP_URL", "http://neo4j:7474") neo4j_user = os.getenv("NEO4J_USER", "neo4j") neo4j_password = os.getenv("NEO4J_PASSWORD", "DaarionNeo4j2026!") # Create/update User node and Message relationship # IMPORTANT: agent_id is added to relationships for filtering cypher = """ MERGE (u:User {user_id: $user_id}) ON CREATE SET u.created_at = datetime() ON MATCH SET u.last_seen = datetime() MERGE (ch:Channel {channel_id: $channel_id}) ON CREATE SET ch.created_at = datetime() MERGE (a:Agent {agent_id: $agent_id}) ON CREATE SET a.created_at = datetime() MERGE (u)-[p:PARTICIPATES_IN {agent_id: $agent_id}]->(ch) ON CREATE SET p.first_seen = datetime() ON MATCH SET p.last_seen = datetime() CREATE (m:Message { message_id: $message_id, role: $role, content_preview: $content_preview, agent_id: $agent_id, created_at: datetime() }) CREATE (u)-[:SENT {agent_id: $agent_id}]->(m) CREATE (m)-[:IN_CHANNEL {agent_id: $agent_id}]->(ch) CREATE (m)-[:HANDLED_BY]->(a) RETURN m.message_id as id """ async with httpx.AsyncClient(timeout=10.0) as client: response = await client.post( f"{neo4j_url}/db/neo4j/tx/commit", auth=(neo4j_user, neo4j_password), json={ "statements": [{ "statement": cypher, "parameters": { "user_id": user_id, "channel_id": channel_id, "message_id": message_id, "role": role, "content_preview": content[:200] if content else "", "agent_id": agent_id } }] } ) if response.status_code == 200: logger.info("neo4j_graph_updated", message_id=message_id, user_id=user_id, agent_id=agent_id) else: logger.warning("neo4j_update_failed", status=response.status_code, response=response.text[:200]) except Exception as e: logger.error("neo4j_update_error", error=str(e), message_id=message_id) @app.get("/agents/{agent_id}/memory") async def get_agent_memory( agent_id: str, user_id: str = Query(...), channel_id: Optional[str] = None, limit: int = Query(default=20, le=100) ): """ Get recent chat events for an agent/user (isolated by agent_id). """ import json as json_lib try: # Query facts filtered by agent_id (database-level filtering) facts = await db.list_facts(user_id=user_id, agent_id=agent_id, limit=limit) # Filter for chat events from this channel events = [] for fact in facts: if fact.get("fact_key", "").startswith("chat_event:"): # Handle fact_value_json being string or dict event_data = fact.get("fact_value_json", {}) if isinstance(event_data, str): try: event_data = json_lib.loads(event_data) except: event_data = {} if not isinstance(event_data, dict): event_data = {} if channel_id is None or event_data.get("channel_id") == channel_id: events.append(event_data) return {"events": events[:limit]} except Exception as e: logger.error("agent_memory_get_failed", error=str(e), agent_id=agent_id) raise HTTPException(status_code=500, detail=str(e)) # ============================================================================ # ADMIN # ============================================================================ @app.get("/stats") async def get_stats(): """Get service statistics""" return { "vector_store": await vector_store.get_collection_stats(), "database": await db.get_stats() } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)