213 lines
7.0 KiB
Python
213 lines
7.0 KiB
Python
import asyncio
|
|
from datetime import datetime, timezone
|
|
import hmac
|
|
import json
|
|
import os
|
|
import uuid
|
|
from typing import Any, Dict, List
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, HTTPException, Request, status
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel, Field
|
|
|
|
from .redis_jobs import create_job, enqueue_job, get_job
|
|
from .registry_api import _load_registry
|
|
|
|
router = APIRouter(prefix="/v1", tags=["daarion-facade"])
|
|
|
|
EVENT_TERMINAL_STATUSES = {"done", "failed"}
|
|
EVENT_KNOWN_STATUSES = {"queued", "running", "done", "failed"}
|
|
EVENT_POLL_SECONDS = float(os.getenv("DAARION_JOB_EVENTS_POLL_SECONDS", "0.5"))
|
|
ROUTER_URL = os.getenv("ROUTER_URL", "http://router:8000").rstrip("/")
|
|
ROUTER_REVIEW_TIMEOUT = float(os.getenv("DAARION_ROUTER_REVIEW_TIMEOUT_SECONDS", "20"))
|
|
AGROMATRIX_REVIEW_AUTH_MODE = os.getenv("AGROMATRIX_REVIEW_AUTH_MODE", "bearer").strip().lower()
|
|
AGROMATRIX_REVIEW_BEARER_TOKENS = [
|
|
part.strip()
|
|
for part in os.getenv("AGROMATRIX_REVIEW_BEARER_TOKENS", "").replace(";", ",").split(",")
|
|
if part.strip()
|
|
]
|
|
|
|
|
|
class InvokeInput(BaseModel):
|
|
prompt: str = Field(min_length=1)
|
|
images: List[str] = Field(default_factory=list)
|
|
|
|
|
|
class InvokeRequest(BaseModel):
|
|
agent_id: str
|
|
input: InvokeInput
|
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class InvokeResponse(BaseModel):
|
|
job_id: str
|
|
status: str
|
|
status_url: str
|
|
|
|
|
|
class SharedMemoryReviewRequest(BaseModel):
|
|
point_id: str
|
|
approve: bool
|
|
reviewer: str | None = None
|
|
note: str | None = None
|
|
|
|
|
|
def _extract_bearer_token(request: Request) -> str:
|
|
auth_header = request.headers.get("Authorization", "")
|
|
if not auth_header.startswith("Bearer "):
|
|
raise HTTPException(status_code=401, detail="Missing Bearer token")
|
|
token = auth_header[len("Bearer ") :].strip()
|
|
if not token:
|
|
raise HTTPException(status_code=401, detail="Empty Bearer token")
|
|
return token
|
|
|
|
|
|
def _require_mentor_auth(request: Request) -> str:
|
|
mode = AGROMATRIX_REVIEW_AUTH_MODE
|
|
if mode in {"off", "none", "disabled"}:
|
|
return ""
|
|
if mode != "bearer":
|
|
raise HTTPException(status_code=500, detail=f"Unsupported AGROMATRIX_REVIEW_AUTH_MODE={mode}")
|
|
if not AGROMATRIX_REVIEW_BEARER_TOKENS:
|
|
raise HTTPException(status_code=503, detail="Review auth is not configured")
|
|
token = _extract_bearer_token(request)
|
|
if not any(hmac.compare_digest(token, candidate) for candidate in AGROMATRIX_REVIEW_BEARER_TOKENS):
|
|
raise HTTPException(status_code=403, detail="Invalid mentor token")
|
|
return token
|
|
|
|
|
|
async def _router_json(
|
|
method: str,
|
|
path: str,
|
|
*,
|
|
payload: Dict[str, Any] | None = None,
|
|
params: Dict[str, Any] | None = None,
|
|
authorization: str | None = None,
|
|
) -> Dict[str, Any]:
|
|
headers: Dict[str, str] = {}
|
|
if authorization:
|
|
headers["Authorization"] = authorization
|
|
url = f"{ROUTER_URL}{path}"
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=ROUTER_REVIEW_TIMEOUT) as client:
|
|
resp = await client.request(method, url, json=payload, params=params, headers=headers)
|
|
except httpx.TimeoutException:
|
|
raise HTTPException(status_code=504, detail="Router timeout")
|
|
except Exception as e:
|
|
raise HTTPException(status_code=502, detail=f"Router unavailable: {e}")
|
|
|
|
try:
|
|
body = resp.json()
|
|
except Exception:
|
|
body = {"raw": resp.text}
|
|
|
|
if resp.status_code >= 400:
|
|
detail = body.get("detail") if isinstance(body, dict) else body
|
|
raise HTTPException(status_code=resp.status_code, detail=detail or f"Router error {resp.status_code}")
|
|
return body if isinstance(body, dict) else {"data": body}
|
|
|
|
|
|
def _sse_message(event: str, payload: Dict[str, Any]) -> str:
|
|
return f"event: {event}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
|
|
|
|
|
@router.post("/invoke", status_code=status.HTTP_202_ACCEPTED, response_model=InvokeResponse)
|
|
async def invoke(payload: InvokeRequest) -> InvokeResponse:
|
|
registry = _load_registry().get("agents", {})
|
|
if payload.agent_id not in registry:
|
|
raise HTTPException(status_code=404, detail=f"Unknown agent_id: {payload.agent_id}")
|
|
|
|
job_id = f"job_{uuid.uuid4().hex}"
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
job_doc = {
|
|
"job_id": job_id,
|
|
"status": "queued",
|
|
"agent_id": payload.agent_id,
|
|
"input": payload.input.model_dump(),
|
|
"metadata": payload.metadata,
|
|
"result": None,
|
|
"error": None,
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"started_at": None,
|
|
"finished_at": None,
|
|
}
|
|
await create_job(job_id, job_doc)
|
|
await enqueue_job(job_id)
|
|
return InvokeResponse(job_id=job_id, status="queued", status_url=f"/v1/jobs/{job_id}")
|
|
|
|
|
|
@router.get("/jobs/{job_id}")
|
|
async def job_status(job_id: str) -> Dict[str, Any]:
|
|
job = await get_job(job_id)
|
|
if not job:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
return job
|
|
|
|
|
|
@router.get("/jobs/{job_id}/events")
|
|
async def job_events(job_id: str, request: Request) -> StreamingResponse:
|
|
existing = await get_job(job_id)
|
|
if not existing:
|
|
raise HTTPException(status_code=404, detail="Job not found")
|
|
|
|
async def event_stream():
|
|
last_state = None
|
|
yield "retry: 1000\n\n"
|
|
|
|
while True:
|
|
if await request.is_disconnected():
|
|
break
|
|
|
|
job = await get_job(job_id)
|
|
if not job:
|
|
yield _sse_message("failed", {"job_id": job_id, "status": "failed", "error": {"message": "Job not found"}})
|
|
break
|
|
|
|
status_value = str(job.get("status", "unknown"))
|
|
updated_at = str(job.get("updated_at", ""))
|
|
state = (status_value, updated_at)
|
|
|
|
if state != last_state:
|
|
event_name = status_value if status_value in EVENT_KNOWN_STATUSES else "status"
|
|
yield _sse_message(event_name, job)
|
|
last_state = state
|
|
|
|
if status_value in EVENT_TERMINAL_STATUSES:
|
|
break
|
|
|
|
await asyncio.sleep(EVENT_POLL_SECONDS)
|
|
|
|
return StreamingResponse(
|
|
event_stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|
|
|
|
|
|
@router.get("/agromatrix/shared-memory/pending")
|
|
async def agromatrix_shared_pending(limit: int = 50) -> Dict[str, Any]:
|
|
return await _router_json(
|
|
"GET",
|
|
"/v1/agromatrix/shared-memory/pending",
|
|
params={"limit": max(1, min(limit, 200))},
|
|
)
|
|
|
|
|
|
@router.post("/agromatrix/shared-memory/review")
|
|
async def agromatrix_shared_review(req: SharedMemoryReviewRequest, request: Request) -> Dict[str, Any]:
|
|
token = _require_mentor_auth(request)
|
|
auth_header = f"Bearer {token}" if token else None
|
|
return await _router_json(
|
|
"POST",
|
|
"/v1/agromatrix/shared-memory/review",
|
|
payload=req.model_dump(),
|
|
authorization=auth_header,
|
|
)
|