limits.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. import logging
  2. from datetime import datetime, timedelta, timezone
  3. from typing import Optional
  4. from uuid import UUID
  5. from core.base import Handler, R2RException
  6. from .base import PostgresConnectionManager
  7. logger = logging.getLogger()
  8. class PostgresLimitsHandler(Handler):
  9. TABLE_NAME = "request_log"
  10. def __init__(
  11. self,
  12. project_name: str,
  13. connection_manager: PostgresConnectionManager,
  14. route_limits: dict,
  15. ):
  16. super().__init__(project_name, connection_manager)
  17. self.route_limits = route_limits
  18. async def create_tables(self):
  19. query = f"""
  20. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
  21. time TIMESTAMPTZ NOT NULL,
  22. user_id UUID NOT NULL,
  23. route TEXT NOT NULL
  24. );
  25. """
  26. await self.connection_manager.execute_query(query)
  27. async def _count_requests(
  28. self, user_id: UUID, route: Optional[str], since: datetime
  29. ) -> int:
  30. if route:
  31. query = f"""
  32. SELECT COUNT(*)::int
  33. FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  34. WHERE user_id = $1
  35. AND route = $2
  36. AND time >= $3
  37. """
  38. params = [user_id, route, since]
  39. else:
  40. query = f"""
  41. SELECT COUNT(*)::int
  42. FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  43. WHERE user_id = $1
  44. AND time >= $2
  45. """
  46. params = [user_id, since]
  47. result = await self.connection_manager.fetchrow_query(query, params)
  48. return result["count"] if result else 0
  49. async def _count_monthly_requests(self, user_id: UUID) -> int:
  50. now = datetime.now(timezone.utc)
  51. start_of_month = now.replace(
  52. day=1, hour=0, minute=0, second=0, microsecond=0
  53. )
  54. return await self._count_requests(
  55. user_id, route=None, since=start_of_month
  56. )
  57. async def check_limits(self, user_id: UUID, route: str):
  58. limits = self.route_limits.get(
  59. route,
  60. {
  61. "global_per_min": 60,
  62. "route_per_min": 30,
  63. "monthly_limit": 10000,
  64. },
  65. )
  66. global_per_min = limits["global_per_min"]
  67. route_per_min = limits["route_per_min"]
  68. monthly_limit = limits["monthly_limit"]
  69. now = datetime.now(timezone.utc)
  70. one_min_ago = now - timedelta(minutes=1)
  71. # Global per-minute check
  72. user_req_count = await self._count_requests(user_id, None, one_min_ago)
  73. print("min req count = ", user_req_count)
  74. if user_req_count >= global_per_min:
  75. raise ValueError("Global per-minute rate limit exceeded")
  76. # Per-route per-minute check
  77. route_req_count = await self._count_requests(
  78. user_id, route, one_min_ago
  79. )
  80. if route_req_count >= route_per_min:
  81. raise ValueError("Per-route per-minute rate limit exceeded")
  82. # Monthly limit check
  83. monthly_count = await self._count_monthly_requests(user_id)
  84. print("monthly_count = ", monthly_count)
  85. if monthly_count >= monthly_limit:
  86. raise ValueError("Monthly rate limit exceeded")
  87. async def log_request(self, user_id: UUID, route: str):
  88. query = f"""
  89. INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
  90. VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
  91. """
  92. await self.connection_manager.execute_query(query, [user_id, route])
  93. # import logging
  94. # from datetime import datetime, timedelta
  95. # from typing import Optional
  96. # from uuid import UUID
  97. # from core.base import Handler, R2RException
  98. # from .base import PostgresConnectionManager
  99. # logger = logging.getLogger()
  100. # class PostgresLimitsHandler(Handler):
  101. # TABLE_NAME = "request_log"
  102. # def __init__(
  103. # self,
  104. # project_name: str,
  105. # connection_manager: PostgresConnectionManager,
  106. # route_limits: dict,
  107. # ):
  108. # super().__init__(project_name, connection_manager)
  109. # self.route_limits = route_limits
  110. # async def create_tables(self):
  111. # """
  112. # Create the request_log table if it doesn't exist.
  113. # """
  114. # query = f"""
  115. # CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
  116. # time TIMESTAMPTZ NOT NULL,
  117. # user_id UUID NOT NULL,
  118. # route TEXT NOT NULL
  119. # );
  120. # """
  121. # await self.connection_manager.execute_query(query)
  122. # async def _count_requests(
  123. # self, user_id: UUID, route: Optional[str], since: datetime
  124. # ) -> int:
  125. # if route:
  126. # query = f"""
  127. # SELECT COUNT(*)::int
  128. # FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  129. # WHERE user_id = $1
  130. # AND route = $2
  131. # AND time >= $3
  132. # """
  133. # params = [user_id, route, since]
  134. # else:
  135. # query = f"""
  136. # SELECT COUNT(*)::int
  137. # FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  138. # WHERE user_id = $1
  139. # AND time >= $2
  140. # """
  141. # params = [user_id, since]
  142. # result = await self.connection_manager.fetchrow_query(query, params)
  143. # return result["count"] if result else 0
  144. # async def _count_monthly_requests(self, user_id: UUID) -> int:
  145. # now = datetime.utcnow()
  146. # start_of_month = now.replace(
  147. # day=1, hour=0, minute=0, second=0, microsecond=0
  148. # )
  149. # return await self._count_requests(
  150. # user_id, route=None, since=start_of_month
  151. # )
  152. # async def check_limits(self, user_id: UUID, route: str):
  153. # """
  154. # Check if the user can proceed with the request, using route-specific limits.
  155. # Raises ValueError if the user exceeded any limit.
  156. # """
  157. # limits = self.route_limits.get(
  158. # route,
  159. # {
  160. # "global_per_min": 60, # default global per min
  161. # "route_per_min": 20, # default route per min
  162. # "monthly_limit": 10000, # default monthly limit
  163. # },
  164. # )
  165. # global_per_min = limits["global_per_min"]
  166. # route_per_min = limits["route_per_min"]
  167. # monthly_limit = limits["monthly_limit"]
  168. # now = datetime.utcnow()
  169. # one_min_ago = now - timedelta(minutes=1)
  170. # # Global per-minute check
  171. # user_req_count = await self._count_requests(user_id, None, one_min_ago)
  172. # print('min req count = ', user_req_count)
  173. # if user_req_count >= global_per_min:
  174. # raise ValueError("Global per-minute rate limit exceeded")
  175. # # Per-route per-minute check
  176. # route_req_count = await self._count_requests(
  177. # user_id, route, one_min_ago
  178. # )
  179. # if route_req_count >= route_per_min:
  180. # raise ValueError("Per-route per-minute rate limit exceeded")
  181. # # Monthly limit check
  182. # monthly_count = await self._count_monthly_requests(user_id)
  183. # print('monthly_count = ', monthly_count)
  184. # if monthly_count >= monthly_limit:
  185. # raise ValueError("Monthly rate limit exceeded")
  186. # async def log_request(self, user_id: UUID, route: str):
  187. # query = f"""
  188. # INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
  189. # VALUES (NOW(), $1, $2)
  190. # """
  191. # await self.connection_manager.execute_query(query, [user_id, route])