""" Tests for ThreatModel Tool """ import pytest import os import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from services.router.tool_manager import ToolManager, ToolResult class TestThreatModelTool: """Test threatmodel tool functionality""" @pytest.mark.asyncio async def test_analyze_service_basic(self): """Test basic service analysis""" tool_mgr = ToolManager({}) openapi_spec = """ openapi: 3.0.0 paths: /api/users: get: summary: Get users security: - BearerAuth: [] /api/admin: post: summary: Admin endpoint components: securitySchemes: BearerAuth: type: http scheme: bearer """ result = await tool_mgr._threatmodel_tool({ "action": "analyze_service", "inputs": { "service_name": "test-service", "artifacts": [ { "type": "openapi", "source": "text", "value": openapi_spec } ] } }) assert result.success is True assert result.result is not None assert "scope" in result.result assert result.result["scope"]["service_name"] == "test-service" assert "entrypoints" in result.result assert "threats" in result.result assert "security_checklist" in result.result @pytest.mark.asyncio async def test_analyze_diff_with_rce(self): """Test that diff with RCE patterns generates threats""" tool_mgr = ToolManager({}) diff = """ diff --git a/app.py b/app.py --- a/app.py +++ b/app.py @@ -1,3 +1,5 @@ +import os +result = os.system(request.args.get('cmd')) """ result = await tool_mgr._threatmodel_tool({ "action": "analyze_diff", "inputs": { "service_name": "vulnerable-service", "artifacts": [ { "type": "diff", "source": "text", "value": diff } ] } }) assert result.success is True threats = result.result["threats"] rce_threats = [t for t in threats if "RCE" in t.get("title", "")] assert len(rce_threats) > 0 @pytest.mark.asyncio async def test_analyze_diff_with_ssrf(self): """Test that diff with URL fetch patterns generates SSRF threats""" tool_mgr = ToolManager({}) diff = """ diff --git a/app.py b/app.py --- a/app.py +++ b/app.py @@ -1,2 +1,3 @@ +import requests +result = requests.get(url) """ result = await tool_mgr._threatmodel_tool({ "action": "analyze_diff", "inputs": { "service_name": "fetch-service", "artifacts": [ { "type": "diff", "source": "text", "value": diff } ] } }) assert result.success is True threats = result.result["threats"] ssrf_threats = [t for t in threats if "SSRF" in t.get("title", "")] assert len(ssrf_threats) > 0 @pytest.mark.asyncio async def test_no_auth_endpoint_generates_threat(self): """Test that endpoints without auth generate threats""" tool_mgr = ToolManager({}) openapi_spec = """ openapi: 3.0.0 paths: /public/data: get: summary: Public data """ result = await tool_mgr._threatmodel_tool({ "action": "analyze_service", "inputs": { "service_name": "public-api", "artifacts": [ { "type": "openapi", "source": "text", "value": openapi_spec } ] } }) assert result.success is True threats = result.result["threats"] auth_threats = [t for t in threats if "Unauthenticated" in t.get("title", "")] assert len(auth_threats) > 0 @pytest.mark.asyncio async def test_agentic_risk_profile_adds_threats(self): """Test that agentic_tools profile adds specific threats""" tool_mgr = ToolManager({}) result = await tool_mgr._threatmodel_tool({ "action": "analyze_service", "inputs": { "service_name": "agent-service", "artifacts": [ { "type": "text", "source": "text", "value": "Agent tool execution service" } ] }, "options": { "risk_profile": "agentic_tools" } }) assert result.success is True threats = result.result["threats"] threat_ids = [t.get("id", "") for t in threats] assert any("TM-AI-" in tid for tid in threat_ids) boundaries = result.result["trust_boundaries"] agent_boundary = [b for b in boundaries if "agent_to_tool" in b.get("name", "")] assert len(agent_boundary) > 0 @pytest.mark.asyncio async def test_public_api_risk_profile(self): """Test that public_api profile adds specific threats""" tool_mgr = ToolManager({}) result = await tool_mgr._threatmodel_tool({ "action": "analyze_service", "inputs": { "service_name": "public-api", "artifacts": [ { "type": "text", "source": "text", "value": "Public API service" } ] }, "options": { "risk_profile": "public_api" } }) assert result.success is True threats = result.result["threats"] threat_ids = [t.get("id", "") for t in threats] assert any("TM-PA-" in tid for tid in threat_ids) @pytest.mark.asyncio async def test_extracts_assets_from_content(self): """Test that assets are extracted from content""" tool_mgr = ToolManager({}) content = """ const API_KEY = "sk-1234567890"; const USER_TOKEN = "user_session_abc123"; user_email = "user@example.com"; """ result = await tool_mgr._threatmodel_tool({ "action": "analyze_diff", "inputs": { "service_name": "test-service", "artifacts": [ { "type": "text", "source": "text", "value": content } ] } }) assert result.success is True assets = result.result["assets"] assert len(assets) > 0 asset_names = [a.get("name", "").lower() for a in assets] assert any("api_key" in n or "token" in n for n in asset_names) @pytest.mark.asyncio async def test_strict_mode_impact(self): """Test that strict mode affects summary output""" tool_mgr = ToolManager({}) diff = """ diff --git a/app.py b/app.py --- a/app.py +++ b/app.py @@ -1,2 +1,3 @@ +os.system(request.args.get('cmd')) """ result_normal = await tool_mgr._threatmodel_tool({ "action": "analyze_diff", "inputs": { "service_name": "test-service", "artifacts": [{"type": "diff", "source": "text", "value": diff}] }, "options": {"strict": False} }) result_strict = await tool_mgr._threatmodel_tool({ "action": "analyze_diff", "inputs": { "service_name": "test-service", "artifacts": [{"type": "diff", "source": "text", "value": diff}] }, "options": {"strict": True} }) assert result_normal.success is True assert result_strict.success is True assert "FAIL" in result_strict.result["summary"] or "HIGH" in result_strict.result["summary"] @pytest.mark.asyncio async def test_security_checklist_generated(self): """Test that security checklist is generated""" tool_mgr = ToolManager({}) result = await tool_mgr._threatmodel_tool({ "action": "generate_checklist", "inputs": { "service_name": "test-service", "artifacts": [] } }) assert result.success is True checklist = result.result["security_checklist"] assert len(checklist) > 0 checklist_types = [c.get("type") for c in checklist] assert "auth" in checklist_types or "authz" in checklist_types @pytest.mark.asyncio async def test_max_chars_limit(self): """Test that max_chars limit is enforced""" tool_mgr = ToolManager({}) large_content = "a" * 700000 result = await tool_mgr._threatmodel_tool({ "action": "analyze_service", "inputs": { "service_name": "test-service", "artifacts": [ {"type": "text", "source": "text", "value": large_content} ] } }) assert result.success is False assert "max_chars" in result.error.lower() @pytest.mark.asyncio async def test_deterministic_ordering(self): """Test that threats are in deterministic order""" tool_mgr = ToolManager({}) result = await tool_mgr._threatmodel_tool({ "action": "analyze_service", "inputs": { "service_name": "test-service", "artifacts": [ {"type": "text", "source": "text", "value": "test content"} ] } }) assert result.success is True threats = result.result["threats"] if len(threats) > 1: ids = [t.get("id", "") for t in threats] assert ids == sorted(ids) @pytest.mark.asyncio async def test_entrypoints_from_openapi(self): """Test that entrypoints are extracted from OpenAPI""" tool_mgr = ToolManager({}) openapi_spec = """ openapi: 3.0.0 paths: /v1/users: get: summary: Get users /v1/users: post: summary: Create user security: - ApiKeyAuth: [] /v1/admin/users: delete: summary: Delete user components: securitySchemes: ApiKeyAuth: type: apiKey in: header """ result = await tool_mgr._threatmodel_tool({ "action": "analyze_service", "inputs": { "service_name": "test-api", "artifacts": [ {"type": "openapi", "source": "text", "value": openapi_spec} ] } }) assert result.success is True entrypoints = result.result["entrypoints"] assert len(entrypoints) >= 2 http_entrypoints = [e for e in entrypoints if e.get("type") == "http"] assert len(http_entrypoints) >= 2 @pytest.mark.asyncio async def test_trust_boundaries_identified(self): """Test that trust boundaries are identified""" tool_mgr = ToolManager({}) result = await tool_mgr._threatmodel_tool({ "action": "analyze_service", "inputs": { "service_name": "test-service", "artifacts": [ {"type": "text", "source": "text", "value": "HTTP service with database"} ] } }) assert result.success is True boundaries = result.result["trust_boundaries"] assert len(boundaries) > 0 boundary_names = [b.get("name", "") for b in boundaries] assert any("client" in n or "gateway" in n for n in boundary_names) @pytest.mark.asyncio async def test_controls_generated_from_threats(self): """Test that controls are generated from threats""" tool_mgr = ToolManager({}) diff = """ diff --git a/app.py b/app.py --- a/app.py +++ b/app.py @@ -1,2 +1,3 @@ +result = os.system(cmd) """ result = await tool_mgr._threatmodel_tool({ "action": "analyze_diff", "inputs": { "service_name": "test-service", "artifacts": [{"type": "diff", "source": "text", "value": diff}] } }) assert result.success is True controls = result.result["controls"] assert len(controls) > 0