limits.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import logging
  2. from datetime import datetime, timedelta, timezone
  3. from typing import Optional
  4. from uuid import UUID
  5. from core.base import Handler
  6. from shared.abstractions import User
  7. from ..base.providers.database import DatabaseConfig, LimitSettings
  8. from .base import PostgresConnectionManager
  9. logger = logging.getLogger(__name__)
  10. class PostgresLimitsHandler(Handler):
  11. TABLE_NAME = "request_log"
  12. def __init__(
  13. self,
  14. project_name: str,
  15. connection_manager: PostgresConnectionManager,
  16. config: DatabaseConfig,
  17. ):
  18. """
  19. :param config: The global DatabaseConfig with default rate limits.
  20. """
  21. super().__init__(project_name, connection_manager)
  22. self.config = config # Contains e.g. self.config.limits for fallback
  23. logger.debug(
  24. f"Initialized PostgresLimitsHandler with project: {project_name}"
  25. )
  26. async def create_tables(self):
  27. query = f"""
  28. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
  29. time TIMESTAMPTZ NOT NULL,
  30. user_id UUID NOT NULL,
  31. route TEXT NOT NULL
  32. );
  33. """
  34. logger.debug("Creating request_log table if not exists")
  35. await self.connection_manager.execute_query(query)
  36. async def _count_requests(
  37. self,
  38. user_id: UUID,
  39. route: Optional[str],
  40. since: datetime,
  41. ) -> int:
  42. """
  43. Count how many requests a user (optionally for a specific route)
  44. has made since the given datetime.
  45. """
  46. if route:
  47. query = f"""
  48. SELECT COUNT(*)::int
  49. FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  50. WHERE user_id = $1
  51. AND route = $2
  52. AND time >= $3
  53. """
  54. params = [user_id, route, since]
  55. logger.debug(
  56. f"Counting requests for user={user_id}, route={route}"
  57. )
  58. else:
  59. query = f"""
  60. SELECT COUNT(*)::int
  61. FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  62. WHERE user_id = $1
  63. AND time >= $2
  64. """
  65. params = [user_id, since]
  66. logger.debug(f"Counting all requests for user={user_id}")
  67. result = await self.connection_manager.fetchrow_query(query, params)
  68. return result["count"] if result else 0
  69. async def _count_monthly_requests(self, user_id: UUID) -> int:
  70. """
  71. Count the number of requests so far this month for a given user.
  72. """
  73. now = datetime.now(timezone.utc)
  74. start_of_month = now.replace(
  75. day=1, hour=0, minute=0, second=0, microsecond=0
  76. )
  77. return await self._count_requests(user_id, None, start_of_month)
  78. return await self._count_requests(
  79. user_id, route=None, since=start_of_month
  80. )
  81. def _determine_limits_for(
  82. self, user_id: UUID, route: str
  83. ) -> LimitSettings:
  84. # Start with base limits
  85. limits = self.config.limits
  86. # Route-specific limits - directly override if present
  87. if route_limits := self.config.route_limits.get(route):
  88. # Only override non-None values from route_limits
  89. if route_limits.global_per_min is not None:
  90. limits.global_per_min = route_limits.global_per_min
  91. if route_limits.route_per_min is not None:
  92. limits.route_per_min = route_limits.route_per_min
  93. if route_limits.monthly_limit is not None:
  94. limits.monthly_limit = route_limits.monthly_limit
  95. # User-specific limits - directly override if present
  96. if user_limits := self.config.user_limits.get(user_id):
  97. # Only override non-None values from user_limits
  98. if user_limits.global_per_min is not None:
  99. limits.global_per_min = user_limits.global_per_min
  100. if user_limits.route_per_min is not None:
  101. limits.route_per_min = user_limits.route_per_min
  102. if user_limits.monthly_limit is not None:
  103. limits.monthly_limit = user_limits.monthly_limit
  104. return limits
  105. async def check_limits(self, user: User, route: str):
  106. """
  107. Perform rate limit checks for a user on a specific route.
  108. :param user: The fully-fetched User object with .limits_overrides, etc.
  109. :param route: The route/path being accessed.
  110. :raises ValueError: if any limit is exceeded.
  111. """
  112. user_id = user.id
  113. now = datetime.now(timezone.utc)
  114. one_min_ago = now - timedelta(minutes=1)
  115. # 1) First check route-specific configuration limits
  116. route_config = self.config.route_limits.get(route)
  117. if route_config:
  118. # Check route-specific per-minute limit
  119. if route_config.route_per_min is not None:
  120. route_req_count = await self._count_requests(
  121. user_id, route, one_min_ago
  122. )
  123. if route_req_count > route_config.route_per_min:
  124. logger.warning(
  125. f"Per-route per-minute limit exceeded for user_id={user_id}, route={route}"
  126. )
  127. raise ValueError(
  128. "Per-route per-minute rate limit exceeded"
  129. )
  130. # Check route-specific monthly limit
  131. if route_config.monthly_limit is not None:
  132. monthly_count = await self._count_monthly_requests(user_id)
  133. if monthly_count > route_config.monthly_limit:
  134. logger.warning(
  135. f"Route monthly limit exceeded for user_id={user_id}, route={route}"
  136. )
  137. raise ValueError("Route monthly limit exceeded")
  138. # 2) Get user overrides and base limits
  139. user_overrides = user.limits_overrides or {}
  140. base_limits = self.config.limits
  141. # Extract user-level overrides
  142. global_per_min = user_overrides.get(
  143. "global_per_min", base_limits.global_per_min
  144. )
  145. monthly_limit = user_overrides.get(
  146. "monthly_limit", base_limits.monthly_limit
  147. )
  148. # 3) Check route-specific overrides from user config
  149. route_overrides = user_overrides.get("route_overrides", {})
  150. specific_config = route_overrides.get(route, {})
  151. # Apply route-specific overrides for per-minute limits
  152. route_per_min = specific_config.get(
  153. "route_per_min", base_limits.route_per_min
  154. )
  155. # If route specifically overrides global or monthly limits, apply them
  156. if "global_per_min" in specific_config:
  157. global_per_min = specific_config["global_per_min"]
  158. if "monthly_limit" in specific_config:
  159. monthly_limit = specific_config["monthly_limit"]
  160. # 4) Check global per-minute limit
  161. if global_per_min is not None:
  162. user_req_count = await self._count_requests(
  163. user_id, None, one_min_ago
  164. )
  165. if user_req_count > global_per_min:
  166. logger.warning(
  167. f"Global per-minute limit exceeded for user_id={user_id}, route={route}"
  168. )
  169. raise ValueError("Global per-minute rate limit exceeded")
  170. # 5) Check user-specific route per-minute limit
  171. if route_per_min is not None:
  172. route_req_count = await self._count_requests(
  173. user_id, route, one_min_ago
  174. )
  175. if route_req_count > route_per_min:
  176. logger.warning(
  177. f"Per-route per-minute limit exceeded for user_id={user_id}, route={route}"
  178. )
  179. raise ValueError("Per-route per-minute rate limit exceeded")
  180. # 6) Check monthly limit
  181. if monthly_limit is not None:
  182. monthly_count = await self._count_monthly_requests(user_id)
  183. if monthly_count > monthly_limit:
  184. logger.warning(
  185. f"Monthly limit exceeded for user_id={user_id}, route={route}"
  186. )
  187. raise ValueError("Monthly rate limit exceeded")
  188. async def log_request(self, user_id: UUID, route: str):
  189. """
  190. Log a successful request to the request_log table.
  191. """
  192. query = f"""
  193. INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
  194. VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
  195. """
  196. await self.connection_manager.execute_query(query, [user_id, route])