feat: implement TTS, Document processing, and Memory Service /facts API
- TTS: xtts-v2 integration with voice cloning support
- Document: docling integration for PDF/DOCX/PPTX processing
- Memory Service: added /facts/upsert, /facts/{key}, /facts endpoints
- Added required dependencies (TTS, docling)
This commit is contained in:
37
services/chandra-inference/Dockerfile
Normal file
37
services/chandra-inference/Dockerfile
Normal file
@@ -0,0 +1,37 @@
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.11 \
|
||||
python3-pip \
|
||||
curl \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip3 install --no-cache-dir \
|
||||
torch \
|
||||
torchvision \
|
||||
transformers \
|
||||
accelerate \
|
||||
pillow \
|
||||
fastapi \
|
||||
uvicorn \
|
||||
python-multipart \
|
||||
pydantic \
|
||||
httpx \
|
||||
tiktoken \
|
||||
sentencepiece \
|
||||
einops \
|
||||
verovio
|
||||
|
||||
# Copy inference service
|
||||
COPY . /app
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run inference service
|
||||
CMD ["python3", "main.py"]
|
||||
265
services/chandra-inference/main.py
Normal file
265
services/chandra-inference/main.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Chandra Inference Service
|
||||
Direct inference using HuggingFace model
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import base64
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Chandra Inference Service")
|
||||
|
||||
# Configuration
|
||||
# Using GOT-OCR2.0 - best open-source OCR for documents and tables
|
||||
# Alternative: microsoft/trocr-base-printed for simple text
|
||||
OCR_MODEL = os.getenv("OCR_MODEL", "stepfun-ai/GOT-OCR2_0")
|
||||
DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Load model (lazy loading)
|
||||
model = None
|
||||
processor = None
|
||||
|
||||
|
||||
def load_model():
|
||||
"""Load OCR model from HuggingFace"""
|
||||
global model, processor
|
||||
|
||||
if model is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info(f"Loading OCR model: {OCR_MODEL} on {DEVICE}")
|
||||
|
||||
# Try GOT-OCR2.0 first (best for documents and tables)
|
||||
if "GOT-OCR" in OCR_MODEL or "got-ocr" in OCR_MODEL.lower():
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
# GOT-OCR uses different loading
|
||||
tokenizer = AutoTokenizer.from_pretrained(OCR_MODEL, trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained(
|
||||
OCR_MODEL,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
||||
device_map="auto" if DEVICE == "cuda" else None
|
||||
)
|
||||
processor = tokenizer
|
||||
logger.info(f"GOT-OCR2.0 model loaded on {DEVICE}")
|
||||
|
||||
else:
|
||||
# Fallback to TrOCR for simple text OCR
|
||||
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
|
||||
|
||||
processor = TrOCRProcessor.from_pretrained(OCR_MODEL)
|
||||
model = VisionEncoderDecoderModel.from_pretrained(
|
||||
OCR_MODEL,
|
||||
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32
|
||||
)
|
||||
|
||||
if DEVICE == "cpu":
|
||||
model = model.to(DEVICE)
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
if model:
|
||||
model.eval()
|
||||
logger.info(f"Model loaded successfully on {DEVICE}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load model: {e}", exc_info=True)
|
||||
logger.warning("Service will run in degraded mode without model")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
"""Load model on startup"""
|
||||
try:
|
||||
load_model()
|
||||
except Exception as e:
|
||||
logger.error(f"Startup failed: {e}", exc_info=True)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
try:
|
||||
if model is None:
|
||||
return {
|
||||
"status": "loading",
|
||||
"service": "ocr-inference",
|
||||
"model": OCR_MODEL,
|
||||
"device": DEVICE
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "ocr-inference",
|
||||
"model": OCR_MODEL,
|
||||
"device": DEVICE,
|
||||
"cuda_available": torch.cuda.is_available()
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"service": "ocr-inference",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@app.post("/process")
|
||||
async def process_document(
|
||||
file: Optional[UploadFile] = File(None),
|
||||
doc_url: Optional[str] = Form(None),
|
||||
doc_base64: Optional[str] = Form(None),
|
||||
output_format: str = Form("markdown"),
|
||||
accurate_mode: bool = Form(False)
|
||||
):
|
||||
"""
|
||||
Process a document using Chandra OCR.
|
||||
|
||||
Args:
|
||||
file: Uploaded file (PDF, image)
|
||||
doc_url: URL to document
|
||||
doc_base64: Base64 encoded document
|
||||
output_format: markdown, html, or json
|
||||
accurate_mode: Use accurate mode (slower but more precise)
|
||||
"""
|
||||
try:
|
||||
# Ensure model is loaded
|
||||
if model is None:
|
||||
load_model()
|
||||
|
||||
# Get image data
|
||||
image_data = None
|
||||
|
||||
if file:
|
||||
image_data = await file.read()
|
||||
elif doc_base64:
|
||||
image_data = base64.b64decode(doc_base64)
|
||||
elif doc_url:
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(doc_url)
|
||||
if response.status_code == 200:
|
||||
image_data = response.content
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to download document: {response.status_code}"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No document provided"
|
||||
)
|
||||
|
||||
# Load image
|
||||
image = Image.open(BytesIO(image_data)).convert("RGB")
|
||||
|
||||
# Process with OCR model
|
||||
if model is None or processor is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="OCR model not loaded. Check logs for details."
|
||||
)
|
||||
|
||||
# Different processing for GOT-OCR vs TrOCR
|
||||
if "GOT-OCR" in OCR_MODEL or "got-ocr" in OCR_MODEL.lower():
|
||||
# GOT-OCR2.0 processing
|
||||
with torch.no_grad():
|
||||
result = model.chat(processor, image, ocr_type='ocr')
|
||||
generated_text = result if isinstance(result, str) else str(result)
|
||||
else:
|
||||
# TrOCR processing
|
||||
inputs = processor(images=image, return_tensors="pt").to(DEVICE)
|
||||
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
inputs.pixel_values,
|
||||
max_length=512,
|
||||
num_beams=5 if accurate_mode else 3
|
||||
)
|
||||
|
||||
generated_text = processor.batch_decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True
|
||||
)[0]
|
||||
|
||||
# Format output based on requested format
|
||||
if output_format == "markdown":
|
||||
result = {
|
||||
"markdown": generated_text,
|
||||
"format": "markdown"
|
||||
}
|
||||
elif output_format == "html":
|
||||
# Convert markdown to HTML (simplified)
|
||||
result = {
|
||||
"html": generated_text.replace("\n", "<br>"),
|
||||
"format": "html"
|
||||
}
|
||||
else: # json
|
||||
result = {
|
||||
"text": generated_text,
|
||||
"format": "json",
|
||||
"metadata": {
|
||||
"model": CHANDRA_MODEL,
|
||||
"device": DEVICE,
|
||||
"accurate_mode": accurate_mode
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output_format": output_format,
|
||||
"result": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Document processing failed: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Processing failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def list_models():
|
||||
"""List available models"""
|
||||
return {
|
||||
"current_model": OCR_MODEL,
|
||||
"available_models": [
|
||||
{
|
||||
"name": "stepfun-ai/GOT-OCR2_0",
|
||||
"type": "document_ocr",
|
||||
"description": "Best for documents, tables, formulas, handwriting",
|
||||
"vram": "~8GB"
|
||||
},
|
||||
{
|
||||
"name": "microsoft/trocr-base-printed",
|
||||
"type": "text_ocr",
|
||||
"description": "Fast OCR for printed text",
|
||||
"vram": "~2GB"
|
||||
},
|
||||
{
|
||||
"name": "microsoft/trocr-base-handwritten",
|
||||
"type": "handwriting_ocr",
|
||||
"description": "OCR for handwritten text",
|
||||
"vram": "~2GB"
|
||||
}
|
||||
],
|
||||
"note": "GOT-OCR2.0 recommended for documents and tables. TrOCR for simple text.",
|
||||
"device": DEVICE,
|
||||
"cuda_available": torch.cuda.is_available()
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
27
services/chandra-service/Dockerfile
Normal file
27
services/chandra-service/Dockerfile
Normal file
@@ -0,0 +1,27 @@
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir \
|
||||
fastapi \
|
||||
uvicorn \
|
||||
httpx \
|
||||
pydantic \
|
||||
python-multipart \
|
||||
pillow
|
||||
|
||||
# Copy service files
|
||||
COPY . /app
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8002
|
||||
|
||||
# Run service
|
||||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8002"]
|
||||
61
services/chandra-service/README.md
Normal file
61
services/chandra-service/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Chandra Document Processing Service
|
||||
|
||||
Wrapper service for Datalab Chandra OCR model for document and table processing.
|
||||
|
||||
## Features
|
||||
|
||||
- Document OCR with structure preservation
|
||||
- Table extraction with formatting
|
||||
- Handwriting recognition
|
||||
- Form processing
|
||||
- Output formats: Markdown, HTML, JSON
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Health Check
|
||||
```
|
||||
GET /health
|
||||
```
|
||||
|
||||
### Process Document
|
||||
```
|
||||
POST /process
|
||||
```
|
||||
|
||||
**Request:**
|
||||
- `file`: Uploaded file (PDF, image)
|
||||
- `doc_url`: URL to document
|
||||
- `doc_base64`: Base64 encoded document
|
||||
- `output_format`: markdown, html, or json
|
||||
- `accurate_mode`: true/false
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"success": true,
|
||||
"output_format": "markdown",
|
||||
"result": {
|
||||
"markdown": "...",
|
||||
"metadata": {...}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### List Models
|
||||
```
|
||||
GET /models
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Environment variables:
|
||||
- `CHANDRA_API_URL`: URL to Chandra inference service (default: `http://chandra-inference:8000`)
|
||||
- `CHANDRA_LICENSE_KEY`: Datalab license key (if required)
|
||||
- `CHANDRA_MODEL`: Model name (chandra-small or chandra)
|
||||
|
||||
## Integration
|
||||
|
||||
This service integrates with:
|
||||
- Router (`OCR_URL` and `CHANDRA_URL`)
|
||||
- Gateway (`doc_service.py`)
|
||||
- Memory Service (for storing processed documents)
|
||||
177
services/chandra-service/main.py
Normal file
177
services/chandra-service/main.py
Normal file
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Chandra Document Processing Service
|
||||
Wrapper for Datalab Chandra OCR model
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
import httpx
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(title="Chandra Document Processing Service")
|
||||
|
||||
# Configuration
|
||||
CHANDRA_API_URL = os.getenv("CHANDRA_API_URL", "http://chandra-inference:8000")
|
||||
CHANDRA_LICENSE_KEY = os.getenv("CHANDRA_LICENSE_KEY", "")
|
||||
CHANDRA_MODEL = os.getenv("CHANDRA_MODEL", "chandra-small") # chandra-small or chandra
|
||||
|
||||
# Health check endpoint
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
try:
|
||||
# Check if Chandra inference service is available
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(f"{CHANDRA_API_URL}/health")
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "chandra-service",
|
||||
"chandra_api": CHANDRA_API_URL,
|
||||
"model": CHANDRA_MODEL
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "degraded",
|
||||
"service": "chandra-service",
|
||||
"chandra_api": CHANDRA_API_URL,
|
||||
"error": "Chandra inference service unavailable"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}")
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"service": "chandra-service",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
class ProcessDocumentRequest(BaseModel):
|
||||
"""Request model for document processing"""
|
||||
doc_url: Optional[str] = None
|
||||
doc_base64: Optional[str] = None
|
||||
output_format: str = "markdown" # markdown, html, json
|
||||
accurate_mode: bool = False
|
||||
|
||||
|
||||
@app.post("/process")
|
||||
async def process_document(
|
||||
request: ProcessDocumentRequest,
|
||||
file: Optional[UploadFile] = File(None)
|
||||
):
|
||||
"""
|
||||
Process a document using Chandra OCR.
|
||||
|
||||
Accepts:
|
||||
- doc_url: URL to document/image
|
||||
- doc_base64: Base64 encoded document/image
|
||||
- file: Uploaded file
|
||||
- output_format: markdown, html, or json
|
||||
- accurate_mode: Use accurate mode (slower but more precise)
|
||||
"""
|
||||
try:
|
||||
# Determine input source
|
||||
image_data = None
|
||||
|
||||
if file:
|
||||
# Read uploaded file
|
||||
contents = await file.read()
|
||||
image_data = contents
|
||||
elif request.doc_base64:
|
||||
# Decode base64
|
||||
image_data = base64.b64decode(request.doc_base64)
|
||||
elif request.doc_url:
|
||||
# Download from URL
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(request.doc_url)
|
||||
if response.status_code == 200:
|
||||
image_data = response.content
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to download document from URL: {response.status_code}"
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No document provided. Use file, doc_url, or doc_base64"
|
||||
)
|
||||
|
||||
# Prepare request to Chandra inference service
|
||||
files = {
|
||||
"file": ("document", image_data, "application/octet-stream")
|
||||
}
|
||||
data = {
|
||||
"output_format": request.output_format,
|
||||
"accurate_mode": str(request.accurate_mode).lower()
|
||||
}
|
||||
|
||||
if CHANDRA_LICENSE_KEY:
|
||||
data["license_key"] = CHANDRA_LICENSE_KEY
|
||||
|
||||
# Call Chandra inference service
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
f"{CHANDRA_API_URL}/process",
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return {
|
||||
"success": True,
|
||||
"output_format": request.output_format,
|
||||
"result": result
|
||||
}
|
||||
else:
|
||||
logger.error(f"Chandra API error: {response.status_code} - {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail=f"Chandra API error: {response.text}"
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.error("Chandra API timeout")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail="Chandra API timeout"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Document processing failed: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Document processing failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def list_models():
|
||||
"""List available Chandra models"""
|
||||
return {
|
||||
"models": [
|
||||
{
|
||||
"name": "chandra-small",
|
||||
"description": "Fast model with lower latency",
|
||||
"vram_required": "~8GB"
|
||||
},
|
||||
{
|
||||
"name": "chandra",
|
||||
"description": "Balanced model",
|
||||
"vram_required": "~16GB"
|
||||
}
|
||||
],
|
||||
"current_model": CHANDRA_MODEL
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8002)
|
||||
41
services/docling-service/Dockerfile
Normal file
41
services/docling-service/Dockerfile
Normal file
@@ -0,0 +1,41 @@
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.11 \
|
||||
python3-pip \
|
||||
curl \
|
||||
git \
|
||||
libgl1-mesa-glx \
|
||||
libglib2.0-0 \
|
||||
poppler-utils \
|
||||
tesseract-ocr \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip3 install --no-cache-dir \
|
||||
docling \
|
||||
docling-core \
|
||||
torch \
|
||||
torchvision \
|
||||
transformers \
|
||||
accelerate \
|
||||
pillow \
|
||||
fastapi \
|
||||
uvicorn \
|
||||
python-multipart \
|
||||
pydantic \
|
||||
httpx \
|
||||
PyMuPDF \
|
||||
pdf2image
|
||||
|
||||
# Copy service code
|
||||
COPY . /app
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8003
|
||||
|
||||
# Run service
|
||||
CMD ["python3", "main.py"]
|
||||
350
services/docling-service/main.py
Normal file
350
services/docling-service/main.py
Normal file
@@ -0,0 +1,350 @@
|
||||
"""
|
||||
IBM Docling Service - Document conversion with table/formula extraction
|
||||
|
||||
Converts PDF, DOCX, PPTX, images to Markdown/JSON while preserving:
|
||||
- Tables (with structure)
|
||||
- Formulas (LaTeX)
|
||||
- Code blocks
|
||||
- Images
|
||||
- Document structure
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
import tempfile
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
from io import BytesIO
|
||||
|
||||
from fastapi import FastAPI, HTTPException, File, UploadFile, Form
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
import torch
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI(
|
||||
title="Docling Document Conversion Service",
|
||||
description="Convert documents to structured formats using IBM Docling"
|
||||
)
|
||||
|
||||
# Configuration
|
||||
DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
DOCLING_MODEL = os.getenv("DOCLING_MODEL", "ds4sd/docling-models")
|
||||
|
||||
# Global converter instance
|
||||
converter = None
|
||||
|
||||
|
||||
def load_docling():
|
||||
"""Load Docling converter"""
|
||||
global converter
|
||||
|
||||
if converter is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
|
||||
logger.info(f"Loading Docling on {DEVICE}...")
|
||||
|
||||
# Configure pipeline options
|
||||
pipeline_options = PdfPipelineOptions()
|
||||
pipeline_options.do_ocr = True
|
||||
pipeline_options.do_table_structure = True
|
||||
|
||||
# Initialize converter
|
||||
converter = DocumentConverter(
|
||||
allowed_formats=[
|
||||
InputFormat.PDF,
|
||||
InputFormat.DOCX,
|
||||
InputFormat.PPTX,
|
||||
InputFormat.IMAGE,
|
||||
InputFormat.HTML,
|
||||
]
|
||||
)
|
||||
|
||||
logger.info("Docling loaded successfully")
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import Docling: {e}")
|
||||
logger.warning("Service will run in degraded mode")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Docling: {e}", exc_info=True)
|
||||
logger.warning("Service will run in degraded mode")
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
load_docling()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Health check endpoint"""
|
||||
if converter is None:
|
||||
return {
|
||||
"status": "loading",
|
||||
"service": "docling-service",
|
||||
"device": DEVICE
|
||||
}
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "docling-service",
|
||||
"device": DEVICE,
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"features": ["pdf", "docx", "pptx", "images", "tables", "formulas"]
|
||||
}
|
||||
|
||||
|
||||
class ConvertRequest(BaseModel):
|
||||
"""Request model for document conversion"""
|
||||
doc_url: Optional[str] = None
|
||||
doc_base64: Optional[str] = None
|
||||
output_format: str = "markdown" # markdown, json, text
|
||||
extract_tables: bool = True
|
||||
extract_images: bool = False
|
||||
ocr_enabled: bool = True
|
||||
|
||||
|
||||
@app.post("/convert")
|
||||
async def convert_document(
|
||||
file: Optional[UploadFile] = File(None),
|
||||
doc_url: Optional[str] = Form(None),
|
||||
doc_base64: Optional[str] = Form(None),
|
||||
output_format: str = Form("markdown"),
|
||||
extract_tables: bool = Form(True),
|
||||
extract_images: bool = Form(False),
|
||||
ocr_enabled: bool = Form(True)
|
||||
):
|
||||
"""
|
||||
Convert a document to structured format.
|
||||
|
||||
Supports:
|
||||
- PDF, DOCX, PPTX, HTML, images
|
||||
- Table extraction with structure
|
||||
- Formula extraction (LaTeX)
|
||||
- OCR for scanned documents
|
||||
|
||||
Output formats:
|
||||
- markdown: Structured markdown with tables
|
||||
- json: Full document structure as JSON
|
||||
- text: Plain text extraction
|
||||
"""
|
||||
if converter is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Docling not loaded. Check logs for details."
|
||||
)
|
||||
|
||||
try:
|
||||
# Get document data
|
||||
doc_path = None
|
||||
temp_file = None
|
||||
|
||||
if file:
|
||||
# Save uploaded file to temp
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix)
|
||||
content = await file.read()
|
||||
temp_file.write(content)
|
||||
temp_file.close()
|
||||
doc_path = temp_file.name
|
||||
|
||||
elif doc_base64:
|
||||
# Decode base64 and save to temp
|
||||
content = base64.b64decode(doc_base64)
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
|
||||
temp_file.write(content)
|
||||
temp_file.close()
|
||||
doc_path = temp_file.name
|
||||
|
||||
elif doc_url:
|
||||
# Download from URL
|
||||
import httpx
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
response = await client.get(doc_url)
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to download document: {response.status_code}"
|
||||
)
|
||||
content = response.content
|
||||
|
||||
# Determine extension from URL or content-type
|
||||
ext = ".pdf"
|
||||
if doc_url.endswith(".docx"):
|
||||
ext = ".docx"
|
||||
elif doc_url.endswith(".pptx"):
|
||||
ext = ".pptx"
|
||||
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=ext)
|
||||
temp_file.write(content)
|
||||
temp_file.close()
|
||||
doc_path = temp_file.name
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No document provided. Use file, doc_url, or doc_base64"
|
||||
)
|
||||
|
||||
# Convert document
|
||||
logger.info(f"Converting document: {doc_path}")
|
||||
result = converter.convert(doc_path)
|
||||
|
||||
# Format output
|
||||
if output_format == "markdown":
|
||||
output = result.document.export_to_markdown()
|
||||
elif output_format == "json":
|
||||
output = result.document.export_to_dict()
|
||||
else:
|
||||
output = result.document.export_to_text()
|
||||
|
||||
# Extract tables if requested
|
||||
tables = []
|
||||
if extract_tables:
|
||||
for table in result.document.tables:
|
||||
tables.append({
|
||||
"id": table.id if hasattr(table, 'id') else None,
|
||||
"content": table.export_to_markdown() if hasattr(table, 'export_to_markdown') else str(table),
|
||||
"rows": len(table.data) if hasattr(table, 'data') else 0
|
||||
})
|
||||
|
||||
# Cleanup temp file
|
||||
if temp_file:
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"output_format": output_format,
|
||||
"result": output,
|
||||
"tables": tables if extract_tables else None,
|
||||
"pages": result.document.num_pages if hasattr(result.document, 'num_pages') else None,
|
||||
"metadata": {
|
||||
"title": result.document.title if hasattr(result.document, 'title') else None,
|
||||
"author": result.document.author if hasattr(result.document, 'author') else None
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Document conversion failed: {e}", exc_info=True)
|
||||
|
||||
# Cleanup on error
|
||||
if temp_file and os.path.exists(temp_file.name):
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Document conversion failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.post("/extract-tables")
|
||||
async def extract_tables(
|
||||
file: Optional[UploadFile] = File(None),
|
||||
doc_base64: Optional[str] = Form(None)
|
||||
):
|
||||
"""
|
||||
Extract tables from a document.
|
||||
|
||||
Returns tables as:
|
||||
- Markdown format
|
||||
- Structured data (rows/columns)
|
||||
- HTML format
|
||||
"""
|
||||
if converter is None:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Docling not loaded. Check logs for details."
|
||||
)
|
||||
|
||||
try:
|
||||
# Get document
|
||||
temp_file = None
|
||||
|
||||
if file:
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=Path(file.filename).suffix)
|
||||
content = await file.read()
|
||||
temp_file.write(content)
|
||||
temp_file.close()
|
||||
doc_path = temp_file.name
|
||||
elif doc_base64:
|
||||
content = base64.b64decode(doc_base64)
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
|
||||
temp_file.write(content)
|
||||
temp_file.close()
|
||||
doc_path = temp_file.name
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No document provided"
|
||||
)
|
||||
|
||||
# Convert and extract tables
|
||||
result = converter.convert(doc_path)
|
||||
|
||||
tables = []
|
||||
for idx, table in enumerate(result.document.tables):
|
||||
table_data = {
|
||||
"index": idx,
|
||||
"markdown": table.export_to_markdown() if hasattr(table, 'export_to_markdown') else None,
|
||||
"html": table.export_to_html() if hasattr(table, 'export_to_html') else None,
|
||||
}
|
||||
|
||||
# Try to get structured data
|
||||
if hasattr(table, 'data'):
|
||||
table_data["data"] = table.data
|
||||
table_data["rows"] = len(table.data)
|
||||
table_data["columns"] = len(table.data[0]) if table.data else 0
|
||||
|
||||
tables.append(table_data)
|
||||
|
||||
# Cleanup
|
||||
if temp_file:
|
||||
os.unlink(temp_file.name)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tables_count": len(tables),
|
||||
"tables": tables
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Table extraction failed: {e}", exc_info=True)
|
||||
if temp_file and os.path.exists(temp_file.name):
|
||||
os.unlink(temp_file.name)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Table extraction failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
async def list_models():
|
||||
"""List available models and features"""
|
||||
return {
|
||||
"service": "docling-service",
|
||||
"models": [
|
||||
{
|
||||
"name": "ds4sd/docling-models",
|
||||
"description": "IBM Docling - Document conversion with tables and formulas",
|
||||
"features": ["pdf", "docx", "pptx", "html", "images"],
|
||||
"capabilities": ["ocr", "tables", "formulas", "structure"]
|
||||
}
|
||||
],
|
||||
"supported_formats": {
|
||||
"input": ["pdf", "docx", "pptx", "html", "png", "jpg", "tiff"],
|
||||
"output": ["markdown", "json", "text"]
|
||||
},
|
||||
"device": DEVICE,
|
||||
"cuda_available": torch.cuda.is_available()
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8003)
|
||||
26
services/image-gen-service/Dockerfile
Normal file
26
services/image-gen-service/Dockerfile
Normal file
@@ -0,0 +1,26 @@
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# System dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.11 \
|
||||
python3-pip \
|
||||
curl \
|
||||
git \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Python dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||
|
||||
# App code
|
||||
COPY app/ ./app/
|
||||
|
||||
EXPOSE 8892
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD curl -f http://localhost:8892/health || exit 1
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8892"]
|
||||
126
services/image-gen-service/app/main.py
Normal file
126
services/image-gen-service/app/main.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import Flux2KleinPipeline
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
app = FastAPI(title="Image Generation Service", version="1.0.0")
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., min_length=1)
|
||||
negative_prompt: Optional[str] = None
|
||||
width: int = Field(1024, ge=256, le=2048)
|
||||
height: int = Field(1024, ge=256, le=2048)
|
||||
num_inference_steps: int = Field(50, ge=1, le=100)
|
||||
guidance_scale: float = Field(4.0, ge=0.0, le=20.0)
|
||||
seed: Optional[int] = Field(None, ge=0)
|
||||
|
||||
|
||||
MODEL_ID = os.getenv("IMAGE_GEN_MODEL", "black-forest-labs/FLUX.2-klein-base-4B")
|
||||
DEVICE = os.getenv("IMAGE_GEN_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
DTYPE_ENV = os.getenv("IMAGE_GEN_DTYPE", "float16")
|
||||
|
||||
|
||||
def _resolve_dtype() -> torch.dtype:
|
||||
if DEVICE.startswith("cuda"):
|
||||
return torch.float16 if DTYPE_ENV == "float16" else torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
|
||||
PIPELINE: Optional[Flux2KleinPipeline] = None
|
||||
LOAD_ERROR: Optional[str] = None
|
||||
|
||||
|
||||
def _load_pipeline() -> None:
|
||||
global PIPELINE, LOAD_ERROR
|
||||
try:
|
||||
dtype = _resolve_dtype()
|
||||
# Use bfloat16 for FLUX.2 Klein as recommended
|
||||
if dtype == torch.float16 and DEVICE.startswith("cuda"):
|
||||
dtype = torch.bfloat16
|
||||
|
||||
pipe = Flux2KleinPipeline.from_pretrained(
|
||||
MODEL_ID,
|
||||
torch_dtype=dtype,
|
||||
)
|
||||
|
||||
# Enable CPU offload to reduce VRAM usage
|
||||
if DEVICE.startswith("cuda"):
|
||||
pipe.enable_model_cpu_offload()
|
||||
else:
|
||||
pipe.to(DEVICE)
|
||||
|
||||
PIPELINE = pipe
|
||||
LOAD_ERROR = None
|
||||
except Exception as exc: # pragma: no cover - surface error via health/info
|
||||
PIPELINE = None
|
||||
LOAD_ERROR = str(exc)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
def startup_event() -> None:
|
||||
_load_pipeline()
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict:
|
||||
if LOAD_ERROR:
|
||||
raise HTTPException(status_code=503, detail=LOAD_ERROR)
|
||||
return {
|
||||
"status": "ok",
|
||||
"model_loaded": PIPELINE is not None,
|
||||
"model_id": MODEL_ID,
|
||||
"device": DEVICE,
|
||||
"dtype": str(_resolve_dtype()).replace("torch.", ""),
|
||||
}
|
||||
|
||||
|
||||
@app.get("/info")
|
||||
def info() -> dict:
|
||||
return {
|
||||
"model_id": MODEL_ID,
|
||||
"device": DEVICE,
|
||||
"dtype": str(_resolve_dtype()).replace("torch.", ""),
|
||||
"pipeline_loaded": PIPELINE is not None,
|
||||
"load_error": LOAD_ERROR,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
def generate(payload: GenerateRequest) -> dict:
|
||||
if LOAD_ERROR:
|
||||
raise HTTPException(status_code=503, detail=LOAD_ERROR)
|
||||
if PIPELINE is None:
|
||||
raise HTTPException(status_code=503, detail="Model is not loaded yet")
|
||||
|
||||
generator = None
|
||||
if payload.seed is not None:
|
||||
generator = torch.Generator(device="cuda" if DEVICE.startswith("cuda") else "cpu")
|
||||
generator.manual_seed(payload.seed)
|
||||
|
||||
with torch.inference_mode():
|
||||
result = PIPELINE(
|
||||
prompt=payload.prompt,
|
||||
negative_prompt=payload.negative_prompt if payload.negative_prompt else None,
|
||||
height=payload.height,
|
||||
width=payload.width,
|
||||
num_inference_steps=payload.num_inference_steps,
|
||||
guidance_scale=payload.guidance_scale,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
image = result.images[0]
|
||||
buffer = io.BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
encoded = base64.b64encode(buffer.getvalue()).decode("ascii")
|
||||
|
||||
return {
|
||||
"image_base64": encoded,
|
||||
"seed": payload.seed,
|
||||
"model_id": MODEL_ID,
|
||||
}
|
||||
8
services/image-gen-service/requirements.txt
Normal file
8
services/image-gen-service/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi==0.110.0
|
||||
uvicorn==0.29.0
|
||||
torch
|
||||
git+https://github.com/huggingface/diffusers.git
|
||||
transformers
|
||||
accelerate
|
||||
safetensors
|
||||
pillow
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
import jwt
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from fastapi import HTTPException, Security
|
||||
from fastapi import HTTPException, Security, Depends
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from app.config import get_settings
|
||||
|
||||
@@ -18,6 +18,7 @@ JWT_ALGORITHM = settings.jwt_algorithm
|
||||
JWT_EXPIRATION = settings.jwt_expiration
|
||||
|
||||
security = HTTPBearer()
|
||||
security_optional = HTTPBearer(auto_error=False)
|
||||
|
||||
|
||||
def generate_jwt_token(service_name: str, permissions: list = None) -> str:
|
||||
@@ -43,7 +44,7 @@ def verify_jwt_token(token: str) -> dict:
|
||||
|
||||
|
||||
async def get_current_service_optional(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Security(security, auto_error=False)
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional)
|
||||
) -> Optional[dict]:
|
||||
"""Dependency для отримання поточного сервісу з JWT (опціонально)"""
|
||||
if not credentials:
|
||||
|
||||
@@ -406,6 +406,117 @@ class Database:
|
||||
""", thread_id)
|
||||
return dict(row) if row else None
|
||||
|
||||
# ========================================================================
|
||||
# FACTS (Simple Key-Value storage)
|
||||
# ========================================================================
|
||||
|
||||
async def ensure_facts_table(self):
|
||||
"""Create facts table if not exists"""
|
||||
async with self.pool.acquire() as conn:
|
||||
await conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS user_facts (
|
||||
fact_id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id TEXT NOT NULL,
|
||||
team_id TEXT,
|
||||
fact_key TEXT NOT NULL,
|
||||
fact_value TEXT,
|
||||
fact_value_json JSONB,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
UNIQUE(user_id, team_id, fact_key)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_user_facts_user_id ON user_facts(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_user_facts_team_id ON user_facts(team_id);
|
||||
""")
|
||||
|
||||
async def upsert_fact(
|
||||
self,
|
||||
user_id: str,
|
||||
fact_key: str,
|
||||
fact_value: Optional[str] = None,
|
||||
fact_value_json: Optional[dict] = None,
|
||||
team_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Create or update a user fact"""
|
||||
async with self.pool.acquire() as conn:
|
||||
row = await conn.fetchrow("""
|
||||
INSERT INTO user_facts (user_id, team_id, fact_key, fact_value, fact_value_json)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (user_id, team_id, fact_key)
|
||||
DO UPDATE SET
|
||||
fact_value = EXCLUDED.fact_value,
|
||||
fact_value_json = EXCLUDED.fact_value_json,
|
||||
updated_at = NOW()
|
||||
RETURNING *
|
||||
""", user_id, team_id, fact_key, fact_value, fact_value_json)
|
||||
|
||||
return dict(row) if row else {}
|
||||
|
||||
async def get_fact(
|
||||
self,
|
||||
user_id: str,
|
||||
fact_key: str,
|
||||
team_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get a specific fact"""
|
||||
async with self.pool.acquire() as conn:
|
||||
if team_id:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT * FROM user_facts
|
||||
WHERE user_id = $1 AND fact_key = $2 AND team_id = $3
|
||||
""", user_id, fact_key, team_id)
|
||||
else:
|
||||
row = await conn.fetchrow("""
|
||||
SELECT * FROM user_facts
|
||||
WHERE user_id = $1 AND fact_key = $2 AND team_id IS NULL
|
||||
""", user_id, fact_key)
|
||||
|
||||
return dict(row) if row else None
|
||||
|
||||
async def list_facts(
|
||||
self,
|
||||
user_id: str,
|
||||
team_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""List all facts for a user"""
|
||||
async with self.pool.acquire() as conn:
|
||||
if team_id:
|
||||
rows = await conn.fetch("""
|
||||
SELECT * FROM user_facts
|
||||
WHERE user_id = $1 AND team_id = $2
|
||||
ORDER BY fact_key
|
||||
""", user_id, team_id)
|
||||
else:
|
||||
rows = await conn.fetch("""
|
||||
SELECT * FROM user_facts
|
||||
WHERE user_id = $1
|
||||
ORDER BY fact_key
|
||||
""", user_id)
|
||||
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def delete_fact(
|
||||
self,
|
||||
user_id: str,
|
||||
fact_key: str,
|
||||
team_id: Optional[str] = None
|
||||
) -> bool:
|
||||
"""Delete a fact"""
|
||||
async with self.pool.acquire() as conn:
|
||||
if team_id:
|
||||
result = await conn.execute("""
|
||||
DELETE FROM user_facts
|
||||
WHERE user_id = $1 AND fact_key = $2 AND team_id = $3
|
||||
""", user_id, fact_key, team_id)
|
||||
else:
|
||||
result = await conn.execute("""
|
||||
DELETE FROM user_facts
|
||||
WHERE user_id = $1 AND fact_key = $2 AND team_id IS NULL
|
||||
""", user_id, fact_key)
|
||||
|
||||
return "DELETE 1" in result
|
||||
|
||||
# ========================================================================
|
||||
# STATS
|
||||
# ========================================================================
|
||||
@@ -418,11 +529,18 @@ class Database:
|
||||
memories = await conn.fetchval("SELECT COUNT(*) FROM long_term_memory_items WHERE valid_to IS NULL")
|
||||
summaries = await conn.fetchval("SELECT COUNT(*) FROM thread_summaries")
|
||||
|
||||
# Add facts count safely
|
||||
try:
|
||||
facts = await conn.fetchval("SELECT COUNT(*) FROM user_facts")
|
||||
except:
|
||||
facts = 0
|
||||
|
||||
return {
|
||||
"threads": threads,
|
||||
"events": events,
|
||||
"active_memories": memories,
|
||||
"summaries": summaries
|
||||
"summaries": summaries,
|
||||
"facts": facts
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -11,8 +11,20 @@ from .config import get_settings
|
||||
logger = structlog.get_logger()
|
||||
settings = get_settings()
|
||||
|
||||
# Initialize Cohere client
|
||||
co = cohere.Client(settings.cohere_api_key)
|
||||
# Cohere client will be initialized lazily
|
||||
_cohere_client = None
|
||||
|
||||
def get_cohere_client():
|
||||
"""Lazy initialization of Cohere client"""
|
||||
global _cohere_client
|
||||
if _cohere_client is None and settings.cohere_api_key:
|
||||
try:
|
||||
_cohere_client = cohere.Client(settings.cohere_api_key)
|
||||
logger.info("cohere_client_initialized")
|
||||
except Exception as e:
|
||||
logger.warning("cohere_client_init_failed", error=str(e))
|
||||
_cohere_client = False # Mark as failed to avoid retries
|
||||
return _cohere_client if _cohere_client else None
|
||||
|
||||
|
||||
@retry(
|
||||
@@ -36,9 +48,14 @@ async def get_embeddings(
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
co_client = get_cohere_client()
|
||||
if not co_client:
|
||||
logger.warning("cohere_not_configured", message="Cohere API key not set, returning empty embeddings")
|
||||
return [[] for _ in texts]
|
||||
|
||||
logger.info("generating_embeddings", count=len(texts), input_type=input_type)
|
||||
|
||||
response = co.embed(
|
||||
response = co_client.embed(
|
||||
texts=texts,
|
||||
model=settings.cohere_model,
|
||||
input_type=input_type,
|
||||
|
||||
698
services/memory-service/app/ingestion.py
Normal file
698
services/memory-service/app/ingestion.py
Normal file
@@ -0,0 +1,698 @@
|
||||
"""
|
||||
Memory Ingestion Pipeline
|
||||
Автоматичне витягування фактів/пам'яті з діалогів
|
||||
|
||||
Етапи:
|
||||
1. PII Scrubber - виявлення та редакція персональних даних
|
||||
2. Memory Candidate Extractor - класифікація та витягування
|
||||
3. Dedup & Merge - дедуплікація схожих пам'ятей
|
||||
4. Write - збереження в SQL + Vector + Graph
|
||||
5. Audit Log - запис в аудит
|
||||
"""
|
||||
|
||||
import re
|
||||
import hashlib
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from uuid import UUID, uuid4
|
||||
from enum import Enum
|
||||
import structlog
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class MemoryType(str, Enum):
|
||||
EPISODIC = "episodic" # Події/факти про взаємодію
|
||||
SEMANTIC = "semantic" # Стійкі вподобання/профіль
|
||||
PROCEDURAL = "procedural" # Як робити щось
|
||||
|
||||
|
||||
class MemoryCategory(str, Enum):
|
||||
PREFERENCE = "preference" # Вподобання користувача
|
||||
FACT = "fact" # Факт про користувача
|
||||
TOPIC_INTEREST = "topic_interest" # Інтерес до теми
|
||||
ROLE = "role" # Роль (інвестор, інженер)
|
||||
INTERACTION = "interaction" # Тип взаємодії
|
||||
FEEDBACK = "feedback" # Відгук/оцінка
|
||||
OPT_OUT = "opt_out" # Заборона збереження
|
||||
|
||||
|
||||
class PIIType(str, Enum):
|
||||
PHONE = "phone"
|
||||
EMAIL = "email"
|
||||
ADDRESS = "address"
|
||||
PASSPORT = "passport"
|
||||
CARD_NUMBER = "card_number"
|
||||
NAME = "name"
|
||||
LOCATION = "location"
|
||||
|
||||
|
||||
class MemoryCandidate(BaseModel):
|
||||
"""Кандидат на збереження в пам'ять"""
|
||||
content: str
|
||||
summary: str
|
||||
memory_type: MemoryType
|
||||
category: MemoryCategory
|
||||
importance: float # 0.0 - 1.0
|
||||
confidence: float # 0.0 - 1.0
|
||||
ttl_days: Optional[int] = None
|
||||
source_message_ids: List[str] = []
|
||||
metadata: Dict[str, Any] = {}
|
||||
|
||||
|
||||
class PIIDetection(BaseModel):
|
||||
"""Результат виявлення PII"""
|
||||
pii_type: PIIType
|
||||
start: int
|
||||
end: int
|
||||
original: str
|
||||
redacted: str
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 1. PII SCRUBBER
|
||||
# =============================================================================
|
||||
|
||||
class PIIScrubber:
|
||||
"""Виявлення та редакція персональних даних"""
|
||||
|
||||
# Регулярні вирази для PII
|
||||
PATTERNS = {
|
||||
PIIType.PHONE: [
|
||||
r'\+?38?\s?0?\d{2}[\s\-]?\d{3}[\s\-]?\d{2}[\s\-]?\d{2}', # UA phones
|
||||
r'\+?\d{1,3}[\s\-]?\(?\d{2,3}\)?[\s\-]?\d{3}[\s\-]?\d{2}[\s\-]?\d{2}',
|
||||
],
|
||||
PIIType.EMAIL: [
|
||||
r'[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}',
|
||||
],
|
||||
PIIType.CARD_NUMBER: [
|
||||
r'\b\d{4}[\s\-]?\d{4}[\s\-]?\d{4}[\s\-]?\d{4}\b',
|
||||
],
|
||||
PIIType.PASSPORT: [
|
||||
r'\b[A-Z]{2}\d{6}\b', # UA passport
|
||||
],
|
||||
}
|
||||
|
||||
def detect(self, text: str) -> List[PIIDetection]:
|
||||
"""Виявити всі PII в тексті"""
|
||||
detections = []
|
||||
|
||||
for pii_type, patterns in self.PATTERNS.items():
|
||||
for pattern in patterns:
|
||||
for match in re.finditer(pattern, text, re.IGNORECASE):
|
||||
detections.append(PIIDetection(
|
||||
pii_type=pii_type,
|
||||
start=match.start(),
|
||||
end=match.end(),
|
||||
original=match.group(),
|
||||
redacted=self._redact(pii_type, match.group())
|
||||
))
|
||||
|
||||
return detections
|
||||
|
||||
def _redact(self, pii_type: PIIType, value: str) -> str:
|
||||
"""Редагувати PII значення"""
|
||||
if pii_type == PIIType.EMAIL:
|
||||
parts = value.split('@')
|
||||
return f"{parts[0][:2]}***@{parts[1]}" if len(parts) == 2 else "[EMAIL]"
|
||||
elif pii_type == PIIType.PHONE:
|
||||
return f"***{value[-4:]}" if len(value) > 4 else "[PHONE]"
|
||||
elif pii_type == PIIType.CARD_NUMBER:
|
||||
return f"****{value[-4:]}"
|
||||
else:
|
||||
return f"[{pii_type.value.upper()}]"
|
||||
|
||||
def scrub(self, text: str) -> Tuple[str, List[PIIDetection], bool]:
|
||||
"""
|
||||
Очистити текст від PII
|
||||
Returns: (cleaned_text, detections, has_pii)
|
||||
"""
|
||||
detections = self.detect(text)
|
||||
|
||||
if not detections:
|
||||
return text, [], False
|
||||
|
||||
# Сортувати за позицією (з кінця) для правильної заміни
|
||||
detections.sort(key=lambda x: x.start, reverse=True)
|
||||
|
||||
cleaned = text
|
||||
for detection in detections:
|
||||
cleaned = cleaned[:detection.start] + detection.redacted + cleaned[detection.end:]
|
||||
|
||||
return cleaned, detections, True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 2. MEMORY CANDIDATE EXTRACTOR
|
||||
# =============================================================================
|
||||
|
||||
class MemoryExtractor:
|
||||
"""Витягування кандидатів на пам'ять з повідомлень"""
|
||||
|
||||
# Ключові фрази для категорій
|
||||
CATEGORY_PATTERNS = {
|
||||
MemoryCategory.PREFERENCE: [
|
||||
r'я (хочу|бажаю|віддаю перевагу|люблю|не люблю)',
|
||||
r'мені (подобається|не подобається|зручніше)',
|
||||
r'(краще|гірше) для мене',
|
||||
],
|
||||
MemoryCategory.ROLE: [
|
||||
r'я (інвестор|інженер|розробник|науковець|журналіст|модератор)',
|
||||
r'працюю (як|в галузі)',
|
||||
r'моя (роль|посада|професія)',
|
||||
],
|
||||
MemoryCategory.TOPIC_INTEREST: [
|
||||
r'цікавить (мене )?(BioMiner|EcoMiner|токеноміка|governance|стейкінг)',
|
||||
r'хочу (дізнатися|розібратися) (в|з)',
|
||||
r'питання (про|щодо|стосовно)',
|
||||
],
|
||||
MemoryCategory.OPT_OUT: [
|
||||
r'(не |НЕ )?(запам[\'ʼ]ятов|запамʼятовуй|запамятовуй)',
|
||||
r'забудь (мене|це|все)',
|
||||
r'вимкни (пам[\'ʼ]ять|память)',
|
||||
],
|
||||
}
|
||||
|
||||
# Важливість за категорією
|
||||
IMPORTANCE_WEIGHTS = {
|
||||
MemoryCategory.PREFERENCE: 0.7,
|
||||
MemoryCategory.ROLE: 0.8,
|
||||
MemoryCategory.TOPIC_INTEREST: 0.6,
|
||||
MemoryCategory.FACT: 0.5,
|
||||
MemoryCategory.OPT_OUT: 1.0, # Найвища важливість
|
||||
}
|
||||
|
||||
def extract(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> List[MemoryCandidate]:
|
||||
"""
|
||||
Витягнути кандидатів на пам'ять з повідомлень
|
||||
|
||||
Args:
|
||||
messages: Список повідомлень [{role, content, message_id, ...}]
|
||||
context: Додатковий контекст (group_id, user_id, etc.)
|
||||
|
||||
Returns:
|
||||
Список MemoryCandidate
|
||||
"""
|
||||
candidates = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.get('role') != 'user':
|
||||
continue
|
||||
|
||||
content = msg.get('content', '')
|
||||
message_id = msg.get('message_id', str(uuid4()))
|
||||
|
||||
# Перевірити opt-out фрази
|
||||
opt_out = self._check_opt_out(content)
|
||||
if opt_out:
|
||||
candidates.append(opt_out)
|
||||
candidates[-1].source_message_ids = [message_id]
|
||||
continue
|
||||
|
||||
# Шукати інші категорії
|
||||
for category, patterns in self.CATEGORY_PATTERNS.items():
|
||||
if category == MemoryCategory.OPT_OUT:
|
||||
continue
|
||||
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, content, re.IGNORECASE):
|
||||
candidate = self._create_candidate(
|
||||
content=content,
|
||||
category=category,
|
||||
message_id=message_id,
|
||||
context=context
|
||||
)
|
||||
if candidate:
|
||||
candidates.append(candidate)
|
||||
break
|
||||
|
||||
return candidates
|
||||
|
||||
def _check_opt_out(self, content: str) -> Optional[MemoryCandidate]:
|
||||
"""Перевірити на opt-out фразу"""
|
||||
for pattern in self.CATEGORY_PATTERNS[MemoryCategory.OPT_OUT]:
|
||||
match = re.search(pattern, content, re.IGNORECASE)
|
||||
if match:
|
||||
# Визначити тип opt-out
|
||||
if 'забудь' in content.lower():
|
||||
action = 'forget'
|
||||
summary = "Користувач просить видалити пам'ять"
|
||||
else:
|
||||
action = 'disable'
|
||||
summary = "Користувач просить не запам'ятовувати"
|
||||
|
||||
return MemoryCandidate(
|
||||
content=content,
|
||||
summary=summary,
|
||||
memory_type=MemoryType.SEMANTIC,
|
||||
category=MemoryCategory.OPT_OUT,
|
||||
importance=1.0,
|
||||
confidence=0.95,
|
||||
metadata={'action': action}
|
||||
)
|
||||
return None
|
||||
|
||||
def _create_candidate(
|
||||
self,
|
||||
content: str,
|
||||
category: MemoryCategory,
|
||||
message_id: str,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[MemoryCandidate]:
|
||||
"""Створити кандидата на пам'ять"""
|
||||
|
||||
# Визначити тип пам'яті
|
||||
if category in [MemoryCategory.PREFERENCE, MemoryCategory.ROLE]:
|
||||
memory_type = MemoryType.SEMANTIC
|
||||
ttl_days = None # Безстроково
|
||||
else:
|
||||
memory_type = MemoryType.EPISODIC
|
||||
ttl_days = 90 # 3 місяці
|
||||
|
||||
# Створити короткий summary
|
||||
summary = self._generate_summary(content, category)
|
||||
|
||||
return MemoryCandidate(
|
||||
content=content,
|
||||
summary=summary,
|
||||
memory_type=memory_type,
|
||||
category=category,
|
||||
importance=self.IMPORTANCE_WEIGHTS.get(category, 0.5),
|
||||
confidence=0.7, # Базова впевненість, можна підвищити через LLM
|
||||
ttl_days=ttl_days,
|
||||
source_message_ids=[message_id],
|
||||
metadata=context or {}
|
||||
)
|
||||
|
||||
def _generate_summary(self, content: str, category: MemoryCategory) -> str:
|
||||
"""Згенерувати короткий summary"""
|
||||
# Простий варіант - перші 100 символів
|
||||
# В production використовувати LLM
|
||||
summary = content[:100]
|
||||
if len(content) > 100:
|
||||
summary += "..."
|
||||
return f"[{category.value}] {summary}"
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 3. DEDUP & MERGE
|
||||
# =============================================================================
|
||||
|
||||
class MemoryDeduplicator:
|
||||
"""Дедуплікація та об'єднання схожих пам'ятей"""
|
||||
|
||||
def __init__(self, similarity_threshold: float = 0.85):
|
||||
self.similarity_threshold = similarity_threshold
|
||||
|
||||
def deduplicate(
|
||||
self,
|
||||
new_candidates: List[MemoryCandidate],
|
||||
existing_memories: List[Dict[str, Any]]
|
||||
) -> Tuple[List[MemoryCandidate], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Дедуплікувати нових кандидатів проти існуючих пам'ятей
|
||||
|
||||
Returns:
|
||||
(candidates_to_create, memories_to_update)
|
||||
"""
|
||||
to_create = []
|
||||
to_update = []
|
||||
|
||||
for candidate in new_candidates:
|
||||
# Шукати схожу пам'ять
|
||||
similar = self._find_similar(candidate, existing_memories)
|
||||
|
||||
if similar:
|
||||
# Оновити існуючу пам'ять
|
||||
to_update.append({
|
||||
'memory_id': similar['memory_id'],
|
||||
'content': candidate.content,
|
||||
'summary': candidate.summary,
|
||||
'importance': max(candidate.importance, similar.get('importance', 0)),
|
||||
'source_message_ids': list(set(
|
||||
similar.get('source_message_ids', []) +
|
||||
candidate.source_message_ids
|
||||
))
|
||||
})
|
||||
else:
|
||||
to_create.append(candidate)
|
||||
|
||||
return to_create, to_update
|
||||
|
||||
def _find_similar(
|
||||
self,
|
||||
candidate: MemoryCandidate,
|
||||
existing: List[Dict[str, Any]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Знайти схожу пам'ять"""
|
||||
candidate_hash = self._content_hash(candidate.content)
|
||||
|
||||
for memory in existing:
|
||||
# Швидка перевірка за хешем
|
||||
if self._content_hash(memory.get('content', '')) == candidate_hash:
|
||||
return memory
|
||||
|
||||
# Перевірка за категорією + summary
|
||||
if (memory.get('category') == candidate.category.value and
|
||||
self._text_similarity(candidate.summary, memory.get('summary', '')) > self.similarity_threshold):
|
||||
return memory
|
||||
|
||||
return None
|
||||
|
||||
def _content_hash(self, content: str) -> str:
|
||||
"""Обчислити хеш контенту"""
|
||||
normalized = content.lower().strip()
|
||||
return hashlib.md5(normalized.encode()).hexdigest()
|
||||
|
||||
def _text_similarity(self, text1: str, text2: str) -> float:
|
||||
"""Проста подібність тексту (Jaccard)"""
|
||||
if not text1 or not text2:
|
||||
return 0.0
|
||||
|
||||
words1 = set(text1.lower().split())
|
||||
words2 = set(text2.lower().split())
|
||||
|
||||
intersection = len(words1 & words2)
|
||||
union = len(words1 | words2)
|
||||
|
||||
return intersection / union if union > 0 else 0.0
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# 4. MEMORY INGESTION PIPELINE
|
||||
# =============================================================================
|
||||
|
||||
class MemoryIngestionPipeline:
|
||||
"""
|
||||
Повний пайплайн витягування та збереження пам'яті
|
||||
"""
|
||||
|
||||
def __init__(self, db=None, vector_store=None, graph_store=None):
|
||||
self.db = db
|
||||
self.vector_store = vector_store
|
||||
self.graph_store = graph_store
|
||||
|
||||
self.pii_scrubber = PIIScrubber()
|
||||
self.extractor = MemoryExtractor()
|
||||
self.deduplicator = MemoryDeduplicator()
|
||||
|
||||
async def process_conversation(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
user_id: Optional[str] = None,
|
||||
platform_user_id: Optional[str] = None,
|
||||
group_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Обробити розмову та витягнути пам'ять
|
||||
|
||||
Returns:
|
||||
{
|
||||
"memories_created": int,
|
||||
"memories_updated": int,
|
||||
"pii_detected": bool,
|
||||
"opt_out_requested": bool,
|
||||
"details": [...]
|
||||
}
|
||||
"""
|
||||
result = {
|
||||
"memories_created": 0,
|
||||
"memories_updated": 0,
|
||||
"pii_detected": False,
|
||||
"opt_out_requested": False,
|
||||
"details": []
|
||||
}
|
||||
|
||||
# 1. PII Scrubbing
|
||||
cleaned_messages = []
|
||||
for msg in messages:
|
||||
if msg.get('role') == 'user':
|
||||
cleaned, detections, has_pii = self.pii_scrubber.scrub(msg.get('content', ''))
|
||||
if has_pii:
|
||||
result["pii_detected"] = True
|
||||
logger.info("pii_detected",
|
||||
count=len(detections),
|
||||
types=[d.pii_type.value for d in detections])
|
||||
cleaned_messages.append({**msg, 'content': cleaned, '_pii_detected': has_pii})
|
||||
else:
|
||||
cleaned_messages.append(msg)
|
||||
|
||||
# 2. Extract candidates
|
||||
context = {
|
||||
'user_id': user_id,
|
||||
'platform_user_id': platform_user_id,
|
||||
'group_id': group_id,
|
||||
'conversation_id': conversation_id
|
||||
}
|
||||
candidates = self.extractor.extract(cleaned_messages, context)
|
||||
|
||||
# Перевірити opt-out
|
||||
for candidate in candidates:
|
||||
if candidate.category == MemoryCategory.OPT_OUT:
|
||||
result["opt_out_requested"] = True
|
||||
await self._handle_opt_out(candidate, context)
|
||||
result["details"].append({
|
||||
"type": "opt_out",
|
||||
"action": candidate.metadata.get('action'),
|
||||
"summary": candidate.summary
|
||||
})
|
||||
|
||||
# Якщо opt-out — не зберігати інші пам'яті
|
||||
if result["opt_out_requested"]:
|
||||
return result
|
||||
|
||||
# 3. Dedup against existing
|
||||
existing_memories = []
|
||||
if self.db:
|
||||
existing_memories = await self._get_existing_memories(
|
||||
user_id=user_id,
|
||||
platform_user_id=platform_user_id,
|
||||
group_id=group_id
|
||||
)
|
||||
|
||||
to_create, to_update = self.deduplicator.deduplicate(candidates, existing_memories)
|
||||
|
||||
# 4. Write to storage
|
||||
for candidate in to_create:
|
||||
memory_id = await self._create_memory(candidate, context)
|
||||
if memory_id:
|
||||
result["memories_created"] += 1
|
||||
result["details"].append({
|
||||
"type": "created",
|
||||
"memory_id": str(memory_id),
|
||||
"category": candidate.category.value,
|
||||
"summary": candidate.summary
|
||||
})
|
||||
|
||||
for update in to_update:
|
||||
success = await self._update_memory(update)
|
||||
if success:
|
||||
result["memories_updated"] += 1
|
||||
result["details"].append({
|
||||
"type": "updated",
|
||||
"memory_id": update['memory_id'],
|
||||
"summary": update.get('summary')
|
||||
})
|
||||
|
||||
# 5. Audit log
|
||||
await self._log_ingestion(result, context)
|
||||
|
||||
logger.info("ingestion_complete",
|
||||
created=result["memories_created"],
|
||||
updated=result["memories_updated"],
|
||||
pii=result["pii_detected"],
|
||||
opt_out=result["opt_out_requested"])
|
||||
|
||||
return result
|
||||
|
||||
async def _handle_opt_out(
|
||||
self,
|
||||
candidate: MemoryCandidate,
|
||||
context: Dict[str, Any]
|
||||
):
|
||||
"""Обробити opt-out запит"""
|
||||
action = candidate.metadata.get('action', 'disable')
|
||||
group_id = context.get('group_id')
|
||||
platform_user_id = context.get('platform_user_id')
|
||||
|
||||
if not platform_user_id:
|
||||
return
|
||||
|
||||
if self.db:
|
||||
if action == 'forget' and group_id:
|
||||
# Повне видалення в групі
|
||||
await self.db.execute(
|
||||
"SELECT memory_forget_in_group($1::uuid, $2)",
|
||||
group_id, platform_user_id
|
||||
)
|
||||
else:
|
||||
# Просто відключити збереження
|
||||
if group_id:
|
||||
await self.db.execute("""
|
||||
UPDATE group_members
|
||||
SET no_memory_in_group = TRUE
|
||||
WHERE group_id = $1::uuid AND platform_user_id = $2
|
||||
""", group_id, platform_user_id)
|
||||
else:
|
||||
await self.db.execute("""
|
||||
UPDATE memory_consent
|
||||
SET memory_enabled = FALSE, updated_at = NOW()
|
||||
WHERE platform_user_id = $1
|
||||
""", platform_user_id)
|
||||
|
||||
async def _get_existing_memories(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
platform_user_id: Optional[str],
|
||||
group_id: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Отримати існуючі пам'яті"""
|
||||
if not self.db:
|
||||
return []
|
||||
|
||||
query = """
|
||||
SELECT memory_id, content, summary, category, importance, source_message_ids
|
||||
FROM memories
|
||||
WHERE is_active = TRUE
|
||||
"""
|
||||
params = []
|
||||
|
||||
if group_id:
|
||||
query += " AND group_id = $1::uuid"
|
||||
params.append(group_id)
|
||||
if platform_user_id:
|
||||
query += " AND platform_user_id = $2"
|
||||
params.append(platform_user_id)
|
||||
elif user_id:
|
||||
query += " AND user_id = $1::uuid AND group_id IS NULL"
|
||||
params.append(user_id)
|
||||
elif platform_user_id:
|
||||
query += " AND platform_user_id = $1 AND group_id IS NULL"
|
||||
params.append(platform_user_id)
|
||||
else:
|
||||
return []
|
||||
|
||||
rows = await self.db.fetch(query, *params)
|
||||
return [dict(row) for row in rows]
|
||||
|
||||
async def _create_memory(
|
||||
self,
|
||||
candidate: MemoryCandidate,
|
||||
context: Dict[str, Any]
|
||||
) -> Optional[UUID]:
|
||||
"""Створити нову пам'ять"""
|
||||
if not self.db:
|
||||
return uuid4() # Mock для тестування
|
||||
|
||||
memory_id = uuid4()
|
||||
|
||||
# Calculate expires_at
|
||||
expires_at = None
|
||||
if candidate.ttl_days:
|
||||
expires_at = datetime.now() + timedelta(days=candidate.ttl_days)
|
||||
|
||||
await self.db.execute("""
|
||||
INSERT INTO memories (
|
||||
memory_id, user_id, platform_user_id, group_id,
|
||||
memory_type, category, content, summary,
|
||||
importance, confidence, ttl_days, expires_at,
|
||||
source_message_ids, extraction_method, metadata
|
||||
) VALUES (
|
||||
$1, $2::uuid, $3, $4::uuid,
|
||||
$5, $6, $7, $8,
|
||||
$9, $10, $11, $12,
|
||||
$13, $14, $15
|
||||
)
|
||||
""",
|
||||
memory_id,
|
||||
context.get('user_id'),
|
||||
context.get('platform_user_id'),
|
||||
context.get('group_id'),
|
||||
candidate.memory_type.value,
|
||||
candidate.category.value,
|
||||
candidate.content,
|
||||
candidate.summary,
|
||||
candidate.importance,
|
||||
candidate.confidence,
|
||||
candidate.ttl_days,
|
||||
expires_at,
|
||||
candidate.source_message_ids,
|
||||
'pipeline',
|
||||
candidate.metadata
|
||||
)
|
||||
|
||||
# Зберегти embedding якщо є vector store
|
||||
if self.vector_store:
|
||||
await self._store_embedding(memory_id, candidate, context)
|
||||
|
||||
# Зберегти в граф якщо є graph store
|
||||
if self.graph_store:
|
||||
await self._store_graph_relation(memory_id, candidate, context)
|
||||
|
||||
return memory_id
|
||||
|
||||
async def _update_memory(self, update: Dict[str, Any]) -> bool:
|
||||
"""Оновити існуючу пам'ять"""
|
||||
if not self.db:
|
||||
return True
|
||||
|
||||
await self.db.execute("""
|
||||
UPDATE memories
|
||||
SET content = $2, summary = $3, importance = $4,
|
||||
source_message_ids = $5, updated_at = NOW()
|
||||
WHERE memory_id = $1::uuid
|
||||
""",
|
||||
update['memory_id'],
|
||||
update['content'],
|
||||
update['summary'],
|
||||
update['importance'],
|
||||
update['source_message_ids']
|
||||
)
|
||||
return True
|
||||
|
||||
async def _store_embedding(
|
||||
self,
|
||||
memory_id: UUID,
|
||||
candidate: MemoryCandidate,
|
||||
context: Dict[str, Any]
|
||||
):
|
||||
"""Зберегти embedding в vector store"""
|
||||
# Реалізація залежить від vector store (Qdrant, pgvector)
|
||||
pass
|
||||
|
||||
async def _store_graph_relation(
|
||||
self,
|
||||
memory_id: UUID,
|
||||
candidate: MemoryCandidate,
|
||||
context: Dict[str, Any]
|
||||
):
|
||||
"""Зберегти зв'язок в graph store"""
|
||||
# Реалізація для Neo4j
|
||||
pass
|
||||
|
||||
async def _log_ingestion(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
context: Dict[str, Any]
|
||||
):
|
||||
"""Записати в аудит"""
|
||||
if not self.db:
|
||||
return
|
||||
|
||||
await self.db.execute("""
|
||||
INSERT INTO memory_events (
|
||||
user_id, group_id, action, actor, new_value
|
||||
) VALUES (
|
||||
$1::uuid, $2::uuid, 'ingestion', 'pipeline', $3
|
||||
)
|
||||
""",
|
||||
context.get('user_id'),
|
||||
context.get('group_id'),
|
||||
result
|
||||
)
|
||||
@@ -477,6 +477,102 @@ async def get_context(
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# FACTS (Simple Key-Value storage for Gateway compatibility)
|
||||
# ============================================================================
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any
|
||||
|
||||
class FactUpsertRequest(BaseModel):
|
||||
"""Request to upsert a user fact"""
|
||||
user_id: str
|
||||
fact_key: str
|
||||
fact_value: Optional[str] = None
|
||||
fact_value_json: Optional[dict] = None
|
||||
team_id: Optional[str] = None
|
||||
|
||||
@app.post("/facts/upsert")
|
||||
async def upsert_fact(request: FactUpsertRequest):
|
||||
"""
|
||||
Create or update a user fact.
|
||||
|
||||
This is a simple key-value store for Gateway compatibility.
|
||||
Facts are stored in PostgreSQL without vector indexing.
|
||||
"""
|
||||
try:
|
||||
# Ensure facts table exists (will be created on first call)
|
||||
await db.ensure_facts_table()
|
||||
|
||||
# Upsert the fact
|
||||
result = await db.upsert_fact(
|
||||
user_id=request.user_id,
|
||||
fact_key=request.fact_key,
|
||||
fact_value=request.fact_value,
|
||||
fact_value_json=request.fact_value_json,
|
||||
team_id=request.team_id
|
||||
)
|
||||
|
||||
logger.info(f"fact_upserted", user_id=request.user_id, fact_key=request.fact_key)
|
||||
return {"status": "ok", "fact_id": result.get("fact_id") if result else None}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"fact_upsert_failed", error=str(e), user_id=request.user_id)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/facts/{fact_key}")
|
||||
async def get_fact(
|
||||
fact_key: str,
|
||||
user_id: str = Query(...),
|
||||
team_id: Optional[str] = None
|
||||
):
|
||||
"""Get a specific fact for a user"""
|
||||
try:
|
||||
fact = await db.get_fact(user_id=user_id, fact_key=fact_key, team_id=team_id)
|
||||
if not fact:
|
||||
raise HTTPException(status_code=404, detail="Fact not found")
|
||||
return fact
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"fact_get_failed", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/facts")
|
||||
async def list_facts(
|
||||
user_id: str = Query(...),
|
||||
team_id: Optional[str] = None
|
||||
):
|
||||
"""List all facts for a user"""
|
||||
try:
|
||||
facts = await db.list_facts(user_id=user_id, team_id=team_id)
|
||||
return {"facts": facts}
|
||||
except Exception as e:
|
||||
logger.error(f"facts_list_failed", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.delete("/facts/{fact_key}")
|
||||
async def delete_fact(
|
||||
fact_key: str,
|
||||
user_id: str = Query(...),
|
||||
team_id: Optional[str] = None
|
||||
):
|
||||
"""Delete a fact"""
|
||||
try:
|
||||
deleted = await db.delete_fact(user_id=user_id, fact_key=fact_key, team_id=team_id)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Fact not found")
|
||||
return {"status": "ok", "deleted": True}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"fact_delete_failed", error=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ADMIN
|
||||
# ============================================================================
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import yaml
|
||||
import httpx
|
||||
import logging
|
||||
from neo4j import AsyncGraphDatabase
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -15,15 +16,27 @@ app = FastAPI(title="DAARION Router", version="2.0.0")
|
||||
|
||||
# Configuration
|
||||
NATS_URL = os.getenv("NATS_URL", "nats://nats:4222")
|
||||
SWAPPER_URL = os.getenv("SWAPPER_URL", "http://192.168.1.33:8890")
|
||||
STT_URL = os.getenv("STT_URL", "http://192.168.1.33:8895")
|
||||
VISION_URL = os.getenv("VISION_URL", "http://192.168.1.33:11434")
|
||||
OCR_URL = os.getenv("OCR_URL", "http://192.168.1.33:8896")
|
||||
SWAPPER_URL = os.getenv("SWAPPER_URL", "http://swapper-service:8890")
|
||||
# All multimodal services now through Swapper
|
||||
STT_URL = os.getenv("STT_URL", "http://swapper-service:8890") # Swapper /stt endpoint
|
||||
TTS_URL = os.getenv("TTS_URL", "http://swapper-service:8890") # Swapper /tts endpoint
|
||||
VISION_URL = os.getenv("VISION_URL", "http://172.18.0.1:11434") # Host Ollama
|
||||
OCR_URL = os.getenv("OCR_URL", "http://swapper-service:8890") # Swapper /ocr endpoint
|
||||
DOCUMENT_URL = os.getenv("DOCUMENT_URL", "http://swapper-service:8890") # Swapper /document endpoint
|
||||
CITY_SERVICE_URL = os.getenv("CITY_SERVICE_URL", "http://daarion-city-service:7001")
|
||||
|
||||
# Neo4j Configuration
|
||||
NEO4J_URI = os.getenv("NEO4J_BOLT_URL", "bolt://neo4j:7687")
|
||||
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
|
||||
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "DaarionNeo4j2026!")
|
||||
|
||||
# HTTP client for backend services
|
||||
http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
# Neo4j driver
|
||||
neo4j_driver = None
|
||||
neo4j_available = False
|
||||
|
||||
# NATS client
|
||||
nc = None
|
||||
nats_available = False
|
||||
@@ -82,13 +95,29 @@ router_config = load_router_config()
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize NATS connection and subscriptions"""
|
||||
global nc, nats_available, http_client
|
||||
global nc, nats_available, http_client, neo4j_driver, neo4j_available
|
||||
logger.info("🚀 DAGI Router v2.0.0 starting up...")
|
||||
|
||||
# Initialize HTTP client
|
||||
http_client = httpx.AsyncClient(timeout=60.0)
|
||||
logger.info("✅ HTTP client initialized")
|
||||
|
||||
# Initialize Neo4j connection
|
||||
try:
|
||||
neo4j_driver = AsyncGraphDatabase.driver(
|
||||
NEO4J_URI,
|
||||
auth=(NEO4J_USER, NEO4J_PASSWORD)
|
||||
)
|
||||
# Verify connection
|
||||
async with neo4j_driver.session() as session:
|
||||
result = await session.run("RETURN 1 as test")
|
||||
await result.consume()
|
||||
neo4j_available = True
|
||||
logger.info(f"✅ Connected to Neo4j at {NEO4J_URI}")
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Neo4j not available: {e}")
|
||||
neo4j_available = False
|
||||
|
||||
# Try to connect to NATS
|
||||
try:
|
||||
import nats
|
||||
@@ -111,6 +140,7 @@ async def startup_event():
|
||||
logger.info(f"📡 STT URL: {STT_URL}")
|
||||
logger.info(f"📡 Vision URL: {VISION_URL}")
|
||||
logger.info(f"📡 OCR URL: {OCR_URL}")
|
||||
logger.info(f"📡 Neo4j URL: {NEO4J_URI}")
|
||||
|
||||
async def subscribe_to_filter_decisions():
|
||||
"""Subscribe to agent.filter.decision events"""
|
||||
@@ -409,47 +439,152 @@ async def agent_infer(agent_id: str, request: InferRequest):
|
||||
system_prompt = agent_config.get("system_prompt")
|
||||
|
||||
# Determine which backend to use
|
||||
model = request.model or "gpt-oss:latest"
|
||||
# Use router config to get default model for agent, fallback to qwen3-8b
|
||||
agent_config = router_config.get("agents", {}).get(agent_id, {})
|
||||
default_llm = agent_config.get("default_llm", "qwen3-8b")
|
||||
|
||||
# Try Swapper first (for LLM models)
|
||||
# Check if there's a routing rule for this agent
|
||||
routing_rules = router_config.get("routing", [])
|
||||
for rule in routing_rules:
|
||||
if rule.get("when", {}).get("agent") == agent_id:
|
||||
if "use_llm" in rule:
|
||||
default_llm = rule.get("use_llm")
|
||||
logger.info(f"🎯 Agent {agent_id} routing to: {default_llm}")
|
||||
break
|
||||
|
||||
# Get LLM profile configuration
|
||||
llm_profiles = router_config.get("llm_profiles", {})
|
||||
llm_profile = llm_profiles.get(default_llm, {})
|
||||
provider = llm_profile.get("provider", "ollama")
|
||||
|
||||
# Determine model name
|
||||
if provider in ["deepseek", "openai", "anthropic", "mistral"]:
|
||||
model = llm_profile.get("model", "deepseek-chat")
|
||||
else:
|
||||
# For local ollama, use swapper model name format
|
||||
model = request.model or "qwen3-8b"
|
||||
|
||||
# =========================================================================
|
||||
# CLOUD PROVIDERS (DeepSeek, OpenAI, etc.)
|
||||
# =========================================================================
|
||||
if provider == "deepseek":
|
||||
try:
|
||||
api_key = os.getenv(llm_profile.get("api_key_env", "DEEPSEEK_API_KEY"))
|
||||
base_url = llm_profile.get("base_url", "https://api.deepseek.com")
|
||||
|
||||
if not api_key:
|
||||
logger.error("❌ DeepSeek API key not configured")
|
||||
raise HTTPException(status_code=500, detail="DeepSeek API key not configured")
|
||||
|
||||
logger.info(f"🌐 Calling DeepSeek API with model: {model}")
|
||||
|
||||
# Build messages array for chat completion
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": request.prompt})
|
||||
|
||||
deepseek_resp = await http_client.post(
|
||||
f"{base_url}/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
},
|
||||
json={
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": request.max_tokens or llm_profile.get("max_tokens", 2048),
|
||||
"temperature": request.temperature or llm_profile.get("temperature", 0.2),
|
||||
"stream": False
|
||||
},
|
||||
timeout=llm_profile.get("timeout_ms", 40000) / 1000
|
||||
)
|
||||
|
||||
if deepseek_resp.status_code == 200:
|
||||
data = deepseek_resp.json()
|
||||
response_text = data.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
tokens_used = data.get("usage", {}).get("total_tokens", 0)
|
||||
|
||||
logger.info(f"✅ DeepSeek response received, {tokens_used} tokens")
|
||||
return InferResponse(
|
||||
response=response_text,
|
||||
model=model,
|
||||
tokens_used=tokens_used,
|
||||
backend="deepseek-cloud"
|
||||
)
|
||||
else:
|
||||
logger.error(f"❌ DeepSeek error: {deepseek_resp.status_code} - {deepseek_resp.text}")
|
||||
raise HTTPException(status_code=deepseek_resp.status_code, detail=f"DeepSeek API error: {deepseek_resp.text}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"❌ DeepSeek error: {e}")
|
||||
# Don't fallback to local for cloud agents - raise error
|
||||
raise HTTPException(status_code=503, detail=f"DeepSeek API error: {str(e)}")
|
||||
|
||||
# =========================================================================
|
||||
# LOCAL PROVIDERS (Ollama via Swapper)
|
||||
# =========================================================================
|
||||
try:
|
||||
# Check if Swapper is available
|
||||
health_resp = await http_client.get(f"{SWAPPER_URL}/health", timeout=5.0)
|
||||
if health_resp.status_code == 200:
|
||||
# Load model if needed
|
||||
load_resp = await http_client.post(
|
||||
f"{SWAPPER_URL}/load",
|
||||
json={"model": model},
|
||||
timeout=30.0
|
||||
logger.info(f"📡 Calling Swapper with model: {model}")
|
||||
# Generate response via Swapper (which handles model loading)
|
||||
generate_resp = await http_client.post(
|
||||
f"{SWAPPER_URL}/generate",
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": request.prompt,
|
||||
"system": system_prompt,
|
||||
"max_tokens": request.max_tokens,
|
||||
"temperature": request.temperature,
|
||||
"stream": False
|
||||
},
|
||||
timeout=300.0
|
||||
)
|
||||
|
||||
if load_resp.status_code == 200:
|
||||
# Generate response via Ollama
|
||||
generate_resp = await http_client.post(
|
||||
f"{VISION_URL}/api/generate",
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": request.prompt,
|
||||
"system": system_prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": request.max_tokens,
|
||||
"temperature": request.temperature
|
||||
}
|
||||
},
|
||||
timeout=120.0
|
||||
if generate_resp.status_code == 200:
|
||||
data = generate_resp.json()
|
||||
return InferResponse(
|
||||
response=data.get("response", ""),
|
||||
model=model,
|
||||
tokens_used=data.get("eval_count", 0),
|
||||
backend="swapper+ollama"
|
||||
)
|
||||
|
||||
if generate_resp.status_code == 200:
|
||||
data = generate_resp.json()
|
||||
return InferResponse(
|
||||
response=data.get("response", ""),
|
||||
model=model,
|
||||
tokens_used=data.get("eval_count"),
|
||||
backend="swapper+ollama"
|
||||
)
|
||||
else:
|
||||
logger.error(f"❌ Swapper error: {generate_resp.status_code} - {generate_resp.text}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Swapper/Ollama error: {e}")
|
||||
# Fallback to direct Ollama if Swapper fails
|
||||
try:
|
||||
logger.info(f"🔄 Falling back to direct Ollama connection")
|
||||
generate_resp = await http_client.post(
|
||||
f"{VISION_URL}/api/generate",
|
||||
json={
|
||||
"model": "qwen3:8b", # Use actual Ollama model name
|
||||
"prompt": request.prompt,
|
||||
"system": system_prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"num_predict": request.max_tokens,
|
||||
"temperature": request.temperature
|
||||
}
|
||||
},
|
||||
timeout=120.0
|
||||
)
|
||||
|
||||
if generate_resp.status_code == 200:
|
||||
data = generate_resp.json()
|
||||
return InferResponse(
|
||||
response=data.get("response", ""),
|
||||
model=model,
|
||||
tokens_used=data.get("eval_count", 0),
|
||||
backend="ollama-direct"
|
||||
)
|
||||
except Exception as e2:
|
||||
logger.error(f"❌ Direct Ollama fallback also failed: {e2}")
|
||||
|
||||
# Fallback: return error
|
||||
raise HTTPException(
|
||||
@@ -499,6 +634,290 @@ async def list_available_models():
|
||||
return {"models": models, "total": len(models)}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# NEO4J GRAPH API ENDPOINTS
|
||||
# =============================================================================
|
||||
|
||||
class GraphNode(BaseModel):
|
||||
"""Model for creating/updating a graph node"""
|
||||
label: str # Node type: User, Agent, Topic, Fact, Entity, etc.
|
||||
properties: Dict[str, Any]
|
||||
node_id: Optional[str] = None # If provided, update existing node
|
||||
|
||||
class GraphRelationship(BaseModel):
|
||||
"""Model for creating a relationship between nodes"""
|
||||
from_node_id: str
|
||||
to_node_id: str
|
||||
relationship_type: str # KNOWS, MENTIONED, RELATED_TO, CREATED_BY, etc.
|
||||
properties: Optional[Dict[str, Any]] = None
|
||||
|
||||
class GraphQuery(BaseModel):
|
||||
"""Model for querying the graph"""
|
||||
cypher: Optional[str] = None # Direct Cypher query (advanced)
|
||||
# Or use structured query:
|
||||
node_label: Optional[str] = None
|
||||
node_id: Optional[str] = None
|
||||
relationship_type: Optional[str] = None
|
||||
depth: int = 1 # How many hops to traverse
|
||||
limit: int = 50
|
||||
|
||||
class GraphSearchRequest(BaseModel):
|
||||
"""Natural language search in graph"""
|
||||
query: str
|
||||
entity_types: Optional[List[str]] = None # Filter by types
|
||||
limit: int = 20
|
||||
|
||||
|
||||
@app.post("/v1/graph/nodes")
|
||||
async def create_graph_node(node: GraphNode):
|
||||
"""Create or update a node in the knowledge graph"""
|
||||
if not neo4j_available or not neo4j_driver:
|
||||
raise HTTPException(status_code=503, detail="Neo4j not available")
|
||||
|
||||
try:
|
||||
async with neo4j_driver.session() as session:
|
||||
# Generate node_id if not provided
|
||||
node_id = node.node_id or f"{node.label.lower()}_{os.urandom(8).hex()}"
|
||||
|
||||
# Build properties with node_id
|
||||
props = {**node.properties, "node_id": node_id, "updated_at": "datetime()"}
|
||||
|
||||
# Create or merge node
|
||||
cypher = f"""
|
||||
MERGE (n:{node.label} {{node_id: $node_id}})
|
||||
SET n += $properties
|
||||
SET n.updated_at = datetime()
|
||||
RETURN n
|
||||
"""
|
||||
|
||||
result = await session.run(cypher, node_id=node_id, properties=node.properties)
|
||||
record = await result.single()
|
||||
|
||||
if record:
|
||||
created_node = dict(record["n"])
|
||||
logger.info(f"📊 Created/updated node: {node.label} - {node_id}")
|
||||
return {"status": "ok", "node_id": node_id, "node": created_node}
|
||||
|
||||
raise HTTPException(status_code=500, detail="Failed to create node")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Neo4j error creating node: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/v1/graph/relationships")
|
||||
async def create_graph_relationship(rel: GraphRelationship):
|
||||
"""Create a relationship between two nodes"""
|
||||
if not neo4j_available or not neo4j_driver:
|
||||
raise HTTPException(status_code=503, detail="Neo4j not available")
|
||||
|
||||
try:
|
||||
async with neo4j_driver.session() as session:
|
||||
props_clause = ""
|
||||
if rel.properties:
|
||||
props_clause = " SET r += $properties"
|
||||
|
||||
cypher = f"""
|
||||
MATCH (a {{node_id: $from_id}})
|
||||
MATCH (b {{node_id: $to_id}})
|
||||
MERGE (a)-[r:{rel.relationship_type}]->(b)
|
||||
{props_clause}
|
||||
SET r.created_at = datetime()
|
||||
RETURN a.node_id as from_id, b.node_id as to_id, type(r) as rel_type
|
||||
"""
|
||||
|
||||
result = await session.run(
|
||||
cypher,
|
||||
from_id=rel.from_node_id,
|
||||
to_id=rel.to_node_id,
|
||||
properties=rel.properties or {}
|
||||
)
|
||||
record = await result.single()
|
||||
|
||||
if record:
|
||||
logger.info(f"🔗 Created relationship: {rel.from_node_id} -[{rel.relationship_type}]-> {rel.to_node_id}")
|
||||
return {
|
||||
"status": "ok",
|
||||
"from_id": record["from_id"],
|
||||
"to_id": record["to_id"],
|
||||
"relationship_type": record["rel_type"]
|
||||
}
|
||||
|
||||
raise HTTPException(status_code=404, detail="One or both nodes not found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Neo4j error creating relationship: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/v1/graph/query")
|
||||
async def query_graph(query: GraphQuery):
|
||||
"""Query the knowledge graph"""
|
||||
if not neo4j_available or not neo4j_driver:
|
||||
raise HTTPException(status_code=503, detail="Neo4j not available")
|
||||
|
||||
try:
|
||||
async with neo4j_driver.session() as session:
|
||||
# If direct Cypher provided, use it (with safety check)
|
||||
if query.cypher:
|
||||
# Basic safety: only allow read queries
|
||||
if any(kw in query.cypher.upper() for kw in ["DELETE", "REMOVE", "DROP", "CREATE", "MERGE", "SET"]):
|
||||
raise HTTPException(status_code=400, detail="Only read queries allowed via cypher parameter")
|
||||
|
||||
result = await session.run(query.cypher)
|
||||
records = await result.data()
|
||||
return {"status": "ok", "results": records, "count": len(records)}
|
||||
|
||||
# Build structured query
|
||||
if query.node_id:
|
||||
# Get specific node with relationships
|
||||
cypher = f"""
|
||||
MATCH (n {{node_id: $node_id}})
|
||||
OPTIONAL MATCH (n)-[r]-(related)
|
||||
RETURN n, collect({{rel: type(r), node: related}}) as connections
|
||||
LIMIT 1
|
||||
"""
|
||||
result = await session.run(cypher, node_id=query.node_id)
|
||||
|
||||
elif query.node_label:
|
||||
# Get nodes by label
|
||||
cypher = f"""
|
||||
MATCH (n:{query.node_label})
|
||||
RETURN n
|
||||
ORDER BY n.updated_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
result = await session.run(cypher, limit=query.limit)
|
||||
|
||||
else:
|
||||
# Get recent nodes
|
||||
cypher = """
|
||||
MATCH (n)
|
||||
RETURN n, labels(n) as labels
|
||||
ORDER BY n.updated_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
result = await session.run(cypher, limit=query.limit)
|
||||
|
||||
records = await result.data()
|
||||
return {"status": "ok", "results": records, "count": len(records)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Neo4j query error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/v1/graph/search")
|
||||
async def search_graph(q: str, types: Optional[str] = None, limit: int = 20):
|
||||
"""Search nodes by text in properties"""
|
||||
if not neo4j_available or not neo4j_driver:
|
||||
raise HTTPException(status_code=503, detail="Neo4j not available")
|
||||
|
||||
try:
|
||||
async with neo4j_driver.session() as session:
|
||||
# Build label filter
|
||||
label_filter = ""
|
||||
if types:
|
||||
labels = [t.strip() for t in types.split(",")]
|
||||
label_filter = " AND (" + " OR ".join([f"n:{l}" for l in labels]) + ")"
|
||||
|
||||
# Search in common text properties
|
||||
cypher = f"""
|
||||
MATCH (n)
|
||||
WHERE (
|
||||
n.name CONTAINS $query OR
|
||||
n.title CONTAINS $query OR
|
||||
n.text CONTAINS $query OR
|
||||
n.description CONTAINS $query OR
|
||||
n.content CONTAINS $query
|
||||
){label_filter}
|
||||
RETURN n, labels(n) as labels
|
||||
ORDER BY n.updated_at DESC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
result = await session.run(cypher, query=q, limit=limit)
|
||||
records = await result.data()
|
||||
|
||||
return {"status": "ok", "query": q, "results": records, "count": len(records)}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Neo4j search error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/v1/graph/stats")
|
||||
async def get_graph_stats():
|
||||
"""Get knowledge graph statistics"""
|
||||
if not neo4j_available or not neo4j_driver:
|
||||
raise HTTPException(status_code=503, detail="Neo4j not available")
|
||||
|
||||
try:
|
||||
async with neo4j_driver.session() as session:
|
||||
# Get node counts by label
|
||||
labels_result = await session.run("""
|
||||
CALL db.labels() YIELD label
|
||||
CALL apoc.cypher.run('MATCH (n:`' + label + '`) RETURN count(n) as count', {}) YIELD value
|
||||
RETURN label, value.count as count
|
||||
""")
|
||||
|
||||
# If APOC not available, use simpler query
|
||||
try:
|
||||
labels_data = await labels_result.data()
|
||||
except:
|
||||
labels_result = await session.run("""
|
||||
MATCH (n)
|
||||
RETURN labels(n)[0] as label, count(*) as count
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
labels_data = await labels_result.data()
|
||||
|
||||
# Get relationship counts
|
||||
rels_result = await session.run("""
|
||||
MATCH ()-[r]->()
|
||||
RETURN type(r) as type, count(*) as count
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
rels_data = await rels_result.data()
|
||||
|
||||
# Get total counts
|
||||
total_result = await session.run("""
|
||||
MATCH (n) RETURN count(n) as nodes
|
||||
""")
|
||||
total_nodes = (await total_result.single())["nodes"]
|
||||
|
||||
total_rels_result = await session.run("""
|
||||
MATCH ()-[r]->() RETURN count(r) as relationships
|
||||
""")
|
||||
total_rels = (await total_rels_result.single())["relationships"]
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"total_nodes": total_nodes,
|
||||
"total_relationships": total_rels,
|
||||
"nodes_by_label": labels_data,
|
||||
"relationships_by_type": rels_data
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Neo4j stats error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Cleanup connections on shutdown"""
|
||||
global neo4j_driver, http_client, nc
|
||||
|
||||
if neo4j_driver:
|
||||
await neo4j_driver.close()
|
||||
logger.info("🔌 Neo4j connection closed")
|
||||
|
||||
if http_client:
|
||||
await http_client.aclose()
|
||||
logger.info("🔌 HTTP client closed")
|
||||
|
||||
if nc:
|
||||
await nc.close()
|
||||
logger.info("🔌 NATS connection closed")
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ pydantic==2.5.0
|
||||
nats-py==2.6.0
|
||||
PyYAML==6.0.1
|
||||
httpx>=0.25.0
|
||||
neo4j>=5.14.0
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,21 +1,30 @@
|
||||
FROM python:3.11-slim
|
||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
||||
|
||||
# Встановити wget для healthcheck
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends wget \
|
||||
# Install Python and system deps (including ffmpeg for audio processing)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.11 \
|
||||
python3-pip \
|
||||
wget \
|
||||
curl \
|
||||
git \
|
||||
ffmpeg \
|
||||
libsndfile1 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install Python dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
RUN pip3 install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application
|
||||
COPY app/ ./app/
|
||||
COPY config/ ./config/
|
||||
|
||||
EXPOSE 8890
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=60s --retries=3 \
|
||||
CMD wget -qO- http://localhost:8890/health || exit 1
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8890"]
|
||||
|
||||
CMD ["python3", "-m", "uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8890"]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,64 +1,220 @@
|
||||
# Swapper Configuration for Node #1 (Production Server)
|
||||
# Single-active LLM scheduler
|
||||
# Optimized Multimodal Stack: LLM + Vision + OCR + Document + Audio
|
||||
# Hetzner GEX44 - NVIDIA RTX 4000 SFF Ada (20GB VRAM)
|
||||
#
|
||||
# ВАЖЛИВО: Ембедінги через зовнішні API:
|
||||
# - Text: Cohere API (embed-multilingual-v3.0, 1024 dim)
|
||||
# - Image: Vision Encoder (OpenCLIP ViT-L/14, 768 dim)
|
||||
# НЕ використовуємо локальні embedding моделі!
|
||||
|
||||
swapper:
|
||||
mode: single-active
|
||||
max_concurrent_models: 1
|
||||
mode: multi-active
|
||||
max_concurrent_models: 4 # LLM + OCR + STT + TTS (до 15GB)
|
||||
model_swap_timeout: 300
|
||||
gpu_enabled: true
|
||||
metal_acceleration: false # NVIDIA GPU, not Apple Silicon
|
||||
# Модель для автоматичного завантаження при старті
|
||||
# qwen3-8b - основна модель (4.87 GB), швидка відповідь на перший запит
|
||||
metal_acceleration: false
|
||||
default_model: qwen3-8b
|
||||
lazy_load_ocr: true
|
||||
lazy_load_audio: true
|
||||
# Автоматичне вивантаження при нестачі VRAM
|
||||
auto_unload_on_oom: true
|
||||
vram_threshold_gb: 18 # Починати вивантажувати при 18GB
|
||||
|
||||
models:
|
||||
# Primary LLM - Qwen3 8B (High Priority) - Main model from INFRASTRUCTURE.md
|
||||
# ============================================
|
||||
# LLM MODELS (Ollama) - тільки qwen3
|
||||
# ============================================
|
||||
|
||||
# Primary LLM - Qwen3 8B (includes math, coding, reasoning)
|
||||
qwen3-8b:
|
||||
path: ollama:qwen3:8b
|
||||
type: llm
|
||||
size_gb: 4.87
|
||||
size_gb: 5.2
|
||||
priority: high
|
||||
description: "Primary LLM for general tasks and conversations"
|
||||
|
||||
# Vision Model - Qwen3-VL 8B (High Priority) - For image processing
|
||||
description: "Qwen3 8B - primary LLM with math, coding, reasoning capabilities"
|
||||
capabilities:
|
||||
- chat
|
||||
- math
|
||||
- coding
|
||||
- reasoning
|
||||
- multilingual
|
||||
|
||||
# ============================================
|
||||
# VISION MODELS (Ollama)
|
||||
# ============================================
|
||||
|
||||
# Vision Model - Qwen3-VL 8B
|
||||
qwen3-vl-8b:
|
||||
path: ollama:qwen3-vl:8b
|
||||
type: vision
|
||||
size_gb: 5.72
|
||||
size_gb: 6.1
|
||||
priority: high
|
||||
description: "Vision model for image understanding and processing"
|
||||
|
||||
# Qwen2.5 7B Instruct (High Priority)
|
||||
qwen2.5-7b-instruct:
|
||||
path: ollama:qwen2.5:7b-instruct-q4_K_M
|
||||
type: llm
|
||||
size_gb: 4.36
|
||||
description: "Qwen3-VL 8B for image understanding and visual reasoning"
|
||||
capabilities:
|
||||
- image_understanding
|
||||
- visual_qa
|
||||
- diagram_analysis
|
||||
- ocr_basic
|
||||
|
||||
# ============================================
|
||||
# OCR/DOCUMENT MODELS (HuggingFace)
|
||||
# ============================================
|
||||
|
||||
# GOT-OCR2.0 - Best for documents, tables, formulas
|
||||
got-ocr2:
|
||||
path: huggingface:stepfun-ai/GOT-OCR2_0
|
||||
type: ocr
|
||||
size_gb: 7.0
|
||||
priority: high
|
||||
description: "Qwen2.5 7B Instruct model"
|
||||
description: "Best OCR for documents, tables, formulas, handwriting"
|
||||
capabilities:
|
||||
- documents
|
||||
- tables
|
||||
- formulas
|
||||
- handwriting
|
||||
- multilingual
|
||||
|
||||
# Lightweight LLM - Qwen2.5 3B Instruct (Medium Priority)
|
||||
qwen2.5-3b-instruct:
|
||||
path: ollama:qwen2.5:3b-instruct-q4_K_M
|
||||
type: llm
|
||||
size_gb: 1.80
|
||||
# Donut - Document Understanding (no external OCR, 91% CORD)
|
||||
donut-base:
|
||||
path: huggingface:naver-clova-ix/donut-base
|
||||
type: ocr
|
||||
size_gb: 3.0
|
||||
priority: high
|
||||
description: "Document parsing without OCR engine (91% CORD accuracy)"
|
||||
capabilities:
|
||||
- document_parsing
|
||||
- receipts
|
||||
- forms
|
||||
- invoices
|
||||
|
||||
# Donut fine-tuned for receipts/invoices (CORD dataset)
|
||||
donut-cord:
|
||||
path: huggingface:naver-clova-ix/donut-base-finetuned-cord-v2
|
||||
type: ocr
|
||||
size_gb: 3.0
|
||||
priority: medium
|
||||
description: "Lightweight LLM for faster responses"
|
||||
|
||||
# Math Specialist - Qwen2 Math 7B (High Priority)
|
||||
qwen2-math-7b:
|
||||
path: ollama:qwen2-math:7b
|
||||
type: math
|
||||
size_gb: 4.13
|
||||
description: "Donut fine-tuned for receipts extraction"
|
||||
capabilities:
|
||||
- receipts
|
||||
- invoices
|
||||
- structured_extraction
|
||||
|
||||
# IBM Granite Docling - Document conversion with structure preservation
|
||||
granite-docling:
|
||||
path: huggingface:ds4sd/docling-ibm-granite-vision-1b
|
||||
type: document
|
||||
size_gb: 2.5
|
||||
priority: high
|
||||
description: "Specialized model for mathematical tasks"
|
||||
description: "IBM Granite Docling for PDF/document structure extraction"
|
||||
capabilities:
|
||||
- pdf_conversion
|
||||
- table_extraction
|
||||
- formula_extraction
|
||||
- layout_preservation
|
||||
- doctags_format
|
||||
|
||||
# ============================================
|
||||
# AUDIO MODELS - STT (Speech-to-Text)
|
||||
# ============================================
|
||||
|
||||
# Faster Whisper Large-v3 - Best STT quality
|
||||
faster-whisper-large:
|
||||
path: huggingface:Systran/faster-whisper-large-v3
|
||||
type: stt
|
||||
size_gb: 3.0
|
||||
priority: high
|
||||
description: "Faster Whisper Large-v3 - best quality, 99 languages"
|
||||
capabilities:
|
||||
- speech_recognition
|
||||
- transcription
|
||||
- multilingual
|
||||
- timestamps
|
||||
- ukrainian
|
||||
|
||||
# Whisper Small - Fast/lightweight for quick transcription
|
||||
whisper-small:
|
||||
path: huggingface:openai/whisper-small
|
||||
type: stt
|
||||
size_gb: 0.5
|
||||
priority: medium
|
||||
description: "Whisper Small for fast transcription"
|
||||
capabilities:
|
||||
- speech_recognition
|
||||
- transcription
|
||||
|
||||
# ============================================
|
||||
# AUDIO MODELS - TTS (Text-to-Speech)
|
||||
# ============================================
|
||||
|
||||
# Coqui XTTS-v2 - Best multilingual TTS with Ukrainian support
|
||||
xtts-v2:
|
||||
path: huggingface:coqui/XTTS-v2
|
||||
type: tts
|
||||
size_gb: 2.0
|
||||
priority: high
|
||||
description: "XTTS-v2 multilingual TTS with voice cloning, Ukrainian support"
|
||||
capabilities:
|
||||
- text_to_speech
|
||||
- voice_cloning
|
||||
- multilingual
|
||||
- ukrainian
|
||||
- 17_languages
|
||||
|
||||
# ============================================
|
||||
# IMAGE GENERATION MODELS (HuggingFace/Diffusers)
|
||||
# ============================================
|
||||
|
||||
# FLUX.2 Klein 4B - High quality image generation with lazy loading
|
||||
flux-klein-4b:
|
||||
path: huggingface:black-forest-labs/FLUX.2-klein-base-4B
|
||||
type: image_generation
|
||||
size_gb: 15.4
|
||||
priority: medium
|
||||
description: "FLUX.2 Klein 4B - high quality image generation, lazy loaded on demand"
|
||||
capabilities:
|
||||
- text_to_image
|
||||
- high_quality
|
||||
- 1024x1024
|
||||
- artistic
|
||||
default_params:
|
||||
num_inference_steps: 50
|
||||
guidance_scale: 4.0
|
||||
width: 1024
|
||||
height: 1024
|
||||
|
||||
storage:
|
||||
models_dir: /app/models
|
||||
cache_dir: /app/cache
|
||||
swap_dir: /app/swap
|
||||
huggingface_cache: /root/.cache/huggingface
|
||||
|
||||
ollama:
|
||||
url: http://ollama:11434 # From Docker container to Ollama service
|
||||
url: http://172.18.0.1:11434
|
||||
timeout: 300
|
||||
|
||||
huggingface:
|
||||
device: cuda
|
||||
torch_dtype: float16
|
||||
trust_remote_code: true
|
||||
low_cpu_mem_usage: true
|
||||
|
||||
# ============================================
|
||||
# EMBEDDING SERVICES (External APIs)
|
||||
# НЕ через Swapper - окремі сервіси!
|
||||
# ============================================
|
||||
#
|
||||
# Text Embeddings:
|
||||
# Service: Memory Service → Cohere API
|
||||
# Model: embed-multilingual-v3.0
|
||||
# Dimension: 1024
|
||||
# Endpoint: Memory Service handles internally
|
||||
#
|
||||
# Image/Multimodal Embeddings:
|
||||
# Service: Vision Encoder (port 8001)
|
||||
# Model: OpenCLIP ViT-L/14
|
||||
# Dimension: 768
|
||||
# Endpoint: http://vision-encoder:8001/embed
|
||||
#
|
||||
# Vector Storage:
|
||||
# Qdrant (port 6333) - separate collections for text vs image embeddings
|
||||
# ВАЖЛИВО: НЕ змішувати embedding spaces в одній колекції!
|
||||
|
||||
63
services/swapper-service/config/swapper_config_node3.yaml
Normal file
63
services/swapper-service/config/swapper_config_node3.yaml
Normal file
@@ -0,0 +1,63 @@
|
||||
# Swapper Configuration for Node #3 (AI/ML Workstation)
|
||||
# Single-active LLM scheduler
|
||||
# Threadripper PRO + RTX 3090 24GB - GPU-intensive workloads
|
||||
|
||||
swapper:
|
||||
mode: single-active
|
||||
max_concurrent_models: 1
|
||||
model_swap_timeout: 300
|
||||
gpu_enabled: true
|
||||
metal_acceleration: false # NVIDIA GPU, not Apple Silicon
|
||||
# Модель для автоматичного завантаження при старті
|
||||
# qwen3-8b - основна модель (4.87 GB), швидка відповідь на перший запит
|
||||
default_model: qwen3-8b
|
||||
|
||||
models:
|
||||
# Primary LLM - Qwen3 8B (High Priority) - Main model from INFRASTRUCTURE.md
|
||||
qwen3-8b:
|
||||
path: ollama:qwen3:8b
|
||||
type: llm
|
||||
size_gb: 4.87
|
||||
priority: high
|
||||
description: "Primary LLM for general tasks and conversations"
|
||||
|
||||
# Vision Model - Qwen3-VL 8B (High Priority) - For image processing
|
||||
qwen3-vl-8b:
|
||||
path: ollama:qwen3-vl:8b
|
||||
type: vision
|
||||
size_gb: 5.72
|
||||
priority: high
|
||||
description: "Vision model for image understanding and processing"
|
||||
|
||||
# Qwen2.5 7B Instruct (High Priority)
|
||||
qwen2.5-7b-instruct:
|
||||
path: ollama:qwen2.5:7b-instruct-q4_K_M
|
||||
type: llm
|
||||
size_gb: 4.36
|
||||
priority: high
|
||||
description: "Qwen2.5 7B Instruct model"
|
||||
|
||||
# Lightweight LLM - Qwen2.5 3B Instruct (Medium Priority)
|
||||
qwen2.5-3b-instruct:
|
||||
path: ollama:qwen2.5:3b-instruct-q4_K_M
|
||||
type: llm
|
||||
size_gb: 1.80
|
||||
priority: medium
|
||||
description: "Lightweight LLM for faster responses"
|
||||
|
||||
# Math Specialist - Qwen2 Math 7B (High Priority)
|
||||
qwen2-math-7b:
|
||||
path: ollama:qwen2-math:7b
|
||||
type: math
|
||||
size_gb: 4.13
|
||||
priority: high
|
||||
description: "Specialized model for mathematical tasks"
|
||||
|
||||
storage:
|
||||
models_dir: /app/models
|
||||
cache_dir: /app/cache
|
||||
swap_dir: /app/swap
|
||||
|
||||
ollama:
|
||||
url: http://ollama:11434 # From Docker container to Ollama service
|
||||
timeout: 300
|
||||
@@ -5,3 +5,31 @@ pydantic==2.5.0
|
||||
pyyaml==6.0.1
|
||||
python-multipart==0.0.6
|
||||
|
||||
# HuggingFace dependencies for OCR models
|
||||
torch>=2.0.0
|
||||
torchvision>=0.15.0
|
||||
transformers>=4.35.0
|
||||
accelerate>=0.25.0
|
||||
pillow>=10.0.0
|
||||
tiktoken>=0.5.0
|
||||
sentencepiece>=0.1.99
|
||||
einops>=0.7.0
|
||||
|
||||
# STT (Speech-to-Text) dependencies
|
||||
faster-whisper>=1.0.0
|
||||
openai-whisper>=20231117
|
||||
|
||||
# Image Generation (Diffusion models)
|
||||
diffusers @ git+https://github.com/huggingface/diffusers.git
|
||||
safetensors>=0.4.0
|
||||
|
||||
# Web Scraping & Search
|
||||
trafilatura>=1.6.0
|
||||
duckduckgo-search>=4.0.0
|
||||
lxml_html_clean>=0.1.0
|
||||
|
||||
# TTS (Text-to-Speech)
|
||||
TTS>=0.22.0
|
||||
|
||||
# Document Processing
|
||||
docling>=2.0.0
|
||||
Reference in New Issue
Block a user