feat: Add Alateya, Clan, Eonarch agents + fix gateway-router connection
## Agents Added - Alateya: R&D, biotech, innovations - Clan (Spirit): Community spirit agent - Eonarch: Consciousness evolution agent ## Changes - docker-compose.node1.yml: Added tokens for all 3 new agents - gateway-bot/http_api.py: Added configs and webhook endpoints - gateway-bot/clan_prompt.txt: New prompt file - gateway-bot/eonarch_prompt.txt: New prompt file ## Fixes - Fixed ROUTER_URL from :9102 to :8000 (internal container port) - All 9 Telegram agents now working ## Documentation - Created PROJECT-MASTER-INDEX.md - single entry point - Added various status documents and scripts Tokens configured: - Helion, NUTRA, Agromatrix (existing) - Alateya, Clan, Eonarch (new) - Druid, GreenFood, DAARWIZZ (configured)
This commit is contained in:
43
services/memory/qdrant/__init__.py
Normal file
43
services/memory/qdrant/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
Co-Memory Qdrant Module
|
||||
|
||||
Canonical Qdrant client with payload validation and filter building.
|
||||
|
||||
Security Invariants:
|
||||
- tenant_id is ALWAYS required in filters
|
||||
- indexed=true is default for search
|
||||
- Empty should clause is NEVER allowed (would match everything)
|
||||
- visibility=private is ONLY accessible by owner
|
||||
"""
|
||||
|
||||
from .payload_validation import validate_payload, PayloadValidationError
|
||||
from .collections import ensure_collection, get_canonical_collection_name
|
||||
from .filters import (
|
||||
build_qdrant_filter,
|
||||
build_agent_only_filter,
|
||||
build_multi_agent_filter,
|
||||
build_project_filter,
|
||||
build_tag_filter,
|
||||
AccessContext,
|
||||
FilterSecurityError,
|
||||
)
|
||||
from .client import CoMemoryQdrantClient
|
||||
|
||||
__all__ = [
|
||||
# Validation
|
||||
"validate_payload",
|
||||
"PayloadValidationError",
|
||||
# Collections
|
||||
"ensure_collection",
|
||||
"get_canonical_collection_name",
|
||||
# Filters
|
||||
"build_qdrant_filter",
|
||||
"build_agent_only_filter",
|
||||
"build_multi_agent_filter",
|
||||
"build_project_filter",
|
||||
"build_tag_filter",
|
||||
"AccessContext",
|
||||
"FilterSecurityError",
|
||||
# Client
|
||||
"CoMemoryQdrantClient",
|
||||
]
|
||||
413
services/memory/qdrant/client.py
Normal file
413
services/memory/qdrant/client.py
Normal file
@@ -0,0 +1,413 @@
|
||||
"""
|
||||
Co-Memory Qdrant Client
|
||||
|
||||
High-level client for canonical Qdrant operations with validation and filtering.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from .payload_validation import validate_payload, PayloadValidationError
|
||||
from .collections import ensure_collection, get_canonical_collection_name, list_legacy_collections
|
||||
from .filters import AccessContext, build_qdrant_filter
|
||||
|
||||
try:
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
PointStruct,
|
||||
Filter,
|
||||
SearchRequest,
|
||||
ScoredPoint,
|
||||
)
|
||||
HAS_QDRANT = True
|
||||
except ImportError:
|
||||
HAS_QDRANT = False
|
||||
QdrantClient = None
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CoMemoryQdrantClient:
|
||||
"""
|
||||
High-level Qdrant client for Co-Memory operations.
|
||||
|
||||
Features:
|
||||
- Automatic payload validation
|
||||
- Canonical collection management
|
||||
- Access-controlled filtering
|
||||
- Dual-write/dual-read support for migration
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: str = "localhost",
|
||||
port: int = 6333,
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
text_dim: int = 1024,
|
||||
text_metric: str = "cosine",
|
||||
):
|
||||
"""
|
||||
Initialize Co-Memory Qdrant client.
|
||||
|
||||
Args:
|
||||
host: Qdrant host
|
||||
port: Qdrant port
|
||||
url: Full Qdrant URL (overrides host/port)
|
||||
api_key: Qdrant API key (optional)
|
||||
text_dim: Default text embedding dimension
|
||||
text_metric: Default distance metric
|
||||
"""
|
||||
if not HAS_QDRANT:
|
||||
raise ImportError("qdrant-client not installed. Run: pip install qdrant-client")
|
||||
|
||||
# Load from env if not provided
|
||||
host = host or os.getenv("QDRANT_HOST", "localhost")
|
||||
port = port or int(os.getenv("QDRANT_PORT", "6333"))
|
||||
url = url or os.getenv("QDRANT_URL")
|
||||
api_key = api_key or os.getenv("QDRANT_API_KEY")
|
||||
|
||||
if url:
|
||||
self._client = QdrantClient(url=url, api_key=api_key)
|
||||
else:
|
||||
self._client = QdrantClient(host=host, port=port, api_key=api_key)
|
||||
|
||||
self.text_dim = text_dim
|
||||
self.text_metric = text_metric
|
||||
self.text_collection = get_canonical_collection_name("text", text_dim)
|
||||
|
||||
# Feature flags
|
||||
self.dual_write_enabled = os.getenv("DUAL_WRITE_OLD", "false").lower() == "true"
|
||||
self.dual_read_enabled = os.getenv("DUAL_READ_OLD", "false").lower() == "true"
|
||||
|
||||
logger.info(f"CoMemoryQdrantClient initialized: {self.text_collection}")
|
||||
|
||||
@property
|
||||
def client(self) -> "QdrantClient":
|
||||
"""Get underlying Qdrant client."""
|
||||
return self._client
|
||||
|
||||
def ensure_collections(self) -> None:
|
||||
"""Ensure canonical collections exist."""
|
||||
ensure_collection(
|
||||
self._client,
|
||||
self.text_collection,
|
||||
self.text_dim,
|
||||
self.text_metric,
|
||||
)
|
||||
logger.info(f"Ensured collection: {self.text_collection}")
|
||||
|
||||
def upsert_text(
|
||||
self,
|
||||
points: List[Dict[str, Any]],
|
||||
validate: bool = True,
|
||||
collection_name: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Upsert text embeddings to canonical collection.
|
||||
|
||||
Args:
|
||||
points: List of dicts with 'id', 'vector', 'payload'
|
||||
validate: Validate payloads before upsert
|
||||
collection_name: Override collection name (for migration)
|
||||
|
||||
Returns:
|
||||
Upsert result summary
|
||||
"""
|
||||
collection = collection_name or self.text_collection
|
||||
valid_points = []
|
||||
errors = []
|
||||
|
||||
for point in points:
|
||||
point_id = point.get("id") or str(uuid4())
|
||||
vector = point.get("vector")
|
||||
payload = point.get("payload", {})
|
||||
|
||||
if not vector:
|
||||
errors.append({"id": point_id, "error": "Missing vector"})
|
||||
continue
|
||||
|
||||
# Validate payload
|
||||
if validate:
|
||||
try:
|
||||
validate_payload(payload)
|
||||
except PayloadValidationError as e:
|
||||
errors.append({"id": point_id, "error": str(e), "details": e.errors})
|
||||
continue
|
||||
|
||||
valid_points.append(PointStruct(
|
||||
id=point_id,
|
||||
vector=vector,
|
||||
payload=payload,
|
||||
))
|
||||
|
||||
if valid_points:
|
||||
self._client.upsert(
|
||||
collection_name=collection,
|
||||
points=valid_points,
|
||||
)
|
||||
logger.info(f"Upserted {len(valid_points)} points to {collection}")
|
||||
|
||||
# Dual-write to legacy collections if enabled
|
||||
if self.dual_write_enabled and collection_name is None:
|
||||
self._dual_write_legacy(valid_points)
|
||||
|
||||
return {
|
||||
"upserted": len(valid_points),
|
||||
"errors": len(errors),
|
||||
"error_details": errors if errors else None,
|
||||
}
|
||||
|
||||
def _dual_write_legacy(self, points: List["PointStruct"]) -> None:
|
||||
"""Write to legacy collections for migration compatibility."""
|
||||
# Group points by legacy collection
|
||||
legacy_points: Dict[str, List[PointStruct]] = {}
|
||||
|
||||
for point in points:
|
||||
payload = point.payload
|
||||
agent_id = payload.get("agent_id")
|
||||
scope = payload.get("scope")
|
||||
|
||||
if agent_id and scope:
|
||||
# Map to legacy collection name
|
||||
agent_slug = agent_id.replace("agt_", "")
|
||||
legacy_name = f"{agent_slug}_{scope}"
|
||||
|
||||
if legacy_name not in legacy_points:
|
||||
legacy_points[legacy_name] = []
|
||||
legacy_points[legacy_name].append(point)
|
||||
|
||||
# Write to legacy collections
|
||||
for legacy_collection, pts in legacy_points.items():
|
||||
try:
|
||||
self._client.upsert(
|
||||
collection_name=legacy_collection,
|
||||
points=pts,
|
||||
)
|
||||
logger.debug(f"Dual-write: {len(pts)} points to {legacy_collection}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Dual-write to {legacy_collection} failed: {e}")
|
||||
|
||||
def search_text(
|
||||
self,
|
||||
query_vector: List[float],
|
||||
ctx: AccessContext,
|
||||
limit: int = 10,
|
||||
scope: Optional[str] = None,
|
||||
scopes: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
score_threshold: Optional[float] = None,
|
||||
with_payload: bool = True,
|
||||
collection_name: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search text embeddings with access control.
|
||||
|
||||
Args:
|
||||
query_vector: Query embedding vector
|
||||
ctx: Access context for filtering
|
||||
limit: Maximum results
|
||||
scope: Filter by scope (docs, messages, etc.)
|
||||
scopes: Filter by multiple scopes
|
||||
tags: Filter by tags
|
||||
score_threshold: Minimum similarity score
|
||||
with_payload: Include payload in results
|
||||
collection_name: Override collection name
|
||||
|
||||
Returns:
|
||||
List of search results
|
||||
"""
|
||||
collection = collection_name or self.text_collection
|
||||
|
||||
# Build access-controlled filter
|
||||
filter_dict = build_qdrant_filter(
|
||||
ctx=ctx,
|
||||
scope=scope,
|
||||
scopes=scopes,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
# Convert to Qdrant Filter
|
||||
qdrant_filter = self._dict_to_filter(filter_dict)
|
||||
|
||||
# Search
|
||||
results = self._client.search(
|
||||
collection_name=collection,
|
||||
query_vector=query_vector,
|
||||
query_filter=qdrant_filter,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
with_payload=with_payload,
|
||||
)
|
||||
|
||||
# Convert results
|
||||
output = []
|
||||
for result in results:
|
||||
output.append({
|
||||
"id": result.id,
|
||||
"score": result.score,
|
||||
"payload": result.payload if with_payload else None,
|
||||
})
|
||||
|
||||
# Dual-read from legacy if enabled and no results
|
||||
if self.dual_read_enabled and not output:
|
||||
legacy_results = self._dual_read_legacy(
|
||||
query_vector, ctx, limit, scope, scopes, tags, score_threshold
|
||||
)
|
||||
output.extend(legacy_results)
|
||||
|
||||
return output
|
||||
|
||||
def _dual_read_legacy(
|
||||
self,
|
||||
query_vector: List[float],
|
||||
ctx: AccessContext,
|
||||
limit: int,
|
||||
scope: Optional[str],
|
||||
scopes: Optional[List[str]],
|
||||
tags: Optional[List[str]],
|
||||
score_threshold: Optional[float],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Fallback read from legacy collections."""
|
||||
results = []
|
||||
|
||||
# Determine which legacy collections to search
|
||||
legacy_collections = []
|
||||
if ctx.agent_id:
|
||||
agent_slug = ctx.agent_id.replace("agt_", "")
|
||||
if scope:
|
||||
legacy_collections.append(f"{agent_slug}_{scope}")
|
||||
elif scopes:
|
||||
for s in scopes:
|
||||
legacy_collections.append(f"{agent_slug}_{s}")
|
||||
else:
|
||||
legacy_collections.append(f"{agent_slug}_docs")
|
||||
legacy_collections.append(f"{agent_slug}_messages")
|
||||
|
||||
for legacy_collection in legacy_collections:
|
||||
try:
|
||||
legacy_results = self._client.search(
|
||||
collection_name=legacy_collection,
|
||||
query_vector=query_vector,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
with_payload=True,
|
||||
)
|
||||
|
||||
for result in legacy_results:
|
||||
results.append({
|
||||
"id": result.id,
|
||||
"score": result.score,
|
||||
"payload": result.payload,
|
||||
"_legacy_collection": legacy_collection,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Legacy read from {legacy_collection} failed: {e}")
|
||||
|
||||
return results
|
||||
|
||||
def _dict_to_filter(self, filter_dict: Dict[str, Any]) -> "Filter":
|
||||
"""Convert filter dictionary to Qdrant Filter object."""
|
||||
from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchAny
|
||||
|
||||
def build_condition(cond: Dict[str, Any]):
|
||||
key = cond.get("key")
|
||||
match = cond.get("match", {})
|
||||
|
||||
if "value" in match:
|
||||
return FieldCondition(
|
||||
key=key,
|
||||
match=MatchValue(value=match["value"])
|
||||
)
|
||||
elif "any" in match:
|
||||
return FieldCondition(
|
||||
key=key,
|
||||
match=MatchAny(any=match["any"])
|
||||
)
|
||||
return None
|
||||
|
||||
def build_conditions(conditions: List[Dict]) -> List:
|
||||
result = []
|
||||
for cond in conditions:
|
||||
if "must" in cond:
|
||||
# Nested filter
|
||||
nested = Filter(must=[build_condition(c) for c in cond["must"] if build_condition(c)])
|
||||
if "must_not" in cond:
|
||||
nested.must_not = [build_condition(c) for c in cond["must_not"] if build_condition(c)]
|
||||
result.append(nested)
|
||||
else:
|
||||
built = build_condition(cond)
|
||||
if built:
|
||||
result.append(built)
|
||||
return result
|
||||
|
||||
must = build_conditions(filter_dict.get("must", []))
|
||||
should = build_conditions(filter_dict.get("should", []))
|
||||
must_not = build_conditions(filter_dict.get("must_not", []))
|
||||
|
||||
return Filter(
|
||||
must=must if must else None,
|
||||
should=should if should else None,
|
||||
must_not=must_not if must_not else None,
|
||||
)
|
||||
|
||||
def delete_points(
|
||||
self,
|
||||
point_ids: List[str],
|
||||
collection_name: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Delete points by IDs.
|
||||
|
||||
Args:
|
||||
point_ids: List of point IDs to delete
|
||||
collection_name: Override collection name
|
||||
|
||||
Returns:
|
||||
Number of points deleted
|
||||
"""
|
||||
collection = collection_name or self.text_collection
|
||||
|
||||
self._client.delete(
|
||||
collection_name=collection,
|
||||
points_selector=point_ids,
|
||||
)
|
||||
|
||||
return len(point_ids)
|
||||
|
||||
def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get collection statistics."""
|
||||
collection = collection_name or self.text_collection
|
||||
|
||||
try:
|
||||
info = self._client.get_collection(collection)
|
||||
return {
|
||||
"name": collection,
|
||||
"vectors_count": info.vectors_count,
|
||||
"points_count": info.points_count,
|
||||
"status": info.status.value,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
def list_all_collections(self) -> Dict[str, List[str]]:
|
||||
"""List all collections categorized as canonical or legacy."""
|
||||
collections = self._client.get_collections().collections
|
||||
|
||||
canonical = []
|
||||
legacy = []
|
||||
|
||||
for col in collections:
|
||||
if col.name.startswith("cm_"):
|
||||
canonical.append(col.name)
|
||||
else:
|
||||
legacy.append(col.name)
|
||||
|
||||
return {
|
||||
"canonical": canonical,
|
||||
"legacy": legacy,
|
||||
}
|
||||
260
services/memory/qdrant/collections.py
Normal file
260
services/memory/qdrant/collections.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""
|
||||
Qdrant Collection Management for Co-Memory
|
||||
|
||||
Handles canonical collection creation and configuration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import (
|
||||
Distance,
|
||||
VectorParams,
|
||||
PayloadSchemaType,
|
||||
TextIndexParams,
|
||||
TokenizerType,
|
||||
)
|
||||
HAS_QDRANT = True
|
||||
except ImportError:
|
||||
HAS_QDRANT = False
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Canonical collection naming
|
||||
COLLECTION_PREFIX = "cm"
|
||||
COLLECTION_VERSION = "v1"
|
||||
|
||||
|
||||
def get_canonical_collection_name(
|
||||
collection_type: str = "text",
|
||||
dim: int = 1024,
|
||||
version: str = COLLECTION_VERSION
|
||||
) -> str:
|
||||
"""
|
||||
Generate canonical collection name.
|
||||
|
||||
Args:
|
||||
collection_type: Type of embeddings (text, code, mm)
|
||||
dim: Vector dimension
|
||||
version: Schema version
|
||||
|
||||
Returns:
|
||||
Collection name like "cm_text_1024_v1"
|
||||
"""
|
||||
return f"{COLLECTION_PREFIX}_{collection_type}_{dim}_{version}"
|
||||
|
||||
|
||||
def get_distance_metric(metric: str) -> "Distance":
|
||||
"""Convert metric string to Qdrant Distance enum."""
|
||||
if not HAS_QDRANT:
|
||||
raise ImportError("qdrant-client not installed")
|
||||
|
||||
metrics = {
|
||||
"cosine": Distance.COSINE,
|
||||
"dot": Distance.DOT,
|
||||
"euclidean": Distance.EUCLID,
|
||||
}
|
||||
return metrics.get(metric.lower(), Distance.COSINE)
|
||||
|
||||
|
||||
# Default payload indexes for optimal query performance
|
||||
DEFAULT_PAYLOAD_INDEXES = [
|
||||
{"field": "tenant_id", "type": "keyword"},
|
||||
{"field": "team_id", "type": "keyword"},
|
||||
{"field": "project_id", "type": "keyword"},
|
||||
{"field": "agent_id", "type": "keyword"},
|
||||
{"field": "scope", "type": "keyword"},
|
||||
{"field": "visibility", "type": "keyword"},
|
||||
{"field": "indexed", "type": "bool"},
|
||||
{"field": "source_id", "type": "keyword"},
|
||||
{"field": "owner_kind", "type": "keyword"},
|
||||
{"field": "owner_id", "type": "keyword"},
|
||||
{"field": "tags", "type": "keyword"},
|
||||
{"field": "acl.read_team_ids", "type": "keyword"},
|
||||
{"field": "acl.read_agent_ids", "type": "keyword"},
|
||||
{"field": "acl.read_role_ids", "type": "keyword"},
|
||||
]
|
||||
|
||||
|
||||
def ensure_collection(
|
||||
client: "QdrantClient",
|
||||
name: str,
|
||||
dim: int,
|
||||
metric: str = "cosine",
|
||||
payload_indexes: Optional[List[Dict[str, str]]] = None,
|
||||
on_disk: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Ensure a canonical collection exists with proper configuration.
|
||||
|
||||
Args:
|
||||
client: Qdrant client instance
|
||||
name: Collection name
|
||||
dim: Vector dimension
|
||||
metric: Distance metric (cosine, dot, euclidean)
|
||||
payload_indexes: List of payload fields to index
|
||||
on_disk: Whether to store vectors on disk
|
||||
|
||||
Returns:
|
||||
True if collection was created, False if already exists
|
||||
"""
|
||||
if not HAS_QDRANT:
|
||||
raise ImportError("qdrant-client not installed")
|
||||
|
||||
# Check if collection exists
|
||||
collections = client.get_collections().collections
|
||||
existing_names = [c.name for c in collections]
|
||||
|
||||
if name in existing_names:
|
||||
logger.info(f"Collection '{name}' already exists")
|
||||
|
||||
# Ensure payload indexes
|
||||
_ensure_payload_indexes(client, name, payload_indexes or DEFAULT_PAYLOAD_INDEXES)
|
||||
return False
|
||||
|
||||
# Create collection
|
||||
logger.info(f"Creating collection '{name}' with dim={dim}, metric={metric}")
|
||||
|
||||
client.create_collection(
|
||||
collection_name=name,
|
||||
vectors_config=VectorParams(
|
||||
size=dim,
|
||||
distance=get_distance_metric(metric),
|
||||
on_disk=on_disk,
|
||||
),
|
||||
)
|
||||
|
||||
# Create payload indexes
|
||||
_ensure_payload_indexes(client, name, payload_indexes or DEFAULT_PAYLOAD_INDEXES)
|
||||
|
||||
logger.info(f"Collection '{name}' created successfully")
|
||||
return True
|
||||
|
||||
|
||||
def _ensure_payload_indexes(
|
||||
client: "QdrantClient",
|
||||
collection_name: str,
|
||||
indexes: List[Dict[str, str]]
|
||||
) -> None:
|
||||
"""
|
||||
Ensure payload indexes exist on collection.
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
collection_name: Collection name
|
||||
indexes: List of index configurations
|
||||
"""
|
||||
if not HAS_QDRANT:
|
||||
return
|
||||
|
||||
for index_config in indexes:
|
||||
field_name = index_config["field"]
|
||||
field_type = index_config.get("type", "keyword")
|
||||
|
||||
try:
|
||||
if field_type == "keyword":
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field_name,
|
||||
field_schema=PayloadSchemaType.KEYWORD,
|
||||
)
|
||||
elif field_type == "bool":
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field_name,
|
||||
field_schema=PayloadSchemaType.BOOL,
|
||||
)
|
||||
elif field_type == "integer":
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field_name,
|
||||
field_schema=PayloadSchemaType.INTEGER,
|
||||
)
|
||||
elif field_type == "float":
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field_name,
|
||||
field_schema=PayloadSchemaType.FLOAT,
|
||||
)
|
||||
elif field_type == "datetime":
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field_name,
|
||||
field_schema=PayloadSchemaType.DATETIME,
|
||||
)
|
||||
elif field_type == "text":
|
||||
client.create_payload_index(
|
||||
collection_name=collection_name,
|
||||
field_name=field_name,
|
||||
field_schema=TextIndexParams(
|
||||
type="text",
|
||||
tokenizer=TokenizerType.WORD,
|
||||
min_token_len=2,
|
||||
max_token_len=15,
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(f"Created payload index: {field_name} ({field_type})")
|
||||
|
||||
except Exception as e:
|
||||
# Index might already exist
|
||||
if "already exists" not in str(e).lower():
|
||||
logger.warning(f"Failed to create index {field_name}: {e}")
|
||||
|
||||
|
||||
def get_collection_info(client: "QdrantClient", name: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get collection information.
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
name: Collection name
|
||||
|
||||
Returns:
|
||||
Collection info dict or None if not found
|
||||
"""
|
||||
if not HAS_QDRANT:
|
||||
raise ImportError("qdrant-client not installed")
|
||||
|
||||
try:
|
||||
info = client.get_collection(name)
|
||||
return {
|
||||
"name": name,
|
||||
"vectors_count": info.vectors_count,
|
||||
"points_count": info.points_count,
|
||||
"status": info.status.value,
|
||||
"config": {
|
||||
"size": info.config.params.vectors.size,
|
||||
"distance": info.config.params.vectors.distance.value,
|
||||
}
|
||||
}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def list_legacy_collections(client: "QdrantClient") -> List[str]:
|
||||
"""
|
||||
List all legacy (non-canonical) collections.
|
||||
|
||||
Args:
|
||||
client: Qdrant client
|
||||
|
||||
Returns:
|
||||
List of legacy collection names
|
||||
"""
|
||||
if not HAS_QDRANT:
|
||||
raise ImportError("qdrant-client not installed")
|
||||
|
||||
collections = client.get_collections().collections
|
||||
legacy = []
|
||||
|
||||
for col in collections:
|
||||
# Canonical collections start with "cm_"
|
||||
if not col.name.startswith(f"{COLLECTION_PREFIX}_"):
|
||||
legacy.append(col.name)
|
||||
|
||||
return legacy
|
||||
541
services/memory/qdrant/filters.py
Normal file
541
services/memory/qdrant/filters.py
Normal file
@@ -0,0 +1,541 @@
|
||||
"""
|
||||
Qdrant Filter Builder for Co-Memory
|
||||
|
||||
Builds complex Qdrant filters based on access context and query requirements.
|
||||
|
||||
SECURITY INVARIANTS:
|
||||
- tenant_id is ALWAYS in must conditions
|
||||
- indexed=true is default in must conditions
|
||||
- Empty should list is NEVER returned (would match everything)
|
||||
- visibility=private ONLY accessible by owner, NEVER leaked to multi-agent queries
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
|
||||
class FilterSecurityError(Exception):
|
||||
"""Raised when filter would violate security invariants."""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AccessContext:
|
||||
"""
|
||||
Context for building access-controlled filters.
|
||||
|
||||
Represents "who is asking" for the query.
|
||||
"""
|
||||
tenant_id: str
|
||||
team_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
||||
role_ids: List[str] = field(default_factory=list)
|
||||
|
||||
# Access permissions
|
||||
allowed_agent_ids: List[str] = field(default_factory=list)
|
||||
allowed_team_ids: List[str] = field(default_factory=list)
|
||||
|
||||
# Query constraints
|
||||
is_admin: bool = False
|
||||
|
||||
def has_identity(self) -> bool:
|
||||
"""Check if context has at least one identity for access control."""
|
||||
return bool(
|
||||
self.agent_id or
|
||||
self.user_id or
|
||||
self.team_id or
|
||||
self.allowed_agent_ids or
|
||||
self.is_admin
|
||||
)
|
||||
|
||||
|
||||
def build_qdrant_filter(
|
||||
ctx: AccessContext,
|
||||
scope: Optional[str] = None,
|
||||
scopes: Optional[List[str]] = None,
|
||||
visibility: Optional[str] = None,
|
||||
visibilities: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
source_id: Optional[str] = None,
|
||||
channel_id: Optional[str] = None,
|
||||
indexed_only: bool = True,
|
||||
include_private: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build a Qdrant filter for canonical collection queries.
|
||||
|
||||
Args:
|
||||
ctx: Access context (who is querying)
|
||||
scope: Single scope to filter (docs, messages, etc.)
|
||||
scopes: Multiple scopes to filter
|
||||
visibility: Single visibility filter
|
||||
visibilities: Multiple visibilities to filter
|
||||
tags: Tags to filter by (ANY match)
|
||||
source_id: Specific source ID
|
||||
channel_id: Channel/chat ID for messages
|
||||
indexed_only: Only return indexed=true items
|
||||
include_private: Include private items (owner only)
|
||||
|
||||
Returns:
|
||||
Qdrant filter dictionary
|
||||
|
||||
Raises:
|
||||
FilterSecurityError: If filter would be insecure (e.g., empty should)
|
||||
"""
|
||||
# SECURITY: tenant_id is ALWAYS required, even for admin
|
||||
if not ctx.tenant_id:
|
||||
raise FilterSecurityError("tenant_id is required and cannot be None/empty")
|
||||
|
||||
must_conditions = []
|
||||
must_not_conditions = []
|
||||
|
||||
# INVARIANT: Always filter by tenant (even for admin)
|
||||
must_conditions.append({
|
||||
"key": "tenant_id",
|
||||
"match": {"value": ctx.tenant_id}
|
||||
})
|
||||
|
||||
# INVARIANT: Default only indexed items (indexed=true unless explicitly disabled)
|
||||
if indexed_only:
|
||||
must_conditions.append({
|
||||
"key": "indexed",
|
||||
"match": {"value": True}
|
||||
})
|
||||
|
||||
# Scope filter
|
||||
if scope:
|
||||
must_conditions.append({
|
||||
"key": "scope",
|
||||
"match": {"value": scope}
|
||||
})
|
||||
elif scopes:
|
||||
must_conditions.append({
|
||||
"key": "scope",
|
||||
"match": {"any": scopes}
|
||||
})
|
||||
|
||||
# Source ID filter
|
||||
if source_id:
|
||||
must_conditions.append({
|
||||
"key": "source_id",
|
||||
"match": {"value": source_id}
|
||||
})
|
||||
|
||||
# Channel filter
|
||||
if channel_id:
|
||||
must_conditions.append({
|
||||
"key": "channel_id",
|
||||
"match": {"value": channel_id}
|
||||
})
|
||||
|
||||
# Build access control filter
|
||||
access_should = _build_access_filter(
|
||||
ctx=ctx,
|
||||
visibility=visibility,
|
||||
visibilities=visibilities,
|
||||
include_private=include_private,
|
||||
)
|
||||
|
||||
# SECURITY: Validate access filter is not empty (would match everything)
|
||||
if not access_should and not ctx.is_admin:
|
||||
raise FilterSecurityError(
|
||||
"Access filter is empty - would return all documents. "
|
||||
"Context must have at least one identity (agent_id, user_id, team_id) "
|
||||
"or is_admin=True"
|
||||
)
|
||||
|
||||
# Combine filters
|
||||
filter_dict = {"must": must_conditions}
|
||||
|
||||
if access_should:
|
||||
filter_dict["should"] = access_should
|
||||
filter_dict["minimum_should_match"] = 1
|
||||
|
||||
# Tags filter (added to must for AND behavior, not OR)
|
||||
if tags:
|
||||
# Tags require ALL specified tags to match
|
||||
for tag in tags:
|
||||
must_conditions.append({
|
||||
"key": "tags",
|
||||
"match": {"value": tag}
|
||||
})
|
||||
|
||||
if must_not_conditions:
|
||||
filter_dict["must_not"] = must_not_conditions
|
||||
|
||||
return filter_dict
|
||||
|
||||
|
||||
def _build_access_filter(
|
||||
ctx: AccessContext,
|
||||
visibility: Optional[str] = None,
|
||||
visibilities: Optional[List[str]] = None,
|
||||
include_private: bool = False,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Build access control filter based on visibility and ACL.
|
||||
|
||||
Returns list of should conditions for OR matching.
|
||||
|
||||
SECURITY: Never returns empty list unless ctx.is_admin
|
||||
"""
|
||||
should = []
|
||||
|
||||
# Admin access: still respects tenant isolation, but has broader visibility
|
||||
# SECURITY: Admin without explicit visibility sees public+confidential by default
|
||||
# To see private, admin must explicitly request visibility="private" or include_private=True
|
||||
if ctx.is_admin:
|
||||
if visibility:
|
||||
should.append({
|
||||
"must": [{"key": "visibility", "match": {"value": visibility}}]
|
||||
})
|
||||
elif visibilities:
|
||||
should.append({
|
||||
"must": [{"key": "visibility", "match": {"any": visibilities}}]
|
||||
})
|
||||
elif include_private:
|
||||
# Admin explicitly requested private access
|
||||
should.append({
|
||||
"must": [{"key": "visibility", "match": {"any": ["public", "confidential", "private"]}}]
|
||||
})
|
||||
else:
|
||||
# Admin default: public + confidential only (no private leak by default)
|
||||
should.append({
|
||||
"must": [{"key": "visibility", "match": {"any": ["public", "confidential"]}}]
|
||||
})
|
||||
return should
|
||||
|
||||
# Determine allowed visibilities
|
||||
allowed_vis = set()
|
||||
if visibility:
|
||||
allowed_vis.add(visibility)
|
||||
elif visibilities:
|
||||
allowed_vis.update(visibilities)
|
||||
else:
|
||||
allowed_vis = {"public", "confidential"}
|
||||
# SECURITY: private only if explicitly requested AND owner identity exists
|
||||
if include_private and (ctx.agent_id or ctx.user_id):
|
||||
allowed_vis.add("private")
|
||||
|
||||
# 1. Public content in same team
|
||||
if "public" in allowed_vis and ctx.team_id:
|
||||
should.append({
|
||||
"must": [
|
||||
{"key": "team_id", "match": {"value": ctx.team_id}},
|
||||
{"key": "visibility", "match": {"value": "public"}}
|
||||
]
|
||||
})
|
||||
|
||||
# 2. Own content (owner match) - ONLY way to access private
|
||||
if ctx.agent_id:
|
||||
own_vis = ["public", "confidential"]
|
||||
if "private" in allowed_vis:
|
||||
own_vis.append("private")
|
||||
|
||||
should.append({
|
||||
"must": [
|
||||
{"key": "owner_kind", "match": {"value": "agent"}},
|
||||
{"key": "owner_id", "match": {"value": ctx.agent_id}},
|
||||
{"key": "visibility", "match": {"any": own_vis}}
|
||||
]
|
||||
})
|
||||
|
||||
if ctx.user_id:
|
||||
own_vis = ["public", "confidential"]
|
||||
if "private" in allowed_vis:
|
||||
own_vis.append("private")
|
||||
|
||||
should.append({
|
||||
"must": [
|
||||
{"key": "owner_kind", "match": {"value": "user"}},
|
||||
{"key": "owner_id", "match": {"value": ctx.user_id}},
|
||||
{"key": "visibility", "match": {"any": own_vis}}
|
||||
]
|
||||
})
|
||||
|
||||
# 3. Confidential with ACL access (NEVER private via ACL)
|
||||
if "confidential" in allowed_vis:
|
||||
# Access via agent ACL
|
||||
if ctx.agent_id:
|
||||
should.append({
|
||||
"must": [
|
||||
{"key": "visibility", "match": {"value": "confidential"}},
|
||||
{"key": "acl.read_agent_ids", "match": {"value": ctx.agent_id}}
|
||||
]
|
||||
})
|
||||
|
||||
# Access via team ACL
|
||||
if ctx.team_id:
|
||||
should.append({
|
||||
"must": [
|
||||
{"key": "visibility", "match": {"value": "confidential"}},
|
||||
{"key": "acl.read_team_ids", "match": {"value": ctx.team_id}}
|
||||
]
|
||||
})
|
||||
|
||||
# Access via role ACL
|
||||
for role_id in ctx.role_ids:
|
||||
should.append({
|
||||
"must": [
|
||||
{"key": "visibility", "match": {"value": "confidential"}},
|
||||
{"key": "acl.read_role_ids", "match": {"value": role_id}}
|
||||
]
|
||||
})
|
||||
|
||||
# 4. Cross-agent access (if allowed_agent_ids specified)
|
||||
# SECURITY: NEVER includes private - only public+confidential
|
||||
if ctx.allowed_agent_ids:
|
||||
cross_vis = []
|
||||
if "public" in allowed_vis:
|
||||
cross_vis.append("public")
|
||||
if "confidential" in allowed_vis:
|
||||
cross_vis.append("confidential")
|
||||
|
||||
if cross_vis:
|
||||
should.append({
|
||||
"must": [
|
||||
{"key": "agent_id", "match": {"any": ctx.allowed_agent_ids}},
|
||||
{"key": "visibility", "match": {"any": cross_vis}}
|
||||
]
|
||||
})
|
||||
|
||||
return should
|
||||
|
||||
|
||||
def build_agent_only_filter(
|
||||
ctx: AccessContext,
|
||||
agent_id: str,
|
||||
scope: Optional[str] = None,
|
||||
include_team_public: bool = True,
|
||||
include_own_public: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build filter for agent reading only its own content + optional team public.
|
||||
|
||||
Use case: Agent answering user, sees only own knowledge.
|
||||
|
||||
SECURITY: Private content is only accessible if agent_id matches owner_id.
|
||||
|
||||
Args:
|
||||
ctx: Access context
|
||||
agent_id: The agent requesting data (must match for private access)
|
||||
scope: Optional scope filter
|
||||
include_team_public: Include public content from team
|
||||
include_own_public: Include own public content (default True)
|
||||
"""
|
||||
# SECURITY: tenant_id always required
|
||||
if not ctx.tenant_id:
|
||||
raise FilterSecurityError("tenant_id is required and cannot be None/empty")
|
||||
if not agent_id:
|
||||
raise FilterSecurityError("agent_id is required for build_agent_only_filter")
|
||||
|
||||
must = [
|
||||
{"key": "tenant_id", "match": {"value": ctx.tenant_id}},
|
||||
{"key": "indexed", "match": {"value": True}}
|
||||
]
|
||||
|
||||
if scope:
|
||||
must.append({"key": "scope", "match": {"value": scope}})
|
||||
|
||||
should = []
|
||||
|
||||
# Own content: confidential + private (agent is owner)
|
||||
should.append({
|
||||
"must": [
|
||||
{"key": "owner_kind", "match": {"value": "agent"}},
|
||||
{"key": "owner_id", "match": {"value": agent_id}},
|
||||
{"key": "visibility", "match": {"any": ["confidential", "private"]}}
|
||||
]
|
||||
})
|
||||
|
||||
# Own public content (if enabled)
|
||||
if include_own_public:
|
||||
should.append({
|
||||
"must": [
|
||||
{"key": "agent_id", "match": {"value": agent_id}},
|
||||
{"key": "visibility", "match": {"value": "public"}}
|
||||
]
|
||||
})
|
||||
|
||||
# Team public content (if enabled and team exists)
|
||||
if include_team_public and ctx.team_id:
|
||||
should.append({
|
||||
"must": [
|
||||
{"key": "team_id", "match": {"value": ctx.team_id}},
|
||||
{"key": "visibility", "match": {"value": "public"}}
|
||||
]
|
||||
})
|
||||
|
||||
# SECURITY: Ensure should is never empty
|
||||
if not should:
|
||||
raise FilterSecurityError("Filter would have empty should clause")
|
||||
|
||||
return {
|
||||
"must": must,
|
||||
"should": should,
|
||||
"minimum_should_match": 1
|
||||
}
|
||||
|
||||
|
||||
def build_multi_agent_filter(
|
||||
ctx: AccessContext,
|
||||
agent_ids: List[str],
|
||||
scope: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build filter for reading from multiple agents (aggregator use case).
|
||||
|
||||
Use case: DAARWIZZ orchestrator needs cross-agent retrieval.
|
||||
|
||||
SECURITY INVARIANTS:
|
||||
- Private content is ALWAYS excluded (hard-coded must_not)
|
||||
- Only public + confidential from specified agents
|
||||
- No parameter to override this (by design)
|
||||
- should blocks NEVER include visibility=private
|
||||
- agent_ids MUST be subset of ctx.allowed_agent_ids (unless admin)
|
||||
|
||||
Args:
|
||||
ctx: Access context (must have allowed_agent_ids or is_admin=True)
|
||||
agent_ids: List of agents to read from (must be allowed)
|
||||
scope: Optional scope filter
|
||||
|
||||
Raises:
|
||||
FilterSecurityError: If agent_ids contains unauthorized agents
|
||||
"""
|
||||
# SECURITY: tenant_id always required
|
||||
if not ctx.tenant_id:
|
||||
raise FilterSecurityError("tenant_id is required and cannot be None/empty")
|
||||
if not agent_ids:
|
||||
raise FilterSecurityError("agent_ids cannot be empty for build_multi_agent_filter")
|
||||
|
||||
# SECURITY: Validate agent_ids are allowed for this context
|
||||
if not ctx.is_admin:
|
||||
if not ctx.allowed_agent_ids:
|
||||
raise FilterSecurityError(
|
||||
"build_multi_agent_filter requires ctx.allowed_agent_ids or ctx.is_admin=True. "
|
||||
"Cannot access arbitrary agents without explicit permission."
|
||||
)
|
||||
|
||||
requested_set = set(agent_ids)
|
||||
allowed_set = set(ctx.allowed_agent_ids)
|
||||
unauthorized = requested_set - allowed_set
|
||||
|
||||
if unauthorized:
|
||||
raise FilterSecurityError(
|
||||
f"Unauthorized agent access: {sorted(unauthorized)}. "
|
||||
f"Allowed agents: {sorted(allowed_set)}"
|
||||
)
|
||||
|
||||
must = [
|
||||
{"key": "tenant_id", "match": {"value": ctx.tenant_id}},
|
||||
{"key": "indexed", "match": {"value": True}}
|
||||
]
|
||||
|
||||
if scope:
|
||||
must.append({"key": "scope", "match": {"value": scope}})
|
||||
|
||||
should = [
|
||||
# Content from allowed agents (public + confidential ONLY)
|
||||
{
|
||||
"must": [
|
||||
{"key": "agent_id", "match": {"any": agent_ids}},
|
||||
{"key": "visibility", "match": {"any": ["public", "confidential"]}}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Team public content
|
||||
if ctx.team_id:
|
||||
should.append({
|
||||
"must": [
|
||||
{"key": "team_id", "match": {"value": ctx.team_id}},
|
||||
{"key": "visibility", "match": {"value": "public"}}
|
||||
]
|
||||
})
|
||||
|
||||
return {
|
||||
"must": must,
|
||||
"should": should,
|
||||
"minimum_should_match": 1,
|
||||
# SECURITY: ALWAYS exclude private - no parameter to override
|
||||
"must_not": [
|
||||
{"key": "visibility", "match": {"value": "private"}}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def build_project_filter(
|
||||
ctx: AccessContext,
|
||||
project_id: str,
|
||||
scope: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build filter for project-scoped retrieval.
|
||||
|
||||
Use case: Reading only within one project.
|
||||
|
||||
SECURITY: Private is always excluded.
|
||||
"""
|
||||
# SECURITY: tenant_id always required
|
||||
if not ctx.tenant_id:
|
||||
raise FilterSecurityError("tenant_id is required and cannot be None/empty")
|
||||
if not project_id:
|
||||
raise FilterSecurityError("project_id is required for build_project_filter")
|
||||
|
||||
must = [
|
||||
{"key": "tenant_id", "match": {"value": ctx.tenant_id}},
|
||||
{"key": "project_id", "match": {"value": project_id}},
|
||||
{"key": "indexed", "match": {"value": True}},
|
||||
# SECURITY: Only public + confidential, never private
|
||||
{"key": "visibility", "match": {"any": ["public", "confidential"]}}
|
||||
]
|
||||
|
||||
if scope:
|
||||
must.append({"key": "scope", "match": {"value": scope}})
|
||||
|
||||
return {
|
||||
"must": must,
|
||||
"must_not": [
|
||||
{"key": "visibility", "match": {"value": "private"}}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def build_tag_filter(
|
||||
ctx: AccessContext,
|
||||
tags: List[str],
|
||||
scope: str = "docs",
|
||||
visibility: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build filter for tag-based retrieval.
|
||||
|
||||
Use case: Druid legal KB only, Nutra food knowledge.
|
||||
|
||||
SECURITY: Private is excluded by default.
|
||||
"""
|
||||
# SECURITY: tenant_id always required
|
||||
if not ctx.tenant_id:
|
||||
raise FilterSecurityError("tenant_id is required and cannot be None/empty")
|
||||
if not tags:
|
||||
raise FilterSecurityError("tags cannot be empty for build_tag_filter")
|
||||
|
||||
must = [
|
||||
{"key": "tenant_id", "match": {"value": ctx.tenant_id}},
|
||||
{"key": "scope", "match": {"value": scope}},
|
||||
{"key": "indexed", "match": {"value": True}},
|
||||
]
|
||||
|
||||
# Visibility: default to public+confidential, never private unless explicitly specified
|
||||
if visibility:
|
||||
must.append({"key": "visibility", "match": {"value": visibility}})
|
||||
else:
|
||||
must.append({"key": "visibility", "match": {"any": ["public", "confidential"]}})
|
||||
|
||||
# Add tag conditions (all must match)
|
||||
for tag in tags:
|
||||
must.append({"key": "tags", "match": {"value": tag}})
|
||||
|
||||
return {"must": must}
|
||||
283
services/memory/qdrant/payload_validation.py
Normal file
283
services/memory/qdrant/payload_validation.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Payload Validation for Co-Memory Qdrant
|
||||
|
||||
Validates payloads against cm_payload_v1 schema before upsert.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Try to use jsonschema if available, otherwise use manual validation
|
||||
try:
|
||||
import jsonschema
|
||||
HAS_JSONSCHEMA = True
|
||||
except ImportError:
|
||||
HAS_JSONSCHEMA = False
|
||||
|
||||
|
||||
class PayloadValidationError(Exception):
|
||||
"""Raised when payload validation fails."""
|
||||
|
||||
def __init__(self, message: str, errors: Optional[List[str]] = None):
|
||||
super().__init__(message)
|
||||
self.errors = errors or []
|
||||
|
||||
|
||||
# Enums
|
||||
VALID_SCOPES = {"docs", "messages", "memory", "artifacts", "signals"}
|
||||
VALID_VISIBILITY = {"public", "confidential", "private"}
|
||||
VALID_OWNER_KINDS = {"user", "team", "agent"}
|
||||
VALID_SOURCE_KINDS = {"document", "wiki", "message", "artifact", "web", "code"}
|
||||
VALID_METRICS = {"cosine", "dot", "euclidean"}
|
||||
|
||||
# ID patterns
|
||||
TENANT_ID_PATTERN = re.compile(r"^t_[a-z0-9_]+$")
|
||||
TEAM_ID_PATTERN = re.compile(r"^team_[a-z0-9_]+$")
|
||||
PROJECT_ID_PATTERN = re.compile(r"^proj_[a-z0-9_]+$")
|
||||
AGENT_ID_PATTERN = re.compile(r"^agt_[a-z0-9_]+$")
|
||||
SOURCE_ID_PATTERN = re.compile(r"^(doc|msg|art|web|code)_[A-Za-z0-9]+$")
|
||||
CHUNK_ID_PATTERN = re.compile(r"^chk_[A-Za-z0-9]+$")
|
||||
|
||||
|
||||
def _load_json_schema() -> Optional[Dict]:
|
||||
"""Load JSON schema from file if available."""
|
||||
schema_path = Path(__file__).parent.parent.parent.parent / "docs" / "memory" / "cm_payload_v1.schema.json"
|
||||
if schema_path.exists():
|
||||
with open(schema_path) as f:
|
||||
return json.load(f)
|
||||
return None
|
||||
|
||||
|
||||
_SCHEMA = _load_json_schema()
|
||||
|
||||
|
||||
def validate_payload(payload: Dict[str, Any], strict: bool = True) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate payload against cm_payload_v1 schema.
|
||||
|
||||
Args:
|
||||
payload: The payload dictionary to validate
|
||||
strict: If True, raise exception on validation failure
|
||||
|
||||
Returns:
|
||||
The validated payload (potentially with defaults added)
|
||||
|
||||
Raises:
|
||||
PayloadValidationError: If validation fails and strict=True
|
||||
"""
|
||||
errors = []
|
||||
|
||||
# Use jsonschema if available
|
||||
if HAS_JSONSCHEMA and _SCHEMA:
|
||||
try:
|
||||
jsonschema.validate(payload, _SCHEMA)
|
||||
except jsonschema.ValidationError as e:
|
||||
errors.append(f"Schema validation: {e.message}")
|
||||
else:
|
||||
# Manual validation
|
||||
errors.extend(_validate_required_fields(payload))
|
||||
errors.extend(_validate_field_values(payload))
|
||||
|
||||
if errors and strict:
|
||||
raise PayloadValidationError(
|
||||
f"Payload validation failed: {len(errors)} error(s)",
|
||||
errors=errors
|
||||
)
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
def _validate_required_fields(payload: Dict[str, Any]) -> List[str]:
|
||||
"""Validate required fields are present."""
|
||||
errors = []
|
||||
|
||||
required = [
|
||||
"schema_version",
|
||||
"tenant_id",
|
||||
"owner_kind",
|
||||
"owner_id",
|
||||
"scope",
|
||||
"visibility",
|
||||
"indexed",
|
||||
"source_kind",
|
||||
"source_id",
|
||||
"chunk",
|
||||
"fingerprint",
|
||||
"created_at",
|
||||
]
|
||||
|
||||
for field in required:
|
||||
if field not in payload:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Check nested required fields
|
||||
if "chunk" in payload and isinstance(payload["chunk"], dict):
|
||||
if "chunk_id" not in payload["chunk"]:
|
||||
errors.append("Missing required field: chunk.chunk_id")
|
||||
if "chunk_idx" not in payload["chunk"]:
|
||||
errors.append("Missing required field: chunk.chunk_idx")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def _validate_field_values(payload: Dict[str, Any]) -> List[str]:
|
||||
"""Validate field values match expected formats."""
|
||||
errors = []
|
||||
|
||||
# Schema version
|
||||
if payload.get("schema_version") != "cm_payload_v1":
|
||||
errors.append(f"Invalid schema_version: {payload.get('schema_version')}, expected 'cm_payload_v1'")
|
||||
|
||||
# Tenant ID
|
||||
tenant_id = payload.get("tenant_id")
|
||||
if tenant_id and not TENANT_ID_PATTERN.match(tenant_id):
|
||||
errors.append(f"Invalid tenant_id format: {tenant_id}")
|
||||
|
||||
# Team ID (optional)
|
||||
team_id = payload.get("team_id")
|
||||
if team_id and not TEAM_ID_PATTERN.match(team_id):
|
||||
errors.append(f"Invalid team_id format: {team_id}")
|
||||
|
||||
# Project ID (optional)
|
||||
project_id = payload.get("project_id")
|
||||
if project_id and not PROJECT_ID_PATTERN.match(project_id):
|
||||
errors.append(f"Invalid project_id format: {project_id}")
|
||||
|
||||
# Agent ID (optional)
|
||||
agent_id = payload.get("agent_id")
|
||||
if agent_id and not AGENT_ID_PATTERN.match(agent_id):
|
||||
errors.append(f"Invalid agent_id format: {agent_id}")
|
||||
|
||||
# Scope
|
||||
if payload.get("scope") not in VALID_SCOPES:
|
||||
errors.append(f"Invalid scope: {payload.get('scope')}, valid: {VALID_SCOPES}")
|
||||
|
||||
# Visibility
|
||||
if payload.get("visibility") not in VALID_VISIBILITY:
|
||||
errors.append(f"Invalid visibility: {payload.get('visibility')}, valid: {VALID_VISIBILITY}")
|
||||
|
||||
# Owner kind
|
||||
if payload.get("owner_kind") not in VALID_OWNER_KINDS:
|
||||
errors.append(f"Invalid owner_kind: {payload.get('owner_kind')}, valid: {VALID_OWNER_KINDS}")
|
||||
|
||||
# Source kind
|
||||
if payload.get("source_kind") not in VALID_SOURCE_KINDS:
|
||||
errors.append(f"Invalid source_kind: {payload.get('source_kind')}, valid: {VALID_SOURCE_KINDS}")
|
||||
|
||||
# Source ID
|
||||
source_id = payload.get("source_id")
|
||||
if source_id and not SOURCE_ID_PATTERN.match(source_id):
|
||||
errors.append(f"Invalid source_id format: {source_id}")
|
||||
|
||||
# Chunk
|
||||
chunk = payload.get("chunk", {})
|
||||
if isinstance(chunk, dict):
|
||||
chunk_id = chunk.get("chunk_id")
|
||||
if chunk_id and not CHUNK_ID_PATTERN.match(chunk_id):
|
||||
errors.append(f"Invalid chunk.chunk_id format: {chunk_id}")
|
||||
|
||||
chunk_idx = chunk.get("chunk_idx")
|
||||
if chunk_idx is not None and (not isinstance(chunk_idx, int) or chunk_idx < 0):
|
||||
errors.append(f"Invalid chunk.chunk_idx: {chunk_idx}, must be non-negative integer")
|
||||
|
||||
# Indexed
|
||||
if not isinstance(payload.get("indexed"), bool):
|
||||
errors.append(f"Invalid indexed: {payload.get('indexed')}, must be boolean")
|
||||
|
||||
# Created at
|
||||
created_at = payload.get("created_at")
|
||||
if created_at:
|
||||
try:
|
||||
datetime.fromisoformat(created_at.replace("Z", "+00:00"))
|
||||
except (ValueError, AttributeError):
|
||||
errors.append(f"Invalid created_at format: {created_at}, expected ISO 8601")
|
||||
|
||||
# Embedding (optional)
|
||||
embedding = payload.get("embedding", {})
|
||||
if isinstance(embedding, dict):
|
||||
if "metric" in embedding and embedding["metric"] not in VALID_METRICS:
|
||||
errors.append(f"Invalid embedding.metric: {embedding['metric']}, valid: {VALID_METRICS}")
|
||||
if "dim" in embedding and (not isinstance(embedding["dim"], int) or embedding["dim"] < 1):
|
||||
errors.append(f"Invalid embedding.dim: {embedding['dim']}, must be positive integer")
|
||||
|
||||
# Importance (optional)
|
||||
importance = payload.get("importance")
|
||||
if importance is not None and (not isinstance(importance, (int, float)) or importance < 0 or importance > 1):
|
||||
errors.append(f"Invalid importance: {importance}, must be 0-1")
|
||||
|
||||
# TTL days (optional)
|
||||
ttl_days = payload.get("ttl_days")
|
||||
if ttl_days is not None and (not isinstance(ttl_days, int) or ttl_days < 1):
|
||||
errors.append(f"Invalid ttl_days: {ttl_days}, must be positive integer")
|
||||
|
||||
# ACL fields (must be arrays of non-empty strings, no nulls)
|
||||
acl = payload.get("acl", {})
|
||||
if isinstance(acl, dict):
|
||||
for acl_field in ["read_team_ids", "read_agent_ids", "read_role_ids"]:
|
||||
value = acl.get(acl_field)
|
||||
if value is not None:
|
||||
if not isinstance(value, list):
|
||||
errors.append(f"Invalid acl.{acl_field}: must be array, got {type(value).__name__}")
|
||||
elif not all(isinstance(item, str) and item for item in value):
|
||||
# Check: all items must be non-empty strings (no None, no "")
|
||||
errors.append(f"Invalid acl.{acl_field}: all items must be non-empty strings (no null/empty)")
|
||||
elif acl is not None:
|
||||
errors.append(f"Invalid acl: must be object, got {type(acl).__name__}")
|
||||
|
||||
# Tags must be array of non-empty strings
|
||||
tags = payload.get("tags")
|
||||
if tags is not None:
|
||||
if not isinstance(tags, list):
|
||||
errors.append(f"Invalid tags: must be array, got {type(tags).__name__}")
|
||||
elif not all(isinstance(item, str) and item for item in tags):
|
||||
errors.append(f"Invalid tags: all items must be non-empty strings (no null/empty)")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def create_minimal_payload(
|
||||
tenant_id: str,
|
||||
source_id: str,
|
||||
chunk_id: str,
|
||||
chunk_idx: int,
|
||||
fingerprint: str,
|
||||
scope: str = "docs",
|
||||
visibility: str = "confidential",
|
||||
owner_kind: str = "team",
|
||||
owner_id: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
team_id: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a minimal valid payload with required fields.
|
||||
|
||||
Returns a payload that passes validation.
|
||||
"""
|
||||
payload = {
|
||||
"schema_version": "cm_payload_v1",
|
||||
"tenant_id": tenant_id,
|
||||
"team_id": team_id,
|
||||
"agent_id": agent_id,
|
||||
"owner_kind": owner_kind,
|
||||
"owner_id": owner_id or team_id or tenant_id,
|
||||
"scope": scope,
|
||||
"visibility": visibility,
|
||||
"indexed": True,
|
||||
"source_kind": "document",
|
||||
"source_id": source_id,
|
||||
"chunk": {
|
||||
"chunk_id": chunk_id,
|
||||
"chunk_idx": chunk_idx,
|
||||
},
|
||||
"fingerprint": fingerprint,
|
||||
"created_at": datetime.utcnow().isoformat() + "Z",
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
payload.update(kwargs)
|
||||
|
||||
return payload
|
||||
1
services/memory/qdrant/tests/__init__.py
Normal file
1
services/memory/qdrant/tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Tests for Co-Memory Qdrant module
|
||||
446
services/memory/qdrant/tests/test_filters.py
Normal file
446
services/memory/qdrant/tests/test_filters.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
Unit tests for Co-Memory Qdrant filters.
|
||||
|
||||
Tests security invariants:
|
||||
- Empty should never returned (would match everything)
|
||||
- Private content only accessible by owner
|
||||
- tenant_id always in must conditions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from services.memory.qdrant.filters import (
|
||||
AccessContext,
|
||||
FilterSecurityError,
|
||||
build_qdrant_filter,
|
||||
build_agent_only_filter,
|
||||
build_multi_agent_filter,
|
||||
build_project_filter,
|
||||
build_tag_filter,
|
||||
)
|
||||
|
||||
|
||||
class TestAccessContext:
|
||||
"""Tests for AccessContext dataclass."""
|
||||
|
||||
def test_minimal_context(self):
|
||||
"""Test creating minimal access context."""
|
||||
ctx = AccessContext(tenant_id="t_daarion")
|
||||
assert ctx.tenant_id == "t_daarion"
|
||||
assert ctx.team_id is None
|
||||
assert ctx.agent_id is None
|
||||
assert ctx.is_admin is False
|
||||
|
||||
def test_full_context(self):
|
||||
"""Test creating full access context."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
project_id="proj_helion",
|
||||
agent_id="agt_helion",
|
||||
user_id="user_123",
|
||||
role_ids=["role_admin"],
|
||||
allowed_agent_ids=["agt_nutra", "agt_druid"],
|
||||
is_admin=True,
|
||||
)
|
||||
assert ctx.tenant_id == "t_daarion"
|
||||
assert ctx.team_id == "team_core"
|
||||
assert ctx.agent_id == "agt_helion"
|
||||
assert len(ctx.allowed_agent_ids) == 2
|
||||
|
||||
|
||||
class TestBuildQdrantFilter:
|
||||
"""Tests for build_qdrant_filter function."""
|
||||
|
||||
def test_basic_filter_with_identity(self):
|
||||
"""Test basic filter with tenant and identity."""
|
||||
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
|
||||
|
||||
result = build_qdrant_filter(ctx)
|
||||
|
||||
assert "must" in result
|
||||
# Check tenant_id is in must conditions
|
||||
tenant_condition = next(
|
||||
(c for c in result["must"] if c.get("key") == "tenant_id"),
|
||||
None
|
||||
)
|
||||
assert tenant_condition is not None
|
||||
assert tenant_condition["match"]["value"] == "t_daarion"
|
||||
|
||||
def test_empty_context_raises_security_error(self):
|
||||
"""Test that context without identity raises FilterSecurityError."""
|
||||
ctx = AccessContext(tenant_id="t_daarion")
|
||||
# No team_id, agent_id, user_id, or allowed_agent_ids
|
||||
|
||||
with pytest.raises(FilterSecurityError) as exc_info:
|
||||
build_qdrant_filter(ctx)
|
||||
|
||||
assert "empty" in str(exc_info.value).lower()
|
||||
|
||||
def test_admin_bypasses_identity_check(self):
|
||||
"""Test that admin context doesn't require identity."""
|
||||
ctx = AccessContext(tenant_id="t_daarion", is_admin=True)
|
||||
|
||||
# Should not raise
|
||||
result = build_qdrant_filter(ctx)
|
||||
assert "must" in result
|
||||
|
||||
def test_indexed_only_filter(self):
|
||||
"""Test that indexed=true is included by default."""
|
||||
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
|
||||
|
||||
result = build_qdrant_filter(ctx, indexed_only=True)
|
||||
|
||||
indexed_condition = next(
|
||||
(c for c in result["must"] if c.get("key") == "indexed"),
|
||||
None
|
||||
)
|
||||
assert indexed_condition is not None
|
||||
assert indexed_condition["match"]["value"] is True
|
||||
|
||||
def test_scope_filter(self):
|
||||
"""Test single scope filter."""
|
||||
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
|
||||
|
||||
result = build_qdrant_filter(ctx, scope="docs")
|
||||
|
||||
scope_condition = next(
|
||||
(c for c in result["must"] if c.get("key") == "scope"),
|
||||
None
|
||||
)
|
||||
assert scope_condition is not None
|
||||
assert scope_condition["match"]["value"] == "docs"
|
||||
|
||||
def test_multiple_scopes_filter(self):
|
||||
"""Test multiple scopes filter."""
|
||||
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
|
||||
|
||||
result = build_qdrant_filter(ctx, scopes=["docs", "messages"])
|
||||
|
||||
scope_condition = next(
|
||||
(c for c in result["must"] if c.get("key") == "scope"),
|
||||
None
|
||||
)
|
||||
assert scope_condition is not None
|
||||
assert scope_condition["match"]["any"] == ["docs", "messages"]
|
||||
|
||||
def test_source_id_filter(self):
|
||||
"""Test source_id filter."""
|
||||
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
|
||||
|
||||
result = build_qdrant_filter(ctx, source_id="doc_123")
|
||||
|
||||
source_condition = next(
|
||||
(c for c in result["must"] if c.get("key") == "source_id"),
|
||||
None
|
||||
)
|
||||
assert source_condition is not None
|
||||
assert source_condition["match"]["value"] == "doc_123"
|
||||
|
||||
def test_admin_access(self):
|
||||
"""Test admin has broader access."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
is_admin=True,
|
||||
)
|
||||
|
||||
result = build_qdrant_filter(ctx, visibility="private")
|
||||
|
||||
# Admin should get simpler filter
|
||||
assert "must" in result
|
||||
|
||||
|
||||
class TestSecurityInvariants:
|
||||
"""Tests for security invariants in filters."""
|
||||
|
||||
def test_private_not_in_multi_agent_filter(self):
|
||||
"""Test that multi-agent filter ALWAYS excludes private."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
allowed_agent_ids=["agt_helion", "agt_nutra"], # Must have permission
|
||||
)
|
||||
|
||||
result = build_multi_agent_filter(
|
||||
ctx,
|
||||
agent_ids=["agt_helion", "agt_nutra"],
|
||||
)
|
||||
|
||||
# Must have must_not with private
|
||||
assert "must_not" in result
|
||||
private_exclusion = next(
|
||||
(c for c in result["must_not"]
|
||||
if c.get("key") == "visibility" and c.get("match", {}).get("value") == "private"),
|
||||
None
|
||||
)
|
||||
assert private_exclusion is not None
|
||||
|
||||
def test_empty_agent_ids_raises_error(self):
|
||||
"""Test that empty agent_ids raises error."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
allowed_agent_ids=["agt_helion"]
|
||||
)
|
||||
|
||||
with pytest.raises(FilterSecurityError):
|
||||
build_multi_agent_filter(ctx, agent_ids=[])
|
||||
|
||||
def test_unauthorized_agent_ids_raises_error(self):
|
||||
"""Test that requesting unauthorized agent_ids raises error."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
allowed_agent_ids=["agt_helion", "agt_nutra"]
|
||||
)
|
||||
|
||||
# Try to access agent not in allowed list
|
||||
with pytest.raises(FilterSecurityError) as exc_info:
|
||||
build_multi_agent_filter(ctx, agent_ids=["agt_helion", "agt_druid"])
|
||||
|
||||
assert "agt_druid" in str(exc_info.value)
|
||||
assert "Unauthorized" in str(exc_info.value)
|
||||
|
||||
def test_multi_agent_requires_allowed_list_or_admin(self):
|
||||
"""Test that multi-agent filter requires allowed_agent_ids or admin."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
# No allowed_agent_ids, not admin
|
||||
)
|
||||
|
||||
with pytest.raises(FilterSecurityError) as exc_info:
|
||||
build_multi_agent_filter(ctx, agent_ids=["agt_helion"])
|
||||
|
||||
assert "allowed_agent_ids" in str(exc_info.value)
|
||||
|
||||
def test_admin_can_access_any_agents(self):
|
||||
"""Test that admin can access any agent without allowed list."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
is_admin=True,
|
||||
# No allowed_agent_ids needed for admin
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
result = build_multi_agent_filter(ctx, agent_ids=["agt_helion", "agt_druid", "agt_nutra"])
|
||||
assert "must" in result
|
||||
assert "should" in result
|
||||
|
||||
def test_empty_agent_id_for_agent_only_raises_error(self):
|
||||
"""Test that empty agent_id raises error for agent_only filter."""
|
||||
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
|
||||
|
||||
with pytest.raises(FilterSecurityError):
|
||||
build_agent_only_filter(ctx, agent_id="")
|
||||
|
||||
def test_tenant_id_always_present(self):
|
||||
"""Test that tenant_id is always in must conditions."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
agent_id="agt_helion",
|
||||
)
|
||||
|
||||
# Test all filter builders
|
||||
filters = [
|
||||
build_qdrant_filter(ctx),
|
||||
build_agent_only_filter(ctx, agent_id="agt_helion"),
|
||||
build_multi_agent_filter(ctx, agent_ids=["agt_helion"]),
|
||||
build_project_filter(ctx, project_id="proj_helion"),
|
||||
build_tag_filter(ctx, tags=["test"]),
|
||||
]
|
||||
|
||||
for f in filters:
|
||||
tenant_condition = next(
|
||||
(c for c in f["must"] if c.get("key") == "tenant_id"),
|
||||
None
|
||||
)
|
||||
assert tenant_condition is not None, f"tenant_id missing in {f}"
|
||||
|
||||
def test_private_only_for_owner(self):
|
||||
"""Test that private is only accessible when owner matches."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
agent_id="agt_helion",
|
||||
)
|
||||
|
||||
result = build_agent_only_filter(ctx, agent_id="agt_helion")
|
||||
|
||||
# Should have owner check with private
|
||||
should_conditions = result.get("should", [])
|
||||
|
||||
# Find the condition that allows private
|
||||
private_allowed = False
|
||||
for cond in should_conditions:
|
||||
must = cond.get("must", [])
|
||||
has_private = any(
|
||||
c.get("key") == "visibility" and "private" in str(c.get("match", {}))
|
||||
for c in must
|
||||
)
|
||||
has_owner_check = any(
|
||||
c.get("key") == "owner_id" and c.get("match", {}).get("value") == "agt_helion"
|
||||
for c in must
|
||||
)
|
||||
if has_private and has_owner_check:
|
||||
private_allowed = True
|
||||
break
|
||||
|
||||
assert private_allowed, "Private should only be allowed with owner check"
|
||||
|
||||
|
||||
class TestBuildAgentOnlyFilter:
|
||||
"""Tests for build_agent_only_filter function."""
|
||||
|
||||
def test_agent_own_content(self):
|
||||
"""Test filter for agent's own content."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
)
|
||||
|
||||
result = build_agent_only_filter(ctx, agent_id="agt_helion")
|
||||
|
||||
assert "must" in result
|
||||
assert "should" in result
|
||||
assert result.get("minimum_should_match") == 1
|
||||
|
||||
# Check own content condition exists
|
||||
own_content = result["should"][0]
|
||||
assert "must" in own_content
|
||||
|
||||
def test_agent_with_team_public(self):
|
||||
"""Test filter includes team public when enabled."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
)
|
||||
|
||||
result = build_agent_only_filter(
|
||||
ctx,
|
||||
agent_id="agt_helion",
|
||||
include_team_public=True,
|
||||
)
|
||||
|
||||
# Should have 2 should conditions: own + team public
|
||||
assert len(result["should"]) == 2
|
||||
|
||||
def test_agent_without_team_public(self):
|
||||
"""Test filter excludes team public when disabled."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
)
|
||||
|
||||
result = build_agent_only_filter(
|
||||
ctx,
|
||||
agent_id="agt_helion",
|
||||
include_team_public=False,
|
||||
)
|
||||
|
||||
# Should have only 1 should condition: own
|
||||
assert len(result["should"]) == 1
|
||||
|
||||
|
||||
class TestBuildMultiAgentFilter:
|
||||
"""Tests for build_multi_agent_filter function."""
|
||||
|
||||
def test_multi_agent_access(self):
|
||||
"""Test filter for accessing multiple agents."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
allowed_agent_ids=["agt_helion", "agt_nutra", "agt_druid"], # Must have permission
|
||||
)
|
||||
|
||||
result = build_multi_agent_filter(
|
||||
ctx,
|
||||
agent_ids=["agt_helion", "agt_nutra", "agt_druid"],
|
||||
)
|
||||
|
||||
assert "must" in result
|
||||
assert "should" in result
|
||||
|
||||
# Check agent_ids are in any match
|
||||
agent_condition = result["should"][0]["must"][0]
|
||||
assert agent_condition["key"] == "agent_id"
|
||||
assert "any" in agent_condition["match"]
|
||||
assert len(agent_condition["match"]["any"]) == 3
|
||||
|
||||
def test_multi_agent_excludes_private(self):
|
||||
"""Test filter excludes private always (no parameter to override)."""
|
||||
ctx = AccessContext(
|
||||
tenant_id="t_daarion",
|
||||
team_id="team_core",
|
||||
allowed_agent_ids=["agt_helion"],
|
||||
)
|
||||
|
||||
result = build_multi_agent_filter(
|
||||
ctx,
|
||||
agent_ids=["agt_helion"],
|
||||
)
|
||||
|
||||
# must_not is always present with private exclusion
|
||||
assert "must_not" in result
|
||||
private_exclusion = result["must_not"][0]
|
||||
assert private_exclusion["key"] == "visibility"
|
||||
assert private_exclusion["match"]["value"] == "private"
|
||||
|
||||
|
||||
class TestBuildProjectFilter:
|
||||
"""Tests for build_project_filter function."""
|
||||
|
||||
def test_project_scoped(self):
|
||||
"""Test project-scoped filter."""
|
||||
ctx = AccessContext(tenant_id="t_daarion")
|
||||
|
||||
result = build_project_filter(ctx, project_id="proj_helion")
|
||||
|
||||
project_condition = next(
|
||||
(c for c in result["must"] if c.get("key") == "project_id"),
|
||||
None
|
||||
)
|
||||
assert project_condition is not None
|
||||
assert project_condition["match"]["value"] == "proj_helion"
|
||||
|
||||
def test_project_excludes_private(self):
|
||||
"""Test project filter excludes private."""
|
||||
ctx = AccessContext(tenant_id="t_daarion")
|
||||
|
||||
result = build_project_filter(ctx, project_id="proj_helion")
|
||||
|
||||
assert "must_not" in result
|
||||
|
||||
|
||||
class TestBuildTagFilter:
|
||||
"""Tests for build_tag_filter function."""
|
||||
|
||||
def test_single_tag(self):
|
||||
"""Test single tag filter."""
|
||||
ctx = AccessContext(tenant_id="t_daarion")
|
||||
|
||||
result = build_tag_filter(ctx, tags=["legal_kb"])
|
||||
|
||||
tag_condition = next(
|
||||
(c for c in result["must"] if c.get("key") == "tags"),
|
||||
None
|
||||
)
|
||||
assert tag_condition is not None
|
||||
assert tag_condition["match"]["value"] == "legal_kb"
|
||||
|
||||
def test_multiple_tags(self):
|
||||
"""Test multiple tags filter (all must match)."""
|
||||
ctx = AccessContext(tenant_id="t_daarion")
|
||||
|
||||
result = build_tag_filter(ctx, tags=["legal_kb", "contracts"])
|
||||
|
||||
# All tags should be in must conditions
|
||||
tag_conditions = [
|
||||
c for c in result["must"] if c.get("key") == "tags"
|
||||
]
|
||||
assert len(tag_conditions) == 2
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
282
services/memory/qdrant/tests/test_payload_validation.py
Normal file
282
services/memory/qdrant/tests/test_payload_validation.py
Normal file
@@ -0,0 +1,282 @@
|
||||
"""
|
||||
Unit tests for Co-Memory payload validation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from services.memory.qdrant.payload_validation import (
|
||||
validate_payload,
|
||||
PayloadValidationError,
|
||||
create_minimal_payload,
|
||||
)
|
||||
|
||||
|
||||
class TestPayloadValidation:
|
||||
"""Tests for payload validation."""
|
||||
|
||||
def test_valid_minimal_payload(self):
|
||||
"""Test that minimal valid payload passes validation."""
|
||||
payload = create_minimal_payload(
|
||||
tenant_id="t_daarion",
|
||||
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
|
||||
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
|
||||
chunk_idx=0,
|
||||
fingerprint="sha256:abc123",
|
||||
team_id="team_core",
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
result = validate_payload(payload)
|
||||
assert result["schema_version"] == "cm_payload_v1"
|
||||
assert result["tenant_id"] == "t_daarion"
|
||||
|
||||
def test_valid_full_payload(self):
|
||||
"""Test that full payload with all fields passes validation."""
|
||||
payload = {
|
||||
"schema_version": "cm_payload_v1",
|
||||
"tenant_id": "t_daarion",
|
||||
"team_id": "team_core",
|
||||
"project_id": "proj_helion",
|
||||
"agent_id": "agt_helion",
|
||||
"owner_kind": "agent",
|
||||
"owner_id": "agt_helion",
|
||||
"scope": "docs",
|
||||
"visibility": "confidential",
|
||||
"indexed": True,
|
||||
"source_kind": "document",
|
||||
"source_id": "doc_01HQ8K9X2NPQR3FGJKLM5678",
|
||||
"chunk": {
|
||||
"chunk_id": "chk_01HQ8K9X3MPQR3FGJKLM9012",
|
||||
"chunk_idx": 0,
|
||||
},
|
||||
"fingerprint": "sha256:abc123def456",
|
||||
"created_at": "2026-01-26T12:00:00Z",
|
||||
"acl": {
|
||||
"read_team_ids": ["team_core"],
|
||||
"read_agent_ids": ["agt_nutra"],
|
||||
},
|
||||
"tags": ["product", "features"],
|
||||
"lang": "uk",
|
||||
"importance": 0.8,
|
||||
"embedding": {
|
||||
"model": "cohere-embed-v3",
|
||||
"dim": 1024,
|
||||
"metric": "cosine",
|
||||
}
|
||||
}
|
||||
|
||||
result = validate_payload(payload)
|
||||
assert result["agent_id"] == "agt_helion"
|
||||
assert result["scope"] == "docs"
|
||||
|
||||
def test_missing_required_field(self):
|
||||
"""Test that missing required field raises error."""
|
||||
payload = {
|
||||
"schema_version": "cm_payload_v1",
|
||||
"tenant_id": "t_daarion",
|
||||
# Missing other required fields
|
||||
}
|
||||
|
||||
with pytest.raises(PayloadValidationError) as exc_info:
|
||||
validate_payload(payload)
|
||||
|
||||
assert "Missing required field" in str(exc_info.value)
|
||||
|
||||
def test_invalid_schema_version(self):
|
||||
"""Test that invalid schema version raises error."""
|
||||
payload = create_minimal_payload(
|
||||
tenant_id="t_daarion",
|
||||
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
|
||||
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
|
||||
chunk_idx=0,
|
||||
fingerprint="sha256:abc123",
|
||||
)
|
||||
payload["schema_version"] = "v2" # Invalid
|
||||
|
||||
with pytest.raises(PayloadValidationError) as exc_info:
|
||||
validate_payload(payload)
|
||||
|
||||
assert "schema_version" in str(exc_info.value)
|
||||
|
||||
def test_invalid_tenant_id_format(self):
|
||||
"""Test that invalid tenant_id format raises error."""
|
||||
payload = create_minimal_payload(
|
||||
tenant_id="invalid-tenant", # Wrong format
|
||||
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
|
||||
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
|
||||
chunk_idx=0,
|
||||
fingerprint="sha256:abc123",
|
||||
)
|
||||
|
||||
with pytest.raises(PayloadValidationError) as exc_info:
|
||||
validate_payload(payload)
|
||||
|
||||
assert "tenant_id" in str(exc_info.value)
|
||||
|
||||
def test_invalid_agent_id_format(self):
|
||||
"""Test that invalid agent_id format raises error."""
|
||||
payload = create_minimal_payload(
|
||||
tenant_id="t_daarion",
|
||||
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
|
||||
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
|
||||
chunk_idx=0,
|
||||
fingerprint="sha256:abc123",
|
||||
)
|
||||
payload["agent_id"] = "helion" # Missing prefix
|
||||
|
||||
with pytest.raises(PayloadValidationError) as exc_info:
|
||||
validate_payload(payload)
|
||||
|
||||
assert "agent_id" in str(exc_info.value)
|
||||
|
||||
def test_invalid_scope(self):
|
||||
"""Test that invalid scope raises error."""
|
||||
payload = create_minimal_payload(
|
||||
tenant_id="t_daarion",
|
||||
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
|
||||
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
|
||||
chunk_idx=0,
|
||||
fingerprint="sha256:abc123",
|
||||
scope="invalid_scope",
|
||||
)
|
||||
|
||||
with pytest.raises(PayloadValidationError) as exc_info:
|
||||
validate_payload(payload)
|
||||
|
||||
assert "scope" in str(exc_info.value)
|
||||
|
||||
def test_invalid_visibility(self):
|
||||
"""Test that invalid visibility raises error."""
|
||||
payload = create_minimal_payload(
|
||||
tenant_id="t_daarion",
|
||||
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
|
||||
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
|
||||
chunk_idx=0,
|
||||
fingerprint="sha256:abc123",
|
||||
visibility="secret", # Invalid
|
||||
)
|
||||
|
||||
with pytest.raises(PayloadValidationError) as exc_info:
|
||||
validate_payload(payload)
|
||||
|
||||
assert "visibility" in str(exc_info.value)
|
||||
|
||||
def test_invalid_importance_range(self):
|
||||
"""Test that importance outside 0-1 raises error."""
|
||||
payload = create_minimal_payload(
|
||||
tenant_id="t_daarion",
|
||||
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
|
||||
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
|
||||
chunk_idx=0,
|
||||
fingerprint="sha256:abc123",
|
||||
)
|
||||
payload["importance"] = 1.5 # Invalid
|
||||
|
||||
with pytest.raises(PayloadValidationError) as exc_info:
|
||||
validate_payload(payload)
|
||||
|
||||
assert "importance" in str(exc_info.value)
|
||||
|
||||
def test_invalid_chunk_idx_negative(self):
|
||||
"""Test that negative chunk_idx raises error."""
|
||||
payload = create_minimal_payload(
|
||||
tenant_id="t_daarion",
|
||||
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
|
||||
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
|
||||
chunk_idx=-1, # Invalid
|
||||
fingerprint="sha256:abc123",
|
||||
)
|
||||
|
||||
with pytest.raises(PayloadValidationError) as exc_info:
|
||||
validate_payload(payload)
|
||||
|
||||
assert "chunk_idx" in str(exc_info.value)
|
||||
|
||||
def test_non_strict_mode(self):
|
||||
"""Test that non-strict mode returns payload with errors."""
|
||||
payload = {
|
||||
"schema_version": "cm_payload_v1",
|
||||
"tenant_id": "invalid", # Invalid format
|
||||
}
|
||||
|
||||
# Should not raise in non-strict mode
|
||||
result = validate_payload(payload, strict=False)
|
||||
assert result["schema_version"] == "cm_payload_v1"
|
||||
|
||||
def test_all_valid_scopes(self):
|
||||
"""Test that all valid scopes pass validation."""
|
||||
valid_scopes = ["docs", "messages", "memory", "artifacts", "signals"]
|
||||
|
||||
for scope in valid_scopes:
|
||||
payload = create_minimal_payload(
|
||||
tenant_id="t_daarion",
|
||||
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
|
||||
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
|
||||
chunk_idx=0,
|
||||
fingerprint="sha256:abc123",
|
||||
scope=scope,
|
||||
)
|
||||
result = validate_payload(payload)
|
||||
assert result["scope"] == scope
|
||||
|
||||
def test_all_valid_visibilities(self):
|
||||
"""Test that all valid visibilities pass validation."""
|
||||
valid_visibilities = ["public", "confidential", "private"]
|
||||
|
||||
for visibility in valid_visibilities:
|
||||
payload = create_minimal_payload(
|
||||
tenant_id="t_daarion",
|
||||
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
|
||||
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
|
||||
chunk_idx=0,
|
||||
fingerprint="sha256:abc123",
|
||||
visibility=visibility,
|
||||
)
|
||||
result = validate_payload(payload)
|
||||
assert result["visibility"] == visibility
|
||||
|
||||
|
||||
class TestCollectionNameMapping:
|
||||
"""Tests for legacy collection name to payload mapping."""
|
||||
|
||||
def test_parse_docs_collection(self):
|
||||
"""Test parsing *_docs collection names."""
|
||||
from scripts.qdrant_migrate_to_canonical import parse_collection_name
|
||||
|
||||
result = parse_collection_name("helion_docs")
|
||||
assert result is not None
|
||||
assert result["agent_id"] == "agt_helion"
|
||||
assert result["scope"] == "docs"
|
||||
assert result["tags"] == []
|
||||
|
||||
def test_parse_messages_collection(self):
|
||||
"""Test parsing *_messages collection names."""
|
||||
from scripts.qdrant_migrate_to_canonical import parse_collection_name
|
||||
|
||||
result = parse_collection_name("nutra_messages")
|
||||
assert result is not None
|
||||
assert result["agent_id"] == "agt_nutra"
|
||||
assert result["scope"] == "messages"
|
||||
|
||||
def test_parse_special_kb_collection(self):
|
||||
"""Test parsing special knowledge base collections."""
|
||||
from scripts.qdrant_migrate_to_canonical import parse_collection_name
|
||||
|
||||
result = parse_collection_name("druid_legal_kb")
|
||||
assert result is not None
|
||||
assert result["agent_id"] == "agt_druid"
|
||||
assert result["scope"] == "docs"
|
||||
assert "legal_kb" in result["tags"]
|
||||
|
||||
def test_parse_unknown_collection(self):
|
||||
"""Test parsing unknown collection returns None."""
|
||||
from scripts.qdrant_migrate_to_canonical import parse_collection_name
|
||||
|
||||
result = parse_collection_name("random_collection_xyz")
|
||||
# Should still try to match generic patterns or return None
|
||||
# Based on implementation, this might match *_xyz or return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
Reference in New Issue
Block a user