feat: Add Alateya, Clan, Eonarch agents + fix gateway-router connection

## Agents Added
- Alateya: R&D, biotech, innovations
- Clan (Spirit): Community spirit agent
- Eonarch: Consciousness evolution agent

## Changes
- docker-compose.node1.yml: Added tokens for all 3 new agents
- gateway-bot/http_api.py: Added configs and webhook endpoints
- gateway-bot/clan_prompt.txt: New prompt file
- gateway-bot/eonarch_prompt.txt: New prompt file

## Fixes
- Fixed ROUTER_URL from :9102 to :8000 (internal container port)
- All 9 Telegram agents now working

## Documentation
- Created PROJECT-MASTER-INDEX.md - single entry point
- Added various status documents and scripts

Tokens configured:
- Helion, NUTRA, Agromatrix (existing)
- Alateya, Clan, Eonarch (new)
- Druid, GreenFood, DAARWIZZ (configured)
This commit is contained in:
Apple
2026-01-28 06:40:34 -08:00
parent 4aeb69e7ae
commit 0c8bef82f4
120 changed files with 21905 additions and 425 deletions

View File

@@ -0,0 +1,43 @@
"""
Co-Memory Qdrant Module
Canonical Qdrant client with payload validation and filter building.
Security Invariants:
- tenant_id is ALWAYS required in filters
- indexed=true is default for search
- Empty should clause is NEVER allowed (would match everything)
- visibility=private is ONLY accessible by owner
"""
from .payload_validation import validate_payload, PayloadValidationError
from .collections import ensure_collection, get_canonical_collection_name
from .filters import (
build_qdrant_filter,
build_agent_only_filter,
build_multi_agent_filter,
build_project_filter,
build_tag_filter,
AccessContext,
FilterSecurityError,
)
from .client import CoMemoryQdrantClient
__all__ = [
# Validation
"validate_payload",
"PayloadValidationError",
# Collections
"ensure_collection",
"get_canonical_collection_name",
# Filters
"build_qdrant_filter",
"build_agent_only_filter",
"build_multi_agent_filter",
"build_project_filter",
"build_tag_filter",
"AccessContext",
"FilterSecurityError",
# Client
"CoMemoryQdrantClient",
]

View File

@@ -0,0 +1,413 @@
"""
Co-Memory Qdrant Client
High-level client for canonical Qdrant operations with validation and filtering.
"""
import logging
import os
from typing import Any, Dict, List, Optional, Tuple
from uuid import uuid4
from .payload_validation import validate_payload, PayloadValidationError
from .collections import ensure_collection, get_canonical_collection_name, list_legacy_collections
from .filters import AccessContext, build_qdrant_filter
try:
from qdrant_client import QdrantClient
from qdrant_client.models import (
PointStruct,
Filter,
SearchRequest,
ScoredPoint,
)
HAS_QDRANT = True
except ImportError:
HAS_QDRANT = False
QdrantClient = None
logger = logging.getLogger(__name__)
class CoMemoryQdrantClient:
"""
High-level Qdrant client for Co-Memory operations.
Features:
- Automatic payload validation
- Canonical collection management
- Access-controlled filtering
- Dual-write/dual-read support for migration
"""
def __init__(
self,
host: str = "localhost",
port: int = 6333,
url: Optional[str] = None,
api_key: Optional[str] = None,
text_dim: int = 1024,
text_metric: str = "cosine",
):
"""
Initialize Co-Memory Qdrant client.
Args:
host: Qdrant host
port: Qdrant port
url: Full Qdrant URL (overrides host/port)
api_key: Qdrant API key (optional)
text_dim: Default text embedding dimension
text_metric: Default distance metric
"""
if not HAS_QDRANT:
raise ImportError("qdrant-client not installed. Run: pip install qdrant-client")
# Load from env if not provided
host = host or os.getenv("QDRANT_HOST", "localhost")
port = port or int(os.getenv("QDRANT_PORT", "6333"))
url = url or os.getenv("QDRANT_URL")
api_key = api_key or os.getenv("QDRANT_API_KEY")
if url:
self._client = QdrantClient(url=url, api_key=api_key)
else:
self._client = QdrantClient(host=host, port=port, api_key=api_key)
self.text_dim = text_dim
self.text_metric = text_metric
self.text_collection = get_canonical_collection_name("text", text_dim)
# Feature flags
self.dual_write_enabled = os.getenv("DUAL_WRITE_OLD", "false").lower() == "true"
self.dual_read_enabled = os.getenv("DUAL_READ_OLD", "false").lower() == "true"
logger.info(f"CoMemoryQdrantClient initialized: {self.text_collection}")
@property
def client(self) -> "QdrantClient":
"""Get underlying Qdrant client."""
return self._client
def ensure_collections(self) -> None:
"""Ensure canonical collections exist."""
ensure_collection(
self._client,
self.text_collection,
self.text_dim,
self.text_metric,
)
logger.info(f"Ensured collection: {self.text_collection}")
def upsert_text(
self,
points: List[Dict[str, Any]],
validate: bool = True,
collection_name: Optional[str] = None,
) -> Dict[str, Any]:
"""
Upsert text embeddings to canonical collection.
Args:
points: List of dicts with 'id', 'vector', 'payload'
validate: Validate payloads before upsert
collection_name: Override collection name (for migration)
Returns:
Upsert result summary
"""
collection = collection_name or self.text_collection
valid_points = []
errors = []
for point in points:
point_id = point.get("id") or str(uuid4())
vector = point.get("vector")
payload = point.get("payload", {})
if not vector:
errors.append({"id": point_id, "error": "Missing vector"})
continue
# Validate payload
if validate:
try:
validate_payload(payload)
except PayloadValidationError as e:
errors.append({"id": point_id, "error": str(e), "details": e.errors})
continue
valid_points.append(PointStruct(
id=point_id,
vector=vector,
payload=payload,
))
if valid_points:
self._client.upsert(
collection_name=collection,
points=valid_points,
)
logger.info(f"Upserted {len(valid_points)} points to {collection}")
# Dual-write to legacy collections if enabled
if self.dual_write_enabled and collection_name is None:
self._dual_write_legacy(valid_points)
return {
"upserted": len(valid_points),
"errors": len(errors),
"error_details": errors if errors else None,
}
def _dual_write_legacy(self, points: List["PointStruct"]) -> None:
"""Write to legacy collections for migration compatibility."""
# Group points by legacy collection
legacy_points: Dict[str, List[PointStruct]] = {}
for point in points:
payload = point.payload
agent_id = payload.get("agent_id")
scope = payload.get("scope")
if agent_id and scope:
# Map to legacy collection name
agent_slug = agent_id.replace("agt_", "")
legacy_name = f"{agent_slug}_{scope}"
if legacy_name not in legacy_points:
legacy_points[legacy_name] = []
legacy_points[legacy_name].append(point)
# Write to legacy collections
for legacy_collection, pts in legacy_points.items():
try:
self._client.upsert(
collection_name=legacy_collection,
points=pts,
)
logger.debug(f"Dual-write: {len(pts)} points to {legacy_collection}")
except Exception as e:
logger.warning(f"Dual-write to {legacy_collection} failed: {e}")
def search_text(
self,
query_vector: List[float],
ctx: AccessContext,
limit: int = 10,
scope: Optional[str] = None,
scopes: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
score_threshold: Optional[float] = None,
with_payload: bool = True,
collection_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""
Search text embeddings with access control.
Args:
query_vector: Query embedding vector
ctx: Access context for filtering
limit: Maximum results
scope: Filter by scope (docs, messages, etc.)
scopes: Filter by multiple scopes
tags: Filter by tags
score_threshold: Minimum similarity score
with_payload: Include payload in results
collection_name: Override collection name
Returns:
List of search results
"""
collection = collection_name or self.text_collection
# Build access-controlled filter
filter_dict = build_qdrant_filter(
ctx=ctx,
scope=scope,
scopes=scopes,
tags=tags,
)
# Convert to Qdrant Filter
qdrant_filter = self._dict_to_filter(filter_dict)
# Search
results = self._client.search(
collection_name=collection,
query_vector=query_vector,
query_filter=qdrant_filter,
limit=limit,
score_threshold=score_threshold,
with_payload=with_payload,
)
# Convert results
output = []
for result in results:
output.append({
"id": result.id,
"score": result.score,
"payload": result.payload if with_payload else None,
})
# Dual-read from legacy if enabled and no results
if self.dual_read_enabled and not output:
legacy_results = self._dual_read_legacy(
query_vector, ctx, limit, scope, scopes, tags, score_threshold
)
output.extend(legacy_results)
return output
def _dual_read_legacy(
self,
query_vector: List[float],
ctx: AccessContext,
limit: int,
scope: Optional[str],
scopes: Optional[List[str]],
tags: Optional[List[str]],
score_threshold: Optional[float],
) -> List[Dict[str, Any]]:
"""Fallback read from legacy collections."""
results = []
# Determine which legacy collections to search
legacy_collections = []
if ctx.agent_id:
agent_slug = ctx.agent_id.replace("agt_", "")
if scope:
legacy_collections.append(f"{agent_slug}_{scope}")
elif scopes:
for s in scopes:
legacy_collections.append(f"{agent_slug}_{s}")
else:
legacy_collections.append(f"{agent_slug}_docs")
legacy_collections.append(f"{agent_slug}_messages")
for legacy_collection in legacy_collections:
try:
legacy_results = self._client.search(
collection_name=legacy_collection,
query_vector=query_vector,
limit=limit,
score_threshold=score_threshold,
with_payload=True,
)
for result in legacy_results:
results.append({
"id": result.id,
"score": result.score,
"payload": result.payload,
"_legacy_collection": legacy_collection,
})
except Exception as e:
logger.debug(f"Legacy read from {legacy_collection} failed: {e}")
return results
def _dict_to_filter(self, filter_dict: Dict[str, Any]) -> "Filter":
"""Convert filter dictionary to Qdrant Filter object."""
from qdrant_client.models import Filter, FieldCondition, MatchValue, MatchAny
def build_condition(cond: Dict[str, Any]):
key = cond.get("key")
match = cond.get("match", {})
if "value" in match:
return FieldCondition(
key=key,
match=MatchValue(value=match["value"])
)
elif "any" in match:
return FieldCondition(
key=key,
match=MatchAny(any=match["any"])
)
return None
def build_conditions(conditions: List[Dict]) -> List:
result = []
for cond in conditions:
if "must" in cond:
# Nested filter
nested = Filter(must=[build_condition(c) for c in cond["must"] if build_condition(c)])
if "must_not" in cond:
nested.must_not = [build_condition(c) for c in cond["must_not"] if build_condition(c)]
result.append(nested)
else:
built = build_condition(cond)
if built:
result.append(built)
return result
must = build_conditions(filter_dict.get("must", []))
should = build_conditions(filter_dict.get("should", []))
must_not = build_conditions(filter_dict.get("must_not", []))
return Filter(
must=must if must else None,
should=should if should else None,
must_not=must_not if must_not else None,
)
def delete_points(
self,
point_ids: List[str],
collection_name: Optional[str] = None,
) -> int:
"""
Delete points by IDs.
Args:
point_ids: List of point IDs to delete
collection_name: Override collection name
Returns:
Number of points deleted
"""
collection = collection_name or self.text_collection
self._client.delete(
collection_name=collection,
points_selector=point_ids,
)
return len(point_ids)
def get_collection_stats(self, collection_name: Optional[str] = None) -> Dict[str, Any]:
"""Get collection statistics."""
collection = collection_name or self.text_collection
try:
info = self._client.get_collection(collection)
return {
"name": collection,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value,
}
except Exception as e:
return {"error": str(e)}
def list_all_collections(self) -> Dict[str, List[str]]:
"""List all collections categorized as canonical or legacy."""
collections = self._client.get_collections().collections
canonical = []
legacy = []
for col in collections:
if col.name.startswith("cm_"):
canonical.append(col.name)
else:
legacy.append(col.name)
return {
"canonical": canonical,
"legacy": legacy,
}

View File

@@ -0,0 +1,260 @@
"""
Qdrant Collection Management for Co-Memory
Handles canonical collection creation and configuration.
"""
import logging
from typing import Any, Dict, List, Optional
try:
from qdrant_client import QdrantClient
from qdrant_client.models import (
Distance,
VectorParams,
PayloadSchemaType,
TextIndexParams,
TokenizerType,
)
HAS_QDRANT = True
except ImportError:
HAS_QDRANT = False
logger = logging.getLogger(__name__)
# Canonical collection naming
COLLECTION_PREFIX = "cm"
COLLECTION_VERSION = "v1"
def get_canonical_collection_name(
collection_type: str = "text",
dim: int = 1024,
version: str = COLLECTION_VERSION
) -> str:
"""
Generate canonical collection name.
Args:
collection_type: Type of embeddings (text, code, mm)
dim: Vector dimension
version: Schema version
Returns:
Collection name like "cm_text_1024_v1"
"""
return f"{COLLECTION_PREFIX}_{collection_type}_{dim}_{version}"
def get_distance_metric(metric: str) -> "Distance":
"""Convert metric string to Qdrant Distance enum."""
if not HAS_QDRANT:
raise ImportError("qdrant-client not installed")
metrics = {
"cosine": Distance.COSINE,
"dot": Distance.DOT,
"euclidean": Distance.EUCLID,
}
return metrics.get(metric.lower(), Distance.COSINE)
# Default payload indexes for optimal query performance
DEFAULT_PAYLOAD_INDEXES = [
{"field": "tenant_id", "type": "keyword"},
{"field": "team_id", "type": "keyword"},
{"field": "project_id", "type": "keyword"},
{"field": "agent_id", "type": "keyword"},
{"field": "scope", "type": "keyword"},
{"field": "visibility", "type": "keyword"},
{"field": "indexed", "type": "bool"},
{"field": "source_id", "type": "keyword"},
{"field": "owner_kind", "type": "keyword"},
{"field": "owner_id", "type": "keyword"},
{"field": "tags", "type": "keyword"},
{"field": "acl.read_team_ids", "type": "keyword"},
{"field": "acl.read_agent_ids", "type": "keyword"},
{"field": "acl.read_role_ids", "type": "keyword"},
]
def ensure_collection(
client: "QdrantClient",
name: str,
dim: int,
metric: str = "cosine",
payload_indexes: Optional[List[Dict[str, str]]] = None,
on_disk: bool = True,
) -> bool:
"""
Ensure a canonical collection exists with proper configuration.
Args:
client: Qdrant client instance
name: Collection name
dim: Vector dimension
metric: Distance metric (cosine, dot, euclidean)
payload_indexes: List of payload fields to index
on_disk: Whether to store vectors on disk
Returns:
True if collection was created, False if already exists
"""
if not HAS_QDRANT:
raise ImportError("qdrant-client not installed")
# Check if collection exists
collections = client.get_collections().collections
existing_names = [c.name for c in collections]
if name in existing_names:
logger.info(f"Collection '{name}' already exists")
# Ensure payload indexes
_ensure_payload_indexes(client, name, payload_indexes or DEFAULT_PAYLOAD_INDEXES)
return False
# Create collection
logger.info(f"Creating collection '{name}' with dim={dim}, metric={metric}")
client.create_collection(
collection_name=name,
vectors_config=VectorParams(
size=dim,
distance=get_distance_metric(metric),
on_disk=on_disk,
),
)
# Create payload indexes
_ensure_payload_indexes(client, name, payload_indexes or DEFAULT_PAYLOAD_INDEXES)
logger.info(f"Collection '{name}' created successfully")
return True
def _ensure_payload_indexes(
client: "QdrantClient",
collection_name: str,
indexes: List[Dict[str, str]]
) -> None:
"""
Ensure payload indexes exist on collection.
Args:
client: Qdrant client
collection_name: Collection name
indexes: List of index configurations
"""
if not HAS_QDRANT:
return
for index_config in indexes:
field_name = index_config["field"]
field_type = index_config.get("type", "keyword")
try:
if field_type == "keyword":
client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=PayloadSchemaType.KEYWORD,
)
elif field_type == "bool":
client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=PayloadSchemaType.BOOL,
)
elif field_type == "integer":
client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=PayloadSchemaType.INTEGER,
)
elif field_type == "float":
client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=PayloadSchemaType.FLOAT,
)
elif field_type == "datetime":
client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=PayloadSchemaType.DATETIME,
)
elif field_type == "text":
client.create_payload_index(
collection_name=collection_name,
field_name=field_name,
field_schema=TextIndexParams(
type="text",
tokenizer=TokenizerType.WORD,
min_token_len=2,
max_token_len=15,
),
)
logger.debug(f"Created payload index: {field_name} ({field_type})")
except Exception as e:
# Index might already exist
if "already exists" not in str(e).lower():
logger.warning(f"Failed to create index {field_name}: {e}")
def get_collection_info(client: "QdrantClient", name: str) -> Optional[Dict[str, Any]]:
"""
Get collection information.
Args:
client: Qdrant client
name: Collection name
Returns:
Collection info dict or None if not found
"""
if not HAS_QDRANT:
raise ImportError("qdrant-client not installed")
try:
info = client.get_collection(name)
return {
"name": name,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value,
"config": {
"size": info.config.params.vectors.size,
"distance": info.config.params.vectors.distance.value,
}
}
except Exception:
return None
def list_legacy_collections(client: "QdrantClient") -> List[str]:
"""
List all legacy (non-canonical) collections.
Args:
client: Qdrant client
Returns:
List of legacy collection names
"""
if not HAS_QDRANT:
raise ImportError("qdrant-client not installed")
collections = client.get_collections().collections
legacy = []
for col in collections:
# Canonical collections start with "cm_"
if not col.name.startswith(f"{COLLECTION_PREFIX}_"):
legacy.append(col.name)
return legacy

View File

@@ -0,0 +1,541 @@
"""
Qdrant Filter Builder for Co-Memory
Builds complex Qdrant filters based on access context and query requirements.
SECURITY INVARIANTS:
- tenant_id is ALWAYS in must conditions
- indexed=true is default in must conditions
- Empty should list is NEVER returned (would match everything)
- visibility=private ONLY accessible by owner, NEVER leaked to multi-agent queries
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
class FilterSecurityError(Exception):
"""Raised when filter would violate security invariants."""
pass
@dataclass
class AccessContext:
"""
Context for building access-controlled filters.
Represents "who is asking" for the query.
"""
tenant_id: str
team_id: Optional[str] = None
project_id: Optional[str] = None
agent_id: Optional[str] = None
user_id: Optional[str] = None
role_ids: List[str] = field(default_factory=list)
# Access permissions
allowed_agent_ids: List[str] = field(default_factory=list)
allowed_team_ids: List[str] = field(default_factory=list)
# Query constraints
is_admin: bool = False
def has_identity(self) -> bool:
"""Check if context has at least one identity for access control."""
return bool(
self.agent_id or
self.user_id or
self.team_id or
self.allowed_agent_ids or
self.is_admin
)
def build_qdrant_filter(
ctx: AccessContext,
scope: Optional[str] = None,
scopes: Optional[List[str]] = None,
visibility: Optional[str] = None,
visibilities: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
source_id: Optional[str] = None,
channel_id: Optional[str] = None,
indexed_only: bool = True,
include_private: bool = False,
) -> Dict[str, Any]:
"""
Build a Qdrant filter for canonical collection queries.
Args:
ctx: Access context (who is querying)
scope: Single scope to filter (docs, messages, etc.)
scopes: Multiple scopes to filter
visibility: Single visibility filter
visibilities: Multiple visibilities to filter
tags: Tags to filter by (ANY match)
source_id: Specific source ID
channel_id: Channel/chat ID for messages
indexed_only: Only return indexed=true items
include_private: Include private items (owner only)
Returns:
Qdrant filter dictionary
Raises:
FilterSecurityError: If filter would be insecure (e.g., empty should)
"""
# SECURITY: tenant_id is ALWAYS required, even for admin
if not ctx.tenant_id:
raise FilterSecurityError("tenant_id is required and cannot be None/empty")
must_conditions = []
must_not_conditions = []
# INVARIANT: Always filter by tenant (even for admin)
must_conditions.append({
"key": "tenant_id",
"match": {"value": ctx.tenant_id}
})
# INVARIANT: Default only indexed items (indexed=true unless explicitly disabled)
if indexed_only:
must_conditions.append({
"key": "indexed",
"match": {"value": True}
})
# Scope filter
if scope:
must_conditions.append({
"key": "scope",
"match": {"value": scope}
})
elif scopes:
must_conditions.append({
"key": "scope",
"match": {"any": scopes}
})
# Source ID filter
if source_id:
must_conditions.append({
"key": "source_id",
"match": {"value": source_id}
})
# Channel filter
if channel_id:
must_conditions.append({
"key": "channel_id",
"match": {"value": channel_id}
})
# Build access control filter
access_should = _build_access_filter(
ctx=ctx,
visibility=visibility,
visibilities=visibilities,
include_private=include_private,
)
# SECURITY: Validate access filter is not empty (would match everything)
if not access_should and not ctx.is_admin:
raise FilterSecurityError(
"Access filter is empty - would return all documents. "
"Context must have at least one identity (agent_id, user_id, team_id) "
"or is_admin=True"
)
# Combine filters
filter_dict = {"must": must_conditions}
if access_should:
filter_dict["should"] = access_should
filter_dict["minimum_should_match"] = 1
# Tags filter (added to must for AND behavior, not OR)
if tags:
# Tags require ALL specified tags to match
for tag in tags:
must_conditions.append({
"key": "tags",
"match": {"value": tag}
})
if must_not_conditions:
filter_dict["must_not"] = must_not_conditions
return filter_dict
def _build_access_filter(
ctx: AccessContext,
visibility: Optional[str] = None,
visibilities: Optional[List[str]] = None,
include_private: bool = False,
) -> List[Dict[str, Any]]:
"""
Build access control filter based on visibility and ACL.
Returns list of should conditions for OR matching.
SECURITY: Never returns empty list unless ctx.is_admin
"""
should = []
# Admin access: still respects tenant isolation, but has broader visibility
# SECURITY: Admin without explicit visibility sees public+confidential by default
# To see private, admin must explicitly request visibility="private" or include_private=True
if ctx.is_admin:
if visibility:
should.append({
"must": [{"key": "visibility", "match": {"value": visibility}}]
})
elif visibilities:
should.append({
"must": [{"key": "visibility", "match": {"any": visibilities}}]
})
elif include_private:
# Admin explicitly requested private access
should.append({
"must": [{"key": "visibility", "match": {"any": ["public", "confidential", "private"]}}]
})
else:
# Admin default: public + confidential only (no private leak by default)
should.append({
"must": [{"key": "visibility", "match": {"any": ["public", "confidential"]}}]
})
return should
# Determine allowed visibilities
allowed_vis = set()
if visibility:
allowed_vis.add(visibility)
elif visibilities:
allowed_vis.update(visibilities)
else:
allowed_vis = {"public", "confidential"}
# SECURITY: private only if explicitly requested AND owner identity exists
if include_private and (ctx.agent_id or ctx.user_id):
allowed_vis.add("private")
# 1. Public content in same team
if "public" in allowed_vis and ctx.team_id:
should.append({
"must": [
{"key": "team_id", "match": {"value": ctx.team_id}},
{"key": "visibility", "match": {"value": "public"}}
]
})
# 2. Own content (owner match) - ONLY way to access private
if ctx.agent_id:
own_vis = ["public", "confidential"]
if "private" in allowed_vis:
own_vis.append("private")
should.append({
"must": [
{"key": "owner_kind", "match": {"value": "agent"}},
{"key": "owner_id", "match": {"value": ctx.agent_id}},
{"key": "visibility", "match": {"any": own_vis}}
]
})
if ctx.user_id:
own_vis = ["public", "confidential"]
if "private" in allowed_vis:
own_vis.append("private")
should.append({
"must": [
{"key": "owner_kind", "match": {"value": "user"}},
{"key": "owner_id", "match": {"value": ctx.user_id}},
{"key": "visibility", "match": {"any": own_vis}}
]
})
# 3. Confidential with ACL access (NEVER private via ACL)
if "confidential" in allowed_vis:
# Access via agent ACL
if ctx.agent_id:
should.append({
"must": [
{"key": "visibility", "match": {"value": "confidential"}},
{"key": "acl.read_agent_ids", "match": {"value": ctx.agent_id}}
]
})
# Access via team ACL
if ctx.team_id:
should.append({
"must": [
{"key": "visibility", "match": {"value": "confidential"}},
{"key": "acl.read_team_ids", "match": {"value": ctx.team_id}}
]
})
# Access via role ACL
for role_id in ctx.role_ids:
should.append({
"must": [
{"key": "visibility", "match": {"value": "confidential"}},
{"key": "acl.read_role_ids", "match": {"value": role_id}}
]
})
# 4. Cross-agent access (if allowed_agent_ids specified)
# SECURITY: NEVER includes private - only public+confidential
if ctx.allowed_agent_ids:
cross_vis = []
if "public" in allowed_vis:
cross_vis.append("public")
if "confidential" in allowed_vis:
cross_vis.append("confidential")
if cross_vis:
should.append({
"must": [
{"key": "agent_id", "match": {"any": ctx.allowed_agent_ids}},
{"key": "visibility", "match": {"any": cross_vis}}
]
})
return should
def build_agent_only_filter(
ctx: AccessContext,
agent_id: str,
scope: Optional[str] = None,
include_team_public: bool = True,
include_own_public: bool = True,
) -> Dict[str, Any]:
"""
Build filter for agent reading only its own content + optional team public.
Use case: Agent answering user, sees only own knowledge.
SECURITY: Private content is only accessible if agent_id matches owner_id.
Args:
ctx: Access context
agent_id: The agent requesting data (must match for private access)
scope: Optional scope filter
include_team_public: Include public content from team
include_own_public: Include own public content (default True)
"""
# SECURITY: tenant_id always required
if not ctx.tenant_id:
raise FilterSecurityError("tenant_id is required and cannot be None/empty")
if not agent_id:
raise FilterSecurityError("agent_id is required for build_agent_only_filter")
must = [
{"key": "tenant_id", "match": {"value": ctx.tenant_id}},
{"key": "indexed", "match": {"value": True}}
]
if scope:
must.append({"key": "scope", "match": {"value": scope}})
should = []
# Own content: confidential + private (agent is owner)
should.append({
"must": [
{"key": "owner_kind", "match": {"value": "agent"}},
{"key": "owner_id", "match": {"value": agent_id}},
{"key": "visibility", "match": {"any": ["confidential", "private"]}}
]
})
# Own public content (if enabled)
if include_own_public:
should.append({
"must": [
{"key": "agent_id", "match": {"value": agent_id}},
{"key": "visibility", "match": {"value": "public"}}
]
})
# Team public content (if enabled and team exists)
if include_team_public and ctx.team_id:
should.append({
"must": [
{"key": "team_id", "match": {"value": ctx.team_id}},
{"key": "visibility", "match": {"value": "public"}}
]
})
# SECURITY: Ensure should is never empty
if not should:
raise FilterSecurityError("Filter would have empty should clause")
return {
"must": must,
"should": should,
"minimum_should_match": 1
}
def build_multi_agent_filter(
ctx: AccessContext,
agent_ids: List[str],
scope: Optional[str] = None,
) -> Dict[str, Any]:
"""
Build filter for reading from multiple agents (aggregator use case).
Use case: DAARWIZZ orchestrator needs cross-agent retrieval.
SECURITY INVARIANTS:
- Private content is ALWAYS excluded (hard-coded must_not)
- Only public + confidential from specified agents
- No parameter to override this (by design)
- should blocks NEVER include visibility=private
- agent_ids MUST be subset of ctx.allowed_agent_ids (unless admin)
Args:
ctx: Access context (must have allowed_agent_ids or is_admin=True)
agent_ids: List of agents to read from (must be allowed)
scope: Optional scope filter
Raises:
FilterSecurityError: If agent_ids contains unauthorized agents
"""
# SECURITY: tenant_id always required
if not ctx.tenant_id:
raise FilterSecurityError("tenant_id is required and cannot be None/empty")
if not agent_ids:
raise FilterSecurityError("agent_ids cannot be empty for build_multi_agent_filter")
# SECURITY: Validate agent_ids are allowed for this context
if not ctx.is_admin:
if not ctx.allowed_agent_ids:
raise FilterSecurityError(
"build_multi_agent_filter requires ctx.allowed_agent_ids or ctx.is_admin=True. "
"Cannot access arbitrary agents without explicit permission."
)
requested_set = set(agent_ids)
allowed_set = set(ctx.allowed_agent_ids)
unauthorized = requested_set - allowed_set
if unauthorized:
raise FilterSecurityError(
f"Unauthorized agent access: {sorted(unauthorized)}. "
f"Allowed agents: {sorted(allowed_set)}"
)
must = [
{"key": "tenant_id", "match": {"value": ctx.tenant_id}},
{"key": "indexed", "match": {"value": True}}
]
if scope:
must.append({"key": "scope", "match": {"value": scope}})
should = [
# Content from allowed agents (public + confidential ONLY)
{
"must": [
{"key": "agent_id", "match": {"any": agent_ids}},
{"key": "visibility", "match": {"any": ["public", "confidential"]}}
]
}
]
# Team public content
if ctx.team_id:
should.append({
"must": [
{"key": "team_id", "match": {"value": ctx.team_id}},
{"key": "visibility", "match": {"value": "public"}}
]
})
return {
"must": must,
"should": should,
"minimum_should_match": 1,
# SECURITY: ALWAYS exclude private - no parameter to override
"must_not": [
{"key": "visibility", "match": {"value": "private"}}
]
}
def build_project_filter(
ctx: AccessContext,
project_id: str,
scope: Optional[str] = None,
) -> Dict[str, Any]:
"""
Build filter for project-scoped retrieval.
Use case: Reading only within one project.
SECURITY: Private is always excluded.
"""
# SECURITY: tenant_id always required
if not ctx.tenant_id:
raise FilterSecurityError("tenant_id is required and cannot be None/empty")
if not project_id:
raise FilterSecurityError("project_id is required for build_project_filter")
must = [
{"key": "tenant_id", "match": {"value": ctx.tenant_id}},
{"key": "project_id", "match": {"value": project_id}},
{"key": "indexed", "match": {"value": True}},
# SECURITY: Only public + confidential, never private
{"key": "visibility", "match": {"any": ["public", "confidential"]}}
]
if scope:
must.append({"key": "scope", "match": {"value": scope}})
return {
"must": must,
"must_not": [
{"key": "visibility", "match": {"value": "private"}}
]
}
def build_tag_filter(
ctx: AccessContext,
tags: List[str],
scope: str = "docs",
visibility: Optional[str] = None,
) -> Dict[str, Any]:
"""
Build filter for tag-based retrieval.
Use case: Druid legal KB only, Nutra food knowledge.
SECURITY: Private is excluded by default.
"""
# SECURITY: tenant_id always required
if not ctx.tenant_id:
raise FilterSecurityError("tenant_id is required and cannot be None/empty")
if not tags:
raise FilterSecurityError("tags cannot be empty for build_tag_filter")
must = [
{"key": "tenant_id", "match": {"value": ctx.tenant_id}},
{"key": "scope", "match": {"value": scope}},
{"key": "indexed", "match": {"value": True}},
]
# Visibility: default to public+confidential, never private unless explicitly specified
if visibility:
must.append({"key": "visibility", "match": {"value": visibility}})
else:
must.append({"key": "visibility", "match": {"any": ["public", "confidential"]}})
# Add tag conditions (all must match)
for tag in tags:
must.append({"key": "tags", "match": {"value": tag}})
return {"must": must}

View File

@@ -0,0 +1,283 @@
"""
Payload Validation for Co-Memory Qdrant
Validates payloads against cm_payload_v1 schema before upsert.
"""
import json
import re
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
# Try to use jsonschema if available, otherwise use manual validation
try:
import jsonschema
HAS_JSONSCHEMA = True
except ImportError:
HAS_JSONSCHEMA = False
class PayloadValidationError(Exception):
"""Raised when payload validation fails."""
def __init__(self, message: str, errors: Optional[List[str]] = None):
super().__init__(message)
self.errors = errors or []
# Enums
VALID_SCOPES = {"docs", "messages", "memory", "artifacts", "signals"}
VALID_VISIBILITY = {"public", "confidential", "private"}
VALID_OWNER_KINDS = {"user", "team", "agent"}
VALID_SOURCE_KINDS = {"document", "wiki", "message", "artifact", "web", "code"}
VALID_METRICS = {"cosine", "dot", "euclidean"}
# ID patterns
TENANT_ID_PATTERN = re.compile(r"^t_[a-z0-9_]+$")
TEAM_ID_PATTERN = re.compile(r"^team_[a-z0-9_]+$")
PROJECT_ID_PATTERN = re.compile(r"^proj_[a-z0-9_]+$")
AGENT_ID_PATTERN = re.compile(r"^agt_[a-z0-9_]+$")
SOURCE_ID_PATTERN = re.compile(r"^(doc|msg|art|web|code)_[A-Za-z0-9]+$")
CHUNK_ID_PATTERN = re.compile(r"^chk_[A-Za-z0-9]+$")
def _load_json_schema() -> Optional[Dict]:
"""Load JSON schema from file if available."""
schema_path = Path(__file__).parent.parent.parent.parent / "docs" / "memory" / "cm_payload_v1.schema.json"
if schema_path.exists():
with open(schema_path) as f:
return json.load(f)
return None
_SCHEMA = _load_json_schema()
def validate_payload(payload: Dict[str, Any], strict: bool = True) -> Dict[str, Any]:
"""
Validate payload against cm_payload_v1 schema.
Args:
payload: The payload dictionary to validate
strict: If True, raise exception on validation failure
Returns:
The validated payload (potentially with defaults added)
Raises:
PayloadValidationError: If validation fails and strict=True
"""
errors = []
# Use jsonschema if available
if HAS_JSONSCHEMA and _SCHEMA:
try:
jsonschema.validate(payload, _SCHEMA)
except jsonschema.ValidationError as e:
errors.append(f"Schema validation: {e.message}")
else:
# Manual validation
errors.extend(_validate_required_fields(payload))
errors.extend(_validate_field_values(payload))
if errors and strict:
raise PayloadValidationError(
f"Payload validation failed: {len(errors)} error(s)",
errors=errors
)
return payload
def _validate_required_fields(payload: Dict[str, Any]) -> List[str]:
"""Validate required fields are present."""
errors = []
required = [
"schema_version",
"tenant_id",
"owner_kind",
"owner_id",
"scope",
"visibility",
"indexed",
"source_kind",
"source_id",
"chunk",
"fingerprint",
"created_at",
]
for field in required:
if field not in payload:
errors.append(f"Missing required field: {field}")
# Check nested required fields
if "chunk" in payload and isinstance(payload["chunk"], dict):
if "chunk_id" not in payload["chunk"]:
errors.append("Missing required field: chunk.chunk_id")
if "chunk_idx" not in payload["chunk"]:
errors.append("Missing required field: chunk.chunk_idx")
return errors
def _validate_field_values(payload: Dict[str, Any]) -> List[str]:
"""Validate field values match expected formats."""
errors = []
# Schema version
if payload.get("schema_version") != "cm_payload_v1":
errors.append(f"Invalid schema_version: {payload.get('schema_version')}, expected 'cm_payload_v1'")
# Tenant ID
tenant_id = payload.get("tenant_id")
if tenant_id and not TENANT_ID_PATTERN.match(tenant_id):
errors.append(f"Invalid tenant_id format: {tenant_id}")
# Team ID (optional)
team_id = payload.get("team_id")
if team_id and not TEAM_ID_PATTERN.match(team_id):
errors.append(f"Invalid team_id format: {team_id}")
# Project ID (optional)
project_id = payload.get("project_id")
if project_id and not PROJECT_ID_PATTERN.match(project_id):
errors.append(f"Invalid project_id format: {project_id}")
# Agent ID (optional)
agent_id = payload.get("agent_id")
if agent_id and not AGENT_ID_PATTERN.match(agent_id):
errors.append(f"Invalid agent_id format: {agent_id}")
# Scope
if payload.get("scope") not in VALID_SCOPES:
errors.append(f"Invalid scope: {payload.get('scope')}, valid: {VALID_SCOPES}")
# Visibility
if payload.get("visibility") not in VALID_VISIBILITY:
errors.append(f"Invalid visibility: {payload.get('visibility')}, valid: {VALID_VISIBILITY}")
# Owner kind
if payload.get("owner_kind") not in VALID_OWNER_KINDS:
errors.append(f"Invalid owner_kind: {payload.get('owner_kind')}, valid: {VALID_OWNER_KINDS}")
# Source kind
if payload.get("source_kind") not in VALID_SOURCE_KINDS:
errors.append(f"Invalid source_kind: {payload.get('source_kind')}, valid: {VALID_SOURCE_KINDS}")
# Source ID
source_id = payload.get("source_id")
if source_id and not SOURCE_ID_PATTERN.match(source_id):
errors.append(f"Invalid source_id format: {source_id}")
# Chunk
chunk = payload.get("chunk", {})
if isinstance(chunk, dict):
chunk_id = chunk.get("chunk_id")
if chunk_id and not CHUNK_ID_PATTERN.match(chunk_id):
errors.append(f"Invalid chunk.chunk_id format: {chunk_id}")
chunk_idx = chunk.get("chunk_idx")
if chunk_idx is not None and (not isinstance(chunk_idx, int) or chunk_idx < 0):
errors.append(f"Invalid chunk.chunk_idx: {chunk_idx}, must be non-negative integer")
# Indexed
if not isinstance(payload.get("indexed"), bool):
errors.append(f"Invalid indexed: {payload.get('indexed')}, must be boolean")
# Created at
created_at = payload.get("created_at")
if created_at:
try:
datetime.fromisoformat(created_at.replace("Z", "+00:00"))
except (ValueError, AttributeError):
errors.append(f"Invalid created_at format: {created_at}, expected ISO 8601")
# Embedding (optional)
embedding = payload.get("embedding", {})
if isinstance(embedding, dict):
if "metric" in embedding and embedding["metric"] not in VALID_METRICS:
errors.append(f"Invalid embedding.metric: {embedding['metric']}, valid: {VALID_METRICS}")
if "dim" in embedding and (not isinstance(embedding["dim"], int) or embedding["dim"] < 1):
errors.append(f"Invalid embedding.dim: {embedding['dim']}, must be positive integer")
# Importance (optional)
importance = payload.get("importance")
if importance is not None and (not isinstance(importance, (int, float)) or importance < 0 or importance > 1):
errors.append(f"Invalid importance: {importance}, must be 0-1")
# TTL days (optional)
ttl_days = payload.get("ttl_days")
if ttl_days is not None and (not isinstance(ttl_days, int) or ttl_days < 1):
errors.append(f"Invalid ttl_days: {ttl_days}, must be positive integer")
# ACL fields (must be arrays of non-empty strings, no nulls)
acl = payload.get("acl", {})
if isinstance(acl, dict):
for acl_field in ["read_team_ids", "read_agent_ids", "read_role_ids"]:
value = acl.get(acl_field)
if value is not None:
if not isinstance(value, list):
errors.append(f"Invalid acl.{acl_field}: must be array, got {type(value).__name__}")
elif not all(isinstance(item, str) and item for item in value):
# Check: all items must be non-empty strings (no None, no "")
errors.append(f"Invalid acl.{acl_field}: all items must be non-empty strings (no null/empty)")
elif acl is not None:
errors.append(f"Invalid acl: must be object, got {type(acl).__name__}")
# Tags must be array of non-empty strings
tags = payload.get("tags")
if tags is not None:
if not isinstance(tags, list):
errors.append(f"Invalid tags: must be array, got {type(tags).__name__}")
elif not all(isinstance(item, str) and item for item in tags):
errors.append(f"Invalid tags: all items must be non-empty strings (no null/empty)")
return errors
def create_minimal_payload(
tenant_id: str,
source_id: str,
chunk_id: str,
chunk_idx: int,
fingerprint: str,
scope: str = "docs",
visibility: str = "confidential",
owner_kind: str = "team",
owner_id: Optional[str] = None,
agent_id: Optional[str] = None,
team_id: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""
Create a minimal valid payload with required fields.
Returns a payload that passes validation.
"""
payload = {
"schema_version": "cm_payload_v1",
"tenant_id": tenant_id,
"team_id": team_id,
"agent_id": agent_id,
"owner_kind": owner_kind,
"owner_id": owner_id or team_id or tenant_id,
"scope": scope,
"visibility": visibility,
"indexed": True,
"source_kind": "document",
"source_id": source_id,
"chunk": {
"chunk_id": chunk_id,
"chunk_idx": chunk_idx,
},
"fingerprint": fingerprint,
"created_at": datetime.utcnow().isoformat() + "Z",
}
# Add optional fields
payload.update(kwargs)
return payload

View File

@@ -0,0 +1 @@
# Tests for Co-Memory Qdrant module

View File

@@ -0,0 +1,446 @@
"""
Unit tests for Co-Memory Qdrant filters.
Tests security invariants:
- Empty should never returned (would match everything)
- Private content only accessible by owner
- tenant_id always in must conditions
"""
import pytest
from services.memory.qdrant.filters import (
AccessContext,
FilterSecurityError,
build_qdrant_filter,
build_agent_only_filter,
build_multi_agent_filter,
build_project_filter,
build_tag_filter,
)
class TestAccessContext:
"""Tests for AccessContext dataclass."""
def test_minimal_context(self):
"""Test creating minimal access context."""
ctx = AccessContext(tenant_id="t_daarion")
assert ctx.tenant_id == "t_daarion"
assert ctx.team_id is None
assert ctx.agent_id is None
assert ctx.is_admin is False
def test_full_context(self):
"""Test creating full access context."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
project_id="proj_helion",
agent_id="agt_helion",
user_id="user_123",
role_ids=["role_admin"],
allowed_agent_ids=["agt_nutra", "agt_druid"],
is_admin=True,
)
assert ctx.tenant_id == "t_daarion"
assert ctx.team_id == "team_core"
assert ctx.agent_id == "agt_helion"
assert len(ctx.allowed_agent_ids) == 2
class TestBuildQdrantFilter:
"""Tests for build_qdrant_filter function."""
def test_basic_filter_with_identity(self):
"""Test basic filter with tenant and identity."""
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
result = build_qdrant_filter(ctx)
assert "must" in result
# Check tenant_id is in must conditions
tenant_condition = next(
(c for c in result["must"] if c.get("key") == "tenant_id"),
None
)
assert tenant_condition is not None
assert tenant_condition["match"]["value"] == "t_daarion"
def test_empty_context_raises_security_error(self):
"""Test that context without identity raises FilterSecurityError."""
ctx = AccessContext(tenant_id="t_daarion")
# No team_id, agent_id, user_id, or allowed_agent_ids
with pytest.raises(FilterSecurityError) as exc_info:
build_qdrant_filter(ctx)
assert "empty" in str(exc_info.value).lower()
def test_admin_bypasses_identity_check(self):
"""Test that admin context doesn't require identity."""
ctx = AccessContext(tenant_id="t_daarion", is_admin=True)
# Should not raise
result = build_qdrant_filter(ctx)
assert "must" in result
def test_indexed_only_filter(self):
"""Test that indexed=true is included by default."""
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
result = build_qdrant_filter(ctx, indexed_only=True)
indexed_condition = next(
(c for c in result["must"] if c.get("key") == "indexed"),
None
)
assert indexed_condition is not None
assert indexed_condition["match"]["value"] is True
def test_scope_filter(self):
"""Test single scope filter."""
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
result = build_qdrant_filter(ctx, scope="docs")
scope_condition = next(
(c for c in result["must"] if c.get("key") == "scope"),
None
)
assert scope_condition is not None
assert scope_condition["match"]["value"] == "docs"
def test_multiple_scopes_filter(self):
"""Test multiple scopes filter."""
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
result = build_qdrant_filter(ctx, scopes=["docs", "messages"])
scope_condition = next(
(c for c in result["must"] if c.get("key") == "scope"),
None
)
assert scope_condition is not None
assert scope_condition["match"]["any"] == ["docs", "messages"]
def test_source_id_filter(self):
"""Test source_id filter."""
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
result = build_qdrant_filter(ctx, source_id="doc_123")
source_condition = next(
(c for c in result["must"] if c.get("key") == "source_id"),
None
)
assert source_condition is not None
assert source_condition["match"]["value"] == "doc_123"
def test_admin_access(self):
"""Test admin has broader access."""
ctx = AccessContext(
tenant_id="t_daarion",
is_admin=True,
)
result = build_qdrant_filter(ctx, visibility="private")
# Admin should get simpler filter
assert "must" in result
class TestSecurityInvariants:
"""Tests for security invariants in filters."""
def test_private_not_in_multi_agent_filter(self):
"""Test that multi-agent filter ALWAYS excludes private."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
allowed_agent_ids=["agt_helion", "agt_nutra"], # Must have permission
)
result = build_multi_agent_filter(
ctx,
agent_ids=["agt_helion", "agt_nutra"],
)
# Must have must_not with private
assert "must_not" in result
private_exclusion = next(
(c for c in result["must_not"]
if c.get("key") == "visibility" and c.get("match", {}).get("value") == "private"),
None
)
assert private_exclusion is not None
def test_empty_agent_ids_raises_error(self):
"""Test that empty agent_ids raises error."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
allowed_agent_ids=["agt_helion"]
)
with pytest.raises(FilterSecurityError):
build_multi_agent_filter(ctx, agent_ids=[])
def test_unauthorized_agent_ids_raises_error(self):
"""Test that requesting unauthorized agent_ids raises error."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
allowed_agent_ids=["agt_helion", "agt_nutra"]
)
# Try to access agent not in allowed list
with pytest.raises(FilterSecurityError) as exc_info:
build_multi_agent_filter(ctx, agent_ids=["agt_helion", "agt_druid"])
assert "agt_druid" in str(exc_info.value)
assert "Unauthorized" in str(exc_info.value)
def test_multi_agent_requires_allowed_list_or_admin(self):
"""Test that multi-agent filter requires allowed_agent_ids or admin."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
# No allowed_agent_ids, not admin
)
with pytest.raises(FilterSecurityError) as exc_info:
build_multi_agent_filter(ctx, agent_ids=["agt_helion"])
assert "allowed_agent_ids" in str(exc_info.value)
def test_admin_can_access_any_agents(self):
"""Test that admin can access any agent without allowed list."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
is_admin=True,
# No allowed_agent_ids needed for admin
)
# Should not raise
result = build_multi_agent_filter(ctx, agent_ids=["agt_helion", "agt_druid", "agt_nutra"])
assert "must" in result
assert "should" in result
def test_empty_agent_id_for_agent_only_raises_error(self):
"""Test that empty agent_id raises error for agent_only filter."""
ctx = AccessContext(tenant_id="t_daarion", team_id="team_core")
with pytest.raises(FilterSecurityError):
build_agent_only_filter(ctx, agent_id="")
def test_tenant_id_always_present(self):
"""Test that tenant_id is always in must conditions."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
agent_id="agt_helion",
)
# Test all filter builders
filters = [
build_qdrant_filter(ctx),
build_agent_only_filter(ctx, agent_id="agt_helion"),
build_multi_agent_filter(ctx, agent_ids=["agt_helion"]),
build_project_filter(ctx, project_id="proj_helion"),
build_tag_filter(ctx, tags=["test"]),
]
for f in filters:
tenant_condition = next(
(c for c in f["must"] if c.get("key") == "tenant_id"),
None
)
assert tenant_condition is not None, f"tenant_id missing in {f}"
def test_private_only_for_owner(self):
"""Test that private is only accessible when owner matches."""
ctx = AccessContext(
tenant_id="t_daarion",
agent_id="agt_helion",
)
result = build_agent_only_filter(ctx, agent_id="agt_helion")
# Should have owner check with private
should_conditions = result.get("should", [])
# Find the condition that allows private
private_allowed = False
for cond in should_conditions:
must = cond.get("must", [])
has_private = any(
c.get("key") == "visibility" and "private" in str(c.get("match", {}))
for c in must
)
has_owner_check = any(
c.get("key") == "owner_id" and c.get("match", {}).get("value") == "agt_helion"
for c in must
)
if has_private and has_owner_check:
private_allowed = True
break
assert private_allowed, "Private should only be allowed with owner check"
class TestBuildAgentOnlyFilter:
"""Tests for build_agent_only_filter function."""
def test_agent_own_content(self):
"""Test filter for agent's own content."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
)
result = build_agent_only_filter(ctx, agent_id="agt_helion")
assert "must" in result
assert "should" in result
assert result.get("minimum_should_match") == 1
# Check own content condition exists
own_content = result["should"][0]
assert "must" in own_content
def test_agent_with_team_public(self):
"""Test filter includes team public when enabled."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
)
result = build_agent_only_filter(
ctx,
agent_id="agt_helion",
include_team_public=True,
)
# Should have 2 should conditions: own + team public
assert len(result["should"]) == 2
def test_agent_without_team_public(self):
"""Test filter excludes team public when disabled."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
)
result = build_agent_only_filter(
ctx,
agent_id="agt_helion",
include_team_public=False,
)
# Should have only 1 should condition: own
assert len(result["should"]) == 1
class TestBuildMultiAgentFilter:
"""Tests for build_multi_agent_filter function."""
def test_multi_agent_access(self):
"""Test filter for accessing multiple agents."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
allowed_agent_ids=["agt_helion", "agt_nutra", "agt_druid"], # Must have permission
)
result = build_multi_agent_filter(
ctx,
agent_ids=["agt_helion", "agt_nutra", "agt_druid"],
)
assert "must" in result
assert "should" in result
# Check agent_ids are in any match
agent_condition = result["should"][0]["must"][0]
assert agent_condition["key"] == "agent_id"
assert "any" in agent_condition["match"]
assert len(agent_condition["match"]["any"]) == 3
def test_multi_agent_excludes_private(self):
"""Test filter excludes private always (no parameter to override)."""
ctx = AccessContext(
tenant_id="t_daarion",
team_id="team_core",
allowed_agent_ids=["agt_helion"],
)
result = build_multi_agent_filter(
ctx,
agent_ids=["agt_helion"],
)
# must_not is always present with private exclusion
assert "must_not" in result
private_exclusion = result["must_not"][0]
assert private_exclusion["key"] == "visibility"
assert private_exclusion["match"]["value"] == "private"
class TestBuildProjectFilter:
"""Tests for build_project_filter function."""
def test_project_scoped(self):
"""Test project-scoped filter."""
ctx = AccessContext(tenant_id="t_daarion")
result = build_project_filter(ctx, project_id="proj_helion")
project_condition = next(
(c for c in result["must"] if c.get("key") == "project_id"),
None
)
assert project_condition is not None
assert project_condition["match"]["value"] == "proj_helion"
def test_project_excludes_private(self):
"""Test project filter excludes private."""
ctx = AccessContext(tenant_id="t_daarion")
result = build_project_filter(ctx, project_id="proj_helion")
assert "must_not" in result
class TestBuildTagFilter:
"""Tests for build_tag_filter function."""
def test_single_tag(self):
"""Test single tag filter."""
ctx = AccessContext(tenant_id="t_daarion")
result = build_tag_filter(ctx, tags=["legal_kb"])
tag_condition = next(
(c for c in result["must"] if c.get("key") == "tags"),
None
)
assert tag_condition is not None
assert tag_condition["match"]["value"] == "legal_kb"
def test_multiple_tags(self):
"""Test multiple tags filter (all must match)."""
ctx = AccessContext(tenant_id="t_daarion")
result = build_tag_filter(ctx, tags=["legal_kb", "contracts"])
# All tags should be in must conditions
tag_conditions = [
c for c in result["must"] if c.get("key") == "tags"
]
assert len(tag_conditions) == 2
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@@ -0,0 +1,282 @@
"""
Unit tests for Co-Memory payload validation.
"""
import pytest
from datetime import datetime
from services.memory.qdrant.payload_validation import (
validate_payload,
PayloadValidationError,
create_minimal_payload,
)
class TestPayloadValidation:
"""Tests for payload validation."""
def test_valid_minimal_payload(self):
"""Test that minimal valid payload passes validation."""
payload = create_minimal_payload(
tenant_id="t_daarion",
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
chunk_idx=0,
fingerprint="sha256:abc123",
team_id="team_core",
)
# Should not raise
result = validate_payload(payload)
assert result["schema_version"] == "cm_payload_v1"
assert result["tenant_id"] == "t_daarion"
def test_valid_full_payload(self):
"""Test that full payload with all fields passes validation."""
payload = {
"schema_version": "cm_payload_v1",
"tenant_id": "t_daarion",
"team_id": "team_core",
"project_id": "proj_helion",
"agent_id": "agt_helion",
"owner_kind": "agent",
"owner_id": "agt_helion",
"scope": "docs",
"visibility": "confidential",
"indexed": True,
"source_kind": "document",
"source_id": "doc_01HQ8K9X2NPQR3FGJKLM5678",
"chunk": {
"chunk_id": "chk_01HQ8K9X3MPQR3FGJKLM9012",
"chunk_idx": 0,
},
"fingerprint": "sha256:abc123def456",
"created_at": "2026-01-26T12:00:00Z",
"acl": {
"read_team_ids": ["team_core"],
"read_agent_ids": ["agt_nutra"],
},
"tags": ["product", "features"],
"lang": "uk",
"importance": 0.8,
"embedding": {
"model": "cohere-embed-v3",
"dim": 1024,
"metric": "cosine",
}
}
result = validate_payload(payload)
assert result["agent_id"] == "agt_helion"
assert result["scope"] == "docs"
def test_missing_required_field(self):
"""Test that missing required field raises error."""
payload = {
"schema_version": "cm_payload_v1",
"tenant_id": "t_daarion",
# Missing other required fields
}
with pytest.raises(PayloadValidationError) as exc_info:
validate_payload(payload)
assert "Missing required field" in str(exc_info.value)
def test_invalid_schema_version(self):
"""Test that invalid schema version raises error."""
payload = create_minimal_payload(
tenant_id="t_daarion",
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
chunk_idx=0,
fingerprint="sha256:abc123",
)
payload["schema_version"] = "v2" # Invalid
with pytest.raises(PayloadValidationError) as exc_info:
validate_payload(payload)
assert "schema_version" in str(exc_info.value)
def test_invalid_tenant_id_format(self):
"""Test that invalid tenant_id format raises error."""
payload = create_minimal_payload(
tenant_id="invalid-tenant", # Wrong format
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
chunk_idx=0,
fingerprint="sha256:abc123",
)
with pytest.raises(PayloadValidationError) as exc_info:
validate_payload(payload)
assert "tenant_id" in str(exc_info.value)
def test_invalid_agent_id_format(self):
"""Test that invalid agent_id format raises error."""
payload = create_minimal_payload(
tenant_id="t_daarion",
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
chunk_idx=0,
fingerprint="sha256:abc123",
)
payload["agent_id"] = "helion" # Missing prefix
with pytest.raises(PayloadValidationError) as exc_info:
validate_payload(payload)
assert "agent_id" in str(exc_info.value)
def test_invalid_scope(self):
"""Test that invalid scope raises error."""
payload = create_minimal_payload(
tenant_id="t_daarion",
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
chunk_idx=0,
fingerprint="sha256:abc123",
scope="invalid_scope",
)
with pytest.raises(PayloadValidationError) as exc_info:
validate_payload(payload)
assert "scope" in str(exc_info.value)
def test_invalid_visibility(self):
"""Test that invalid visibility raises error."""
payload = create_minimal_payload(
tenant_id="t_daarion",
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
chunk_idx=0,
fingerprint="sha256:abc123",
visibility="secret", # Invalid
)
with pytest.raises(PayloadValidationError) as exc_info:
validate_payload(payload)
assert "visibility" in str(exc_info.value)
def test_invalid_importance_range(self):
"""Test that importance outside 0-1 raises error."""
payload = create_minimal_payload(
tenant_id="t_daarion",
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
chunk_idx=0,
fingerprint="sha256:abc123",
)
payload["importance"] = 1.5 # Invalid
with pytest.raises(PayloadValidationError) as exc_info:
validate_payload(payload)
assert "importance" in str(exc_info.value)
def test_invalid_chunk_idx_negative(self):
"""Test that negative chunk_idx raises error."""
payload = create_minimal_payload(
tenant_id="t_daarion",
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
chunk_idx=-1, # Invalid
fingerprint="sha256:abc123",
)
with pytest.raises(PayloadValidationError) as exc_info:
validate_payload(payload)
assert "chunk_idx" in str(exc_info.value)
def test_non_strict_mode(self):
"""Test that non-strict mode returns payload with errors."""
payload = {
"schema_version": "cm_payload_v1",
"tenant_id": "invalid", # Invalid format
}
# Should not raise in non-strict mode
result = validate_payload(payload, strict=False)
assert result["schema_version"] == "cm_payload_v1"
def test_all_valid_scopes(self):
"""Test that all valid scopes pass validation."""
valid_scopes = ["docs", "messages", "memory", "artifacts", "signals"]
for scope in valid_scopes:
payload = create_minimal_payload(
tenant_id="t_daarion",
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
chunk_idx=0,
fingerprint="sha256:abc123",
scope=scope,
)
result = validate_payload(payload)
assert result["scope"] == scope
def test_all_valid_visibilities(self):
"""Test that all valid visibilities pass validation."""
valid_visibilities = ["public", "confidential", "private"]
for visibility in valid_visibilities:
payload = create_minimal_payload(
tenant_id="t_daarion",
source_id="doc_01HQ8K9X2NPQR3FGJKLM5678",
chunk_id="chk_01HQ8K9X3MPQR3FGJKLM9012",
chunk_idx=0,
fingerprint="sha256:abc123",
visibility=visibility,
)
result = validate_payload(payload)
assert result["visibility"] == visibility
class TestCollectionNameMapping:
"""Tests for legacy collection name to payload mapping."""
def test_parse_docs_collection(self):
"""Test parsing *_docs collection names."""
from scripts.qdrant_migrate_to_canonical import parse_collection_name
result = parse_collection_name("helion_docs")
assert result is not None
assert result["agent_id"] == "agt_helion"
assert result["scope"] == "docs"
assert result["tags"] == []
def test_parse_messages_collection(self):
"""Test parsing *_messages collection names."""
from scripts.qdrant_migrate_to_canonical import parse_collection_name
result = parse_collection_name("nutra_messages")
assert result is not None
assert result["agent_id"] == "agt_nutra"
assert result["scope"] == "messages"
def test_parse_special_kb_collection(self):
"""Test parsing special knowledge base collections."""
from scripts.qdrant_migrate_to_canonical import parse_collection_name
result = parse_collection_name("druid_legal_kb")
assert result is not None
assert result["agent_id"] == "agt_druid"
assert result["scope"] == "docs"
assert "legal_kb" in result["tags"]
def test_parse_unknown_collection(self):
"""Test parsing unknown collection returns None."""
from scripts.qdrant_migrate_to_canonical import parse_collection_name
result = parse_collection_name("random_collection_xyz")
# Should still try to match generic patterns or return None
# Based on implementation, this might match *_xyz or return None
if __name__ == "__main__":
pytest.main([__file__, "-v"])