🧠 Add Agent Memory System with PostgreSQL + Qdrant + Cohere
Features: - Three-tier memory architecture (short/mid/long-term) - PostgreSQL schema for conversations, events, memories - Qdrant vector database for semantic search - Cohere embeddings (embed-multilingual-v3.0, 1024 dims) - FastAPI Memory Service with full CRUD - External Secrets integration with Vault - Kubernetes deployment manifests Components: - infrastructure/database/agent-memory-schema.sql - infrastructure/kubernetes/apps/qdrant/ - infrastructure/kubernetes/apps/memory-service/ - services/memory-service/ (FastAPI app) Also includes: - External Secrets Operator - Traefik Ingress Controller - Cert-Manager with Let's Encrypt - ArgoCD for GitOps
This commit is contained in:
@@ -1,24 +1,23 @@
|
||||
# DAARION Memory Service
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Встановлюємо системні залежності
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
postgresql-client \
|
||||
curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Копіюємо requirements та встановлюємо залежності
|
||||
# Install dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Копіюємо код
|
||||
COPY . .
|
||||
# Copy application
|
||||
COPY app/ ./app/
|
||||
|
||||
# Відкриваємо порт
|
||||
# Environment
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Run
|
||||
EXPOSE 8000
|
||||
|
||||
# Запускаємо додаток
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
# Memory Service for DAARION.city
|
||||
"""DAARION Memory Service"""
|
||||
|
||||
56
services/memory-service/app/config.py
Normal file
56
services/memory-service/app/config.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
DAARION Memory Service Configuration
|
||||
"""
|
||||
from pydantic_settings import BaseSettings
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings"""
|
||||
|
||||
# Service
|
||||
service_name: str = "memory-service"
|
||||
debug: bool = False
|
||||
|
||||
# PostgreSQL
|
||||
postgres_host: str = "daarion-pooler.daarion"
|
||||
postgres_port: int = 5432
|
||||
postgres_user: str = "daarion"
|
||||
postgres_password: str = "DaarionDB2026!"
|
||||
postgres_db: str = "daarion_main"
|
||||
|
||||
@property
|
||||
def postgres_url(self) -> str:
|
||||
return f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}@{self.postgres_host}:{self.postgres_port}/{self.postgres_db}"
|
||||
|
||||
# Qdrant
|
||||
qdrant_host: str = "qdrant.qdrant"
|
||||
qdrant_port: int = 6333
|
||||
qdrant_collection_memories: str = "memories"
|
||||
qdrant_collection_messages: str = "messages"
|
||||
|
||||
# Cohere (embeddings)
|
||||
cohere_api_key: str = "nOdOXnuepLku2ipJWpe6acWgAsJCsDhMO0RnaEJB"
|
||||
cohere_model: str = "embed-multilingual-v3.0" # 1024 dimensions
|
||||
embedding_dimensions: int = 1024
|
||||
|
||||
# Memory settings
|
||||
short_term_window_messages: int = 20
|
||||
short_term_window_minutes: int = 60
|
||||
summary_trigger_tokens: int = 4000
|
||||
summary_target_tokens: int = 500
|
||||
retrieval_top_k: int = 10
|
||||
|
||||
# Confidence thresholds
|
||||
memory_min_confidence: float = 0.5
|
||||
memory_confirm_boost: float = 0.1
|
||||
memory_reject_penalty: float = 0.3
|
||||
|
||||
class Config:
|
||||
env_prefix = "MEMORY_"
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
430
services/memory-service/app/database.py
Normal file
430
services/memory-service/app/database.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""
|
||||
DAARION Memory Service - PostgreSQL Database Layer
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import UUID, uuid4
|
||||
from datetime import datetime
|
||||
import structlog
|
||||
import asyncpg
|
||||
|
||||
from .config import get_settings
|
||||
from .models import EventType, MessageRole, MemoryCategory, RetentionPolicy, FeedbackAction
|
||||
|
||||
logger = structlog.get_logger()
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class Database:
|
||||
"""PostgreSQL database operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.pool: Optional[asyncpg.Pool] = None
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to database"""
|
||||
self.pool = await asyncpg.create_pool(
|
||||
host=settings.postgres_host,
|
||||
port=settings.postgres_port,
|
||||
user=settings.postgres_user,
|
||||
password=settings.postgres_password,
|
||||
database=settings.postgres_db,
|
||||
min_size=5,
|
||||
max_size=20
|
||||
)
|
||||
logger.info("database_connected")
|
||||
|
||||
async def disconnect(self):
|
||||
"""Disconnect from database"""
|
||||
if self.pool:
|
||||
await self.pool.close()
|
||||
logger.info("database_disconnected")
|
||||
|
||||
# ========================================================================
|
||||
# THREADS
|
||||
# ========================================================================
|
||||
|
||||
async def create_thread(
|
||||
self,
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
agent_id: Optional[UUID] = None,
|
||||
title: Optional[str] = None,
|
||||
tags: List[str] = [],
|
||||
metadata: dict = {}
|
||||
) -> Dict[str, Any]:
|
||||
"""Create new conversation thread"""
|
||||
thread_id = uuid4()
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO conversation_threads
|
||||
(thread_id, org_id, workspace_id, user_id, agent_id, title, tags, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
RETURNING *
|
||||
""", thread_id, org_id, workspace_id, user_id, agent_id, title, tags, metadata)
|
||||
|
||||
logger.info("thread_created", thread_id=str(thread_id))
|
||||
return dict(row)
|
||||
|
||||
async def get_thread(self, thread_id: UUID) -> Optional[Dict[str, Any]]:
|
||||
"""Get thread by ID"""
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT * FROM conversation_threads WHERE thread_id = $1
|
||||
""", thread_id)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def list_threads(
|
||||
self,
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
agent_id: Optional[UUID] = None,
|
||||
limit: int = 20
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List threads for user"""
|
||||
async with self.pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT * FROM conversation_threads
|
||||
WHERE org_id = $1 AND user_id = $2 AND status = 'active'
|
||||
"""
|
||||
params = [org_id, user_id]
|
||||
|
||||
if workspace_id:
|
||||
query += f" AND workspace_id = ${len(params) + 1}"
|
||||
params.append(workspace_id)
|
||||
if agent_id:
|
||||
query += f" AND agent_id = ${len(params) + 1}"
|
||||
params.append(agent_id)
|
||||
|
||||
query += f" ORDER BY last_activity_at DESC LIMIT ${len(params) + 1}"
|
||||
params.append(limit)
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
# ========================================================================
|
||||
# EVENTS
|
||||
# ========================================================================
|
||||
|
||||
async def add_event(
|
||||
self,
|
||||
thread_id: UUID,
|
||||
event_type: EventType,
|
||||
role: Optional[MessageRole] = None,
|
||||
content: Optional[str] = None,
|
||||
tool_name: Optional[str] = None,
|
||||
tool_input: Optional[dict] = None,
|
||||
tool_output: Optional[dict] = None,
|
||||
payload: dict = {},
|
||||
token_count: Optional[int] = None,
|
||||
model_used: Optional[str] = None,
|
||||
latency_ms: Optional[int] = None,
|
||||
metadata: dict = {}
|
||||
) -> Dict[str, Any]:
|
||||
"""Add event to conversation"""
|
||||
event_id = uuid4()
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO conversation_events
|
||||
(event_id, thread_id, event_type, role, content, tool_name,
|
||||
tool_input, tool_output, payload, token_count, model_used,
|
||||
latency_ms, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
RETURNING *
|
||||
""", event_id, thread_id, event_type.value,
|
||||
role.value if role else None, content, tool_name,
|
||||
tool_input, tool_output, payload, token_count, model_used,
|
||||
latency_ms, metadata)
|
||||
|
||||
logger.info("event_added", event_id=str(event_id), type=event_type.value)
|
||||
return dict(row)
|
||||
|
||||
async def get_events(
|
||||
self,
|
||||
thread_id: UUID,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get events for thread"""
|
||||
async with self.pool.acquire() as conn:
|
||||
rows = await conn.fetch("""
|
||||
SELECT * FROM conversation_events
|
||||
WHERE thread_id = $1
|
||||
ORDER BY sequence_num DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
""", thread_id, limit, offset)
|
||||
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def get_events_for_summary(
|
||||
self,
|
||||
thread_id: UUID,
|
||||
after_seq: Optional[int] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get events for summarization"""
|
||||
async with self.pool.acquire() as conn:
|
||||
if after_seq:
|
||||
rows = await conn.fetch("""
|
||||
SELECT * FROM conversation_events
|
||||
WHERE thread_id = $1 AND sequence_num > $2
|
||||
ORDER BY sequence_num ASC
|
||||
""", thread_id, after_seq)
|
||||
else:
|
||||
rows = await conn.fetch("""
|
||||
SELECT * FROM conversation_events
|
||||
WHERE thread_id = $1
|
||||
ORDER BY sequence_num ASC
|
||||
""", thread_id)
|
||||
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
# ========================================================================
|
||||
# MEMORIES
|
||||
# ========================================================================
|
||||
|
||||
async def create_memory(
|
||||
self,
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
category: MemoryCategory,
|
||||
fact_text: str,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
agent_id: Optional[UUID] = None,
|
||||
confidence: float = 0.8,
|
||||
source_event_id: Optional[UUID] = None,
|
||||
source_thread_id: Optional[UUID] = None,
|
||||
extraction_method: str = "explicit",
|
||||
is_sensitive: bool = False,
|
||||
retention: RetentionPolicy = RetentionPolicy.UNTIL_REVOKED,
|
||||
ttl_days: Optional[int] = None,
|
||||
tags: List[str] = [],
|
||||
metadata: dict = {}
|
||||
) -> Dict[str, Any]:
|
||||
"""Create long-term memory item"""
|
||||
memory_id = uuid4()
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO long_term_memory_items
|
||||
(memory_id, org_id, workspace_id, user_id, agent_id, category,
|
||||
fact_text, confidence, source_event_id, source_thread_id,
|
||||
extraction_method, is_sensitive, retention, ttl_days, tags, metadata)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
|
||||
RETURNING *
|
||||
""", memory_id, org_id, workspace_id, user_id, agent_id, category.value,
|
||||
fact_text, confidence, source_event_id, source_thread_id,
|
||||
extraction_method, is_sensitive, retention.value, ttl_days, tags, metadata)
|
||||
|
||||
logger.info("memory_created", memory_id=str(memory_id), category=category.value)
|
||||
return dict(row)
|
||||
|
||||
async def get_memory(self, memory_id: UUID) -> Optional[Dict[str, Any]]:
|
||||
"""Get memory by ID"""
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT * FROM long_term_memory_items
|
||||
WHERE memory_id = $1 AND valid_to IS NULL
|
||||
""", memory_id)
|
||||
return dict(row) if row else None
|
||||
|
||||
async def list_memories(
|
||||
self,
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
agent_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
category: Optional[MemoryCategory] = None,
|
||||
include_global: bool = True,
|
||||
limit: int = 50
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List memories for user"""
|
||||
async with self.pool.acquire() as conn:
|
||||
query = """
|
||||
SELECT * FROM long_term_memory_items
|
||||
WHERE org_id = $1 AND user_id = $2 AND valid_to IS NULL
|
||||
AND confidence >= $3
|
||||
"""
|
||||
params = [org_id, user_id, settings.memory_min_confidence]
|
||||
|
||||
if workspace_id:
|
||||
query += f" AND (workspace_id = ${len(params) + 1} OR workspace_id IS NULL)"
|
||||
params.append(workspace_id)
|
||||
|
||||
if agent_id:
|
||||
if include_global:
|
||||
query += f" AND (agent_id = ${len(params) + 1} OR agent_id IS NULL)"
|
||||
else:
|
||||
query += f" AND agent_id = ${len(params) + 1}"
|
||||
params.append(agent_id)
|
||||
|
||||
if category:
|
||||
query += f" AND category = ${len(params) + 1}"
|
||||
params.append(category.value)
|
||||
|
||||
query += f" ORDER BY confidence DESC, last_used_at DESC NULLS LAST LIMIT ${len(params) + 1}"
|
||||
params.append(limit)
|
||||
|
||||
rows = await conn.fetch(query, *params)
|
||||
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def update_memory_embedding_id(self, memory_id: UUID, embedding_id: str):
|
||||
"""Update memory with Qdrant point ID"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
UPDATE long_term_memory_items
|
||||
SET fact_embedding_id = $2
|
||||
WHERE memory_id = $1
|
||||
""", memory_id, embedding_id)
|
||||
|
||||
async def update_memory_confidence(
|
||||
self,
|
||||
memory_id: UUID,
|
||||
confidence: float,
|
||||
verified: bool = False
|
||||
):
|
||||
"""Update memory confidence"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
UPDATE long_term_memory_items
|
||||
SET confidence = $2,
|
||||
is_verified = CASE WHEN $3 THEN true ELSE is_verified END,
|
||||
verification_count = verification_count + CASE WHEN $3 THEN 1 ELSE 0 END,
|
||||
last_confirmed_at = CASE WHEN $3 THEN NOW() ELSE last_confirmed_at END
|
||||
WHERE memory_id = $1
|
||||
""", memory_id, confidence, verified)
|
||||
|
||||
async def update_memory_text(self, memory_id: UUID, new_text: str):
|
||||
"""Update memory text"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
UPDATE long_term_memory_items
|
||||
SET fact_text = $2
|
||||
WHERE memory_id = $1
|
||||
""", memory_id, new_text)
|
||||
|
||||
async def invalidate_memory(self, memory_id: UUID):
|
||||
"""Mark memory as invalid (soft delete)"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
UPDATE long_term_memory_items
|
||||
SET valid_to = NOW()
|
||||
WHERE memory_id = $1
|
||||
""", memory_id)
|
||||
logger.info("memory_invalidated", memory_id=str(memory_id))
|
||||
|
||||
async def increment_memory_usage(self, memory_id: UUID):
|
||||
"""Increment memory usage counter"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
UPDATE long_term_memory_items
|
||||
SET use_count = use_count + 1, last_used_at = NOW()
|
||||
WHERE memory_id = $1
|
||||
""", memory_id)
|
||||
|
||||
# ========================================================================
|
||||
# FEEDBACK
|
||||
# ========================================================================
|
||||
|
||||
async def add_memory_feedback(
|
||||
self,
|
||||
memory_id: UUID,
|
||||
user_id: UUID,
|
||||
action: FeedbackAction,
|
||||
old_value: Optional[str] = None,
|
||||
new_value: Optional[str] = None,
|
||||
reason: Optional[str] = None
|
||||
):
|
||||
"""Record user feedback on memory"""
|
||||
feedback_id = uuid4()
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
INSERT INTO memory_feedback
|
||||
(feedback_id, memory_id, user_id, action, old_value, new_value, reason)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
""", feedback_id, memory_id, user_id, action.value, old_value, new_value, reason)
|
||||
|
||||
logger.info("feedback_recorded", memory_id=str(memory_id), action=action.value)
|
||||
|
||||
# ========================================================================
|
||||
# SUMMARIES
|
||||
# ========================================================================
|
||||
|
||||
async def create_summary(
|
||||
self,
|
||||
thread_id: UUID,
|
||||
summary_text: str,
|
||||
state: dict,
|
||||
events_from_seq: int,
|
||||
events_to_seq: int,
|
||||
events_count: int,
|
||||
original_tokens: Optional[int] = None,
|
||||
summary_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create thread summary"""
|
||||
summary_id = uuid4()
|
||||
|
||||
# Get next version
|
||||
async with self.pool.acquire() as conn:
|
||||
version_row = await conn.fetchrow("""
|
||||
SELECT COALESCE(MAX(version), 0) + 1 as next_version
|
||||
FROM thread_summaries WHERE thread_id = $1
|
||||
""", thread_id)
|
||||
version = version_row["next_version"]
|
||||
|
||||
compression_ratio = None
|
||||
if original_tokens and summary_tokens:
|
||||
compression_ratio = summary_tokens / original_tokens
|
||||
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO thread_summaries
|
||||
(summary_id, thread_id, version, summary_text, state,
|
||||
events_from_seq, events_to_seq, events_count,
|
||||
original_tokens, summary_tokens, compression_ratio)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
||||
RETURNING *
|
||||
""", summary_id, thread_id, version, summary_text, state,
|
||||
events_from_seq, events_to_seq, events_count,
|
||||
original_tokens, summary_tokens, compression_ratio)
|
||||
|
||||
logger.info("summary_created", summary_id=str(summary_id), version=version)
|
||||
return dict(row)
|
||||
|
||||
async def get_latest_summary(self, thread_id: UUID) -> Optional[Dict[str, Any]]:
|
||||
"""Get latest summary for thread"""
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT * FROM thread_summaries
|
||||
WHERE thread_id = $1
|
||||
ORDER BY version DESC
|
||||
LIMIT 1
|
||||
""", thread_id)
|
||||
return dict(row) if row else None
|
||||
|
||||
# ========================================================================
|
||||
# STATS
|
||||
# ========================================================================
|
||||
|
||||
async def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get database statistics"""
|
||||
async with self.pool.acquire() as conn:
|
||||
threads = await conn.fetchval("SELECT COUNT(*) FROM conversation_threads")
|
||||
events = await conn.fetchval("SELECT COUNT(*) FROM conversation_events")
|
||||
memories = await conn.fetchval("SELECT COUNT(*) FROM long_term_memory_items WHERE valid_to IS NULL")
|
||||
summaries = await conn.fetchval("SELECT COUNT(*) FROM thread_summaries")
|
||||
|
||||
return {
|
||||
"threads": threads,
|
||||
"events": events,
|
||||
"active_memories": memories,
|
||||
"summaries": summaries
|
||||
}
|
||||
|
||||
|
||||
# Global instance
|
||||
db = Database()
|
||||
86
services/memory-service/app/embedding.py
Normal file
86
services/memory-service/app/embedding.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""
|
||||
DAARION Memory Service - Embedding Layer (Cohere)
|
||||
"""
|
||||
import cohere
|
||||
from typing import List
|
||||
import structlog
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from .config import get_settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
settings = get_settings()
|
||||
|
||||
# Initialize Cohere client
|
||||
co = cohere.Client(settings.cohere_api_key)
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=1, max=10)
|
||||
)
|
||||
async def get_embeddings(
|
||||
texts: List[str],
|
||||
input_type: str = "search_document"
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Get embeddings from Cohere API.
|
||||
|
||||
Args:
|
||||
texts: List of texts to embed
|
||||
input_type: "search_document" for indexing, "search_query" for queries
|
||||
|
||||
Returns:
|
||||
List of embedding vectors (1024 dimensions for embed-multilingual-v3.0)
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
logger.info("generating_embeddings", count=len(texts), input_type=input_type)
|
||||
|
||||
response = co.embed(
|
||||
texts=texts,
|
||||
model=settings.cohere_model,
|
||||
input_type=input_type,
|
||||
truncate="END"
|
||||
)
|
||||
|
||||
embeddings = response.embeddings
|
||||
|
||||
logger.info(
|
||||
"embeddings_generated",
|
||||
count=len(embeddings),
|
||||
dimensions=len(embeddings[0]) if embeddings else 0
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
async def get_query_embedding(query: str) -> List[float]:
|
||||
"""Get embedding for a search query"""
|
||||
embeddings = await get_embeddings([query], input_type="search_query")
|
||||
return embeddings[0] if embeddings else []
|
||||
|
||||
|
||||
async def get_document_embeddings(texts: List[str]) -> List[List[float]]:
|
||||
"""Get embeddings for documents (memories, summaries)"""
|
||||
return await get_embeddings(texts, input_type="search_document")
|
||||
|
||||
|
||||
# Batch processing for large sets
|
||||
async def batch_embed(
|
||||
texts: List[str],
|
||||
input_type: str = "search_document",
|
||||
batch_size: int = 96 # Cohere limit
|
||||
) -> List[List[float]]:
|
||||
"""
|
||||
Embed large number of texts in batches.
|
||||
"""
|
||||
all_embeddings = []
|
||||
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i:i + batch_size]
|
||||
embeddings = await get_embeddings(batch, input_type)
|
||||
all_embeddings.extend(embeddings)
|
||||
|
||||
return all_embeddings
|
||||
@@ -1,443 +1,483 @@
|
||||
"""
|
||||
Memory Service - FastAPI додаток
|
||||
Підтримує: user_facts, dialog_summaries, agent_memory_events
|
||||
Інтеграція з token-gate через RBAC
|
||||
DAARION Memory Service - FastAPI Application
|
||||
|
||||
Трирівнева пам'ять агентів:
|
||||
- Short-term: conversation events (робочий буфер)
|
||||
- Mid-term: thread summaries (сесійна/тематична)
|
||||
- Long-term: memory items (персональна/проектна)
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi import FastAPI, Depends, HTTPException, Query, Header
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List, Optional
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
from fastapi import FastAPI, HTTPException, Query
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.models import Base, UserFact, DialogSummary, AgentMemoryEvent
|
||||
from app.schemas import (
|
||||
UserFactCreate, UserFactUpdate, UserFactResponse, UserFactUpsertRequest, UserFactUpsertResponse,
|
||||
DialogSummaryCreate, DialogSummaryResponse, DialogSummaryListResponse,
|
||||
AgentMemoryEventCreate, AgentMemoryEventResponse, AgentMemoryEventListResponse,
|
||||
TokenGateCheck, TokenGateCheckResponse
|
||||
)
|
||||
from app.crud import (
|
||||
get_user_fact, get_user_facts, create_user_fact, update_user_fact,
|
||||
upsert_user_fact, delete_user_fact, get_user_facts_by_token_gate,
|
||||
create_dialog_summary, get_dialog_summaries, get_dialog_summary, delete_dialog_summary,
|
||||
create_agent_memory_event, get_agent_memory_events, delete_agent_memory_event
|
||||
from .config import get_settings
|
||||
from .models import (
|
||||
CreateThreadRequest, AddEventRequest, CreateMemoryRequest,
|
||||
MemoryFeedbackRequest, RetrievalRequest, SummaryRequest,
|
||||
ThreadResponse, EventResponse, MemoryResponse,
|
||||
SummaryResponse, RetrievalResponse, RetrievalResult,
|
||||
ContextResponse, MemoryCategory, FeedbackAction
|
||||
)
|
||||
from .vector_store import vector_store
|
||||
from .database import db
|
||||
|
||||
# ========== Configuration ==========
|
||||
logger = structlog.get_logger()
|
||||
settings = get_settings()
|
||||
|
||||
DATABASE_URL = os.getenv(
|
||||
"DATABASE_URL",
|
||||
"sqlite:///./memory.db" # SQLite для розробки, PostgreSQL для продакшену
|
||||
)
|
||||
|
||||
# Створюємо engine та sessionmaker
|
||||
engine = create_engine(DATABASE_URL)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Startup and shutdown events"""
|
||||
# Startup
|
||||
logger.info("starting_memory_service")
|
||||
await db.connect()
|
||||
await vector_store.initialize()
|
||||
yield
|
||||
# Shutdown
|
||||
await db.disconnect()
|
||||
logger.info("memory_service_stopped")
|
||||
|
||||
# Створюємо таблиці (для dev, в продакшені використовуйте міграції)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# ========== FastAPI App ==========
|
||||
|
||||
app = FastAPI(
|
||||
title="Memory Service",
|
||||
description="Сервіс пам'яті для MicroDAO: user_facts, dialog_summaries, agent_memory_events",
|
||||
version="1.0.0"
|
||||
title="DAARION Memory Service",
|
||||
description="Agent memory management with PostgreSQL + Qdrant + Cohere",
|
||||
version="1.0.0",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # В продакшені обмежте це
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# ========== Dependencies ==========
|
||||
|
||||
def get_db():
|
||||
"""Dependency для отримання DB сесії"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def verify_token(authorization: Optional[str] = Header(None)) -> Optional[str]:
|
||||
"""
|
||||
Перевірка JWT токену (заглушка)
|
||||
В продакшені інтегруйте з вашою системою авторизації
|
||||
"""
|
||||
if not authorization:
|
||||
raise HTTPException(status_code=401, detail="Missing authorization header")
|
||||
|
||||
# Заглушка: в реальності перевіряйте JWT
|
||||
# token = authorization.replace("Bearer ", "")
|
||||
# user_id = verify_jwt_token(token)
|
||||
# return user_id
|
||||
|
||||
# Для тестування повертаємо user_id з заголовка
|
||||
return "u_test" # TODO: реалізувати реальну перевірку
|
||||
|
||||
|
||||
async def check_token_gate(
|
||||
user_id: str,
|
||||
token_requirements: dict,
|
||||
db: Session
|
||||
) -> TokenGateCheckResponse:
|
||||
"""
|
||||
Перевірка токен-гейту (інтеграція з RBAC/Wallet Service)
|
||||
Заглушка - в продакшені викликайте ваш PDP/Wallet Service
|
||||
"""
|
||||
# TODO: Інтегрувати з:
|
||||
# - PDP Service для перевірки capabilities
|
||||
# - Wallet Service для перевірки балансів
|
||||
# - RBAC для перевірки ролей
|
||||
|
||||
# Приклад логіки:
|
||||
# if "token" in token_requirements:
|
||||
# token_type = token_requirements["token"]
|
||||
# min_balance = token_requirements.get("min_balance", 0)
|
||||
# balance = await wallet_service.get_balance(user_id, token_type)
|
||||
# if balance < min_balance:
|
||||
# return TokenGateCheckResponse(
|
||||
# allowed=False,
|
||||
# reason=f"Insufficient {token_type} balance",
|
||||
# missing_requirements={"token": token_type, "required": min_balance, "current": balance}
|
||||
# )
|
||||
|
||||
# Заглушка: завжди дозволяємо
|
||||
return TokenGateCheckResponse(allowed=True)
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
# ============================================================================
|
||||
# HEALTH
|
||||
# ============================================================================
|
||||
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "ok", "service": "memory-service"}
|
||||
async def health():
|
||||
"""Health check"""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": settings.service_name,
|
||||
"vector_store": await vector_store.get_collection_stats()
|
||||
}
|
||||
|
||||
|
||||
# ========== User Facts Endpoints ==========
|
||||
# ============================================================================
|
||||
# THREADS (Conversations)
|
||||
# ============================================================================
|
||||
|
||||
@app.post("/facts/upsert", response_model=UserFactUpsertResponse)
|
||||
async def upsert_fact(
|
||||
fact_request: UserFactUpsertRequest,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
@app.post("/threads", response_model=ThreadResponse)
|
||||
async def create_thread(request: CreateThreadRequest):
|
||||
"""Create new conversation thread"""
|
||||
thread = await db.create_thread(
|
||||
org_id=request.org_id,
|
||||
user_id=request.user_id,
|
||||
workspace_id=request.workspace_id,
|
||||
agent_id=request.agent_id,
|
||||
title=request.title,
|
||||
tags=request.tags,
|
||||
metadata=request.metadata
|
||||
)
|
||||
return thread
|
||||
|
||||
|
||||
@app.get("/threads/{thread_id}", response_model=ThreadResponse)
|
||||
async def get_thread(thread_id: UUID):
|
||||
"""Get thread by ID"""
|
||||
thread = await db.get_thread(thread_id)
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
return thread
|
||||
|
||||
|
||||
@app.get("/threads", response_model=List[ThreadResponse])
|
||||
async def list_threads(
|
||||
user_id: UUID = Query(...),
|
||||
org_id: UUID = Query(...),
|
||||
workspace_id: Optional[UUID] = None,
|
||||
agent_id: Optional[UUID] = None,
|
||||
limit: int = Query(default=20, le=100)
|
||||
):
|
||||
"""
|
||||
Створити або оновити факт користувача (upsert)
|
||||
"""List threads for user"""
|
||||
threads = await db.list_threads(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
agent_id=agent_id,
|
||||
limit=limit
|
||||
)
|
||||
return threads
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# EVENTS (Short-term Memory)
|
||||
# ============================================================================
|
||||
|
||||
@app.post("/events", response_model=EventResponse)
|
||||
async def add_event(request: AddEventRequest):
|
||||
"""Add event to conversation (message, tool call, etc.)"""
|
||||
event = await db.add_event(
|
||||
thread_id=request.thread_id,
|
||||
event_type=request.event_type,
|
||||
role=request.role,
|
||||
content=request.content,
|
||||
tool_name=request.tool_name,
|
||||
tool_input=request.tool_input,
|
||||
tool_output=request.tool_output,
|
||||
payload=request.payload,
|
||||
token_count=request.token_count,
|
||||
model_used=request.model_used,
|
||||
latency_ms=request.latency_ms,
|
||||
metadata=request.metadata
|
||||
)
|
||||
return event
|
||||
|
||||
|
||||
@app.get("/threads/{thread_id}/events", response_model=List[EventResponse])
|
||||
async def get_events(
|
||||
thread_id: UUID,
|
||||
limit: int = Query(default=50, le=200),
|
||||
offset: int = Query(default=0)
|
||||
):
|
||||
"""Get events for thread (most recent first)"""
|
||||
events = await db.get_events(thread_id, limit=limit, offset=offset)
|
||||
return events
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MEMORIES (Long-term Memory)
|
||||
# ============================================================================
|
||||
|
||||
@app.post("/memories", response_model=MemoryResponse)
|
||||
async def create_memory(request: CreateMemoryRequest):
|
||||
"""Create long-term memory item"""
|
||||
# Create in PostgreSQL
|
||||
memory = await db.create_memory(
|
||||
org_id=request.org_id,
|
||||
user_id=request.user_id,
|
||||
workspace_id=request.workspace_id,
|
||||
agent_id=request.agent_id,
|
||||
category=request.category,
|
||||
fact_text=request.fact_text,
|
||||
confidence=request.confidence,
|
||||
source_event_id=request.source_event_id,
|
||||
source_thread_id=request.source_thread_id,
|
||||
extraction_method=request.extraction_method,
|
||||
is_sensitive=request.is_sensitive,
|
||||
retention=request.retention,
|
||||
ttl_days=request.ttl_days,
|
||||
tags=request.tags,
|
||||
metadata=request.metadata
|
||||
)
|
||||
|
||||
Це основний ендпоінт для контрольованої довгострокової пам'яті.
|
||||
Підтримує токен-гейт інтеграцію.
|
||||
"""
|
||||
# Перевірка токен-гейту якщо потрібно
|
||||
if fact_request.token_gated and fact_request.token_requirements:
|
||||
gate_check = await check_token_gate(
|
||||
fact_request.user_id,
|
||||
fact_request.token_requirements,
|
||||
db
|
||||
)
|
||||
if not gate_check.allowed:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Token gate check failed: {gate_check.reason}"
|
||||
# Index in Qdrant
|
||||
point_id = await vector_store.index_memory(
|
||||
memory_id=memory["memory_id"],
|
||||
text=request.fact_text,
|
||||
org_id=request.org_id,
|
||||
user_id=request.user_id,
|
||||
category=request.category,
|
||||
agent_id=request.agent_id,
|
||||
workspace_id=request.workspace_id,
|
||||
thread_id=request.source_thread_id
|
||||
)
|
||||
|
||||
# Update memory with embedding ID
|
||||
await db.update_memory_embedding_id(memory["memory_id"], point_id)
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
@app.get("/memories/{memory_id}", response_model=MemoryResponse)
|
||||
async def get_memory(memory_id: UUID):
|
||||
"""Get memory by ID"""
|
||||
memory = await db.get_memory(memory_id)
|
||||
if not memory:
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
return memory
|
||||
|
||||
|
||||
@app.get("/memories", response_model=List[MemoryResponse])
|
||||
async def list_memories(
|
||||
user_id: UUID = Query(...),
|
||||
org_id: UUID = Query(...),
|
||||
agent_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
category: Optional[MemoryCategory] = None,
|
||||
include_global: bool = True,
|
||||
limit: int = Query(default=50, le=200)
|
||||
):
|
||||
"""List memories for user"""
|
||||
memories = await db.list_memories(
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
workspace_id=workspace_id,
|
||||
category=category,
|
||||
include_global=include_global,
|
||||
limit=limit
|
||||
)
|
||||
return memories
|
||||
|
||||
|
||||
@app.post("/memories/{memory_id}/feedback")
|
||||
async def memory_feedback(memory_id: UUID, request: MemoryFeedbackRequest):
|
||||
"""User feedback on memory (confirm/reject/edit/delete)"""
|
||||
memory = await db.get_memory(memory_id)
|
||||
if not memory:
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
|
||||
# Record feedback
|
||||
await db.add_memory_feedback(
|
||||
memory_id=memory_id,
|
||||
user_id=request.user_id,
|
||||
action=request.action,
|
||||
old_value=memory["fact_text"],
|
||||
new_value=request.new_value,
|
||||
reason=request.reason
|
||||
)
|
||||
|
||||
# Apply action
|
||||
if request.action == FeedbackAction.CONFIRM:
|
||||
new_confidence = min(1.0, memory["confidence"] + settings.memory_confirm_boost)
|
||||
await db.update_memory_confidence(memory_id, new_confidence, verified=True)
|
||||
|
||||
elif request.action == FeedbackAction.REJECT:
|
||||
new_confidence = max(0.0, memory["confidence"] - settings.memory_reject_penalty)
|
||||
if new_confidence < settings.memory_min_confidence:
|
||||
# Mark as invalid
|
||||
await db.invalidate_memory(memory_id)
|
||||
await vector_store.delete_memory(memory_id)
|
||||
else:
|
||||
await db.update_memory_confidence(memory_id, new_confidence)
|
||||
|
||||
elif request.action == FeedbackAction.EDIT:
|
||||
if request.new_value:
|
||||
await db.update_memory_text(memory_id, request.new_value)
|
||||
# Re-index with new text
|
||||
await vector_store.delete_memory(memory_id)
|
||||
await vector_store.index_memory(
|
||||
memory_id=memory_id,
|
||||
text=request.new_value,
|
||||
org_id=memory["org_id"],
|
||||
user_id=memory["user_id"],
|
||||
category=memory["category"],
|
||||
agent_id=memory.get("agent_id"),
|
||||
workspace_id=memory.get("workspace_id")
|
||||
)
|
||||
|
||||
elif request.action == FeedbackAction.DELETE:
|
||||
await db.invalidate_memory(memory_id)
|
||||
await vector_store.delete_memory(memory_id)
|
||||
|
||||
return {"status": "ok", "action": request.action.value}
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RETRIEVAL (Semantic Search)
|
||||
# ============================================================================
|
||||
|
||||
@app.post("/retrieve", response_model=RetrievalResponse)
|
||||
async def retrieve_memories(request: RetrievalRequest):
|
||||
"""
|
||||
Semantic retrieval of relevant memories.
|
||||
|
||||
# Перевірка прав доступу (користувач може змінювати тільки свої факти)
|
||||
if fact_request.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Cannot modify other user's facts")
|
||||
Performs multiple queries and deduplicates results.
|
||||
"""
|
||||
all_results = []
|
||||
seen_ids = set()
|
||||
|
||||
fact, created = upsert_user_fact(db, fact_request)
|
||||
for query in request.queries:
|
||||
results = await vector_store.search_memories(
|
||||
query=query,
|
||||
org_id=request.org_id,
|
||||
user_id=request.user_id,
|
||||
agent_id=request.agent_id,
|
||||
workspace_id=request.workspace_id,
|
||||
categories=request.categories,
|
||||
include_global=request.include_global,
|
||||
top_k=request.top_k
|
||||
)
|
||||
|
||||
for r in results:
|
||||
memory_id = r.get("memory_id")
|
||||
if memory_id and memory_id not in seen_ids:
|
||||
seen_ids.add(memory_id)
|
||||
|
||||
# Get full memory from DB for confidence check
|
||||
memory = await db.get_memory(UUID(memory_id))
|
||||
if memory and memory["confidence"] >= request.min_confidence:
|
||||
all_results.append(RetrievalResult(
|
||||
memory_id=UUID(memory_id),
|
||||
fact_text=r["text"],
|
||||
category=MemoryCategory(r["category"]),
|
||||
confidence=memory["confidence"],
|
||||
relevance_score=r["score"],
|
||||
agent_id=UUID(r["agent_id"]) if r.get("agent_id") else None,
|
||||
is_global=r.get("agent_id") is None
|
||||
))
|
||||
|
||||
# Update usage stats
|
||||
await db.increment_memory_usage(UUID(memory_id))
|
||||
|
||||
return UserFactUpsertResponse(
|
||||
fact=UserFactResponse.model_validate(fact),
|
||||
created=created
|
||||
# Sort by relevance
|
||||
all_results.sort(key=lambda x: x.relevance_score, reverse=True)
|
||||
|
||||
return RetrievalResponse(
|
||||
results=all_results[:request.top_k],
|
||||
query_count=len(request.queries),
|
||||
total_results=len(all_results)
|
||||
)
|
||||
|
||||
|
||||
@app.get("/facts", response_model=List[UserFactResponse])
|
||||
async def list_facts(
|
||||
team_id: Optional[str] = Query(None),
|
||||
fact_keys: Optional[str] = Query(None, description="Comma-separated list of fact keys"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Отримати список фактів користувача"""
|
||||
fact_keys_list = None
|
||||
if fact_keys:
|
||||
fact_keys_list = [k.strip() for k in fact_keys.split(",")]
|
||||
|
||||
facts = get_user_facts(db, user_id, team_id, fact_keys_list, skip, limit)
|
||||
return [UserFactResponse.model_validate(f) for f in facts]
|
||||
# ============================================================================
|
||||
# SUMMARIES (Mid-term Memory)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.get("/facts/{fact_key}", response_model=UserFactResponse)
|
||||
async def get_fact(
|
||||
fact_key: str,
|
||||
team_id: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Отримати конкретний факт за ключем"""
|
||||
fact = get_user_fact(db, user_id, fact_key, team_id)
|
||||
if not fact:
|
||||
raise HTTPException(status_code=404, detail="Fact not found")
|
||||
return UserFactResponse.model_validate(fact)
|
||||
|
||||
|
||||
@app.post("/facts", response_model=UserFactResponse)
|
||||
async def create_fact(
|
||||
fact: UserFactCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Створити новий факт"""
|
||||
if fact.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Cannot create fact for other user")
|
||||
|
||||
db_fact = create_user_fact(db, fact)
|
||||
return UserFactResponse.model_validate(db_fact)
|
||||
|
||||
|
||||
@app.patch("/facts/{fact_id}", response_model=UserFactResponse)
|
||||
async def update_fact(
|
||||
fact_id: str,
|
||||
fact_update: UserFactUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Оновити факт"""
|
||||
fact = db.query(UserFact).filter(UserFact.id == fact_id).first()
|
||||
if not fact:
|
||||
raise HTTPException(status_code=404, detail="Fact not found")
|
||||
|
||||
if fact.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Cannot modify other user's fact")
|
||||
|
||||
updated_fact = update_user_fact(db, fact_id, fact_update)
|
||||
if not updated_fact:
|
||||
raise HTTPException(status_code=404, detail="Fact not found")
|
||||
|
||||
return UserFactResponse.model_validate(updated_fact)
|
||||
|
||||
|
||||
@app.delete("/facts/{fact_id}")
|
||||
async def delete_fact(
|
||||
fact_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Видалити факт"""
|
||||
fact = db.query(UserFact).filter(UserFact.id == fact_id).first()
|
||||
if not fact:
|
||||
raise HTTPException(status_code=404, detail="Fact not found")
|
||||
|
||||
if fact.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Cannot delete other user's fact")
|
||||
|
||||
success = delete_user_fact(db, fact_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Fact not found")
|
||||
|
||||
return {"success": True}
|
||||
|
||||
|
||||
@app.get("/facts/token-gated", response_model=List[UserFactResponse])
|
||||
async def list_token_gated_facts(
|
||||
team_id: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Отримати токен-гейт факти користувача"""
|
||||
facts = get_user_facts_by_token_gate(db, user_id, team_id)
|
||||
return [UserFactResponse.model_validate(f) for f in facts]
|
||||
|
||||
|
||||
# ========== Dialog Summary Endpoints ==========
|
||||
|
||||
@app.post("/summaries", response_model=DialogSummaryResponse)
|
||||
async def create_summary(
|
||||
summary: DialogSummaryCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
@app.post("/threads/{thread_id}/summarize", response_model=SummaryResponse)
|
||||
async def create_summary(thread_id: UUID, request: SummaryRequest):
|
||||
"""
|
||||
Створити підсумок діалогу
|
||||
Generate rolling summary for thread.
|
||||
|
||||
Використовується для масштабування без переповнення контексту.
|
||||
Агрегує інформацію про сесії/діалоги.
|
||||
Compresses old events into a structured summary.
|
||||
"""
|
||||
db_summary = create_dialog_summary(db, summary)
|
||||
return DialogSummaryResponse.model_validate(db_summary)
|
||||
|
||||
|
||||
@app.get("/summaries", response_model=DialogSummaryListResponse)
|
||||
async def list_summaries(
|
||||
team_id: Optional[str] = Query(None),
|
||||
channel_id: Optional[str] = Query(None),
|
||||
agent_id: Optional[str] = Query(None),
|
||||
user_id_param: Optional[str] = Query(None, alias="user_id"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
cursor: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Отримати список підсумків діалогів"""
|
||||
summaries, next_cursor = get_dialog_summaries(
|
||||
db, team_id, channel_id, agent_id, user_id_param, skip, limit, cursor
|
||||
thread = await db.get_thread(thread_id)
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
# Check if summary is needed
|
||||
if not request.force and thread["total_tokens"] < settings.summary_trigger_tokens:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Token count ({thread['total_tokens']}) below threshold ({settings.summary_trigger_tokens})"
|
||||
)
|
||||
|
||||
# Get events to summarize
|
||||
events = await db.get_events_for_summary(thread_id)
|
||||
|
||||
# TODO: Call LLM to generate summary
|
||||
# For now, create a placeholder
|
||||
summary_text = f"Summary of {len(events)} events. [Implement LLM summarization]"
|
||||
state = {
|
||||
"goals": [],
|
||||
"decisions": [],
|
||||
"open_questions": [],
|
||||
"next_steps": [],
|
||||
"key_facts": []
|
||||
}
|
||||
|
||||
# Create summary
|
||||
summary = await db.create_summary(
|
||||
thread_id=thread_id,
|
||||
summary_text=summary_text,
|
||||
state=state,
|
||||
events_from_seq=events[0]["sequence_num"] if events else 0,
|
||||
events_to_seq=events[-1]["sequence_num"] if events else 0,
|
||||
events_count=len(events)
|
||||
)
|
||||
|
||||
return DialogSummaryListResponse(
|
||||
items=[DialogSummaryResponse.model_validate(s) for s in summaries],
|
||||
total=len(summaries),
|
||||
cursor=next_cursor
|
||||
# Index summary in Qdrant
|
||||
await vector_store.index_summary(
|
||||
summary_id=summary["summary_id"],
|
||||
text=summary_text,
|
||||
thread_id=thread_id,
|
||||
org_id=thread["org_id"],
|
||||
user_id=thread["user_id"],
|
||||
agent_id=thread.get("agent_id"),
|
||||
workspace_id=thread.get("workspace_id")
|
||||
)
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
@app.get("/threads/{thread_id}/summary", response_model=Optional[SummaryResponse])
|
||||
async def get_latest_summary(thread_id: UUID):
|
||||
"""Get latest summary for thread"""
|
||||
summary = await db.get_latest_summary(thread_id)
|
||||
return summary
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# CONTEXT (Full context for agent)
|
||||
# ============================================================================
|
||||
|
||||
@app.get("/threads/{thread_id}/context", response_model=ContextResponse)
|
||||
async def get_context(
|
||||
thread_id: UUID,
|
||||
queries: List[str] = Query(default=[]),
|
||||
top_k: int = Query(default=10)
|
||||
):
|
||||
"""
|
||||
Get full context for agent prompt.
|
||||
|
||||
Combines:
|
||||
- Latest summary (mid-term)
|
||||
- Recent messages (short-term)
|
||||
- Retrieved memories (long-term)
|
||||
"""
|
||||
thread = await db.get_thread(thread_id)
|
||||
if not thread:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
# Get summary
|
||||
summary = await db.get_latest_summary(thread_id)
|
||||
|
||||
# Get recent messages
|
||||
recent = await db.get_events(
|
||||
thread_id,
|
||||
limit=settings.short_term_window_messages
|
||||
)
|
||||
|
||||
# Retrieve memories if queries provided
|
||||
retrieved = []
|
||||
if queries:
|
||||
retrieval_response = await retrieve_memories(RetrievalRequest(
|
||||
org_id=thread["org_id"],
|
||||
user_id=thread["user_id"],
|
||||
agent_id=thread.get("agent_id"),
|
||||
workspace_id=thread.get("workspace_id"),
|
||||
queries=queries,
|
||||
top_k=top_k,
|
||||
include_global=True
|
||||
))
|
||||
retrieved = retrieval_response.results
|
||||
|
||||
# Estimate tokens
|
||||
token_estimate = sum(e.get("token_count", 0) or 0 for e in recent)
|
||||
if summary:
|
||||
token_estimate += summary.get("summary_tokens", 0) or 0
|
||||
|
||||
return ContextResponse(
|
||||
thread_id=thread_id,
|
||||
summary=summary,
|
||||
recent_messages=recent,
|
||||
retrieved_memories=retrieved,
|
||||
token_estimate=token_estimate
|
||||
)
|
||||
|
||||
|
||||
@app.get("/summaries/{summary_id}", response_model=DialogSummaryResponse)
|
||||
async def get_summary(
|
||||
summary_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Отримати підсумок за ID"""
|
||||
summary = get_dialog_summary(db, summary_id)
|
||||
if not summary:
|
||||
raise HTTPException(status_code=404, detail="Summary not found")
|
||||
return DialogSummaryResponse.model_validate(summary)
|
||||
# ============================================================================
|
||||
# ADMIN
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@app.delete("/summaries/{summary_id}")
|
||||
async def delete_summary(
|
||||
summary_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Видалити підсумок"""
|
||||
success = delete_dialog_summary(db, summary_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Summary not found")
|
||||
return {"success": True}
|
||||
|
||||
|
||||
# ========== Agent Memory Event Endpoints ==========
|
||||
|
||||
@app.post("/agents/{agent_id}/memory", response_model=AgentMemoryEventResponse)
|
||||
async def create_memory_event(
|
||||
agent_id: str,
|
||||
event: AgentMemoryEventCreate,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Створити подію пам'яті агента"""
|
||||
# Перевірка що agent_id збігається
|
||||
if event.agent_id != agent_id:
|
||||
raise HTTPException(status_code=400, detail="agent_id mismatch")
|
||||
|
||||
db_event = create_agent_memory_event(db, event)
|
||||
return AgentMemoryEventResponse.model_validate(db_event)
|
||||
|
||||
|
||||
@app.get("/agents/{agent_id}/memory", response_model=AgentMemoryEventListResponse)
|
||||
async def list_memory_events(
|
||||
agent_id: str,
|
||||
team_id: Optional[str] = Query(None),
|
||||
channel_id: Optional[str] = Query(None),
|
||||
scope: Optional[str] = Query(None, description="short_term | mid_term | long_term"),
|
||||
kind: Optional[str] = Query(None, description="message | fact | summary | note"),
|
||||
skip: int = Query(0, ge=0),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
cursor: Optional[str] = Query(None),
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Отримати список подій пам'яті агента"""
|
||||
events, next_cursor = get_agent_memory_events(
|
||||
db, agent_id, team_id, channel_id, scope, kind, skip, limit, cursor
|
||||
)
|
||||
|
||||
return AgentMemoryEventListResponse(
|
||||
items=[AgentMemoryEventResponse.model_validate(e) for e in events],
|
||||
total=len(events),
|
||||
cursor=next_cursor
|
||||
)
|
||||
|
||||
|
||||
@app.delete("/agents/{agent_id}/memory/{event_id}")
|
||||
async def delete_memory_event(
|
||||
agent_id: str,
|
||||
event_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""Видалити подію пам'яті"""
|
||||
success = delete_agent_memory_event(db, event_id)
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Memory event not found")
|
||||
return {"success": True}
|
||||
|
||||
|
||||
# ========== Monitor Events Endpoints (Batch Processing) ==========
|
||||
|
||||
from app.monitor_events import MonitorEventBatch, MonitorEventResponse, save_monitor_events_batch, save_monitor_event_single
|
||||
|
||||
@app.post("/api/memory/monitor-events/batch", response_model=MonitorEventResponse)
|
||||
async def save_monitor_events_batch_endpoint(
|
||||
batch: MonitorEventBatch,
|
||||
db: Session = Depends(get_db),
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
Зберегти батч подій Monitor Agent
|
||||
Оптимізовано для збору метрик з багатьох нод
|
||||
"""
|
||||
return await save_monitor_events_batch(batch, db, authorization)
|
||||
|
||||
@app.post("/api/memory/monitor-events/{node_id}", response_model=AgentMemoryEventResponse)
|
||||
async def save_monitor_event_endpoint(
|
||||
node_id: str,
|
||||
event: Dict[str, Any],
|
||||
db: Session = Depends(get_db),
|
||||
authorization: Optional[str] = Header(None)
|
||||
):
|
||||
"""
|
||||
Зберегти одну подію Monitor Agent
|
||||
"""
|
||||
return await save_monitor_event_single(node_id, event, db, authorization)
|
||||
|
||||
|
||||
# ========== Token Gate Integration Endpoint ==========
|
||||
|
||||
@app.post("/token-gate/check", response_model=TokenGateCheckResponse)
|
||||
async def check_token_gate_endpoint(
|
||||
check: TokenGateCheck,
|
||||
db: Session = Depends(get_db),
|
||||
user_id: str = Depends(verify_token)
|
||||
):
|
||||
"""
|
||||
Перевірка токен-гейту для факту
|
||||
|
||||
Інтеграція з RBAC/Wallet Service для перевірки доступу
|
||||
"""
|
||||
if check.user_id != user_id:
|
||||
raise HTTPException(status_code=403, detail="Cannot check token gate for other user")
|
||||
|
||||
return await check_token_gate(user_id, check.token_requirements, db)
|
||||
@app.get("/stats")
|
||||
async def get_stats():
|
||||
"""Get service statistics"""
|
||||
return {
|
||||
"vector_store": await vector_store.get_collection_stats(),
|
||||
"database": await db.get_stats()
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
|
||||
@@ -1,178 +1,249 @@
|
||||
"""
|
||||
SQLAlchemy моделі для Memory Service
|
||||
Підтримує: user_facts, dialog_summaries, agent_memory_events
|
||||
DAARION Memory Service - Pydantic Models
|
||||
"""
|
||||
|
||||
from sqlalchemy import (
|
||||
Column, String, Text, JSON, TIMESTAMP,
|
||||
CheckConstraint, Index, Boolean, Integer
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
import os
|
||||
|
||||
# Перевірка типу бази даних
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./memory.db")
|
||||
IS_SQLITE = "sqlite" in DATABASE_URL.lower()
|
||||
|
||||
if IS_SQLITE:
|
||||
# Для SQLite використовуємо стандартні типи
|
||||
from sqlalchemy import JSON as JSONB_TYPE
|
||||
UUID_TYPE = String # SQLite не має UUID, використовуємо String
|
||||
else:
|
||||
# Для PostgreSQL використовуємо специфічні типи
|
||||
from sqlalchemy.dialects.postgresql import UUID, JSONB
|
||||
UUID_TYPE = UUID
|
||||
JSONB_TYPE = JSONB
|
||||
|
||||
try:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
HAS_PGVECTOR = True
|
||||
except ImportError:
|
||||
HAS_PGVECTOR = False
|
||||
# Заглушка для SQLite
|
||||
class Vector:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
Base = declarative_base()
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Any
|
||||
from uuid import UUID
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class UserFact(Base):
|
||||
"""
|
||||
Довгострокові факти про користувача
|
||||
Використовується для контрольованої довгострокової пам'яті
|
||||
(мови, вподобання, тип користувача, токен-статуси)
|
||||
"""
|
||||
__tablename__ = "user_facts"
|
||||
# ============================================================================
|
||||
# ENUMS
|
||||
# ============================================================================
|
||||
|
||||
id = Column(UUID_TYPE(as_uuid=False) if not IS_SQLITE else String, primary_key=True, server_default=func.gen_random_uuid() if not IS_SQLITE else None)
|
||||
user_id = Column(String, nullable=False, index=True) # Без FK constraint для тестування
|
||||
team_id = Column(String, nullable=True, index=True) # Без FK constraint, оскільки teams може не існувати
|
||||
|
||||
# Ключ факту (наприклад: "language", "is_donor", "is_validator", "top_contributor")
|
||||
fact_key = Column(String, nullable=False, index=True)
|
||||
|
||||
# Значення факту (може бути текст, число, boolean, JSON)
|
||||
fact_value = Column(Text, nullable=True)
|
||||
fact_value_json = Column(JSONB_TYPE, nullable=True)
|
||||
|
||||
# Метадані: джерело, впевненість, термін дії
|
||||
meta = Column(JSONB_TYPE, nullable=False, server_default="{}")
|
||||
|
||||
# Токен-гейт: чи залежить факт від токенів/активності
|
||||
token_gated = Column(Boolean, nullable=False, server_default="false")
|
||||
token_requirements = Column(JSONB_TYPE, nullable=True) # {"token": "DAAR", "min_balance": 1}
|
||||
|
||||
created_at = Column(TIMESTAMP(timezone=True) if not IS_SQLITE else TIMESTAMP, nullable=False, server_default=func.now())
|
||||
updated_at = Column(TIMESTAMP(timezone=True) if not IS_SQLITE else TIMESTAMP, nullable=True, onupdate=func.now())
|
||||
expires_at = Column(TIMESTAMP(timezone=True) if not IS_SQLITE else TIMESTAMP, nullable=True) # Для тимчасових фактів
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_user_facts_user_key", "user_id", "fact_key"),
|
||||
Index("idx_user_facts_team", "team_id"),
|
||||
Index("idx_user_facts_token_gated", "token_gated"),
|
||||
)
|
||||
class EventType(str, Enum):
|
||||
MESSAGE = "message"
|
||||
TOOL_CALL = "tool_call"
|
||||
TOOL_RESULT = "tool_result"
|
||||
DECISION = "decision"
|
||||
SUMMARY = "summary"
|
||||
MEMORY_WRITE = "memory_write"
|
||||
MEMORY_RETRACT = "memory_retract"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class DialogSummary(Base):
|
||||
"""
|
||||
Підсумки діалогів для масштабування без переповнення контексту
|
||||
Зберігає агреговану інформацію про сесії/діалоги
|
||||
"""
|
||||
__tablename__ = "dialog_summaries"
|
||||
|
||||
id = Column(UUID_TYPE(as_uuid=False) if not IS_SQLITE else String, primary_key=True, server_default=func.gen_random_uuid() if not IS_SQLITE else None)
|
||||
|
||||
# Контекст діалогу (без FK constraints для тестування)
|
||||
team_id = Column(String, nullable=False, index=True)
|
||||
channel_id = Column(String, nullable=True, index=True)
|
||||
agent_id = Column(String, nullable=True, index=True)
|
||||
user_id = Column(String, nullable=True, index=True)
|
||||
|
||||
# Період, який охоплює підсумок
|
||||
period_start = Column(TIMESTAMP(timezone=True) if not IS_SQLITE else TIMESTAMP, nullable=False)
|
||||
period_end = Column(TIMESTAMP(timezone=True) if not IS_SQLITE else TIMESTAMP, nullable=False)
|
||||
|
||||
# Підсумок
|
||||
summary_text = Column(Text, nullable=False)
|
||||
summary_json = Column(JSONB_TYPE, nullable=True) # Структуровані дані
|
||||
|
||||
# Статистика
|
||||
message_count = Column(Integer, nullable=False, server_default="0")
|
||||
participant_count = Column(Integer, nullable=False, server_default="0")
|
||||
|
||||
# Ключові теми/теги
|
||||
topics = Column(JSONB_TYPE, nullable=True) # ["project-planning", "bug-fix", ...]
|
||||
|
||||
# Метадані
|
||||
meta = Column(JSONB_TYPE, nullable=False, server_default="{}")
|
||||
|
||||
created_at = Column(TIMESTAMP(timezone=True) if not IS_SQLITE else TIMESTAMP, nullable=False, server_default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_dialog_summaries_team_period", "team_id", "period_start", "period_end"),
|
||||
Index("idx_dialog_summaries_channel", "channel_id"),
|
||||
Index("idx_dialog_summaries_agent", "agent_id"),
|
||||
)
|
||||
class MessageRole(str, Enum):
|
||||
USER = "user"
|
||||
ASSISTANT = "assistant"
|
||||
SYSTEM = "system"
|
||||
TOOL = "tool"
|
||||
|
||||
|
||||
class AgentMemoryEvent(Base):
|
||||
"""
|
||||
Події пам'яті агентів (short-term, mid-term, long-term)
|
||||
Базується на документації: docs/cursor/13_agent_memory_system.md
|
||||
"""
|
||||
__tablename__ = "agent_memory_events"
|
||||
|
||||
id = Column(UUID_TYPE(as_uuid=False) if not IS_SQLITE else String, primary_key=True, server_default=func.gen_random_uuid() if not IS_SQLITE else None)
|
||||
|
||||
# Без FK constraints для тестування
|
||||
agent_id = Column(String, nullable=False, index=True)
|
||||
team_id = Column(String, nullable=False, index=True)
|
||||
channel_id = Column(String, nullable=True, index=True)
|
||||
user_id = Column(String, nullable=True, index=True)
|
||||
|
||||
# Scope: short_term, mid_term, long_term
|
||||
scope = Column(String, nullable=False)
|
||||
|
||||
# Kind: message, fact, summary, note
|
||||
kind = Column(String, nullable=False)
|
||||
|
||||
# Тіло події
|
||||
body_text = Column(Text, nullable=True)
|
||||
body_json = Column(JSONB_TYPE, nullable=True)
|
||||
|
||||
created_at = Column(TIMESTAMP(timezone=True) if not IS_SQLITE else TIMESTAMP, nullable=False, server_default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
CheckConstraint("scope IN ('short_term', 'mid_term', 'long_term')", name="ck_agent_memory_scope"),
|
||||
CheckConstraint("kind IN ('message', 'fact', 'summary', 'note')", name="ck_agent_memory_kind"),
|
||||
Index("idx_agent_memory_events_agent_team_scope", "agent_id", "team_id", "scope"),
|
||||
Index("idx_agent_memory_events_channel", "agent_id", "channel_id"),
|
||||
Index("idx_agent_memory_events_created_at", "created_at"),
|
||||
)
|
||||
class MemoryCategory(str, Enum):
|
||||
PREFERENCE = "preference"
|
||||
IDENTITY = "identity"
|
||||
CONSTRAINT = "constraint"
|
||||
PROJECT_FACT = "project_fact"
|
||||
RELATIONSHIP = "relationship"
|
||||
SKILL = "skill"
|
||||
GOAL = "goal"
|
||||
CONTEXT = "context"
|
||||
FEEDBACK = "feedback"
|
||||
|
||||
|
||||
class AgentMemoryFactsVector(Base):
|
||||
"""
|
||||
Векторні представлення фактів для RAG (Retrieval-Augmented Generation)
|
||||
"""
|
||||
__tablename__ = "agent_memory_facts_vector"
|
||||
class RetentionPolicy(str, Enum):
|
||||
PERMANENT = "permanent"
|
||||
SESSION = "session"
|
||||
TTL_DAYS = "ttl_days"
|
||||
UNTIL_REVOKED = "until_revoked"
|
||||
|
||||
id = Column(UUID_TYPE(as_uuid=False) if not IS_SQLITE else String, primary_key=True, server_default=func.gen_random_uuid() if not IS_SQLITE else None)
|
||||
|
||||
# Без FK constraints для тестування
|
||||
agent_id = Column(String, nullable=False, index=True)
|
||||
team_id = Column(String, nullable=False, index=True)
|
||||
|
||||
fact_text = Column(Text, nullable=False)
|
||||
embedding = Column(Vector(1536), nullable=True) if HAS_PGVECTOR else Column(Text, nullable=True) # OpenAI ada-002 embedding size
|
||||
|
||||
meta = Column(JSONB_TYPE, nullable=False, server_default="{}")
|
||||
|
||||
created_at = Column(TIMESTAMP(timezone=True) if not IS_SQLITE else TIMESTAMP, nullable=False, server_default=func.now())
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_agent_memory_facts_vector_agent_team", "agent_id", "team_id"),
|
||||
)
|
||||
class FeedbackAction(str, Enum):
|
||||
CONFIRM = "confirm"
|
||||
REJECT = "reject"
|
||||
EDIT = "edit"
|
||||
DELETE = "delete"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# REQUEST MODELS
|
||||
# ============================================================================
|
||||
|
||||
class CreateThreadRequest(BaseModel):
|
||||
"""Create new conversation thread"""
|
||||
org_id: UUID
|
||||
workspace_id: Optional[UUID] = None
|
||||
user_id: UUID
|
||||
agent_id: Optional[UUID] = None
|
||||
title: Optional[str] = None
|
||||
tags: List[str] = []
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class AddEventRequest(BaseModel):
|
||||
"""Add event to conversation"""
|
||||
thread_id: UUID
|
||||
event_type: EventType
|
||||
role: Optional[MessageRole] = None
|
||||
content: Optional[str] = None
|
||||
tool_name: Optional[str] = None
|
||||
tool_input: Optional[dict] = None
|
||||
tool_output: Optional[dict] = None
|
||||
payload: dict = {}
|
||||
token_count: Optional[int] = None
|
||||
model_used: Optional[str] = None
|
||||
latency_ms: Optional[int] = None
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class CreateMemoryRequest(BaseModel):
|
||||
"""Create long-term memory item"""
|
||||
org_id: UUID
|
||||
workspace_id: Optional[UUID] = None
|
||||
user_id: UUID
|
||||
agent_id: Optional[UUID] = None # null = global
|
||||
category: MemoryCategory
|
||||
fact_text: str
|
||||
confidence: float = Field(default=0.8, ge=0, le=1)
|
||||
source_event_id: Optional[UUID] = None
|
||||
source_thread_id: Optional[UUID] = None
|
||||
extraction_method: str = "explicit"
|
||||
is_sensitive: bool = False
|
||||
retention: RetentionPolicy = RetentionPolicy.UNTIL_REVOKED
|
||||
ttl_days: Optional[int] = None
|
||||
tags: List[str] = []
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class MemoryFeedbackRequest(BaseModel):
|
||||
"""User feedback on memory"""
|
||||
memory_id: UUID
|
||||
user_id: UUID
|
||||
action: FeedbackAction
|
||||
new_value: Optional[str] = None
|
||||
reason: Optional[str] = None
|
||||
|
||||
|
||||
class RetrievalRequest(BaseModel):
|
||||
"""Semantic retrieval request"""
|
||||
org_id: UUID
|
||||
user_id: UUID
|
||||
agent_id: Optional[UUID] = None
|
||||
workspace_id: Optional[UUID] = None
|
||||
queries: List[str]
|
||||
top_k: int = 10
|
||||
min_confidence: float = 0.5
|
||||
include_global: bool = True
|
||||
categories: Optional[List[MemoryCategory]] = None
|
||||
|
||||
|
||||
class SummaryRequest(BaseModel):
|
||||
"""Generate summary for thread"""
|
||||
thread_id: UUID
|
||||
force: bool = False # force even if under token threshold
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# RESPONSE MODELS
|
||||
# ============================================================================
|
||||
|
||||
class ThreadResponse(BaseModel):
|
||||
thread_id: UUID
|
||||
org_id: UUID
|
||||
workspace_id: Optional[UUID]
|
||||
user_id: UUID
|
||||
agent_id: Optional[UUID]
|
||||
title: Optional[str]
|
||||
status: str
|
||||
message_count: int
|
||||
total_tokens: int
|
||||
created_at: datetime
|
||||
last_activity_at: datetime
|
||||
|
||||
|
||||
class EventResponse(BaseModel):
|
||||
event_id: UUID
|
||||
thread_id: UUID
|
||||
event_type: EventType
|
||||
role: Optional[MessageRole]
|
||||
content: Optional[str]
|
||||
tool_name: Optional[str]
|
||||
tool_input: Optional[dict]
|
||||
tool_output: Optional[dict]
|
||||
payload: dict
|
||||
token_count: Optional[int]
|
||||
sequence_num: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class MemoryResponse(BaseModel):
|
||||
memory_id: UUID
|
||||
org_id: UUID
|
||||
workspace_id: Optional[UUID]
|
||||
user_id: UUID
|
||||
agent_id: Optional[UUID]
|
||||
category: MemoryCategory
|
||||
fact_text: str
|
||||
confidence: float
|
||||
is_verified: bool
|
||||
is_sensitive: bool
|
||||
retention: RetentionPolicy
|
||||
valid_from: datetime
|
||||
valid_to: Optional[datetime]
|
||||
last_used_at: Optional[datetime]
|
||||
use_count: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class SummaryResponse(BaseModel):
|
||||
summary_id: UUID
|
||||
thread_id: UUID
|
||||
version: int
|
||||
summary_text: str
|
||||
state: dict
|
||||
events_count: int
|
||||
compression_ratio: Optional[float]
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class RetrievalResult(BaseModel):
|
||||
"""Single retrieval result"""
|
||||
memory_id: UUID
|
||||
fact_text: str
|
||||
category: MemoryCategory
|
||||
confidence: float
|
||||
relevance_score: float
|
||||
agent_id: Optional[UUID]
|
||||
is_global: bool
|
||||
|
||||
|
||||
class RetrievalResponse(BaseModel):
|
||||
"""Retrieval response with results"""
|
||||
results: List[RetrievalResult]
|
||||
query_count: int
|
||||
total_results: int
|
||||
|
||||
|
||||
class ContextResponse(BaseModel):
|
||||
"""Full context for agent prompt"""
|
||||
thread_id: UUID
|
||||
summary: Optional[SummaryResponse]
|
||||
recent_messages: List[EventResponse]
|
||||
retrieved_memories: List[RetrievalResult]
|
||||
token_estimate: int
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# INTERNAL MODELS
|
||||
# ============================================================================
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
"""Internal embedding request"""
|
||||
texts: List[str]
|
||||
input_type: str = "search_document" # or "search_query"
|
||||
|
||||
|
||||
class QdrantPayload(BaseModel):
|
||||
"""Qdrant point payload"""
|
||||
org_id: str
|
||||
workspace_id: Optional[str]
|
||||
user_id: str
|
||||
agent_id: Optional[str]
|
||||
thread_id: Optional[str]
|
||||
memory_id: Optional[str]
|
||||
event_id: Optional[str]
|
||||
type: str # "memory", "summary", "message"
|
||||
category: Optional[str]
|
||||
text: str
|
||||
created_at: str
|
||||
|
||||
325
services/memory-service/app/vector_store.py
Normal file
325
services/memory-service/app/vector_store.py
Normal file
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
DAARION Memory Service - Qdrant Vector Store
|
||||
"""
|
||||
from typing import List, Optional, Dict, Any
|
||||
from uuid import UUID, uuid4
|
||||
import structlog
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models as qmodels
|
||||
|
||||
from .config import get_settings
|
||||
from .embedding import get_query_embedding, get_document_embeddings
|
||||
from .models import MemoryCategory, QdrantPayload
|
||||
|
||||
logger = structlog.get_logger()
|
||||
settings = get_settings()
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""Qdrant vector store for semantic search"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = QdrantClient(
|
||||
host=settings.qdrant_host,
|
||||
port=settings.qdrant_port
|
||||
)
|
||||
self.memories_collection = settings.qdrant_collection_memories
|
||||
self.messages_collection = settings.qdrant_collection_messages
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize collections if they don't exist"""
|
||||
await self._ensure_collection(
|
||||
self.memories_collection,
|
||||
settings.embedding_dimensions
|
||||
)
|
||||
await self._ensure_collection(
|
||||
self.messages_collection,
|
||||
settings.embedding_dimensions
|
||||
)
|
||||
logger.info("vector_store_initialized")
|
||||
|
||||
async def _ensure_collection(self, name: str, dimensions: int):
|
||||
"""Create collection if it doesn't exist"""
|
||||
collections = self.client.get_collections().collections
|
||||
exists = any(c.name == name for c in collections)
|
||||
|
||||
if not exists:
|
||||
self.client.create_collection(
|
||||
collection_name=name,
|
||||
vectors_config=qmodels.VectorParams(
|
||||
size=dimensions,
|
||||
distance=qmodels.Distance.COSINE
|
||||
)
|
||||
)
|
||||
|
||||
# Create payload indexes for filtering
|
||||
self.client.create_payload_index(
|
||||
collection_name=name,
|
||||
field_name="org_id",
|
||||
field_schema=qmodels.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
self.client.create_payload_index(
|
||||
collection_name=name,
|
||||
field_name="user_id",
|
||||
field_schema=qmodels.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
self.client.create_payload_index(
|
||||
collection_name=name,
|
||||
field_name="agent_id",
|
||||
field_schema=qmodels.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
self.client.create_payload_index(
|
||||
collection_name=name,
|
||||
field_name="type",
|
||||
field_schema=qmodels.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
self.client.create_payload_index(
|
||||
collection_name=name,
|
||||
field_name="category",
|
||||
field_schema=qmodels.PayloadSchemaType.KEYWORD
|
||||
)
|
||||
|
||||
logger.info("collection_created", name=name, dimensions=dimensions)
|
||||
|
||||
async def index_memory(
|
||||
self,
|
||||
memory_id: UUID,
|
||||
text: str,
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
category: MemoryCategory,
|
||||
agent_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
thread_id: Optional[UUID] = None,
|
||||
metadata: Dict[str, Any] = {}
|
||||
) -> str:
|
||||
"""
|
||||
Index a memory item in Qdrant.
|
||||
|
||||
Returns:
|
||||
Qdrant point ID
|
||||
"""
|
||||
# Get embedding
|
||||
embeddings = await get_document_embeddings([text])
|
||||
if not embeddings:
|
||||
raise ValueError("Failed to generate embedding")
|
||||
|
||||
vector = embeddings[0]
|
||||
point_id = str(uuid4())
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"org_id": str(org_id),
|
||||
"user_id": str(user_id),
|
||||
"memory_id": str(memory_id),
|
||||
"type": "memory",
|
||||
"category": category.value,
|
||||
"text": text,
|
||||
**metadata
|
||||
}
|
||||
|
||||
if agent_id:
|
||||
payload["agent_id"] = str(agent_id)
|
||||
if workspace_id:
|
||||
payload["workspace_id"] = str(workspace_id)
|
||||
if thread_id:
|
||||
payload["thread_id"] = str(thread_id)
|
||||
|
||||
# Upsert point
|
||||
self.client.upsert(
|
||||
collection_name=self.memories_collection,
|
||||
points=[
|
||||
qmodels.PointStruct(
|
||||
id=point_id,
|
||||
vector=vector,
|
||||
payload=payload
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"memory_indexed",
|
||||
memory_id=str(memory_id),
|
||||
point_id=point_id,
|
||||
category=category.value
|
||||
)
|
||||
|
||||
return point_id
|
||||
|
||||
async def index_summary(
|
||||
self,
|
||||
summary_id: UUID,
|
||||
text: str,
|
||||
thread_id: UUID,
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
agent_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None
|
||||
) -> str:
|
||||
"""Index a thread summary"""
|
||||
embeddings = await get_document_embeddings([text])
|
||||
if not embeddings:
|
||||
raise ValueError("Failed to generate embedding")
|
||||
|
||||
vector = embeddings[0]
|
||||
point_id = str(uuid4())
|
||||
|
||||
payload = {
|
||||
"org_id": str(org_id),
|
||||
"user_id": str(user_id),
|
||||
"thread_id": str(thread_id),
|
||||
"summary_id": str(summary_id),
|
||||
"type": "summary",
|
||||
"text": text
|
||||
}
|
||||
|
||||
if agent_id:
|
||||
payload["agent_id"] = str(agent_id)
|
||||
if workspace_id:
|
||||
payload["workspace_id"] = str(workspace_id)
|
||||
|
||||
self.client.upsert(
|
||||
collection_name=self.memories_collection,
|
||||
points=[
|
||||
qmodels.PointStruct(
|
||||
id=point_id,
|
||||
vector=vector,
|
||||
payload=payload
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
logger.info("summary_indexed", summary_id=str(summary_id), point_id=point_id)
|
||||
return point_id
|
||||
|
||||
async def search_memories(
|
||||
self,
|
||||
query: str,
|
||||
org_id: UUID,
|
||||
user_id: UUID,
|
||||
agent_id: Optional[UUID] = None,
|
||||
workspace_id: Optional[UUID] = None,
|
||||
categories: Optional[List[MemoryCategory]] = None,
|
||||
include_global: bool = True,
|
||||
top_k: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Semantic search for memories.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
org_id: Organization filter
|
||||
user_id: User filter
|
||||
agent_id: Agent filter (None = include global)
|
||||
workspace_id: Workspace filter
|
||||
categories: Filter by memory categories
|
||||
include_global: Include memories without agent_id
|
||||
top_k: Number of results
|
||||
|
||||
Returns:
|
||||
List of results with scores
|
||||
"""
|
||||
# Get query embedding
|
||||
query_vector = await get_query_embedding(query)
|
||||
if not query_vector:
|
||||
return []
|
||||
|
||||
# Build filter
|
||||
must_conditions = [
|
||||
qmodels.FieldCondition(
|
||||
key="org_id",
|
||||
match=qmodels.MatchValue(value=str(org_id))
|
||||
),
|
||||
qmodels.FieldCondition(
|
||||
key="user_id",
|
||||
match=qmodels.MatchValue(value=str(user_id))
|
||||
)
|
||||
]
|
||||
|
||||
if workspace_id:
|
||||
must_conditions.append(
|
||||
qmodels.FieldCondition(
|
||||
key="workspace_id",
|
||||
match=qmodels.MatchValue(value=str(workspace_id))
|
||||
)
|
||||
)
|
||||
|
||||
if categories:
|
||||
must_conditions.append(
|
||||
qmodels.FieldCondition(
|
||||
key="category",
|
||||
match=qmodels.MatchAny(any=[c.value for c in categories])
|
||||
)
|
||||
)
|
||||
|
||||
# Agent filter with global option
|
||||
if agent_id and not include_global:
|
||||
must_conditions.append(
|
||||
qmodels.FieldCondition(
|
||||
key="agent_id",
|
||||
match=qmodels.MatchValue(value=str(agent_id))
|
||||
)
|
||||
)
|
||||
elif agent_id and include_global:
|
||||
# Include both agent-specific and global (no agent_id)
|
||||
# This requires a should clause
|
||||
pass # Will handle in separate query if needed
|
||||
|
||||
search_filter = qmodels.Filter(must=must_conditions)
|
||||
|
||||
# Search
|
||||
results = self.client.search(
|
||||
collection_name=self.memories_collection,
|
||||
query_vector=query_vector,
|
||||
query_filter=search_filter,
|
||||
limit=top_k,
|
||||
with_payload=True
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"memory_search",
|
||||
query_preview=query[:50],
|
||||
results_count=len(results)
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"point_id": str(r.id),
|
||||
"score": r.score,
|
||||
**r.payload
|
||||
}
|
||||
for r in results
|
||||
]
|
||||
|
||||
async def delete_memory(self, memory_id: UUID):
|
||||
"""Delete memory from index by memory_id"""
|
||||
self.client.delete(
|
||||
collection_name=self.memories_collection,
|
||||
points_selector=qmodels.FilterSelector(
|
||||
filter=qmodels.Filter(
|
||||
must=[
|
||||
qmodels.FieldCondition(
|
||||
key="memory_id",
|
||||
match=qmodels.MatchValue(value=str(memory_id))
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
logger.info("memory_deleted_from_index", memory_id=str(memory_id))
|
||||
|
||||
async def get_collection_stats(self) -> Dict[str, Any]:
|
||||
"""Get collection statistics"""
|
||||
memories_info = self.client.get_collection(self.memories_collection)
|
||||
|
||||
return {
|
||||
"memories": {
|
||||
"points_count": memories_info.points_count,
|
||||
"vectors_count": memories_info.vectors_count,
|
||||
"indexed_vectors_count": memories_info.indexed_vectors_count
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Global instance
|
||||
vector_store = VectorStore()
|
||||
@@ -1,6 +1,32 @@
|
||||
fastapi>=0.104.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
sqlalchemy>=2.0.0
|
||||
psycopg2-binary>=2.9.0
|
||||
pydantic>=2.0.0
|
||||
python-dotenv>=1.0.0
|
||||
# DAARION Memory Service
|
||||
# Agent memory management with PostgreSQL + Qdrant + Cohere
|
||||
|
||||
# Web framework
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
pydantic==2.5.3
|
||||
pydantic-settings==2.1.0
|
||||
|
||||
# Database
|
||||
asyncpg==0.29.0
|
||||
sqlalchemy[asyncio]==2.0.25
|
||||
alembic==1.13.1
|
||||
|
||||
# Vector database
|
||||
qdrant-client==1.7.3
|
||||
|
||||
# Embeddings
|
||||
cohere==4.44
|
||||
|
||||
# Utilities
|
||||
python-dotenv==1.0.0
|
||||
httpx==0.26.0
|
||||
tenacity==8.2.3
|
||||
structlog==24.1.0
|
||||
|
||||
# Token counting
|
||||
tiktoken==0.5.2
|
||||
|
||||
# Testing
|
||||
pytest==7.4.4
|
||||
pytest-asyncio==0.23.3
|
||||
|
||||
Reference in New Issue
Block a user