limits.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. import logging
  2. from datetime import datetime, timedelta, timezone
  3. from typing import Optional
  4. from uuid import UUID
  5. from core.base import Handler
  6. from shared.abstractions import User # your domain user model
  7. from ..base.providers.database import DatabaseConfig, LimitSettings
  8. from .base import PostgresConnectionManager
  9. logger = logging.getLogger(__name__)
  10. class PostgresLimitsHandler(Handler):
  11. TABLE_NAME = "request_log"
  12. def __init__(
  13. self,
  14. project_name: str,
  15. connection_manager: PostgresConnectionManager,
  16. config: DatabaseConfig,
  17. ):
  18. """
  19. :param config: The global DatabaseConfig with default rate limits.
  20. """
  21. super().__init__(project_name, connection_manager)
  22. self.config = config # Contains e.g. self.config.limits for fallback
  23. logger.debug(
  24. f"Initialized PostgresLimitsHandler with project: {project_name}"
  25. )
  26. async def create_tables(self):
  27. query = f"""
  28. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
  29. time TIMESTAMPTZ NOT NULL,
  30. user_id UUID NOT NULL,
  31. route TEXT NOT NULL
  32. );
  33. """
  34. logger.debug("Creating request_log table if not exists")
  35. await self.connection_manager.execute_query(query)
  36. async def _count_requests(
  37. self,
  38. user_id: UUID,
  39. route: Optional[str],
  40. since: datetime,
  41. ) -> int:
  42. """
  43. Count how many requests a user (optionally for a specific route)
  44. has made since the given datetime.
  45. """
  46. if route:
  47. query = f"""
  48. SELECT COUNT(*)::int
  49. FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  50. WHERE user_id = $1
  51. AND route = $2
  52. AND time >= $3
  53. """
  54. params = [user_id, route, since]
  55. logger.debug(
  56. f"Counting requests for user={user_id}, route={route}"
  57. )
  58. else:
  59. query = f"""
  60. SELECT COUNT(*)::int
  61. FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  62. WHERE user_id = $1
  63. AND time >= $2
  64. """
  65. params = [user_id, since]
  66. logger.debug(f"Counting all requests for user={user_id}")
  67. result = await self.connection_manager.fetchrow_query(query, params)
  68. return result["count"] if result else 0
  69. async def _count_monthly_requests(self, user_id: UUID) -> int:
  70. """
  71. Count the number of requests so far this month for a given user.
  72. """
  73. now = datetime.now(timezone.utc)
  74. start_of_month = now.replace(
  75. day=1, hour=0, minute=0, second=0, microsecond=0
  76. )
  77. return await self._count_requests(user_id, None, start_of_month)
  78. async def check_limits(self, user: User, route: str):
  79. """
  80. Perform rate limit checks for a user on a specific route.
  81. :param user: The fully-fetched User object with .limits_overrides, etc.
  82. :param route: The route/path being accessed.
  83. :raises ValueError: if any limit is exceeded.
  84. """
  85. user_id = user.id
  86. now = datetime.now(timezone.utc)
  87. one_min_ago = now - timedelta(minutes=1)
  88. # 1) First check route-specific configuration limits
  89. route_config = self.config.route_limits.get(route)
  90. if route_config:
  91. # Check route-specific per-minute limit
  92. if route_config.route_per_min is not None:
  93. route_req_count = await self._count_requests(
  94. user_id, route, one_min_ago
  95. )
  96. if route_req_count > route_config.route_per_min:
  97. logger.warning(
  98. f"Per-route per-minute limit exceeded for user_id={user_id}, route={route}"
  99. )
  100. raise ValueError(
  101. "Per-route per-minute rate limit exceeded"
  102. )
  103. # Check route-specific monthly limit
  104. if route_config.monthly_limit is not None:
  105. monthly_count = await self._count_monthly_requests(user_id)
  106. if monthly_count > route_config.monthly_limit:
  107. logger.warning(
  108. f"Route monthly limit exceeded for user_id={user_id}, route={route}"
  109. )
  110. raise ValueError("Route monthly limit exceeded")
  111. # 2) Get user overrides and base limits
  112. user_overrides = user.limits_overrides or {}
  113. base_limits = self.config.limits
  114. # Extract user-level overrides
  115. global_per_min = user_overrides.get(
  116. "global_per_min", base_limits.global_per_min
  117. )
  118. monthly_limit = user_overrides.get(
  119. "monthly_limit", base_limits.monthly_limit
  120. )
  121. # 3) Check route-specific overrides from user config
  122. route_overrides = user_overrides.get("route_overrides", {})
  123. specific_config = route_overrides.get(route, {})
  124. # Apply route-specific overrides for per-minute limits
  125. route_per_min = specific_config.get(
  126. "route_per_min", base_limits.route_per_min
  127. )
  128. # If route specifically overrides global or monthly limits, apply them
  129. if "global_per_min" in specific_config:
  130. global_per_min = specific_config["global_per_min"]
  131. if "monthly_limit" in specific_config:
  132. monthly_limit = specific_config["monthly_limit"]
  133. # 4) Check global per-minute limit
  134. if global_per_min is not None:
  135. user_req_count = await self._count_requests(
  136. user_id, None, one_min_ago
  137. )
  138. if user_req_count > global_per_min:
  139. logger.warning(
  140. f"Global per-minute limit exceeded for user_id={user_id}, route={route}"
  141. )
  142. raise ValueError("Global per-minute rate limit exceeded")
  143. # 5) Check user-specific route per-minute limit
  144. if route_per_min is not None:
  145. route_req_count = await self._count_requests(
  146. user_id, route, one_min_ago
  147. )
  148. if route_req_count > route_per_min:
  149. logger.warning(
  150. f"Per-route per-minute limit exceeded for user_id={user_id}, route={route}"
  151. )
  152. raise ValueError("Per-route per-minute rate limit exceeded")
  153. # 6) Check monthly limit
  154. if monthly_limit is not None:
  155. monthly_count = await self._count_monthly_requests(user_id)
  156. if monthly_count > monthly_limit:
  157. logger.warning(
  158. f"Monthly limit exceeded for user_id={user_id}, route={route}"
  159. )
  160. raise ValueError("Monthly rate limit exceeded")
  161. async def log_request(self, user_id: UUID, route: str):
  162. """
  163. Log a successful request to the request_log table.
  164. """
  165. query = f"""
  166. INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
  167. VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
  168. """
  169. await self.connection_manager.execute_query(query, [user_id, route])