import logging from datetime import datetime, timedelta, timezone from typing import Optional from uuid import UUID from core.base import Handler from shared.abstractions import User from ..base.providers.database import DatabaseConfig, LimitSettings from .base import PostgresConnectionManager logger = logging.getLogger(__name__) class PostgresLimitsHandler(Handler): TABLE_NAME = "request_log" def __init__( self, project_name: str, connection_manager: PostgresConnectionManager, config: DatabaseConfig, ): """ :param config: The global DatabaseConfig with default rate limits. """ super().__init__(project_name, connection_manager) self.config = config # Contains e.g. self.config.limits for fallback logger.debug( f"Initialized PostgresLimitsHandler with project: {project_name}" ) async def create_tables(self): query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} ( time TIMESTAMPTZ NOT NULL, user_id UUID NOT NULL, route TEXT NOT NULL ); """ logger.debug("Creating request_log table if not exists") await self.connection_manager.execute_query(query) async def _count_requests( self, user_id: UUID, route: Optional[str], since: datetime, ) -> int: """ Count how many requests a user (optionally for a specific route) has made since the given datetime. """ if route: query = f""" SELECT COUNT(*)::int FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} WHERE user_id = $1 AND route = $2 AND time >= $3 """ params = [user_id, route, since] logger.debug( f"Counting requests for user={user_id}, route={route}" ) else: query = f""" SELECT COUNT(*)::int FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} WHERE user_id = $1 AND time >= $2 """ params = [user_id, since] logger.debug(f"Counting all requests for user={user_id}") result = await self.connection_manager.fetchrow_query(query, params) return result["count"] if result else 0 async def _count_monthly_requests(self, user_id: UUID) -> int: """ Count the number of requests so far this month for a given user. """ now = datetime.now(timezone.utc) start_of_month = now.replace( day=1, hour=0, minute=0, second=0, microsecond=0 ) return await self._count_requests(user_id, None, start_of_month) return await self._count_requests( user_id, route=None, since=start_of_month ) def _determine_limits_for( self, user_id: UUID, route: str ) -> LimitSettings: # Start with base limits limits = self.config.limits # Route-specific limits - directly override if present if route_limits := self.config.route_limits.get(route): # Only override non-None values from route_limits if route_limits.global_per_min is not None: limits.global_per_min = route_limits.global_per_min if route_limits.route_per_min is not None: limits.route_per_min = route_limits.route_per_min if route_limits.monthly_limit is not None: limits.monthly_limit = route_limits.monthly_limit # User-specific limits - directly override if present if user_limits := self.config.user_limits.get(user_id): # Only override non-None values from user_limits if user_limits.global_per_min is not None: limits.global_per_min = user_limits.global_per_min if user_limits.route_per_min is not None: limits.route_per_min = user_limits.route_per_min if user_limits.monthly_limit is not None: limits.monthly_limit = user_limits.monthly_limit return limits async def check_limits(self, user: User, route: str): """ Perform rate limit checks for a user on a specific route. :param user: The fully-fetched User object with .limits_overrides, etc. :param route: The route/path being accessed. :raises ValueError: if any limit is exceeded. """ user_id = user.id now = datetime.now(timezone.utc) one_min_ago = now - timedelta(minutes=1) # 1) First check route-specific configuration limits route_config = self.config.route_limits.get(route) if route_config: # Check route-specific per-minute limit if route_config.route_per_min is not None: route_req_count = await self._count_requests( user_id, route, one_min_ago ) if route_req_count > route_config.route_per_min: logger.warning( f"Per-route per-minute limit exceeded for user_id={user_id}, route={route}" ) raise ValueError( "Per-route per-minute rate limit exceeded" ) # Check route-specific monthly limit if route_config.monthly_limit is not None: monthly_count = await self._count_monthly_requests(user_id) if monthly_count > route_config.monthly_limit: logger.warning( f"Route monthly limit exceeded for user_id={user_id}, route={route}" ) raise ValueError("Route monthly limit exceeded") # 2) Get user overrides and base limits user_overrides = user.limits_overrides or {} base_limits = self.config.limits # Extract user-level overrides global_per_min = user_overrides.get( "global_per_min", base_limits.global_per_min ) monthly_limit = user_overrides.get( "monthly_limit", base_limits.monthly_limit ) # 3) Check route-specific overrides from user config route_overrides = user_overrides.get("route_overrides", {}) specific_config = route_overrides.get(route, {}) # Apply route-specific overrides for per-minute limits route_per_min = specific_config.get( "route_per_min", base_limits.route_per_min ) # If route specifically overrides global or monthly limits, apply them if "global_per_min" in specific_config: global_per_min = specific_config["global_per_min"] if "monthly_limit" in specific_config: monthly_limit = specific_config["monthly_limit"] # 4) Check global per-minute limit if global_per_min is not None: user_req_count = await self._count_requests( user_id, None, one_min_ago ) if user_req_count > global_per_min: logger.warning( f"Global per-minute limit exceeded for user_id={user_id}, route={route}" ) raise ValueError("Global per-minute rate limit exceeded") # 5) Check user-specific route per-minute limit if route_per_min is not None: route_req_count = await self._count_requests( user_id, route, one_min_ago ) if route_req_count > route_per_min: logger.warning( f"Per-route per-minute limit exceeded for user_id={user_id}, route={route}" ) raise ValueError("Per-route per-minute rate limit exceeded") # 6) Check monthly limit if monthly_limit is not None: monthly_count = await self._count_monthly_requests(user_id) if monthly_count > monthly_limit: logger.warning( f"Monthly limit exceeded for user_id={user_id}, route={route}" ) raise ValueError("Monthly rate limit exceeded") async def log_request(self, user_id: UUID, route: str): """ Log a successful request to the request_log table. """ query = f""" INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route) VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2) """ await self.connection_manager.execute_query(query, [user_id, route])