limits.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. import logging
  2. from datetime import datetime, timedelta, timezone
  3. from typing import Optional
  4. from uuid import UUID
  5. from core.base import Handler
  6. from shared.abstractions import User
  7. from ...base.providers.database import DatabaseConfig, LimitSettings
  8. from .base import PostgresConnectionManager
  9. logger = logging.getLogger(__name__)
  10. class PostgresLimitsHandler(Handler):
  11. TABLE_NAME = "request_log"
  12. def __init__(
  13. self,
  14. project_name: str,
  15. connection_manager: PostgresConnectionManager,
  16. config: DatabaseConfig,
  17. ):
  18. """
  19. :param config: The global DatabaseConfig with default rate limits.
  20. """
  21. super().__init__(project_name, connection_manager)
  22. self.config = config
  23. logger.debug(
  24. f"Initialized PostgresLimitsHandler with project: {project_name}"
  25. )
  26. async def create_tables(self):
  27. query = f"""
  28. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
  29. time TIMESTAMPTZ NOT NULL,
  30. user_id UUID NOT NULL,
  31. route TEXT NOT NULL
  32. );
  33. """
  34. logger.debug("Creating request_log table if not exists")
  35. await self.connection_manager.execute_query(query)
  36. async def _count_requests(
  37. self,
  38. user_id: UUID,
  39. route: Optional[str],
  40. since: datetime,
  41. ) -> int:
  42. """Count how many requests a user (optionally for a specific route) has
  43. made since the given datetime."""
  44. if route:
  45. query = f"""
  46. SELECT COUNT(*)::int
  47. FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  48. WHERE user_id = $1
  49. AND route = $2
  50. AND time >= $3
  51. """
  52. params = [user_id, route, since]
  53. logger.debug(
  54. f"Counting requests for user={user_id}, route={route}"
  55. )
  56. else:
  57. query = f"""
  58. SELECT COUNT(*)::int
  59. FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  60. WHERE user_id = $1
  61. AND time >= $2
  62. """
  63. params = [user_id, since]
  64. logger.debug(f"Counting all requests for user={user_id}")
  65. result = await self.connection_manager.fetchrow_query(query, params)
  66. return result["count"] if result else 0
  67. async def _count_monthly_requests(
  68. self,
  69. user_id: UUID,
  70. route: Optional[str] = None, # <--- ADDED THIS
  71. ) -> int:
  72. """Count the number of requests so far this month for a given user.
  73. If route is provided, count only for that route. Otherwise, count
  74. globally.
  75. """
  76. now = datetime.now(timezone.utc)
  77. start_of_month = now.replace(
  78. day=1, hour=0, minute=0, second=0, microsecond=0
  79. )
  80. return await self._count_requests(
  81. user_id, route=route, since=start_of_month
  82. )
  83. def determine_effective_limits(
  84. self, user: User, route: str
  85. ) -> LimitSettings:
  86. """
  87. Determine the final effective limits for a user+route combination,
  88. respecting:
  89. 1) Global defaults
  90. 2) Route-specific overrides
  91. 3) User-level overrides
  92. """
  93. # ------------------------
  94. # 1) Start with global/base
  95. # ------------------------
  96. base_limits = self.config.limits
  97. # We’ll make a copy so we don’t mutate self.config.limits directly
  98. effective = LimitSettings(
  99. global_per_min=base_limits.global_per_min,
  100. route_per_min=base_limits.route_per_min,
  101. monthly_limit=base_limits.monthly_limit,
  102. )
  103. # ------------------------
  104. # 2) Route-level overrides
  105. # ------------------------
  106. route_config = self.config.route_limits.get(route)
  107. if route_config:
  108. if route_config.global_per_min is not None:
  109. effective.global_per_min = route_config.global_per_min
  110. if route_config.route_per_min is not None:
  111. effective.route_per_min = route_config.route_per_min
  112. if route_config.monthly_limit is not None:
  113. effective.monthly_limit = route_config.monthly_limit
  114. # ------------------------
  115. # 3) User-level overrides
  116. # ------------------------
  117. # The user object might have a dictionary of overrides
  118. # which can include route_overrides, global_per_min, monthly_limit, etc.
  119. user_overrides = user.limits_overrides or {}
  120. # (a) "global" user overrides
  121. if user_overrides.get("global_per_min") is not None:
  122. effective.global_per_min = user_overrides["global_per_min"]
  123. if user_overrides.get("monthly_limit") is not None:
  124. effective.monthly_limit = user_overrides["monthly_limit"]
  125. # (b) route-level user overrides
  126. route_overrides = user_overrides.get("route_overrides", {})
  127. specific_config = route_overrides.get(route, {})
  128. if specific_config.get("global_per_min") is not None:
  129. effective.global_per_min = specific_config["global_per_min"]
  130. if specific_config.get("route_per_min") is not None:
  131. effective.route_per_min = specific_config["route_per_min"]
  132. if specific_config.get("monthly_limit") is not None:
  133. effective.monthly_limit = specific_config["monthly_limit"]
  134. return effective
  135. async def check_limits(self, user: User, route: str):
  136. """Perform rate limit checks for a user on a specific route.
  137. :param user: The fully-fetched User object with .limits_overrides, etc.
  138. :param route: The route/path being accessed.
  139. :raises ValueError: if any limit is exceeded.
  140. """
  141. user_id = user.id
  142. now = datetime.now(timezone.utc)
  143. one_min_ago = now - timedelta(minutes=1)
  144. # 1) Compute the final (effective) limits for this user & route
  145. limits = self.determine_effective_limits(user, route)
  146. # 2) Check each of them in turn, if they exist
  147. # ------------------------------------------------------------
  148. # Global per-minute limit
  149. # ------------------------------------------------------------
  150. if limits.global_per_min is not None:
  151. user_req_count = await self._count_requests(
  152. user_id, None, one_min_ago
  153. )
  154. if user_req_count > limits.global_per_min:
  155. logger.warning(
  156. f"Global per-minute limit exceeded for "
  157. f"user_id={user_id}, route={route}"
  158. )
  159. raise ValueError("Global per-minute rate limit exceeded")
  160. # ------------------------------------------------------------
  161. # Route-specific per-minute limit
  162. # ------------------------------------------------------------
  163. if limits.route_per_min is not None:
  164. route_req_count = await self._count_requests(
  165. user_id, route, one_min_ago
  166. )
  167. if route_req_count > limits.route_per_min:
  168. logger.warning(
  169. f"Per-route per-minute limit exceeded for "
  170. f"user_id={user_id}, route={route}"
  171. )
  172. raise ValueError("Per-route per-minute rate limit exceeded")
  173. # ------------------------------------------------------------
  174. # Monthly limit
  175. # ------------------------------------------------------------
  176. if limits.monthly_limit is not None:
  177. # If you truly want a per-route monthly limit, we pass 'route'.
  178. # If you want a global monthly limit, pass 'None'.
  179. monthly_count = await self._count_monthly_requests(user_id, route)
  180. if monthly_count > limits.monthly_limit:
  181. logger.warning(
  182. f"Monthly limit exceeded for user_id={user_id}, "
  183. f"route={route}"
  184. )
  185. raise ValueError("Monthly rate limit exceeded")
  186. async def log_request(self, user_id: UUID, route: str):
  187. """Log a successful request to the request_log table."""
  188. query = f"""
  189. INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  190. (time, user_id, route)
  191. VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
  192. """
  193. await self.connection_manager.execute_query(query, [user_id, route])
  194. # import logging
  195. # from datetime import datetime, timedelta, timezone
  196. # from typing import Optional
  197. # from uuid import UUID
  198. # from core.base import Handler
  199. # from shared.abstractions import User
  200. # from ..base.providers.database import DatabaseConfig, LimitSettings
  201. # from .base import PostgresConnectionManager
  202. # logger = logging.getLogger(__name__)
  203. # class PostgresLimitsHandler(Handler):
  204. # TABLE_NAME = "request_log"
  205. # def __init__(
  206. # self,
  207. # project_name: str,
  208. # connection_manager: PostgresConnectionManager,
  209. # config: DatabaseConfig,
  210. # ):
  211. # """
  212. # :param config: The global DatabaseConfig with default rate limits.
  213. # """
  214. # super().__init__(project_name, connection_manager)
  215. # self.config = config
  216. # logger.debug(
  217. # f"Initialized PostgresLimitsHandler with project: {project_name}"
  218. # )
  219. # async def create_tables(self):
  220. # query = f"""
  221. # CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
  222. # time TIMESTAMPTZ NOT NULL,
  223. # user_id UUID NOT NULL,
  224. # route TEXT NOT NULL
  225. # );
  226. # """
  227. # logger.debug("Creating request_log table if not exists")
  228. # await self.connection_manager.execute_query(query)
  229. # async def _count_requests(
  230. # self,
  231. # user_id: UUID,
  232. # route: Optional[str],
  233. # since: datetime,
  234. # ) -> int:
  235. # """
  236. # Count how many requests a user (optionally for a specific route)
  237. # has made since the given datetime.
  238. # """
  239. # if route:
  240. # query = f"""
  241. # SELECT COUNT(*)::int
  242. # FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  243. # WHERE user_id = $1
  244. # AND route = $2
  245. # AND time >= $3
  246. # """
  247. # params = [user_id, route, since]
  248. # logger.debug(f"Counting requests for user={user_id}, route={route}")
  249. # else:
  250. # query = f"""
  251. # SELECT COUNT(*)::int
  252. # FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  253. # WHERE user_id = $1
  254. # AND time >= $2
  255. # """
  256. # params = [user_id, since]
  257. # logger.debug(f"Counting all requests for user={user_id}")
  258. # result = await self.connection_manager.fetchrow_query(query, params)
  259. # return result["count"] if result else 0
  260. # async def _count_monthly_requests(self, user_id: UUID) -> int:
  261. # """
  262. # Count the number of requests so far this month for a given user.
  263. # """
  264. # now = datetime.now(timezone.utc)
  265. # start_of_month = now.replace(
  266. # day=1, hour=0, minute=0, second=0, microsecond=0
  267. # )
  268. # return await self._count_requests(
  269. # user_id, route=None, since=start_of_month
  270. # )
  271. # def determine_effective_limits(
  272. # self, user: User, route: str
  273. # ) -> LimitSettings:
  274. # """
  275. # Determine the final effective limits for a user+route combination,
  276. # respecting:
  277. # 1) Global defaults
  278. # 2) Route-specific overrides
  279. # 3) User-level overrides
  280. # """
  281. # # ------------------------
  282. # # 1) Start with global/base
  283. # # ------------------------
  284. # base_limits = self.config.limits
  285. # # We’ll make a copy so we don’t mutate self.config.limits directly
  286. # effective = LimitSettings(
  287. # global_per_min=base_limits.global_per_min,
  288. # route_per_min=base_limits.route_per_min,
  289. # monthly_limit=base_limits.monthly_limit,
  290. # )
  291. # # ------------------------
  292. # # 2) Route-level overrides
  293. # # ------------------------
  294. # route_config = self.config.route_limits.get(route)
  295. # if route_config:
  296. # if route_config.global_per_min is not None:
  297. # effective.global_per_min = route_config.global_per_min
  298. # if route_config.route_per_min is not None:
  299. # effective.route_per_min = route_config.route_per_min
  300. # if route_config.monthly_limit is not None:
  301. # effective.monthly_limit = route_config.monthly_limit
  302. # # ------------------------
  303. # # 3) User-level overrides
  304. # # ------------------------
  305. # # The user object might have a dictionary of overrides
  306. # # which can include route_overrides, global_per_min, monthly_limit, etc.
  307. # user_overrides = user.limits_overrides or {}
  308. # # (a) "global" user overrides
  309. # if user_overrides.get("global_per_min") is not None:
  310. # effective.global_per_min = user_overrides["global_per_min"]
  311. # if user_overrides.get("monthly_limit") is not None:
  312. # effective.monthly_limit = user_overrides["monthly_limit"]
  313. # # (b) route-level user overrides
  314. # route_overrides = user_overrides.get("route_overrides", {})
  315. # specific_config = route_overrides.get(route, {})
  316. # if specific_config.get("global_per_min") is not None:
  317. # effective.global_per_min = specific_config["global_per_min"]
  318. # if specific_config.get("route_per_min") is not None:
  319. # effective.route_per_min = specific_config["route_per_min"]
  320. # if specific_config.get("monthly_limit") is not None:
  321. # effective.monthly_limit = specific_config["monthly_limit"]
  322. # return effective
  323. # async def check_limits(self, user: User, route: str):
  324. # """
  325. # Perform rate limit checks for a user on a specific route.
  326. # :param user: The fully-fetched User object with .limits_overrides, etc.
  327. # :param route: The route/path being accessed.
  328. # :raises ValueError: if any limit is exceeded.
  329. # """
  330. # user_id = user.id
  331. # now = datetime.now(timezone.utc)
  332. # one_min_ago = now - timedelta(minutes=1)
  333. # # 1) Compute the final (effective) limits for this user & route
  334. # limits = self.determine_effective_limits(user, route)
  335. # # 2) Check each of them in turn, if they exist
  336. # # ------------------------------------------------------------
  337. # # Global per-minute limit
  338. # # ------------------------------------------------------------
  339. # if limits.global_per_min is not None:
  340. # user_req_count = await self._count_requests(
  341. # user_id, None, one_min_ago
  342. # )
  343. # if user_req_count > limits.global_per_min:
  344. # logger.warning(
  345. # f"Global per-minute limit exceeded for "
  346. # f"user_id={user_id}, route={route}"
  347. # )
  348. # raise ValueError("Global per-minute rate limit exceeded")
  349. # # ------------------------------------------------------------
  350. # # Route-specific per-minute limit
  351. # # ------------------------------------------------------------
  352. # if limits.route_per_min is not None:
  353. # route_req_count = await self._count_requests(
  354. # user_id, route, one_min_ago
  355. # )
  356. # if route_req_count > limits.route_per_min:
  357. # logger.warning(
  358. # f"Per-route per-minute limit exceeded for "
  359. # f"user_id={user_id}, route={route}"
  360. # )
  361. # raise ValueError("Per-route per-minute rate limit exceeded")
  362. # # ------------------------------------------------------------
  363. # # Monthly limit
  364. # # ------------------------------------------------------------
  365. # if limits.monthly_limit is not None:
  366. # monthly_count = await self._count_monthly_requests(user_id)
  367. # if monthly_count > limits.monthly_limit:
  368. # logger.warning(
  369. # f"Monthly limit exceeded for user_id={user_id}, "
  370. # f"route={route}"
  371. # )
  372. # raise ValueError("Monthly rate limit exceeded")
  373. # async def log_request(self, user_id: UUID, route: str):
  374. # """
  375. # Log a successful request to the request_log table.
  376. # """
  377. # query = f"""
  378. # INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
  379. # (time, user_id, route)
  380. # VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
  381. # """
  382. # await self.connection_manager.execute_query(query, [user_id, route])