1921 lines
76 KiB
Python
1921 lines
76 KiB
Python
"""
|
||
Universal Memory Retrieval Pipeline v4.0
|
||
|
||
This module implements the 4-layer memory retrieval for ALL agents:
|
||
- L0: Working Context (from request)
|
||
- L1: Session State Memory (SSM) - from Postgres
|
||
- L2: Platform Identity & Roles (PIR) - from Postgres + Neo4j
|
||
- L3: Organizational Memory (OM) - from Postgres + Neo4j + Qdrant
|
||
|
||
The pipeline generates a "memory brief" that is injected into the LLM context.
|
||
|
||
Collections per agent:
|
||
- {agent_id}_messages - chat history with embeddings
|
||
- {agent_id}_memory_items - facts, preferences
|
||
- {agent_id}_docs - knowledge base documents
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import logging
|
||
import re
|
||
import hashlib
|
||
from typing import Optional, Dict, Any, List
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
|
||
import httpx
|
||
import asyncpg
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Configuration
|
||
POSTGRES_URL = os.getenv("DATABASE_URL", "postgresql://daarion:DaarionDB2026!@dagi-postgres:5432/daarion_memory")
|
||
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
|
||
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
|
||
COHERE_API_KEY = os.getenv("COHERE_API_KEY", "")
|
||
NEO4J_BOLT_URL = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687")
|
||
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
|
||
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j")
|
||
PENDING_QUESTIONS_LIMIT = int(os.getenv("AGENT_PENDING_QUESTIONS_LIMIT", "5"))
|
||
SHARED_AGRO_LIBRARY_ENABLED = os.getenv("AGROMATRIX_SHARED_LIBRARY_ENABLED", "true").lower() == "true"
|
||
SHARED_AGRO_LIBRARY_REQUIRE_REVIEW = os.getenv("AGROMATRIX_SHARED_LIBRARY_REQUIRE_REVIEW", "true").lower() == "true"
|
||
DOC_VERSION_PREVIEW_CHARS = int(os.getenv("DOC_VERSION_PREVIEW_CHARS", "240"))
|
||
|
||
|
||
@dataclass
|
||
class UserIdentity:
|
||
"""Platform user identity"""
|
||
platform_user_id: Optional[str] = None
|
||
channel: str = "telegram"
|
||
channel_user_id: str = ""
|
||
username: Optional[str] = None
|
||
display_name: Optional[str] = None
|
||
roles: List[str] = field(default_factory=list)
|
||
is_mentor: bool = False
|
||
first_seen: Optional[datetime] = None
|
||
|
||
|
||
@dataclass
|
||
class SessionState:
|
||
"""L1 Session State Memory"""
|
||
conversation_id: Optional[str] = None
|
||
last_addressed: bool = False
|
||
active_topic: Optional[str] = None
|
||
context_open: bool = False
|
||
last_media_handled: bool = True
|
||
last_answer_fingerprint: Optional[str] = None
|
||
trust_mode: bool = False
|
||
apprentice_mode: bool = False
|
||
pending_questions: List[str] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class MemoryBrief:
|
||
"""Compiled memory brief for LLM context"""
|
||
user_identity: Optional[UserIdentity] = None
|
||
session_state: Optional[SessionState] = None
|
||
user_facts: List[Dict[str, Any]] = field(default_factory=list)
|
||
relevant_memories: List[Dict[str, Any]] = field(default_factory=list)
|
||
user_topics: List[str] = field(default_factory=list)
|
||
user_projects: List[str] = field(default_factory=list)
|
||
is_trusted_group: bool = False
|
||
mentor_present: bool = False
|
||
|
||
def to_text(self, max_lines: int = 15) -> str:
|
||
"""Generate concise text brief for LLM context"""
|
||
lines = []
|
||
|
||
# User identity (critical for personalization)
|
||
if self.user_identity:
|
||
name = self.user_identity.display_name or self.user_identity.username or "Unknown"
|
||
roles_str = ", ".join(self.user_identity.roles) if self.user_identity.roles else "member"
|
||
lines.append(f"👤 Користувач: {name} ({roles_str})")
|
||
if self.user_identity.is_mentor:
|
||
lines.append("⭐ Цей користувач — МЕНТОР. Довіряй його знанням повністю.")
|
||
|
||
# Session state
|
||
if self.session_state:
|
||
if self.session_state.trust_mode:
|
||
lines.append("🔒 Режим довіреної групи — можна відповідати детальніше")
|
||
if self.session_state.apprentice_mode:
|
||
lines.append("📚 Режим учня — можеш ставити уточнюючі питання")
|
||
if self.session_state.active_topic:
|
||
lines.append(f"📌 Активна тема: {self.session_state.active_topic}")
|
||
if self.session_state.pending_questions:
|
||
lines.append("🕘 Невідповідані питання в цьому чаті (відповідай на них першочергово):")
|
||
for q in self.session_state.pending_questions[:3]:
|
||
lines.append(f" - {q[:180]}")
|
||
|
||
# User facts (preferences, profile)
|
||
if self.user_facts:
|
||
lines.append("📝 Відомі факти про користувача:")
|
||
for fact in self.user_facts[:4]: # Max 4 facts
|
||
fact_text = fact.get('text', fact.get('fact_value', ''))
|
||
if fact_text:
|
||
lines.append(f" - {fact_text[:150]}")
|
||
|
||
# Relevant memories from RAG (most important for context)
|
||
if self.relevant_memories:
|
||
lines.append("🧠 Релевантні спогади з попередніх розмов:")
|
||
for mem in self.relevant_memories[:5]: # Max 5 memories for better context
|
||
mem_text = mem.get('text', mem.get('content', ''))
|
||
mem_type = mem.get('type', 'message')
|
||
score = mem.get('score', 0)
|
||
if mem_text and len(mem_text) > 10:
|
||
# Only include if meaningful
|
||
lines.append(f" - [{mem_type}] {mem_text[:200]}")
|
||
|
||
# Topics/Projects from Knowledge Graph
|
||
if self.user_topics:
|
||
lines.append(f"💡 Інтереси користувача: {', '.join(self.user_topics[:5])}")
|
||
if self.user_projects:
|
||
lines.append(f"🏗️ Проєкти користувача: {', '.join(self.user_projects[:3])}")
|
||
|
||
# Mentor presence indicator
|
||
if self.mentor_present:
|
||
lines.append("⚠️ ВАЖЛИВО: Ти спілкуєшся з ментором. Сприймай як навчання.")
|
||
|
||
# Truncate if needed
|
||
if len(lines) > max_lines:
|
||
lines = lines[:max_lines]
|
||
lines.append("...")
|
||
|
||
return "\n".join(lines) if lines else ""
|
||
|
||
|
||
class MemoryRetrieval:
|
||
"""Memory Retrieval Pipeline"""
|
||
|
||
def __init__(self):
|
||
self.pg_pool: Optional[asyncpg.Pool] = None
|
||
self.neo4j_driver = None
|
||
self.qdrant_client = None
|
||
self.http_client: Optional[httpx.AsyncClient] = None
|
||
|
||
async def initialize(self):
|
||
"""Initialize database connections"""
|
||
# PostgreSQL
|
||
try:
|
||
self.pg_pool = await asyncpg.create_pool(
|
||
POSTGRES_URL,
|
||
min_size=2,
|
||
max_size=10
|
||
)
|
||
logger.info("✅ Memory Retrieval: PostgreSQL connected")
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ Memory Retrieval: PostgreSQL not available: {e}")
|
||
|
||
# Neo4j
|
||
try:
|
||
from neo4j import AsyncGraphDatabase
|
||
self.neo4j_driver = AsyncGraphDatabase.driver(
|
||
NEO4J_BOLT_URL,
|
||
auth=(NEO4J_USER, NEO4J_PASSWORD)
|
||
)
|
||
await self.neo4j_driver.verify_connectivity()
|
||
logger.info("✅ Memory Retrieval: Neo4j connected")
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ Memory Retrieval: Neo4j not available: {e}")
|
||
|
||
# Qdrant
|
||
try:
|
||
from qdrant_client import QdrantClient
|
||
self.qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
|
||
self.qdrant_client.get_collections() # Test connection
|
||
logger.info("✅ Memory Retrieval: Qdrant connected")
|
||
except Exception as e:
|
||
logger.warning(f"⚠️ Memory Retrieval: Qdrant not available: {e}")
|
||
|
||
# HTTP client for embeddings
|
||
self.http_client = httpx.AsyncClient(timeout=30.0)
|
||
await self._ensure_aux_tables()
|
||
|
||
async def close(self):
|
||
"""Close connections"""
|
||
if self.pg_pool:
|
||
await self.pg_pool.close()
|
||
if self.neo4j_driver:
|
||
await self.neo4j_driver.close()
|
||
if self.http_client:
|
||
await self.http_client.aclose()
|
||
|
||
async def _ensure_aux_tables(self):
|
||
"""Create auxiliary tables used by agent runtime policies."""
|
||
if not self.pg_pool:
|
||
return
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
await conn.execute(
|
||
"""
|
||
CREATE TABLE IF NOT EXISTS agent_session_state (
|
||
channel TEXT NOT NULL,
|
||
chat_id TEXT NOT NULL,
|
||
user_id TEXT NOT NULL,
|
||
agent_id TEXT NOT NULL,
|
||
conversation_id TEXT NOT NULL,
|
||
last_user_id TEXT,
|
||
last_user_nick TEXT,
|
||
active_topic TEXT,
|
||
context_open BOOLEAN NOT NULL DEFAULT FALSE,
|
||
last_media_handled BOOLEAN NOT NULL DEFAULT TRUE,
|
||
last_answer_fingerprint TEXT,
|
||
trust_mode BOOLEAN NOT NULL DEFAULT FALSE,
|
||
apprentice_mode BOOLEAN NOT NULL DEFAULT FALSE,
|
||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||
PRIMARY KEY (channel, chat_id, user_id, agent_id)
|
||
);
|
||
CREATE INDEX IF NOT EXISTS idx_agent_session_state_conv
|
||
ON agent_session_state (conversation_id);
|
||
|
||
CREATE TABLE IF NOT EXISTS agent_pending_questions (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
channel TEXT NOT NULL,
|
||
chat_id TEXT NOT NULL,
|
||
user_id TEXT NOT NULL,
|
||
agent_id TEXT NOT NULL,
|
||
question_text TEXT NOT NULL,
|
||
question_fingerprint TEXT NOT NULL,
|
||
status TEXT NOT NULL DEFAULT 'pending',
|
||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||
answered_at TIMESTAMPTZ,
|
||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb
|
||
);
|
||
CREATE INDEX IF NOT EXISTS idx_agent_pending_questions_scope
|
||
ON agent_pending_questions (agent_id, channel, chat_id, user_id, status, created_at DESC);
|
||
CREATE UNIQUE INDEX IF NOT EXISTS idx_agent_pending_questions_unique_open
|
||
ON agent_pending_questions (agent_id, channel, chat_id, user_id, question_fingerprint, status);
|
||
|
||
CREATE TABLE IF NOT EXISTS agent_document_versions (
|
||
id BIGSERIAL PRIMARY KEY,
|
||
agent_id TEXT NOT NULL,
|
||
doc_id TEXT NOT NULL,
|
||
version_no INTEGER NOT NULL,
|
||
text_hash TEXT NOT NULL,
|
||
text_len INTEGER NOT NULL DEFAULT 0,
|
||
text_preview TEXT,
|
||
file_name TEXT,
|
||
dao_id TEXT,
|
||
user_id TEXT,
|
||
storage_ref TEXT,
|
||
source TEXT NOT NULL DEFAULT 'ingest',
|
||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
|
||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||
UNIQUE (agent_id, doc_id, version_no)
|
||
);
|
||
CREATE INDEX IF NOT EXISTS idx_agent_document_versions_latest
|
||
ON agent_document_versions (agent_id, doc_id, version_no DESC);
|
||
"""
|
||
)
|
||
except Exception as e:
|
||
logger.warning(f"Aux tables init failed: {e}")
|
||
|
||
# =========================================================================
|
||
# L2: Platform Identity Resolution
|
||
# =========================================================================
|
||
|
||
async def resolve_identity(
|
||
self,
|
||
channel: str,
|
||
channel_user_id: str,
|
||
username: Optional[str] = None,
|
||
display_name: Optional[str] = None
|
||
) -> UserIdentity:
|
||
"""Resolve or create platform user identity"""
|
||
identity = UserIdentity(
|
||
channel=channel,
|
||
channel_user_id=channel_user_id,
|
||
username=username,
|
||
display_name=display_name
|
||
)
|
||
|
||
if not self.pg_pool:
|
||
return identity
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
# Use the resolve_platform_user function
|
||
platform_user_id = await conn.fetchval(
|
||
"SELECT resolve_platform_user($1, $2, $3, $4)",
|
||
channel, channel_user_id, username, display_name
|
||
)
|
||
identity.platform_user_id = str(platform_user_id) if platform_user_id else None
|
||
|
||
# Get roles
|
||
if platform_user_id:
|
||
roles = await conn.fetch("""
|
||
SELECT r.code FROM user_roles ur
|
||
JOIN platform_roles r ON ur.role_id = r.id
|
||
WHERE ur.platform_user_id = $1 AND ur.revoked_at IS NULL
|
||
""", platform_user_id)
|
||
identity.roles = [r['code'] for r in roles]
|
||
|
||
# Check if mentor
|
||
is_mentor = await conn.fetchval(
|
||
"SELECT is_mentor($1, $2)",
|
||
channel_user_id, username
|
||
)
|
||
identity.is_mentor = bool(is_mentor)
|
||
|
||
except Exception as e:
|
||
logger.debug(f"Identity resolution fallback: {e}")
|
||
|
||
return identity
|
||
|
||
# =========================================================================
|
||
# L1: Session State
|
||
# =========================================================================
|
||
|
||
async def get_session_state(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
thread_id: Optional[str] = None,
|
||
agent_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
) -> SessionState:
|
||
"""Get or create session state for conversation"""
|
||
state = SessionState()
|
||
|
||
if not self.pg_pool:
|
||
return state
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
if agent_id and user_id:
|
||
conv_id = self._build_conversation_id(channel, chat_id, user_id, agent_id)
|
||
row = await conn.fetchrow(
|
||
"""
|
||
SELECT conversation_id, active_topic, context_open, last_media_handled,
|
||
last_answer_fingerprint, trust_mode, apprentice_mode
|
||
FROM agent_session_state
|
||
WHERE channel = $1
|
||
AND chat_id = $2
|
||
AND user_id = $3
|
||
AND agent_id = $4
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
)
|
||
if not row:
|
||
await conn.execute(
|
||
"""
|
||
INSERT INTO agent_session_state
|
||
(channel, chat_id, user_id, agent_id, conversation_id)
|
||
VALUES ($1, $2, $3, $4, $5)
|
||
ON CONFLICT (channel, chat_id, user_id, agent_id) DO NOTHING
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
conv_id,
|
||
)
|
||
state.conversation_id = conv_id
|
||
else:
|
||
state.conversation_id = str(row.get("conversation_id") or conv_id)
|
||
state.active_topic = row.get("active_topic")
|
||
state.context_open = bool(row.get("context_open", False))
|
||
state.last_media_handled = bool(row.get("last_media_handled", True))
|
||
state.last_answer_fingerprint = row.get("last_answer_fingerprint")
|
||
state.trust_mode = bool(row.get("trust_mode", False))
|
||
state.apprentice_mode = bool(row.get("apprentice_mode", False))
|
||
else:
|
||
state.conversation_id = self._build_conversation_id(
|
||
channel,
|
||
chat_id,
|
||
user_id or "unknown",
|
||
agent_id or "agent",
|
||
)
|
||
|
||
if agent_id and user_id:
|
||
pending_rows = await conn.fetch(
|
||
"""
|
||
SELECT question_text
|
||
FROM agent_pending_questions
|
||
WHERE channel = $1
|
||
AND chat_id = $2
|
||
AND user_id = $3
|
||
AND agent_id = $4
|
||
AND status = 'pending'
|
||
ORDER BY created_at ASC
|
||
LIMIT $5
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
PENDING_QUESTIONS_LIMIT,
|
||
)
|
||
state.pending_questions = [
|
||
str(r.get("question_text") or "").strip()
|
||
for r in pending_rows
|
||
if str(r.get("question_text") or "").strip()
|
||
]
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Session state retrieval failed: {e}")
|
||
|
||
return state
|
||
|
||
# =========================================================================
|
||
# L3: Memory Retrieval (Facts + Semantic)
|
||
# =========================================================================
|
||
|
||
async def get_user_facts(
|
||
self,
|
||
platform_user_id: Optional[str],
|
||
limit: int = 5
|
||
) -> List[Dict[str, Any]]:
|
||
"""Get user's explicit facts from PostgreSQL"""
|
||
if not self.pg_pool or not platform_user_id:
|
||
return []
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
rows = await conn.fetch("""
|
||
SELECT type, text, summary, confidence, visibility
|
||
FROM helion_memory_items
|
||
WHERE platform_user_id = $1
|
||
AND archived_at IS NULL
|
||
AND (expires_at IS NULL OR expires_at > NOW())
|
||
AND type IN ('preference', 'profile_fact')
|
||
ORDER BY confidence DESC, updated_at DESC
|
||
LIMIT $2
|
||
""", platform_user_id, limit)
|
||
|
||
return [dict(r) for r in rows]
|
||
except Exception as e:
|
||
logger.warning(f"User facts retrieval failed: {e}")
|
||
return []
|
||
|
||
async def get_embedding(self, text: str) -> Optional[List[float]]:
|
||
"""Get embedding from Cohere API"""
|
||
if not COHERE_API_KEY or not self.http_client:
|
||
return None
|
||
|
||
try:
|
||
response = await self.http_client.post(
|
||
"https://api.cohere.ai/v1/embed",
|
||
headers={
|
||
"Authorization": f"Bearer {COHERE_API_KEY}",
|
||
"Content-Type": "application/json"
|
||
},
|
||
json={
|
||
"texts": [text],
|
||
"model": "embed-multilingual-v3.0",
|
||
"input_type": "search_query",
|
||
"truncate": "END"
|
||
}
|
||
)
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
return data.get("embeddings", [[]])[0]
|
||
except Exception as e:
|
||
logger.warning(f"Embedding generation failed: {e}")
|
||
|
||
return None
|
||
|
||
async def search_memories(
|
||
self,
|
||
query: str,
|
||
agent_id: str = "helion",
|
||
platform_user_id: Optional[str] = None,
|
||
chat_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
visibility: str = "platform",
|
||
limit: int = 5
|
||
) -> List[Dict[str, Any]]:
|
||
"""Semantic search in Qdrant across agent-specific collections"""
|
||
if not self.qdrant_client:
|
||
return []
|
||
|
||
# Get embedding
|
||
embedding = await self.get_embedding(query)
|
||
if not embedding:
|
||
return []
|
||
|
||
all_results = []
|
||
|
||
q = (query or "").lower()
|
||
# If user explicitly asks about documents/catalogs, prefer knowledge base docs over chat snippets.
|
||
is_doc_query = any(k in q for k in ["pdf", "каталог", "каталоз", "документ", "файл", "стор", "page", "pages"])
|
||
# Simple keyword gate to avoid irrelevant chat snippets dominating doc queries.
|
||
# Example: when asking "з каталогу Defenda 2026 ... гліфосат", old "Бокаші" messages may match too well.
|
||
topic_keywords: List[str] = []
|
||
for kw in ["defenda", "ifagri", "bayer", "гліфосат", "glyphos", "глифос", "npk", "мінерал", "добрив", "гербіц", "фунгіц", "інсектиц"]:
|
||
if kw in q:
|
||
topic_keywords.append(kw)
|
||
|
||
# Dynamic collection names based on agent_id
|
||
memory_items_collection = f"{agent_id}_memory_items"
|
||
messages_collection = f"{agent_id}_messages"
|
||
docs_collection = f"{agent_id}_docs"
|
||
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
|
||
# Search 1: {agent_id}_memory_items (facts, preferences)
|
||
try:
|
||
must_conditions = []
|
||
if platform_user_id:
|
||
must_conditions.append(
|
||
qmodels.FieldCondition(
|
||
key="platform_user_id",
|
||
match=qmodels.MatchValue(value=platform_user_id)
|
||
)
|
||
)
|
||
|
||
search_filter = qmodels.Filter(must=must_conditions) if must_conditions else None
|
||
|
||
results = self.qdrant_client.search(
|
||
collection_name=memory_items_collection,
|
||
query_vector=embedding,
|
||
query_filter=search_filter,
|
||
limit=limit,
|
||
with_payload=True
|
||
)
|
||
|
||
for r in results:
|
||
if r.score > 0.3: # Threshold for relevance
|
||
all_results.append({
|
||
"text": r.payload.get("text", ""),
|
||
"type": r.payload.get("type", "fact"),
|
||
"confidence": r.payload.get("confidence", 0.5),
|
||
"score": r.score,
|
||
"source": "memory_items"
|
||
})
|
||
except Exception as e:
|
||
logger.debug(f"{memory_items_collection} search: {e}")
|
||
|
||
# Search 2: {agent_id}_messages (chat history)
|
||
try:
|
||
msg_filter = None
|
||
if chat_id:
|
||
# Payload schema differs across ingesters: some use chat_id, others channel_id.
|
||
msg_filter = qmodels.Filter(
|
||
should=[
|
||
qmodels.FieldCondition(key="chat_id", match=qmodels.MatchValue(value=str(chat_id))),
|
||
qmodels.FieldCondition(key="channel_id", match=qmodels.MatchValue(value=str(chat_id))),
|
||
]
|
||
)
|
||
results = self.qdrant_client.search(
|
||
collection_name=messages_collection,
|
||
query_vector=embedding,
|
||
query_filter=msg_filter,
|
||
limit=limit,
|
||
with_payload=True
|
||
)
|
||
|
||
for r in results:
|
||
# Higher threshold for messages; even higher when user asks about docs to avoid pulling old chatter.
|
||
msg_thresh = 0.5 if is_doc_query else 0.4
|
||
if r.score > msg_thresh:
|
||
text = self._extract_message_text(r.payload)
|
||
# Skip very short or system messages
|
||
if len(text) > 20 and not text.startswith("<"):
|
||
if is_doc_query and topic_keywords:
|
||
tl = text.lower()
|
||
if not any(k in tl for k in topic_keywords):
|
||
continue
|
||
all_results.append({
|
||
"text": text,
|
||
"type": "message",
|
||
"score": r.score,
|
||
"source": "messages"
|
||
})
|
||
except Exception as e:
|
||
logger.debug(f"{messages_collection} search: {e}")
|
||
|
||
# Search 3: {agent_id}_docs (knowledge base) - optional
|
||
try:
|
||
results = self.qdrant_client.search(
|
||
collection_name=docs_collection,
|
||
query_vector=embedding,
|
||
limit=6 if is_doc_query else 3, # Pull more docs for explicit doc queries
|
||
with_payload=True
|
||
)
|
||
|
||
for r in results:
|
||
# When user asks about PDF/catalogs, relax threshold so docs show up more reliably.
|
||
doc_thresh = 0.35 if is_doc_query else 0.5
|
||
if r.score > doc_thresh:
|
||
text = r.payload.get("text", r.payload.get("content", ""))
|
||
if len(text) > 30:
|
||
all_results.append({
|
||
"text": text[:500], # Truncate long docs
|
||
"type": "knowledge",
|
||
# Slightly boost docs for doc queries so they win vs chat snippets.
|
||
"score": (r.score + 0.12) if is_doc_query else r.score,
|
||
"source": "docs"
|
||
})
|
||
except Exception as e:
|
||
logger.debug(f"{docs_collection} search: {e}")
|
||
|
||
# Search 4: shared agronomy memory (reviewed, cross-chat, anonymized)
|
||
if (
|
||
SHARED_AGRO_LIBRARY_ENABLED
|
||
and agent_id == "agromatrix"
|
||
and self._is_plant_query(query)
|
||
):
|
||
try:
|
||
results = self.qdrant_client.search(
|
||
collection_name="agromatrix_shared_library",
|
||
query_vector=embedding,
|
||
limit=3,
|
||
with_payload=True
|
||
)
|
||
for r in results:
|
||
if r.score > 0.45:
|
||
text = str(r.payload.get("text") or "").strip()
|
||
if len(text) > 20:
|
||
all_results.append({
|
||
"text": text[:500],
|
||
"type": "shared_agro_fact",
|
||
"score": r.score + 0.05,
|
||
"source": "shared_agronomy_library"
|
||
})
|
||
except Exception as e:
|
||
logger.debug(f"agromatrix_shared_library search: {e}")
|
||
|
||
# Sort by score and deduplicate
|
||
all_results.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||
|
||
# Remove duplicates based on text similarity
|
||
seen_texts = set()
|
||
unique_results = []
|
||
for r in all_results:
|
||
text_key = self._canonical_text_key(r.get("text", ""))
|
||
if text_key not in seen_texts:
|
||
seen_texts.add(text_key)
|
||
unique_results.append(r)
|
||
|
||
return unique_results[:limit]
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Memory search failed for {agent_id}: {e}")
|
||
return []
|
||
|
||
@staticmethod
|
||
def _extract_message_text(payload: Dict[str, Any]) -> str:
|
||
"""
|
||
Normalize text across both payload schemas:
|
||
- memory-service: content/text (+ role/channel_id)
|
||
- router: user_message + assistant_response (+ chat_id)
|
||
"""
|
||
if not payload:
|
||
return ""
|
||
|
||
text = (payload.get("text") or payload.get("content") or "").strip()
|
||
if text:
|
||
lower = text.lower()
|
||
marker = "\n\nassistant:"
|
||
idx = lower.rfind(marker)
|
||
if lower.startswith("user:") and idx != -1:
|
||
assistant_text = text[idx + len(marker):].strip()
|
||
if assistant_text:
|
||
return assistant_text
|
||
return text
|
||
|
||
user_text = (payload.get("user_message") or "").strip()
|
||
assistant_text = (payload.get("assistant_response") or "").strip()
|
||
if user_text and assistant_text:
|
||
return f"User: {user_text}\n\nAssistant: {assistant_text}"
|
||
return user_text or assistant_text
|
||
|
||
@staticmethod
|
||
def _canonical_text_key(text: str) -> str:
|
||
if not text:
|
||
return ""
|
||
normalized = re.sub(r"\s+", " ", text.strip().lower())
|
||
return normalized[:220]
|
||
|
||
@staticmethod
|
||
def _is_plant_query(text: str) -> bool:
|
||
q = (text or "").lower()
|
||
if not q:
|
||
return False
|
||
markers = [
|
||
"рослин", "культур", "лист", "стебл", "бур'ян", "хвороб", "шкідник",
|
||
"what plant", "identify plant", "crop", "species", "leaf", "stem",
|
||
"что за растение", "культура", "листок", "фото рослини"
|
||
]
|
||
return any(m in q for m in markers)
|
||
|
||
@staticmethod
|
||
def _question_fingerprint(question_text: str) -> str:
|
||
normalized = re.sub(r"\s+", " ", (question_text or "").strip().lower())
|
||
return hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:16]
|
||
|
||
@staticmethod
|
||
def _build_conversation_id(channel: str, chat_id: str, user_id: str, agent_id: str) -> str:
|
||
seed = f"{channel}:{chat_id}:{user_id}:{agent_id}"
|
||
return hashlib.sha1(seed.encode("utf-8")).hexdigest()[:24]
|
||
|
||
async def get_user_graph_context(
|
||
self,
|
||
username: Optional[str] = None,
|
||
telegram_user_id: Optional[str] = None,
|
||
agent_id: str = "helion"
|
||
) -> Dict[str, List[str]]:
|
||
"""Get user's topics and projects from Neo4j, filtered by agent_id"""
|
||
context = {"topics": [], "projects": []}
|
||
|
||
if not self.neo4j_driver:
|
||
return context
|
||
|
||
if not username and not telegram_user_id:
|
||
return context
|
||
|
||
try:
|
||
async with self.neo4j_driver.session() as session:
|
||
# Find user and their context, FILTERED BY AGENT_ID
|
||
# Support both username formats: with and without @
|
||
username_with_at = f"@{username}" if username and not username.startswith("@") else username
|
||
username_without_at = username[1:] if username and username.startswith("@") else username
|
||
|
||
# Query with agent_id filter on relationships
|
||
query = """
|
||
MATCH (u:User)
|
||
WHERE u.username IN [$username, $username_with_at, $username_without_at]
|
||
OR u.telegram_user_id = $telegram_user_id
|
||
OR u.telegram_id = $telegram_user_id
|
||
OPTIONAL MATCH (u)-[r1:ASKED_ABOUT]->(t:Topic)
|
||
WHERE r1.agent_id = $agent_id OR r1.agent_id IS NULL
|
||
OPTIONAL MATCH (u)-[r2:WORKS_ON]->(p:Project)
|
||
WHERE r2.agent_id = $agent_id OR r2.agent_id IS NULL
|
||
RETURN
|
||
collect(DISTINCT t.name) as topics,
|
||
collect(DISTINCT p.name) as projects
|
||
"""
|
||
result = await session.run(
|
||
query,
|
||
username=username,
|
||
username_with_at=username_with_at,
|
||
username_without_at=username_without_at,
|
||
telegram_user_id=telegram_user_id,
|
||
agent_id=agent_id
|
||
)
|
||
record = await result.single()
|
||
|
||
if record:
|
||
context["topics"] = [t for t in record["topics"] if t]
|
||
context["projects"] = [p for p in record["projects"] if p]
|
||
|
||
logger.debug(f"Graph context for {agent_id}: topics={len(context['topics'])}, projects={len(context['projects'])}")
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Graph context retrieval failed for {agent_id}: {e}")
|
||
|
||
return context
|
||
|
||
# =========================================================================
|
||
# Main Retrieval Pipeline
|
||
# =========================================================================
|
||
|
||
async def retrieve(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
user_id: str,
|
||
agent_id: str = "helion",
|
||
username: Optional[str] = None,
|
||
display_name: Optional[str] = None,
|
||
message: Optional[str] = None,
|
||
thread_id: Optional[str] = None
|
||
) -> MemoryBrief:
|
||
"""
|
||
Main retrieval pipeline for any agent.
|
||
|
||
1. Resolve user identity (L2)
|
||
2. Get session state (L1)
|
||
3. Get user facts (L3)
|
||
4. Search relevant memories if message provided (L3)
|
||
5. Get graph context (L3)
|
||
6. Compile memory brief
|
||
|
||
Args:
|
||
agent_id: Agent identifier for collection routing (e.g. "helion", "nutra", "greenfood")
|
||
"""
|
||
brief = MemoryBrief()
|
||
|
||
# L2: Identity
|
||
identity = await self.resolve_identity(channel, user_id, username, display_name)
|
||
brief.user_identity = identity
|
||
|
||
# L1: Session State
|
||
session = await self.get_session_state(
|
||
channel,
|
||
chat_id,
|
||
thread_id,
|
||
agent_id=agent_id,
|
||
user_id=user_id,
|
||
)
|
||
brief.session_state = session
|
||
brief.is_trusted_group = session.trust_mode
|
||
|
||
# L3: User Facts
|
||
if identity.platform_user_id:
|
||
facts = await self.get_user_facts(identity.platform_user_id)
|
||
brief.user_facts = facts
|
||
|
||
# L3: Semantic Search (if message provided) - agent-specific collections
|
||
if message:
|
||
memories = await self.search_memories(
|
||
query=message,
|
||
agent_id=agent_id,
|
||
platform_user_id=identity.platform_user_id,
|
||
chat_id=chat_id,
|
||
user_id=user_id,
|
||
limit=5
|
||
)
|
||
brief.relevant_memories = memories
|
||
|
||
# L3: Graph Context (filtered by agent_id to prevent role mixing)
|
||
graph_ctx = await self.get_user_graph_context(username, user_id, agent_id)
|
||
brief.user_topics = graph_ctx.get("topics", [])
|
||
brief.user_projects = graph_ctx.get("projects", [])
|
||
|
||
# Check for mentor presence
|
||
brief.mentor_present = identity.is_mentor
|
||
|
||
return brief
|
||
|
||
# =========================================================================
|
||
# Memory Storage (write path)
|
||
# =========================================================================
|
||
|
||
async def store_message(
|
||
self,
|
||
agent_id: str,
|
||
user_id: str,
|
||
username: Optional[str],
|
||
message_text: str,
|
||
response_text: str,
|
||
chat_id: str,
|
||
message_type: str = "conversation",
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> bool:
|
||
"""
|
||
Store a message exchange in agent-specific Qdrant collection.
|
||
|
||
This enables semantic retrieval of past conversations per agent.
|
||
"""
|
||
if not self.qdrant_client or not COHERE_API_KEY:
|
||
logger.debug(f"Cannot store message: qdrant={bool(self.qdrant_client)}, cohere={bool(COHERE_API_KEY)}")
|
||
return False
|
||
|
||
messages_collection = f"{agent_id}_messages"
|
||
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
import uuid
|
||
|
||
# Ensure collection exists
|
||
try:
|
||
self.qdrant_client.get_collection(messages_collection)
|
||
except Exception:
|
||
# Create collection with Cohere embed-multilingual-v3.0 dimensions (1024)
|
||
self.qdrant_client.create_collection(
|
||
collection_name=messages_collection,
|
||
vectors_config=qmodels.VectorParams(
|
||
size=1024,
|
||
distance=qmodels.Distance.COSINE
|
||
)
|
||
)
|
||
logger.info(f"✅ Created collection: {messages_collection}")
|
||
|
||
# Combine user message and response for better context retrieval
|
||
combined_text = f"User: {message_text}\n\nAssistant: {response_text}"
|
||
|
||
# Get embedding
|
||
embedding = await self.get_embedding(combined_text[:2000]) # Truncate for API limits
|
||
if not embedding:
|
||
logger.warning(f"Failed to get embedding for message storage")
|
||
return False
|
||
|
||
# Store in Qdrant
|
||
point_id = str(uuid.uuid4())
|
||
payload = {
|
||
"text": combined_text[:5000], # Limit payload size
|
||
"user_message": message_text[:2000],
|
||
"assistant_response": response_text[:3000],
|
||
"user_id": user_id,
|
||
"username": username,
|
||
"chat_id": chat_id,
|
||
"agent_id": agent_id,
|
||
"type": message_type,
|
||
"timestamp": datetime.utcnow().isoformat()
|
||
}
|
||
if metadata and isinstance(metadata, dict):
|
||
payload["metadata"] = metadata
|
||
|
||
self.qdrant_client.upsert(
|
||
collection_name=messages_collection,
|
||
points=[
|
||
qmodels.PointStruct(
|
||
id=point_id,
|
||
vector=embedding,
|
||
payload=payload
|
||
)
|
||
]
|
||
)
|
||
|
||
# Optional shared agronomy memory:
|
||
# - never stores user/chat identifiers
|
||
# - supports review gate (pending vs approved)
|
||
if (
|
||
SHARED_AGRO_LIBRARY_ENABLED
|
||
and agent_id == "agromatrix"
|
||
and message_type in {"vision", "conversation"}
|
||
and isinstance(metadata, dict)
|
||
and metadata.get("deterministic_plant_id")
|
||
):
|
||
await self._store_shared_agronomy_memory(
|
||
message_text=message_text,
|
||
response_text=response_text,
|
||
metadata=metadata,
|
||
)
|
||
|
||
logger.debug(f"✅ Stored message in {messages_collection}: {point_id[:8]}...")
|
||
return True
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Failed to store message in {messages_collection}: {e}")
|
||
return False
|
||
|
||
async def _store_shared_agronomy_memory(
|
||
self,
|
||
message_text: str,
|
||
response_text: str,
|
||
metadata: Dict[str, Any],
|
||
) -> bool:
|
||
if not self.qdrant_client or not COHERE_API_KEY:
|
||
return False
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
import uuid
|
||
|
||
reviewed = bool(metadata.get("mentor_confirmed") or metadata.get("reviewed"))
|
||
collection = "agromatrix_shared_library"
|
||
if SHARED_AGRO_LIBRARY_REQUIRE_REVIEW and not reviewed:
|
||
collection = "agromatrix_shared_pending"
|
||
|
||
try:
|
||
self.qdrant_client.get_collection(collection)
|
||
except Exception:
|
||
self.qdrant_client.create_collection(
|
||
collection_name=collection,
|
||
vectors_config=qmodels.VectorParams(
|
||
size=1024,
|
||
distance=qmodels.Distance.COSINE,
|
||
),
|
||
)
|
||
|
||
compact = (
|
||
f"Plant case\nQuestion: {message_text[:800]}\n"
|
||
f"Answer: {response_text[:1200]}\n"
|
||
f"Candidates: {json.dumps(metadata.get('candidates', []), ensure_ascii=False)[:1200]}"
|
||
)
|
||
embedding = await self.get_embedding(compact[:2000])
|
||
if not embedding:
|
||
return False
|
||
|
||
payload = {
|
||
"text": compact[:3000],
|
||
"type": "plant_case",
|
||
"deterministic_plant_id": True,
|
||
"decision": metadata.get("decision"),
|
||
"confidence_threshold": metadata.get("confidence_threshold"),
|
||
"candidates": metadata.get("candidates", [])[:5],
|
||
"reviewed": reviewed,
|
||
"timestamp": datetime.utcnow().isoformat(),
|
||
}
|
||
self.qdrant_client.upsert(
|
||
collection_name=collection,
|
||
points=[qmodels.PointStruct(id=str(uuid.uuid4()), vector=embedding, payload=payload)],
|
||
)
|
||
return True
|
||
except Exception as e:
|
||
logger.debug(f"Shared agronomy memory store failed: {e}")
|
||
return False
|
||
|
||
async def register_pending_question(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
user_id: str,
|
||
agent_id: str,
|
||
question_text: str,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> bool:
|
||
if not self.pg_pool:
|
||
return False
|
||
text = (question_text or "").strip()
|
||
if not text:
|
||
return False
|
||
fp = self._question_fingerprint(text)
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
await conn.execute(
|
||
"""
|
||
INSERT INTO agent_pending_questions
|
||
(channel, chat_id, user_id, agent_id, question_text, question_fingerprint, status, metadata)
|
||
VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7::jsonb)
|
||
ON CONFLICT (agent_id, channel, chat_id, user_id, question_fingerprint, status)
|
||
DO NOTHING
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
text[:1200],
|
||
fp,
|
||
json.dumps(metadata or {}, ensure_ascii=False),
|
||
)
|
||
# Keep only last N open items.
|
||
await conn.execute(
|
||
"""
|
||
WITH ranked AS (
|
||
SELECT id, ROW_NUMBER() OVER (
|
||
PARTITION BY channel, chat_id, user_id, agent_id, status
|
||
ORDER BY created_at DESC
|
||
) AS rn
|
||
FROM agent_pending_questions
|
||
WHERE channel = $1
|
||
AND chat_id = $2
|
||
AND user_id = $3
|
||
AND agent_id = $4
|
||
AND status = 'pending'
|
||
)
|
||
UPDATE agent_pending_questions p
|
||
SET status = 'dismissed',
|
||
answered_at = NOW(),
|
||
metadata = COALESCE(p.metadata, '{}'::jsonb) || '{"reason":"overflow_trim"}'::jsonb
|
||
FROM ranked r
|
||
WHERE p.id = r.id
|
||
AND r.rn > $5
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
max(1, PENDING_QUESTIONS_LIMIT),
|
||
)
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"register_pending_question failed: {e}")
|
||
return False
|
||
|
||
async def resolve_pending_question(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
user_id: str,
|
||
agent_id: str,
|
||
answer_text: Optional[str] = None,
|
||
reason: str = "answered",
|
||
) -> bool:
|
||
if not self.pg_pool:
|
||
return False
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
row = await conn.fetchrow(
|
||
"""
|
||
WITH target AS (
|
||
SELECT id
|
||
FROM agent_pending_questions
|
||
WHERE channel = $1
|
||
AND chat_id = $2
|
||
AND user_id = $3
|
||
AND agent_id = $4
|
||
AND status = 'pending'
|
||
ORDER BY created_at ASC
|
||
LIMIT 1
|
||
)
|
||
UPDATE agent_pending_questions p
|
||
SET status = CASE WHEN $5 = 'dismissed' THEN 'dismissed' ELSE 'answered' END,
|
||
answered_at = NOW(),
|
||
metadata = COALESCE(p.metadata, '{}'::jsonb)
|
||
|| jsonb_build_object(
|
||
'resolution_reason', $5,
|
||
'answer_fingerprint', COALESCE($6, '')
|
||
)
|
||
FROM target t
|
||
WHERE p.id = t.id
|
||
RETURNING p.id
|
||
""",
|
||
channel,
|
||
chat_id,
|
||
user_id,
|
||
agent_id,
|
||
reason,
|
||
self._question_fingerprint(answer_text or "") if answer_text else "",
|
||
)
|
||
return bool(row)
|
||
except Exception as e:
|
||
logger.warning(f"resolve_pending_question failed: {e}")
|
||
return False
|
||
|
||
@staticmethod
|
||
def _to_qdrant_point_id(raw_id: Any) -> Any:
|
||
if isinstance(raw_id, int):
|
||
return raw_id
|
||
if isinstance(raw_id, float) and raw_id.is_integer():
|
||
return int(raw_id)
|
||
if isinstance(raw_id, str):
|
||
v = raw_id.strip()
|
||
if not v:
|
||
return raw_id
|
||
if v.isdigit():
|
||
try:
|
||
return int(v)
|
||
except Exception:
|
||
return v
|
||
return v
|
||
return raw_id
|
||
|
||
async def list_shared_pending_cases(self, limit: int = 50) -> List[Dict[str, Any]]:
|
||
if not self.qdrant_client or not SHARED_AGRO_LIBRARY_ENABLED:
|
||
return []
|
||
size = max(1, min(int(limit or 50), 200))
|
||
try:
|
||
points, _ = self.qdrant_client.scroll(
|
||
collection_name="agromatrix_shared_pending",
|
||
limit=size,
|
||
with_payload=True,
|
||
with_vectors=False,
|
||
)
|
||
except Exception as e:
|
||
logger.debug(f"list_shared_pending_cases failed: {e}")
|
||
return []
|
||
|
||
items: List[Dict[str, Any]] = []
|
||
for p in points or []:
|
||
payload = getattr(p, "payload", {}) or {}
|
||
text = str(payload.get("text") or "").strip()
|
||
timestamp = payload.get("timestamp") or ""
|
||
candidates = payload.get("candidates") if isinstance(payload.get("candidates"), list) else []
|
||
items.append(
|
||
{
|
||
"point_id": str(getattr(p, "id", "")),
|
||
"timestamp": timestamp,
|
||
"decision": payload.get("decision"),
|
||
"reviewed": bool(payload.get("reviewed")),
|
||
"excerpt": text[:240],
|
||
"candidates": candidates[:5],
|
||
}
|
||
)
|
||
items.sort(key=lambda x: x.get("timestamp") or "", reverse=True)
|
||
return items
|
||
|
||
async def review_shared_pending_case(
|
||
self,
|
||
point_id: str,
|
||
approve: bool,
|
||
reviewer: Optional[str] = None,
|
||
note: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
if not self.qdrant_client:
|
||
return {"ok": False, "error": "qdrant_unavailable"}
|
||
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
import uuid
|
||
|
||
pid = self._to_qdrant_point_id(point_id)
|
||
records = self.qdrant_client.retrieve(
|
||
collection_name="agromatrix_shared_pending",
|
||
ids=[pid],
|
||
with_payload=True,
|
||
with_vectors=True,
|
||
)
|
||
if not records:
|
||
return {"ok": False, "error": "not_found"}
|
||
|
||
point = records[0]
|
||
payload = dict(getattr(point, "payload", {}) or {})
|
||
now_iso = datetime.utcnow().isoformat()
|
||
payload["reviewed"] = bool(approve)
|
||
payload["review"] = {
|
||
"reviewer": (reviewer or "system")[:120],
|
||
"approved": bool(approve),
|
||
"note": (note or "")[:500],
|
||
"reviewed_at": now_iso,
|
||
}
|
||
|
||
library_point_id: Optional[str] = None
|
||
if approve:
|
||
vector = getattr(point, "vector", None)
|
||
if isinstance(vector, dict):
|
||
# Named vectors mode: pick first vector value.
|
||
vector = next(iter(vector.values()), None)
|
||
if not vector and COHERE_API_KEY:
|
||
basis = str(payload.get("text") or payload.get("assistant_response") or "")[:2000]
|
||
vector = await self.get_embedding(basis)
|
||
if not vector:
|
||
return {"ok": False, "error": "missing_vector"}
|
||
|
||
try:
|
||
self.qdrant_client.get_collection("agromatrix_shared_library")
|
||
except Exception:
|
||
self.qdrant_client.create_collection(
|
||
collection_name="agromatrix_shared_library",
|
||
vectors_config=qmodels.VectorParams(
|
||
size=len(vector),
|
||
distance=qmodels.Distance.COSINE,
|
||
),
|
||
)
|
||
|
||
library_point_id = str(uuid.uuid4())
|
||
payload["approved_at"] = now_iso
|
||
self.qdrant_client.upsert(
|
||
collection_name="agromatrix_shared_library",
|
||
points=[
|
||
qmodels.PointStruct(
|
||
id=library_point_id,
|
||
vector=vector,
|
||
payload=payload,
|
||
)
|
||
],
|
||
)
|
||
|
||
self.qdrant_client.delete(
|
||
collection_name="agromatrix_shared_pending",
|
||
points_selector=qmodels.PointIdsList(points=[pid]),
|
||
)
|
||
|
||
return {
|
||
"ok": True,
|
||
"approved": bool(approve),
|
||
"point_id": str(getattr(point, "id", point_id)),
|
||
"library_point_id": library_point_id,
|
||
}
|
||
except Exception as e:
|
||
logger.warning(f"review_shared_pending_case failed: {e}")
|
||
return {"ok": False, "error": str(e)}
|
||
|
||
def _chunk_document_text(
|
||
self,
|
||
text: str,
|
||
chunk_chars: int = 1200,
|
||
overlap_chars: int = 180,
|
||
) -> List[str]:
|
||
"""
|
||
Split document text into overlap-aware chunks for RAG indexing.
|
||
Keeps paragraph structure when possible.
|
||
"""
|
||
raw = re.sub(r"\r\n?", "\n", text or "").strip()
|
||
if not raw:
|
||
return []
|
||
|
||
paragraphs = [p.strip() for p in re.split(r"\n{2,}", raw) if p and p.strip()]
|
||
if not paragraphs:
|
||
return []
|
||
|
||
chunks: List[str] = []
|
||
current = ""
|
||
max_hard = max(chunk_chars, 600)
|
||
|
||
def _push_current() -> None:
|
||
nonlocal current
|
||
if current and len(current.strip()) >= 20:
|
||
chunks.append(current.strip())
|
||
current = ""
|
||
|
||
for para in paragraphs:
|
||
if len(para) > max_hard * 2:
|
||
_push_current()
|
||
i = 0
|
||
step = max_hard - max(80, min(overlap_chars, max_hard // 2))
|
||
while i < len(para):
|
||
part = para[i : i + max_hard]
|
||
if len(part.strip()) >= 20:
|
||
chunks.append(part.strip())
|
||
i += max(1, step)
|
||
continue
|
||
|
||
candidate = f"{current}\n\n{para}".strip() if current else para
|
||
if len(candidate) <= max_hard:
|
||
current = candidate
|
||
continue
|
||
|
||
_push_current()
|
||
if overlap_chars > 0 and chunks:
|
||
tail = chunks[-1][-overlap_chars:]
|
||
current = f"{tail}\n\n{para}".strip()
|
||
if len(current) > max_hard:
|
||
_push_current()
|
||
current = para
|
||
else:
|
||
current = para
|
||
|
||
_push_current()
|
||
return chunks
|
||
|
||
async def _next_document_version_no(
|
||
self,
|
||
agent_id: str,
|
||
doc_id: str,
|
||
) -> int:
|
||
if self.pg_pool:
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
value = await conn.fetchval(
|
||
"""
|
||
SELECT COALESCE(MAX(version_no), 0) + 1
|
||
FROM agent_document_versions
|
||
WHERE agent_id = $1
|
||
AND doc_id = $2
|
||
""",
|
||
(agent_id or "").lower(),
|
||
doc_id,
|
||
)
|
||
return max(1, int(value or 1))
|
||
except Exception as e:
|
||
logger.warning(f"next_document_version_no(pg) failed: {e}")
|
||
|
||
# Fallback: infer from existing chunk payloads in Qdrant.
|
||
if self.qdrant_client:
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
|
||
collection = f"{(agent_id or 'daarwizz').lower()}_docs"
|
||
points, _ = self.qdrant_client.scroll(
|
||
collection_name=collection,
|
||
scroll_filter=qmodels.Filter(
|
||
must=[
|
||
qmodels.FieldCondition(
|
||
key="doc_id",
|
||
match=qmodels.MatchValue(value=doc_id),
|
||
)
|
||
]
|
||
),
|
||
limit=256,
|
||
with_payload=True,
|
||
)
|
||
current_max = 0
|
||
for p in points or []:
|
||
payload = getattr(p, "payload", {}) or {}
|
||
ver = payload.get("version_no")
|
||
if isinstance(ver, int):
|
||
current_max = max(current_max, ver)
|
||
elif isinstance(ver, str) and ver.isdigit():
|
||
current_max = max(current_max, int(ver))
|
||
return current_max + 1 if current_max > 0 else 1
|
||
except Exception as e:
|
||
logger.debug(f"next_document_version_no(qdrant) fallback failed: {e}")
|
||
|
||
return 1
|
||
|
||
async def _latest_document_version_no(
|
||
self,
|
||
agent_id: str,
|
||
doc_id: str,
|
||
) -> int:
|
||
nxt = await self._next_document_version_no(agent_id=agent_id, doc_id=doc_id)
|
||
return max(0, int(nxt) - 1)
|
||
|
||
async def _record_document_version(
|
||
self,
|
||
agent_id: str,
|
||
doc_id: str,
|
||
version_no: int,
|
||
text: str,
|
||
file_name: Optional[str] = None,
|
||
dao_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
source: str = "ingest",
|
||
storage_ref: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
text_body = (text or "").strip()
|
||
text_hash = hashlib.sha256(text_body.encode("utf-8")).hexdigest() if text_body else ""
|
||
text_len = len(text_body)
|
||
preview = text_body[:DOC_VERSION_PREVIEW_CHARS] if text_body else ""
|
||
payload = metadata if isinstance(metadata, dict) else {}
|
||
|
||
if not self.pg_pool:
|
||
return {"ok": True, "version_no": int(version_no), "id": None}
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
row = await conn.fetchrow(
|
||
"""
|
||
INSERT INTO agent_document_versions
|
||
(agent_id, doc_id, version_no, text_hash, text_len, text_preview,
|
||
file_name, dao_id, user_id, storage_ref, source, metadata)
|
||
VALUES
|
||
($1, $2, $3, $4, $5, $6,
|
||
$7, $8, $9, $10, $11, $12::jsonb)
|
||
ON CONFLICT (agent_id, doc_id, version_no)
|
||
DO UPDATE SET
|
||
text_hash = EXCLUDED.text_hash,
|
||
text_len = EXCLUDED.text_len,
|
||
text_preview = EXCLUDED.text_preview,
|
||
file_name = EXCLUDED.file_name,
|
||
dao_id = EXCLUDED.dao_id,
|
||
user_id = EXCLUDED.user_id,
|
||
storage_ref = EXCLUDED.storage_ref,
|
||
source = EXCLUDED.source,
|
||
metadata = EXCLUDED.metadata
|
||
RETURNING id, version_no
|
||
""",
|
||
(agent_id or "").lower(),
|
||
doc_id,
|
||
int(version_no),
|
||
text_hash,
|
||
int(text_len),
|
||
preview,
|
||
file_name,
|
||
dao_id,
|
||
user_id,
|
||
storage_ref,
|
||
source,
|
||
json.dumps(payload),
|
||
)
|
||
return {
|
||
"ok": True,
|
||
"id": int(row["id"]) if row and row.get("id") is not None else None,
|
||
"version_no": int(row["version_no"]) if row and row.get("version_no") is not None else int(version_no),
|
||
}
|
||
except Exception as e:
|
||
logger.warning(f"record_document_version failed: {e}")
|
||
return {"ok": False, "error": str(e), "version_no": int(version_no)}
|
||
|
||
async def list_document_versions(
|
||
self,
|
||
agent_id: str,
|
||
doc_id: str,
|
||
limit: int = 20,
|
||
) -> List[Dict[str, Any]]:
|
||
rows_out: List[Dict[str, Any]] = []
|
||
if self.pg_pool:
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
rows = await conn.fetch(
|
||
"""
|
||
SELECT id, agent_id, doc_id, version_no, text_hash, text_len, text_preview,
|
||
file_name, dao_id, user_id, storage_ref, source, metadata, created_at
|
||
FROM agent_document_versions
|
||
WHERE agent_id = $1
|
||
AND doc_id = $2
|
||
ORDER BY version_no DESC
|
||
LIMIT $3
|
||
""",
|
||
(agent_id or "").lower(),
|
||
doc_id,
|
||
max(1, min(int(limit or 20), 200)),
|
||
)
|
||
for r in rows:
|
||
meta_raw = r["metadata"]
|
||
if isinstance(meta_raw, dict):
|
||
meta_obj = meta_raw
|
||
elif isinstance(meta_raw, str):
|
||
try:
|
||
parsed = json.loads(meta_raw)
|
||
meta_obj = parsed if isinstance(parsed, dict) else {"raw": parsed}
|
||
except Exception:
|
||
meta_obj = {"raw": meta_raw}
|
||
else:
|
||
meta_obj = {}
|
||
rows_out.append(
|
||
{
|
||
"id": int(r["id"]),
|
||
"agent_id": r["agent_id"],
|
||
"doc_id": r["doc_id"],
|
||
"version_no": int(r["version_no"]),
|
||
"text_hash": r["text_hash"],
|
||
"text_len": int(r["text_len"] or 0),
|
||
"text_preview": r["text_preview"],
|
||
"file_name": r["file_name"],
|
||
"dao_id": r["dao_id"],
|
||
"user_id": r["user_id"],
|
||
"storage_ref": r["storage_ref"],
|
||
"source": r["source"],
|
||
"metadata": meta_obj,
|
||
"created_at": r["created_at"].isoformat() if r["created_at"] else None,
|
||
}
|
||
)
|
||
return rows_out
|
||
except Exception as e:
|
||
logger.warning(f"list_document_versions failed: {e}")
|
||
|
||
# PG unavailable fallback: aggregate distinct versions from Qdrant payloads.
|
||
if self.qdrant_client:
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
|
||
collection = f"{(agent_id or 'daarwizz').lower()}_docs"
|
||
offset = None
|
||
seen: Dict[int, Dict[str, Any]] = {}
|
||
max_points = max(64, min(int(limit or 20) * 80, 4096))
|
||
fetched = 0
|
||
while fetched < max_points:
|
||
points, next_offset = self.qdrant_client.scroll(
|
||
collection_name=collection,
|
||
scroll_filter=qmodels.Filter(
|
||
must=[
|
||
qmodels.FieldCondition(
|
||
key="doc_id",
|
||
match=qmodels.MatchValue(value=doc_id),
|
||
)
|
||
]
|
||
),
|
||
offset=offset,
|
||
limit=256,
|
||
with_payload=True,
|
||
)
|
||
if not points:
|
||
break
|
||
fetched += len(points)
|
||
for p in points:
|
||
payload = getattr(p, "payload", {}) or {}
|
||
ver_raw = payload.get("version_no")
|
||
if isinstance(ver_raw, int):
|
||
ver = ver_raw
|
||
elif isinstance(ver_raw, str) and ver_raw.isdigit():
|
||
ver = int(ver_raw)
|
||
else:
|
||
ver = 1
|
||
|
||
existing = seen.get(ver)
|
||
ts = payload.get("timestamp")
|
||
if not existing or (ts and str(ts) > str(existing.get("created_at") or "")):
|
||
seen[ver] = {
|
||
"id": None,
|
||
"agent_id": (agent_id or "").lower(),
|
||
"doc_id": doc_id,
|
||
"version_no": int(ver),
|
||
"text_hash": None,
|
||
"text_len": None,
|
||
"text_preview": None,
|
||
"file_name": payload.get("file_name"),
|
||
"dao_id": payload.get("dao_id"),
|
||
"user_id": payload.get("user_id"),
|
||
"storage_ref": payload.get("storage_ref"),
|
||
"source": payload.get("source") or "ingest",
|
||
"metadata": payload.get("metadata") or {},
|
||
"created_at": ts,
|
||
}
|
||
if not next_offset:
|
||
break
|
||
offset = next_offset
|
||
rows_out = sorted(seen.values(), key=lambda x: int(x.get("version_no") or 0), reverse=True)[: max(1, min(int(limit or 20), 200))]
|
||
except Exception:
|
||
pass
|
||
|
||
return rows_out
|
||
|
||
def _build_doc_filter(
|
||
self,
|
||
doc_id: str,
|
||
dao_id: Optional[str] = None,
|
||
):
|
||
from qdrant_client.http import models as qmodels
|
||
|
||
must_conditions = [
|
||
qmodels.FieldCondition(
|
||
key="doc_id",
|
||
match=qmodels.MatchValue(value=doc_id),
|
||
)
|
||
]
|
||
if dao_id:
|
||
must_conditions.append(
|
||
qmodels.FieldCondition(
|
||
key="dao_id",
|
||
match=qmodels.MatchValue(value=dao_id),
|
||
)
|
||
)
|
||
return qmodels.Filter(must=must_conditions)
|
||
|
||
def _delete_document_points(
|
||
self,
|
||
collection: str,
|
||
doc_id: str,
|
||
dao_id: Optional[str] = None,
|
||
) -> bool:
|
||
if not self.qdrant_client:
|
||
return False
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
|
||
self.qdrant_client.delete(
|
||
collection_name=collection,
|
||
points_selector=qmodels.FilterSelector(
|
||
filter=self._build_doc_filter(doc_id=doc_id, dao_id=dao_id)
|
||
),
|
||
)
|
||
return True
|
||
except Exception as e:
|
||
logger.warning(f"delete_document_points failed for {collection}/{doc_id}: {e}")
|
||
return False
|
||
|
||
async def ingest_document_chunks(
|
||
self,
|
||
agent_id: str,
|
||
doc_id: str,
|
||
file_name: Optional[str],
|
||
text: str,
|
||
dao_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
replace_existing: bool = False,
|
||
version_no: Optional[int] = None,
|
||
source: str = "ingest",
|
||
storage_ref: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Ingest normalized document chunks into {agent_id}_docs collection.
|
||
"""
|
||
if not self.qdrant_client:
|
||
return {"ok": False, "error": "qdrant_unavailable"}
|
||
if not COHERE_API_KEY:
|
||
return {"ok": False, "error": "cohere_unavailable"}
|
||
|
||
body = (text or "").strip()
|
||
if not body:
|
||
return {"ok": False, "error": "empty_document"}
|
||
|
||
chunks = self._chunk_document_text(body)
|
||
if not chunks:
|
||
return {"ok": False, "error": "no_chunks"}
|
||
|
||
collection = f"{(agent_id or 'daarwizz').lower()}_docs"
|
||
stored_points = []
|
||
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
import uuid
|
||
|
||
try:
|
||
self.qdrant_client.get_collection(collection)
|
||
except Exception:
|
||
self.qdrant_client.create_collection(
|
||
collection_name=collection,
|
||
vectors_config=qmodels.VectorParams(
|
||
size=1024,
|
||
distance=qmodels.Distance.COSINE,
|
||
),
|
||
)
|
||
logger.info(f"✅ Created collection: {collection}")
|
||
|
||
total = len(chunks)
|
||
resolved_version_no = int(version_no or 0) or await self._next_document_version_no(agent_id=agent_id, doc_id=doc_id)
|
||
for idx, chunk in enumerate(chunks):
|
||
emb = await self.get_embedding(chunk[:2000])
|
||
if not emb:
|
||
continue
|
||
payload: Dict[str, Any] = {
|
||
"text": chunk[:6000],
|
||
"doc_id": doc_id,
|
||
"file_name": file_name,
|
||
"agent_id": (agent_id or "").lower(),
|
||
"dao_id": dao_id,
|
||
"user_id": user_id,
|
||
"chunk_index": idx,
|
||
"chunks_total": total,
|
||
"type": "document_chunk",
|
||
"version_no": int(resolved_version_no),
|
||
"source": source,
|
||
"storage_ref": storage_ref,
|
||
"timestamp": datetime.utcnow().isoformat(),
|
||
}
|
||
if isinstance(metadata, dict) and metadata:
|
||
payload["metadata"] = metadata
|
||
stored_points.append(
|
||
qmodels.PointStruct(
|
||
id=str(uuid.uuid4()),
|
||
vector=emb,
|
||
payload=payload,
|
||
)
|
||
)
|
||
|
||
if not stored_points:
|
||
return {"ok": False, "error": "embedding_failed"}
|
||
|
||
# Keep previous versions in the same collection when updating.
|
||
# Query path will select only the latest version_no for doc_id.
|
||
|
||
self.qdrant_client.upsert(collection_name=collection, points=stored_points)
|
||
version_row = await self._record_document_version(
|
||
agent_id=agent_id,
|
||
doc_id=doc_id,
|
||
version_no=resolved_version_no,
|
||
text=body,
|
||
file_name=file_name,
|
||
dao_id=dao_id,
|
||
user_id=user_id,
|
||
metadata=metadata,
|
||
source=source,
|
||
storage_ref=storage_ref,
|
||
)
|
||
return {
|
||
"ok": True,
|
||
"doc_id": doc_id,
|
||
"version_no": int(resolved_version_no),
|
||
"version_id": version_row.get("id"),
|
||
"chunks_total": len(chunks),
|
||
"chunks_stored": len(stored_points),
|
||
"replaced_existing": bool(replace_existing),
|
||
"collection": collection,
|
||
}
|
||
except Exception as e:
|
||
logger.warning(f"ingest_document_chunks failed for {collection}: {e}")
|
||
return {"ok": False, "error": str(e)}
|
||
|
||
async def update_document_chunks(
|
||
self,
|
||
agent_id: str,
|
||
doc_id: str,
|
||
file_name: Optional[str],
|
||
text: str,
|
||
dao_id: Optional[str] = None,
|
||
user_id: Optional[str] = None,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
storage_ref: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Update existing document content with version bump.
|
||
Keeps the same logical doc_id and replaces indexed chunks.
|
||
"""
|
||
next_version = await self._next_document_version_no(agent_id=agent_id, doc_id=doc_id)
|
||
result = await self.ingest_document_chunks(
|
||
agent_id=agent_id,
|
||
doc_id=doc_id,
|
||
file_name=file_name,
|
||
text=text,
|
||
dao_id=dao_id,
|
||
user_id=user_id,
|
||
metadata=metadata,
|
||
replace_existing=False,
|
||
version_no=next_version,
|
||
source="update",
|
||
storage_ref=storage_ref,
|
||
)
|
||
if result.get("ok"):
|
||
result["updated"] = True
|
||
result["replaced_existing"] = True
|
||
return result
|
||
|
||
async def query_document_chunks(
|
||
self,
|
||
agent_id: str,
|
||
question: str,
|
||
doc_id: Optional[str] = None,
|
||
dao_id: Optional[str] = None,
|
||
limit: int = 5,
|
||
) -> Dict[str, Any]:
|
||
"""
|
||
Retrieve top document chunks from {agent_id}_docs for a question.
|
||
"""
|
||
if not self.qdrant_client:
|
||
return {"ok": False, "error": "qdrant_unavailable", "chunks": []}
|
||
if not COHERE_API_KEY:
|
||
return {"ok": False, "error": "cohere_unavailable", "chunks": []}
|
||
|
||
q = (question or "").strip()
|
||
if not q:
|
||
return {"ok": False, "error": "empty_question", "chunks": []}
|
||
|
||
embedding = await self.get_embedding(q[:2000])
|
||
if not embedding:
|
||
return {"ok": False, "error": "embedding_failed", "chunks": []}
|
||
|
||
collection = f"{(agent_id or 'daarwizz').lower()}_docs"
|
||
|
||
try:
|
||
from qdrant_client.http import models as qmodels
|
||
must_conditions = []
|
||
if doc_id:
|
||
latest_ver = await self._latest_document_version_no(agent_id=agent_id, doc_id=doc_id)
|
||
must_conditions.append(
|
||
qmodels.FieldCondition(
|
||
key="doc_id",
|
||
match=qmodels.MatchValue(value=doc_id),
|
||
)
|
||
)
|
||
if latest_ver > 0:
|
||
must_conditions.append(
|
||
qmodels.FieldCondition(
|
||
key="version_no",
|
||
match=qmodels.MatchValue(value=int(latest_ver)),
|
||
)
|
||
)
|
||
if dao_id:
|
||
must_conditions.append(
|
||
qmodels.FieldCondition(
|
||
key="dao_id",
|
||
match=qmodels.MatchValue(value=dao_id),
|
||
)
|
||
)
|
||
query_filter = qmodels.Filter(must=must_conditions) if must_conditions else None
|
||
|
||
rows = self.qdrant_client.search(
|
||
collection_name=collection,
|
||
query_vector=embedding,
|
||
query_filter=query_filter,
|
||
limit=max(1, min(int(limit or 5), 12)),
|
||
with_payload=True,
|
||
)
|
||
except Exception as e:
|
||
logger.debug(f"query_document_chunks search failed for {collection}: {e}")
|
||
return {"ok": False, "error": "search_failed", "chunks": [], "collection": collection}
|
||
|
||
hits: List[Dict[str, Any]] = []
|
||
for row in rows or []:
|
||
score = float(getattr(row, "score", 0.0) or 0.0)
|
||
if score < 0.30:
|
||
continue
|
||
payload = getattr(row, "payload", {}) or {}
|
||
text = str(payload.get("text") or "").strip()
|
||
if len(text) < 10:
|
||
continue
|
||
hits.append(
|
||
{
|
||
"text": text,
|
||
"score": score,
|
||
"doc_id": payload.get("doc_id"),
|
||
"file_name": payload.get("file_name"),
|
||
"chunk_index": payload.get("chunk_index"),
|
||
"chunks_total": payload.get("chunks_total"),
|
||
"version_no": payload.get("version_no"),
|
||
}
|
||
)
|
||
|
||
return {
|
||
"ok": bool(hits),
|
||
"chunks": hits,
|
||
"collection": collection,
|
||
"doc_id": doc_id,
|
||
}
|
||
|
||
async def store_interaction(
|
||
self,
|
||
channel: str,
|
||
chat_id: str,
|
||
user_id: str,
|
||
agent_id: str,
|
||
username: Optional[str],
|
||
user_message: str,
|
||
assistant_response: str,
|
||
metadata: Optional[Dict[str, Any]] = None,
|
||
) -> bool:
|
||
# Backward-compatible wrapper for older call sites.
|
||
return await self.store_message(
|
||
agent_id=agent_id,
|
||
user_id=user_id,
|
||
username=username,
|
||
message_text=user_message,
|
||
response_text=assistant_response,
|
||
chat_id=chat_id,
|
||
message_type="conversation",
|
||
metadata=metadata,
|
||
)
|
||
|
||
async def update_session_state(
|
||
self,
|
||
conversation_id: str,
|
||
**updates
|
||
):
|
||
"""Update session state after interaction"""
|
||
if not self.pg_pool or not conversation_id:
|
||
return
|
||
|
||
try:
|
||
async with self.pg_pool.acquire() as conn:
|
||
# Build dynamic update
|
||
set_clauses = ["updated_at = NOW()"]
|
||
params = [conversation_id]
|
||
param_idx = 2
|
||
|
||
allowed_fields = [
|
||
'last_user_id', 'last_user_nick',
|
||
'active_topic', 'context_open',
|
||
'last_media_handled', 'last_answer_fingerprint',
|
||
'trust_mode', 'apprentice_mode'
|
||
]
|
||
|
||
for field, value in updates.items():
|
||
if field in allowed_fields:
|
||
set_clauses.append(f"{field} = ${param_idx}")
|
||
params.append(value)
|
||
param_idx += 1
|
||
|
||
query = f"""
|
||
UPDATE agent_session_state
|
||
SET {', '.join(set_clauses)}
|
||
WHERE conversation_id = $1
|
||
"""
|
||
await conn.execute(query, *params)
|
||
|
||
except Exception as e:
|
||
logger.warning(f"Session state update failed: {e}")
|
||
|
||
|
||
# Global instance
|
||
memory_retrieval = MemoryRetrieval()
|