""" DAARION Memory Service - PostgreSQL Database Layer """ from typing import List, Optional, Dict, Any from uuid import UUID, uuid4 from datetime import datetime import structlog import asyncpg from .config import get_settings from .models import EventType, MessageRole, MemoryCategory, RetentionPolicy, FeedbackAction logger = structlog.get_logger() settings = get_settings() class Database: """PostgreSQL database operations""" def __init__(self): self.pool: Optional[asyncpg.Pool] = None async def connect(self): """Connect to database""" self.pool = await asyncpg.create_pool( host=settings.postgres_host, port=settings.postgres_port, user=settings.postgres_user, password=settings.postgres_password, database=settings.postgres_db, min_size=5, max_size=20 ) logger.info("database_connected") async def disconnect(self): """Disconnect from database""" if self.pool: await self.pool.close() logger.info("database_disconnected") # ======================================================================== # THREADS # ======================================================================== async def create_thread( self, org_id: UUID, user_id: UUID, workspace_id: Optional[UUID] = None, agent_id: Optional[UUID] = None, title: Optional[str] = None, tags: List[str] = [], metadata: dict = {} ) -> Dict[str, Any]: """Create new conversation thread""" thread_id = uuid4() async with self.pool.acquire() as conn: row = await conn.fetchrow(""" INSERT INTO conversation_threads (thread_id, org_id, workspace_id, user_id, agent_id, title, tags, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING * """, thread_id, org_id, workspace_id, user_id, agent_id, title, tags, metadata) logger.info("thread_created", thread_id=str(thread_id)) return dict(row) async def get_thread(self, thread_id: UUID) -> Optional[Dict[str, Any]]: """Get thread by ID""" async with self.pool.acquire() as conn: row = await conn.fetchrow(""" SELECT * FROM conversation_threads WHERE thread_id = $1 """, thread_id) return dict(row) if row else None async def list_threads( self, org_id: UUID, user_id: UUID, workspace_id: Optional[UUID] = None, agent_id: Optional[UUID] = None, limit: int = 20 ) -> List[Dict[str, Any]]: """List threads for user""" async with self.pool.acquire() as conn: query = """ SELECT * FROM conversation_threads WHERE org_id = $1 AND user_id = $2 AND status = 'active' """ params = [org_id, user_id] if workspace_id: query += f" AND workspace_id = ${len(params) + 1}" params.append(workspace_id) if agent_id: query += f" AND agent_id = ${len(params) + 1}" params.append(agent_id) query += f" ORDER BY last_activity_at DESC LIMIT ${len(params) + 1}" params.append(limit) rows = await conn.fetch(query, *params) return [dict(row) for row in rows] # ======================================================================== # EVENTS # ======================================================================== async def add_event( self, thread_id: UUID, event_type: EventType, role: Optional[MessageRole] = None, content: Optional[str] = None, tool_name: Optional[str] = None, tool_input: Optional[dict] = None, tool_output: Optional[dict] = None, payload: dict = {}, token_count: Optional[int] = None, model_used: Optional[str] = None, latency_ms: Optional[int] = None, metadata: dict = {} ) -> Dict[str, Any]: """Add event to conversation""" event_id = uuid4() async with self.pool.acquire() as conn: row = await conn.fetchrow(""" INSERT INTO conversation_events (event_id, thread_id, event_type, role, content, tool_name, tool_input, tool_output, payload, token_count, model_used, latency_ms, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) RETURNING * """, event_id, thread_id, event_type.value, role.value if role else None, content, tool_name, tool_input, tool_output, payload, token_count, model_used, latency_ms, metadata) logger.info("event_added", event_id=str(event_id), type=event_type.value) return dict(row) async def get_events( self, thread_id: UUID, limit: int = 50, offset: int = 0 ) -> List[Dict[str, Any]]: """Get events for thread""" async with self.pool.acquire() as conn: rows = await conn.fetch(""" SELECT * FROM conversation_events WHERE thread_id = $1 ORDER BY sequence_num DESC LIMIT $2 OFFSET $3 """, thread_id, limit, offset) return [dict(row) for row in rows] async def get_events_for_summary( self, thread_id: UUID, after_seq: Optional[int] = None ) -> List[Dict[str, Any]]: """Get events for summarization""" async with self.pool.acquire() as conn: if after_seq: rows = await conn.fetch(""" SELECT * FROM conversation_events WHERE thread_id = $1 AND sequence_num > $2 ORDER BY sequence_num ASC """, thread_id, after_seq) else: rows = await conn.fetch(""" SELECT * FROM conversation_events WHERE thread_id = $1 ORDER BY sequence_num ASC """, thread_id) return [dict(row) for row in rows] # ======================================================================== # MEMORIES # ======================================================================== async def create_memory( self, org_id: UUID, user_id: UUID, category: MemoryCategory, fact_text: str, workspace_id: Optional[UUID] = None, agent_id: Optional[UUID] = None, confidence: float = 0.8, source_event_id: Optional[UUID] = None, source_thread_id: Optional[UUID] = None, extraction_method: str = "explicit", is_sensitive: bool = False, retention: RetentionPolicy = RetentionPolicy.UNTIL_REVOKED, ttl_days: Optional[int] = None, tags: List[str] = [], metadata: dict = {} ) -> Dict[str, Any]: """Create long-term memory item""" memory_id = uuid4() async with self.pool.acquire() as conn: row = await conn.fetchrow(""" INSERT INTO long_term_memory_items (memory_id, org_id, workspace_id, user_id, agent_id, category, fact_text, confidence, source_event_id, source_thread_id, extraction_method, is_sensitive, retention, ttl_days, tags, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16) RETURNING * """, memory_id, org_id, workspace_id, user_id, agent_id, category.value, fact_text, confidence, source_event_id, source_thread_id, extraction_method, is_sensitive, retention.value, ttl_days, tags, metadata) logger.info("memory_created", memory_id=str(memory_id), category=category.value) return dict(row) async def get_memory(self, memory_id: UUID) -> Optional[Dict[str, Any]]: """Get memory by ID""" async with self.pool.acquire() as conn: row = await conn.fetchrow(""" SELECT * FROM long_term_memory_items WHERE memory_id = $1 AND valid_to IS NULL """, memory_id) return dict(row) if row else None async def list_memories( self, org_id: UUID, user_id: UUID, agent_id: Optional[UUID] = None, workspace_id: Optional[UUID] = None, category: Optional[MemoryCategory] = None, include_global: bool = True, limit: int = 50 ) -> List[Dict[str, Any]]: """List memories for user""" async with self.pool.acquire() as conn: query = """ SELECT * FROM long_term_memory_items WHERE org_id = $1 AND user_id = $2 AND valid_to IS NULL AND confidence >= $3 """ params = [org_id, user_id, settings.memory_min_confidence] if workspace_id: query += f" AND (workspace_id = ${len(params) + 1} OR workspace_id IS NULL)" params.append(workspace_id) if agent_id: if include_global: query += f" AND (agent_id = ${len(params) + 1} OR agent_id IS NULL)" else: query += f" AND agent_id = ${len(params) + 1}" params.append(agent_id) if category: query += f" AND category = ${len(params) + 1}" params.append(category.value) query += f" ORDER BY confidence DESC, last_used_at DESC NULLS LAST LIMIT ${len(params) + 1}" params.append(limit) rows = await conn.fetch(query, *params) return [dict(row) for row in rows] async def update_memory_embedding_id(self, memory_id: UUID, embedding_id: str): """Update memory with Qdrant point ID""" async with self.pool.acquire() as conn: await conn.execute(""" UPDATE long_term_memory_items SET fact_embedding_id = $2 WHERE memory_id = $1 """, memory_id, embedding_id) async def update_memory_confidence( self, memory_id: UUID, confidence: float, verified: bool = False ): """Update memory confidence""" async with self.pool.acquire() as conn: await conn.execute(""" UPDATE long_term_memory_items SET confidence = $2, is_verified = CASE WHEN $3 THEN true ELSE is_verified END, verification_count = verification_count + CASE WHEN $3 THEN 1 ELSE 0 END, last_confirmed_at = CASE WHEN $3 THEN NOW() ELSE last_confirmed_at END WHERE memory_id = $1 """, memory_id, confidence, verified) async def update_memory_text(self, memory_id: UUID, new_text: str): """Update memory text""" async with self.pool.acquire() as conn: await conn.execute(""" UPDATE long_term_memory_items SET fact_text = $2 WHERE memory_id = $1 """, memory_id, new_text) async def invalidate_memory(self, memory_id: UUID): """Mark memory as invalid (soft delete)""" async with self.pool.acquire() as conn: await conn.execute(""" UPDATE long_term_memory_items SET valid_to = NOW() WHERE memory_id = $1 """, memory_id) logger.info("memory_invalidated", memory_id=str(memory_id)) async def increment_memory_usage(self, memory_id: UUID): """Increment memory usage counter""" async with self.pool.acquire() as conn: await conn.execute(""" UPDATE long_term_memory_items SET use_count = use_count + 1, last_used_at = NOW() WHERE memory_id = $1 """, memory_id) # ======================================================================== # FEEDBACK # ======================================================================== async def add_memory_feedback( self, memory_id: UUID, user_id: UUID, action: FeedbackAction, old_value: Optional[str] = None, new_value: Optional[str] = None, reason: Optional[str] = None ): """Record user feedback on memory""" feedback_id = uuid4() async with self.pool.acquire() as conn: await conn.execute(""" INSERT INTO memory_feedback (feedback_id, memory_id, user_id, action, old_value, new_value, reason) VALUES ($1, $2, $3, $4, $5, $6, $7) """, feedback_id, memory_id, user_id, action.value, old_value, new_value, reason) logger.info("feedback_recorded", memory_id=str(memory_id), action=action.value) # ======================================================================== # SUMMARIES # ======================================================================== async def create_summary( self, thread_id: UUID, summary_text: str, state: dict, events_from_seq: int, events_to_seq: int, events_count: int, original_tokens: Optional[int] = None, summary_tokens: Optional[int] = None ) -> Dict[str, Any]: """Create thread summary""" summary_id = uuid4() # Get next version async with self.pool.acquire() as conn: version_row = await conn.fetchrow(""" SELECT COALESCE(MAX(version), 0) + 1 as next_version FROM thread_summaries WHERE thread_id = $1 """, thread_id) version = version_row["next_version"] compression_ratio = None if original_tokens and summary_tokens: compression_ratio = summary_tokens / original_tokens row = await conn.fetchrow(""" INSERT INTO thread_summaries (summary_id, thread_id, version, summary_text, state, events_from_seq, events_to_seq, events_count, original_tokens, summary_tokens, compression_ratio) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING * """, summary_id, thread_id, version, summary_text, state, events_from_seq, events_to_seq, events_count, original_tokens, summary_tokens, compression_ratio) logger.info("summary_created", summary_id=str(summary_id), version=version) return dict(row) async def get_latest_summary(self, thread_id: UUID) -> Optional[Dict[str, Any]]: """Get latest summary for thread""" async with self.pool.acquire() as conn: row = await conn.fetchrow(""" SELECT * FROM thread_summaries WHERE thread_id = $1 ORDER BY version DESC LIMIT 1 """, thread_id) return dict(row) if row else None # ======================================================================== # STATS # ======================================================================== async def get_stats(self) -> Dict[str, Any]: """Get database statistics""" async with self.pool.acquire() as conn: threads = await conn.fetchval("SELECT COUNT(*) FROM conversation_threads") events = await conn.fetchval("SELECT COUNT(*) FROM conversation_events") memories = await conn.fetchval("SELECT COUNT(*) FROM long_term_memory_items WHERE valid_to IS NULL") summaries = await conn.fetchval("SELECT COUNT(*) FROM thread_summaries") return { "threads": threads, "events": events, "active_memories": memories, "summaries": summaries } # Global instance db = Database()