1167 lines
47 KiB
Python
1167 lines
47 KiB
Python
"""
|
||
Universal Memory Retrieval Pipeline v4.0
|
||
|
||
This module implements the 4-layer memory retrieval for ALL agents:
|
||
- L0: Working Context (from request)
|
||
- L1: Session State Memory (SSM) - from Postgres
|
||
- L2: Platform Identity & Roles (PIR) - from Postgres + Neo4j
|
||
- L3: Organizational Memory (OM) - from Postgres + Neo4j + Qdrant
|
||
|
||
The pipeline generates a "memory brief" that is injected into the LLM context.
|
||
|
||
Collections per agent:
|
||
- {agent_id}_messages - chat history with embeddings
|
||
- {agent_id}_memory_items - facts, preferences
|
||
- {agent_id}_docs - knowledge base documents
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import logging
|
||
import re
|
||
from typing import Optional, Dict, Any, List
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
import hashlib
|
||
|
||
import httpx
|
||
import asyncpg
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Configuration
|
||
POSTGRES_URL = os.getenv("DATABASE_URL", "postgresql://daarion:DaarionDB2026!@dagi-postgres:5432/daarion_memory")
|
||
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
|
||
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
||
COHERE_API_KEY = os.getenv("COHERE_API_KEY", "")
|
||
NEO4J_BOLT_URL = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687")
|
||
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
|
||
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j")
|
||
PENDING_QUESTIONS_LIMIT = int(os.getenv("AGENT_PENDING_QUESTIONS_LIMIT", "5"))
|
||
SHARED_AGRO_LIBRARY_ENABLED = os.getenv("AGROMATRIX_SHARED_LIBRARY_ENABLED", "true").lower() == "true"
|
||
SHARED_AGRO_LIBRARY_REQUIRE_REVIEW = os.getenv("AGROMATRIX_SHARED_LIBRARY_REQUIRE_REVIEW", "true").lower() == "true"
|
||
|
||
|
||
@dataclass
|
||
class UserIdentity:
|
||
"""Platform user identity"""
|
||
platform_user_id: Optional[str] = None
|
||
channel: str = "telegram"
|
||
channel_user_id: str = ""
|
||
username: Optional[str] = None
|
||
display_name: Optional[str] = None
|
||
roles: List[str] = field(default_factory=list)
|
||
is_mentor: bool = False
|
||
first_seen: Optional[datetime] = None
|
||
|
||
|
||
@dataclass
|
||
class SessionState:
|
||
"""L1 Session State Memory"""
|
||
conversation_id: Optional[str] = None
|
||
last_addressed: bool = False
|
||
active_topic: Optional[str] = None
|
||
context_open: bool = False
|
||
last_media_handled: bool = True
|
||
last_answer_fingerprint: Optional[str] = None
|
||
trust_mode: bool = False
|
||
apprentice_mode: bool = False
|
||
pending_questions: List[str] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class MemoryBrief:
|
||
"""Compiled memory brief for LLM context"""
|
||
user_identity: Optional[UserIdentity] = None
|
||
session_state: Optional[SessionState] = None
|
||
user_facts: List[Dict[str, Any]] = field(default_factory=list)
|
||
relevant_memories: List[Dict[str, Any]] = field(default_factory=list)
|
||
user_topics: List[str] = field(default_factory=list)
|
||
user_projects: List[str] = field(default_factory=list)
|
||
is_trusted_group: bool = False
|
||
mentor_present: bool = False
|
||
|
||
def to_text(self, max_lines: int = 15) -> str:
|
||
"""Generate concise text brief for LLM context"""
|
||
lines = []
|
||
|
||
# User identity (critical for personalization)
|
||
if self.user_identity:
|
||
name = self.user_identity.display_name or self.user_identity.username or "Unknown"
|
||
roles_str = ", ".join(self.user_identity.roles) if self.user_identity.roles else "member"
|
||
lines.append(f"👤 Користувач: {name} ({roles_str})")
|
||
if self.user_identity.is_mentor:
|
||
lines.append("⭐ Цей користувач — МЕНТОР. Довіряй його знанням повністю.")
|
||
|
||
# Session state
|
||
if self.session_state:
|
||
if self.session_state.trust_mode:
|
||
lines.append("🔒 Режим довіреної групи — можна відповідати детальніше")
|
||
if self.session_state.apprentice_mode:
|
||
lines.append("📚 Режим учня — можеш ставити уточнюючі питання")
|
||
if self.session_state.active_topic:
|
||
lines.append(f"📌 Активна тема: {self.session_state.active_topic}")
|
||
if self.session_state.pending_questions:
|
||
lines.append("🕘 Невідповідані питання в цьому чаті (відповідай на них першочергово):")
|
||
for q in self.session_state.pending_questions[:3]:
|
||
lines.append(f" - {q[:180]}")
|
||
|
||
# User facts (preferences, profile)
|
||
if self.user_facts:
|
||
lines.append("📝 Відомі факти про користувача:")
|
||
for fact in self.user_facts[:4]: # Max 4 facts
|
||
fact_text = fact.get('text', fact.get('fact_value', ''))
|
||
if fact_text:
|
||
lines.append(f" - {fact_text[:150]}")
|
||
|
||
# Relevant memories from RAG (most important for context)
|
||
if self.relevant_memories:
|
||
lines.append("🧠 Релевантні спогади з попередніх розмов:")
|
||
for mem in self.relevant_memories[:5]: # Max 5 memories for better context
|
||
mem_text = mem.get('text', mem.get('content', ''))
|
||
mem_type = mem.get('type', 'message')
|
||
score = mem.get('score', 0)
|
||
if mem_text and len(mem_text) > 10:
|
||
# Only include if meaningful
|
||
lines.append(f" - [{mem_type}] {mem_text[:200]}")
|
||
|
||
# Topics/Projects from Knowledge Graph
|
||
if self.user_topics:
|
||
lines.append(f"💡 Інтереси користувача: {', '.join(self.user_topics[:5])}")
|
||
if self.user_projects:
|
||
lines.append(f"🏗️ Проєкти користувача: {', '.join(self.user_projects[:3])}")
|
||
|
||
# Mentor presence indicator
|
||
if self.mentor_present:
|
||
lines.append("⚠️ ВАЖЛИВО: Ти спілкуєшся з ментором. Сприймай як навчання.")
|
||
|
||
# Truncate if needed
|
||
if len(lines) > max_lines:
|
||
lines = lines[:max_lines]
|
||
lines.append("...")
|
||
|
||
return "\n".join(lines) if lines else ""
|
||
|
||
|
||
class MemoryRetrieval:
|
||
"""Memory Retrieval Pipeline"""
|
||
|
||
def __init__(self):
|
||
self.pg_pool: Optional[asyncpg.Pool] = None
|
||
self.neo4j_driver = None
|
||
self.qdrant_client = None
|
||
self.http_client: Optional[httpx.AsyncClient] = None
|
||
|
||
async def initialize(self):
|
||
"""Initialize database connections"""
|
||
# PostgreSQL
|
||
try:
|
||
self.pg_pool = await asyncpg.create_pool(
|
||
POSTGRES_URL,
|
||
min_size=2,
|
||
max_size=10
|
||
)
|
||
logger.info("✅ Memory Retrieval: PostgreSQL connected")
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ Memory Retrieval: PostgreSQL not available: {e}")
|
||
|
||
# Neo4j
|
||
try:
|
||
from neo4j import AsyncGraphDatabase
|
||
self.neo4j_driver = AsyncGraphDatabase.driver(
|
||
NEO4J_BOLT_URL,
|
||
auth=(NEO4J_USER, NEO4J_PASSWORD)
|
||
)
|
||
await self.neo4j_driver.verify_connectivity()
|
||
logger.info("✅ Memory Retrieval: Neo4j connected")
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ Memory Retrieval: Neo4j not available: {e}")
|
||
|
||
# Qdrant
|
||
try:
|
||
from qdrant_client import QdrantClient
|
||
self.qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
||
self.qdrant_client.get_collections() # Test connection
|
||
logger.info("✅ Memory Retrieval: Qdrant connected")
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ Memory Retrieval: Qdrant not available: {e}")
|
||
|
||
# HTTP client for embeddings
|
||
self.http_client = httpx.AsyncClient(timeout=30.0)
|
||
await self._ensure_aux_tables()
|
||
|
||
async def close(self):
|
||
"""Close connections"""
|
||
if self.pg_pool:
|
||
await self.pg_pool.close()
|
||
if self.neo4j_driver:
|
||
await self.neo4j_driver.close()
|
||
if self.http_client:
|
||
await self.http_client.aclose()
|
||
|
||
async def _ensure_aux_tables(self):
|
||
"""Create auxiliary tables used by agent runtime policies."""
|
||
if not self.pg_pool:
|
||
return
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
await conn.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS agent_session_state (
|
||
channel TEXT NOT NULL,
|
||
chat_id TEXT NOT NULL,
|
||
user_id TEXT NOT NULL,
|
||
agent_id TEXT NOT NULL,
|
||
conversation_id TEXT NOT NULL,
|
||
last_user_id TEXT,
|
||
last_user_nick TEXT,
|
||
active_topic TEXT,
|
||
context_open BOOLEAN NOT NULL DEFAULT FALSE,
|
||
last_media_handled BOOLEAN NOT NULL DEFAULT TRUE,
|
||
last_answer_fingerprint TEXT,
|
||
trust_mode BOOLEAN NOT NULL DEFAULT FALSE,
|
||
apprentice_mode BOOLEAN NOT NULL DEFAULT FALSE,
|
||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||
PRIMARY KEY (channel, chat_id, user_id, agent_id)
|
||
);
|
||
CREATE INDEX IF NOT EXISTS idx_agent_session_state_conv
|
||
ON agent_session_state (conversation_id);
|
||
|
||
CREATE TABLE IF NOT EXISTS agent_pending_questions (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
channel TEXT NOT NULL,
|
||
chat_id TEXT NOT NULL,
|
||
user_id TEXT NOT NULL,
|
||
agent_id TEXT NOT NULL,
|
||
question_text TEXT NOT NULL,
|
||
question_fingerprint TEXT NOT NULL,
|
||
status TEXT NOT NULL DEFAULT 'pending',
|
||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||
answered_at TIMESTAMPTZ,
|
||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb
|
||
);
|
||
CREATE INDEX IF NOT EXISTS idx_agent_pending_questions_scope
|
||
ON agent_pending_questions (agent_id, channel, chat_id, user_id, status, created_at DESC);
|
||
CREATE UNIQUE INDEX IF NOT EXISTS idx_agent_pending_questions_unique_open
|
||
ON agent_pending_questions (agent_id, channel, chat_id, user_id, question_fingerprint, status);
|
||
"""
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Aux tables init failed: {e}")
|
||
|
||
# =========================================================================
|
||
# L2: Platform Identity Resolution
|
||
# =========================================================================
|
||
|
||
async def resolve_identity(
|
||
self,
|
||
channel: str,
|
||
channel_user_id: str,
|
||
username: Optional[str] = None,
|
||
display_name: Optional[str] = None
|
||
) -> UserIdentity:
|
||
"""Resolve or create platform user identity"""
|
||
identity = UserIdentity(
|
||
channel=channel,
|
||
channel_user_id=channel_user_id,
|
||
username=username,
|
||
display_name=display_name
|
||
)
|
||
|
||
if not self.pg_pool:
|
||
return identity
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
# Use the resolve_platform_user function
|
||
platform_user_id = await conn.fetchval(
|
||
"SELECT resolve_platform_user($1, $2, $3, $4)",
|
||
channel, channel_user_id, username, display_name
|
||
)
|
||
identity.platform_user_id = str(platform_user_id) if platform_user_id else None
|
||
|
||
# Get roles
|
||
if platform_user_id:
|
||
roles = await conn.fetch("""
|
||
SELECT r.code FROM user_roles ur
|
||
JOIN platform_roles r ON ur.role_id = r.id
|
||
WHERE ur.platform_user_id = $1 AND ur.revoked_at IS NULL
|
||
""", platform_user_id)
|
||
identity.roles = [r['code'] for r in roles]
|
||
|
||
# Check if mentor
|
||
is_mentor = await conn.fetchval(
|
||
"SELECT is_mentor($1, $2)",
|
||
channel_user_id, username
|
||
)
|
||
identity.is_mentor = bool(is_mentor)
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Identity resolution fallback: {e}")
|
||
|
||
return identity
|
||
|
||
# =========================================================================
|
||
# L1: Session State
|
||
# =========================================================================
|
||
|
||
async def get_session_state(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
thread_id: Optional[str] = None,
|
||
agent_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
) -> SessionState:
|
||
"""Get or create session state for conversation"""
|
||
state = SessionState()
|
||
|
||
if not self.pg_pool:
|
||
return state
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
if agent_id and user_id:
|
||
conv_id = self._build_conversation_id(channel, chat_id, user_id, agent_id)
|
||
row = await conn.fetchrow(
|
||
"""
|
||
SELECT conversation_id, active_topic, context_open, last_media_handled,
|
||
last_answer_fingerprint, trust_mode, apprentice_mode
|
||
FROM agent_session_state
|
||
WHERE channel = $1
|
||
AND chat_id = $2
|
||
AND user_id = $3
|
||
AND agent_id = $4
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
)
|
||
if not row:
|
||
await conn.execute(
|
||
"""
|
||
INSERT INTO agent_session_state
|
||
(channel, chat_id, user_id, agent_id, conversation_id)
|
||
VALUES ($1, $2, $3, $4, $5)
|
||
ON CONFLICT (channel, chat_id, user_id, agent_id) DO NOTHING
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
conv_id,
|
||
)
|
||
state.conversation_id = conv_id
|
||
else:
|
||
state.conversation_id = str(row.get("conversation_id") or conv_id)
|
||
state.active_topic = row.get("active_topic")
|
||
state.context_open = bool(row.get("context_open", False))
|
||
state.last_media_handled = bool(row.get("last_media_handled", True))
|
||
state.last_answer_fingerprint = row.get("last_answer_fingerprint")
|
||
state.trust_mode = bool(row.get("trust_mode", False))
|
||
state.apprentice_mode = bool(row.get("apprentice_mode", False))
|
||
else:
|
||
state.conversation_id = self._build_conversation_id(
|
||
channel,
|
||
chat_id,
|
||
user_id or "unknown",
|
||
agent_id or "agent",
|
||
)
|
||
|
||
if agent_id and user_id:
|
||
pending_rows = await conn.fetch(
|
||
"""
|
||
SELECT question_text
|
||
FROM agent_pending_questions
|
||
WHERE channel = $1
|
||
AND chat_id = $2
|
||
AND user_id = $3
|
||
AND agent_id = $4
|
||
AND status = 'pending'
|
||
ORDER BY created_at ASC
|
||
LIMIT $5
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
PENDING_QUESTIONS_LIMIT,
|
||
)
|
||
state.pending_questions = [
|
||
str(r.get("question_text") or "").strip()
|
||
for r in pending_rows
|
||
if str(r.get("question_text") or "").strip()
|
||
]
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Session state retrieval failed: {e}")
|
||
|
||
return state
|
||
|
||
# =========================================================================
|
||
# L3: Memory Retrieval (Facts + Semantic)
|
||
# =========================================================================
|
||
|
||
async def get_user_facts(
|
||
self,
|
||
platform_user_id: Optional[str],
|
||
limit: int = 5
|
||
) -> List[Dict[str, Any]]:
|
||
"""Get user's explicit facts from PostgreSQL"""
|
||
if not self.pg_pool or not platform_user_id:
|
||
return []
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
rows = await conn.fetch("""
|
||
SELECT type, text, summary, confidence, visibility
|
||
FROM helion_memory_items
|
||
WHERE platform_user_id = $1
|
||
AND archived_at IS NULL
|
||
AND (expires_at IS NULL OR expires_at > NOW())
|
||
AND type IN ('preference', 'profile_fact')
|
||
ORDER BY confidence DESC, updated_at DESC
|
||
LIMIT $2
|
||
""", platform_user_id, limit)
|
||
|
||
return [dict(r) for r in rows]
|
||
except Exception as e:
|
||
logger.warning(f"User facts retrieval failed: {e}")
|
||
return []
|
||
|
||
async def get_embedding(self, text: str) -> Optional[List[float]]:
|
||
"""Get embedding from Cohere API"""
|
||
if not COHERE_API_KEY or not self.http_client:
|
||
return None
|
||
|
||
try:
|
||
response = await self.http_client.post(
|
||
"https://api.cohere.ai/v1/embed",
|
||
headers={
|
||
"Authorization": f"Bearer {COHERE_API_KEY}",
|
||
"Content-Type": "application/json"
|
||
},
|
||
json={
|
||
"texts": [text],
|
||
"model": "embed-multilingual-v3.0",
|
||
"input_type": "search_query",
|
||
"truncate": "END"
|
||
}
|
||
)
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
return data.get("embeddings", [[]])[0]
|
||
except Exception as e:
|
||
logger.warning(f"Embedding generation failed: {e}")
|
||
|
||
return None
|
||
|
||
async def search_memories(
|
||
self,
|
||
query: str,
|
||
agent_id: str = "helion",
|
||
platform_user_id: Optional[str] = None,
|
||
chat_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
visibility: str = "platform",
|
||
limit: int = 5
|
||
) -> List[Dict[str, Any]]:
|
||
"""Semantic search in Qdrant across agent-specific collections"""
|
||
if not self.qdrant_client:
|
||
return []
|
||
|
||
# Get embedding
|
||
embedding = await self.get_embedding(query)
|
||
if not embedding:
|
||
return []
|
||
|
||
all_results = []
|
||
|
||
q = (query or "").lower()
|
||
# If user explicitly asks about documents/catalogs, prefer knowledge base docs over chat snippets.
|
||
is_doc_query = any(k in q for k in ["pdf", "каталог", "каталоз", "документ", "файл", "стор", "page", "pages"])
|
||
# Simple keyword gate to avoid irrelevant chat snippets dominating doc queries.
|
||
# Example: when asking "з каталогу Defenda 2026 ... гліфосат", old "Бокаші" messages may match too well.
|
||
topic_keywords: List[str] = []
|
||
for kw in ["defenda", "ifagri", "bayer", "гліфосат", "glyphos", "глифос", "npk", "мінерал", "добрив", "гербіц", "фунгіц", "інсектиц"]:
|
||
if kw in q:
|
||
topic_keywords.append(kw)
|
||
|
||
# Dynamic collection names based on agent_id
|
||
memory_items_collection = f"{agent_id}_memory_items"
|
||
messages_collection = f"{agent_id}_messages"
|
||
docs_collection = f"{agent_id}_docs"
|
||
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
|
||
# Search 1: {agent_id}_memory_items (facts, preferences)
|
||
try:
|
||
must_conditions = []
|
||
if platform_user_id:
|
||
must_conditions.append(
|
||
qmodels.FieldCondition(
|
||
key="platform_user_id",
|
||
match=qmodels.MatchValue(value=platform_user_id)
|
||
)
|
||
)
|
||
|
||
search_filter = qmodels.Filter(must=must_conditions) if must_conditions else None
|
||
|
||
results = self.qdrant_client.search(
|
||
collection_name=memory_items_collection,
|
||
query_vector=embedding,
|
||
query_filter=search_filter,
|
||
limit=limit,
|
||
with_payload=True
|
||
)
|
||
|
||
for r in results:
|
||
if r.score > 0.3: # Threshold for relevance
|
||
all_results.append({
|
||
"text": r.payload.get("text", ""),
|
||
"type": r.payload.get("type", "fact"),
|
||
"confidence": r.payload.get("confidence", 0.5),
|
||
"score": r.score,
|
||
"source": "memory_items"
|
||
})
|
||
except Exception as e:
|
||
logger.debug(f"{memory_items_collection} search: {e}")
|
||
|
||
# Search 2: {agent_id}_messages (chat history)
|
||
try:
|
||
msg_filter = None
|
||
if chat_id:
|
||
# Payload schema differs across ingesters: some use chat_id, others channel_id.
|
||
msg_filter = qmodels.Filter(
|
||
should=[
|
||
qmodels.FieldCondition(key="chat_id", match=qmodels.MatchValue(value=str(chat_id))),
|
||
qmodels.FieldCondition(key="channel_id", match=qmodels.MatchValue(value=str(chat_id))),
|
||
]
|
||
)
|
||
results = self.qdrant_client.search(
|
||
collection_name=messages_collection,
|
||
query_vector=embedding,
|
||
query_filter=msg_filter,
|
||
limit=limit,
|
||
with_payload=True
|
||
)
|
||
|
||
for r in results:
|
||
# Higher threshold for messages; even higher when user asks about docs to avoid pulling old chatter.
|
||
msg_thresh = 0.5 if is_doc_query else 0.4
|
||
if r.score > msg_thresh:
|
||
text = self._extract_message_text(r.payload)
|
||
# Skip very short or system messages
|
||
if len(text) > 20 and not text.startswith("<"):
|
||
if is_doc_query and topic_keywords:
|
||
tl = text.lower()
|
||
if not any(k in tl for k in topic_keywords):
|
||
continue
|
||
all_results.append({
|
||
"text": text,
|
||
"type": "message",
|
||
"score": r.score,
|
||
"source": "messages"
|
||
})
|
||
except Exception as e:
|
||
logger.debug(f"{messages_collection} search: {e}")
|
||
|
||
# Search 3: {agent_id}_docs (knowledge base) - optional
|
||
try:
|
||
results = self.qdrant_client.search(
|
||
collection_name=docs_collection,
|
||
query_vector=embedding,
|
||
limit=6 if is_doc_query else 3, # Pull more docs for explicit doc queries
|
||
with_payload=True
|
||
)
|
||
|
||
for r in results:
|
||
# When user asks about PDF/catalogs, relax threshold so docs show up more reliably.
|
||
doc_thresh = 0.35 if is_doc_query else 0.5
|
||
if r.score > doc_thresh:
|
||
text = r.payload.get("text", r.payload.get("content", ""))
|
||
if len(text) > 30:
|
||
all_results.append({
|
||
"text": text[:500], # Truncate long docs
|
||
"type": "knowledge",
|
||
# Slightly boost docs for doc queries so they win vs chat snippets.
|
||
"score": (r.score + 0.12) if is_doc_query else r.score,
|
||
"source": "docs"
|
||
})
|
||
except Exception as e:
|
||
logger.debug(f"{docs_collection} search: {e}")
|
||
|
||
# Search 4: shared agronomy memory (reviewed, cross-chat, anonymized)
|
||
if (
|
||
SHARED_AGRO_LIBRARY_ENABLED
|
||
and agent_id == "agromatrix"
|
||
and self._is_plant_query(query)
|
||
):
|
||
try:
|
||
results = self.qdrant_client.search(
|
||
collection_name="agromatrix_shared_library",
|
||
query_vector=embedding,
|
||
limit=3,
|
||
with_payload=True
|
||
)
|
||
for r in results:
|
||
if r.score > 0.45:
|
||
text = str(r.payload.get("text") or "").strip()
|
||
if len(text) > 20:
|
||
all_results.append({
|
||
"text": text[:500],
|
||
"type": "shared_agro_fact",
|
||
"score": r.score + 0.05,
|
||
"source": "shared_agronomy_library"
|
||
})
|
||
except Exception as e:
|
||
logger.debug(f"agromatrix_shared_library search: {e}")
|
||
|
||
# Sort by score and deduplicate
|
||
all_results.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||
|
||
# Remove duplicates based on text similarity
|
||
seen_texts = set()
|
||
unique_results = []
|
||
for r in all_results:
|
||
text_key = self._canonical_text_key(r.get("text", ""))
|
||
if text_key not in seen_texts:
|
||
seen_texts.add(text_key)
|
||
unique_results.append(r)
|
||
|
||
return unique_results[:limit]
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Memory search failed for {agent_id}: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
def _extract_message_text(payload: Dict[str, Any]) -> str:
|
||
"""
|
||
Normalize text across both payload schemas:
|
||
- memory-service: content/text (+ role/channel_id)
|
||
- router: user_message + assistant_response (+ chat_id)
|
||
"""
|
||
if not payload:
|
||
return ""
|
||
|
||
text = (payload.get("text") or payload.get("content") or "").strip()
|
||
if text:
|
||
lower = text.lower()
|
||
marker = "\n\nassistant:"
|
||
idx = lower.rfind(marker)
|
||
if lower.startswith("user:") and idx != -1:
|
||
assistant_text = text[idx + len(marker):].strip()
|
||
if assistant_text:
|
||
return assistant_text
|
||
return text
|
||
|
||
user_text = (payload.get("user_message") or "").strip()
|
||
assistant_text = (payload.get("assistant_response") or "").strip()
|
||
if user_text and assistant_text:
|
||
return f"User: {user_text}\n\nAssistant: {assistant_text}"
|
||
return user_text or assistant_text
|
||
|
||
@staticmethod
|
||
def _canonical_text_key(text: str) -> str:
|
||
if not text:
|
||
return ""
|
||
normalized = re.sub(r"\s+", " ", text.strip().lower())
|
||
return normalized[:220]
|
||
|
||
@staticmethod
|
||
def _is_plant_query(text: str) -> bool:
|
||
q = (text or "").lower()
|
||
if not q:
|
||
return False
|
||
markers = [
|
||
"рослин", "культур", "лист", "стебл", "бур'ян", "хвороб", "шкідник",
|
||
"what plant", "identify plant", "crop", "species", "leaf", "stem",
|
||
"что за растение", "культура", "листок", "фото рослини"
|
||
]
|
||
return any(m in q for m in markers)
|
||
|
||
@staticmethod
|
||
def _question_fingerprint(question_text: str) -> str:
|
||
normalized = re.sub(r"\s+", " ", (question_text or "").strip().lower())
|
||
return hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:16]
|
||
|
||
@staticmethod
|
||
def _build_conversation_id(channel: str, chat_id: str, user_id: str, agent_id: str) -> str:
|
||
seed = f"{channel}:{chat_id}:{user_id}:{agent_id}"
|
||
return hashlib.sha1(seed.encode("utf-8")).hexdigest()[:24]
|
||
|
||
async def get_user_graph_context(
|
||
self,
|
||
username: Optional[str] = None,
|
||
telegram_user_id: Optional[str] = None,
|
||
agent_id: str = "helion"
|
||
) -> Dict[str, List[str]]:
|
||
"""Get user's topics and projects from Neo4j, filtered by agent_id"""
|
||
context = {"topics": [], "projects": []}
|
||
|
||
if not self.neo4j_driver:
|
||
return context
|
||
|
||
if not username and not telegram_user_id:
|
||
return context
|
||
|
||
try:
|
||
async with self.neo4j_driver.session() as session:
|
||
# Find user and their context, FILTERED BY AGENT_ID
|
||
# Support both username formats: with and without @
|
||
username_with_at = f"@{username}" if username and not username.startswith("@") else username
|
||
username_without_at = username[1:] if username and username.startswith("@") else username
|
||
|
||
# Query with agent_id filter on relationships
|
||
query = """
|
||
MATCH (u:User)
|
||
WHERE u.username IN [$username, $username_with_at, $username_without_at]
|
||
OR u.telegram_user_id = $telegram_user_id
|
||
OR u.telegram_id = $telegram_user_id
|
||
OPTIONAL MATCH (u)-[r1:ASKED_ABOUT]->(t:Topic)
|
||
WHERE r1.agent_id = $agent_id OR r1.agent_id IS NULL
|
||
OPTIONAL MATCH (u)-[r2:WORKS_ON]->(p:Project)
|
||
WHERE r2.agent_id = $agent_id OR r2.agent_id IS NULL
|
||
RETURN
|
||
collect(DISTINCT t.name) as topics,
|
||
collect(DISTINCT p.name) as projects
|
||
"""
|
||
result = await session.run(
|
||
query,
|
||
username=username,
|
||
username_with_at=username_with_at,
|
||
username_without_at=username_without_at,
|
||
telegram_user_id=telegram_user_id,
|
||
agent_id=agent_id
|
||
)
|
||
record = await result.single()
|
||
|
||
if record:
|
||
context["topics"] = [t for t in record["topics"] if t]
|
||
context["projects"] = [p for p in record["projects"] if p]
|
||
|
||
logger.debug(f"Graph context for {agent_id}: topics={len(context['topics'])}, projects={len(context['projects'])}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Graph context retrieval failed for {agent_id}: {e}")
|
||
|
||
return context
|
||
|
||
# =========================================================================
|
||
# Main Retrieval Pipeline
|
||
# =========================================================================
|
||
|
||
async def retrieve(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
user_id: str,
|
||
agent_id: str = "helion",
|
||
username: Optional[str] = None,
|
||
display_name: Optional[str] = None,
|
||
message: Optional[str] = None,
|
||
thread_id: Optional[str] = None
|
||
) -> MemoryBrief:
|
||
"""
|
||
Main retrieval pipeline for any agent.
|
||
|
||
1. Resolve user identity (L2)
|
||
2. Get session state (L1)
|
||
3. Get user facts (L3)
|
||
4. Search relevant memories if message provided (L3)
|
||
5. Get graph context (L3)
|
||
6. Compile memory brief
|
||
|
||
Args:
|
||
agent_id: Agent identifier for collection routing (e.g. "helion", "nutra", "greenfood")
|
||
"""
|
||
brief = MemoryBrief()
|
||
|
||
# L2: Identity
|
||
identity = await self.resolve_identity(channel, user_id, username, display_name)
|
||
brief.user_identity = identity
|
||
|
||
# L1: Session State
|
||
session = await self.get_session_state(
|
||
channel,
|
||
chat_id,
|
||
thread_id,
|
||
agent_id=agent_id,
|
||
user_id=user_id,
|
||
)
|
||
brief.session_state = session
|
||
brief.is_trusted_group = session.trust_mode
|
||
|
||
# L3: User Facts
|
||
if identity.platform_user_id:
|
||
facts = await self.get_user_facts(identity.platform_user_id)
|
||
brief.user_facts = facts
|
||
|
||
# L3: Semantic Search (if message provided) - agent-specific collections
|
||
if message:
|
||
memories = await self.search_memories(
|
||
query=message,
|
||
agent_id=agent_id,
|
||
platform_user_id=identity.platform_user_id,
|
||
chat_id=chat_id,
|
||
user_id=user_id,
|
||
limit=5
|
||
)
|
||
brief.relevant_memories = memories
|
||
|
||
# L3: Graph Context (filtered by agent_id to prevent role mixing)
|
||
graph_ctx = await self.get_user_graph_context(username, user_id, agent_id)
|
||
brief.user_topics = graph_ctx.get("topics", [])
|
||
brief.user_projects = graph_ctx.get("projects", [])
|
||
|
||
# Check for mentor presence
|
||
brief.mentor_present = identity.is_mentor
|
||
|
||
return brief
|
||
|
||
# =========================================================================
|
||
# Memory Storage (write path)
|
||
# =========================================================================
|
||
|
||
async def store_message(
|
||
self,
|
||
agent_id: str,
|
||
user_id: str,
|
||
username: Optional[str],
|
||
message_text: str,
|
||
response_text: str,
|
||
chat_id: str,
|
||
message_type: str = "conversation",
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> bool:
|
||
"""
|
||
Store a message exchange in agent-specific Qdrant collection.
|
||
|
||
This enables semantic retrieval of past conversations per agent.
|
||
"""
|
||
if not self.qdrant_client or not COHERE_API_KEY:
|
||
logger.debug(f"Cannot store message: qdrant={bool(self.qdrant_client)}, cohere={bool(COHERE_API_KEY)}")
|
||
return False
|
||
|
||
messages_collection = f"{agent_id}_messages"
|
||
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
import uuid
|
||
|
||
# Ensure collection exists
|
||
try:
|
||
self.qdrant_client.get_collection(messages_collection)
|
||
except Exception:
|
||
# Create collection with Cohere embed-multilingual-v3.0 dimensions (1024)
|
||
self.qdrant_client.create_collection(
|
||
collection_name=messages_collection,
|
||
vectors_config=qmodels.VectorParams(
|
||
size=1024,
|
||
distance=qmodels.Distance.COSINE
|
||
)
|
||
)
|
||
logger.info(f"✅ Created collection: {messages_collection}")
|
||
|
||
# Combine user message and response for better context retrieval
|
||
combined_text = f"User: {message_text}\n\nAssistant: {response_text}"
|
||
|
||
# Get embedding
|
||
embedding = await self.get_embedding(combined_text[:2000]) # Truncate for API limits
|
||
if not embedding:
|
||
logger.warning(f"Failed to get embedding for message storage")
|
||
return False
|
||
|
||
# Store in Qdrant
|
||
point_id = str(uuid.uuid4())
|
||
payload = {
|
||
"text": combined_text[:5000], # Limit payload size
|
||
"user_message": message_text[:2000],
|
||
"assistant_response": response_text[:3000],
|
||
"user_id": user_id,
|
||
"username": username,
|
||
"chat_id": chat_id,
|
||
"agent_id": agent_id,
|
||
"type": message_type,
|
||
"timestamp": datetime.utcnow().isoformat()
|
||
}
|
||
if metadata and isinstance(metadata, dict):
|
||
payload["metadata"] = metadata
|
||
|
||
self.qdrant_client.upsert(
|
||
collection_name=messages_collection,
|
||
points=[
|
||
qmodels.PointStruct(
|
||
id=point_id,
|
||
vector=embedding,
|
||
payload=payload
|
||
)
|
||
]
|
||
)
|
||
|
||
# Optional shared agronomy memory:
|
||
# - never stores user/chat identifiers
|
||
# - supports review gate (pending vs approved)
|
||
if (
|
||
SHARED_AGRO_LIBRARY_ENABLED
|
||
and agent_id == "agromatrix"
|
||
and message_type in {"vision", "conversation"}
|
||
and isinstance(metadata, dict)
|
||
and metadata.get("deterministic_plant_id")
|
||
):
|
||
await self._store_shared_agronomy_memory(
|
||
message_text=message_text,
|
||
response_text=response_text,
|
||
metadata=metadata,
|
||
)
|
||
|
||
logger.debug(f"✅ Stored message in {messages_collection}: {point_id[:8]}...")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to store message in {messages_collection}: {e}")
|
||
return False
|
||
|
||
async def _store_shared_agronomy_memory(
|
||
self,
|
||
message_text: str,
|
||
response_text: str,
|
||
metadata: Dict[str, Any],
|
||
) -> bool:
|
||
if not self.qdrant_client or not COHERE_API_KEY:
|
||
return False
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
import uuid
|
||
|
||
reviewed = bool(metadata.get("mentor_confirmed") or metadata.get("reviewed"))
|
||
collection = "agromatrix_shared_library"
|
||
if SHARED_AGRO_LIBRARY_REQUIRE_REVIEW and not reviewed:
|
||
collection = "agromatrix_shared_pending"
|
||
|
||
try:
|
||
self.qdrant_client.get_collection(collection)
|
||
except Exception:
|
||
self.qdrant_client.create_collection(
|
||
collection_name=collection,
|
||
vectors_config=qmodels.VectorParams(
|
||
size=1024,
|
||
distance=qmodels.Distance.COSINE,
|
||
),
|
||
)
|
||
|
||
compact = (
|
||
f"Plant case\nQuestion: {message_text[:800]}\n"
|
||
f"Answer: {response_text[:1200]}\n"
|
||
f"Candidates: {json.dumps(metadata.get('candidates', []), ensure_ascii=False)[:1200]}"
|
||
)
|
||
embedding = await self.get_embedding(compact[:2000])
|
||
if not embedding:
|
||
return False
|
||
|
||
payload = {
|
||
"text": compact[:3000],
|
||
"type": "plant_case",
|
||
"deterministic_plant_id": True,
|
||
"decision": metadata.get("decision"),
|
||
"confidence_threshold": metadata.get("confidence_threshold"),
|
||
"candidates": metadata.get("candidates", [])[:5],
|
||
"reviewed": reviewed,
|
||
"timestamp": datetime.utcnow().isoformat(),
|
||
}
|
||
self.qdrant_client.upsert(
|
||
collection_name=collection,
|
||
points=[qmodels.PointStruct(id=str(uuid.uuid4()), vector=embedding, payload=payload)],
|
||
)
|
||
return True
|
||
except Exception as e:
|
||
logger.debug(f"Shared agronomy memory store failed: {e}")
|
||
return False
|
||
|
||
async def register_pending_question(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
user_id: str,
|
||
agent_id: str,
|
||
question_text: str,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> bool:
|
||
if not self.pg_pool:
|
||
return False
|
||
text = (question_text or "").strip()
|
||
if not text:
|
||
return False
|
||
fp = self._question_fingerprint(text)
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
await conn.execute(
|
||
"""
|
||
INSERT INTO agent_pending_questions
|
||
(channel, chat_id, user_id, agent_id, question_text, question_fingerprint, status, metadata)
|
||
VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7::jsonb)
|
||
ON CONFLICT (agent_id, channel, chat_id, user_id, question_fingerprint, status)
|
||
DO NOTHING
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
text[:1200],
|
||
fp,
|
||
json.dumps(metadata or {}, ensure_ascii=False),
|
||
)
|
||
# Keep only last N open items.
|
||
await conn.execute(
|
||
"""
|
||
WITH ranked AS (
|
||
SELECT id, ROW_NUMBER() OVER (
|
||
PARTITION BY channel, chat_id, user_id, agent_id, status
|
||
ORDER BY created_at DESC
|
||
) AS rn
|
||
FROM agent_pending_questions
|
||
WHERE channel = $1
|
||
AND chat_id = $2
|
||
AND user_id = $3
|
||
AND agent_id = $4
|
||
AND status = 'pending'
|
||
)
|
||
UPDATE agent_pending_questions p
|
||
SET status = 'dismissed',
|
||
answered_at = NOW(),
|
||
metadata = COALESCE(p.metadata, '{}'::jsonb) || '{"reason":"overflow_trim"}'::jsonb
|
||
FROM ranked r
|
||
WHERE p.id = r.id
|
||
AND r.rn > $5
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
max(1, PENDING_QUESTIONS_LIMIT),
|
||
)
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"register_pending_question failed: {e}")
|
||
return False
|
||
|
||
async def resolve_pending_question(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
user_id: str,
|
||
agent_id: str,
|
||
answer_text: Optional[str] = None,
|
||
reason: str = "answered",
|
||
) -> bool:
|
||
if not self.pg_pool:
|
||
return False
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
row = await conn.fetchrow(
|
||
"""
|
||
WITH target AS (
|
||
SELECT id
|
||
FROM agent_pending_questions
|
||
WHERE channel = $1
|
||
AND chat_id = $2
|
||
AND user_id = $3
|
||
AND agent_id = $4
|
||
AND status = 'pending'
|
||
ORDER BY created_at ASC
|
||
LIMIT 1
|
||
)
|
||
UPDATE agent_pending_questions p
|
||
SET status = CASE WHEN $5 = 'dismissed' THEN 'dismissed' ELSE 'answered' END,
|
||
answered_at = NOW(),
|
||
metadata = COALESCE(p.metadata, '{}'::jsonb)
|
||
|| jsonb_build_object(
|
||
'resolution_reason', $5,
|
||
'answer_fingerprint', COALESCE($6, '')
|
||
)
|
||
FROM target t
|
||
WHERE p.id = t.id
|
||
RETURNING p.id
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
reason,
|
||
self._question_fingerprint(answer_text or "") if answer_text else "",
|
||
)
|
||
return bool(row)
|
||
except Exception as e:
|
||
logger.warning(f"resolve_pending_question failed: {e}")
|
||
return False
|
||
|
||
async def store_interaction(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
user_id: str,
|
||
agent_id: str,
|
||
username: Optional[str],
|
||
user_message: str,
|
||
assistant_response: str,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> bool:
|
||
# Backward-compatible wrapper for older call sites.
|
||
return await self.store_message(
|
||
agent_id=agent_id,
|
||
user_id=user_id,
|
||
username=username,
|
||
message_text=user_message,
|
||
response_text=assistant_response,
|
||
chat_id=chat_id,
|
||
message_type="conversation",
|
||
metadata=metadata,
|
||
)
|
||
|
||
async def update_session_state(
|
||
self,
|
||
conversation_id: str,
|
||
**updates
|
||
):
|
||
"""Update session state after interaction"""
|
||
if not self.pg_pool or not conversation_id:
|
||
return
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
# Build dynamic update
|
||
set_clauses = ["updated_at = NOW()"]
|
||
params = [conversation_id]
|
||
param_idx = 2
|
||
|
||
allowed_fields = [
|
||
'last_user_id', 'last_user_nick',
|
||
'active_topic', 'context_open',
|
||
'last_media_handled', 'last_answer_fingerprint',
|
||
'trust_mode', 'apprentice_mode'
|
||
]
|
||
|
||
for field, value in updates.items():
|
||
if field in allowed_fields:
|
||
set_clauses.append(f"{field} = ${param_idx}")
|
||
params.append(value)
|
||
param_idx += 1
|
||
|
||
query = f"""
|
||
UPDATE agent_session_state
|
||
SET {', '.join(set_clauses)}
|
||
WHERE conversation_id = $1
|
||
"""
|
||
await conn.execute(query, *params)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Session state update failed: {e}")
|
||
|
||
|
||
# Global instance
|
||
memory_retrieval = MemoryRetrieval()
|