""" Tool Governance: RBAC enforcement, Safety Middleware, Audit. Applies to ALL /v1/tools/* dispatch. Components: 1. RBAC Matrix enforcement – deny without entitlement 2. Tool Safety Middleware – limits, redaction, allowlist, audit 3. Audit events – structured per-call events (no payload, only metadata) Usage (in tool_manager.py execute_tool): from tool_governance import ToolGovernance governance = ToolGovernance() # Pre-call check = governance.pre_call(tool_name, action, agent_id, user_id, workspace_id, input_text) if not check.allowed: return ToolResult(success=False, error=check.reason) # Execute actual tool handler ... result = await _actual_handler(args) # Post-call governance.post_call(check.call_ctx, result, duration_ms) """ import hashlib import ipaddress import json import logging import os import re import time import uuid from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) # ─── Config Paths ───────────────────────────────────────────────────────────── _CONFIG_DIR = Path(__file__).parent.parent.parent / "config" _RBAC_PATH = _CONFIG_DIR / "rbac_tools_matrix.yml" _LIMITS_PATH = _CONFIG_DIR / "tool_limits.yml" _ALLOWLIST_PATH = _CONFIG_DIR / "network_allowlist.yml" # ─── Data Classes ───────────────────────────────────────────────────────────── @dataclass class CallContext: req_id: str tool: str action: str agent_id: str user_id: str workspace_id: str ts_start: float input_hash: str input_chars: int limits_applied: Dict[str, Any] = field(default_factory=dict) @dataclass class PreCallResult: allowed: bool reason: str = "" call_ctx: Optional[CallContext] = None @dataclass class AuditEvent: ts: str req_id: str tool: str action: str workspace_id: str user_id: str agent_id: str status: str # "pass" | "deny" | "error" duration_ms: float limits_applied: Dict[str, Any] input_hash: str input_chars: int output_size_bytes: int # ─── YAML Loader (lazy, cached) ─────────────────────────────────────────────── _yaml_cache: Dict[str, Any] = {} def _load_yaml(path: Path) -> dict: key = str(path) if key not in _yaml_cache: try: import yaml with open(path, "r") as f: _yaml_cache[key] = yaml.safe_load(f) or {} except Exception as e: logger.warning(f"Could not load {path}: {e}") _yaml_cache[key] = {} return _yaml_cache[key] def _reload_yaml_cache(): """Force reload all yaml caches (for tests / hot-reload).""" _yaml_cache.clear() # ─── Secret Redaction ───────────────────────────────────────────────────────── _SECRET_PATTERNS = [ # API keys / tokens re.compile( r'(?i)(api[_-]?key|token|secret|password|passwd|pwd|auth|bearer|jwt|' r'oauth|private[_-]?key|sk-|ghp_|xoxb-|AKIA|client_secret)' r'[\s=:]+[\'"`]?([a-zA-Z0-9_\-\.]{8,})[\'"`]?', re.MULTILINE, ), # Generic high-entropy strings after known labels re.compile( r'(?i)(credential|access[_-]?key|refresh[_-]?token|signing[_-]?key)' r'[\s=:]+[\'"`]?([a-zA-Z0-9/+]{20,}={0,2})[\'"`]?', re.MULTILINE, ), ] def redact(text: str) -> str: """Mask secret values in text. Always enabled by default.""" if not text: return text for pat in _SECRET_PATTERNS: def _replace(m): label = m.group(1) return f"{label}=***REDACTED***" text = pat.sub(_replace, text) return text # ─── Network Allowlist Check ────────────────────────────────────────────────── _PRIVATE_RANGES = [ ipaddress.ip_network("10.0.0.0/8"), ipaddress.ip_network("172.16.0.0/12"), ipaddress.ip_network("192.168.0.0/16"), ipaddress.ip_network("127.0.0.0/8"), ipaddress.ip_network("169.254.0.0/16"), ipaddress.ip_network("::1/128"), ipaddress.ip_network("fc00::/7"), ] def _is_private_ip(host: str) -> bool: try: addr = ipaddress.ip_address(host) return any(addr in net for net in _PRIVATE_RANGES) except ValueError: return False def check_url_allowed(tool: str, url: str) -> Tuple[bool, str]: """ Check if a URL is allowed for a given tool per network_allowlist.yml. Returns (allowed, reason). """ import urllib.parse parsed = urllib.parse.urlparse(url) host = parsed.hostname or "" scheme = parsed.scheme or "https" allowlist_cfg = _load_yaml(_ALLOWLIST_PATH) tool_cfg = allowlist_cfg.get(tool, {}) if not tool_cfg: # No config: deny by default (safe default) return False, f"No allowlist config for tool '{tool}'" # Check scheme allowed_schemes = tool_cfg.get("schemes", ["https"]) if scheme not in allowed_schemes: return False, f"Scheme '{scheme}' not allowed for tool '{tool}'" # Check allow_any_public flag if tool_cfg.get("allow_any_public"): if tool_cfg.get("block_private_ranges") and _is_private_ip(host): return False, f"Private IP blocked: {host}" return True, "" # Check explicit hosts allowed_hosts = tool_cfg.get("hosts", []) if host in allowed_hosts: return True, "" return False, f"Host '{host}' not in allowlist for tool '{tool}'" # ─── RBAC Matrix ────────────────────────────────────────────────────────────── def _get_agent_role(agent_id: str) -> str: """Resolve agent role (delegates to agent_tools_config).""" try: from agent_tools_config import get_agent_role return get_agent_role(agent_id) except Exception: return "agent_default" def _get_role_entitlements(role: str) -> List[str]: """Get entitlements for a role from RBAC matrix.""" matrix = _load_yaml(_RBAC_PATH) role_entitlements = matrix.get("role_entitlements", {}) return role_entitlements.get(role, role_entitlements.get("agent_default", [])) def _get_required_entitlements(tool: str, action: str) -> List[str]: """Get required entitlements for tool+action from matrix.""" matrix = _load_yaml(_RBAC_PATH) tools_section = matrix.get("tools", {}) tool_cfg = tools_section.get(tool, {}) actions = tool_cfg.get("actions", {}) # Try exact action, then _default action_cfg = actions.get(action) or actions.get("_default", {}) return action_cfg.get("entitlements", []) if action_cfg else [] def check_rbac(agent_id: str, tool: str, action: str) -> Tuple[bool, str]: """ Check RBAC: agent role → entitlements → required entitlements for tool+action. Returns (allowed, reason). """ role = _get_agent_role(agent_id) agent_ents = set(_get_role_entitlements(role)) required = _get_required_entitlements(tool, action) if not required: # No entitlements required → allowed return True, "" missing = [e for e in required if e not in agent_ents] if missing: return False, f"Missing entitlements: {missing} (agent={agent_id}, role={role})" return True, "" # ─── Limits ─────────────────────────────────────────────────────────────────── def _get_limits(tool: str) -> Dict[str, Any]: """Get effective limits for a tool (per-tool overrides merged with defaults).""" cfg = _load_yaml(_LIMITS_PATH) defaults = cfg.get("defaults", { "timeout_ms": 30000, "max_chars_in": 200000, "max_bytes_out": 524288, "rate_limit_rpm": 60, "concurrency": 5, }) per_tool = cfg.get("tools", {}).get(tool, {}) return {**defaults, **per_tool} def check_input_limits(tool: str, input_text: str) -> Tuple[bool, str, Dict]: """ Enforce max_chars_in limit. Returns (ok, reason, limits_applied). """ limits = _get_limits(tool) max_chars = limits.get("max_chars_in", 200000) actual = len(input_text) if input_text else 0 if actual > max_chars: return False, f"Input too large: {actual} chars (max {max_chars} for {tool})", limits return True, "", limits # ─── Audit ──────────────────────────────────────────────────────────────────── def _emit_audit(event: AuditEvent): """ Emit structured audit event. 1. Writes to logger (structured, no payload). 2. Persists to AuditStore (JSONL/Postgres/Memory) for FinOps analysis. Persistence is non-fatal: errors are logged as warnings without interrupting tool execution. """ import datetime record = { "ts": event.ts or datetime.datetime.now(datetime.timezone.utc).isoformat(), "req_id": event.req_id, "tool": event.tool, "action": event.action, "workspace_id": event.workspace_id, "user_id": event.user_id, "agent_id": event.agent_id, "status": event.status, "duration_ms": round(event.duration_ms, 2), "limits_applied": event.limits_applied, "input_hash": event.input_hash, "input_chars": event.input_chars, "output_size_bytes": event.output_size_bytes, } logger.info(f"TOOL_AUDIT {json.dumps(record)}") # Persist to audit store (non-fatal) try: from audit_store import get_audit_store store = get_audit_store() store.write(event) except Exception as _audit_err: logger.warning("audit_store.write failed (non-fatal): %s", _audit_err) # ─── Main Governance Class ──────────────────────────────────────────────────── class ToolGovernance: """ Single entry point for tool governance. Call pre_call() before executing any tool. Call post_call() after execution to emit audit event. """ def __init__(self, *, enable_rbac: bool = True, enable_redaction: bool = True, enable_limits: bool = True, enable_audit: bool = True, enable_allowlist: bool = True): self.enable_rbac = enable_rbac self.enable_redaction = enable_redaction self.enable_limits = enable_limits self.enable_audit = enable_audit self.enable_allowlist = enable_allowlist def pre_call( self, tool: str, action: str, agent_id: str, user_id: str = "unknown", workspace_id: str = "unknown", input_text: str = "", ) -> PreCallResult: """ Run all pre-call checks. Returns PreCallResult. If allowed=False, caller must return error immediately. """ req_id = str(uuid.uuid4())[:12] ts_start = time.monotonic() # 1. RBAC check if self.enable_rbac: ok, reason = check_rbac(agent_id, tool, action) if not ok: if self.enable_audit: _emit_audit(AuditEvent( ts=_now_iso(), req_id=req_id, tool=tool, action=action, workspace_id=workspace_id, user_id=user_id, agent_id=agent_id, status="deny", duration_ms=0, limits_applied={}, input_hash="", input_chars=0, output_size_bytes=0, )) return PreCallResult(allowed=False, reason=f"RBAC denied: {reason}") # 2. Input limits limits_applied = {} if self.enable_limits and input_text: ok, reason, limits_applied = check_input_limits(tool, input_text) if not ok: if self.enable_audit: _emit_audit(AuditEvent( ts=_now_iso(), req_id=req_id, tool=tool, action=action, workspace_id=workspace_id, user_id=user_id, agent_id=agent_id, status="deny", duration_ms=0, limits_applied=limits_applied, input_hash="", input_chars=len(input_text), output_size_bytes=0, )) return PreCallResult(allowed=False, reason=f"Limits exceeded: {reason}") elif not limits_applied: limits_applied = _get_limits(tool) # Build call context input_hash = hashlib.sha256(input_text.encode()).hexdigest()[:16] if input_text else "" ctx = CallContext( req_id=req_id, tool=tool, action=action, agent_id=agent_id, user_id=user_id, workspace_id=workspace_id, ts_start=ts_start, input_hash=input_hash, input_chars=len(input_text) if input_text else 0, limits_applied=limits_applied, ) return PreCallResult(allowed=True, call_ctx=ctx) def post_call(self, ctx: CallContext, result_value: Any, error: Optional[str] = None): """ Emit audit event after tool execution. result_value: raw result data (used only for size calculation, not logged). """ if not self.enable_audit or ctx is None: return duration_ms = (time.monotonic() - ctx.ts_start) * 1000 status = "error" if error else "pass" # Calculate output size (bytes) without logging content try: out_bytes = len(json.dumps(result_value).encode()) if result_value is not None else 0 except Exception: out_bytes = 0 _emit_audit(AuditEvent( ts=_now_iso(), req_id=ctx.req_id, tool=ctx.tool, action=ctx.action, workspace_id=ctx.workspace_id, user_id=ctx.user_id, agent_id=ctx.agent_id, status=status, duration_ms=duration_ms, limits_applied=ctx.limits_applied, input_hash=ctx.input_hash, input_chars=ctx.input_chars, output_size_bytes=out_bytes, )) def apply_redaction(self, text: str) -> str: """Apply secret redaction if enabled.""" if not self.enable_redaction: return text return redact(text) def check_url(self, tool: str, url: str) -> Tuple[bool, str]: """Check URL against allowlist if enabled.""" if not self.enable_allowlist: return True, "" return check_url_allowed(tool, url) def get_timeout_ms(self, tool: str) -> int: """Get configured timeout for a tool.""" limits = _get_limits(tool) return limits.get("timeout_ms", 30000) # ─── Helpers ────────────────────────────────────────────────────────────────── def _now_iso() -> str: import datetime return datetime.datetime.now(datetime.timezone.utc).isoformat() # ─── Module-level singleton ─────────────────────────────────────────────────── _governance: Optional[ToolGovernance] = None def get_governance() -> ToolGovernance: """Get the shared ToolGovernance singleton.""" global _governance if _governance is None: _governance = ToolGovernance() return _governance def reset_governance(instance: Optional[ToolGovernance] = None): """Reset singleton (for testing).""" global _governance _governance = instance