agromatrix: add pending-question memory, anti-repeat guard, and numeric contract
This commit is contained in:
@@ -11,6 +11,7 @@ import httpx
|
||||
import logging
|
||||
import hashlib
|
||||
import time # For latency metrics
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
# CrewAI Integration
|
||||
try:
|
||||
@@ -262,12 +263,114 @@ def _build_agromatrix_deterministic_fallback(candidates: List[Dict[str, Any]]) -
|
||||
|
||||
|
||||
EMPTY_ANSWER_GUARD_AGENTS = {"devtools", "monitor"}
|
||||
DETERMINISTIC_PLANT_POLICY_AGENTS = {
|
||||
part.strip().lower()
|
||||
for part in os.getenv(
|
||||
"DETERMINISTIC_PLANT_POLICY_AGENTS",
|
||||
"agromatrix,greenfood,nutra",
|
||||
).split(",")
|
||||
if part.strip()
|
||||
}
|
||||
REPEAT_FINGERPRINT_MIN_SIMILARITY = float(os.getenv("AGENT_REPEAT_FINGERPRINT_MIN_SIMILARITY", "0.92"))
|
||||
|
||||
|
||||
def _normalize_text_response(text: str) -> str:
|
||||
return re.sub(r"\s+", " ", str(text or "")).strip()
|
||||
|
||||
|
||||
def _response_fingerprint(text: str) -> str:
|
||||
normalized = _normalize_text_response(text).lower()
|
||||
normalized = re.sub(r"[^a-zа-яіїєґ0-9%./:;,+\- ]+", " ", normalized)
|
||||
normalized = re.sub(r"\s+", " ", normalized).strip()
|
||||
return normalized[:240]
|
||||
|
||||
|
||||
def _fingerprint_similarity(a: str, b: str) -> float:
|
||||
if not a or not b:
|
||||
return 0.0
|
||||
return SequenceMatcher(None, a, b).ratio()
|
||||
|
||||
|
||||
def _looks_like_user_question(text: str) -> bool:
|
||||
t = (text or "").strip().lower()
|
||||
if not t:
|
||||
return False
|
||||
if "?" in t:
|
||||
return True
|
||||
starters = (
|
||||
"що", "як", "чому", "коли", "де", "скільки", "яка", "який", "які",
|
||||
"what", "how", "why", "when", "where", "which", "can you",
|
||||
"что", "как", "почему", "когда", "где", "сколько",
|
||||
)
|
||||
return any(t.startswith(s + " ") for s in starters)
|
||||
|
||||
|
||||
def _looks_like_negative_feedback(text: str) -> bool:
|
||||
t = (text or "").lower()
|
||||
markers = (
|
||||
"не вірно", "невірно", "неправильно", "помилка", "знову не так",
|
||||
"це не так", "не релевантно", "повтор", "ти знову", "мимо",
|
||||
"wrong", "incorrect", "not relevant", "repeat", "again wrong",
|
||||
"неверно", "неправильно", "это ошибка", "снова не так",
|
||||
)
|
||||
return any(m in t for m in markers)
|
||||
|
||||
|
||||
def _looks_like_numeric_request(text: str) -> bool:
|
||||
t = (text or "").lower()
|
||||
markers = (
|
||||
"скільки", "сума", "витра", "cost", "total", "amount", "ціна",
|
||||
"вартість", "дохід", "прибут", "маржа", "баланс", "unit cost",
|
||||
"сколько", "сумма", "затрат", "стоимость", "расход",
|
||||
)
|
||||
return any(m in t for m in markers)
|
||||
|
||||
|
||||
def _numeric_contract_present(text: str) -> bool:
|
||||
t = _normalize_text_response(text)
|
||||
low = t.lower()
|
||||
if not re.search(r"\d", low):
|
||||
return False
|
||||
has_value_with_unit = re.search(
|
||||
r"\b\d[\d\s.,]*\s*(грн|uah|usd|eur|kg|кг|т|л|га|шт|%|тон|літр|hectare|ha)\b",
|
||||
low,
|
||||
) is not None
|
||||
has_explicit_source = any(
|
||||
re.search(pattern, low) is not None
|
||||
for pattern in (
|
||||
r"\bsheet\s*[:#]?\s*[a-z0-9_]+",
|
||||
r"\brow\s*[:#]?\s*\d+",
|
||||
r"\bрядок\s*[:#]?\s*\d+",
|
||||
r"\bлист\s*[:#]?\s*[a-zа-я0-9_]+",
|
||||
r"\bcell\s*[:#]?\s*[a-z]+\d+",
|
||||
r"\bкомірк[а-я]*\s*[:#]?\s*[a-zа-я]+\d+",
|
||||
r"\bsource\s*[:#]",
|
||||
r"\bджерел[оа]\s*[:#]",
|
||||
)
|
||||
)
|
||||
return bool(has_value_with_unit and has_explicit_source)
|
||||
|
||||
|
||||
def _build_numeric_contract_uncertain_response() -> str:
|
||||
return (
|
||||
"Не можу підтвердити точне число без джерела. "
|
||||
"Щоб дати коректну відповідь, надішли таблицю/файл або уточни лист і діапазон. "
|
||||
"Формат відповіді дам строго як: value + unit + source(sheet,row)."
|
||||
)
|
||||
|
||||
|
||||
def _response_is_uncertain_or_incomplete(text: str) -> bool:
|
||||
low = _normalize_text_response(text).lower()
|
||||
if not low:
|
||||
return True
|
||||
markers = (
|
||||
"не впевнений", "не можу", "надішли", "уточни", "уточніть",
|
||||
"потрібно більше", "insufficient", "need more", "please send",
|
||||
"не уверен", "не могу", "уточни", "нужно больше",
|
||||
)
|
||||
return any(m in low for m in markers)
|
||||
|
||||
|
||||
def _needs_empty_answer_recovery(text: str) -> bool:
|
||||
normalized = _normalize_text_response(text)
|
||||
if not normalized:
|
||||
@@ -1369,6 +1472,8 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
# MEMORY RETRIEVAL (v4.0 - Universal for all agents)
|
||||
# =========================================================================
|
||||
memory_brief_text = ""
|
||||
brief: Optional[MemoryBrief] = None
|
||||
session_state = None
|
||||
# Extract metadata once for both retrieval and storage
|
||||
metadata = request.metadata or {}
|
||||
channel = "telegram" # Default
|
||||
@@ -1382,7 +1487,32 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
# IMPORTANT: inspect only the latest user text when provided by gateway,
|
||||
# not the full context-augmented prompt.
|
||||
raw_user_text = str(metadata.get("raw_user_text", "") or "").strip()
|
||||
image_guard_text = raw_user_text if raw_user_text else request.prompt
|
||||
incoming_user_text = raw_user_text if raw_user_text else request.prompt
|
||||
image_guard_text = incoming_user_text
|
||||
track_pending_question = _looks_like_user_question(incoming_user_text)
|
||||
|
||||
if (
|
||||
MEMORY_RETRIEVAL_AVAILABLE
|
||||
and memory_retrieval
|
||||
and chat_id
|
||||
and user_id
|
||||
and track_pending_question
|
||||
):
|
||||
try:
|
||||
await memory_retrieval.register_pending_question(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
agent_id=request_agent_id,
|
||||
question_text=incoming_user_text,
|
||||
metadata={
|
||||
"source": "router_infer",
|
||||
"has_images": bool(request.images),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Pending question register skipped: {e}")
|
||||
|
||||
if (not request.images) and _looks_like_image_question(image_guard_text):
|
||||
return InferResponse(
|
||||
response=(
|
||||
@@ -1405,6 +1535,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
username=username,
|
||||
message=request.prompt
|
||||
)
|
||||
session_state = brief.session_state if brief else None
|
||||
memory_brief_text = brief.to_text(max_lines=10)
|
||||
if memory_brief_text:
|
||||
logger.info(f"🧠 Memory brief for {request_agent_id}: {len(memory_brief_text)} chars")
|
||||
@@ -1454,6 +1585,63 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
f"🧩 Prompt meta for {agent_id}: source={system_prompt_source}, "
|
||||
f"version={effective_metadata['system_prompt_version']}, hash={system_prompt_hash}"
|
||||
)
|
||||
|
||||
async def _finalize_response_text(text: str, backend_tag: str) -> str:
|
||||
final_text = _normalize_text_response(text)
|
||||
if not final_text:
|
||||
return final_text
|
||||
|
||||
# Agro numeric contract: no numbers without unit + source marker.
|
||||
if request_agent_id == "agromatrix" and _looks_like_numeric_request(incoming_user_text):
|
||||
if not _numeric_contract_present(final_text):
|
||||
final_text = _build_numeric_contract_uncertain_response()
|
||||
|
||||
# Anti-repeat guard: if user reports wrong answer and new answer is near-identical
|
||||
# to previous one, force non-repetitive recovery text.
|
||||
prev_fp = ""
|
||||
if session_state and getattr(session_state, "last_answer_fingerprint", None):
|
||||
prev_fp = str(session_state.last_answer_fingerprint or "")
|
||||
new_fp = _response_fingerprint(final_text)
|
||||
if prev_fp and new_fp:
|
||||
similarity = _fingerprint_similarity(prev_fp, new_fp)
|
||||
if similarity >= REPEAT_FINGERPRINT_MIN_SIMILARITY and _looks_like_negative_feedback(incoming_user_text):
|
||||
final_text = (
|
||||
"Прийняв, попередня відповідь була не по суті. Не повторюю її. "
|
||||
"Переформулюю коротко і по ділу: надішли 1 конкретне питання або файл/фото, "
|
||||
"і я дам перевірену відповідь із джерелом."
|
||||
)
|
||||
new_fp = _response_fingerprint(final_text)
|
||||
logger.warning(
|
||||
f"🔁 Repeat guard fired for {request_agent_id}: similarity={similarity:.3f}, backend={backend_tag}"
|
||||
)
|
||||
|
||||
# Resolve oldest pending question only when answer is not uncertain.
|
||||
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id:
|
||||
try:
|
||||
if track_pending_question and not _response_is_uncertain_or_incomplete(final_text):
|
||||
await memory_retrieval.resolve_pending_question(
|
||||
channel=channel,
|
||||
chat_id=chat_id,
|
||||
user_id=user_id,
|
||||
agent_id=request_agent_id,
|
||||
answer_text=final_text,
|
||||
reason="answered",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Pending question resolve skipped: {e}")
|
||||
|
||||
try:
|
||||
if session_state and getattr(session_state, "conversation_id", None):
|
||||
await memory_retrieval.update_session_state(
|
||||
session_state.conversation_id,
|
||||
last_answer_fingerprint=new_fp[:240],
|
||||
last_user_id=user_id,
|
||||
last_user_nick=username,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Session fingerprint update skipped: {e}")
|
||||
|
||||
return final_text
|
||||
|
||||
# Determine which backend to use
|
||||
# Use router config to get default model for agent, fallback to qwen3:8b
|
||||
@@ -1601,6 +1789,8 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
parts = re.split(r"(?<=[.!?])\s+", final_response_text.strip())
|
||||
if len(parts) > 3:
|
||||
final_response_text = " ".join(parts[:3]).strip()
|
||||
|
||||
final_response_text = await _finalize_response_text(final_response_text, "crewai")
|
||||
|
||||
# Store interaction in memory
|
||||
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id:
|
||||
@@ -1656,7 +1846,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
# 1) run plant classifiers first (nature-id / plantnet)
|
||||
# 2) apply confidence threshold
|
||||
# 3) LLM only explains classifier result, no new guessing
|
||||
if request_agent_id == "agromatrix" and plant_intent and TOOL_MANAGER_AVAILABLE and tool_manager:
|
||||
if request_agent_id in DETERMINISTIC_PLANT_POLICY_AGENTS and plant_intent and TOOL_MANAGER_AVAILABLE and tool_manager:
|
||||
try:
|
||||
image_inputs = _extract_image_inputs_for_plant_tools(request.images, metadata)
|
||||
if image_inputs:
|
||||
@@ -1697,6 +1887,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
top_conf = float(candidates[0].get("confidence", 0.0)) if candidates else 0.0
|
||||
if (not candidates) or (top_conf < threshold):
|
||||
response_text = _build_agromatrix_not_sure_response(candidates, threshold)
|
||||
response_text = await _finalize_response_text(response_text, "plant-id-deterministic-uncertain")
|
||||
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id:
|
||||
asyncio.create_task(
|
||||
memory_retrieval.store_message(
|
||||
@@ -1770,6 +1961,8 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
if (top_name and top_name not in low) and (top_sci and top_sci not in low):
|
||||
response_text = _build_agromatrix_deterministic_fallback(candidates)
|
||||
|
||||
response_text = await _finalize_response_text(response_text, llm_backend)
|
||||
|
||||
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id:
|
||||
asyncio.create_task(
|
||||
memory_retrieval.store_message(
|
||||
@@ -1916,7 +2109,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
|
||||
# Plant identification safety gate:
|
||||
# avoid hard species claims when confidence is low or evidence is weak.
|
||||
if request_agent_id == "agromatrix" and plant_intent and (uncertain or len(vision_sources) < 2):
|
||||
if request_agent_id in DETERMINISTIC_PLANT_POLICY_AGENTS and plant_intent and (uncertain or len(vision_sources) < 2):
|
||||
full_response = _build_cautious_plant_response(full_response or raw_response, len(vision_sources))
|
||||
|
||||
# Image quality gate: one soft retry if response looks empty/meta.
|
||||
@@ -1948,8 +2141,10 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
|
||||
if _image_response_needs_retry(full_response):
|
||||
full_response = _build_image_fallback_response(request_agent_id, request.prompt)
|
||||
elif request_agent_id == "agromatrix" and _vision_response_is_blurry(full_response):
|
||||
elif request_agent_id in DETERMINISTIC_PLANT_POLICY_AGENTS and _vision_response_is_blurry(full_response):
|
||||
full_response = _build_image_fallback_response(request_agent_id, request.prompt)
|
||||
|
||||
full_response = await _finalize_response_text(full_response, "swapper-vision")
|
||||
|
||||
# Store vision message in agent-specific memory
|
||||
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id and full_response:
|
||||
@@ -1979,8 +2174,12 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
)
|
||||
else:
|
||||
logger.error(f"❌ Swapper vision error: {vision_resp.status_code} - {vision_resp.text[:200]}")
|
||||
fallback_response = await _finalize_response_text(
|
||||
_build_image_fallback_response(request_agent_id, request.prompt),
|
||||
"swapper-vision-fallback",
|
||||
)
|
||||
return InferResponse(
|
||||
response=_build_image_fallback_response(request_agent_id, request.prompt),
|
||||
response=fallback_response,
|
||||
model="qwen3-vl-8b",
|
||||
tokens_used=None,
|
||||
backend="swapper-vision-fallback"
|
||||
@@ -1988,8 +2187,12 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Vision processing failed: {e}", exc_info=True)
|
||||
fallback_response = await _finalize_response_text(
|
||||
_build_image_fallback_response(request_agent_id, request.prompt),
|
||||
"swapper-vision-fallback",
|
||||
)
|
||||
return InferResponse(
|
||||
response=_build_image_fallback_response(request_agent_id, request.prompt),
|
||||
response=fallback_response,
|
||||
model="qwen3-vl-8b",
|
||||
tokens_used=None,
|
||||
backend="swapper-vision-fallback"
|
||||
@@ -2435,6 +2638,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
logger.debug(f" Tool {tr['name']}: no image_base64")
|
||||
|
||||
logger.info(f"✅ {cloud['name'].upper()} response received, {tokens_used} tokens")
|
||||
response_text = await _finalize_response_text(response_text, f"{cloud['name']}-cloud")
|
||||
|
||||
# Store message in agent-specific memory (async, non-blocking)
|
||||
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id:
|
||||
@@ -2563,6 +2767,7 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
"Я не отримав корисну відповідь з першої спроби. "
|
||||
"Сформулюй запит коротко ще раз, і я відповім конкретно."
|
||||
)
|
||||
local_response = await _finalize_response_text(local_response, "swapper+ollama")
|
||||
|
||||
# Store in agent-specific memory
|
||||
if MEMORY_RETRIEVAL_AVAILABLE and memory_retrieval and chat_id and user_id and local_response:
|
||||
@@ -2607,8 +2812,9 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
|
||||
if generate_resp.status_code == 200:
|
||||
data = generate_resp.json()
|
||||
fallback_text = await _finalize_response_text(data.get("response", ""), "ollama-direct")
|
||||
return InferResponse(
|
||||
response=data.get("response", ""),
|
||||
response=fallback_text,
|
||||
model=model,
|
||||
tokens_used=data.get("eval_count", 0),
|
||||
backend="ollama-direct"
|
||||
|
||||
@@ -22,6 +22,7 @@ import re
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
import hashlib
|
||||
|
||||
import httpx
|
||||
import asyncpg
|
||||
@@ -36,6 +37,9 @@ COHERE_API_KEY = os.getenv("COHERE_API_KEY", "")
|
||||
NEO4J_BOLT_URL = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687")
|
||||
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
|
||||
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "neo4j")
|
||||
PENDING_QUESTIONS_LIMIT = int(os.getenv("AGENT_PENDING_QUESTIONS_LIMIT", "5"))
|
||||
SHARED_AGRO_LIBRARY_ENABLED = os.getenv("AGROMATRIX_SHARED_LIBRARY_ENABLED", "true").lower() == "true"
|
||||
SHARED_AGRO_LIBRARY_REQUIRE_REVIEW = os.getenv("AGROMATRIX_SHARED_LIBRARY_REQUIRE_REVIEW", "true").lower() == "true"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -62,6 +66,7 @@ class SessionState:
|
||||
last_answer_fingerprint: Optional[str] = None
|
||||
trust_mode: bool = False
|
||||
apprentice_mode: bool = False
|
||||
pending_questions: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -96,6 +101,10 @@ class MemoryBrief:
|
||||
lines.append("📚 Режим учня — можеш ставити уточнюючі питання")
|
||||
if self.session_state.active_topic:
|
||||
lines.append(f"📌 Активна тема: {self.session_state.active_topic}")
|
||||
if self.session_state.pending_questions:
|
||||
lines.append("🕘 Невідповідані питання в цьому чаті (відповідай на них першочергово):")
|
||||
for q in self.session_state.pending_questions[:3]:
|
||||
lines.append(f" - {q[:180]}")
|
||||
|
||||
# User facts (preferences, profile)
|
||||
if self.user_facts:
|
||||
@@ -179,6 +188,7 @@ class MemoryRetrieval:
|
||||
|
||||
# HTTP client for embeddings
|
||||
self.http_client = httpx.AsyncClient(timeout=30.0)
|
||||
await self._ensure_aux_tables()
|
||||
|
||||
async def close(self):
|
||||
"""Close connections"""
|
||||
@@ -188,6 +198,57 @@ class MemoryRetrieval:
|
||||
await self.neo4j_driver.close()
|
||||
if self.http_client:
|
||||
await self.http_client.aclose()
|
||||
|
||||
async def _ensure_aux_tables(self):
|
||||
"""Create auxiliary tables used by agent runtime policies."""
|
||||
if not self.pg_pool:
|
||||
return
|
||||
try:
|
||||
async with self.pg_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS agent_session_state (
|
||||
channel TEXT NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id TEXT NOT NULL,
|
||||
conversation_id TEXT NOT NULL,
|
||||
last_user_id TEXT,
|
||||
last_user_nick TEXT,
|
||||
active_topic TEXT,
|
||||
context_open BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
last_media_handled BOOLEAN NOT NULL DEFAULT TRUE,
|
||||
last_answer_fingerprint TEXT,
|
||||
trust_mode BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
apprentice_mode BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (channel, chat_id, user_id, agent_id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_agent_session_state_conv
|
||||
ON agent_session_state (conversation_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS agent_pending_questions (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
channel TEXT NOT NULL,
|
||||
chat_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
agent_id TEXT NOT NULL,
|
||||
question_text TEXT NOT NULL,
|
||||
question_fingerprint TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
answered_at TIMESTAMPTZ,
|
||||
metadata JSONB NOT NULL DEFAULT '{}'::jsonb
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_agent_pending_questions_scope
|
||||
ON agent_pending_questions (agent_id, channel, chat_id, user_id, status, created_at DESC);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_agent_pending_questions_unique_open
|
||||
ON agent_pending_questions (agent_id, channel, chat_id, user_id, question_fingerprint, status);
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Aux tables init failed: {e}")
|
||||
|
||||
# =========================================================================
|
||||
# L2: Platform Identity Resolution
|
||||
@@ -237,7 +298,7 @@ class MemoryRetrieval:
|
||||
identity.is_mentor = bool(is_mentor)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Identity resolution failed: {e}")
|
||||
logger.debug(f"Identity resolution fallback: {e}")
|
||||
|
||||
return identity
|
||||
|
||||
@@ -249,7 +310,9 @@ class MemoryRetrieval:
|
||||
self,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
thread_id: Optional[str] = None
|
||||
thread_id: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> SessionState:
|
||||
"""Get or create session state for conversation"""
|
||||
state = SessionState()
|
||||
@@ -259,42 +322,78 @@ class MemoryRetrieval:
|
||||
|
||||
try:
|
||||
async with self.pg_pool.acquire() as conn:
|
||||
# Get or create conversation
|
||||
conv_id = await conn.fetchval(
|
||||
"SELECT get_or_create_conversation($1, $2, $3, NULL)",
|
||||
channel, chat_id, thread_id
|
||||
)
|
||||
state.conversation_id = str(conv_id) if conv_id else None
|
||||
|
||||
# Get conversation state
|
||||
if conv_id:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT * FROM helion_conversation_state
|
||||
WHERE conversation_id = $1
|
||||
""", conv_id)
|
||||
|
||||
if row:
|
||||
state.last_addressed = row.get('last_addressed_to_helion', False)
|
||||
state.active_topic = row.get('active_topic_id')
|
||||
state.context_open = row.get('active_context_open', False)
|
||||
state.last_media_handled = row.get('last_media_handled', True)
|
||||
state.last_answer_fingerprint = row.get('last_answer_fingerprint')
|
||||
state.trust_mode = row.get('group_trust_mode', False)
|
||||
state.apprentice_mode = row.get('apprentice_mode', False)
|
||||
if agent_id and user_id:
|
||||
conv_id = self._build_conversation_id(channel, chat_id, user_id, agent_id)
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
SELECT conversation_id, active_topic, context_open, last_media_handled,
|
||||
last_answer_fingerprint, trust_mode, apprentice_mode
|
||||
FROM agent_session_state
|
||||
WHERE channel = $1
|
||||
AND chat_id = $2
|
||||
AND user_id = $3
|
||||
AND agent_id = $4
|
||||
""",
|
||||
channel,
|
||||
chat_id,
|
||||
user_id,
|
||||
agent_id,
|
||||
)
|
||||
if not row:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO agent_session_state
|
||||
(channel, chat_id, user_id, agent_id, conversation_id)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (channel, chat_id, user_id, agent_id) DO NOTHING
|
||||
""",
|
||||
channel,
|
||||
chat_id,
|
||||
user_id,
|
||||
agent_id,
|
||||
conv_id,
|
||||
)
|
||||
state.conversation_id = conv_id
|
||||
else:
|
||||
# Create initial state
|
||||
await conn.execute("""
|
||||
INSERT INTO helion_conversation_state (conversation_id)
|
||||
VALUES ($1)
|
||||
ON CONFLICT (conversation_id) DO NOTHING
|
||||
""", conv_id)
|
||||
|
||||
# Check if trusted group
|
||||
is_trusted = await conn.fetchval(
|
||||
"SELECT is_trusted_group($1, $2)",
|
||||
channel, chat_id
|
||||
)
|
||||
state.trust_mode = bool(is_trusted)
|
||||
state.conversation_id = str(row.get("conversation_id") or conv_id)
|
||||
state.active_topic = row.get("active_topic")
|
||||
state.context_open = bool(row.get("context_open", False))
|
||||
state.last_media_handled = bool(row.get("last_media_handled", True))
|
||||
state.last_answer_fingerprint = row.get("last_answer_fingerprint")
|
||||
state.trust_mode = bool(row.get("trust_mode", False))
|
||||
state.apprentice_mode = bool(row.get("apprentice_mode", False))
|
||||
else:
|
||||
state.conversation_id = self._build_conversation_id(
|
||||
channel,
|
||||
chat_id,
|
||||
user_id or "unknown",
|
||||
agent_id or "agent",
|
||||
)
|
||||
|
||||
if agent_id and user_id:
|
||||
pending_rows = await conn.fetch(
|
||||
"""
|
||||
SELECT question_text
|
||||
FROM agent_pending_questions
|
||||
WHERE channel = $1
|
||||
AND chat_id = $2
|
||||
AND user_id = $3
|
||||
AND agent_id = $4
|
||||
AND status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT $5
|
||||
""",
|
||||
channel,
|
||||
chat_id,
|
||||
user_id,
|
||||
agent_id,
|
||||
PENDING_QUESTIONS_LIMIT,
|
||||
)
|
||||
state.pending_questions = [
|
||||
str(r.get("question_text") or "").strip()
|
||||
for r in pending_rows
|
||||
if str(r.get("question_text") or "").strip()
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Session state retrieval failed: {e}")
|
||||
@@ -494,6 +593,32 @@ class MemoryRetrieval:
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"{docs_collection} search: {e}")
|
||||
|
||||
# Search 4: shared agronomy memory (reviewed, cross-chat, anonymized)
|
||||
if (
|
||||
SHARED_AGRO_LIBRARY_ENABLED
|
||||
and agent_id == "agromatrix"
|
||||
and self._is_plant_query(query)
|
||||
):
|
||||
try:
|
||||
results = self.qdrant_client.search(
|
||||
collection_name="agromatrix_shared_library",
|
||||
query_vector=embedding,
|
||||
limit=3,
|
||||
with_payload=True
|
||||
)
|
||||
for r in results:
|
||||
if r.score > 0.45:
|
||||
text = str(r.payload.get("text") or "").strip()
|
||||
if len(text) > 20:
|
||||
all_results.append({
|
||||
"text": text[:500],
|
||||
"type": "shared_agro_fact",
|
||||
"score": r.score + 0.05,
|
||||
"source": "shared_agronomy_library"
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"agromatrix_shared_library search: {e}")
|
||||
|
||||
# Sort by score and deduplicate
|
||||
all_results.sort(key=lambda x: x.get("score", 0), reverse=True)
|
||||
@@ -546,6 +671,28 @@ class MemoryRetrieval:
|
||||
return ""
|
||||
normalized = re.sub(r"\s+", " ", text.strip().lower())
|
||||
return normalized[:220]
|
||||
|
||||
@staticmethod
|
||||
def _is_plant_query(text: str) -> bool:
|
||||
q = (text or "").lower()
|
||||
if not q:
|
||||
return False
|
||||
markers = [
|
||||
"рослин", "культур", "лист", "стебл", "бур'ян", "хвороб", "шкідник",
|
||||
"what plant", "identify plant", "crop", "species", "leaf", "stem",
|
||||
"что за растение", "культура", "листок", "фото рослини"
|
||||
]
|
||||
return any(m in q for m in markers)
|
||||
|
||||
@staticmethod
|
||||
def _question_fingerprint(question_text: str) -> str:
|
||||
normalized = re.sub(r"\s+", " ", (question_text or "").strip().lower())
|
||||
return hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:16]
|
||||
|
||||
@staticmethod
|
||||
def _build_conversation_id(channel: str, chat_id: str, user_id: str, agent_id: str) -> str:
|
||||
seed = f"{channel}:{chat_id}:{user_id}:{agent_id}"
|
||||
return hashlib.sha1(seed.encode("utf-8")).hexdigest()[:24]
|
||||
|
||||
async def get_user_graph_context(
|
||||
self,
|
||||
@@ -639,7 +786,13 @@ class MemoryRetrieval:
|
||||
brief.user_identity = identity
|
||||
|
||||
# L1: Session State
|
||||
session = await self.get_session_state(channel, chat_id, thread_id)
|
||||
session = await self.get_session_state(
|
||||
channel,
|
||||
chat_id,
|
||||
thread_id,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
)
|
||||
brief.session_state = session
|
||||
brief.is_trusted_group = session.trust_mode
|
||||
|
||||
@@ -749,6 +902,22 @@ class MemoryRetrieval:
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Optional shared agronomy memory:
|
||||
# - never stores user/chat identifiers
|
||||
# - supports review gate (pending vs approved)
|
||||
if (
|
||||
SHARED_AGRO_LIBRARY_ENABLED
|
||||
and agent_id == "agromatrix"
|
||||
and message_type in {"vision", "conversation"}
|
||||
and isinstance(metadata, dict)
|
||||
and metadata.get("deterministic_plant_id")
|
||||
):
|
||||
await self._store_shared_agronomy_memory(
|
||||
message_text=message_text,
|
||||
response_text=response_text,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
logger.debug(f"✅ Stored message in {messages_collection}: {point_id[:8]}...")
|
||||
return True
|
||||
@@ -756,6 +925,202 @@ class MemoryRetrieval:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to store message in {messages_collection}: {e}")
|
||||
return False
|
||||
|
||||
async def _store_shared_agronomy_memory(
|
||||
self,
|
||||
message_text: str,
|
||||
response_text: str,
|
||||
metadata: Dict[str, Any],
|
||||
) -> bool:
|
||||
if not self.qdrant_client or not COHERE_API_KEY:
|
||||
return False
|
||||
try:
|
||||
from qdrant_client.http import models as qmodels
|
||||
import uuid
|
||||
|
||||
reviewed = bool(metadata.get("mentor_confirmed") or metadata.get("reviewed"))
|
||||
collection = "agromatrix_shared_library"
|
||||
if SHARED_AGRO_LIBRARY_REQUIRE_REVIEW and not reviewed:
|
||||
collection = "agromatrix_shared_pending"
|
||||
|
||||
try:
|
||||
self.qdrant_client.get_collection(collection)
|
||||
except Exception:
|
||||
self.qdrant_client.create_collection(
|
||||
collection_name=collection,
|
||||
vectors_config=qmodels.VectorParams(
|
||||
size=1024,
|
||||
distance=qmodels.Distance.COSINE,
|
||||
),
|
||||
)
|
||||
|
||||
compact = (
|
||||
f"Plant case\nQuestion: {message_text[:800]}\n"
|
||||
f"Answer: {response_text[:1200]}\n"
|
||||
f"Candidates: {json.dumps(metadata.get('candidates', []), ensure_ascii=False)[:1200]}"
|
||||
)
|
||||
embedding = await self.get_embedding(compact[:2000])
|
||||
if not embedding:
|
||||
return False
|
||||
|
||||
payload = {
|
||||
"text": compact[:3000],
|
||||
"type": "plant_case",
|
||||
"deterministic_plant_id": True,
|
||||
"decision": metadata.get("decision"),
|
||||
"confidence_threshold": metadata.get("confidence_threshold"),
|
||||
"candidates": metadata.get("candidates", [])[:5],
|
||||
"reviewed": reviewed,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
}
|
||||
self.qdrant_client.upsert(
|
||||
collection_name=collection,
|
||||
points=[qmodels.PointStruct(id=str(uuid.uuid4()), vector=embedding, payload=payload)],
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Shared agronomy memory store failed: {e}")
|
||||
return False
|
||||
|
||||
async def register_pending_question(
|
||||
self,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
question_text: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
if not self.pg_pool:
|
||||
return False
|
||||
text = (question_text or "").strip()
|
||||
if not text:
|
||||
return False
|
||||
fp = self._question_fingerprint(text)
|
||||
try:
|
||||
async with self.pg_pool.acquire() as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO agent_pending_questions
|
||||
(channel, chat_id, user_id, agent_id, question_text, question_fingerprint, status, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, 'pending', $7::jsonb)
|
||||
ON CONFLICT (agent_id, channel, chat_id, user_id, question_fingerprint, status)
|
||||
DO NOTHING
|
||||
""",
|
||||
channel,
|
||||
chat_id,
|
||||
user_id,
|
||||
agent_id,
|
||||
text[:1200],
|
||||
fp,
|
||||
json.dumps(metadata or {}, ensure_ascii=False),
|
||||
)
|
||||
# Keep only last N open items.
|
||||
await conn.execute(
|
||||
"""
|
||||
WITH ranked AS (
|
||||
SELECT id, ROW_NUMBER() OVER (
|
||||
PARTITION BY channel, chat_id, user_id, agent_id, status
|
||||
ORDER BY created_at DESC
|
||||
) AS rn
|
||||
FROM agent_pending_questions
|
||||
WHERE channel = $1
|
||||
AND chat_id = $2
|
||||
AND user_id = $3
|
||||
AND agent_id = $4
|
||||
AND status = 'pending'
|
||||
)
|
||||
UPDATE agent_pending_questions p
|
||||
SET status = 'dismissed',
|
||||
answered_at = NOW(),
|
||||
metadata = COALESCE(p.metadata, '{}'::jsonb) || '{"reason":"overflow_trim"}'::jsonb
|
||||
FROM ranked r
|
||||
WHERE p.id = r.id
|
||||
AND r.rn > $5
|
||||
""",
|
||||
channel,
|
||||
chat_id,
|
||||
user_id,
|
||||
agent_id,
|
||||
max(1, PENDING_QUESTIONS_LIMIT),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"register_pending_question failed: {e}")
|
||||
return False
|
||||
|
||||
async def resolve_pending_question(
|
||||
self,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
answer_text: Optional[str] = None,
|
||||
reason: str = "answered",
|
||||
) -> bool:
|
||||
if not self.pg_pool:
|
||||
return False
|
||||
try:
|
||||
async with self.pg_pool.acquire() as conn:
|
||||
row = await conn.fetchrow(
|
||||
"""
|
||||
WITH target AS (
|
||||
SELECT id
|
||||
FROM agent_pending_questions
|
||||
WHERE channel = $1
|
||||
AND chat_id = $2
|
||||
AND user_id = $3
|
||||
AND agent_id = $4
|
||||
AND status = 'pending'
|
||||
ORDER BY created_at ASC
|
||||
LIMIT 1
|
||||
)
|
||||
UPDATE agent_pending_questions p
|
||||
SET status = CASE WHEN $5 = 'dismissed' THEN 'dismissed' ELSE 'answered' END,
|
||||
answered_at = NOW(),
|
||||
metadata = COALESCE(p.metadata, '{}'::jsonb)
|
||||
|| jsonb_build_object(
|
||||
'resolution_reason', $5,
|
||||
'answer_fingerprint', COALESCE($6, '')
|
||||
)
|
||||
FROM target t
|
||||
WHERE p.id = t.id
|
||||
RETURNING p.id
|
||||
""",
|
||||
channel,
|
||||
chat_id,
|
||||
user_id,
|
||||
agent_id,
|
||||
reason,
|
||||
self._question_fingerprint(answer_text or "") if answer_text else "",
|
||||
)
|
||||
return bool(row)
|
||||
except Exception as e:
|
||||
logger.warning(f"resolve_pending_question failed: {e}")
|
||||
return False
|
||||
|
||||
async def store_interaction(
|
||||
self,
|
||||
channel: str,
|
||||
chat_id: str,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
username: Optional[str],
|
||||
user_message: str,
|
||||
assistant_response: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> bool:
|
||||
# Backward-compatible wrapper for older call sites.
|
||||
return await self.store_message(
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
username=username,
|
||||
message_text=user_message,
|
||||
response_text=assistant_response,
|
||||
chat_id=chat_id,
|
||||
message_type="conversation",
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
async def update_session_state(
|
||||
self,
|
||||
@@ -774,10 +1139,10 @@ class MemoryRetrieval:
|
||||
param_idx = 2
|
||||
|
||||
allowed_fields = [
|
||||
'last_addressed_to_helion', 'last_user_id', 'last_user_nick',
|
||||
'active_topic_id', 'active_context_open', 'last_media_id',
|
||||
'last_media_handled', 'last_answer_fingerprint', 'group_trust_mode',
|
||||
'apprentice_mode', 'proactive_questions_today'
|
||||
'last_user_id', 'last_user_nick',
|
||||
'active_topic', 'context_open',
|
||||
'last_media_handled', 'last_answer_fingerprint',
|
||||
'trust_mode', 'apprentice_mode'
|
||||
]
|
||||
|
||||
for field, value in updates.items():
|
||||
@@ -787,7 +1152,7 @@ class MemoryRetrieval:
|
||||
param_idx += 1
|
||||
|
||||
query = f"""
|
||||
UPDATE helion_conversation_state
|
||||
UPDATE agent_session_state
|
||||
SET {', '.join(set_clauses)}
|
||||
WHERE conversation_id = $1
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user