NATS wildcards (node.*.capabilities.get) only work for subscriptions, not for publish. Switch to a dedicated broadcast subject (fabric.capabilities.discover) that all NCS instances subscribe to, enabling proper scatter-gather discovery across nodes. Made-with: Cursor
340 lines
11 KiB
Python
340 lines
11 KiB
Python
"""
|
|
DAARION Memory Service - Qdrant Vector Store
|
|
"""
|
|
from typing import List, Optional, Dict, Any
|
|
from uuid import UUID, uuid4
|
|
import structlog
|
|
from qdrant_client import QdrantClient
|
|
from qdrant_client.http import models as qmodels
|
|
|
|
from .config import get_settings
|
|
from .embedding import get_query_embedding, get_document_embeddings
|
|
from .models import MemoryCategory, QdrantPayload
|
|
|
|
logger = structlog.get_logger()
|
|
settings = get_settings()
|
|
|
|
|
|
class VectorStore:
|
|
"""Qdrant vector store for semantic search"""
|
|
|
|
def __init__(self):
|
|
self.client = QdrantClient(
|
|
host=settings.qdrant_host,
|
|
port=settings.qdrant_port
|
|
)
|
|
self.memories_collection = settings.qdrant_collection_memories
|
|
self.messages_collection = settings.qdrant_collection_messages
|
|
|
|
async def initialize(self):
|
|
"""Initialize collections if they don't exist"""
|
|
await self._ensure_collection(
|
|
self.memories_collection,
|
|
settings.embedding_dimensions
|
|
)
|
|
await self._ensure_collection(
|
|
self.messages_collection,
|
|
settings.embedding_dimensions
|
|
)
|
|
logger.info("vector_store_initialized")
|
|
|
|
async def _ensure_collection(self, name: str, dimensions: int):
|
|
"""Create collection if it doesn't exist"""
|
|
collections = self.client.get_collections().collections
|
|
exists = any(c.name == name for c in collections)
|
|
|
|
if not exists:
|
|
self.client.create_collection(
|
|
collection_name=name,
|
|
vectors_config=qmodels.VectorParams(
|
|
size=dimensions,
|
|
distance=qmodels.Distance.COSINE
|
|
)
|
|
)
|
|
|
|
# Create payload indexes for filtering
|
|
self.client.create_payload_index(
|
|
collection_name=name,
|
|
field_name="org_id",
|
|
field_schema=qmodels.PayloadSchemaType.KEYWORD
|
|
)
|
|
self.client.create_payload_index(
|
|
collection_name=name,
|
|
field_name="user_id",
|
|
field_schema=qmodels.PayloadSchemaType.KEYWORD
|
|
)
|
|
self.client.create_payload_index(
|
|
collection_name=name,
|
|
field_name="agent_id",
|
|
field_schema=qmodels.PayloadSchemaType.KEYWORD
|
|
)
|
|
self.client.create_payload_index(
|
|
collection_name=name,
|
|
field_name="type",
|
|
field_schema=qmodels.PayloadSchemaType.KEYWORD
|
|
)
|
|
self.client.create_payload_index(
|
|
collection_name=name,
|
|
field_name="category",
|
|
field_schema=qmodels.PayloadSchemaType.KEYWORD
|
|
)
|
|
|
|
logger.info("collection_created", name=name, dimensions=dimensions)
|
|
|
|
async def index_memory(
|
|
self,
|
|
memory_id: UUID,
|
|
text: str,
|
|
org_id: UUID,
|
|
user_id: UUID,
|
|
category: MemoryCategory,
|
|
agent_id: Optional[UUID] = None,
|
|
workspace_id: Optional[UUID] = None,
|
|
thread_id: Optional[UUID] = None,
|
|
metadata: Dict[str, Any] = {}
|
|
) -> str:
|
|
"""
|
|
Index a memory item in Qdrant.
|
|
|
|
Returns:
|
|
Qdrant point ID
|
|
"""
|
|
# Get embedding
|
|
embeddings = await get_document_embeddings([text])
|
|
if not embeddings:
|
|
raise ValueError("Failed to generate embedding")
|
|
|
|
vector = embeddings[0]
|
|
point_id = str(uuid4())
|
|
|
|
# Build payload
|
|
payload = {
|
|
"org_id": str(org_id),
|
|
"user_id": str(user_id),
|
|
"memory_id": str(memory_id),
|
|
"type": "memory",
|
|
"category": category.value,
|
|
"text": text,
|
|
**metadata
|
|
}
|
|
|
|
if agent_id:
|
|
payload["agent_id"] = str(agent_id)
|
|
if workspace_id:
|
|
payload["workspace_id"] = str(workspace_id)
|
|
if thread_id:
|
|
payload["thread_id"] = str(thread_id)
|
|
|
|
# Upsert point
|
|
self.client.upsert(
|
|
collection_name=self.memories_collection,
|
|
points=[
|
|
qmodels.PointStruct(
|
|
id=point_id,
|
|
vector=vector,
|
|
payload=payload
|
|
)
|
|
]
|
|
)
|
|
|
|
logger.info(
|
|
"memory_indexed",
|
|
memory_id=str(memory_id),
|
|
point_id=point_id,
|
|
category=category.value
|
|
)
|
|
|
|
return point_id
|
|
|
|
async def index_summary(
|
|
self,
|
|
summary_id: UUID,
|
|
text: str,
|
|
thread_id: UUID,
|
|
org_id: UUID,
|
|
user_id: UUID,
|
|
agent_id: Optional[UUID] = None,
|
|
workspace_id: Optional[UUID] = None
|
|
) -> str:
|
|
"""Index a thread summary"""
|
|
embeddings = await get_document_embeddings([text])
|
|
if not embeddings:
|
|
raise ValueError("Failed to generate embedding")
|
|
|
|
vector = embeddings[0]
|
|
point_id = str(uuid4())
|
|
|
|
payload = {
|
|
"org_id": str(org_id),
|
|
"user_id": str(user_id),
|
|
"thread_id": str(thread_id),
|
|
"summary_id": str(summary_id),
|
|
"type": "summary",
|
|
"text": text
|
|
}
|
|
|
|
if agent_id:
|
|
payload["agent_id"] = str(agent_id)
|
|
if workspace_id:
|
|
payload["workspace_id"] = str(workspace_id)
|
|
|
|
self.client.upsert(
|
|
collection_name=self.memories_collection,
|
|
points=[
|
|
qmodels.PointStruct(
|
|
id=point_id,
|
|
vector=vector,
|
|
payload=payload
|
|
)
|
|
]
|
|
)
|
|
|
|
logger.info("summary_indexed", summary_id=str(summary_id), point_id=point_id)
|
|
return point_id
|
|
|
|
async def search_memories(
|
|
self,
|
|
query: str,
|
|
org_id: UUID,
|
|
user_id: UUID,
|
|
agent_id: Optional[UUID] = None,
|
|
workspace_id: Optional[UUID] = None,
|
|
categories: Optional[List[MemoryCategory]] = None,
|
|
include_global: bool = True,
|
|
top_k: int = 10
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Semantic search for memories.
|
|
|
|
Args:
|
|
query: Search query
|
|
org_id: Organization filter
|
|
user_id: User filter
|
|
agent_id: Agent filter (None = include global)
|
|
workspace_id: Workspace filter
|
|
categories: Filter by memory categories
|
|
include_global: Include memories without agent_id
|
|
top_k: Number of results
|
|
|
|
Returns:
|
|
List of results with scores
|
|
"""
|
|
# Get query embedding
|
|
query_vector = await get_query_embedding(query)
|
|
if not query_vector:
|
|
return []
|
|
|
|
# Build filter
|
|
must_conditions = [
|
|
qmodels.FieldCondition(
|
|
key="org_id",
|
|
match=qmodels.MatchValue(value=str(org_id))
|
|
),
|
|
qmodels.FieldCondition(
|
|
key="user_id",
|
|
match=qmodels.MatchValue(value=str(user_id))
|
|
)
|
|
]
|
|
|
|
if workspace_id:
|
|
must_conditions.append(
|
|
qmodels.FieldCondition(
|
|
key="workspace_id",
|
|
match=qmodels.MatchValue(value=str(workspace_id))
|
|
)
|
|
)
|
|
|
|
if categories:
|
|
must_conditions.append(
|
|
qmodels.FieldCondition(
|
|
key="category",
|
|
match=qmodels.MatchAny(any=[c.value for c in categories])
|
|
)
|
|
)
|
|
|
|
# Agent filter with global option
|
|
if agent_id and not include_global:
|
|
must_conditions.append(
|
|
qmodels.FieldCondition(
|
|
key="agent_id",
|
|
match=qmodels.MatchValue(value=str(agent_id))
|
|
)
|
|
)
|
|
elif agent_id and include_global:
|
|
# Include both agent-specific and global (no agent_id)
|
|
# This requires a should clause
|
|
pass # Will handle in separate query if needed
|
|
|
|
search_filter = qmodels.Filter(must=must_conditions)
|
|
|
|
# Search
|
|
results = self.client.search(
|
|
collection_name=self.memories_collection,
|
|
query_vector=query_vector,
|
|
query_filter=search_filter,
|
|
limit=top_k,
|
|
with_payload=True
|
|
)
|
|
|
|
logger.info(
|
|
"memory_search",
|
|
query_preview=query[:50],
|
|
results_count=len(results)
|
|
)
|
|
|
|
return [
|
|
{
|
|
"point_id": str(r.id),
|
|
"score": r.score,
|
|
**r.payload
|
|
}
|
|
for r in results
|
|
]
|
|
|
|
async def delete_memory(self, memory_id: UUID):
|
|
"""Delete memory from index by memory_id"""
|
|
self.client.delete(
|
|
collection_name=self.memories_collection,
|
|
points_selector=qmodels.FilterSelector(
|
|
filter=qmodels.Filter(
|
|
must=[
|
|
qmodels.FieldCondition(
|
|
key="memory_id",
|
|
match=qmodels.MatchValue(value=str(memory_id))
|
|
)
|
|
]
|
|
)
|
|
)
|
|
)
|
|
logger.info("memory_deleted_from_index", memory_id=str(memory_id))
|
|
|
|
async def get_collection_stats(self) -> Dict[str, Any]:
|
|
"""Get collection statistics"""
|
|
try:
|
|
memories_info = self.client.get_collection(self.memories_collection)
|
|
|
|
points_count = getattr(memories_info, 'points_count', 0)
|
|
vectors_count = getattr(memories_info, 'vectors_count', points_count)
|
|
indexed_vectors_count = getattr(memories_info, 'indexed_vectors_count', 0)
|
|
|
|
return {
|
|
"memories": {
|
|
"points_count": points_count,
|
|
"vectors_count": vectors_count,
|
|
"indexed_vectors_count": indexed_vectors_count
|
|
}
|
|
}
|
|
except Exception as e:
|
|
logger.error("get_collection_stats_failed", error=str(e))
|
|
return {
|
|
"memories": {
|
|
"points_count": 0,
|
|
"vectors_count": 0,
|
|
"indexed_vectors_count": 0
|
|
}
|
|
}
|
|
|
|
|
|
# Global instance
|
|
vector_store = VectorStore()
|