802 lines
32 KiB
Python
802 lines
32 KiB
Python
"""
|
||
Universal Memory Retrieval Pipeline v4.0
|
||
|
||
This module implements the 4-layer memory retrieval for ALL agents:
|
||
- L0: Working Context (from request)
|
||
- L1: Session State Memory (SSM) - from Postgres
|
||
- L2: Platform Identity & Roles (PIR) - from Postgres + Neo4j
|
||
- L3: Organizational Memory (OM) - from Postgres + Neo4j + Qdrant
|
||
|
||
The pipeline generates a "memory brief" that is injected into the LLM context.
|
||
|
||
Collections per agent:
|
||
- {agent_id}_messages - chat history with embeddings
|
||
- {agent_id}_memory_items - facts, preferences
|
||
- {agent_id}_docs - knowledge base documents
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import logging
|
||
import re
|
||
from typing import Optional, Dict, Any, List
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
|
||
import httpx
|
||
import asyncpg
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Configuration
|
||
POSTGRES_URL = os.getenv("DATABASE_URL", "postgresql://daarion:DaarionDB2026!@dagi-postgres:5432/daarion_memory")
|
||
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
|
||
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
||
COHERE_API_KEY = os.getenv("COHERE_API_KEY", "")
|
||
NEO4J_BOLT_URL = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687")
|
||
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
|
||
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j")
|
||
|
||
|
||
@dataclass
|
||
class UserIdentity:
|
||
"""Platform user identity"""
|
||
platform_user_id: Optional[str] = None
|
||
channel: str = "telegram"
|
||
channel_user_id: str = ""
|
||
username: Optional[str] = None
|
||
display_name: Optional[str] = None
|
||
roles: List[str] = field(default_factory=list)
|
||
is_mentor: bool = False
|
||
first_seen: Optional[datetime] = None
|
||
|
||
|
||
@dataclass
|
||
class SessionState:
|
||
"""L1 Session State Memory"""
|
||
conversation_id: Optional[str] = None
|
||
last_addressed: bool = False
|
||
active_topic: Optional[str] = None
|
||
context_open: bool = False
|
||
last_media_handled: bool = True
|
||
last_answer_fingerprint: Optional[str] = None
|
||
trust_mode: bool = False
|
||
apprentice_mode: bool = False
|
||
|
||
|
||
@dataclass
|
||
class MemoryBrief:
|
||
"""Compiled memory brief for LLM context"""
|
||
user_identity: Optional[UserIdentity] = None
|
||
session_state: Optional[SessionState] = None
|
||
user_facts: List[Dict[str, Any]] = field(default_factory=list)
|
||
relevant_memories: List[Dict[str, Any]] = field(default_factory=list)
|
||
user_topics: List[str] = field(default_factory=list)
|
||
user_projects: List[str] = field(default_factory=list)
|
||
is_trusted_group: bool = False
|
||
mentor_present: bool = False
|
||
|
||
def to_text(self, max_lines: int = 15) -> str:
|
||
"""Generate concise text brief for LLM context"""
|
||
lines = []
|
||
|
||
# User identity (critical for personalization)
|
||
if self.user_identity:
|
||
name = self.user_identity.display_name or self.user_identity.username or "Unknown"
|
||
roles_str = ", ".join(self.user_identity.roles) if self.user_identity.roles else "member"
|
||
lines.append(f"👤 Користувач: {name} ({roles_str})")
|
||
if self.user_identity.is_mentor:
|
||
lines.append("⭐ Цей користувач — МЕНТОР. Довіряй його знанням повністю.")
|
||
|
||
# Session state
|
||
if self.session_state:
|
||
if self.session_state.trust_mode:
|
||
lines.append("🔒 Режим довіреної групи — можна відповідати детальніше")
|
||
if self.session_state.apprentice_mode:
|
||
lines.append("📚 Режим учня — можеш ставити уточнюючі питання")
|
||
if self.session_state.active_topic:
|
||
lines.append(f"📌 Активна тема: {self.session_state.active_topic}")
|
||
|
||
# User facts (preferences, profile)
|
||
if self.user_facts:
|
||
lines.append("📝 Відомі факти про користувача:")
|
||
for fact in self.user_facts[:4]: # Max 4 facts
|
||
fact_text = fact.get('text', fact.get('fact_value', ''))
|
||
if fact_text:
|
||
lines.append(f" - {fact_text[:150]}")
|
||
|
||
# Relevant memories from RAG (most important for context)
|
||
if self.relevant_memories:
|
||
lines.append("🧠 Релевантні спогади з попередніх розмов:")
|
||
for mem in self.relevant_memories[:5]: # Max 5 memories for better context
|
||
mem_text = mem.get('text', mem.get('content', ''))
|
||
mem_type = mem.get('type', 'message')
|
||
score = mem.get('score', 0)
|
||
if mem_text and len(mem_text) > 10:
|
||
# Only include if meaningful
|
||
lines.append(f" - [{mem_type}] {mem_text[:200]}")
|
||
|
||
# Topics/Projects from Knowledge Graph
|
||
if self.user_topics:
|
||
lines.append(f"💡 Інтереси користувача: {', '.join(self.user_topics[:5])}")
|
||
if self.user_projects:
|
||
lines.append(f"🏗️ Проєкти користувача: {', '.join(self.user_projects[:3])}")
|
||
|
||
# Mentor presence indicator
|
||
if self.mentor_present:
|
||
lines.append("⚠️ ВАЖЛИВО: Ти спілкуєшся з ментором. Сприймай як навчання.")
|
||
|
||
# Truncate if needed
|
||
if len(lines) > max_lines:
|
||
lines = lines[:max_lines]
|
||
lines.append("...")
|
||
|
||
return "\n".join(lines) if lines else ""
|
||
|
||
|
||
class MemoryRetrieval:
|
||
"""Memory Retrieval Pipeline"""
|
||
|
||
def __init__(self):
|
||
self.pg_pool: Optional[asyncpg.Pool] = None
|
||
self.neo4j_driver = None
|
||
self.qdrant_client = None
|
||
self.http_client: Optional[httpx.AsyncClient] = None
|
||
|
||
async def initialize(self):
|
||
"""Initialize database connections"""
|
||
# PostgreSQL
|
||
try:
|
||
self.pg_pool = await asyncpg.create_pool(
|
||
POSTGRES_URL,
|
||
min_size=2,
|
||
max_size=10
|
||
)
|
||
logger.info("✅ Memory Retrieval: PostgreSQL connected")
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ Memory Retrieval: PostgreSQL not available: {e}")
|
||
|
||
# Neo4j
|
||
try:
|
||
from neo4j import AsyncGraphDatabase
|
||
self.neo4j_driver = AsyncGraphDatabase.driver(
|
||
NEO4J_BOLT_URL,
|
||
auth=(NEO4J_USER, NEO4J_PASSWORD)
|
||
)
|
||
await self.neo4j_driver.verify_connectivity()
|
||
logger.info("✅ Memory Retrieval: Neo4j connected")
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ Memory Retrieval: Neo4j not available: {e}")
|
||
|
||
# Qdrant
|
||
try:
|
||
from qdrant_client import QdrantClient
|
||
self.qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
||
self.qdrant_client.get_collections() # Test connection
|
||
logger.info("✅ Memory Retrieval: Qdrant connected")
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ Memory Retrieval: Qdrant not available: {e}")
|
||
|
||
# HTTP client for embeddings
|
||
self.http_client = httpx.AsyncClient(timeout=30.0)
|
||
|
||
async def close(self):
|
||
"""Close connections"""
|
||
if self.pg_pool:
|
||
await self.pg_pool.close()
|
||
if self.neo4j_driver:
|
||
await self.neo4j_driver.close()
|
||
if self.http_client:
|
||
await self.http_client.aclose()
|
||
|
||
# =========================================================================
|
||
# L2: Platform Identity Resolution
|
||
# =========================================================================
|
||
|
||
async def resolve_identity(
|
||
self,
|
||
channel: str,
|
||
channel_user_id: str,
|
||
username: Optional[str] = None,
|
||
display_name: Optional[str] = None
|
||
) -> UserIdentity:
|
||
"""Resolve or create platform user identity"""
|
||
identity = UserIdentity(
|
||
channel=channel,
|
||
channel_user_id=channel_user_id,
|
||
username=username,
|
||
display_name=display_name
|
||
)
|
||
|
||
if not self.pg_pool:
|
||
return identity
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
# Use the resolve_platform_user function
|
||
platform_user_id = await conn.fetchval(
|
||
"SELECT resolve_platform_user($1, $2, $3, $4)",
|
||
channel, channel_user_id, username, display_name
|
||
)
|
||
identity.platform_user_id = str(platform_user_id) if platform_user_id else None
|
||
|
||
# Get roles
|
||
if platform_user_id:
|
||
roles = await conn.fetch("""
|
||
SELECT r.code FROM user_roles ur
|
||
JOIN platform_roles r ON ur.role_id = r.id
|
||
WHERE ur.platform_user_id = $1 AND ur.revoked_at IS NULL
|
||
""", platform_user_id)
|
||
identity.roles = [r['code'] for r in roles]
|
||
|
||
# Check if mentor
|
||
is_mentor = await conn.fetchval(
|
||
"SELECT is_mentor($1, $2)",
|
||
channel_user_id, username
|
||
)
|
||
identity.is_mentor = bool(is_mentor)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Identity resolution failed: {e}")
|
||
|
||
return identity
|
||
|
||
# =========================================================================
|
||
# L1: Session State
|
||
# =========================================================================
|
||
|
||
async def get_session_state(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
thread_id: Optional[str] = None
|
||
) -> SessionState:
|
||
"""Get or create session state for conversation"""
|
||
state = SessionState()
|
||
|
||
if not self.pg_pool:
|
||
return state
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
# Get or create conversation
|
||
conv_id = await conn.fetchval(
|
||
"SELECT get_or_create_conversation($1, $2, $3, NULL)",
|
||
channel, chat_id, thread_id
|
||
)
|
||
state.conversation_id = str(conv_id) if conv_id else None
|
||
|
||
# Get conversation state
|
||
if conv_id:
|
||
row = await conn.fetchrow("""
|
||
SELECT * FROM helion_conversation_state
|
||
WHERE conversation_id = $1
|
||
""", conv_id)
|
||
|
||
if row:
|
||
state.last_addressed = row.get('last_addressed_to_helion', False)
|
||
state.active_topic = row.get('active_topic_id')
|
||
state.context_open = row.get('active_context_open', False)
|
||
state.last_media_handled = row.get('last_media_handled', True)
|
||
state.last_answer_fingerprint = row.get('last_answer_fingerprint')
|
||
state.trust_mode = row.get('group_trust_mode', False)
|
||
state.apprentice_mode = row.get('apprentice_mode', False)
|
||
else:
|
||
# Create initial state
|
||
await conn.execute("""
|
||
INSERT INTO helion_conversation_state (conversation_id)
|
||
VALUES ($1)
|
||
ON CONFLICT (conversation_id) DO NOTHING
|
||
""", conv_id)
|
||
|
||
# Check if trusted group
|
||
is_trusted = await conn.fetchval(
|
||
"SELECT is_trusted_group($1, $2)",
|
||
channel, chat_id
|
||
)
|
||
state.trust_mode = bool(is_trusted)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Session state retrieval failed: {e}")
|
||
|
||
return state
|
||
|
||
# =========================================================================
|
||
# L3: Memory Retrieval (Facts + Semantic)
|
||
# =========================================================================
|
||
|
||
async def get_user_facts(
|
||
self,
|
||
platform_user_id: Optional[str],
|
||
limit: int = 5
|
||
) -> List[Dict[str, Any]]:
|
||
"""Get user's explicit facts from PostgreSQL"""
|
||
if not self.pg_pool or not platform_user_id:
|
||
return []
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
rows = await conn.fetch("""
|
||
SELECT type, text, summary, confidence, visibility
|
||
FROM helion_memory_items
|
||
WHERE platform_user_id = $1
|
||
AND archived_at IS NULL
|
||
AND (expires_at IS NULL OR expires_at > NOW())
|
||
AND type IN ('preference', 'profile_fact')
|
||
ORDER BY confidence DESC, updated_at DESC
|
||
LIMIT $2
|
||
""", platform_user_id, limit)
|
||
|
||
return [dict(r) for r in rows]
|
||
except Exception as e:
|
||
logger.warning(f"User facts retrieval failed: {e}")
|
||
return []
|
||
|
||
async def get_embedding(self, text: str) -> Optional[List[float]]:
|
||
"""Get embedding from Cohere API"""
|
||
if not COHERE_API_KEY or not self.http_client:
|
||
return None
|
||
|
||
try:
|
||
response = await self.http_client.post(
|
||
"https://api.cohere.ai/v1/embed",
|
||
headers={
|
||
"Authorization": f"Bearer {COHERE_API_KEY}",
|
||
"Content-Type": "application/json"
|
||
},
|
||
json={
|
||
"texts": [text],
|
||
"model": "embed-multilingual-v3.0",
|
||
"input_type": "search_query",
|
||
"truncate": "END"
|
||
}
|
||
)
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
return data.get("embeddings", [[]])[0]
|
||
except Exception as e:
|
||
logger.warning(f"Embedding generation failed: {e}")
|
||
|
||
return None
|
||
|
||
async def search_memories(
|
||
self,
|
||
query: str,
|
||
agent_id: str = "helion",
|
||
platform_user_id: Optional[str] = None,
|
||
chat_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
visibility: str = "platform",
|
||
limit: int = 5
|
||
) -> List[Dict[str, Any]]:
|
||
"""Semantic search in Qdrant across agent-specific collections"""
|
||
if not self.qdrant_client:
|
||
return []
|
||
|
||
# Get embedding
|
||
embedding = await self.get_embedding(query)
|
||
if not embedding:
|
||
return []
|
||
|
||
all_results = []
|
||
|
||
q = (query or "").lower()
|
||
# If user explicitly asks about documents/catalogs, prefer knowledge base docs over chat snippets.
|
||
is_doc_query = any(k in q for k in ["pdf", "каталог", "каталоз", "документ", "файл", "стор", "page", "pages"])
|
||
# Simple keyword gate to avoid irrelevant chat snippets dominating doc queries.
|
||
# Example: when asking "з каталогу Defenda 2026 ... гліфосат", old "Бокаші" messages may match too well.
|
||
topic_keywords: List[str] = []
|
||
for kw in ["defenda", "ifagri", "bayer", "гліфосат", "glyphos", "глифос", "npk", "мінерал", "добрив", "гербіц", "фунгіц", "інсектиц"]:
|
||
if kw in q:
|
||
topic_keywords.append(kw)
|
||
|
||
# Dynamic collection names based on agent_id
|
||
memory_items_collection = f"{agent_id}_memory_items"
|
||
messages_collection = f"{agent_id}_messages"
|
||
docs_collection = f"{agent_id}_docs"
|
||
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
|
||
# Search 1: {agent_id}_memory_items (facts, preferences)
|
||
try:
|
||
must_conditions = []
|
||
if platform_user_id:
|
||
must_conditions.append(
|
||
qmodels.FieldCondition(
|
||
key="platform_user_id",
|
||
match=qmodels.MatchValue(value=platform_user_id)
|
||
)
|
||
)
|
||
|
||
search_filter = qmodels.Filter(must=must_conditions) if must_conditions else None
|
||
|
||
results = self.qdrant_client.search(
|
||
collection_name=memory_items_collection,
|
||
query_vector=embedding,
|
||
query_filter=search_filter,
|
||
limit=limit,
|
||
with_payload=True
|
||
)
|
||
|
||
for r in results:
|
||
if r.score > 0.3: # Threshold for relevance
|
||
all_results.append({
|
||
"text": r.payload.get("text", ""),
|
||
"type": r.payload.get("type", "fact"),
|
||
"confidence": r.payload.get("confidence", 0.5),
|
||
"score": r.score,
|
||
"source": "memory_items"
|
||
})
|
||
except Exception as e:
|
||
logger.debug(f"{memory_items_collection} search: {e}")
|
||
|
||
# Search 2: {agent_id}_messages (chat history)
|
||
try:
|
||
msg_filter = None
|
||
if chat_id:
|
||
# Payload schema differs across ingesters: some use chat_id, others channel_id.
|
||
msg_filter = qmodels.Filter(
|
||
should=[
|
||
qmodels.FieldCondition(key="chat_id", match=qmodels.MatchValue(value=str(chat_id))),
|
||
qmodels.FieldCondition(key="channel_id", match=qmodels.MatchValue(value=str(chat_id))),
|
||
]
|
||
)
|
||
results = self.qdrant_client.search(
|
||
collection_name=messages_collection,
|
||
query_vector=embedding,
|
||
query_filter=msg_filter,
|
||
limit=limit,
|
||
with_payload=True
|
||
)
|
||
|
||
for r in results:
|
||
# Higher threshold for messages; even higher when user asks about docs to avoid pulling old chatter.
|
||
msg_thresh = 0.5 if is_doc_query else 0.4
|
||
if r.score > msg_thresh:
|
||
text = self._extract_message_text(r.payload)
|
||
# Skip very short or system messages
|
||
if len(text) > 20 and not text.startswith("<"):
|
||
if is_doc_query and topic_keywords:
|
||
tl = text.lower()
|
||
if not any(k in tl for k in topic_keywords):
|
||
continue
|
||
all_results.append({
|
||
"text": text,
|
||
"type": "message",
|
||
"score": r.score,
|
||
"source": "messages"
|
||
})
|
||
except Exception as e:
|
||
logger.debug(f"{messages_collection} search: {e}")
|
||
|
||
# Search 3: {agent_id}_docs (knowledge base) - optional
|
||
try:
|
||
results = self.qdrant_client.search(
|
||
collection_name=docs_collection,
|
||
query_vector=embedding,
|
||
limit=6 if is_doc_query else 3, # Pull more docs for explicit doc queries
|
||
with_payload=True
|
||
)
|
||
|
||
for r in results:
|
||
# When user asks about PDF/catalogs, relax threshold so docs show up more reliably.
|
||
doc_thresh = 0.35 if is_doc_query else 0.5
|
||
if r.score > doc_thresh:
|
||
text = r.payload.get("text", r.payload.get("content", ""))
|
||
if len(text) > 30:
|
||
all_results.append({
|
||
"text": text[:500], # Truncate long docs
|
||
"type": "knowledge",
|
||
# Slightly boost docs for doc queries so they win vs chat snippets.
|
||
"score": (r.score + 0.12) if is_doc_query else r.score,
|
||
"source": "docs"
|
||
})
|
||
except Exception as e:
|
||
logger.debug(f"{docs_collection} search: {e}")
|
||
|
||
# Sort by score and deduplicate
|
||
all_results.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||
|
||
# Remove duplicates based on text similarity
|
||
seen_texts = set()
|
||
unique_results = []
|
||
for r in all_results:
|
||
text_key = self._canonical_text_key(r.get("text", ""))
|
||
if text_key not in seen_texts:
|
||
seen_texts.add(text_key)
|
||
unique_results.append(r)
|
||
|
||
return unique_results[:limit]
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Memory search failed for {agent_id}: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
def _extract_message_text(payload: Dict[str, Any]) -> str:
|
||
"""
|
||
Normalize text across both payload schemas:
|
||
- memory-service: content/text (+ role/channel_id)
|
||
- router: user_message + assistant_response (+ chat_id)
|
||
"""
|
||
if not payload:
|
||
return ""
|
||
|
||
text = (payload.get("text") or payload.get("content") or "").strip()
|
||
if text:
|
||
lower = text.lower()
|
||
marker = "\n\nassistant:"
|
||
idx = lower.rfind(marker)
|
||
if lower.startswith("user:") and idx != -1:
|
||
assistant_text = text[idx + len(marker):].strip()
|
||
if assistant_text:
|
||
return assistant_text
|
||
return text
|
||
|
||
user_text = (payload.get("user_message") or "").strip()
|
||
assistant_text = (payload.get("assistant_response") or "").strip()
|
||
if user_text and assistant_text:
|
||
return f"User: {user_text}\n\nAssistant: {assistant_text}"
|
||
return user_text or assistant_text
|
||
|
||
@staticmethod
|
||
def _canonical_text_key(text: str) -> str:
|
||
if not text:
|
||
return ""
|
||
normalized = re.sub(r"\s+", " ", text.strip().lower())
|
||
return normalized[:220]
|
||
|
||
async def get_user_graph_context(
|
||
self,
|
||
username: Optional[str] = None,
|
||
telegram_user_id: Optional[str] = None,
|
||
agent_id: str = "helion"
|
||
) -> Dict[str, List[str]]:
|
||
"""Get user's topics and projects from Neo4j, filtered by agent_id"""
|
||
context = {"topics": [], "projects": []}
|
||
|
||
if not self.neo4j_driver:
|
||
return context
|
||
|
||
if not username and not telegram_user_id:
|
||
return context
|
||
|
||
try:
|
||
async with self.neo4j_driver.session() as session:
|
||
# Find user and their context, FILTERED BY AGENT_ID
|
||
# Support both username formats: with and without @
|
||
username_with_at = f"@{username}" if username and not username.startswith("@") else username
|
||
username_without_at = username[1:] if username and username.startswith("@") else username
|
||
|
||
# Query with agent_id filter on relationships
|
||
query = """
|
||
MATCH (u:User)
|
||
WHERE u.username IN [$username, $username_with_at, $username_without_at]
|
||
OR u.telegram_user_id = $telegram_user_id
|
||
OR u.telegram_id = $telegram_user_id
|
||
OPTIONAL MATCH (u)-[r1:ASKED_ABOUT]->(t:Topic)
|
||
WHERE r1.agent_id = $agent_id OR r1.agent_id IS NULL
|
||
OPTIONAL MATCH (u)-[r2:WORKS_ON]->(p:Project)
|
||
WHERE r2.agent_id = $agent_id OR r2.agent_id IS NULL
|
||
RETURN
|
||
collect(DISTINCT t.name) as topics,
|
||
collect(DISTINCT p.name) as projects
|
||
"""
|
||
result = await session.run(
|
||
query,
|
||
username=username,
|
||
username_with_at=username_with_at,
|
||
username_without_at=username_without_at,
|
||
telegram_user_id=telegram_user_id,
|
||
agent_id=agent_id
|
||
)
|
||
record = await result.single()
|
||
|
||
if record:
|
||
context["topics"] = [t for t in record["topics"] if t]
|
||
context["projects"] = [p for p in record["projects"] if p]
|
||
|
||
logger.debug(f"Graph context for {agent_id}: topics={len(context['topics'])}, projects={len(context['projects'])}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Graph context retrieval failed for {agent_id}: {e}")
|
||
|
||
return context
|
||
|
||
# =========================================================================
|
||
# Main Retrieval Pipeline
|
||
# =========================================================================
|
||
|
||
async def retrieve(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
user_id: str,
|
||
agent_id: str = "helion",
|
||
username: Optional[str] = None,
|
||
display_name: Optional[str] = None,
|
||
message: Optional[str] = None,
|
||
thread_id: Optional[str] = None
|
||
) -> MemoryBrief:
|
||
"""
|
||
Main retrieval pipeline for any agent.
|
||
|
||
1. Resolve user identity (L2)
|
||
2. Get session state (L1)
|
||
3. Get user facts (L3)
|
||
4. Search relevant memories if message provided (L3)
|
||
5. Get graph context (L3)
|
||
6. Compile memory brief
|
||
|
||
Args:
|
||
agent_id: Agent identifier for collection routing (e.g. "helion", "nutra", "greenfood")
|
||
"""
|
||
brief = MemoryBrief()
|
||
|
||
# L2: Identity
|
||
identity = await self.resolve_identity(channel, user_id, username, display_name)
|
||
brief.user_identity = identity
|
||
|
||
# L1: Session State
|
||
session = await self.get_session_state(channel, chat_id, thread_id)
|
||
brief.session_state = session
|
||
brief.is_trusted_group = session.trust_mode
|
||
|
||
# L3: User Facts
|
||
if identity.platform_user_id:
|
||
facts = await self.get_user_facts(identity.platform_user_id)
|
||
brief.user_facts = facts
|
||
|
||
# L3: Semantic Search (if message provided) - agent-specific collections
|
||
if message:
|
||
memories = await self.search_memories(
|
||
query=message,
|
||
agent_id=agent_id,
|
||
platform_user_id=identity.platform_user_id,
|
||
chat_id=chat_id,
|
||
user_id=user_id,
|
||
limit=5
|
||
)
|
||
brief.relevant_memories = memories
|
||
|
||
# L3: Graph Context (filtered by agent_id to prevent role mixing)
|
||
graph_ctx = await self.get_user_graph_context(username, user_id, agent_id)
|
||
brief.user_topics = graph_ctx.get("topics", [])
|
||
brief.user_projects = graph_ctx.get("projects", [])
|
||
|
||
# Check for mentor presence
|
||
brief.mentor_present = identity.is_mentor
|
||
|
||
return brief
|
||
|
||
# =========================================================================
|
||
# Memory Storage (write path)
|
||
# =========================================================================
|
||
|
||
async def store_message(
|
||
self,
|
||
agent_id: str,
|
||
user_id: str,
|
||
username: Optional[str],
|
||
message_text: str,
|
||
response_text: str,
|
||
chat_id: str,
|
||
message_type: str = "conversation",
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> bool:
|
||
"""
|
||
Store a message exchange in agent-specific Qdrant collection.
|
||
|
||
This enables semantic retrieval of past conversations per agent.
|
||
"""
|
||
if not self.qdrant_client or not COHERE_API_KEY:
|
||
logger.debug(f"Cannot store message: qdrant={bool(self.qdrant_client)}, cohere={bool(COHERE_API_KEY)}")
|
||
return False
|
||
|
||
messages_collection = f"{agent_id}_messages"
|
||
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
import uuid
|
||
|
||
# Ensure collection exists
|
||
try:
|
||
self.qdrant_client.get_collection(messages_collection)
|
||
except Exception:
|
||
# Create collection with Cohere embed-multilingual-v3.0 dimensions (1024)
|
||
self.qdrant_client.create_collection(
|
||
collection_name=messages_collection,
|
||
vectors_config=qmodels.VectorParams(
|
||
size=1024,
|
||
distance=qmodels.Distance.COSINE
|
||
)
|
||
)
|
||
logger.info(f"✅ Created collection: {messages_collection}")
|
||
|
||
# Combine user message and response for better context retrieval
|
||
combined_text = f"User: {message_text}\n\nAssistant: {response_text}"
|
||
|
||
# Get embedding
|
||
embedding = await self.get_embedding(combined_text[:2000]) # Truncate for API limits
|
||
if not embedding:
|
||
logger.warning(f"Failed to get embedding for message storage")
|
||
return False
|
||
|
||
# Store in Qdrant
|
||
point_id = str(uuid.uuid4())
|
||
payload = {
|
||
"text": combined_text[:5000], # Limit payload size
|
||
"user_message": message_text[:2000],
|
||
"assistant_response": response_text[:3000],
|
||
"user_id": user_id,
|
||
"username": username,
|
||
"chat_id": chat_id,
|
||
"agent_id": agent_id,
|
||
"type": message_type,
|
||
"timestamp": datetime.utcnow().isoformat()
|
||
}
|
||
if metadata and isinstance(metadata, dict):
|
||
payload["metadata"] = metadata
|
||
|
||
self.qdrant_client.upsert(
|
||
collection_name=messages_collection,
|
||
points=[
|
||
qmodels.PointStruct(
|
||
id=point_id,
|
||
vector=embedding,
|
||
payload=payload
|
||
)
|
||
]
|
||
)
|
||
|
||
logger.debug(f"✅ Stored message in {messages_collection}: {point_id[:8]}...")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to store message in {messages_collection}: {e}")
|
||
return False
|
||
|
||
async def update_session_state(
|
||
self,
|
||
conversation_id: str,
|
||
**updates
|
||
):
|
||
"""Update session state after interaction"""
|
||
if not self.pg_pool or not conversation_id:
|
||
return
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
# Build dynamic update
|
||
set_clauses = ["updated_at = NOW()"]
|
||
params = [conversation_id]
|
||
param_idx = 2
|
||
|
||
allowed_fields = [
|
||
'last_addressed_to_helion', 'last_user_id', 'last_user_nick',
|
||
'active_topic_id', 'active_context_open', 'last_media_id',
|
||
'last_media_handled', 'last_answer_fingerprint', 'group_trust_mode',
|
||
'apprentice_mode', 'proactive_questions_today'
|
||
]
|
||
|
||
for field, value in updates.items():
|
||
if field in allowed_fields:
|
||
set_clauses.append(f"{field} = ${param_idx}")
|
||
params.append(value)
|
||
param_idx += 1
|
||
|
||
query = f"""
|
||
UPDATE helion_conversation_state
|
||
SET {', '.join(set_clauses)}
|
||
WHERE conversation_id = $1
|
||
"""
|
||
await conn.execute(query, *params)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Session state update failed: {e}")
|
||
|
||
|
||
# Global instance
|
||
memory_retrieval = MemoryRetrieval()
|