limits.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. import logging
  2. from datetime import datetime, timedelta, timezone
  3. from typing import Optional
  4. from uuid import UUID
  5. from core.base import Handler
  6. from ..base.providers.database import DatabaseConfig, LimitSettings
  7. from .base import PostgresConnectionManager
  8. logger = logging.getLogger(__name__)
  9. class PostgresLimitsHandler(Handler):
  10. TABLE_NAME = "request_log"
  11. def __init__(
  12. self,
  13. project_name: str,
  14. connection_manager: PostgresConnectionManager,
  15. config: DatabaseConfig,
  16. ):
  17. super().__init__(project_name, connection_manager)
  18. self.config = config
  19. logger.debug(
  20. f"Initialized PostgresLimitsHandler with project: {project_name}"
  21. )
  22. async def create_tables(self):
  23. query = f"""
  24. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
  25. time TIMESTAMPTZ NOT NULL,
  26. user_id UUID NOT NULL,
  27. route TEXT NOT NULL
  28. );
  29. """
  30. logger.debug("Creating request_log table if not exists")
  31. await self.connection_manager.execute_query(query)
  32. async def _count_requests(
  33. self, user_id: UUID, route: Optional[str], since: datetime
  34. ) -> int:
  35. if route:
  36. query = f"""
  37. SELECT COUNT(*)::int
  38. FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  39. WHERE user_id = $1
  40. AND route = $2
  41. AND time >= $3
  42. """
  43. params = [user_id, route, since]
  44. logger.debug(f"Counting requests for route {route}")
  45. else:
  46. query = f"""
  47. SELECT COUNT(*)::int
  48. FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  49. WHERE user_id = $1
  50. AND time >= $2
  51. """
  52. params = [user_id, since]
  53. logger.debug("Counting all requests")
  54. result = await self.connection_manager.fetchrow_query(query, params)
  55. count = result["count"] if result else 0
  56. return count
  57. async def _count_monthly_requests(self, user_id: UUID) -> int:
  58. now = datetime.now(timezone.utc)
  59. start_of_month = now.replace(
  60. day=1, hour=0, minute=0, second=0, microsecond=0
  61. )
  62. count = await self._count_requests(
  63. user_id, route=None, since=start_of_month
  64. )
  65. return count
  66. def _determine_limits_for(
  67. self, user_id: UUID, route: str
  68. ) -> LimitSettings:
  69. # Start with base limits
  70. limits = self.config.limits
  71. # Route-specific limits - directly override if present
  72. route_limits = self.config.route_limits.get(route)
  73. if route_limits:
  74. # Only override non-None values from route_limits
  75. if route_limits.global_per_min is not None:
  76. limits.global_per_min = route_limits.global_per_min
  77. if route_limits.route_per_min is not None:
  78. limits.route_per_min = route_limits.route_per_min
  79. if route_limits.monthly_limit is not None:
  80. limits.monthly_limit = route_limits.monthly_limit
  81. # User-specific limits - directly override if present
  82. user_limits = self.config.user_limits.get(user_id)
  83. if user_limits:
  84. # Only override non-None values from user_limits
  85. if user_limits.global_per_min is not None:
  86. limits.global_per_min = user_limits.global_per_min
  87. if user_limits.route_per_min is not None:
  88. limits.route_per_min = user_limits.route_per_min
  89. if user_limits.monthly_limit is not None:
  90. limits.monthly_limit = user_limits.monthly_limit
  91. return limits
  92. async def check_limits(self, user_id: UUID, route: str):
  93. # Determine final applicable limits
  94. limits = self._determine_limits_for(user_id, route)
  95. if not limits:
  96. limits = self.config.default_limits
  97. global_per_min = limits.global_per_min
  98. route_per_min = limits.route_per_min
  99. monthly_limit = limits.monthly_limit
  100. now = datetime.now(timezone.utc)
  101. one_min_ago = now - timedelta(minutes=1)
  102. # Global per-minute check
  103. if global_per_min is not None:
  104. user_req_count = await self._count_requests(
  105. user_id, None, one_min_ago
  106. )
  107. if user_req_count > global_per_min:
  108. logger.warning(
  109. f"Global per-minute limit exceeded for user_id={user_id}, route={route}"
  110. )
  111. raise ValueError("Global per-minute rate limit exceeded")
  112. # Per-route per-minute check
  113. if route_per_min is not None:
  114. route_req_count = await self._count_requests(
  115. user_id, route, one_min_ago
  116. )
  117. if route_req_count > route_per_min:
  118. logger.warning(
  119. f"Per-route per-minute limit exceeded for user_id={user_id}, route={route}"
  120. )
  121. raise ValueError("Per-route per-minute rate limit exceeded")
  122. # Monthly limit check
  123. if monthly_limit is not None:
  124. monthly_count = await self._count_monthly_requests(user_id)
  125. if monthly_count > monthly_limit:
  126. logger.warning(
  127. f"Monthly limit exceeded for user_id={user_id}, route={route}"
  128. )
  129. raise ValueError("Monthly rate limit exceeded")
  130. async def log_request(self, user_id: UUID, route: str):
  131. query = f"""
  132. INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
  133. VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
  134. """
  135. await self.connection_manager.execute_query(query, [user_id, route])