184 lines
4.8 KiB
Python
184 lines
4.8 KiB
Python
"""
|
|
Database connection and operations for Auth Service
|
|
"""
|
|
import asyncpg
|
|
from typing import Optional, List, Dict, Any
|
|
from uuid import UUID
|
|
from datetime import datetime, timedelta, timezone
|
|
from contextlib import asynccontextmanager
|
|
|
|
from config import get_settings
|
|
|
|
settings = get_settings()
|
|
|
|
_pool: Optional[asyncpg.Pool] = None
|
|
|
|
|
|
async def get_pool() -> asyncpg.Pool:
|
|
global _pool
|
|
if _pool is None:
|
|
_pool = await asyncpg.create_pool(
|
|
settings.database_url,
|
|
min_size=2,
|
|
max_size=10
|
|
)
|
|
return _pool
|
|
|
|
|
|
async def close_pool():
|
|
global _pool
|
|
if _pool:
|
|
await _pool.close()
|
|
_pool = None
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_connection():
|
|
pool = await get_pool()
|
|
async with pool.acquire() as conn:
|
|
yield conn
|
|
|
|
|
|
# User operations
|
|
async def create_user(
|
|
email: str,
|
|
password_hash: str,
|
|
display_name: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
async with get_connection() as conn:
|
|
# Create user
|
|
user = await conn.fetchrow(
|
|
"""
|
|
INSERT INTO auth_users (email, password_hash, display_name)
|
|
VALUES ($1, $2, $3)
|
|
RETURNING id, email, display_name, avatar_url, is_active, is_admin, created_at
|
|
""",
|
|
email, password_hash, display_name
|
|
)
|
|
|
|
user_id = user['id']
|
|
|
|
# Assign default 'user' role
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO auth_user_roles (user_id, role_id)
|
|
VALUES ($1, 'user')
|
|
""",
|
|
user_id
|
|
)
|
|
|
|
return dict(user)
|
|
|
|
|
|
async def get_user_by_email(email: str) -> Optional[Dict[str, Any]]:
|
|
async with get_connection() as conn:
|
|
user = await conn.fetchrow(
|
|
"""
|
|
SELECT id, email, password_hash, display_name, avatar_url,
|
|
is_active, is_admin, created_at, updated_at
|
|
FROM auth_users
|
|
WHERE email = $1
|
|
""",
|
|
email
|
|
)
|
|
return dict(user) if user else None
|
|
|
|
|
|
async def get_user_by_id(user_id: UUID) -> Optional[Dict[str, Any]]:
|
|
async with get_connection() as conn:
|
|
user = await conn.fetchrow(
|
|
"""
|
|
SELECT id, email, display_name, avatar_url,
|
|
is_active, is_admin, created_at, updated_at
|
|
FROM auth_users
|
|
WHERE id = $1
|
|
""",
|
|
user_id
|
|
)
|
|
return dict(user) if user else None
|
|
|
|
|
|
async def get_user_roles(user_id: UUID) -> List[str]:
|
|
async with get_connection() as conn:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT role_id FROM auth_user_roles
|
|
WHERE user_id = $1
|
|
""",
|
|
user_id
|
|
)
|
|
return [row['role_id'] for row in rows]
|
|
|
|
|
|
# Session operations
|
|
async def create_session(
|
|
user_id: UUID,
|
|
user_agent: Optional[str] = None,
|
|
ip_address: Optional[str] = None,
|
|
ttl_seconds: int = 604800
|
|
) -> UUID:
|
|
async with get_connection() as conn:
|
|
expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)
|
|
|
|
row = await conn.fetchrow(
|
|
"""
|
|
INSERT INTO auth_sessions (user_id, user_agent, ip_address, expires_at)
|
|
VALUES ($1, $2, $3::inet, $4)
|
|
RETURNING id
|
|
""",
|
|
user_id, user_agent, ip_address, expires_at
|
|
)
|
|
return row['id']
|
|
|
|
|
|
async def get_session(session_id: UUID) -> Optional[Dict[str, Any]]:
|
|
async with get_connection() as conn:
|
|
session = await conn.fetchrow(
|
|
"""
|
|
SELECT id, user_id, expires_at, revoked_at
|
|
FROM auth_sessions
|
|
WHERE id = $1
|
|
""",
|
|
session_id
|
|
)
|
|
return dict(session) if session else None
|
|
|
|
|
|
async def revoke_session(session_id: UUID) -> bool:
|
|
async with get_connection() as conn:
|
|
result = await conn.execute(
|
|
"""
|
|
UPDATE auth_sessions
|
|
SET revoked_at = now()
|
|
WHERE id = $1 AND revoked_at IS NULL
|
|
""",
|
|
session_id
|
|
)
|
|
return result == "UPDATE 1"
|
|
|
|
|
|
async def is_session_valid(session_id: UUID) -> bool:
|
|
async with get_connection() as conn:
|
|
row = await conn.fetchrow(
|
|
"""
|
|
SELECT 1 FROM auth_sessions
|
|
WHERE id = $1
|
|
AND revoked_at IS NULL
|
|
AND expires_at > now()
|
|
""",
|
|
session_id
|
|
)
|
|
return row is not None
|
|
|
|
|
|
async def cleanup_expired_sessions():
|
|
"""Remove expired sessions (can be run periodically)"""
|
|
async with get_connection() as conn:
|
|
await conn.execute(
|
|
"""
|
|
DELETE FROM auth_sessions
|
|
WHERE expires_at < now() - INTERVAL '7 days'
|
|
"""
|
|
)
|
|
|