Files
microdao-daarion/services/auth-service/database.py

184 lines
4.8 KiB
Python

"""
Database connection and operations for Auth Service
"""
import asyncpg
from typing import Optional, List, Dict, Any
from uuid import UUID
from datetime import datetime, timedelta, timezone
from contextlib import asynccontextmanager
from config import get_settings
settings = get_settings()
_pool: Optional[asyncpg.Pool] = None
async def get_pool() -> asyncpg.Pool:
global _pool
if _pool is None:
_pool = await asyncpg.create_pool(
settings.database_url,
min_size=2,
max_size=10
)
return _pool
async def close_pool():
global _pool
if _pool:
await _pool.close()
_pool = None
@asynccontextmanager
async def get_connection():
pool = await get_pool()
async with pool.acquire() as conn:
yield conn
# User operations
async def create_user(
email: str,
password_hash: str,
display_name: Optional[str] = None
) -> Dict[str, Any]:
async with get_connection() as conn:
# Create user
user = await conn.fetchrow(
"""
INSERT INTO auth_users (email, password_hash, display_name)
VALUES ($1, $2, $3)
RETURNING id, email, display_name, avatar_url, is_active, is_admin, created_at
""",
email, password_hash, display_name
)
user_id = user['id']
# Assign default 'user' role
await conn.execute(
"""
INSERT INTO auth_user_roles (user_id, role_id)
VALUES ($1, 'user')
""",
user_id
)
return dict(user)
async def get_user_by_email(email: str) -> Optional[Dict[str, Any]]:
async with get_connection() as conn:
user = await conn.fetchrow(
"""
SELECT id, email, password_hash, display_name, avatar_url,
is_active, is_admin, created_at, updated_at
FROM auth_users
WHERE email = $1
""",
email
)
return dict(user) if user else None
async def get_user_by_id(user_id: UUID) -> Optional[Dict[str, Any]]:
async with get_connection() as conn:
user = await conn.fetchrow(
"""
SELECT id, email, display_name, avatar_url,
is_active, is_admin, created_at, updated_at
FROM auth_users
WHERE id = $1
""",
user_id
)
return dict(user) if user else None
async def get_user_roles(user_id: UUID) -> List[str]:
async with get_connection() as conn:
rows = await conn.fetch(
"""
SELECT role_id FROM auth_user_roles
WHERE user_id = $1
""",
user_id
)
return [row['role_id'] for row in rows]
# Session operations
async def create_session(
user_id: UUID,
user_agent: Optional[str] = None,
ip_address: Optional[str] = None,
ttl_seconds: int = 604800
) -> UUID:
async with get_connection() as conn:
expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)
row = await conn.fetchrow(
"""
INSERT INTO auth_sessions (user_id, user_agent, ip_address, expires_at)
VALUES ($1, $2, $3::inet, $4)
RETURNING id
""",
user_id, user_agent, ip_address, expires_at
)
return row['id']
async def get_session(session_id: UUID) -> Optional[Dict[str, Any]]:
async with get_connection() as conn:
session = await conn.fetchrow(
"""
SELECT id, user_id, expires_at, revoked_at
FROM auth_sessions
WHERE id = $1
""",
session_id
)
return dict(session) if session else None
async def revoke_session(session_id: UUID) -> bool:
async with get_connection() as conn:
result = await conn.execute(
"""
UPDATE auth_sessions
SET revoked_at = now()
WHERE id = $1 AND revoked_at IS NULL
""",
session_id
)
return result == "UPDATE 1"
async def is_session_valid(session_id: UUID) -> bool:
async with get_connection() as conn:
row = await conn.fetchrow(
"""
SELECT 1 FROM auth_sessions
WHERE id = $1
AND revoked_at IS NULL
AND expires_at > now()
""",
session_id
)
return row is not None
async def cleanup_expired_sessions():
"""Remove expired sessions (can be run periodically)"""
async with get_connection() as conn:
await conn.execute(
"""
DELETE FROM auth_sessions
WHERE expires_at < now() - INTERVAL '7 days'
"""
)