agromatrix: add pending-question memory, anti-repeat guard, and numeric contract

This commit is contained in:
NODA1 System
2026-02-21 12:47:23 +01:00
parent a87a1fe52c
commit d963c52fe5
2 changed files with 621 additions and 50 deletions

View File

@@ -11,6 +11,7 @@ import httpx
import logging import logging
import hashlib import hashlib
import time # For latency metrics import time # For latency metrics
from difflib import SequenceMatcher
# CrewAI Integration # CrewAI Integration
try: try:
@@ -262,12 +263,114 @@ def _build_agromatrix_deterministic_fallback(candidates: List[Dict[str, Any]]) -
EMPTY_ANSWER_GUARD_AGENTS = {"devtools", "monitor"} EMPTY_ANSWER_GUARD_AGENTS = {"devtools", "monitor"}
DETERMINISTIC_PLANT_POLICY_AGENTS = {
part.strip().lower()
for part in os.getenv(
"DETERMINISTIC_PLANT_POLICY_AGENTS",
"agromatrix,greenfood,nutra",
).split(",")
if part.strip()
}
REPEAT_FINGERPRINT_MIN_SIMILARITY = float(os.getenv("AGENT_REPEAT_FINGERPRINT_MIN_SIMILARITY", "0.92"))
def _normalize_text_response(text: str) -> str: def _normalize_text_response(text: str) -> str:
return re.sub(r"\s+", " ", str(text or "")).strip() return re.sub(r"\s+", " ", str(text or "")).strip()
def _response_fingerprint(text: str) -> str:
normalized = _normalize_text_response(text).lower()
normalized = re.sub(r"[^a-zаіїєґ0-9%./:;,+\- ]+", " ", normalized)
normalized = re.sub(r"\s+", " ", normalized).strip()
return normalized[:240]
def _fingerprint_similarity(a: str, b: str) -> float:
if not a or not b:
return 0.0
return SequenceMatcher(None, a, b).ratio()
def _looks_like_user_question(text: str) -> bool:
t = (text or "").strip().lower()
if not t:
return False
if "?" in t:
return True
starters = (
"що", "як", "чому", "коли", "де", "скільки", "яка", "який", "які",
"what", "how", "why", "when", "where", "which", "can you",
"что", "как", "почему", "когда", "где", "сколько",
)
return any(t.startswith(s + " ") for s in starters)
def _looks_like_negative_feedback(text: str) -> bool:
t = (text or "").lower()
markers = (
"не вірно", "невірно", "неправильно", "помилка", "знову не так",
"це не так", "не релевантно", "повтор", "ти знову", "мимо",
"wrong", "incorrect", "not relevant", "repeat", "again wrong",
"неверно", "неправильно", "это ошибка", "снова не так",
)
return any(m in t for m in markers)
def _looks_like_numeric_request(text: str) -> bool:
t = (text or "").lower()
markers = (
"скільки", "сума", "витра", "cost", "total", "amount", "ціна",
"вартість", "дохід", "прибут", "маржа", "баланс", "unit cost",
"сколько", "сумма", "затрат", "стоимость", "расход",
)
return any(m in t for m in markers)
def _numeric_contract_present(text: str) -> bool:
t = _normalize_text_response(text)
low = t.lower()
if not re.search(r"\d", low):
return False
has_value_with_unit = re.search(
r"\b\d[\d\s.,]*\s*(грн|uah|usd|eur|kg|кг|т|л|га|шт|%|тон|літр|hectare|ha)\b",
low,
) is not None
has_explicit_source = any(
re.search(pattern, low) is not None
for pattern in (
r"\bsheet\s*[:#]?\s*[a-z0-9_]+",
r"\brow\s*[:#]?\s*\d+",
r"\bрядок\s*[:#]?\s*\d+",
r"\bлист\s*[:#]?\s*[a-zа-я0-9_]+",
r"\bcell\s*[:#]?\s*[a-z]+\d+",
r"\омірк[а-я]*\s*[:#]?\s*[a-zа-я]+\d+",
r"\bsource\s*[:#]",
r"\bджерел[оа]\s*[:#]",
)
)
return bool(has_value_with_unit and has_explicit_source)
def _build_numeric_contract_uncertain_response() -> str:
return (
"Не можу підтвердити точне число без джерела. "
"Щоб дати коректну відповідь, надішли таблицю/файл або уточни лист і діапазон. "
"Формат відповіді дам строго як: value + unit + source(sheet,row)."
)
def _response_is_uncertain_or_incomplete(text: str) -> bool:
low = _normalize_text_response(text).lower()
if not low:
return True
markers = (
"не впевнений", "не можу", "надішли", "уточни", "уточніть",
"потрібно більше", "insufficient", "need more", "please send",
"не уверен", "не могу", "уточни", "нужно больше",
)
return any(m in low for m in markers)
def _needs_empty_answer_recovery(text: str) -> bool: def _needs_empty_answer_recovery(text: str) -> bool:
normalized = _normalize_text_response(text) normalized = _normalize_text_response(text)
if not normalized: if not normalized:
@@ -1369,6 +1472,8 @@ async def agent_infer(agent_id: str, request: InferRequest):
# MEMORY RETRIEVAL (v4.0 - Universal for all agents) # MEMORY RETRIEVAL (v4.0 - Universal for all agents)
# ========================================================================= # =========================================================================
memory_brief_text = "" memory_brief_text = ""
brief: Optional[MemoryBrief] = None
session_state = None
# Extract metadata once for both retrieval and storage # Extract metadata once for both retrieval and storage
metadata = request.metadata or {} metadata = request.metadata or {}
channel = "telegram" # Default channel = "telegram" # Default
@@ -1382,7 +1487,32 @@ async def agent_infer(agent_id: str, request: InferRequest):
# IMPORTANT: inspect only the latest user text when provided by gateway, # IMPORTANT: inspect only the latest user text when provided by gateway,
# not the full context-augmented prompt. # not the full context-augmented prompt.
raw_user_text = str(metadata.get("raw_user_text", "") or "").strip() raw_user_text = str(metadata.get("raw_user_text", "") or "").strip()
image_guard_text = raw_user_text if raw_user_text else request.prompt incoming_user_text = raw_user_text if raw_user_text else request.prompt
image_guard_text = incoming_user_text
track_pending_question = _looks_like_user_question(incoming_user_text)
if (
MEMORY_RETRIEVAL_AVAILABLE
and memory_retrieval
and chat_id
and user_id
and track_pending_question
):
try:
await memory_retrieval.register_pending_question(
channel=channel,
chat_id=chat_id,
user_id=user_id,
agent_id=request_agent_id,
question_text=incoming_user_text,
metadata={
"source": "router_infer",
"has_images": bool(request.images),
},
)
except Exception as e:
logger.debug(f"Pending question register skipped: {e}")
if (not request.images) and _looks_like_image_question(image_guard_text): if (not request.images) and _looks_like_image_question(image_guard_text):
return InferResponse( return InferResponse(
response=( response=(
@@ -1405,6 +1535,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
username=username, username=username,
message=request.prompt message=request.prompt
) )
session_state = brief.session_state if brief else None
memory_brief_text = brief.to_text(max_lines=10) memory_brief_text = brief.to_text(max_lines=10)
if memory_brief_text: if memory_brief_text:
logger.info(f"🧠 Memory brief for {request_agent_id}: {len(memory_brief_text)} chars") logger.info(f"🧠 Memory brief for {request_agent_id}: {len(memory_brief_text)} chars")
@@ -1454,6 +1585,63 @@ async def agent_infer(agent_id: str, request: InferRequest):
f"🧩 Prompt meta for {agent_id}: source={system_prompt_source}, " f"🧩 Prompt meta for {agent_id}: source={system_prompt_source}, "
f"version={effective_metadata['system_prompt_version']}, hash={system_prompt_hash}" f"version={effective_metadata['system_prompt_version']}, hash={system_prompt_hash}"
) )
async def _finalize_response_text(text: str, backend_tag: str) -> str:
final_text = _normalize_text_response(text)
if not final_text:
return final_text
# Agro numeric contract: no numbers without unit + source marker.
if request_agent_id == "agromatrix" and _looks_like_numeric_request(incoming_user_text):
if not _numeric_contract_present(final_text):
final_text = _build_numeric_contract_uncertain_response()
# Anti-repeat guard: if user reports wrong answer and new answer is near-identical
# to previous one, force non-repetitive recovery text.
prev_fp = ""
if session_state and getattr(session_state, "last_answer_fingerprint", None):
prev_fp = str(session_state.last_answer_fingerprint or "")
new_fp = _response_fingerprint(final_text)
if prev_fp and new_fp:
similarity = _fingerprint_similarity(prev_fp, new_fp)
if similarity >= REPEAT_FINGERPRINT_MIN_SIMILARITY and _looks_like_negative_feedback(incoming_user_text):
final_text = (
"Прийняв, попередня відповідь була не по суті. Не повторюю її. "
"Переформулюю коротко і по ділу: надішли 1 конкретне питання або файл/фото, "
"і я дам перевірену відповідь із джерелом."
)
new_fp = _response_fingerprint(final_text)
logger.warning(
f"🔁 Repeat guard fired for {request_agent_id}: similarity={similarity:.3f}, backend={backend_tag}"
)
# Resolve oldest pending question only when answer is not uncertain.
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id:
try:
if track_pending_question and not _response_is_uncertain_or_incomplete(final_text):
await memory_retrieval.resolve_pending_question(
channel=channel,
chat_id=chat_id,
user_id=user_id,
agent_id=request_agent_id,
answer_text=final_text,
reason="answered",
)
except Exception as e:
logger.debug(f"Pending question resolve skipped: {e}")
try:
if session_state and getattr(session_state, "conversation_id", None):
await memory_retrieval.update_session_state(
session_state.conversation_id,
last_answer_fingerprint=new_fp[:240],
last_user_id=user_id,
last_user_nick=username,
)
except Exception as e:
logger.debug(f"Session fingerprint update skipped: {e}")
return final_text
# Determine which backend to use # Determine which backend to use
# Use router config to get default model for agent, fallback to qwen3:8b # Use router config to get default model for agent, fallback to qwen3:8b
@@ -1601,6 +1789,8 @@ async def agent_infer(agent_id: str, request: InferRequest):
parts = re.split(r"(?<=[.!?])\s+", final_response_text.strip()) parts = re.split(r"(?<=[.!?])\s+", final_response_text.strip())
if len(parts) > 3: if len(parts) > 3:
final_response_text = " ".join(parts[:3]).strip() final_response_text = " ".join(parts[:3]).strip()
final_response_text = await _finalize_response_text(final_response_text, "crewai")
# Store interaction in memory # Store interaction in memory
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id: if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id:
@@ -1656,7 +1846,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
# 1) run plant classifiers first (nature-id / plantnet) # 1) run plant classifiers first (nature-id / plantnet)
# 2) apply confidence threshold # 2) apply confidence threshold
# 3) LLM only explains classifier result, no new guessing # 3) LLM only explains classifier result, no new guessing
if request_agent_id == "agromatrix" and plant_intent and TOOL_MANAGER_AVAILABLE and tool_manager: if request_agent_id in DETERMINISTIC_PLANT_POLICY_AGENTS and plant_intent and TOOL_MANAGER_AVAILABLE and tool_manager:
try: try:
image_inputs = _extract_image_inputs_for_plant_tools(request.images, metadata) image_inputs = _extract_image_inputs_for_plant_tools(request.images, metadata)
if image_inputs: if image_inputs:
@@ -1697,6 +1887,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
top_conf = float(candidates[0].get("confidence", 0.0)) if candidates else 0.0 top_conf = float(candidates[0].get("confidence", 0.0)) if candidates else 0.0
if (not candidates) or (top_conf < threshold): if (not candidates) or (top_conf < threshold):
response_text = _build_agromatrix_not_sure_response(candidates, threshold) response_text = _build_agromatrix_not_sure_response(candidates, threshold)
response_text = await _finalize_response_text(response_text, "plant-id-deterministic-uncertain")
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id: if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id:
asyncio.create_task( asyncio.create_task(
memory_retrieval.store_message( memory_retrieval.store_message(
@@ -1770,6 +1961,8 @@ async def agent_infer(agent_id: str, request: InferRequest):
if (top_name and top_name not in low) and (top_sci and top_sci not in low): if (top_name and top_name not in low) and (top_sci and top_sci not in low):
response_text = _build_agromatrix_deterministic_fallback(candidates) response_text = _build_agromatrix_deterministic_fallback(candidates)
response_text = await _finalize_response_text(response_text, llm_backend)
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id: if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id:
asyncio.create_task( asyncio.create_task(
memory_retrieval.store_message( memory_retrieval.store_message(
@@ -1916,7 +2109,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
# Plant identification safety gate: # Plant identification safety gate:
# avoid hard species claims when confidence is low or evidence is weak. # avoid hard species claims when confidence is low or evidence is weak.
if request_agent_id == "agromatrix" and plant_intent and (uncertain or len(vision_sources) < 2): if request_agent_id in DETERMINISTIC_PLANT_POLICY_AGENTS and plant_intent and (uncertain or len(vision_sources) < 2):
full_response = _build_cautious_plant_response(full_response or raw_response, len(vision_sources)) full_response = _build_cautious_plant_response(full_response or raw_response, len(vision_sources))
# Image quality gate: one soft retry if response looks empty/meta. # Image quality gate: one soft retry if response looks empty/meta.
@@ -1948,8 +2141,10 @@ async def agent_infer(agent_id: str, request: InferRequest):
if _image_response_needs_retry(full_response): if _image_response_needs_retry(full_response):
full_response = _build_image_fallback_response(request_agent_id, request.prompt) full_response = _build_image_fallback_response(request_agent_id, request.prompt)
elif request_agent_id == "agromatrix" and _vision_response_is_blurry(full_response): elif request_agent_id in DETERMINISTIC_PLANT_POLICY_AGENTS and _vision_response_is_blurry(full_response):
full_response = _build_image_fallback_response(request_agent_id, request.prompt) full_response = _build_image_fallback_response(request_agent_id, request.prompt)
full_response = await _finalize_response_text(full_response, "swapper-vision")
# Store vision message in agent-specific memory # Store vision message in agent-specific memory
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id and full_response: if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id and full_response:
@@ -1979,8 +2174,12 @@ async def agent_infer(agent_id: str, request: InferRequest):
) )
else: else:
logger.error(f"❌ Swapper vision error: {vision_resp.status_code} - {vision_resp.text[:200]}") logger.error(f"❌ Swapper vision error: {vision_resp.status_code} - {vision_resp.text[:200]}")
fallback_response = await _finalize_response_text(
_build_image_fallback_response(request_agent_id, request.prompt),
"swapper-vision-fallback",
)
return InferResponse( return InferResponse(
response=_build_image_fallback_response(request_agent_id, request.prompt), response=fallback_response,
model="qwen3-vl-8b", model="qwen3-vl-8b",
tokens_used=None, tokens_used=None,
backend="swapper-vision-fallback" backend="swapper-vision-fallback"
@@ -1988,8 +2187,12 @@ async def agent_infer(agent_id: str, request: InferRequest):
except Exception as e: except Exception as e:
logger.error(f"❌ Vision processing failed: {e}", exc_info=True) logger.error(f"❌ Vision processing failed: {e}", exc_info=True)
fallback_response = await _finalize_response_text(
_build_image_fallback_response(request_agent_id, request.prompt),
"swapper-vision-fallback",
)
return InferResponse( return InferResponse(
response=_build_image_fallback_response(request_agent_id, request.prompt), response=fallback_response,
model="qwen3-vl-8b", model="qwen3-vl-8b",
tokens_used=None, tokens_used=None,
backend="swapper-vision-fallback" backend="swapper-vision-fallback"
@@ -2435,6 +2638,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
logger.debug(f" Tool {tr['name']}: no image_base64") logger.debug(f" Tool {tr['name']}: no image_base64")
logger.info(f"{cloud['name'].upper()} response received, {tokens_used} tokens") logger.info(f"{cloud['name'].upper()} response received, {tokens_used} tokens")
response_text = await _finalize_response_text(response_text, f"{cloud['name']}-cloud")
# Store message in agent-specific memory (async, non-blocking) # Store message in agent-specific memory (async, non-blocking)
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id: if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id:
@@ -2563,6 +2767,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
"Я не отримав корисну відповідь з першої спроби. " "Я не отримав корисну відповідь з першої спроби. "
"Сформулюй запит коротко ще раз, і я відповім конкретно." "Сформулюй запит коротко ще раз, і я відповім конкретно."
) )
local_response = await _finalize_response_text(local_response, "swapper+ollama")
# Store in agent-specific memory # Store in agent-specific memory
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id and local_response: if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id and local_response:
@@ -2607,8 +2812,9 @@ async def agent_infer(agent_id: str, request: InferRequest):
if generate_resp.status_code == 200: if generate_resp.status_code == 200:
data = generate_resp.json() data = generate_resp.json()
fallback_text = await _finalize_response_text(data.get("response", ""), "ollama-direct")
return InferResponse( return InferResponse(
response=data.get("response", ""), response=fallback_text,
model=model, model=model,
tokens_used=data.get("eval_count", 0), tokens_used=data.get("eval_count", 0),
backend="ollama-direct" backend="ollama-direct"

View File

@@ -22,6 +22,7 @@ import re
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
import hashlib
import httpx import httpx
import asyncpg import asyncpg
@@ -36,6 +37,9 @@ COHERE_API_KEY = os.getenv("COHERE_API_KEY", "")
NEO4J_BOLT_URL = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687") NEO4J_BOLT_URL = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j") NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j") NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j")
PENDING_QUESTIONS_LIMIT = int(os.getenv("AGENT_PENDING_QUESTIONS_LIMIT", "5"))
SHARED_AGRO_LIBRARY_ENABLED = os.getenv("AGROMATRIX_SHARED_LIBRARY_ENABLED", "true").lower() == "true"
SHARED_AGRO_LIBRARY_REQUIRE_REVIEW = os.getenv("AGROMATRIX_SHARED_LIBRARY_REQUIRE_REVIEW", "true").lower() == "true"
@dataclass @dataclass
@@ -62,6 +66,7 @@ class SessionState:
last_answer_fingerprint: Optional[str] = None last_answer_fingerprint: Optional[str] = None
trust_mode: bool = False trust_mode: bool = False
apprentice_mode: bool = False apprentice_mode: bool = False
pending_questions: List[str] = field(default_factory=list)
@dataclass @dataclass
@@ -96,6 +101,10 @@ class MemoryBrief:
lines.append("📚 Режим учня — можеш ставити уточнюючі питання") lines.append("📚 Режим учня — можеш ставити уточнюючі питання")
if self.session_state.active_topic: if self.session_state.active_topic:
lines.append(f"📌 Активна тема: {self.session_state.active_topic}") lines.append(f"📌 Активна тема: {self.session_state.active_topic}")
if self.session_state.pending_questions:
lines.append("🕘 Невідповідані питання в цьому чаті (відповідай на них першочергово):")
for q in self.session_state.pending_questions[:3]:
lines.append(f" - {q[:180]}")
# User facts (preferences, profile) # User facts (preferences, profile)
if self.user_facts: if self.user_facts:
@@ -179,6 +188,7 @@ class MemoryRetrieval:
# HTTP client for embeddings # HTTP client for embeddings
self.http_client = httpx.AsyncClient(timeout=30.0) self.http_client = httpx.AsyncClient(timeout=30.0)
await self._ensure_aux_tables()
async def close(self): async def close(self):
"""Close connections""" """Close connections"""
@@ -188,6 +198,57 @@ class MemoryRetrieval:
await self.neo4j_driver.close() await self.neo4j_driver.close()
if self.http_client: if self.http_client:
await self.http_client.aclose() await self.http_client.aclose()
async def _ensure_aux_tables(self):
"""Create auxiliary tables used by agent runtime policies."""
if not self.pg_pool:
return
try:
async with self.pg_pool.acquire() as conn:
await conn.execute(
"""
CREATE TABLE IF NOT EXISTS agent_session_state (
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
user_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
conversation_id TEXT NOT NULL,
last_user_id TEXT,
last_user_nick TEXT,
active_topic TEXT,
context_open BOOLEAN NOT NULL DEFAULT FALSE,
last_media_handled BOOLEAN NOT NULL DEFAULT TRUE,
last_answer_fingerprint TEXT,
trust_mode BOOLEAN NOT NULL DEFAULT FALSE,
apprentice_mode BOOLEAN NOT NULL DEFAULT FALSE,
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
PRIMARY KEY (channel, chat_id, user_id, agent_id)
);
CREATE INDEX IF NOT EXISTS idx_agent_session_state_conv
ON agent_session_state (conversation_id);
CREATE TABLE IF NOT EXISTS agent_pending_questions (
id BIGSERIAL PRIMARY KEY,
channel TEXT NOT NULL,
chat_id TEXT NOT NULL,
user_id TEXT NOT NULL,
agent_id TEXT NOT NULL,
question_text TEXT NOT NULL,
question_fingerprint TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
answered_at TIMESTAMPTZ,
metadata JSONB NOT NULL DEFAULT '{}'::jsonb
);
CREATE INDEX IF NOT EXISTS idx_agent_pending_questions_scope
ON agent_pending_questions (agent_id, channel, chat_id, user_id, status, created_at DESC);
CREATE UNIQUE INDEX IF NOT EXISTS idx_agent_pending_questions_unique_open
ON agent_pending_questions (agent_id, channel, chat_id, user_id, question_fingerprint, status);
"""
)
except Exception as e:
logger.warning(f"Aux tables init failed: {e}")
# ========================================================================= # =========================================================================
# L2: Platform Identity Resolution # L2: Platform Identity Resolution
@@ -237,7 +298,7 @@ class MemoryRetrieval:
identity.is_mentor = bool(is_mentor) identity.is_mentor = bool(is_mentor)
except Exception as e: except Exception as e:
logger.warning(f"Identity resolution failed: {e}") logger.debug(f"Identity resolution fallback: {e}")
return identity return identity
@@ -249,7 +310,9 @@ class MemoryRetrieval:
self, self,
channel: str, channel: str,
chat_id: str, chat_id: str,
thread_id: Optional[str] = None thread_id: Optional[str] = None,
agent_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> SessionState: ) -> SessionState:
"""Get or create session state for conversation""" """Get or create session state for conversation"""
state = SessionState() state = SessionState()
@@ -259,42 +322,78 @@ class MemoryRetrieval:
try: try:
async with self.pg_pool.acquire() as conn: async with self.pg_pool.acquire() as conn:
# Get or create conversation if agent_id and user_id:
conv_id = await conn.fetchval( conv_id = self._build_conversation_id(channel, chat_id, user_id, agent_id)
"SELECT get_or_create_conversation($1, $2, $3, NULL)", row = await conn.fetchrow(
channel, chat_id, thread_id """
) SELECT conversation_id, active_topic, context_open, last_media_handled,
state.conversation_id = str(conv_id) if conv_id else None last_answer_fingerprint, trust_mode, apprentice_mode
FROM agent_session_state
# Get conversation state WHERE channel = $1
if conv_id: AND chat_id = $2
row = await conn.fetchrow(""" AND user_id = $3
SELECT * FROM helion_conversation_state AND agent_id = $4
WHERE conversation_id = $1 """,
""", conv_id) channel,
chat_id,
if row: user_id,
state.last_addressed = row.get('last_addressed_to_helion', False) agent_id,
state.active_topic = row.get('active_topic_id') )
state.context_open = row.get('active_context_open', False) if not row:
state.last_media_handled = row.get('last_media_handled', True) await conn.execute(
state.last_answer_fingerprint = row.get('last_answer_fingerprint') """
state.trust_mode = row.get('group_trust_mode', False) INSERT INTO agent_session_state
state.apprentice_mode = row.get('apprentice_mode', False) (channel, chat_id, user_id, agent_id, conversation_id)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (channel, chat_id, user_id, agent_id) DO NOTHING
""",
channel,
chat_id,
user_id,
agent_id,
conv_id,
)
state.conversation_id = conv_id
else: else:
# Create initial state state.conversation_id = str(row.get("conversation_id") or conv_id)
await conn.execute(""" state.active_topic = row.get("active_topic")
INSERT INTO helion_conversation_state (conversation_id) state.context_open = bool(row.get("context_open", False))
VALUES ($1) state.last_media_handled = bool(row.get("last_media_handled", True))
ON CONFLICT (conversation_id) DO NOTHING state.last_answer_fingerprint = row.get("last_answer_fingerprint")
""", conv_id) state.trust_mode = bool(row.get("trust_mode", False))
state.apprentice_mode = bool(row.get("apprentice_mode", False))
# Check if trusted group else:
is_trusted = await conn.fetchval( state.conversation_id = self._build_conversation_id(
"SELECT is_trusted_group($1, $2)", channel,
channel, chat_id chat_id,
) user_id or "unknown",
state.trust_mode = bool(is_trusted) agent_id or "agent",
)
if agent_id and user_id:
pending_rows = await conn.fetch(
"""
SELECT question_text
FROM agent_pending_questions
WHERE channel = $1
AND chat_id = $2
AND user_id = $3
AND agent_id = $4
AND status = 'pending'
ORDER BY created_at ASC
LIMIT $5
""",
channel,
chat_id,
user_id,
agent_id,
PENDING_QUESTIONS_LIMIT,
)
state.pending_questions = [
str(r.get("question_text") or "").strip()
for r in pending_rows
if str(r.get("question_text") or "").strip()
]
except Exception as e: except Exception as e:
logger.warning(f"Session state retrieval failed: {e}") logger.warning(f"Session state retrieval failed: {e}")
@@ -494,6 +593,32 @@ class MemoryRetrieval:
}) })
except Exception as e: except Exception as e:
logger.debug(f"{docs_collection} search: {e}") logger.debug(f"{docs_collection} search: {e}")
# Search 4: shared agronomy memory (reviewed, cross-chat, anonymized)
if (
SHARED_AGRO_LIBRARY_ENABLED
and agent_id == "agromatrix"
and self._is_plant_query(query)
):
try:
results = self.qdrant_client.search(
collection_name="agromatrix_shared_library",
query_vector=embedding,
limit=3,
with_payload=True
)
for r in results:
if r.score > 0.45:
text = str(r.payload.get("text") or "").strip()
if len(text) > 20:
all_results.append({
"text": text[:500],
"type": "shared_agro_fact",
"score": r.score + 0.05,
"source": "shared_agronomy_library"
})
except Exception as e:
logger.debug(f"agromatrix_shared_library search: {e}")
# Sort by score and deduplicate # Sort by score and deduplicate
all_results.sort(key=lambda x: x.get("score", 0), reverse=True) all_results.sort(key=lambda x: x.get("score", 0), reverse=True)
@@ -546,6 +671,28 @@ class MemoryRetrieval:
return "" return ""
normalized = re.sub(r"\s+", " ", text.strip().lower()) normalized = re.sub(r"\s+", " ", text.strip().lower())
return normalized[:220] return normalized[:220]
@staticmethod
def _is_plant_query(text: str) -> bool:
q = (text or "").lower()
if not q:
return False
markers = [
"рослин", "культур", "лист", "стебл", "бур'ян", "хвороб", "шкідник",
"what plant", "identify plant", "crop", "species", "leaf", "stem",
"что за растение", "культура", "листок", "фото рослини"
]
return any(m in q for m in markers)
@staticmethod
def _question_fingerprint(question_text: str) -> str:
normalized = re.sub(r"\s+", " ", (question_text or "").strip().lower())
return hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:16]
@staticmethod
def _build_conversation_id(channel: str, chat_id: str, user_id: str, agent_id: str) -> str:
seed = f"{channel}:{chat_id}:{user_id}:{agent_id}"
return hashlib.sha1(seed.encode("utf-8")).hexdigest()[:24]
async def get_user_graph_context( async def get_user_graph_context(
self, self,
@@ -639,7 +786,13 @@ class MemoryRetrieval:
brief.user_identity = identity brief.user_identity = identity
# L1: Session State # L1: Session State
session = await self.get_session_state(channel, chat_id, thread_id) session = await self.get_session_state(
channel,
chat_id,
thread_id,
agent_id=agent_id,
user_id=user_id,
)
brief.session_state = session brief.session_state = session
brief.is_trusted_group = session.trust_mode brief.is_trusted_group = session.trust_mode
@@ -749,6 +902,22 @@ class MemoryRetrieval:
) )
] ]
) )
# Optional shared agronomy memory:
# - never stores user/chat identifiers
# - supports review gate (pending vs approved)
if (
SHARED_AGRO_LIBRARY_ENABLED
and agent_id == "agromatrix"
and message_type in {"vision", "conversation"}
and isinstance(metadata, dict)
and metadata.get("deterministic_plant_id")
):
await self._store_shared_agronomy_memory(
message_text=message_text,
response_text=response_text,
metadata=metadata,
)
logger.debug(f"✅ Stored message in {messages_collection}: {point_id[:8]}...") logger.debug(f"✅ Stored message in {messages_collection}: {point_id[:8]}...")
return True return True
@@ -756,6 +925,202 @@ class MemoryRetrieval:
except Exception as e: except Exception as e:
logger.warning(f"Failed to store message in {messages_collection}: {e}") logger.warning(f"Failed to store message in {messages_collection}: {e}")
return False return False
async def _store_shared_agronomy_memory(
self,
message_text: str,
response_text: str,
metadata: Dict[str, Any],
) -> bool:
if not self.qdrant_client or not COHERE_API_KEY:
return False
try:
from qdrant_client.http import models as qmodels
import uuid
reviewed = bool(metadata.get("mentor_confirmed") or metadata.get("reviewed"))
collection = "agromatrix_shared_library"
if SHARED_AGRO_LIBRARY_REQUIRE_REVIEW and not reviewed:
collection = "agromatrix_shared_pending"
try:
self.qdrant_client.get_collection(collection)
except Exception:
self.qdrant_client.create_collection(
collection_name=collection,
vectors_config=qmodels.VectorParams(
size=1024,
distance=qmodels.Distance.COSINE,
),
)
compact = (
f"Plant case\nQuestion: {message_text[:800]}\n"
f"Answer: {response_text[:1200]}\n"
f"Candidates: {json.dumps(metadata.get('candidates', []), ensure_ascii=False)[:1200]}"
)
embedding = await self.get_embedding(compact[:2000])
if not embedding:
return False
payload = {
"text": compact[:3000],
"type": "plant_case",
"deterministic_plant_id": True,
"decision": metadata.get("decision"),
"confidence_threshold": metadata.get("confidence_threshold"),
"candidates": metadata.get("candidates", [])[:5],
"reviewed": reviewed,
"timestamp": datetime.utcnow().isoformat(),
}
self.qdrant_client.upsert(
collection_name=collection,
points=[qmodels.PointStruct(id=str(uuid.uuid4()), vector=embedding, payload=payload)],
)
return True
except Exception as e:
logger.debug(f"Shared agronomy memory store failed: {e}")
return False
async def register_pending_question(
self,
channel: str,
chat_id: str,
user_id: str,
agent_id: str,
question_text: str,
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
if not self.pg_pool:
return False
text = (question_text or "").strip()
if not text:
return False
fp = self._question_fingerprint(text)
try:
async with self.pg_pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO agent_pending_questions
(channel, chat_id, user_id, agent_id, question_text, question_fingerprint, status, metadata)
VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7::jsonb)
ON CONFLICT (agent_id, channel, chat_id, user_id, question_fingerprint, status)
DO NOTHING
""",
channel,
chat_id,
user_id,
agent_id,
text[:1200],
fp,
json.dumps(metadata or {}, ensure_ascii=False),
)
# Keep only last N open items.
await conn.execute(
"""
WITH ranked AS (
SELECT id, ROW_NUMBER() OVER (
PARTITION BY channel, chat_id, user_id, agent_id, status
ORDER BY created_at DESC
) AS rn
FROM agent_pending_questions
WHERE channel = $1
AND chat_id = $2
AND user_id = $3
AND agent_id = $4
AND status = 'pending'
)
UPDATE agent_pending_questions p
SET status = 'dismissed',
answered_at = NOW(),
metadata = COALESCE(p.metadata, '{}'::jsonb) || '{"reason":"overflow_trim"}'::jsonb
FROM ranked r
WHERE p.id = r.id
AND r.rn > $5
""",
channel,
chat_id,
user_id,
agent_id,
max(1, PENDING_QUESTIONS_LIMIT),
)
return True
except Exception as e:
logger.warning(f"register_pending_question failed: {e}")
return False
async def resolve_pending_question(
self,
channel: str,
chat_id: str,
user_id: str,
agent_id: str,
answer_text: Optional[str] = None,
reason: str = "answered",
) -> bool:
if not self.pg_pool:
return False
try:
async with self.pg_pool.acquire() as conn:
row = await conn.fetchrow(
"""
WITH target AS (
SELECT id
FROM agent_pending_questions
WHERE channel = $1
AND chat_id = $2
AND user_id = $3
AND agent_id = $4
AND status = 'pending'
ORDER BY created_at ASC
LIMIT 1
)
UPDATE agent_pending_questions p
SET status = CASE WHEN $5 = 'dismissed' THEN 'dismissed' ELSE 'answered' END,
answered_at = NOW(),
metadata = COALESCE(p.metadata, '{}'::jsonb)
|| jsonb_build_object(
'resolution_reason', $5,
'answer_fingerprint', COALESCE($6, '')
)
FROM target t
WHERE p.id = t.id
RETURNING p.id
""",
channel,
chat_id,
user_id,
agent_id,
reason,
self._question_fingerprint(answer_text or "") if answer_text else "",
)
return bool(row)
except Exception as e:
logger.warning(f"resolve_pending_question failed: {e}")
return False
async def store_interaction(
self,
channel: str,
chat_id: str,
user_id: str,
agent_id: str,
username: Optional[str],
user_message: str,
assistant_response: str,
metadata: Optional[Dict[str, Any]] = None,
) -> bool:
# Backward-compatible wrapper for older call sites.
return await self.store_message(
agent_id=agent_id,
user_id=user_id,
username=username,
message_text=user_message,
response_text=assistant_response,
chat_id=chat_id,
message_type="conversation",
metadata=metadata,
)
async def update_session_state( async def update_session_state(
self, self,
@@ -774,10 +1139,10 @@ class MemoryRetrieval:
param_idx = 2 param_idx = 2
allowed_fields = [ allowed_fields = [
'last_addressed_to_helion', 'last_user_id', 'last_user_nick', 'last_user_id', 'last_user_nick',
'active_topic_id', 'active_context_open', 'last_media_id', 'active_topic', 'context_open',
'last_media_handled', 'last_answer_fingerprint', 'group_trust_mode', 'last_media_handled', 'last_answer_fingerprint',
'apprentice_mode', 'proactive_questions_today' 'trust_mode', 'apprentice_mode'
] ]
for field, value in updates.items(): for field, value in updates.items():
@@ -787,7 +1152,7 @@ class MemoryRetrieval:
param_idx += 1 param_idx += 1
query = f""" query = f"""
UPDATE helion_conversation_state UPDATE agent_session_state
SET {', '.join(set_clauses)} SET {', '.join(set_clauses)}
WHERE conversation_id = $1 WHERE conversation_id = $1
""" """