#!/usr/bin/env python3 """ Qdrant Parity Check: Compare Legacy vs Canonical Collections Verifies that migration preserved data correctly by comparing: 1. Point counts 2. Sample search results (topK similarity) 3. Payload field presence Usage: python qdrant_parity_check.py --agents helion,nutra,druid python qdrant_parity_check.py --all """ import argparse import logging import os import sys from pathlib import Path from typing import Any, Dict, List, Optional, Tuple try: from qdrant_client import QdrantClient except ImportError: print("Error: qdrant-client not installed. Run: pip install qdrant-client") sys.exit(1) # Add parent to path for imports sys.path.insert(0, str(Path(__file__).parent.parent)) from services.memory.qdrant.collections import get_canonical_collection_name logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s" ) logger = logging.getLogger(__name__) # Agent -> legacy collection patterns AGENT_LEGACY_COLLECTIONS = { "helion": ["helion_docs", "helion_messages"], "nutra": ["nutra_docs", "nutra_messages", "nutra_food_knowledge"], "druid": ["druid_docs", "druid_messages", "druid_legal_kb"], "greenfood": ["greenfood_docs", "greenfood_messages"], "agromatrix": ["agromatrix_docs", "agromatrix_messages"], "daarwizz": ["daarwizz_docs", "daarwizz_messages"], } class ParityStats: """Track parity check statistics.""" def __init__(self): self.checks_passed = 0 self.checks_failed = 0 self.warnings = [] self.errors = [] def add_warning(self, msg: str): self.warnings.append(msg) logger.warning(msg) def add_error(self, msg: str): self.errors.append(msg) self.checks_failed += 1 logger.error(msg) def add_pass(self, msg: str): self.checks_passed += 1 logger.info(f"✓ {msg}") def summary(self) -> Dict[str, Any]: return { "passed": self.checks_passed, "failed": self.checks_failed, "warnings": len(self.warnings), "errors": self.errors[:10], } def get_collection_count(client: QdrantClient, collection_name: str) -> Optional[int]: """Get point count for a collection.""" try: info = client.get_collection(collection_name) return info.points_count except Exception: return None def get_sample_vectors( client: QdrantClient, collection_name: str, limit: int = 5 ) -> List[Tuple[str, List[float]]]: """Get sample vectors from a collection.""" try: points, _ = client.scroll( collection_name=collection_name, limit=limit, with_vectors=True, with_payload=False, ) return [(str(p.id), p.vector) for p in points] except Exception as e: logger.warning(f"Could not get samples from {collection_name}: {e}") return [] def search_in_collection( client: QdrantClient, collection_name: str, query_vector: List[float], limit: int = 10, ) -> List[Dict[str, Any]]: """Search in a collection and return results.""" try: results = client.search( collection_name=collection_name, query_vector=query_vector, limit=limit, with_payload=True, ) return [ { "id": str(r.id), "score": r.score, "payload_keys": list(r.payload.keys()) if r.payload else [], } for r in results ] except Exception as e: logger.warning(f"Search failed in {collection_name}: {e}") return [] def check_point_counts( client: QdrantClient, agent: str, canonical_collection: str, stats: ParityStats, ) -> None: """Check that point counts match between legacy and canonical.""" legacy_collections = AGENT_LEGACY_COLLECTIONS.get(agent, []) if not legacy_collections: stats.add_warning(f"No known legacy collections for agent: {agent}") return # Count legacy points legacy_total = 0 for legacy_col in legacy_collections: count = get_collection_count(client, legacy_col) if count is not None: legacy_total += count logger.info(f" Legacy {legacy_col}: {count} points") else: stats.add_warning(f" Legacy collection not found: {legacy_col}") # Search canonical for this agent's points # We can't easily count without scrolling through all, so we'll do a sample check logger.info(f" Legacy total: {legacy_total} points") if legacy_total > 0: stats.add_pass(f"{agent}: {legacy_total} points in legacy collections") def check_search_parity( client: QdrantClient, agent: str, canonical_collection: str, stats: ParityStats, num_samples: int = 3, topk: int = 5, ) -> None: """Check that search results are similar between legacy and canonical.""" legacy_collections = AGENT_LEGACY_COLLECTIONS.get(agent, []) for legacy_col in legacy_collections: # Get sample vectors from legacy samples = get_sample_vectors(client, legacy_col, limit=num_samples) if not samples: continue logger.info(f" Checking {legacy_col} with {len(samples)} sample queries") for point_id, query_vector in samples: # Search in legacy legacy_results = search_in_collection( client, legacy_col, query_vector, limit=topk ) # Search in canonical (would need agent filter in production) canonical_results = search_in_collection( client, canonical_collection, query_vector, limit=topk ) # Compare if not legacy_results: stats.add_warning(f" No results from legacy for point {point_id}") continue if not canonical_results: stats.add_error(f" No results from canonical for point {point_id}") continue # Check if top result score is similar (within 0.1) legacy_top_score = legacy_results[0]["score"] canonical_top_score = canonical_results[0]["score"] score_diff = abs(legacy_top_score - canonical_top_score) if score_diff > 0.1: stats.add_warning( f" Score difference for {point_id}: " f"legacy={legacy_top_score:.4f}, canonical={canonical_top_score:.4f}" ) else: stats.add_pass( f"{legacy_col} point {point_id}: score diff {score_diff:.4f}" ) def check_payload_schema( client: QdrantClient, canonical_collection: str, stats: ParityStats, ) -> None: """Check that canonical payloads have required fields.""" required_fields = [ "schema_version", "tenant_id", "scope", "visibility", "indexed", "source_id", "chunk", "fingerprint", "created_at" ] # Sample points from canonical samples = get_sample_vectors(client, canonical_collection, limit=10) if not samples: stats.add_warning("Could not sample canonical collection for schema check") return # Get payloads points, _ = client.scroll( collection_name=canonical_collection, limit=10, with_payload=True, with_vectors=False, ) for point in points: payload = point.payload or {} missing = [f for f in required_fields if f not in payload] if missing: stats.add_error( f"Point {point.id} missing required fields: {missing}" ) else: # Check schema version if payload.get("schema_version") != "cm_payload_v1": stats.add_error( f"Point {point.id} has invalid schema_version: " f"{payload.get('schema_version')}" ) else: stats.add_pass(f"Point {point.id} has valid schema") def main(): parser = argparse.ArgumentParser( description="Check parity between legacy and canonical Qdrant collections" ) parser.add_argument( "--host", default=os.getenv("QDRANT_HOST", "localhost"), help="Qdrant host" ) parser.add_argument( "--port", type=int, default=int(os.getenv("QDRANT_PORT", "6333")), help="Qdrant port" ) parser.add_argument( "--agents", help="Comma-separated list of agents to check" ) parser.add_argument( "--all", action="store_true", help="Check all known agents" ) parser.add_argument( "--dim", type=int, default=1024, help="Vector dimension for canonical collection" ) args = parser.parse_args() # Connect to Qdrant logger.info(f"Connecting to Qdrant at {args.host}:{args.port}") client = QdrantClient(host=args.host, port=args.port) # Determine agents to check if args.agents: agents = [a.strip().lower() for a in args.agents.split(",")] elif args.all: agents = list(AGENT_LEGACY_COLLECTIONS.keys()) else: print("Available agents:", ", ".join(AGENT_LEGACY_COLLECTIONS.keys())) print("\nUse --agents or --all to run parity check") return canonical_collection = get_canonical_collection_name("text", args.dim) logger.info(f"Canonical collection: {canonical_collection}") # Check if canonical collection exists canonical_count = get_collection_count(client, canonical_collection) if canonical_count is None: logger.error(f"Canonical collection {canonical_collection} not found!") sys.exit(1) logger.info(f"Canonical collection has {canonical_count} points") # Run checks stats = ParityStats() for agent in agents: logger.info(f"\n=== Checking agent: {agent} ===") check_point_counts(client, agent, canonical_collection, stats) check_search_parity(client, agent, canonical_collection, stats) # Schema check (once) logger.info("\n=== Checking payload schema ===") check_payload_schema(client, canonical_collection, stats) # Summary print("\n" + "=" * 50) print("PARITY CHECK SUMMARY") print("=" * 50) summary = stats.summary() print(f"Checks passed: {summary['passed']}") print(f"Checks failed: {summary['failed']}") print(f"Warnings: {summary['warnings']}") if summary['errors']: print("\nErrors:") for err in summary['errors']: print(f" - {err}") if summary['failed'] > 0: sys.exit(1) if __name__ == "__main__": main()