""" Co-Memory Qdrant Client High-level client for canonical Qdrant operations with validation and filtering. """ import logging import os from typing import Any, Dict, List, Optional, Tuple from uuid import uuid4 from .payload_validation import validate_payload, PayloadValidationError from .collections import ensure_collection, get_canonical_collection_name, list_legacy_collections from .filters import AccessContext, build_qdrant_filter try: from qdrant_client import QdrantClient from qdrant_client.models import ( PointStruct, Filter, SearchRequest, ScoredPoint, ) HAS_QDRANT = True except ImportError: HAS_QDRANT = False QdrantClient = None logger = logging.getLogger(__name__) class CoMemoryQdrantClient: """ High-level Qdrant client for Co-Memory operations. Features: - Automatic payload validation - Canonical collection management - Access-controlled filtering - Dual-write/dual-read support for migration """ def __init__( self, host: str = "localhost", port: int = 6333, url: Optional[str] = None, api_key: Optional[str] = None, text_dim: int = 1024, text_metric: str = "cosine", ): """ Initialize Co-Memory Qdrant client. Args: host: Qdrant host port: Qdrant port url: Full Qdrant URL (overrides host/port) api_key: Qdrant API key (optional) text_dim: Default text embedding dimension text_metric: Default distance metric """ if not HAS_QDRANT: raise ImportError("qdrant-client not installed. Run: pip install qdrant-client") # Load from env if not provided host = host or os.getenv("QDRANT_HOST", "localhost") port = port or int(os.getenv("QDRANT_PORT", "6333")) url = url or os.getenv("QDRANT_URL") api_key = api_key or os.getenv("QDRANT_API_KEY") if url: self._client = QdrantClient(url=url, api_key=api_key) else: self._client = QdrantClient(host=host, port=port, api_key=api_key) self.text_dim = text_dim self.text_metric = text_metric self.text_collection = get_canonical_collection_name("text", text_dim) # Feature flags self.dual_write_enabled = os.getenv("DUAL_WRITE_OLD", "false").lower() == "true" self.dual_read_enabled = os.getenv("DUAL_READ_OLD", "false").lower() == "true" logger.info(f"CoMemoryQdrantClient initialized: {self.text_collection}") @property def client(self) -> "QdrantClient": """Get underlying Qdrant client.""" return self._client def ensure_collections(self) -> None: """Ensure canonical collections exist.""" ensure_collection( self._client, self.text_collection, self.text_dim, self.text_metric, ) logger.info(f"Ensured collection: {self.text_collection}") def upsert_text( self, points: List[Dict[str, Any]], validate: bool = True, collection_name: Optional[str] = None, ) -> Dict[str, Any]: """ Upsert text embeddings to canonical collection. Args: points: List of dicts with 'id', 'vector', 'payload' validate: Validate payloads before upsert collection_name: Override collection name (for migration) Returns: Upsert result summary """ collection = collection_name or self.text_collection valid_points = [] errors = [] for point in points: point_id = point.get("id") or str(uuid4()) vector = point.get("vector") payload = point.get("payload", {}) if not vector: errors.append({"id": point_id, "error": "Missing vector"}) continue # Validate payload if validate: try: validate_payload(payload) except PayloadValidationError as e: errors.append({"id": point_id, "error": str(e), "details": e.errors}) continue valid_points.append(PointStruct( id=point_id, vector=vector, payload=payload, )) if valid_points: self._client.upsert( collection_name=collection, points=valid_points, ) logger.info(f"Upserted {len(valid_points)} points to {collection}") # Dual-write to legacy collections if enabled if self.dual_write_enabled and collection_name is None: self._dual_write_legacy(valid_points) return { "upserted": len(valid_points), "errors": len(errors), "error_details": errors if errors else None, } def _dual_write_legacy(self, points: List["PointStruct"]) -> None: """Write to legacy collections for migration compatibility.""" # Group points by legacy collection legacy_points: Dict[str, List[PointStruct]] = {} for point in points: payload = point.payload agent_id = payload.get("agent_id") scope = payload.get("scope") if agent_id and scope: # Map to legacy collection name agent_slug = agent_id.replace("agt_", "") legacy_name = f"{agent_slug}_{scope}" if legacy_name not in legacy_points: legacy_points[legacy_name] = [] legacy_points[legacy_name].append(point) # Write to legacy collections for legacy_collection, pts in legacy_points.items(): try: self._client.upsert( collection_name=legacy_collection, points=pts, ) logger.debug(f"Dual-write: {len(pts)} points to {legacy_collection}") except Exception as e: logger.warning(f"Dual-write to {legacy_collection} failed: {e}") def search_text( self, query_vector: List[float], ctx: AccessContext, limit: int = 10, scope: Optional[str] = None, scopes: Optional[List[str]] = None, tags: Optional[List[str]] = None, score_threshold: Optional[float] = None, with_payload: bool = True, collection_name: Optional[str] = None, ) -> List[Dict[str, Any]]: """ Search text embeddings with access control. Args: query_vector: Query embedding vector ctx: Access context for filtering limit: Maximum results scope: Filter by scope (docs, messages, etc.) scopes: Filter by multiple scopes tags: Filter by tags score_threshold: Minimum similarity score with_payload: Include payload in results collection_name: Override collection name Returns: List of search results """ collection = collection_name or self.text_collection # Build access-controlled filter filter_dict = build_qdrant_filter( ctx=ctx, scope=scope, scopes=scopes, tags=tags, ) # Convert to Qdrant Filter qdrant_filter = self._dict_to_filter(filter_dict) # Search results = self._client.search( collection_name=collection, query_vector=query_vector, query_filter=qdrant_filter, limit=limit, score_threshold=score_threshold, with_payload=with_payload, ) # Convert results output = [] for result in results: output.append({ "id": result.id, "score": result.score, "payload": result.payload if with_payload else None, }) # Dual-read from legacy if enabled and no results if self.dual_read_enabled and not output: legacy_results = self._dual_read_legacy( query_vector, ctx, limit, scope, scopes, tags, score_threshold ) output.extend(legacy_results) return output def _dual_read_legacy( self, query_vector: List[float], ctx: AccessContext, limit: int, scope: Optional[str], scopes: Optional[List[str]], tags: Optional[List[str]], score_threshold: Optional[float], ) -> List[Dict[str, Any]]: """Fallback read from legacy collections.""" results = [] # Determine which legacy collections to search legacy_collections = [] if ctx.agent_id: agent_slug = ctx.agent_id.replace("agt_", "") if scope: legacy_collections.append(f"{agent_slug}_{scope}") elif scopes: for s in scopes: legacy_collections.append(f"{agent_slug}_{s}") else: legacy_collections.append(f"{agent_slug}_docs") legacy_collections.append(f"{agent_slug}_messages") for legacy_collection in legacy_collections: try: legacy_results = self._client.search( collection_name=legacy_collection, query_vector=query_vector, limit=limit, score_threshold=score_threshold, with_payload=True, ) for result in legacy_results: results.append({ "id": result.id, "score": result.score, "payload": result.payload, "_legacy_collection": legacy_collection, }) except Exception as e: logger.debug(f"Legacy read from {legacy_collection} failed: {e}") return results def _dict_to_filter(self, filter_dict: Dict[str, Any]) -> "Filter": """Convert filter dictionary to Qdrant Filter object.""" from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchAny def build_condition(cond: Dict[str, Any]): key = cond.get("key") match = cond.get("match", {}) if "value" in match: return FieldCondition( key=key, match=MatchValue(value=match["value"]) ) elif "any" in match: return FieldCondition( key=key, match=MatchAny(any=match["any"]) ) return None def build_conditions(conditions: List[Dict]) -> List: result = [] for cond in conditions: if "must" in cond: # Nested filter nested = Filter(must=[build_condition(c) for c in cond["must"] if build_condition(c)]) if "must_not" in cond: nested.must_not = [build_condition(c) for c in cond["must_not"] if build_condition(c)] result.append(nested) else: built = build_condition(cond) if built: result.append(built) return result must = build_conditions(filter_dict.get("must", [])) should = build_conditions(filter_dict.get("should", [])) must_not = build_conditions(filter_dict.get("must_not", [])) return Filter( must=must if must else None, should=should if should else None, must_not=must_not if must_not else None, ) def delete_points( self, point_ids: List[str], collection_name: Optional[str] = None, ) -> int: """ Delete points by IDs. Args: point_ids: List of point IDs to delete collection_name: Override collection name Returns: Number of points deleted """ collection = collection_name or self.text_collection self._client.delete( collection_name=collection, points_selector=point_ids, ) return len(point_ids) def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict[str, Any]: """Get collection statistics.""" collection = collection_name or self.text_collection try: info = self._client.get_collection(collection) return { "name": collection, "vectors_count": info.vectors_count, "points_count": info.points_count, "status": info.status.value, } except Exception as e: return {"error": str(e)} def list_all_collections(self) -> Dict[str, List[str]]: """List all collections categorized as canonical or legacy.""" collections = self._client.get_collections().collections canonical = [] legacy = [] for col in collections: if col.name.startswith("cm_"): canonical.append(col.name) else: legacy.append(col.name) return { "canonical": canonical, "legacy": legacy, }