- TTS: xtts-v2 integration with voice cloning support
- Document: docling integration for PDF/DOCX/PPTX processing
- Memory Service: added /facts/upsert, /facts/{key}, /facts endpoints
- Added required dependencies (TTS, docling)
104 lines
2.9 KiB
Python
104 lines
2.9 KiB
Python
"""
|
|
DAARION Memory Service - Embedding Layer (Cohere)
|
|
"""
|
|
import cohere
|
|
from typing import List
|
|
import structlog
|
|
from tenacity import retry, stop_after_attempt, wait_exponential
|
|
|
|
from .config import get_settings
|
|
|
|
logger = structlog.get_logger()
|
|
settings = get_settings()
|
|
|
|
# Cohere client will be initialized lazily
|
|
_cohere_client = None
|
|
|
|
def get_cohere_client():
|
|
"""Lazy initialization of Cohere client"""
|
|
global _cohere_client
|
|
if _cohere_client is None and settings.cohere_api_key:
|
|
try:
|
|
_cohere_client = cohere.Client(settings.cohere_api_key)
|
|
logger.info("cohere_client_initialized")
|
|
except Exception as e:
|
|
logger.warning("cohere_client_init_failed", error=str(e))
|
|
_cohere_client = False # Mark as failed to avoid retries
|
|
return _cohere_client if _cohere_client else None
|
|
|
|
|
|
@retry(
|
|
stop=stop_after_attempt(3),
|
|
wait=wait_exponential(multiplier=1, min=1, max=10)
|
|
)
|
|
async def get_embeddings(
|
|
texts: List[str],
|
|
input_type: str = "search_document"
|
|
) -> List[List[float]]:
|
|
"""
|
|
Get embeddings from Cohere API.
|
|
|
|
Args:
|
|
texts: List of texts to embed
|
|
input_type: "search_document" for indexing, "search_query" for queries
|
|
|
|
Returns:
|
|
List of embedding vectors (1024 dimensions for embed-multilingual-v3.0)
|
|
"""
|
|
if not texts:
|
|
return []
|
|
|
|
co_client = get_cohere_client()
|
|
if not co_client:
|
|
logger.warning("cohere_not_configured", message="Cohere API key not set, returning empty embeddings")
|
|
return [[] for _ in texts]
|
|
|
|
logger.info("generating_embeddings", count=len(texts), input_type=input_type)
|
|
|
|
response = co_client.embed(
|
|
texts=texts,
|
|
model=settings.cohere_model,
|
|
input_type=input_type,
|
|
truncate="END"
|
|
)
|
|
|
|
embeddings = response.embeddings
|
|
|
|
logger.info(
|
|
"embeddings_generated",
|
|
count=len(embeddings),
|
|
dimensions=len(embeddings[0]) if embeddings else 0
|
|
)
|
|
|
|
return embeddings
|
|
|
|
|
|
async def get_query_embedding(query: str) -> List[float]:
|
|
"""Get embedding for a search query"""
|
|
embeddings = await get_embeddings([query], input_type="search_query")
|
|
return embeddings[0] if embeddings else []
|
|
|
|
|
|
async def get_document_embeddings(texts: List[str]) -> List[List[float]]:
|
|
"""Get embeddings for documents (memories, summaries)"""
|
|
return await get_embeddings(texts, input_type="search_document")
|
|
|
|
|
|
# Batch processing for large sets
|
|
async def batch_embed(
|
|
texts: List[str],
|
|
input_type: str = "search_document",
|
|
batch_size: int = 96 # Cohere limit
|
|
) -> List[List[float]]:
|
|
"""
|
|
Embed large number of texts in batches.
|
|
"""
|
|
all_embeddings = []
|
|
|
|
for i in range(0, len(texts), batch_size):
|
|
batch = texts[i:i + batch_size]
|
|
embeddings = await get_embeddings(batch, input_type)
|
|
all_embeddings.extend(embeddings)
|
|
|
|
return all_embeddings
|