Features: - Three-tier memory architecture (short/mid/long-term) - PostgreSQL schema for conversations, events, memories - Qdrant vector database for semantic search - Cohere embeddings (embed-multilingual-v3.0, 1024 dims) - FastAPI Memory Service with full CRUD - External Secrets integration with Vault - Kubernetes deployment manifests Components: - infrastructure/database/agent-memory-schema.sql - infrastructure/kubernetes/apps/qdrant/ - infrastructure/kubernetes/apps/memory-service/ - services/memory-service/ (FastAPI app) Also includes: - External Secrets Operator - Traefik Ingress Controller - Cert-Manager with Let's Encrypt - ArgoCD for GitOps
431 lines
16 KiB
Python
431 lines
16 KiB
Python
"""
|
|
DAARION Memory Service - PostgreSQL Database Layer
|
|
"""
|
|
from typing import List, Optional, Dict, Any
|
|
from uuid import UUID, uuid4
|
|
from datetime import datetime
|
|
import structlog
|
|
import asyncpg
|
|
|
|
from .config import get_settings
|
|
from .models import EventType, MessageRole, MemoryCategory, RetentionPolicy, FeedbackAction
|
|
|
|
logger = structlog.get_logger()
|
|
settings = get_settings()
|
|
|
|
|
|
class Database:
|
|
"""PostgreSQL database operations"""
|
|
|
|
def __init__(self):
|
|
self.pool: Optional[asyncpg.Pool] = None
|
|
|
|
async def connect(self):
|
|
"""Connect to database"""
|
|
self.pool = await asyncpg.create_pool(
|
|
host=settings.postgres_host,
|
|
port=settings.postgres_port,
|
|
user=settings.postgres_user,
|
|
password=settings.postgres_password,
|
|
database=settings.postgres_db,
|
|
min_size=5,
|
|
max_size=20
|
|
)
|
|
logger.info("database_connected")
|
|
|
|
async def disconnect(self):
|
|
"""Disconnect from database"""
|
|
if self.pool:
|
|
await self.pool.close()
|
|
logger.info("database_disconnected")
|
|
|
|
# ========================================================================
|
|
# THREADS
|
|
# ========================================================================
|
|
|
|
async def create_thread(
|
|
self,
|
|
org_id: UUID,
|
|
user_id: UUID,
|
|
workspace_id: Optional[UUID] = None,
|
|
agent_id: Optional[UUID] = None,
|
|
title: Optional[str] = None,
|
|
tags: List[str] = [],
|
|
metadata: dict = {}
|
|
) -> Dict[str, Any]:
|
|
"""Create new conversation thread"""
|
|
thread_id = uuid4()
|
|
|
|
async with self.pool.acquire() as conn:
|
|
row = await conn.fetchrow("""
|
|
INSERT INTO conversation_threads
|
|
(thread_id, org_id, workspace_id, user_id, agent_id, title, tags, metadata)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
RETURNING *
|
|
""", thread_id, org_id, workspace_id, user_id, agent_id, title, tags, metadata)
|
|
|
|
logger.info("thread_created", thread_id=str(thread_id))
|
|
return dict(row)
|
|
|
|
async def get_thread(self, thread_id: UUID) -> Optional[Dict[str, Any]]:
|
|
"""Get thread by ID"""
|
|
async with self.pool.acquire() as conn:
|
|
row = await conn.fetchrow("""
|
|
SELECT * FROM conversation_threads WHERE thread_id = $1
|
|
""", thread_id)
|
|
return dict(row) if row else None
|
|
|
|
async def list_threads(
|
|
self,
|
|
org_id: UUID,
|
|
user_id: UUID,
|
|
workspace_id: Optional[UUID] = None,
|
|
agent_id: Optional[UUID] = None,
|
|
limit: int = 20
|
|
) -> List[Dict[str, Any]]:
|
|
"""List threads for user"""
|
|
async with self.pool.acquire() as conn:
|
|
query = """
|
|
SELECT * FROM conversation_threads
|
|
WHERE org_id = $1 AND user_id = $2 AND status = 'active'
|
|
"""
|
|
params = [org_id, user_id]
|
|
|
|
if workspace_id:
|
|
query += f" AND workspace_id = ${len(params) + 1}"
|
|
params.append(workspace_id)
|
|
if agent_id:
|
|
query += f" AND agent_id = ${len(params) + 1}"
|
|
params.append(agent_id)
|
|
|
|
query += f" ORDER BY last_activity_at DESC LIMIT ${len(params) + 1}"
|
|
params.append(limit)
|
|
|
|
rows = await conn.fetch(query, *params)
|
|
|
|
return [dict(row) for row in rows]
|
|
|
|
# ========================================================================
|
|
# EVENTS
|
|
# ========================================================================
|
|
|
|
async def add_event(
|
|
self,
|
|
thread_id: UUID,
|
|
event_type: EventType,
|
|
role: Optional[MessageRole] = None,
|
|
content: Optional[str] = None,
|
|
tool_name: Optional[str] = None,
|
|
tool_input: Optional[dict] = None,
|
|
tool_output: Optional[dict] = None,
|
|
payload: dict = {},
|
|
token_count: Optional[int] = None,
|
|
model_used: Optional[str] = None,
|
|
latency_ms: Optional[int] = None,
|
|
metadata: dict = {}
|
|
) -> Dict[str, Any]:
|
|
"""Add event to conversation"""
|
|
event_id = uuid4()
|
|
|
|
async with self.pool.acquire() as conn:
|
|
row = await conn.fetchrow("""
|
|
INSERT INTO conversation_events
|
|
(event_id, thread_id, event_type, role, content, tool_name,
|
|
tool_input, tool_output, payload, token_count, model_used,
|
|
latency_ms, metadata)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
|
RETURNING *
|
|
""", event_id, thread_id, event_type.value,
|
|
role.value if role else None, content, tool_name,
|
|
tool_input, tool_output, payload, token_count, model_used,
|
|
latency_ms, metadata)
|
|
|
|
logger.info("event_added", event_id=str(event_id), type=event_type.value)
|
|
return dict(row)
|
|
|
|
async def get_events(
|
|
self,
|
|
thread_id: UUID,
|
|
limit: int = 50,
|
|
offset: int = 0
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get events for thread"""
|
|
async with self.pool.acquire() as conn:
|
|
rows = await conn.fetch("""
|
|
SELECT * FROM conversation_events
|
|
WHERE thread_id = $1
|
|
ORDER BY sequence_num DESC
|
|
LIMIT $2 OFFSET $3
|
|
""", thread_id, limit, offset)
|
|
|
|
return [dict(row) for row in rows]
|
|
|
|
async def get_events_for_summary(
|
|
self,
|
|
thread_id: UUID,
|
|
after_seq: Optional[int] = None
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get events for summarization"""
|
|
async with self.pool.acquire() as conn:
|
|
if after_seq:
|
|
rows = await conn.fetch("""
|
|
SELECT * FROM conversation_events
|
|
WHERE thread_id = $1 AND sequence_num > $2
|
|
ORDER BY sequence_num ASC
|
|
""", thread_id, after_seq)
|
|
else:
|
|
rows = await conn.fetch("""
|
|
SELECT * FROM conversation_events
|
|
WHERE thread_id = $1
|
|
ORDER BY sequence_num ASC
|
|
""", thread_id)
|
|
|
|
return [dict(row) for row in rows]
|
|
|
|
# ========================================================================
|
|
# MEMORIES
|
|
# ========================================================================
|
|
|
|
async def create_memory(
|
|
self,
|
|
org_id: UUID,
|
|
user_id: UUID,
|
|
category: MemoryCategory,
|
|
fact_text: str,
|
|
workspace_id: Optional[UUID] = None,
|
|
agent_id: Optional[UUID] = None,
|
|
confidence: float = 0.8,
|
|
source_event_id: Optional[UUID] = None,
|
|
source_thread_id: Optional[UUID] = None,
|
|
extraction_method: str = "explicit",
|
|
is_sensitive: bool = False,
|
|
retention: RetentionPolicy = RetentionPolicy.UNTIL_REVOKED,
|
|
ttl_days: Optional[int] = None,
|
|
tags: List[str] = [],
|
|
metadata: dict = {}
|
|
) -> Dict[str, Any]:
|
|
"""Create long-term memory item"""
|
|
memory_id = uuid4()
|
|
|
|
async with self.pool.acquire() as conn:
|
|
row = await conn.fetchrow("""
|
|
INSERT INTO long_term_memory_items
|
|
(memory_id, org_id, workspace_id, user_id, agent_id, category,
|
|
fact_text, confidence, source_event_id, source_thread_id,
|
|
extraction_method, is_sensitive, retention, ttl_days, tags, metadata)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16)
|
|
RETURNING *
|
|
""", memory_id, org_id, workspace_id, user_id, agent_id, category.value,
|
|
fact_text, confidence, source_event_id, source_thread_id,
|
|
extraction_method, is_sensitive, retention.value, ttl_days, tags, metadata)
|
|
|
|
logger.info("memory_created", memory_id=str(memory_id), category=category.value)
|
|
return dict(row)
|
|
|
|
async def get_memory(self, memory_id: UUID) -> Optional[Dict[str, Any]]:
|
|
"""Get memory by ID"""
|
|
async with self.pool.acquire() as conn:
|
|
row = await conn.fetchrow("""
|
|
SELECT * FROM long_term_memory_items
|
|
WHERE memory_id = $1 AND valid_to IS NULL
|
|
""", memory_id)
|
|
return dict(row) if row else None
|
|
|
|
async def list_memories(
|
|
self,
|
|
org_id: UUID,
|
|
user_id: UUID,
|
|
agent_id: Optional[UUID] = None,
|
|
workspace_id: Optional[UUID] = None,
|
|
category: Optional[MemoryCategory] = None,
|
|
include_global: bool = True,
|
|
limit: int = 50
|
|
) -> List[Dict[str, Any]]:
|
|
"""List memories for user"""
|
|
async with self.pool.acquire() as conn:
|
|
query = """
|
|
SELECT * FROM long_term_memory_items
|
|
WHERE org_id = $1 AND user_id = $2 AND valid_to IS NULL
|
|
AND confidence >= $3
|
|
"""
|
|
params = [org_id, user_id, settings.memory_min_confidence]
|
|
|
|
if workspace_id:
|
|
query += f" AND (workspace_id = ${len(params) + 1} OR workspace_id IS NULL)"
|
|
params.append(workspace_id)
|
|
|
|
if agent_id:
|
|
if include_global:
|
|
query += f" AND (agent_id = ${len(params) + 1} OR agent_id IS NULL)"
|
|
else:
|
|
query += f" AND agent_id = ${len(params) + 1}"
|
|
params.append(agent_id)
|
|
|
|
if category:
|
|
query += f" AND category = ${len(params) + 1}"
|
|
params.append(category.value)
|
|
|
|
query += f" ORDER BY confidence DESC, last_used_at DESC NULLS LAST LIMIT ${len(params) + 1}"
|
|
params.append(limit)
|
|
|
|
rows = await conn.fetch(query, *params)
|
|
|
|
return [dict(row) for row in rows]
|
|
|
|
async def update_memory_embedding_id(self, memory_id: UUID, embedding_id: str):
|
|
"""Update memory with Qdrant point ID"""
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute("""
|
|
UPDATE long_term_memory_items
|
|
SET fact_embedding_id = $2
|
|
WHERE memory_id = $1
|
|
""", memory_id, embedding_id)
|
|
|
|
async def update_memory_confidence(
|
|
self,
|
|
memory_id: UUID,
|
|
confidence: float,
|
|
verified: bool = False
|
|
):
|
|
"""Update memory confidence"""
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute("""
|
|
UPDATE long_term_memory_items
|
|
SET confidence = $2,
|
|
is_verified = CASE WHEN $3 THEN true ELSE is_verified END,
|
|
verification_count = verification_count + CASE WHEN $3 THEN 1 ELSE 0 END,
|
|
last_confirmed_at = CASE WHEN $3 THEN NOW() ELSE last_confirmed_at END
|
|
WHERE memory_id = $1
|
|
""", memory_id, confidence, verified)
|
|
|
|
async def update_memory_text(self, memory_id: UUID, new_text: str):
|
|
"""Update memory text"""
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute("""
|
|
UPDATE long_term_memory_items
|
|
SET fact_text = $2
|
|
WHERE memory_id = $1
|
|
""", memory_id, new_text)
|
|
|
|
async def invalidate_memory(self, memory_id: UUID):
|
|
"""Mark memory as invalid (soft delete)"""
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute("""
|
|
UPDATE long_term_memory_items
|
|
SET valid_to = NOW()
|
|
WHERE memory_id = $1
|
|
""", memory_id)
|
|
logger.info("memory_invalidated", memory_id=str(memory_id))
|
|
|
|
async def increment_memory_usage(self, memory_id: UUID):
|
|
"""Increment memory usage counter"""
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute("""
|
|
UPDATE long_term_memory_items
|
|
SET use_count = use_count + 1, last_used_at = NOW()
|
|
WHERE memory_id = $1
|
|
""", memory_id)
|
|
|
|
# ========================================================================
|
|
# FEEDBACK
|
|
# ========================================================================
|
|
|
|
async def add_memory_feedback(
|
|
self,
|
|
memory_id: UUID,
|
|
user_id: UUID,
|
|
action: FeedbackAction,
|
|
old_value: Optional[str] = None,
|
|
new_value: Optional[str] = None,
|
|
reason: Optional[str] = None
|
|
):
|
|
"""Record user feedback on memory"""
|
|
feedback_id = uuid4()
|
|
|
|
async with self.pool.acquire() as conn:
|
|
await conn.execute("""
|
|
INSERT INTO memory_feedback
|
|
(feedback_id, memory_id, user_id, action, old_value, new_value, reason)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
|
""", feedback_id, memory_id, user_id, action.value, old_value, new_value, reason)
|
|
|
|
logger.info("feedback_recorded", memory_id=str(memory_id), action=action.value)
|
|
|
|
# ========================================================================
|
|
# SUMMARIES
|
|
# ========================================================================
|
|
|
|
async def create_summary(
|
|
self,
|
|
thread_id: UUID,
|
|
summary_text: str,
|
|
state: dict,
|
|
events_from_seq: int,
|
|
events_to_seq: int,
|
|
events_count: int,
|
|
original_tokens: Optional[int] = None,
|
|
summary_tokens: Optional[int] = None
|
|
) -> Dict[str, Any]:
|
|
"""Create thread summary"""
|
|
summary_id = uuid4()
|
|
|
|
# Get next version
|
|
async with self.pool.acquire() as conn:
|
|
version_row = await conn.fetchrow("""
|
|
SELECT COALESCE(MAX(version), 0) + 1 as next_version
|
|
FROM thread_summaries WHERE thread_id = $1
|
|
""", thread_id)
|
|
version = version_row["next_version"]
|
|
|
|
compression_ratio = None
|
|
if original_tokens and summary_tokens:
|
|
compression_ratio = summary_tokens / original_tokens
|
|
|
|
row = await conn.fetchrow("""
|
|
INSERT INTO thread_summaries
|
|
(summary_id, thread_id, version, summary_text, state,
|
|
events_from_seq, events_to_seq, events_count,
|
|
original_tokens, summary_tokens, compression_ratio)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
|
|
RETURNING *
|
|
""", summary_id, thread_id, version, summary_text, state,
|
|
events_from_seq, events_to_seq, events_count,
|
|
original_tokens, summary_tokens, compression_ratio)
|
|
|
|
logger.info("summary_created", summary_id=str(summary_id), version=version)
|
|
return dict(row)
|
|
|
|
async def get_latest_summary(self, thread_id: UUID) -> Optional[Dict[str, Any]]:
|
|
"""Get latest summary for thread"""
|
|
async with self.pool.acquire() as conn:
|
|
row = await conn.fetchrow("""
|
|
SELECT * FROM thread_summaries
|
|
WHERE thread_id = $1
|
|
ORDER BY version DESC
|
|
LIMIT 1
|
|
""", thread_id)
|
|
return dict(row) if row else None
|
|
|
|
# ========================================================================
|
|
# STATS
|
|
# ========================================================================
|
|
|
|
async def get_stats(self) -> Dict[str, Any]:
|
|
"""Get database statistics"""
|
|
async with self.pool.acquire() as conn:
|
|
threads = await conn.fetchval("SELECT COUNT(*) FROM conversation_threads")
|
|
events = await conn.fetchval("SELECT COUNT(*) FROM conversation_events")
|
|
memories = await conn.fetchval("SELECT COUNT(*) FROM long_term_memory_items WHERE valid_to IS NULL")
|
|
summaries = await conn.fetchval("SELECT COUNT(*) FROM thread_summaries")
|
|
|
|
return {
|
|
"threads": threads,
|
|
"events": events,
|
|
"active_memories": memories,
|
|
"summaries": summaries
|
|
}
|
|
|
|
|
|
# Global instance
|
|
db = Database()
|