Files
microdao-daarion/metrics_middleware.py

212 lines
5.6 KiB
Python

"""
Prometheus Metrics Middleware for DAGI Router
"""
import time
from prometheus_client import Counter, Histogram, Gauge, generate_latest, CONTENT_TYPE_LATEST
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
import logging
logger = logging.getLogger(__name__)
# ============================================================================
# Metrics Definitions
# ============================================================================
# Request counters
http_requests_total = Counter(
'http_requests_total',
'Total HTTP requests',
['method', 'endpoint', 'status']
)
# Request latency histogram
http_request_duration_seconds = Histogram(
'http_request_duration_seconds',
'HTTP request latency in seconds',
['method', 'endpoint']
)
# Active requests gauge
http_requests_in_progress = Gauge(
'http_requests_in_progress',
'Number of HTTP requests in progress',
['method', 'endpoint']
)
# LLM-specific metrics
llm_requests_total = Counter(
'llm_requests_total',
'Total LLM requests',
['agent_id', 'provider', 'status']
)
llm_request_duration_seconds = Histogram(
'llm_request_duration_seconds',
'LLM request latency in seconds',
['agent_id', 'provider']
)
llm_tokens_total = Counter(
'llm_tokens_total',
'Total LLM tokens used',
['agent_id', 'provider', 'type'] # type: prompt/completion
)
llm_errors_total = Counter(
'llm_errors_total',
'Total LLM errors',
['agent_id', 'provider', 'error_type']
)
# Router-specific metrics
router_agent_requests = Counter(
'router_agent_requests',
'Total requests per agent',
['agent_id', 'mode']
)
router_provider_usage = Counter(
'router_provider_usage',
'Provider usage counts',
['provider']
)
# ============================================================================
# Middleware
# ============================================================================
class PrometheusMiddleware(BaseHTTPMiddleware):
"""
FastAPI middleware for Prometheus metrics collection
"""
async def dispatch(self, request: Request, call_next):
# Skip metrics endpoint itself
if request.url.path == "/metrics":
return await call_next(request)
# Extract endpoint (path template)
endpoint = request.url.path
method = request.method
# Track in-progress requests
http_requests_in_progress.labels(method=method, endpoint=endpoint).inc()
# Measure request duration
start_time = time.time()
try:
response = await call_next(request)
status_code = response.status_code
except Exception as e:
logger.error(f"Request failed: {e}")
status_code = 500
raise
finally:
# Record metrics
duration = time.time() - start_time
http_requests_total.labels(
method=method,
endpoint=endpoint,
status=status_code
).inc()
http_request_duration_seconds.labels(
method=method,
endpoint=endpoint
).observe(duration)
http_requests_in_progress.labels(method=method, endpoint=endpoint).dec()
return response
# ============================================================================
# Metrics Endpoint
# ============================================================================
def metrics_endpoint():
"""
Generate Prometheus metrics in text format
"""
return Response(
content=generate_latest(),
media_type=CONTENT_TYPE_LATEST
)
# ============================================================================
# Helper Functions
# ============================================================================
def track_llm_request(agent_id: str, provider: str, duration: float, tokens: dict = None, error: str = None):
"""
Track LLM request metrics
Args:
agent_id: Agent identifier (e.g., "daarwizz", "helion")
provider: LLM provider (e.g., "ollama", "deepseek")
duration: Request duration in seconds
tokens: Token usage dict with "prompt" and "completion" keys
error: Error type if request failed
"""
status = "error" if error else "success"
llm_requests_total.labels(
agent_id=agent_id,
provider=provider,
status=status
).inc()
if not error:
llm_request_duration_seconds.labels(
agent_id=agent_id,
provider=provider
).observe(duration)
if tokens:
llm_tokens_total.labels(
agent_id=agent_id,
provider=provider,
type="prompt"
).inc(tokens.get("prompt", 0))
llm_tokens_total.labels(
agent_id=agent_id,
provider=provider,
type="completion"
).inc(tokens.get("completion", 0))
else:
llm_errors_total.labels(
agent_id=agent_id,
provider=provider,
error_type=error
).inc()
def track_agent_request(agent_id: str, mode: str):
"""
Track agent request
Args:
agent_id: Agent identifier
mode: Request mode (chat, doc_parse, rag_query, etc.)
"""
router_agent_requests.labels(
agent_id=agent_id,
mode=mode
).inc()
def track_provider_usage(provider: str):
"""
Track provider usage
Args:
provider: Provider name
"""
router_provider_usage.labels(provider=provider).inc()