Files
microdao-daarion/services/router/memory_retrieval.py

765 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Universal Memory Retrieval Pipeline v4.0
This module implements the 4-layer memory retrieval for ALL agents:
- L0: Working Context (from request)
- L1: Session State Memory (SSM) - from Postgres
- L2: Platform Identity & Roles (PIR) - from Postgres + Neo4j
- L3: Organizational Memory (OM) - from Postgres + Neo4j + Qdrant
The pipeline generates a "memory brief" that is injected into the LLM context.
Collections per agent:
- {agent_id}_messages - chat history with embeddings
- {agent_id}_memory_items - facts, preferences
- {agent_id}_docs - knowledge base documents
"""
import os
import json
import logging
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
from datetime import datetime
import httpx
import asyncpg
logger = logging.getLogger(__name__)
# Configuration
POSTGRES_URL = os.getenv("DATABASE_URL", "postgresql://daarion:DaarionDB2026!@dagi-postgres:5432/daarion_memory")
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
COHERE_API_KEY = os.getenv("COHERE_API_KEY", "")
NEO4J_BOLT_URL = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j")
@dataclass
class UserIdentity:
"""Platform user identity"""
platform_user_id: Optional[str] = None
channel: str = "telegram"
channel_user_id: str = ""
username: Optional[str] = None
display_name: Optional[str] = None
roles: List[str] = field(default_factory=list)
is_mentor: bool = False
first_seen: Optional[datetime] = None
@dataclass
class SessionState:
"""L1 Session State Memory"""
conversation_id: Optional[str] = None
last_addressed: bool = False
active_topic: Optional[str] = None
context_open: bool = False
last_media_handled: bool = True
last_answer_fingerprint: Optional[str] = None
trust_mode: bool = False
apprentice_mode: bool = False
@dataclass
class MemoryBrief:
"""Compiled memory brief for LLM context"""
user_identity: Optional[UserIdentity] = None
session_state: Optional[SessionState] = None
user_facts: List[Dict[str, Any]] = field(default_factory=list)
relevant_memories: List[Dict[str, Any]] = field(default_factory=list)
user_topics: List[str] = field(default_factory=list)
user_projects: List[str] = field(default_factory=list)
is_trusted_group: bool = False
mentor_present: bool = False
def to_text(self, max_lines: int = 15) -> str:
"""Generate concise text brief for LLM context"""
lines = []
# User identity (critical for personalization)
if self.user_identity:
name = self.user_identity.display_name or self.user_identity.username or "Unknown"
roles_str = ", ".join(self.user_identity.roles) if self.user_identity.roles else "member"
lines.append(f"👤 Користувач: {name} ({roles_str})")
if self.user_identity.is_mentor:
lines.append("⭐ Цей користувач — МЕНТОР. Довіряй його знанням повністю.")
# Session state
if self.session_state:
if self.session_state.trust_mode:
lines.append("🔒 Режим довіреної групи — можна відповідати детальніше")
if self.session_state.apprentice_mode:
lines.append("📚 Режим учня — можеш ставити уточнюючі питання")
if self.session_state.active_topic:
lines.append(f"📌 Активна тема: {self.session_state.active_topic}")
# User facts (preferences, profile)
if self.user_facts:
lines.append("📝 Відомі факти про користувача:")
for fact in self.user_facts[:4]: # Max 4 facts
fact_text = fact.get('text', fact.get('fact_value', ''))
if fact_text:
lines.append(f" - {fact_text[:150]}")
# Relevant memories from RAG (most important for context)
if self.relevant_memories:
lines.append("🧠 Релевантні спогади з попередніх розмов:")
for mem in self.relevant_memories[:5]: # Max 5 memories for better context
mem_text = mem.get('text', mem.get('content', ''))
mem_type = mem.get('type', 'message')
score = mem.get('score', 0)
if mem_text and len(mem_text) > 10:
# Only include if meaningful
lines.append(f" - [{mem_type}] {mem_text[:200]}")
# Topics/Projects from Knowledge Graph
if self.user_topics:
lines.append(f"💡 Інтереси користувача: {', '.join(self.user_topics[:5])}")
if self.user_projects:
lines.append(f"🏗️ Проєкти користувача: {', '.join(self.user_projects[:3])}")
# Mentor presence indicator
if self.mentor_present:
lines.append("⚠️ ВАЖЛИВО: Ти спілкуєшся з ментором. Сприймай як навчання.")
# Truncate if needed
if len(lines) > max_lines:
lines = lines[:max_lines]
lines.append("...")
return "\n".join(lines) if lines else ""
class MemoryRetrieval:
"""Memory Retrieval Pipeline"""
def __init__(self):
self.pg_pool: Optional[asyncpg.Pool] = None
self.neo4j_driver = None
self.qdrant_client = None
self.http_client: Optional[httpx.AsyncClient] = None
async def initialize(self):
"""Initialize database connections"""
# PostgreSQL
try:
self.pg_pool = await asyncpg.create_pool(
POSTGRES_URL,
min_size=2,
max_size=10
)
logger.info("✅ Memory Retrieval: PostgreSQL connected")
except Exception as e:
logger.warning(f"⚠️ Memory Retrieval: PostgreSQL not available: {e}")
# Neo4j
try:
from neo4j import AsyncGraphDatabase
self.neo4j_driver = AsyncGraphDatabase.driver(
NEO4J_BOLT_URL,
auth=(NEO4J_USER, NEO4J_PASSWORD)
)
await self.neo4j_driver.verify_connectivity()
logger.info("✅ Memory Retrieval: Neo4j connected")
except Exception as e:
logger.warning(f"⚠️ Memory Retrieval: Neo4j not available: {e}")
# Qdrant
try:
from qdrant_client import QdrantClient
self.qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
self.qdrant_client.get_collections() # Test connection
logger.info("✅ Memory Retrieval: Qdrant connected")
except Exception as e:
logger.warning(f"⚠️ Memory Retrieval: Qdrant not available: {e}")
# HTTP client for embeddings
self.http_client = httpx.AsyncClient(timeout=30.0)
async def close(self):
"""Close connections"""
if self.pg_pool:
await self.pg_pool.close()
if self.neo4j_driver:
await self.neo4j_driver.close()
if self.http_client:
await self.http_client.aclose()
# =========================================================================
# L2: Platform Identity Resolution
# =========================================================================
async def resolve_identity(
self,
channel: str,
channel_user_id: str,
username: Optional[str] = None,
display_name: Optional[str] = None
) -> UserIdentity:
"""Resolve or create platform user identity"""
identity = UserIdentity(
channel=channel,
channel_user_id=channel_user_id,
username=username,
display_name=display_name
)
if not self.pg_pool:
return identity
try:
async with self.pg_pool.acquire() as conn:
# Use the resolve_platform_user function
platform_user_id = await conn.fetchval(
"SELECT resolve_platform_user($1, $2, $3, $4)",
channel, channel_user_id, username, display_name
)
identity.platform_user_id = str(platform_user_id) if platform_user_id else None
# Get roles
if platform_user_id:
roles = await conn.fetch("""
SELECT r.code FROM user_roles ur
JOIN platform_roles r ON ur.role_id = r.id
WHERE ur.platform_user_id = $1 AND ur.revoked_at IS NULL
""", platform_user_id)
identity.roles = [r['code'] for r in roles]
# Check if mentor
is_mentor = await conn.fetchval(
"SELECT is_mentor($1, $2)",
channel_user_id, username
)
identity.is_mentor = bool(is_mentor)
except Exception as e:
logger.warning(f"Identity resolution failed: {e}")
return identity
# =========================================================================
# L1: Session State
# =========================================================================
async def get_session_state(
self,
channel: str,
chat_id: str,
thread_id: Optional[str] = None
) -> SessionState:
"""Get or create session state for conversation"""
state = SessionState()
if not self.pg_pool:
return state
try:
async with self.pg_pool.acquire() as conn:
# Get or create conversation
conv_id = await conn.fetchval(
"SELECT get_or_create_conversation($1, $2, $3, NULL)",
channel, chat_id, thread_id
)
state.conversation_id = str(conv_id) if conv_id else None
# Get conversation state
if conv_id:
row = await conn.fetchrow("""
SELECT * FROM helion_conversation_state
WHERE conversation_id = $1
""", conv_id)
if row:
state.last_addressed = row.get('last_addressed_to_helion', False)
state.active_topic = row.get('active_topic_id')
state.context_open = row.get('active_context_open', False)
state.last_media_handled = row.get('last_media_handled', True)
state.last_answer_fingerprint = row.get('last_answer_fingerprint')
state.trust_mode = row.get('group_trust_mode', False)
state.apprentice_mode = row.get('apprentice_mode', False)
else:
# Create initial state
await conn.execute("""
INSERT INTO helion_conversation_state (conversation_id)
VALUES ($1)
ON CONFLICT (conversation_id) DO NOTHING
""", conv_id)
# Check if trusted group
is_trusted = await conn.fetchval(
"SELECT is_trusted_group($1, $2)",
channel, chat_id
)
state.trust_mode = bool(is_trusted)
except Exception as e:
logger.warning(f"Session state retrieval failed: {e}")
return state
# =========================================================================
# L3: Memory Retrieval (Facts + Semantic)
# =========================================================================
async def get_user_facts(
self,
platform_user_id: Optional[str],
limit: int = 5
) -> List[Dict[str, Any]]:
"""Get user's explicit facts from PostgreSQL"""
if not self.pg_pool or not platform_user_id:
return []
try:
async with self.pg_pool.acquire() as conn:
rows = await conn.fetch("""
SELECT type, text, summary, confidence, visibility
FROM helion_memory_items
WHERE platform_user_id = $1
AND archived_at IS NULL
AND (expires_at IS NULL OR expires_at > NOW())
AND type IN ('preference', 'profile_fact')
ORDER BY confidence DESC, updated_at DESC
LIMIT $2
""", platform_user_id, limit)
return [dict(r) for r in rows]
except Exception as e:
logger.warning(f"User facts retrieval failed: {e}")
return []
async def get_embedding(self, text: str) -> Optional[List[float]]:
"""Get embedding from Cohere API"""
if not COHERE_API_KEY or not self.http_client:
return None
try:
response = await self.http_client.post(
"https://api.cohere.ai/v1/embed",
headers={
"Authorization": f"Bearer {COHERE_API_KEY}",
"Content-Type": "application/json"
},
json={
"texts": [text],
"model": "embed-multilingual-v3.0",
"input_type": "search_query",
"truncate": "END"
}
)
if response.status_code == 200:
data = response.json()
return data.get("embeddings", [[]])[0]
except Exception as e:
logger.warning(f"Embedding generation failed: {e}")
return None
async def search_memories(
self,
query: str,
agent_id: str = "helion",
platform_user_id: Optional[str] = None,
chat_id: Optional[str] = None,
user_id: Optional[str] = None,
visibility: str = "platform",
limit: int = 5
) -> List[Dict[str, Any]]:
"""Semantic search in Qdrant across agent-specific collections"""
if not self.qdrant_client:
return []
# Get embedding
embedding = await self.get_embedding(query)
if not embedding:
return []
all_results = []
q = (query or "").lower()
# If user explicitly asks about documents/catalogs, prefer knowledge base docs over chat snippets.
is_doc_query = any(k in q for k in ["pdf", "каталог", "каталоз", "документ", "файл", "стор", "page", "pages"])
# Simple keyword gate to avoid irrelevant chat snippets dominating doc queries.
# Example: when asking "з каталогу Defenda 2026 ... гліфосат", old "Бокаші" messages may match too well.
topic_keywords: List[str] = []
for kw in ["defenda", "ifagri", "bayer", "гліфосат", "glyphos", "глифос", "npk", "мінерал", "добрив", "гербіц", "фунгіц", "інсектиц"]:
if kw in q:
topic_keywords.append(kw)
# Dynamic collection names based on agent_id
memory_items_collection = f"{agent_id}_memory_items"
messages_collection = f"{agent_id}_messages"
docs_collection = f"{agent_id}_docs"
try:
from qdrant_client.http import models as qmodels
# Search 1: {agent_id}_memory_items (facts, preferences)
try:
must_conditions = []
if platform_user_id:
must_conditions.append(
qmodels.FieldCondition(
key="platform_user_id",
match=qmodels.MatchValue(value=platform_user_id)
)
)
search_filter = qmodels.Filter(must=must_conditions) if must_conditions else None
results = self.qdrant_client.search(
collection_name=memory_items_collection,
query_vector=embedding,
query_filter=search_filter,
limit=limit,
with_payload=True
)
for r in results:
if r.score > 0.3: # Threshold for relevance
all_results.append({
"text": r.payload.get("text", ""),
"type": r.payload.get("type", "fact"),
"confidence": r.payload.get("confidence", 0.5),
"score": r.score,
"source": "memory_items"
})
except Exception as e:
logger.debug(f"{memory_items_collection} search: {e}")
# Search 2: {agent_id}_messages (chat history)
try:
msg_filter = None
if chat_id:
# Payload schema differs across ingesters: some use chat_id, others channel_id.
msg_filter = qmodels.Filter(
should=[
qmodels.FieldCondition(key="chat_id", match=qmodels.MatchValue(value=str(chat_id))),
qmodels.FieldCondition(key="channel_id", match=qmodels.MatchValue(value=str(chat_id))),
]
)
results = self.qdrant_client.search(
collection_name=messages_collection,
query_vector=embedding,
query_filter=msg_filter,
limit=limit,
with_payload=True
)
for r in results:
# Higher threshold for messages; even higher when user asks about docs to avoid pulling old chatter.
msg_thresh = 0.5 if is_doc_query else 0.4
if r.score > msg_thresh:
text = r.payload.get("text", r.payload.get("content", ""))
# Skip very short or system messages
if len(text) > 20 and not text.startswith("<"):
if is_doc_query and topic_keywords:
tl = text.lower()
if not any(k in tl for k in topic_keywords):
continue
all_results.append({
"text": text,
"type": "message",
"score": r.score,
"source": "messages"
})
except Exception as e:
logger.debug(f"{messages_collection} search: {e}")
# Search 3: {agent_id}_docs (knowledge base) - optional
try:
results = self.qdrant_client.search(
collection_name=docs_collection,
query_vector=embedding,
limit=6 if is_doc_query else 3, # Pull more docs for explicit doc queries
with_payload=True
)
for r in results:
# When user asks about PDF/catalogs, relax threshold so docs show up more reliably.
doc_thresh = 0.35 if is_doc_query else 0.5
if r.score > doc_thresh:
text = r.payload.get("text", r.payload.get("content", ""))
if len(text) > 30:
all_results.append({
"text": text[:500], # Truncate long docs
"type": "knowledge",
# Slightly boost docs for doc queries so they win vs chat snippets.
"score": (r.score + 0.12) if is_doc_query else r.score,
"source": "docs"
})
except Exception as e:
logger.debug(f"{docs_collection} search: {e}")
# Sort by score and deduplicate
all_results.sort(key=lambda x: x.get("score", 0), reverse=True)
# Remove duplicates based on text similarity
seen_texts = set()
unique_results = []
for r in all_results:
text_key = r.get("text", "")[:50].lower()
if text_key not in seen_texts:
seen_texts.add(text_key)
unique_results.append(r)
return unique_results[:limit]
except Exception as e:
logger.warning(f"Memory search failed for {agent_id}: {e}")
return []
async def get_user_graph_context(
self,
username: Optional[str] = None,
telegram_user_id: Optional[str] = None,
agent_id: str = "helion"
) -> Dict[str, List[str]]:
"""Get user's topics and projects from Neo4j, filtered by agent_id"""
context = {"topics": [], "projects": []}
if not self.neo4j_driver:
return context
if not username and not telegram_user_id:
return context
try:
async with self.neo4j_driver.session() as session:
# Find user and their context, FILTERED BY AGENT_ID
# Support both username formats: with and without @
username_with_at = f"@{username}" if username and not username.startswith("@") else username
username_without_at = username[1:] if username and username.startswith("@") else username
# Query with agent_id filter on relationships
query = """
MATCH (u:User)
WHERE u.username IN [$username, $username_with_at, $username_without_at]
OR u.telegram_user_id = $telegram_user_id
OR u.telegram_id = $telegram_user_id
OPTIONAL MATCH (u)-[r1:ASKED_ABOUT]->(t:Topic)
WHERE r1.agent_id = $agent_id OR r1.agent_id IS NULL
OPTIONAL MATCH (u)-[r2:WORKS_ON]->(p:Project)
WHERE r2.agent_id = $agent_id OR r2.agent_id IS NULL
RETURN
collect(DISTINCT t.name) as topics,
collect(DISTINCT p.name) as projects
"""
result = await session.run(
query,
username=username,
username_with_at=username_with_at,
username_without_at=username_without_at,
telegram_user_id=telegram_user_id,
agent_id=agent_id
)
record = await result.single()
if record:
context["topics"] = [t for t in record["topics"] if t]
context["projects"] = [p for p in record["projects"] if p]
logger.debug(f"Graph context for {agent_id}: topics={len(context['topics'])}, projects={len(context['projects'])}")
except Exception as e:
logger.warning(f"Graph context retrieval failed for {agent_id}: {e}")
return context
# =========================================================================
# Main Retrieval Pipeline
# =========================================================================
async def retrieve(
self,
channel: str,
chat_id: str,
user_id: str,
agent_id: str = "helion",
username: Optional[str] = None,
display_name: Optional[str] = None,
message: Optional[str] = None,
thread_id: Optional[str] = None
) -> MemoryBrief:
"""
Main retrieval pipeline for any agent.
1. Resolve user identity (L2)
2. Get session state (L1)
3. Get user facts (L3)
4. Search relevant memories if message provided (L3)
5. Get graph context (L3)
6. Compile memory brief
Args:
agent_id: Agent identifier for collection routing (e.g. "helion", "nutra", "greenfood")
"""
brief = MemoryBrief()
# L2: Identity
identity = await self.resolve_identity(channel, user_id, username, display_name)
brief.user_identity = identity
# L1: Session State
session = await self.get_session_state(channel, chat_id, thread_id)
brief.session_state = session
brief.is_trusted_group = session.trust_mode
# L3: User Facts
if identity.platform_user_id:
facts = await self.get_user_facts(identity.platform_user_id)
brief.user_facts = facts
# L3: Semantic Search (if message provided) - agent-specific collections
if message:
memories = await self.search_memories(
query=message,
agent_id=agent_id,
platform_user_id=identity.platform_user_id,
limit=5
)
brief.relevant_memories = memories
# L3: Graph Context (filtered by agent_id to prevent role mixing)
graph_ctx = await self.get_user_graph_context(username, user_id, agent_id)
brief.user_topics = graph_ctx.get("topics", [])
brief.user_projects = graph_ctx.get("projects", [])
# Check for mentor presence
brief.mentor_present = identity.is_mentor
return brief
# =========================================================================
# Memory Storage (write path)
# =========================================================================
async def store_message(
self,
agent_id: str,
user_id: str,
username: Optional[str],
message_text: str,
response_text: str,
chat_id: str,
message_type: str = "conversation",
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
"""
Store a message exchange in agent-specific Qdrant collection.
This enables semantic retrieval of past conversations per agent.
"""
if not self.qdrant_client or not COHERE_API_KEY:
logger.debug(f"Cannot store message: qdrant={bool(self.qdrant_client)}, cohere={bool(COHERE_API_KEY)}")
return False
messages_collection = f"{agent_id}_messages"
try:
from qdrant_client.http import models as qmodels
import uuid
# Ensure collection exists
try:
self.qdrant_client.get_collection(messages_collection)
except Exception:
# Create collection with Cohere embed-multilingual-v3.0 dimensions (1024)
self.qdrant_client.create_collection(
collection_name=messages_collection,
vectors_config=qmodels.VectorParams(
size=1024,
distance=qmodels.Distance.COSINE
)
)
logger.info(f"✅ Created collection: {messages_collection}")
# Combine user message and response for better context retrieval
combined_text = f"User: {message_text}\n\nAssistant: {response_text}"
# Get embedding
embedding = await self.get_embedding(combined_text[:2000]) # Truncate for API limits
if not embedding:
logger.warning(f"Failed to get embedding for message storage")
return False
# Store in Qdrant
point_id = str(uuid.uuid4())
payload = {
"text": combined_text[:5000], # Limit payload size
"user_message": message_text[:2000],
"assistant_response": response_text[:3000],
"user_id": user_id,
"username": username,
"chat_id": chat_id,
"agent_id": agent_id,
"type": message_type,
"timestamp": datetime.utcnow().isoformat()
}
if metadata and isinstance(metadata, dict):
payload["metadata"] = metadata
self.qdrant_client.upsert(
collection_name=messages_collection,
points=[
qmodels.PointStruct(
id=point_id,
vector=embedding,
payload=payload
)
]
)
logger.debug(f"✅ Stored message in {messages_collection}: {point_id[:8]}...")
return True
except Exception as e:
logger.warning(f"Failed to store message in {messages_collection}: {e}")
return False
async def update_session_state(
self,
conversation_id: str,
**updates
):
"""Update session state after interaction"""
if not self.pg_pool or not conversation_id:
return
try:
async with self.pg_pool.acquire() as conn:
# Build dynamic update
set_clauses = ["updated_at = NOW()"]
params = [conversation_id]
param_idx = 2
allowed_fields = [
'last_addressed_to_helion', 'last_user_id', 'last_user_nick',
'active_topic_id', 'active_context_open', 'last_media_id',
'last_media_handled', 'last_answer_fingerprint', 'group_trust_mode',
'apprentice_mode', 'proactive_questions_today'
]
for field, value in updates.items():
if field in allowed_fields:
set_clauses.append(f"{field} = ${param_idx}")
params.append(value)
param_idx += 1
query = f"""
UPDATE helion_conversation_state
SET {', '.join(set_clauses)}
WHERE conversation_id = $1
"""
await conn.execute(query, *params)
except Exception as e:
logger.warning(f"Session state update failed: {e}")
# Global instance
memory_retrieval = MemoryRetrieval()