Files
microdao-daarion/services/memory-service/app/main.py
Apple 5290287058 feat: implement TTS, Document processing, and Memory Service /facts API
- TTS: xtts-v2 integration with voice cloning support
- Document: docling integration for PDF/DOCX/PPTX processing
- Memory Service: added /facts/upsert, /facts/{key}, /facts endpoints
- Added required dependencies (TTS, docling)
2026-01-17 08:16:37 -08:00

592 lines
19 KiB
Python

"""
DAARION Memory Service - FastAPI Application
Трирівнева пам'ять агентів:
- Short-term: conversation events (робочий буфер)
- Mid-term: thread summaries (сесійна/тематична)
- Long-term: memory items (персональна/проектна)
"""
from contextlib import asynccontextmanager
from typing import List, Optional
from fastapi import Depends
from uuid import UUID
import structlog
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from .config import get_settings
from .models import (
CreateThreadRequest, AddEventRequest, CreateMemoryRequest,
MemoryFeedbackRequest, RetrievalRequest, SummaryRequest,
ThreadResponse, EventResponse, MemoryResponse,
SummaryResponse, RetrievalResponse, RetrievalResult,
ContextResponse, MemoryCategory, FeedbackAction
)
from .vector_store import vector_store
from .database import db
from .auth import get_current_service, get_current_service_optional
logger = structlog.get_logger()
settings = get_settings()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup and shutdown events"""
# Startup
logger.info("starting_memory_service")
await db.connect()
await vector_store.initialize()
yield
# Shutdown
await db.disconnect()
logger.info("memory_service_stopped")
app = FastAPI(
title="DAARION Memory Service",
description="Agent memory management with PostgreSQL + Qdrant + Cohere",
version="1.0.0",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ============================================================================
# HEALTH
# ============================================================================
@app.get("/health")
async def health():
"""Health check"""
return {
"status": "healthy",
"service": settings.service_name,
"vector_store": await vector_store.get_collection_stats()
}
# ============================================================================
# THREADS (Conversations)
# ============================================================================
@app.post("/threads", response_model=ThreadResponse)
async def create_thread(
request: CreateThreadRequest,
service: Optional[dict] = Depends(get_current_service_optional)
):
"""Create new conversation thread"""
# Auth опціональний: якщо JWT надано, перевіряємо; якщо ні - дозволяємо (dev режим)
thread = await db.create_thread(
org_id=request.org_id,
user_id=request.user_id,
workspace_id=request.workspace_id,
agent_id=request.agent_id,
title=request.title,
tags=request.tags,
metadata=request.metadata
)
return thread
@app.get("/threads/{thread_id}", response_model=ThreadResponse)
async def get_thread(thread_id: UUID):
"""Get thread by ID"""
thread = await db.get_thread(thread_id)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
return thread
@app.get("/threads", response_model=List[ThreadResponse])
async def list_threads(
user_id: UUID = Query(...),
org_id: UUID = Query(...),
workspace_id: Optional[UUID] = None,
agent_id: Optional[UUID] = None,
limit: int = Query(default=20, le=100)
):
"""List threads for user"""
threads = await db.list_threads(
org_id=org_id,
user_id=user_id,
workspace_id=workspace_id,
agent_id=agent_id,
limit=limit
)
return threads
# ============================================================================
# EVENTS (Short-term Memory)
# ============================================================================
@app.post("/events", response_model=EventResponse)
async def add_event(
request: AddEventRequest,
service: Optional[dict] = Depends(get_current_service_optional)
):
"""Add event to conversation (message, tool call, etc.)"""
event = await db.add_event(
thread_id=request.thread_id,
event_type=request.event_type,
role=request.role,
content=request.content,
tool_name=request.tool_name,
tool_input=request.tool_input,
tool_output=request.tool_output,
payload=request.payload,
token_count=request.token_count,
model_used=request.model_used,
latency_ms=request.latency_ms,
metadata=request.metadata
)
return event
@app.get("/threads/{thread_id}/events", response_model=List[EventResponse])
async def get_events(
thread_id: UUID,
limit: int = Query(default=50, le=200),
offset: int = Query(default=0)
):
"""Get events for thread (most recent first)"""
events = await db.get_events(thread_id, limit=limit, offset=offset)
return events
# ============================================================================
# MEMORIES (Long-term Memory)
# ============================================================================
@app.post("/memories", response_model=MemoryResponse)
async def create_memory(
request: CreateMemoryRequest,
service: Optional[dict] = Depends(get_current_service_optional)
):
"""Create long-term memory item"""
# Create in PostgreSQL
memory = await db.create_memory(
org_id=request.org_id,
user_id=request.user_id,
workspace_id=request.workspace_id,
agent_id=request.agent_id,
category=request.category,
fact_text=request.fact_text,
confidence=request.confidence,
source_event_id=request.source_event_id,
source_thread_id=request.source_thread_id,
extraction_method=request.extraction_method,
is_sensitive=request.is_sensitive,
retention=request.retention,
ttl_days=request.ttl_days,
tags=request.tags,
metadata=request.metadata
)
# Index in Qdrant
point_id = await vector_store.index_memory(
memory_id=memory["memory_id"],
text=request.fact_text,
org_id=request.org_id,
user_id=request.user_id,
category=request.category,
agent_id=request.agent_id,
workspace_id=request.workspace_id,
thread_id=request.source_thread_id
)
# Update memory with embedding ID
await db.update_memory_embedding_id(memory["memory_id"], point_id)
return memory
@app.get("/memories/{memory_id}", response_model=MemoryResponse)
async def get_memory(memory_id: UUID):
"""Get memory by ID"""
memory = await db.get_memory(memory_id)
if not memory:
raise HTTPException(status_code=404, detail="Memory not found")
return memory
@app.get("/memories", response_model=List[MemoryResponse])
async def list_memories(
user_id: UUID = Query(...),
org_id: UUID = Query(...),
agent_id: Optional[UUID] = None,
workspace_id: Optional[UUID] = None,
category: Optional[MemoryCategory] = None,
include_global: bool = True,
limit: int = Query(default=50, le=200)
):
"""List memories for user"""
memories = await db.list_memories(
org_id=org_id,
user_id=user_id,
agent_id=agent_id,
workspace_id=workspace_id,
category=category,
include_global=include_global,
limit=limit
)
return memories
@app.post("/memories/{memory_id}/feedback")
async def memory_feedback(memory_id: UUID, request: MemoryFeedbackRequest):
"""User feedback on memory (confirm/reject/edit/delete)"""
memory = await db.get_memory(memory_id)
if not memory:
raise HTTPException(status_code=404, detail="Memory not found")
# Record feedback
await db.add_memory_feedback(
memory_id=memory_id,
user_id=request.user_id,
action=request.action,
old_value=memory["fact_text"],
new_value=request.new_value,
reason=request.reason
)
# Apply action
if request.action == FeedbackAction.CONFIRM:
new_confidence = min(1.0, memory["confidence"] + settings.memory_confirm_boost)
await db.update_memory_confidence(memory_id, new_confidence, verified=True)
elif request.action == FeedbackAction.REJECT:
new_confidence = max(0.0, memory["confidence"] - settings.memory_reject_penalty)
if new_confidence < settings.memory_min_confidence:
# Mark as invalid
await db.invalidate_memory(memory_id)
await vector_store.delete_memory(memory_id)
else:
await db.update_memory_confidence(memory_id, new_confidence)
elif request.action == FeedbackAction.EDIT:
if request.new_value:
await db.update_memory_text(memory_id, request.new_value)
# Re-index with new text
await vector_store.delete_memory(memory_id)
await vector_store.index_memory(
memory_id=memory_id,
text=request.new_value,
org_id=memory["org_id"],
user_id=memory["user_id"],
category=memory["category"],
agent_id=memory.get("agent_id"),
workspace_id=memory.get("workspace_id")
)
elif request.action == FeedbackAction.DELETE:
await db.invalidate_memory(memory_id)
await vector_store.delete_memory(memory_id)
return {"status": "ok", "action": request.action.value}
# ============================================================================
# RETRIEVAL (Semantic Search)
# ============================================================================
@app.post("/retrieve", response_model=RetrievalResponse)
async def retrieve_memories(request: RetrievalRequest):
"""
Semantic retrieval of relevant memories.
Performs multiple queries and deduplicates results.
"""
all_results = []
seen_ids = set()
for query in request.queries:
results = await vector_store.search_memories(
query=query,
org_id=request.org_id,
user_id=request.user_id,
agent_id=request.agent_id,
workspace_id=request.workspace_id,
categories=request.categories,
include_global=request.include_global,
top_k=request.top_k
)
for r in results:
memory_id = r.get("memory_id")
if memory_id and memory_id not in seen_ids:
seen_ids.add(memory_id)
# Get full memory from DB for confidence check
memory = await db.get_memory(UUID(memory_id))
if memory and memory["confidence"] >= request.min_confidence:
all_results.append(RetrievalResult(
memory_id=UUID(memory_id),
fact_text=r["text"],
category=MemoryCategory(r["category"]),
confidence=memory["confidence"],
relevance_score=r["score"],
agent_id=UUID(r["agent_id"]) if r.get("agent_id") else None,
is_global=r.get("agent_id") is None
))
# Update usage stats
await db.increment_memory_usage(UUID(memory_id))
# Sort by relevance
all_results.sort(key=lambda x: x.relevance_score, reverse=True)
return RetrievalResponse(
results=all_results[:request.top_k],
query_count=len(request.queries),
total_results=len(all_results)
)
# ============================================================================
# SUMMARIES (Mid-term Memory)
# ============================================================================
@app.post("/threads/{thread_id}/summarize", response_model=SummaryResponse)
async def create_summary(thread_id: UUID, request: SummaryRequest):
"""
Generate rolling summary for thread.
Compresses old events into a structured summary.
"""
thread = await db.get_thread(thread_id)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
# Check if summary is needed
if not request.force and thread["total_tokens"] < settings.summary_trigger_tokens:
raise HTTPException(
status_code=400,
detail=f"Token count ({thread['total_tokens']}) below threshold ({settings.summary_trigger_tokens})"
)
# Get events to summarize
events = await db.get_events_for_summary(thread_id)
# TODO: Call LLM to generate summary
# For now, create a placeholder
summary_text = f"Summary of {len(events)} events. [Implement LLM summarization]"
state = {
"goals": [],
"decisions": [],
"open_questions": [],
"next_steps": [],
"key_facts": []
}
# Create summary
summary = await db.create_summary(
thread_id=thread_id,
summary_text=summary_text,
state=state,
events_from_seq=events[0]["sequence_num"] if events else 0,
events_to_seq=events[-1]["sequence_num"] if events else 0,
events_count=len(events)
)
# Index summary in Qdrant
await vector_store.index_summary(
summary_id=summary["summary_id"],
text=summary_text,
thread_id=thread_id,
org_id=thread["org_id"],
user_id=thread["user_id"],
agent_id=thread.get("agent_id"),
workspace_id=thread.get("workspace_id")
)
return summary
@app.get("/threads/{thread_id}/summary", response_model=Optional[SummaryResponse])
async def get_latest_summary(thread_id: UUID):
"""Get latest summary for thread"""
summary = await db.get_latest_summary(thread_id)
return summary
# ============================================================================
# CONTEXT (Full context for agent)
# ============================================================================
@app.get("/threads/{thread_id}/context", response_model=ContextResponse)
async def get_context(
thread_id: UUID,
queries: List[str] = Query(default=[]),
top_k: int = Query(default=10)
):
"""
Get full context for agent prompt.
Combines:
- Latest summary (mid-term)
- Recent messages (short-term)
- Retrieved memories (long-term)
"""
thread = await db.get_thread(thread_id)
if not thread:
raise HTTPException(status_code=404, detail="Thread not found")
# Get summary
summary = await db.get_latest_summary(thread_id)
# Get recent messages
recent = await db.get_events(
thread_id,
limit=settings.short_term_window_messages
)
# Retrieve memories if queries provided
retrieved = []
if queries:
retrieval_response = await retrieve_memories(RetrievalRequest(
org_id=thread["org_id"],
user_id=thread["user_id"],
agent_id=thread.get("agent_id"),
workspace_id=thread.get("workspace_id"),
queries=queries,
top_k=top_k,
include_global=True
))
retrieved = retrieval_response.results
# Estimate tokens
token_estimate = sum(e.get("token_count", 0) or 0 for e in recent)
if summary:
token_estimate += summary.get("summary_tokens", 0) or 0
return ContextResponse(
thread_id=thread_id,
summary=summary,
recent_messages=recent,
retrieved_memories=retrieved,
token_estimate=token_estimate
)
# ============================================================================
# FACTS (Simple Key-Value storage for Gateway compatibility)
# ============================================================================
from pydantic import BaseModel
from typing import Any
class FactUpsertRequest(BaseModel):
"""Request to upsert a user fact"""
user_id: str
fact_key: str
fact_value: Optional[str] = None
fact_value_json: Optional[dict] = None
team_id: Optional[str] = None
@app.post("/facts/upsert")
async def upsert_fact(request: FactUpsertRequest):
"""
Create or update a user fact.
This is a simple key-value store for Gateway compatibility.
Facts are stored in PostgreSQL without vector indexing.
"""
try:
# Ensure facts table exists (will be created on first call)
await db.ensure_facts_table()
# Upsert the fact
result = await db.upsert_fact(
user_id=request.user_id,
fact_key=request.fact_key,
fact_value=request.fact_value,
fact_value_json=request.fact_value_json,
team_id=request.team_id
)
logger.info(f"fact_upserted", user_id=request.user_id, fact_key=request.fact_key)
return {"status": "ok", "fact_id": result.get("fact_id") if result else None}
except Exception as e:
logger.error(f"fact_upsert_failed", error=str(e), user_id=request.user_id)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/facts/{fact_key}")
async def get_fact(
fact_key: str,
user_id: str = Query(...),
team_id: Optional[str] = None
):
"""Get a specific fact for a user"""
try:
fact = await db.get_fact(user_id=user_id, fact_key=fact_key, team_id=team_id)
if not fact:
raise HTTPException(status_code=404, detail="Fact not found")
return fact
except HTTPException:
raise
except Exception as e:
logger.error(f"fact_get_failed", error=str(e))
raise HTTPException(status_code=500, detail=str(e))
@app.get("/facts")
async def list_facts(
user_id: str = Query(...),
team_id: Optional[str] = None
):
"""List all facts for a user"""
try:
facts = await db.list_facts(user_id=user_id, team_id=team_id)
return {"facts": facts}
except Exception as e:
logger.error(f"facts_list_failed", error=str(e))
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/facts/{fact_key}")
async def delete_fact(
fact_key: str,
user_id: str = Query(...),
team_id: Optional[str] = None
):
"""Delete a fact"""
try:
deleted = await db.delete_fact(user_id=user_id, fact_key=fact_key, team_id=team_id)
if not deleted:
raise HTTPException(status_code=404, detail="Fact not found")
return {"status": "ok", "deleted": True}
except HTTPException:
raise
except Exception as e:
logger.error(f"fact_delete_failed", error=str(e))
raise HTTPException(status_code=500, detail=str(e))
# ============================================================================
# ADMIN
# ============================================================================
@app.get("/stats")
async def get_stats():
"""Get service statistics"""
return {
"vector_store": await vector_store.get_collection_stats(),
"database": await db.get_stats()
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)