test_limits.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. import uuid
  2. from datetime import datetime, timedelta, timezone
  3. from uuid import UUID
  4. import pytest
  5. from core.base import LimitSettings
  6. from core.providers.database.postgres import PostgresLimitsHandler
  7. from shared.abstractions import User
  8. @pytest.mark.asyncio
  9. async def test_log_request_and_count(limits_handler):
  10. """Test that when we log requests, the count increments, and rate-limits
  11. are enforced.
  12. Route-specific test using the /v3/retrieval/search endpoint limits.
  13. """
  14. # Clear existing logs first
  15. clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
  16. await limits_handler.connection_manager.execute_query(clear_query)
  17. user_id = uuid.uuid4()
  18. route = "/v3/retrieval/search" # Using actual route from config
  19. test_user = User(
  20. id=user_id,
  21. email="test@example.com",
  22. is_active=True,
  23. is_verified=True,
  24. is_superuser=False,
  25. limits_overrides=None,
  26. )
  27. # Set route limit to match config: 5 requests per minute
  28. old_route_limits = limits_handler.config.route_limits
  29. new_route_limits = {
  30. route: LimitSettings(route_per_min=5, monthly_limit=10)
  31. }
  32. limits_handler.config.route_limits = new_route_limits
  33. print(f"\nTesting with route limits: {new_route_limits}")
  34. print(f"Route settings: {limits_handler.config.route_limits[route]}")
  35. try:
  36. # Initial check should pass (no requests yet)
  37. await limits_handler.check_limits(test_user, route)
  38. print("Initial check passed (no requests)")
  39. # Log 5 requests (exactly at limit)
  40. for i in range(5):
  41. await limits_handler.log_request(user_id, route)
  42. now = datetime.now(timezone.utc)
  43. one_min_ago = now - timedelta(minutes=1)
  44. route_count = await limits_handler._count_requests(
  45. user_id, route, one_min_ago)
  46. print(f"Route count after request {i + 1}: {route_count}")
  47. # This should pass for all 5 requests
  48. await limits_handler.check_limits(test_user, route)
  49. print(f"Check limits passed after request {i + 1}")
  50. # Log the 6th request (over limit)
  51. await limits_handler.log_request(user_id, route)
  52. route_count = await limits_handler._count_requests(
  53. user_id, route, one_min_ago)
  54. print(f"Route count after request 6: {route_count}")
  55. # This check should fail as we've exceeded route_per_min=5
  56. with pytest.raises(ValueError,
  57. match="Per-route per-minute rate limit exceeded"):
  58. await limits_handler.check_limits(test_user, route)
  59. finally:
  60. limits_handler.config.route_limits = old_route_limits
  61. @pytest.mark.asyncio
  62. async def test_global_limit(limits_handler):
  63. """Test global limit using the configured limit of 10 requests per
  64. minute."""
  65. # Clear existing logs
  66. clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
  67. await limits_handler.connection_manager.execute_query(clear_query)
  68. user_id = uuid.uuid4()
  69. route = "/global-test"
  70. test_user = User(
  71. id=user_id,
  72. email="globaltest@example.com",
  73. is_active=True,
  74. is_verified=True,
  75. is_superuser=False,
  76. limits_overrides=None,
  77. )
  78. # Set global limit to match config: 10 requests per minute
  79. old_limits = limits_handler.config.limits
  80. limits_handler.config.limits = LimitSettings(global_per_min=10,
  81. monthly_limit=20)
  82. try:
  83. # Initial check should pass (no requests)
  84. await limits_handler.check_limits(test_user, route)
  85. print("Initial global check passed (no requests)")
  86. # Log 10 requests (hits the limit)
  87. for i in range(11):
  88. await limits_handler.log_request(user_id, route)
  89. # Debug counts
  90. now = datetime.now(timezone.utc)
  91. one_min_ago = now - timedelta(minutes=1)
  92. global_count = await limits_handler._count_requests(
  93. user_id, None, one_min_ago)
  94. print(f"Global count after 10 requests: {global_count}")
  95. # This should fail as we've hit global_per_min=10
  96. with pytest.raises(ValueError,
  97. match="Global per-minute rate limit exceeded"):
  98. await limits_handler.check_limits(test_user, route)
  99. finally:
  100. limits_handler.config.limits = old_limits
  101. @pytest.mark.asyncio
  102. async def test_monthly_limit(limits_handler):
  103. """Test monthly limit using the configured limit of 20 requests per
  104. month."""
  105. # Clear existing logs
  106. clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
  107. await limits_handler.connection_manager.execute_query(clear_query)
  108. user_id = uuid.uuid4()
  109. route = "/monthly-test"
  110. test_user = User(
  111. id=user_id,
  112. email="monthly@example.com",
  113. is_active=True,
  114. is_verified=True,
  115. is_superuser=False,
  116. limits_overrides=None,
  117. )
  118. old_limits = limits_handler.config.limits
  119. limits_handler.config.limits = LimitSettings(monthly_limit=20)
  120. try:
  121. # Initial check should pass (no requests)
  122. await limits_handler.check_limits(test_user, route)
  123. print("Initial monthly check passed (no requests)")
  124. # Log 20 requests (hits the monthly limit)
  125. for i in range(21):
  126. await limits_handler.log_request(user_id, route)
  127. # Get current month's count
  128. now = datetime.now(timezone.utc)
  129. first_of_month = now.replace(day=1,
  130. hour=0,
  131. minute=0,
  132. second=0,
  133. microsecond=0)
  134. monthly_count = await limits_handler._count_requests(
  135. user_id, None, first_of_month)
  136. print(f"Monthly count after 20 requests: {monthly_count}")
  137. # This should fail as we've hit monthly_limit=20
  138. with pytest.raises(ValueError, match="Monthly rate limit exceeded"):
  139. await limits_handler.check_limits(test_user, route)
  140. finally:
  141. limits_handler.config.limits = old_limits
  142. @pytest.mark.asyncio
  143. async def test_user_level_override(limits_handler):
  144. """Test user-specific override limits with debug logging."""
  145. user_id = UUID("47e53676-b478-5b3f-a409-234ca2164de5")
  146. route = "/test-route"
  147. # Clear existing logs first
  148. clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
  149. await limits_handler.connection_manager.execute_query(clear_query)
  150. test_user = User(
  151. id=user_id,
  152. email="override@example.com",
  153. is_active=True,
  154. is_verified=True,
  155. is_superuser=False,
  156. limits_overrides={
  157. "global_per_min": 2,
  158. "route_per_min": 1,
  159. "route_overrides": {
  160. "/test-route": {
  161. "route_per_min": 1
  162. }
  163. },
  164. },
  165. )
  166. # Set default limits that should be overridden
  167. old_limits = limits_handler.config.limits
  168. limits_handler.config.limits = LimitSettings(global_per_min=10,
  169. monthly_limit=20)
  170. # Debug: Print current limits
  171. print(f"\nDefault limits: {limits_handler.config.limits}")
  172. print(f"User overrides: {test_user.limits_overrides}")
  173. try:
  174. # First check limits (should pass as no requests yet)
  175. await limits_handler.check_limits(test_user, route)
  176. print("Initial check passed (no requests yet)")
  177. # Log first request
  178. await limits_handler.log_request(user_id, route)
  179. # Debug: Get current counts
  180. now = datetime.now(timezone.utc)
  181. one_min_ago = now - timedelta(minutes=1)
  182. global_count = await limits_handler._count_requests(
  183. user_id, None, one_min_ago)
  184. route_count = await limits_handler._count_requests(
  185. user_id, route, one_min_ago)
  186. print("\nAfter first request:")
  187. print(f"Global count: {global_count}")
  188. print(f"Route count: {route_count}")
  189. # Log second request
  190. await limits_handler.log_request(user_id, route)
  191. # This check should fail as we've hit route_per_min=1
  192. with pytest.raises(ValueError,
  193. match="Per-route per-minute rate limit exceeded"):
  194. await limits_handler.check_limits(test_user, route)
  195. finally:
  196. # Cleanup
  197. limits_handler.config.limits = old_limits
  198. @pytest.mark.asyncio
  199. async def test_determine_effective_limits(limits_handler):
  200. """Test that user-level overrides > route-level overrides > global
  201. defaults.
  202. This is a pure logic test of the 'determine_effective_limits' method.
  203. """
  204. # Setup global/base defaults
  205. old_limits = limits_handler.config.limits
  206. limits_handler.config.limits = LimitSettings(global_per_min=10,
  207. route_per_min=5,
  208. monthly_limit=50)
  209. # Setup route-level override
  210. route = "/some-route"
  211. old_route_limits = limits_handler.config.route_limits
  212. limits_handler.config.route_limits = {
  213. route: LimitSettings(global_per_min=8,
  214. route_per_min=3,
  215. monthly_limit=30)
  216. }
  217. # Setup user-level override
  218. test_user = User(
  219. id=uuid.uuid4(),
  220. email="test@example.com",
  221. is_active=True,
  222. is_verified=True,
  223. is_superuser=False,
  224. limits_overrides={
  225. "global_per_min": 6, # should override
  226. "route_overrides": {
  227. route: {
  228. "route_per_min": 2
  229. } # should override
  230. },
  231. },
  232. )
  233. try:
  234. effective = limits_handler.determine_effective_limits(test_user, route)
  235. # Check final / effective limits
  236. # Global limit overridden to 6
  237. assert effective.global_per_min == 6, (
  238. "User-level global override not applied")
  239. # route_per_min should be overridden to 2 (not the route-level 3)
  240. assert effective.route_per_min == 2, (
  241. "User-level route override not applied")
  242. # monthly_limit from route-level override is 30, user didn't override it, so it should stay 30
  243. assert effective.monthly_limit == 30, (
  244. "Route-level monthly override not applied")
  245. finally:
  246. # revert changes
  247. limits_handler.config.limits = old_limits
  248. limits_handler.config.route_limits = old_route_limits
  249. @pytest.mark.asyncio
  250. async def test_separate_route_usage_is_isolated(limits_handler):
  251. """Confirm that calls to /routeA do NOT increment the per-route usage for
  252. /routeB, and vice-versa."""
  253. # 1) Clear existing logs
  254. clear_query = f"DELETE FROM {limits_handler._get_table_name(limits_handler.TABLE_NAME)}"
  255. await limits_handler.connection_manager.execute_query(clear_query)
  256. # 2) Setup user & routes
  257. import uuid
  258. from shared.abstractions import User
  259. user_id = uuid.uuid4()
  260. routeA = "/v3/retrieval/rag"
  261. routeB = "/v3/retrieval/search"
  262. test_user = User(
  263. id=user_id,
  264. email="test@example.com",
  265. is_active=True,
  266. is_verified=True,
  267. is_superuser=False,
  268. limits_overrides=None,
  269. )
  270. # 3) Insert some logs for routeA only
  271. for _ in range(3):
  272. await limits_handler.log_request(user_id, routeA)
  273. # 4) Check usage for routeA → Should be 3 in last minute
  274. now = datetime.now(timezone.utc)
  275. one_min_ago = now - timedelta(minutes=1)
  276. routeA_count = await limits_handler._count_requests(
  277. user_id, routeA, one_min_ago)
  278. assert routeA_count == 3, f"Expected 3 for routeA, got {routeA_count}"
  279. # 5) Check usage for routeB → Should be 0
  280. routeB_count = await limits_handler._count_requests(
  281. user_id, routeB, one_min_ago)
  282. assert routeB_count == 0, f"Expected 0 for routeB, got {routeB_count}"
  283. # 6) Insert some logs for routeB only
  284. for _ in range(2):
  285. await limits_handler.log_request(user_id, routeB)
  286. # 7) Recheck usage
  287. routeA_count_after = await limits_handler._count_requests(
  288. user_id, routeA, one_min_ago)
  289. routeB_count_after = await limits_handler._count_requests(
  290. user_id, routeB, one_min_ago)
  291. assert routeA_count_after == 3, (
  292. f"RouteA usage changed unexpectedly: {routeA_count_after}")
  293. assert routeB_count_after == 2, (
  294. f"RouteB usage is wrong: {routeB_count_after}")
  295. # @pytest.mark.asyncio
  296. # async def test_check_limits_multiple_routes(limits_handler):
  297. # """
  298. # Demonstrates that routeA calls do not count against routeB's per-minute limit.
  299. # """
  300. # # Clear logs
  301. # clear_query = f"DELETE FROM {limits_handler._get_table_name(limits_handler.TABLE_NAME)}"
  302. # await limits_handler.connection_manager.execute_query(clear_query)
  303. # import uuid
  304. # from shared.abstractions import User
  305. # user_id = uuid.uuid4()
  306. # routeA = "/v3/retrieval/rag"
  307. # routeB = "/v3/retrieval/search"
  308. # # Suppose routeA has a limit of 2/min, routeB has a limit of 3/min
  309. # # (You can do this by setting config.route_limits[routeA].route_per_min, etc.)
  310. # # Or just rely on your global config if needed.
  311. # test_user = User(
  312. # id=user_id,
  313. # email="test@example.com",
  314. # is_active=True,
  315. # is_verified=True,
  316. # is_superuser=False,
  317. # limits_overrides=None,
  318. # )
  319. # # 1) Make 2 calls to routeA
  320. # await limits_handler.check_limits(test_user, routeA)
  321. # await limits_handler.log_request(user_id, routeA)
  322. # await limits_handler.check_limits(test_user, routeA)
  323. # await limits_handler.log_request(user_id, routeA)
  324. # await limits_handler.check_limits(test_user, routeA)
  325. # await limits_handler.log_request(user_id, routeA)
  326. # # 2) Confirm next call to routeA fails if the limit is 2/min
  327. # with pytest.raises(ValueError, match="Per-route per-minute rate limit exceeded"):
  328. # await limits_handler.check_limits(test_user, routeA)
  329. # # 3) Meanwhile, routeB usage should be unaffected
  330. # # We can still do 3 calls to routeB (assuming route_per_min=3).
  331. # await limits_handler.check_limits(test_user, routeB)
  332. # await limits_handler.log_request(user_id, routeB)
  333. # await limits_handler.check_limits(test_user, routeB)
  334. # await limits_handler.log_request(user_id, routeB)
  335. # await limits_handler.check_limits(test_user, routeB)
  336. # await limits_handler.log_request(user_id, routeB)
  337. @pytest.mark.asyncio
  338. async def test_route_specific_monthly_usage(limits_handler):
  339. """Confirm that monthly usage is tracked per-route and doesn't get
  340. incremented by calls to other routes."""
  341. # 1) Clear existing logs
  342. clear_query = f"DELETE FROM {limits_handler._get_table_name(limits_handler.TABLE_NAME)}"
  343. await limits_handler.connection_manager.execute_query(clear_query)
  344. # 2) Setup
  345. user_id = uuid.uuid4()
  346. routeA = "/v3/retrieval/rag"
  347. routeB = "/v3/retrieval/search"
  348. test_user = User(
  349. id=user_id,
  350. email="test_monthly_routes@example.com",
  351. is_active=True,
  352. is_verified=True,
  353. is_superuser=False,
  354. limits_overrides=None,
  355. )
  356. # 3) Log 5 requests for routeA
  357. for _ in range(5):
  358. await limits_handler.log_request(user_id, routeA)
  359. # 4) Check monthly usage for routeA => should be 5
  360. routeA_monthly = await limits_handler._count_monthly_requests(
  361. user_id, routeA)
  362. assert routeA_monthly == 5, f"Expected 5 for routeA, got {routeA_monthly}"
  363. # routeB => should still be 0
  364. routeB_monthly = await limits_handler._count_monthly_requests(
  365. user_id, routeB)
  366. assert routeB_monthly == 0, f"Expected 0 for routeB, got {routeB_monthly}"
  367. # 5) Now log 3 requests for routeB
  368. for _ in range(3):
  369. await limits_handler.log_request(user_id, routeB)
  370. # Re-check usage
  371. routeA_monthly_after = await limits_handler._count_monthly_requests(
  372. user_id, routeA)
  373. routeB_monthly_after = await limits_handler._count_monthly_requests(
  374. user_id, routeB)
  375. assert routeA_monthly_after == 5, (
  376. f"RouteA usage changed unexpectedly: {routeA_monthly_after}")
  377. assert routeB_monthly_after == 3, (
  378. f"RouteB usage is wrong: {routeB_monthly_after}")
  379. # Additionally confirm total usage across all routes
  380. global_monthly = await limits_handler._count_monthly_requests(user_id,
  381. route=None)
  382. assert global_monthly == 8, (
  383. f"Expected total of 8 monthly requests, got {global_monthly}")