Files
microdao-daarion/services/router/memory_retrieval.py
Apple e9dedffa48 feat(production): sync all modified production files to git
Includes updates across gateway, router, node-worker, memory-service,
aurora-service, swapper, sofiia-console UI and node2 infrastructure:

- gateway-bot: Dockerfile, http_api.py, druid/aistalk prompts, doc_service
- services/router: main.py, router-config.yml, fabric_metrics, memory_retrieval,
  offload_client, prompt_builder
- services/node-worker: worker.py, main.py, config.py, fabric_metrics
- services/memory-service: Dockerfile, database.py, main.py, requirements
- services/aurora-service: main.py (+399), kling.py, quality_report.py
- services/swapper-service: main.py, swapper_config_node2.yaml
- services/sofiia-console: static/index.html (console UI update)
- config: agent_registry, crewai_agents/teams, router_agents
- ops/fabric_preflight.sh: updated preflight checks
- router-config.yml, docker-compose.node2.yml: infra updates
- docs: NODA1-AGENT-ARCHITECTURE, fabric_contract updated

Made-with: Cursor
2026-03-03 07:13:29 -08:00

1968 lines
78 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Universal Memory Retrieval Pipeline v4.0
This module implements the 4-layer memory retrieval for ALL agents:
- L0: Working Context (from request)
- L1: Session State Memory (SSM) - from Postgres
- L2: Platform Identity & Roles (PIR) - from Postgres + Neo4j
- L3: Organizational Memory (OM) - from Postgres + Neo4j + Qdrant
The pipeline generates a "memory brief" that is injected into the LLM context.
Collections per agent:
- {agent_id}_messages - chat history with embeddings
- {agent_id}_memory_items - facts, preferences
- {agent_id}_docs - knowledge base documents
"""
import os
import json
import logging
import re
import hashlib
from time import monotonic
from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field
from datetime import datetime
import httpx
import asyncpg
logger = logging.getLogger(__name__)
# Configuration
POSTGRES_URL = os.getenv("DATABASE_URL", "postgresql://daarion:DaarionDB2026!@dagi-postgres:5432/daarion_memory")
QDRANT_HOST = os.getenv("QDRANT_HOST", "qdrant")
QDRANT_PORT = int(os.getenv("QDRANT_PORT", "6333"))
COHERE_API_KEY = os.getenv("COHERE_API_KEY", "")
NEO4J_BOLT_URL = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j")
PENDING_QUESTIONS_LIMIT = int(os.getenv("AGENT_PENDING_QUESTIONS_LIMIT", "5"))
SHARED_AGRO_LIBRARY_ENABLED = os.getenv("AGROMATRIX_SHARED_LIBRARY_ENABLED", "true").lower() == "true"
SHARED_AGRO_LIBRARY_REQUIRE_REVIEW = os.getenv("AGROMATRIX_SHARED_LIBRARY_REQUIRE_REVIEW", "true").lower() == "true"
DOC_VERSION_PREVIEW_CHARS = int(os.getenv("DOC_VERSION_PREVIEW_CHARS", "240"))
WARNING_THROTTLE_SECONDS = float(os.getenv("MEMORY_RETRIEVAL_WARNING_THROTTLE_S", "60") or "60")
_warning_last_ts: Dict[str, float] = {}
def _warning_throttled(key: str, message: str) -> None:
"""Emit repetitive warnings at most once per throttle window."""
if WARNING_THROTTLE_SECONDS <= 0:
logger.warning(message)
return
now = monotonic()
last = _warning_last_ts.get(key, 0.0)
if now - last >= WARNING_THROTTLE_SECONDS:
_warning_last_ts[key] = now
logger.warning(message)
@dataclass
class UserIdentity:
"""Platform user identity"""
platform_user_id: Optional[str] = None
channel: str = "telegram"
channel_user_id: str = ""
username: Optional[str] = None
display_name: Optional[str] = None
roles: List[str] = field(default_factory=list)
is_mentor: bool = False
first_seen: Optional[datetime] = None
@dataclass
class SessionState:
"""L1 Session State Memory"""
conversation_id: Optional[str] = None
last_addressed: bool = False
active_topic: Optional[str] = None
context_open: bool = False
last_media_handled: bool = True
last_answer_fingerprint: Optional[str] = None
trust_mode: bool = False
apprentice_mode: bool = False
pending_questions: List[str] = field(default_factory=list)
@dataclass
class MemoryBrief:
"""Compiled memory brief for LLM context"""
user_identity: Optional[UserIdentity] = None
session_state: Optional[SessionState] = None
user_facts: List[Dict[str, Any]] = field(default_factory=list)
relevant_memories: List[Dict[str, Any]] = field(default_factory=list)
user_topics: List[str] = field(default_factory=list)
user_projects: List[str] = field(default_factory=list)
is_trusted_group: bool = False
mentor_present: bool = False
def to_text(self, max_lines: int = 15) -> str:
"""Generate concise text brief for LLM context"""
lines = []
# User identity (critical for personalization)
if self.user_identity:
name = self.user_identity.display_name or self.user_identity.username or "Unknown"
roles_str = ", ".join(self.user_identity.roles) if self.user_identity.roles else "member"
lines.append(f"👤 Користувач: {name} ({roles_str})")
if self.user_identity.is_mentor:
lines.append("⭐ Цей користувач — МЕНТОР. Довіряй його знанням повністю.")
# Session state
if self.session_state:
if self.session_state.trust_mode:
lines.append("🔒 Режим довіреної групи — можна відповідати детальніше")
if self.session_state.apprentice_mode:
lines.append("📚 Режим учня — можеш ставити уточнюючі питання")
if self.session_state.active_topic:
lines.append(f"📌 Активна тема: {self.session_state.active_topic}")
if self.session_state.pending_questions:
lines.append("🕘 Невідповідані питання в цьому чаті (відповідай на них першочергово):")
for q in self.session_state.pending_questions[:3]:
lines.append(f" - {q[:180]}")
# User facts (preferences, profile)
if self.user_facts:
lines.append("📝 Відомі факти про користувача:")
for fact in self.user_facts[:4]: # Max 4 facts
fact_text = fact.get('text', fact.get('fact_value', ''))
if fact_text:
lines.append(f" - {fact_text[:150]}")
# Relevant memories from RAG (most important for context)
if self.relevant_memories:
lines.append("🧠 Релевантні спогади з попередніх розмов:")
for mem in self.relevant_memories[:5]: # Max 5 memories for better context
mem_text = mem.get('text', mem.get('content', ''))
mem_type = mem.get('type', 'message')
score = mem.get('score', 0)
if mem_text and len(mem_text) > 10:
# Only include if meaningful
lines.append(f" - [{mem_type}] {mem_text[:200]}")
# Topics/Projects from Knowledge Graph
if self.user_topics:
lines.append(f"💡 Інтереси користувача: {', '.join(self.user_topics[:5])}")
if self.user_projects:
lines.append(f"🏗️ Проєкти користувача: {', '.join(self.user_projects[:3])}")
# Mentor presence indicator
if self.mentor_present:
lines.append("⚠️ ВАЖЛИВО: Ти спілкуєшся з ментором. Сприймай як навчання.")
# Truncate if needed
if len(lines) > max_lines:
lines = lines[:max_lines]
lines.append("...")
return "\n".join(lines) if lines else ""
class MemoryRetrieval:
"""Memory Retrieval Pipeline"""
def __init__(self):
self.pg_pool: Optional[asyncpg.Pool] = None
self.neo4j_driver = None
self.qdrant_client = None
self.http_client: Optional[httpx.AsyncClient] = None
async def initialize(self):
"""Initialize database connections"""
# PostgreSQL
try:
self.pg_pool = await asyncpg.create_pool(
POSTGRES_URL,
min_size=2,
max_size=10
)
logger.info("✅ Memory Retrieval: PostgreSQL connected")
except Exception as e:
logger.warning(f"⚠️ Memory Retrieval: PostgreSQL not available: {e}")
# Neo4j
try:
from neo4j import AsyncGraphDatabase
self.neo4j_driver = AsyncGraphDatabase.driver(
NEO4J_BOLT_URL,
auth=(NEO4J_USER, NEO4J_PASSWORD)
)
await self.neo4j_driver.verify_connectivity()
logger.info("✅ Memory Retrieval: Neo4j connected")
except Exception as e:
logger.warning(f"⚠️ Memory Retrieval: Neo4j not available: {e}")
# Qdrant
try:
from qdrant_client import QdrantClient
self.qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
self.qdrant_client.get_collections() # Test connection
logger.info("✅ Memory Retrieval: Qdrant connected")
except Exception as e:
logger.warning(f"⚠️ Memory Retrieval: Qdrant not available: {e}")
# HTTP client for embeddings
self.http_client = httpx.AsyncClient(timeout=30.0)
await self._ensure_aux_tables()
async def close(self):
"""Close connections"""
if self.pg_pool:
await self.pg_pool.close()
if self.neo4j_driver:
await self.neo4j_driver.close()
if self.http_client:
await self.http_client.aclose()
async def _ensure_aux_tables(self):
"""Create auxiliary tables used by agent runtime policies."""
if not self.pg_pool:
return
try:
async with self.pg_pool.acquire() as conn:
await conn.execute(
"""
CREATE TABLE IF NOT EXISTS agent_session_state (
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
user_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
conversation_id TEXT NOT NULL,
last_user_id TEXT,
last_user_nick TEXT,
active_topic TEXT,
context_open BOOLEAN NOT NULL DEFAULT FALSE,
last_media_handled BOOLEAN NOT NULL DEFAULT TRUE,
last_answer_fingerprint TEXT,
trust_mode BOOLEAN NOT NULL DEFAULT FALSE,
apprentice_mode BOOLEAN NOT NULL DEFAULT FALSE,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (channel, chat_id, user_id, agent_id)
);
CREATE INDEX IF NOT EXISTS idx_agent_session_state_conv
ON agent_session_state (conversation_id);
CREATE TABLE IF NOT EXISTS agent_pending_questions (
id BIGSERIAL PRIMARY KEY,
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
user_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
question_text TEXT NOT NULL,
question_fingerprint TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
answered_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb
);
CREATE INDEX IF NOT EXISTS idx_agent_pending_questions_scope
ON agent_pending_questions (agent_id, channel, chat_id, user_id, status, created_at DESC);
CREATE UNIQUE INDEX IF NOT EXISTS idx_agent_pending_questions_unique_open
ON agent_pending_questions (agent_id, channel, chat_id, user_id, question_fingerprint, status);
CREATE TABLE IF NOT EXISTS agent_document_versions (
id BIGSERIAL PRIMARY KEY,
agent_id TEXT NOT NULL,
doc_id TEXT NOT NULL,
version_no INTEGER NOT NULL,
text_hash TEXT NOT NULL,
text_len INTEGER NOT NULL DEFAULT 0,
text_preview TEXT,
file_name TEXT,
dao_id TEXT,
user_id TEXT,
storage_ref TEXT,
source TEXT NOT NULL DEFAULT 'ingest',
metadata JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (agent_id, doc_id, version_no)
);
CREATE INDEX IF NOT EXISTS idx_agent_document_versions_latest
ON agent_document_versions (agent_id, doc_id, version_no DESC);
"""
)
except Exception as e:
logger.warning(f"Aux tables init failed: {e}")
# =========================================================================
# L2: Platform Identity Resolution
# =========================================================================
async def resolve_identity(
self,
channel: str,
channel_user_id: str,
username: Optional[str] = None,
display_name: Optional[str] = None
) -> UserIdentity:
"""Resolve or create platform user identity"""
identity = UserIdentity(
channel=channel,
channel_user_id=channel_user_id,
username=username,
display_name=display_name
)
if not self.pg_pool:
return identity
try:
async with self.pg_pool.acquire() as conn:
# Use the resolve_platform_user function
platform_user_id = await conn.fetchval(
"SELECT resolve_platform_user($1, $2, $3, $4)",
channel, channel_user_id, username, display_name
)
identity.platform_user_id = str(platform_user_id) if platform_user_id else None
# Get roles
if platform_user_id:
roles = await conn.fetch("""
SELECT r.code FROM user_roles ur
JOIN platform_roles r ON ur.role_id = r.id
WHERE ur.platform_user_id = $1 AND ur.revoked_at IS NULL
""", platform_user_id)
identity.roles = [r['code'] for r in roles]
# Check if mentor
is_mentor = await conn.fetchval(
"SELECT is_mentor($1, $2)",
channel_user_id, username
)
identity.is_mentor = bool(is_mentor)
except Exception as e:
logger.debug(f"Identity resolution fallback: {e}")
return identity
# =========================================================================
# L1: Session State
# =========================================================================
async def get_session_state(
self,
channel: str,
chat_id: str,
thread_id: Optional[str] = None,
agent_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> SessionState:
"""Get or create session state for conversation"""
state = SessionState()
if not self.pg_pool:
return state
try:
async with self.pg_pool.acquire() as conn:
if agent_id and user_id:
conv_id = self._build_conversation_id(channel, chat_id, user_id, agent_id)
row = await conn.fetchrow(
"""
SELECT conversation_id, active_topic, context_open, last_media_handled,
last_answer_fingerprint, trust_mode, apprentice_mode
FROM agent_session_state
WHERE channel = $1
AND chat_id = $2
AND user_id = $3
AND agent_id = $4
""",
channel,
chat_id,
user_id,
agent_id,
)
if not row:
await conn.execute(
"""
INSERT INTO agent_session_state
(channel, chat_id, user_id, agent_id, conversation_id)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (channel, chat_id, user_id, agent_id) DO NOTHING
""",
channel,
chat_id,
user_id,
agent_id,
conv_id,
)
state.conversation_id = conv_id
else:
state.conversation_id = str(row.get("conversation_id") or conv_id)
state.active_topic = row.get("active_topic")
state.context_open = bool(row.get("context_open", False))
state.last_media_handled = bool(row.get("last_media_handled", True))
state.last_answer_fingerprint = row.get("last_answer_fingerprint")
state.trust_mode = bool(row.get("trust_mode", False))
state.apprentice_mode = bool(row.get("apprentice_mode", False))
else:
state.conversation_id = self._build_conversation_id(
channel,
chat_id,
user_id or "unknown",
agent_id or "agent",
)
if agent_id and user_id:
pending_rows = await conn.fetch(
"""
SELECT question_text
FROM agent_pending_questions
WHERE channel = $1
AND chat_id = $2
AND user_id = $3
AND agent_id = $4
AND status = 'pending'
ORDER BY created_at ASC
LIMIT $5
""",
channel,
chat_id,
user_id,
agent_id,
PENDING_QUESTIONS_LIMIT,
)
state.pending_questions = [
str(r.get("question_text") or "").strip()
for r in pending_rows
if str(r.get("question_text") or "").strip()
]
except Exception as e:
logger.warning(f"Session state retrieval failed: {e}")
return state
# =========================================================================
# L3: Memory Retrieval (Facts + Semantic)
# =========================================================================
async def get_user_facts(
self,
platform_user_id: Optional[str],
limit: int = 5
) -> List[Dict[str, Any]]:
"""Get user's explicit facts from PostgreSQL"""
if not self.pg_pool or not platform_user_id:
return []
try:
async with self.pg_pool.acquire() as conn:
rows = await conn.fetch("""
SELECT type, text, summary, confidence, visibility
FROM helion_memory_items
WHERE platform_user_id = $1
AND archived_at IS NULL
AND (expires_at IS NULL OR expires_at > NOW())
AND type IN ('preference', 'profile_fact')
ORDER BY confidence DESC, updated_at DESC
LIMIT $2
""", platform_user_id, limit)
return [dict(r) for r in rows]
except Exception as e:
logger.warning(f"User facts retrieval failed: {e}")
return []
async def get_embedding(self, text: str) -> Optional[List[float]]:
"""Get embedding from Cohere API"""
if not COHERE_API_KEY or not self.http_client:
return None
try:
response = await self.http_client.post(
"https://api.cohere.ai/v1/embed",
headers={
"Authorization": f"Bearer {COHERE_API_KEY}",
"Content-Type": "application/json"
},
json={
"texts": [text],
"model": "embed-multilingual-v3.0",
"input_type": "search_query",
"truncate": "END"
}
)
if response.status_code == 200:
data = response.json()
return data.get("embeddings", [[]])[0]
except Exception as e:
logger.warning(f"Embedding generation failed: {e}")
return None
async def search_memories(
self,
query: str,
agent_id: str = "helion",
platform_user_id: Optional[str] = None,
chat_id: Optional[str] = None,
user_id: Optional[str] = None,
visibility: str = "platform",
limit: int = 5
) -> List[Dict[str, Any]]:
"""Semantic search in Qdrant across agent-specific collections"""
if not self.qdrant_client:
return []
# Get embedding
embedding = await self.get_embedding(query)
if not embedding:
return []
all_results = []
q = (query or "").lower()
# If user explicitly asks about documents/catalogs, prefer knowledge base docs over chat snippets.
is_doc_query = any(k in q for k in ["pdf", "каталог", "каталоз", "документ", "файл", "стор", "page", "pages"])
# Simple keyword gate to avoid irrelevant chat snippets dominating doc queries.
# Example: when asking "з каталогу Defenda 2026 ... гліфосат", old "Бокаші" messages may match too well.
topic_keywords: List[str] = []
for kw in ["defenda", "ifagri", "bayer", "гліфосат", "glyphos", "глифос", "npk", "мінерал", "добрив", "гербіц", "фунгіц", "інсектиц"]:
if kw in q:
topic_keywords.append(kw)
# Dynamic collection names based on agent_id
memory_items_collection = f"{agent_id}_memory_items"
messages_collection = f"{agent_id}_messages"
docs_collection = f"{agent_id}_docs"
try:
from qdrant_client.http import models as qmodels
# Search 1: {agent_id}_memory_items (facts, preferences)
try:
must_conditions = []
if platform_user_id:
must_conditions.append(
qmodels.FieldCondition(
key="platform_user_id",
match=qmodels.MatchValue(value=platform_user_id)
)
)
search_filter = qmodels.Filter(must=must_conditions) if must_conditions else None
results = self.qdrant_client.search(
collection_name=memory_items_collection,
query_vector=embedding,
query_filter=search_filter,
limit=limit,
with_payload=True
)
for r in results:
if r.score > 0.3: # Threshold for relevance
all_results.append({
"text": r.payload.get("text", ""),
"type": r.payload.get("type", "fact"),
"confidence": r.payload.get("confidence", 0.5),
"score": r.score,
"source": "memory_items"
})
except Exception as e:
logger.debug(f"{memory_items_collection} search: {e}")
# Search 2: {agent_id}_messages (chat history)
try:
msg_filter = None
if chat_id:
# Payload schema differs across ingesters: some use chat_id, others channel_id.
msg_filter = qmodels.Filter(
should=[
qmodels.FieldCondition(key="chat_id", match=qmodels.MatchValue(value=str(chat_id))),
qmodels.FieldCondition(key="channel_id", match=qmodels.MatchValue(value=str(chat_id))),
]
)
results = self.qdrant_client.search(
collection_name=messages_collection,
query_vector=embedding,
query_filter=msg_filter,
limit=limit,
with_payload=True
)
for r in results:
# Higher threshold for messages; even higher when user asks about docs to avoid pulling old chatter.
msg_thresh = 0.5 if is_doc_query else 0.4
if r.score > msg_thresh:
text = self._extract_message_text(r.payload)
# Skip very short or system messages
if len(text) > 20 and not text.startswith("<"):
if is_doc_query and topic_keywords:
tl = text.lower()
if not any(k in tl for k in topic_keywords):
continue
all_results.append({
"text": text,
"type": "message",
"score": r.score,
"source": "messages"
})
except Exception as e:
logger.debug(f"{messages_collection} search: {e}")
# Search 3: {agent_id}_docs (knowledge base) - optional
try:
results = self.qdrant_client.search(
collection_name=docs_collection,
query_vector=embedding,
limit=6 if is_doc_query else 3, # Pull more docs for explicit doc queries
with_payload=True
)
for r in results:
# When user asks about PDF/catalogs, relax threshold so docs show up more reliably.
doc_thresh = 0.35 if is_doc_query else 0.5
if r.score > doc_thresh:
text = r.payload.get("text", r.payload.get("content", ""))
if len(text) > 30:
all_results.append({
"text": text[:500], # Truncate long docs
"type": "knowledge",
# Slightly boost docs for doc queries so they win vs chat snippets.
"score": (r.score + 0.12) if is_doc_query else r.score,
"source": "docs"
})
except Exception as e:
logger.debug(f"{docs_collection} search: {e}")
# Search 4: shared agronomy memory (reviewed, cross-chat, anonymized)
if (
SHARED_AGRO_LIBRARY_ENABLED
and agent_id == "agromatrix"
and self._is_plant_query(query)
):
try:
results = self.qdrant_client.search(
collection_name="agromatrix_shared_library",
query_vector=embedding,
limit=3,
with_payload=True
)
for r in results:
if r.score > 0.45:
text = str(r.payload.get("text") or "").strip()
if len(text) > 20:
all_results.append({
"text": text[:500],
"type": "shared_agro_fact",
"score": r.score + 0.05,
"source": "shared_agronomy_library"
})
except Exception as e:
logger.debug(f"agromatrix_shared_library search: {e}")
# Sort by score and deduplicate
all_results.sort(key=lambda x: x.get("score", 0), reverse=True)
# Remove duplicates based on text similarity
seen_texts = set()
unique_results = []
for r in all_results:
text_key = self._canonical_text_key(r.get("text", ""))
if text_key not in seen_texts:
seen_texts.add(text_key)
unique_results.append(r)
return unique_results[:limit]
except Exception as e:
logger.warning(f"Memory search failed for {agent_id}: {e}")
return []
@staticmethod
def _extract_message_text(payload: Dict[str, Any]) -> str:
"""
Normalize text across both payload schemas:
- memory-service: content/text (+ role/channel_id)
- router: user_message + assistant_response (+ chat_id)
"""
if not payload:
return ""
text = (payload.get("text") or payload.get("content") or "").strip()
if text:
lower = text.lower()
marker = "\n\nassistant:"
idx = lower.rfind(marker)
if lower.startswith("user:") and idx != -1:
assistant_text = text[idx + len(marker):].strip()
if assistant_text:
return assistant_text
return text
user_text = (payload.get("user_message") or "").strip()
assistant_text = (payload.get("assistant_response") or "").strip()
if user_text and assistant_text:
return f"User: {user_text}\n\nAssistant: {assistant_text}"
return user_text or assistant_text
@staticmethod
def _canonical_text_key(text: str) -> str:
if not text:
return ""
normalized = re.sub(r"\s+", " ", text.strip().lower())
return normalized[:220]
@staticmethod
def _is_plant_query(text: str) -> bool:
q = (text or "").lower()
if not q:
return False
markers = [
"рослин", "культур", "лист", "стебл", "бур'ян", "хвороб", "шкідник",
"what plant", "identify plant", "crop", "species", "leaf", "stem",
"что за растение", "культура", "листок", "фото рослини"
]
return any(m in q for m in markers)
@staticmethod
def _question_fingerprint(question_text: str) -> str:
normalized = re.sub(r"\s+", " ", (question_text or "").strip().lower())
return hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:16]
@staticmethod
def _build_conversation_id(channel: str, chat_id: str, user_id: str, agent_id: str) -> str:
seed = f"{channel}:{chat_id}:{user_id}:{agent_id}"
return hashlib.sha1(seed.encode("utf-8")).hexdigest()[:24]
async def get_user_graph_context(
self,
username: Optional[str] = None,
telegram_user_id: Optional[str] = None,
agent_id: str = "helion"
) -> Dict[str, List[str]]:
"""Get user's topics and projects from Neo4j, filtered by agent_id"""
context = {"topics": [], "projects": []}
if not self.neo4j_driver:
return context
if not username and not telegram_user_id:
return context
try:
async with self.neo4j_driver.session() as session:
# Find user and their context, FILTERED BY AGENT_ID
# Support both username formats: with and without @
username_with_at = f"@{username}" if username and not username.startswith("@") else username
username_without_at = username[1:] if username and username.startswith("@") else username
# Query with agent_id filter on relationships
query = """
MATCH (u:User)
WHERE u.username IN [$username, $username_with_at, $username_without_at]
OR u.telegram_user_id = $telegram_user_id
OR u.telegram_id = $telegram_user_id
OPTIONAL MATCH (u)-[r1:ASKED_ABOUT]->(t:Topic)
WHERE r1.agent_id = $agent_id OR r1.agent_id IS NULL
OPTIONAL MATCH (u)-[r2:WORKS_ON]->(p:Project)
WHERE r2.agent_id = $agent_id OR r2.agent_id IS NULL
RETURN
collect(DISTINCT t.name) as topics,
collect(DISTINCT p.name) as projects
"""
result = await session.run(
query,
username=username,
username_with_at=username_with_at,
username_without_at=username_without_at,
telegram_user_id=telegram_user_id,
agent_id=agent_id
)
record = await result.single()
if record:
context["topics"] = [t for t in record["topics"] if t]
context["projects"] = [p for p in record["projects"] if p]
logger.debug(f"Graph context for {agent_id}: topics={len(context['topics'])}, projects={len(context['projects'])}")
except Exception as e:
logger.warning(f"Graph context retrieval failed for {agent_id}: {e}")
return context
# =========================================================================
# Main Retrieval Pipeline
# =========================================================================
async def retrieve(
self,
channel: str,
chat_id: str,
user_id: str,
agent_id: str = "helion",
username: Optional[str] = None,
display_name: Optional[str] = None,
message: Optional[str] = None,
thread_id: Optional[str] = None
) -> MemoryBrief:
"""
Main retrieval pipeline for any agent.
1. Resolve user identity (L2)
2. Get session state (L1)
3. Get user facts (L3)
4. Search relevant memories if message provided (L3)
5. Get graph context (L3)
6. Compile memory brief
Args:
agent_id: Agent identifier for collection routing (e.g. "helion", "nutra", "greenfood")
"""
brief = MemoryBrief()
# L2: Identity
identity = await self.resolve_identity(channel, user_id, username, display_name)
brief.user_identity = identity
# L1: Session State
session = await self.get_session_state(
channel,
chat_id,
thread_id,
agent_id=agent_id,
user_id=user_id,
)
brief.session_state = session
brief.is_trusted_group = session.trust_mode
# L3: User Facts
if identity.platform_user_id:
facts = await self.get_user_facts(identity.platform_user_id)
brief.user_facts = facts
# L3: Semantic Search (if message provided) - agent-specific collections
if message:
memories = await self.search_memories(
query=message,
agent_id=agent_id,
platform_user_id=identity.platform_user_id,
chat_id=chat_id,
user_id=user_id,
limit=5
)
brief.relevant_memories = memories
# L3: Graph Context (filtered by agent_id to prevent role mixing)
graph_ctx = await self.get_user_graph_context(username, user_id, agent_id)
brief.user_topics = graph_ctx.get("topics", [])
brief.user_projects = graph_ctx.get("projects", [])
# Check for mentor presence
brief.mentor_present = identity.is_mentor
return brief
# =========================================================================
# Memory Storage (write path)
# =========================================================================
async def store_message(
self,
agent_id: str,
user_id: str,
username: Optional[str],
message_text: str,
response_text: str,
chat_id: str,
message_type: str = "conversation",
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
"""
Store a message exchange in agent-specific Qdrant collection.
This enables semantic retrieval of past conversations per agent.
"""
if not self.qdrant_client or not COHERE_API_KEY:
logger.debug(f"Cannot store message: qdrant={bool(self.qdrant_client)}, cohere={bool(COHERE_API_KEY)}")
return False
messages_collection = f"{agent_id}_messages"
try:
from qdrant_client.http import models as qmodels
import uuid
# Ensure collection exists
try:
self.qdrant_client.get_collection(messages_collection)
except Exception:
# Create collection with Cohere embed-multilingual-v3.0 dimensions (1024)
self.qdrant_client.create_collection(
collection_name=messages_collection,
vectors_config=qmodels.VectorParams(
size=1024,
distance=qmodels.Distance.COSINE
)
)
logger.info(f"✅ Created collection: {messages_collection}")
# Combine user message and response for better context retrieval
combined_text = f"User: {message_text}\n\nAssistant: {response_text}"
# Get embedding
embedding = await self.get_embedding(combined_text[:2000]) # Truncate for API limits
if not embedding:
logger.warning(f"Failed to get embedding for message storage")
return False
# Store in Qdrant
point_id = str(uuid.uuid4())
payload = {
"text": combined_text[:5000], # Limit payload size
"user_message": message_text[:2000],
"assistant_response": response_text[:3000],
"user_id": user_id,
"username": username,
"chat_id": chat_id,
"agent_id": agent_id,
"type": message_type,
"timestamp": datetime.utcnow().isoformat()
}
if metadata and isinstance(metadata, dict):
payload["metadata"] = metadata
self.qdrant_client.upsert(
collection_name=messages_collection,
points=[
qmodels.PointStruct(
id=point_id,
vector=embedding,
payload=payload
)
]
)
# Optional shared agronomy memory:
# - never stores user/chat identifiers
# - supports review gate (pending vs approved)
if (
SHARED_AGRO_LIBRARY_ENABLED
and agent_id == "agromatrix"
and message_type in {"vision", "conversation"}
and isinstance(metadata, dict)
and metadata.get("deterministic_plant_id")
):
await self._store_shared_agronomy_memory(
message_text=message_text,
response_text=response_text,
metadata=metadata,
)
logger.debug(f"✅ Stored message in {messages_collection}: {point_id[:8]}...")
return True
except Exception as e:
logger.warning(f"Failed to store message in {messages_collection}: {e}")
return False
async def _store_shared_agronomy_memory(
self,
message_text: str,
response_text: str,
metadata: Dict[str, Any],
) -> bool:
if not self.qdrant_client or not COHERE_API_KEY:
return False
try:
from qdrant_client.http import models as qmodels
import uuid
reviewed = bool(metadata.get("mentor_confirmed") or metadata.get("reviewed"))
collection = "agromatrix_shared_library"
if SHARED_AGRO_LIBRARY_REQUIRE_REVIEW and not reviewed:
collection = "agromatrix_shared_pending"
try:
self.qdrant_client.get_collection(collection)
except Exception:
self.qdrant_client.create_collection(
collection_name=collection,
vectors_config=qmodels.VectorParams(
size=1024,
distance=qmodels.Distance.COSINE,
),
)
compact = (
f"Plant case\nQuestion: {message_text[:800]}\n"
f"Answer: {response_text[:1200]}\n"
f"Candidates: {json.dumps(metadata.get('candidates', []), ensure_ascii=False)[:1200]}"
)
embedding = await self.get_embedding(compact[:2000])
if not embedding:
return False
payload = {
"text": compact[:3000],
"type": "plant_case",
"deterministic_plant_id": True,
"decision": metadata.get("decision"),
"confidence_threshold": metadata.get("confidence_threshold"),
"candidates": metadata.get("candidates", [])[:5],
"reviewed": reviewed,
"timestamp": datetime.utcnow().isoformat(),
}
self.qdrant_client.upsert(
collection_name=collection,
points=[qmodels.PointStruct(id=str(uuid.uuid4()), vector=embedding, payload=payload)],
)
return True
except Exception as e:
logger.debug(f"Shared agronomy memory store failed: {e}")
return False
async def register_pending_question(
self,
channel: str,
chat_id: str,
user_id: str,
agent_id: str,
question_text: str,
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
if not self.pg_pool:
return False
text = (question_text or "").strip()
if not text:
return False
fp = self._question_fingerprint(text)
try:
async with self.pg_pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO agent_pending_questions
(channel, chat_id, user_id, agent_id, question_text, question_fingerprint, status, metadata)
VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7::jsonb)
ON CONFLICT (agent_id, channel, chat_id, user_id, question_fingerprint, status)
DO NOTHING
""",
channel,
chat_id,
user_id,
agent_id,
text[:1200],
fp,
json.dumps(metadata or {}, ensure_ascii=False),
)
# Keep only last N open items.
await conn.execute(
"""
WITH ranked AS (
SELECT id, ROW_NUMBER() OVER (
PARTITION BY channel, chat_id, user_id, agent_id, status
ORDER BY created_at DESC
) AS rn
FROM agent_pending_questions
WHERE channel = $1
AND chat_id = $2
AND user_id = $3
AND agent_id = $4
AND status = 'pending'
)
UPDATE agent_pending_questions p
SET status = 'dismissed',
answered_at = NOW(),
metadata = COALESCE(p.metadata, '{}'::jsonb) || '{"reason":"overflow_trim"}'::jsonb
FROM ranked r
WHERE p.id = r.id
AND r.rn > $5
""",
channel,
chat_id,
user_id,
agent_id,
max(1, PENDING_QUESTIONS_LIMIT),
)
return True
except Exception as e:
_warning_throttled("register_pending_question_failed", f"register_pending_question failed: {e}")
return False
async def resolve_pending_question(
self,
channel: str,
chat_id: str,
user_id: str,
agent_id: str,
answer_text: Optional[str] = None,
reason: str = "answered",
) -> bool:
if not self.pg_pool:
return False
try:
async with self.pg_pool.acquire() as conn:
row = await conn.fetchrow(
"""
WITH target AS (
SELECT id, question_fingerprint
FROM agent_pending_questions
WHERE channel = $1
AND chat_id = $2
AND user_id = $3
AND agent_id = $4
AND status = 'pending'
ORDER BY created_at ASC
LIMIT 1
), decision AS (
SELECT
t.id,
CASE
WHEN $5 = 'dismissed' THEN 'dismissed'
WHEN EXISTS (
SELECT 1
FROM agent_pending_questions q
WHERE q.channel = $1
AND q.chat_id = $2
AND q.user_id = $3
AND q.agent_id = $4
AND q.status = 'answered'
AND q.question_fingerprint = t.question_fingerprint
) THEN 'dismissed'
ELSE 'answered'
END AS next_status,
CASE
WHEN $5 = 'dismissed' THEN $5
WHEN EXISTS (
SELECT 1
FROM agent_pending_questions q
WHERE q.channel = $1
AND q.chat_id = $2
AND q.user_id = $3
AND q.agent_id = $4
AND q.status = 'answered'
AND q.question_fingerprint = t.question_fingerprint
) THEN 'duplicate_answered'
ELSE $5
END AS resolution_reason
FROM target t
)
UPDATE agent_pending_questions p
SET status = d.next_status,
answered_at = NOW(),
metadata = COALESCE(p.metadata, '{}'::jsonb)
|| jsonb_build_object(
'resolution_reason', d.resolution_reason,
'answer_fingerprint', COALESCE($6, '')
)
FROM decision d
WHERE p.id = d.id
RETURNING p.id
""",
channel,
chat_id,
user_id,
agent_id,
reason,
self._question_fingerprint(answer_text or "") if answer_text else "",
)
return bool(row)
except Exception as e:
_warning_throttled("resolve_pending_question_failed", f"resolve_pending_question failed: {e}")
return False
@staticmethod
def _to_qdrant_point_id(raw_id: Any) -> Any:
if isinstance(raw_id, int):
return raw_id
if isinstance(raw_id, float) and raw_id.is_integer():
return int(raw_id)
if isinstance(raw_id, str):
v = raw_id.strip()
if not v:
return raw_id
if v.isdigit():
try:
return int(v)
except Exception:
return v
return v
return raw_id
async def list_shared_pending_cases(self, limit: int = 50) -> List[Dict[str, Any]]:
if not self.qdrant_client or not SHARED_AGRO_LIBRARY_ENABLED:
return []
size = max(1, min(int(limit or 50), 200))
try:
points, _ = self.qdrant_client.scroll(
collection_name="agromatrix_shared_pending",
limit=size,
with_payload=True,
with_vectors=False,
)
except Exception as e:
logger.debug(f"list_shared_pending_cases failed: {e}")
return []
items: List[Dict[str, Any]] = []
for p in points or []:
payload = getattr(p, "payload", {}) or {}
text = str(payload.get("text") or "").strip()
timestamp = payload.get("timestamp") or ""
candidates = payload.get("candidates") if isinstance(payload.get("candidates"), list) else []
items.append(
{
"point_id": str(getattr(p, "id", "")),
"timestamp": timestamp,
"decision": payload.get("decision"),
"reviewed": bool(payload.get("reviewed")),
"excerpt": text[:240],
"candidates": candidates[:5],
}
)
items.sort(key=lambda x: x.get("timestamp") or "", reverse=True)
return items
async def review_shared_pending_case(
self,
point_id: str,
approve: bool,
reviewer: Optional[str] = None,
note: Optional[str] = None,
) -> Dict[str, Any]:
if not self.qdrant_client:
return {"ok": False, "error": "qdrant_unavailable"}
try:
from qdrant_client.http import models as qmodels
import uuid
pid = self._to_qdrant_point_id(point_id)
records = self.qdrant_client.retrieve(
collection_name="agromatrix_shared_pending",
ids=[pid],
with_payload=True,
with_vectors=True,
)
if not records:
return {"ok": False, "error": "not_found"}
point = records[0]
payload = dict(getattr(point, "payload", {}) or {})
now_iso = datetime.utcnow().isoformat()
payload["reviewed"] = bool(approve)
payload["review"] = {
"reviewer": (reviewer or "system")[:120],
"approved": bool(approve),
"note": (note or "")[:500],
"reviewed_at": now_iso,
}
library_point_id: Optional[str] = None
if approve:
vector = getattr(point, "vector", None)
if isinstance(vector, dict):
# Named vectors mode: pick first vector value.
vector = next(iter(vector.values()), None)
if not vector and COHERE_API_KEY:
basis = str(payload.get("text") or payload.get("assistant_response") or "")[:2000]
vector = await self.get_embedding(basis)
if not vector:
return {"ok": False, "error": "missing_vector"}
try:
self.qdrant_client.get_collection("agromatrix_shared_library")
except Exception:
self.qdrant_client.create_collection(
collection_name="agromatrix_shared_library",
vectors_config=qmodels.VectorParams(
size=len(vector),
distance=qmodels.Distance.COSINE,
),
)
library_point_id = str(uuid.uuid4())
payload["approved_at"] = now_iso
self.qdrant_client.upsert(
collection_name="agromatrix_shared_library",
points=[
qmodels.PointStruct(
id=library_point_id,
vector=vector,
payload=payload,
)
],
)
self.qdrant_client.delete(
collection_name="agromatrix_shared_pending",
points_selector=qmodels.PointIdsList(points=[pid]),
)
return {
"ok": True,
"approved": bool(approve),
"point_id": str(getattr(point, "id", point_id)),
"library_point_id": library_point_id,
}
except Exception as e:
logger.warning(f"review_shared_pending_case failed: {e}")
return {"ok": False, "error": str(e)}
def _chunk_document_text(
self,
text: str,
chunk_chars: int = 1200,
overlap_chars: int = 180,
) -> List[str]:
"""
Split document text into overlap-aware chunks for RAG indexing.
Keeps paragraph structure when possible.
"""
raw = re.sub(r"\r\n?", "\n", text or "").strip()
if not raw:
return []
paragraphs = [p.strip() for p in re.split(r"\n{2,}", raw) if p and p.strip()]
if not paragraphs:
return []
chunks: List[str] = []
current = ""
max_hard = max(chunk_chars, 600)
def _push_current() -> None:
nonlocal current
if current and len(current.strip()) >= 20:
chunks.append(current.strip())
current = ""
for para in paragraphs:
if len(para) > max_hard * 2:
_push_current()
i = 0
step = max_hard - max(80, min(overlap_chars, max_hard // 2))
while i < len(para):
part = para[i : i + max_hard]
if len(part.strip()) >= 20:
chunks.append(part.strip())
i += max(1, step)
continue
candidate = f"{current}\n\n{para}".strip() if current else para
if len(candidate) <= max_hard:
current = candidate
continue
_push_current()
if overlap_chars > 0 and chunks:
tail = chunks[-1][-overlap_chars:]
current = f"{tail}\n\n{para}".strip()
if len(current) > max_hard:
_push_current()
current = para
else:
current = para
_push_current()
return chunks
async def _next_document_version_no(
self,
agent_id: str,
doc_id: str,
) -> int:
if self.pg_pool:
try:
async with self.pg_pool.acquire() as conn:
value = await conn.fetchval(
"""
SELECT COALESCE(MAX(version_no), 0) + 1
FROM agent_document_versions
WHERE agent_id = $1
AND doc_id = $2
""",
(agent_id or "").lower(),
doc_id,
)
return max(1, int(value or 1))
except Exception as e:
logger.warning(f"next_document_version_no(pg) failed: {e}")
# Fallback: infer from existing chunk payloads in Qdrant.
if self.qdrant_client:
try:
from qdrant_client.http import models as qmodels
collection = f"{(agent_id or 'daarwizz').lower()}_docs"
points, _ = self.qdrant_client.scroll(
collection_name=collection,
scroll_filter=qmodels.Filter(
must=[
qmodels.FieldCondition(
key="doc_id",
match=qmodels.MatchValue(value=doc_id),
)
]
),
limit=256,
with_payload=True,
)
current_max = 0
for p in points or []:
payload = getattr(p, "payload", {}) or {}
ver = payload.get("version_no")
if isinstance(ver, int):
current_max = max(current_max, ver)
elif isinstance(ver, str) and ver.isdigit():
current_max = max(current_max, int(ver))
return current_max + 1 if current_max > 0 else 1
except Exception as e:
logger.debug(f"next_document_version_no(qdrant) fallback failed: {e}")
return 1
async def _latest_document_version_no(
self,
agent_id: str,
doc_id: str,
) -> int:
nxt = await self._next_document_version_no(agent_id=agent_id, doc_id=doc_id)
return max(0, int(nxt) - 1)
async def _record_document_version(
self,
agent_id: str,
doc_id: str,
version_no: int,
text: str,
file_name: Optional[str] = None,
dao_id: Optional[str] = None,
user_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
source: str = "ingest",
storage_ref: Optional[str] = None,
) -> Dict[str, Any]:
text_body = (text or "").strip()
text_hash = hashlib.sha256(text_body.encode("utf-8")).hexdigest() if text_body else ""
text_len = len(text_body)
preview = text_body[:DOC_VERSION_PREVIEW_CHARS] if text_body else ""
payload = metadata if isinstance(metadata, dict) else {}
if not self.pg_pool:
return {"ok": True, "version_no": int(version_no), "id": None}
try:
async with self.pg_pool.acquire() as conn:
row = await conn.fetchrow(
"""
INSERT INTO agent_document_versions
(agent_id, doc_id, version_no, text_hash, text_len, text_preview,
file_name, dao_id, user_id, storage_ref, source, metadata)
VALUES
($1, $2, $3, $4, $5, $6,
$7, $8, $9, $10, $11, $12::jsonb)
ON CONFLICT (agent_id, doc_id, version_no)
DO UPDATE SET
text_hash = EXCLUDED.text_hash,
text_len = EXCLUDED.text_len,
text_preview = EXCLUDED.text_preview,
file_name = EXCLUDED.file_name,
dao_id = EXCLUDED.dao_id,
user_id = EXCLUDED.user_id,
storage_ref = EXCLUDED.storage_ref,
source = EXCLUDED.source,
metadata = EXCLUDED.metadata
RETURNING id, version_no
""",
(agent_id or "").lower(),
doc_id,
int(version_no),
text_hash,
int(text_len),
preview,
file_name,
dao_id,
user_id,
storage_ref,
source,
json.dumps(payload),
)
return {
"ok": True,
"id": int(row["id"]) if row and row.get("id") is not None else None,
"version_no": int(row["version_no"]) if row and row.get("version_no") is not None else int(version_no),
}
except Exception as e:
logger.warning(f"record_document_version failed: {e}")
return {"ok": False, "error": str(e), "version_no": int(version_no)}
async def list_document_versions(
self,
agent_id: str,
doc_id: str,
limit: int = 20,
) -> List[Dict[str, Any]]:
rows_out: List[Dict[str, Any]] = []
if self.pg_pool:
try:
async with self.pg_pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, agent_id, doc_id, version_no, text_hash, text_len, text_preview,
file_name, dao_id, user_id, storage_ref, source, metadata, created_at
FROM agent_document_versions
WHERE agent_id = $1
AND doc_id = $2
ORDER BY version_no DESC
LIMIT $3
""",
(agent_id or "").lower(),
doc_id,
max(1, min(int(limit or 20), 200)),
)
for r in rows:
meta_raw = r["metadata"]
if isinstance(meta_raw, dict):
meta_obj = meta_raw
elif isinstance(meta_raw, str):
try:
parsed = json.loads(meta_raw)
meta_obj = parsed if isinstance(parsed, dict) else {"raw": parsed}
except Exception:
meta_obj = {"raw": meta_raw}
else:
meta_obj = {}
rows_out.append(
{
"id": int(r["id"]),
"agent_id": r["agent_id"],
"doc_id": r["doc_id"],
"version_no": int(r["version_no"]),
"text_hash": r["text_hash"],
"text_len": int(r["text_len"] or 0),
"text_preview": r["text_preview"],
"file_name": r["file_name"],
"dao_id": r["dao_id"],
"user_id": r["user_id"],
"storage_ref": r["storage_ref"],
"source": r["source"],
"metadata": meta_obj,
"created_at": r["created_at"].isoformat() if r["created_at"] else None,
}
)
return rows_out
except Exception as e:
logger.warning(f"list_document_versions failed: {e}")
# PG unavailable fallback: aggregate distinct versions from Qdrant payloads.
if self.qdrant_client:
try:
from qdrant_client.http import models as qmodels
collection = f"{(agent_id or 'daarwizz').lower()}_docs"
offset = None
seen: Dict[int, Dict[str, Any]] = {}
max_points = max(64, min(int(limit or 20) * 80, 4096))
fetched = 0
while fetched < max_points:
points, next_offset = self.qdrant_client.scroll(
collection_name=collection,
scroll_filter=qmodels.Filter(
must=[
qmodels.FieldCondition(
key="doc_id",
match=qmodels.MatchValue(value=doc_id),
)
]
),
offset=offset,
limit=256,
with_payload=True,
)
if not points:
break
fetched += len(points)
for p in points:
payload = getattr(p, "payload", {}) or {}
ver_raw = payload.get("version_no")
if isinstance(ver_raw, int):
ver = ver_raw
elif isinstance(ver_raw, str) and ver_raw.isdigit():
ver = int(ver_raw)
else:
ver = 1
existing = seen.get(ver)
ts = payload.get("timestamp")
if not existing or (ts and str(ts) > str(existing.get("created_at") or "")):
seen[ver] = {
"id": None,
"agent_id": (agent_id or "").lower(),
"doc_id": doc_id,
"version_no": int(ver),
"text_hash": None,
"text_len": None,
"text_preview": None,
"file_name": payload.get("file_name"),
"dao_id": payload.get("dao_id"),
"user_id": payload.get("user_id"),
"storage_ref": payload.get("storage_ref"),
"source": payload.get("source") or "ingest",
"metadata": payload.get("metadata") or {},
"created_at": ts,
}
if not next_offset:
break
offset = next_offset
rows_out = sorted(seen.values(), key=lambda x: int(x.get("version_no") or 0), reverse=True)[: max(1, min(int(limit or 20), 200))]
except Exception:
pass
return rows_out
def _build_doc_filter(
self,
doc_id: str,
dao_id: Optional[str] = None,
):
from qdrant_client.http import models as qmodels
must_conditions = [
qmodels.FieldCondition(
key="doc_id",
match=qmodels.MatchValue(value=doc_id),
)
]
if dao_id:
must_conditions.append(
qmodels.FieldCondition(
key="dao_id",
match=qmodels.MatchValue(value=dao_id),
)
)
return qmodels.Filter(must=must_conditions)
def _delete_document_points(
self,
collection: str,
doc_id: str,
dao_id: Optional[str] = None,
) -> bool:
if not self.qdrant_client:
return False
try:
from qdrant_client.http import models as qmodels
self.qdrant_client.delete(
collection_name=collection,
points_selector=qmodels.FilterSelector(
filter=self._build_doc_filter(doc_id=doc_id, dao_id=dao_id)
),
)
return True
except Exception as e:
logger.warning(f"delete_document_points failed for {collection}/{doc_id}: {e}")
return False
async def ingest_document_chunks(
self,
agent_id: str,
doc_id: str,
file_name: Optional[str],
text: str,
dao_id: Optional[str] = None,
user_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
replace_existing: bool = False,
version_no: Optional[int] = None,
source: str = "ingest",
storage_ref: Optional[str] = None,
) -> Dict[str, Any]:
"""
Ingest normalized document chunks into {agent_id}_docs collection.
"""
if not self.qdrant_client:
return {"ok": False, "error": "qdrant_unavailable"}
if not COHERE_API_KEY:
return {"ok": False, "error": "cohere_unavailable"}
body = (text or "").strip()
if not body:
return {"ok": False, "error": "empty_document"}
chunks = self._chunk_document_text(body)
if not chunks:
return {"ok": False, "error": "no_chunks"}
collection = f"{(agent_id or 'daarwizz').lower()}_docs"
stored_points = []
try:
from qdrant_client.http import models as qmodels
import uuid
try:
self.qdrant_client.get_collection(collection)
except Exception:
self.qdrant_client.create_collection(
collection_name=collection,
vectors_config=qmodels.VectorParams(
size=1024,
distance=qmodels.Distance.COSINE,
),
)
logger.info(f"✅ Created collection: {collection}")
total = len(chunks)
resolved_version_no = int(version_no or 0) or await self._next_document_version_no(agent_id=agent_id, doc_id=doc_id)
for idx, chunk in enumerate(chunks):
emb = await self.get_embedding(chunk[:2000])
if not emb:
continue
payload: Dict[str, Any] = {
"text": chunk[:6000],
"doc_id": doc_id,
"file_name": file_name,
"agent_id": (agent_id or "").lower(),
"dao_id": dao_id,
"user_id": user_id,
"chunk_index": idx,
"chunks_total": total,
"type": "document_chunk",
"version_no": int(resolved_version_no),
"source": source,
"storage_ref": storage_ref,
"timestamp": datetime.utcnow().isoformat(),
}
if isinstance(metadata, dict) and metadata:
payload["metadata"] = metadata
stored_points.append(
qmodels.PointStruct(
id=str(uuid.uuid4()),
vector=emb,
payload=payload,
)
)
if not stored_points:
return {"ok": False, "error": "embedding_failed"}
# Keep previous versions in the same collection when updating.
# Query path will select only the latest version_no for doc_id.
self.qdrant_client.upsert(collection_name=collection, points=stored_points)
version_row = await self._record_document_version(
agent_id=agent_id,
doc_id=doc_id,
version_no=resolved_version_no,
text=body,
file_name=file_name,
dao_id=dao_id,
user_id=user_id,
metadata=metadata,
source=source,
storage_ref=storage_ref,
)
return {
"ok": True,
"doc_id": doc_id,
"version_no": int(resolved_version_no),
"version_id": version_row.get("id"),
"chunks_total": len(chunks),
"chunks_stored": len(stored_points),
"replaced_existing": bool(replace_existing),
"collection": collection,
}
except Exception as e:
logger.warning(f"ingest_document_chunks failed for {collection}: {e}")
return {"ok": False, "error": str(e)}
async def update_document_chunks(
self,
agent_id: str,
doc_id: str,
file_name: Optional[str],
text: str,
dao_id: Optional[str] = None,
user_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
storage_ref: Optional[str] = None,
) -> Dict[str, Any]:
"""
Update existing document content with version bump.
Keeps the same logical doc_id and replaces indexed chunks.
"""
next_version = await self._next_document_version_no(agent_id=agent_id, doc_id=doc_id)
result = await self.ingest_document_chunks(
agent_id=agent_id,
doc_id=doc_id,
file_name=file_name,
text=text,
dao_id=dao_id,
user_id=user_id,
metadata=metadata,
replace_existing=False,
version_no=next_version,
source="update",
storage_ref=storage_ref,
)
if result.get("ok"):
result["updated"] = True
result["replaced_existing"] = True
return result
async def query_document_chunks(
self,
agent_id: str,
question: str,
doc_id: Optional[str] = None,
dao_id: Optional[str] = None,
limit: int = 5,
) -> Dict[str, Any]:
"""
Retrieve top document chunks from {agent_id}_docs for a question.
"""
if not self.qdrant_client:
return {"ok": False, "error": "qdrant_unavailable", "chunks": []}
if not COHERE_API_KEY:
return {"ok": False, "error": "cohere_unavailable", "chunks": []}
q = (question or "").strip()
if not q:
return {"ok": False, "error": "empty_question", "chunks": []}
embedding = await self.get_embedding(q[:2000])
if not embedding:
return {"ok": False, "error": "embedding_failed", "chunks": []}
collection = f"{(agent_id or 'daarwizz').lower()}_docs"
try:
from qdrant_client.http import models as qmodels
must_conditions = []
if doc_id:
latest_ver = await self._latest_document_version_no(agent_id=agent_id, doc_id=doc_id)
must_conditions.append(
qmodels.FieldCondition(
key="doc_id",
match=qmodels.MatchValue(value=doc_id),
)
)
if latest_ver > 0:
must_conditions.append(
qmodels.FieldCondition(
key="version_no",
match=qmodels.MatchValue(value=int(latest_ver)),
)
)
if dao_id:
must_conditions.append(
qmodels.FieldCondition(
key="dao_id",
match=qmodels.MatchValue(value=dao_id),
)
)
query_filter = qmodels.Filter(must=must_conditions) if must_conditions else None
rows = self.qdrant_client.search(
collection_name=collection,
query_vector=embedding,
query_filter=query_filter,
limit=max(1, min(int(limit or 5), 12)),
with_payload=True,
)
except Exception as e:
logger.debug(f"query_document_chunks search failed for {collection}: {e}")
return {"ok": False, "error": "search_failed", "chunks": [], "collection": collection}
hits: List[Dict[str, Any]] = []
for row in rows or []:
score = float(getattr(row, "score", 0.0) or 0.0)
if score < 0.30:
continue
payload = getattr(row, "payload", {}) or {}
text = str(payload.get("text") or "").strip()
if len(text) < 10:
continue
hits.append(
{
"text": text,
"score": score,
"doc_id": payload.get("doc_id"),
"file_name": payload.get("file_name"),
"chunk_index": payload.get("chunk_index"),
"chunks_total": payload.get("chunks_total"),
"version_no": payload.get("version_no"),
}
)
return {
"ok": bool(hits),
"chunks": hits,
"collection": collection,
"doc_id": doc_id,
}
async def store_interaction(
self,
channel: str,
chat_id: str,
user_id: str,
agent_id: str,
username: Optional[str],
user_message: str,
assistant_response: str,
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
# Backward-compatible wrapper for older call sites.
return await self.store_message(
agent_id=agent_id,
user_id=user_id,
username=username,
message_text=user_message,
response_text=assistant_response,
chat_id=chat_id,
message_type="conversation",
metadata=metadata,
)
async def update_session_state(
self,
conversation_id: str,
**updates
):
"""Update session state after interaction"""
if not self.pg_pool or not conversation_id:
return
try:
async with self.pg_pool.acquire() as conn:
# Build dynamic update
set_clauses = ["updated_at = NOW()"]
params = [conversation_id]
param_idx = 2
allowed_fields = [
'last_user_id', 'last_user_nick',
'active_topic', 'context_open',
'last_media_handled', 'last_answer_fingerprint',
'trust_mode', 'apprentice_mode'
]
for field, value in updates.items():
if field in allowed_fields:
set_clauses.append(f"{field} = ${param_idx}")
params.append(value)
param_idx += 1
query = f"""
UPDATE agent_session_state
SET {', '.join(set_clauses)}
WHERE conversation_id = $1
"""
await conn.execute(query, *params)
except Exception as e:
logger.warning(f"Session state update failed: {e}")
# Global instance
memory_retrieval = MemoryRetrieval()