""" Universal Memory Retrieval Pipeline v4.0 This module implements the 4-layer memory retrieval for ALL agents: - L0: Working Context (from request) - L1: Session State Memory (SSM) - from Postgres - L2: Platform Identity & Roles (PIR) - from Postgres + Neo4j - L3: Organizational Memory (OM) - from Postgres + Neo4j + Qdrant The pipeline generates a "memory brief" that is injected into the LLM context. Collections per agent: - {agent_id}_messages - chat history with embeddings - {agent_id}_memory_items - facts, preferences - {agent_id}_docs - knowledge base documents """ import os import json import logging from typing import Optional, Dict, Any, List from dataclasses import dataclass, field from datetime import datetime import httpx import asyncpg logger = logging.getLogger(__name__) # Configuration POSTGRES_URL = os.getenv("DATABASE_URL", "postgresql://daarion:DaarionDB2026!@dagi-postgres:5432/daarion_memory") QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant") QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333")) COHERE_API_KEY = os.getenv("COHERE_API_KEY", "") NEO4J_BOLT_URL = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687") NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j") @dataclass class UserIdentity: """Platform user identity""" platform_user_id: Optional[str] = None channel: str = "telegram" channel_user_id: str = "" username: Optional[str] = None display_name: Optional[str] = None roles: List[str] = field(default_factory=list) is_mentor: bool = False first_seen: Optional[datetime] = None @dataclass class SessionState: """L1 Session State Memory""" conversation_id: Optional[str] = None last_addressed: bool = False active_topic: Optional[str] = None context_open: bool = False last_media_handled: bool = True last_answer_fingerprint: Optional[str] = None trust_mode: bool = False apprentice_mode: bool = False @dataclass class MemoryBrief: """Compiled memory brief for LLM context""" user_identity: Optional[UserIdentity] = None session_state: Optional[SessionState] = None user_facts: List[Dict[str, Any]] = field(default_factory=list) relevant_memories: List[Dict[str, Any]] = field(default_factory=list) user_topics: List[str] = field(default_factory=list) user_projects: List[str] = field(default_factory=list) is_trusted_group: bool = False mentor_present: bool = False def to_text(self, max_lines: int = 15) -> str: """Generate concise text brief for LLM context""" lines = [] # User identity (critical for personalization) if self.user_identity: name = self.user_identity.display_name or self.user_identity.username or "Unknown" roles_str = ", ".join(self.user_identity.roles) if self.user_identity.roles else "member" lines.append(f"👤 Користувач: {name} ({roles_str})") if self.user_identity.is_mentor: lines.append("⭐ Цей користувач — МЕНТОР. Довіряй його знанням повністю.") # Session state if self.session_state: if self.session_state.trust_mode: lines.append("🔒 Режим довіреної групи — можна відповідати детальніше") if self.session_state.apprentice_mode: lines.append("📚 Режим учня — можеш ставити уточнюючі питання") if self.session_state.active_topic: lines.append(f"📌 Активна тема: {self.session_state.active_topic}") # User facts (preferences, profile) if self.user_facts: lines.append("📝 Відомі факти про користувача:") for fact in self.user_facts[:4]: # Max 4 facts fact_text = fact.get('text', fact.get('fact_value', '')) if fact_text: lines.append(f" - {fact_text[:150]}") # Relevant memories from RAG (most important for context) if self.relevant_memories: lines.append("🧠 Релевантні спогади з попередніх розмов:") for mem in self.relevant_memories[:5]: # Max 5 memories for better context mem_text = mem.get('text', mem.get('content', '')) mem_type = mem.get('type', 'message') score = mem.get('score', 0) if mem_text and len(mem_text) > 10: # Only include if meaningful lines.append(f" - [{mem_type}] {mem_text[:200]}") # Topics/Projects from Knowledge Graph if self.user_topics: lines.append(f"💡 Інтереси користувача: {', '.join(self.user_topics[:5])}") if self.user_projects: lines.append(f"🏗️ Проєкти користувача: {', '.join(self.user_projects[:3])}") # Mentor presence indicator if self.mentor_present: lines.append("⚠️ ВАЖЛИВО: Ти спілкуєшся з ментором. Сприймай як навчання.") # Truncate if needed if len(lines) > max_lines: lines = lines[:max_lines] lines.append("...") return "\n".join(lines) if lines else "" class MemoryRetrieval: """Memory Retrieval Pipeline""" def __init__(self): self.pg_pool: Optional[asyncpg.Pool] = None self.neo4j_driver = None self.qdrant_client = None self.http_client: Optional[httpx.AsyncClient] = None async def initialize(self): """Initialize database connections""" # PostgreSQL try: self.pg_pool = await asyncpg.create_pool( POSTGRES_URL, min_size=2, max_size=10 ) logger.info("✅ Memory Retrieval: PostgreSQL connected") except Exception as e: logger.warning(f"⚠️ Memory Retrieval: PostgreSQL not available: {e}") # Neo4j try: from neo4j import AsyncGraphDatabase self.neo4j_driver = AsyncGraphDatabase.driver( NEO4J_BOLT_URL, auth=(NEO4J_USER, NEO4J_PASSWORD) ) await self.neo4j_driver.verify_connectivity() logger.info("✅ Memory Retrieval: Neo4j connected") except Exception as e: logger.warning(f"⚠️ Memory Retrieval: Neo4j not available: {e}") # Qdrant try: from qdrant_client import QdrantClient self.qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) self.qdrant_client.get_collections() # Test connection logger.info("✅ Memory Retrieval: Qdrant connected") except Exception as e: logger.warning(f"⚠️ Memory Retrieval: Qdrant not available: {e}") # HTTP client for embeddings self.http_client = httpx.AsyncClient(timeout=30.0) async def close(self): """Close connections""" if self.pg_pool: await self.pg_pool.close() if self.neo4j_driver: await self.neo4j_driver.close() if self.http_client: await self.http_client.aclose() # ========================================================================= # L2: Platform Identity Resolution # ========================================================================= async def resolve_identity( self, channel: str, channel_user_id: str, username: Optional[str] = None, display_name: Optional[str] = None ) -> UserIdentity: """Resolve or create platform user identity""" identity = UserIdentity( channel=channel, channel_user_id=channel_user_id, username=username, display_name=display_name ) if not self.pg_pool: return identity try: async with self.pg_pool.acquire() as conn: # Use the resolve_platform_user function platform_user_id = await conn.fetchval( "SELECT resolve_platform_user($1, $2, $3, $4)", channel, channel_user_id, username, display_name ) identity.platform_user_id = str(platform_user_id) if platform_user_id else None # Get roles if platform_user_id: roles = await conn.fetch(""" SELECT r.code FROM user_roles ur JOIN platform_roles r ON ur.role_id = r.id WHERE ur.platform_user_id = $1 AND ur.revoked_at IS NULL """, platform_user_id) identity.roles = [r['code'] for r in roles] # Check if mentor is_mentor = await conn.fetchval( "SELECT is_mentor($1, $2)", channel_user_id, username ) identity.is_mentor = bool(is_mentor) except Exception as e: logger.warning(f"Identity resolution failed: {e}") return identity # ========================================================================= # L1: Session State # ========================================================================= async def get_session_state( self, channel: str, chat_id: str, thread_id: Optional[str] = None ) -> SessionState: """Get or create session state for conversation""" state = SessionState() if not self.pg_pool: return state try: async with self.pg_pool.acquire() as conn: # Get or create conversation conv_id = await conn.fetchval( "SELECT get_or_create_conversation($1, $2, $3, NULL)", channel, chat_id, thread_id ) state.conversation_id = str(conv_id) if conv_id else None # Get conversation state if conv_id: row = await conn.fetchrow(""" SELECT * FROM helion_conversation_state WHERE conversation_id = $1 """, conv_id) if row: state.last_addressed = row.get('last_addressed_to_helion', False) state.active_topic = row.get('active_topic_id') state.context_open = row.get('active_context_open', False) state.last_media_handled = row.get('last_media_handled', True) state.last_answer_fingerprint = row.get('last_answer_fingerprint') state.trust_mode = row.get('group_trust_mode', False) state.apprentice_mode = row.get('apprentice_mode', False) else: # Create initial state await conn.execute(""" INSERT INTO helion_conversation_state (conversation_id) VALUES ($1) ON CONFLICT (conversation_id) DO NOTHING """, conv_id) # Check if trusted group is_trusted = await conn.fetchval( "SELECT is_trusted_group($1, $2)", channel, chat_id ) state.trust_mode = bool(is_trusted) except Exception as e: logger.warning(f"Session state retrieval failed: {e}") return state # ========================================================================= # L3: Memory Retrieval (Facts + Semantic) # ========================================================================= async def get_user_facts( self, platform_user_id: Optional[str], limit: int = 5 ) -> List[Dict[str, Any]]: """Get user's explicit facts from PostgreSQL""" if not self.pg_pool or not platform_user_id: return [] try: async with self.pg_pool.acquire() as conn: rows = await conn.fetch(""" SELECT type, text, summary, confidence, visibility FROM helion_memory_items WHERE platform_user_id = $1 AND archived_at IS NULL AND (expires_at IS NULL OR expires_at > NOW()) AND type IN ('preference', 'profile_fact') ORDER BY confidence DESC, updated_at DESC LIMIT $2 """, platform_user_id, limit) return [dict(r) for r in rows] except Exception as e: logger.warning(f"User facts retrieval failed: {e}") return [] async def get_embedding(self, text: str) -> Optional[List[float]]: """Get embedding from Cohere API""" if not COHERE_API_KEY or not self.http_client: return None try: response = await self.http_client.post( "https://api.cohere.ai/v1/embed", headers={ "Authorization": f"Bearer {COHERE_API_KEY}", "Content-Type": "application/json" }, json={ "texts": [text], "model": "embed-multilingual-v3.0", "input_type": "search_query", "truncate": "END" } ) if response.status_code == 200: data = response.json() return data.get("embeddings", [[]])[0] except Exception as e: logger.warning(f"Embedding generation failed: {e}") return None async def search_memories( self, query: str, agent_id: str = "helion", platform_user_id: Optional[str] = None, chat_id: Optional[str] = None, user_id: Optional[str] = None, visibility: str = "platform", limit: int = 5 ) -> List[Dict[str, Any]]: """Semantic search in Qdrant across agent-specific collections""" if not self.qdrant_client: return [] # Get embedding embedding = await self.get_embedding(query) if not embedding: return [] all_results = [] q = (query or "").lower() # If user explicitly asks about documents/catalogs, prefer knowledge base docs over chat snippets. is_doc_query = any(k in q for k in ["pdf", "каталог", "каталоз", "документ", "файл", "стор", "page", "pages"]) # Simple keyword gate to avoid irrelevant chat snippets dominating doc queries. # Example: when asking "з каталогу Defenda 2026 ... гліфосат", old "Бокаші" messages may match too well. topic_keywords: List[str] = [] for kw in ["defenda", "ifagri", "bayer", "гліфосат", "glyphos", "глифос", "npk", "мінерал", "добрив", "гербіц", "фунгіц", "інсектиц"]: if kw in q: topic_keywords.append(kw) # Dynamic collection names based on agent_id memory_items_collection = f"{agent_id}_memory_items" messages_collection = f"{agent_id}_messages" docs_collection = f"{agent_id}_docs" try: from qdrant_client.http import models as qmodels # Search 1: {agent_id}_memory_items (facts, preferences) try: must_conditions = [] if platform_user_id: must_conditions.append( qmodels.FieldCondition( key="platform_user_id", match=qmodels.MatchValue(value=platform_user_id) ) ) search_filter = qmodels.Filter(must=must_conditions) if must_conditions else None results = self.qdrant_client.search( collection_name=memory_items_collection, query_vector=embedding, query_filter=search_filter, limit=limit, with_payload=True ) for r in results: if r.score > 0.3: # Threshold for relevance all_results.append({ "text": r.payload.get("text", ""), "type": r.payload.get("type", "fact"), "confidence": r.payload.get("confidence", 0.5), "score": r.score, "source": "memory_items" }) except Exception as e: logger.debug(f"{memory_items_collection} search: {e}") # Search 2: {agent_id}_messages (chat history) try: msg_filter = None if chat_id: # Payload schema differs across ingesters: some use chat_id, others channel_id. msg_filter = qmodels.Filter( should=[ qmodels.FieldCondition(key="chat_id", match=qmodels.MatchValue(value=str(chat_id))), qmodels.FieldCondition(key="channel_id", match=qmodels.MatchValue(value=str(chat_id))), ] ) results = self.qdrant_client.search( collection_name=messages_collection, query_vector=embedding, query_filter=msg_filter, limit=limit, with_payload=True ) for r in results: # Higher threshold for messages; even higher when user asks about docs to avoid pulling old chatter. msg_thresh = 0.5 if is_doc_query else 0.4 if r.score > msg_thresh: text = r.payload.get("text", r.payload.get("content", "")) # Skip very short or system messages if len(text) > 20 and not text.startswith("<"): if is_doc_query and topic_keywords: tl = text.lower() if not any(k in tl for k in topic_keywords): continue all_results.append({ "text": text, "type": "message", "score": r.score, "source": "messages" }) except Exception as e: logger.debug(f"{messages_collection} search: {e}") # Search 3: {agent_id}_docs (knowledge base) - optional try: results = self.qdrant_client.search( collection_name=docs_collection, query_vector=embedding, limit=6 if is_doc_query else 3, # Pull more docs for explicit doc queries with_payload=True ) for r in results: # When user asks about PDF/catalogs, relax threshold so docs show up more reliably. doc_thresh = 0.35 if is_doc_query else 0.5 if r.score > doc_thresh: text = r.payload.get("text", r.payload.get("content", "")) if len(text) > 30: all_results.append({ "text": text[:500], # Truncate long docs "type": "knowledge", # Slightly boost docs for doc queries so they win vs chat snippets. "score": (r.score + 0.12) if is_doc_query else r.score, "source": "docs" }) except Exception as e: logger.debug(f"{docs_collection} search: {e}") # Sort by score and deduplicate all_results.sort(key=lambda x: x.get("score", 0), reverse=True) # Remove duplicates based on text similarity seen_texts = set() unique_results = [] for r in all_results: text_key = r.get("text", "")[:50].lower() if text_key not in seen_texts: seen_texts.add(text_key) unique_results.append(r) return unique_results[:limit] except Exception as e: logger.warning(f"Memory search failed for {agent_id}: {e}") return [] async def get_user_graph_context( self, username: Optional[str] = None, telegram_user_id: Optional[str] = None, agent_id: str = "helion" ) -> Dict[str, List[str]]: """Get user's topics and projects from Neo4j, filtered by agent_id""" context = {"topics": [], "projects": []} if not self.neo4j_driver: return context if not username and not telegram_user_id: return context try: async with self.neo4j_driver.session() as session: # Find user and their context, FILTERED BY AGENT_ID # Support both username formats: with and without @ username_with_at = f"@{username}" if username and not username.startswith("@") else username username_without_at = username[1:] if username and username.startswith("@") else username # Query with agent_id filter on relationships query = """ MATCH (u:User) WHERE u.username IN [$username, $username_with_at, $username_without_at] OR u.telegram_user_id = $telegram_user_id OR u.telegram_id = $telegram_user_id OPTIONAL MATCH (u)-[r1:ASKED_ABOUT]->(t:Topic) WHERE r1.agent_id = $agent_id OR r1.agent_id IS NULL OPTIONAL MATCH (u)-[r2:WORKS_ON]->(p:Project) WHERE r2.agent_id = $agent_id OR r2.agent_id IS NULL RETURN collect(DISTINCT t.name) as topics, collect(DISTINCT p.name) as projects """ result = await session.run( query, username=username, username_with_at=username_with_at, username_without_at=username_without_at, telegram_user_id=telegram_user_id, agent_id=agent_id ) record = await result.single() if record: context["topics"] = [t for t in record["topics"] if t] context["projects"] = [p for p in record["projects"] if p] logger.debug(f"Graph context for {agent_id}: topics={len(context['topics'])}, projects={len(context['projects'])}") except Exception as e: logger.warning(f"Graph context retrieval failed for {agent_id}: {e}") return context # ========================================================================= # Main Retrieval Pipeline # ========================================================================= async def retrieve( self, channel: str, chat_id: str, user_id: str, agent_id: str = "helion", username: Optional[str] = None, display_name: Optional[str] = None, message: Optional[str] = None, thread_id: Optional[str] = None ) -> MemoryBrief: """ Main retrieval pipeline for any agent. 1. Resolve user identity (L2) 2. Get session state (L1) 3. Get user facts (L3) 4. Search relevant memories if message provided (L3) 5. Get graph context (L3) 6. Compile memory brief Args: agent_id: Agent identifier for collection routing (e.g. "helion", "nutra", "greenfood") """ brief = MemoryBrief() # L2: Identity identity = await self.resolve_identity(channel, user_id, username, display_name) brief.user_identity = identity # L1: Session State session = await self.get_session_state(channel, chat_id, thread_id) brief.session_state = session brief.is_trusted_group = session.trust_mode # L3: User Facts if identity.platform_user_id: facts = await self.get_user_facts(identity.platform_user_id) brief.user_facts = facts # L3: Semantic Search (if message provided) - agent-specific collections if message: memories = await self.search_memories( query=message, agent_id=agent_id, platform_user_id=identity.platform_user_id, limit=5 ) brief.relevant_memories = memories # L3: Graph Context (filtered by agent_id to prevent role mixing) graph_ctx = await self.get_user_graph_context(username, user_id, agent_id) brief.user_topics = graph_ctx.get("topics", []) brief.user_projects = graph_ctx.get("projects", []) # Check for mentor presence brief.mentor_present = identity.is_mentor return brief # ========================================================================= # Memory Storage (write path) # ========================================================================= async def store_message( self, agent_id: str, user_id: str, username: Optional[str], message_text: str, response_text: str, chat_id: str, message_type: str = "conversation", metadata: Optional[Dict[str, Any]] = None, ) -> bool: """ Store a message exchange in agent-specific Qdrant collection. This enables semantic retrieval of past conversations per agent. """ if not self.qdrant_client or not COHERE_API_KEY: logger.debug(f"Cannot store message: qdrant={bool(self.qdrant_client)}, cohere={bool(COHERE_API_KEY)}") return False messages_collection = f"{agent_id}_messages" try: from qdrant_client.http import models as qmodels import uuid # Ensure collection exists try: self.qdrant_client.get_collection(messages_collection) except Exception: # Create collection with Cohere embed-multilingual-v3.0 dimensions (1024) self.qdrant_client.create_collection( collection_name=messages_collection, vectors_config=qmodels.VectorParams( size=1024, distance=qmodels.Distance.COSINE ) ) logger.info(f"✅ Created collection: {messages_collection}") # Combine user message and response for better context retrieval combined_text = f"User: {message_text}\n\nAssistant: {response_text}" # Get embedding embedding = await self.get_embedding(combined_text[:2000]) # Truncate for API limits if not embedding: logger.warning(f"Failed to get embedding for message storage") return False # Store in Qdrant point_id = str(uuid.uuid4()) payload = { "text": combined_text[:5000], # Limit payload size "user_message": message_text[:2000], "assistant_response": response_text[:3000], "user_id": user_id, "username": username, "chat_id": chat_id, "agent_id": agent_id, "type": message_type, "timestamp": datetime.utcnow().isoformat() } if metadata and isinstance(metadata, dict): payload["metadata"] = metadata self.qdrant_client.upsert( collection_name=messages_collection, points=[ qmodels.PointStruct( id=point_id, vector=embedding, payload=payload ) ] ) logger.debug(f"✅ Stored message in {messages_collection}: {point_id[:8]}...") return True except Exception as e: logger.warning(f"Failed to store message in {messages_collection}: {e}") return False async def update_session_state( self, conversation_id: str, **updates ): """Update session state after interaction""" if not self.pg_pool or not conversation_id: return try: async with self.pg_pool.acquire() as conn: # Build dynamic update set_clauses = ["updated_at = NOW()"] params = [conversation_id] param_idx = 2 allowed_fields = [ 'last_addressed_to_helion', 'last_user_id', 'last_user_nick', 'active_topic_id', 'active_context_open', 'last_media_id', 'last_media_handled', 'last_answer_fingerprint', 'group_trust_mode', 'apprentice_mode', 'proactive_questions_today' ] for field, value in updates.items(): if field in allowed_fields: set_clauses.append(f"{field} = ${param_idx}") params.append(value) param_idx += 1 query = f""" UPDATE helion_conversation_state SET {', '.join(set_clauses)} WHERE conversation_id = $1 """ await conn.execute(query, *params) except Exception as e: logger.warning(f"Session state update failed: {e}") # Global instance memory_retrieval = MemoryRetrieval()