test_limits.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. import uuid
  2. from datetime import datetime, timedelta, timezone
  3. from uuid import UUID
  4. import pytest
  5. from core.base import LimitSettings
  6. from core.database.postgres import PostgresLimitsHandler
  7. from shared.abstractions import User
  8. @pytest.mark.asyncio
  9. async def test_log_request_and_count(limits_handler):
  10. """
  11. Test that when we log requests, the count increments, and rate-limits are enforced.
  12. Route-specific test using the /v3/retrieval/search endpoint limits.
  13. """
  14. # Clear existing logs first
  15. clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
  16. await limits_handler.connection_manager.execute_query(clear_query)
  17. user_id = uuid.uuid4()
  18. route = "/v3/retrieval/search" # Using actual route from config
  19. test_user = User(
  20. id=user_id,
  21. email="test@example.com",
  22. is_active=True,
  23. is_verified=True,
  24. is_superuser=False,
  25. limits_overrides=None,
  26. )
  27. # Set route limit to match config: 5 requests per minute
  28. old_route_limits = limits_handler.config.route_limits
  29. new_route_limits = {
  30. route: LimitSettings(route_per_min=5, monthly_limit=10)
  31. }
  32. limits_handler.config.route_limits = new_route_limits
  33. print(f"\nTesting with route limits: {new_route_limits}")
  34. print(f"Route settings: {limits_handler.config.route_limits[route]}")
  35. try:
  36. # Initial check should pass (no requests yet)
  37. await limits_handler.check_limits(test_user, route)
  38. print("Initial check passed (no requests)")
  39. # Log 5 requests (exactly at limit)
  40. for i in range(5):
  41. await limits_handler.log_request(user_id, route)
  42. now = datetime.now(timezone.utc)
  43. one_min_ago = now - timedelta(minutes=1)
  44. route_count = await limits_handler._count_requests(
  45. user_id, route, one_min_ago
  46. )
  47. print(f"Route count after request {i+1}: {route_count}")
  48. # This should pass for all 5 requests
  49. await limits_handler.check_limits(test_user, route)
  50. print(f"Check limits passed after request {i+1}")
  51. # Log the 6th request (over limit)
  52. await limits_handler.log_request(user_id, route)
  53. route_count = await limits_handler._count_requests(
  54. user_id, route, one_min_ago
  55. )
  56. print(f"Route count after request 6: {route_count}")
  57. # This check should fail as we've exceeded route_per_min=5
  58. with pytest.raises(
  59. ValueError, match="Per-route per-minute rate limit exceeded"
  60. ):
  61. await limits_handler.check_limits(test_user, route)
  62. finally:
  63. limits_handler.config.route_limits = old_route_limits
  64. @pytest.mark.asyncio
  65. async def test_global_limit(limits_handler):
  66. """
  67. Test global limit using the configured limit of 10 requests per minute
  68. """
  69. # Clear existing logs
  70. clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
  71. await limits_handler.connection_manager.execute_query(clear_query)
  72. user_id = uuid.uuid4()
  73. route = "/global-test"
  74. test_user = User(
  75. id=user_id,
  76. email="globaltest@example.com",
  77. is_active=True,
  78. is_verified=True,
  79. is_superuser=False,
  80. limits_overrides=None,
  81. )
  82. # Set global limit to match config: 10 requests per minute
  83. old_limits = limits_handler.config.limits
  84. limits_handler.config.limits = LimitSettings(
  85. global_per_min=10, monthly_limit=20
  86. )
  87. try:
  88. # Initial check should pass (no requests)
  89. await limits_handler.check_limits(test_user, route)
  90. print("Initial global check passed (no requests)")
  91. # Log 10 requests (hits the limit)
  92. for i in range(11):
  93. await limits_handler.log_request(user_id, route)
  94. # Debug counts
  95. now = datetime.now(timezone.utc)
  96. one_min_ago = now - timedelta(minutes=1)
  97. global_count = await limits_handler._count_requests(
  98. user_id, None, one_min_ago
  99. )
  100. print(f"Global count after 10 requests: {global_count}")
  101. # This should fail as we've hit global_per_min=10
  102. with pytest.raises(
  103. ValueError, match="Global per-minute rate limit exceeded"
  104. ):
  105. await limits_handler.check_limits(test_user, route)
  106. finally:
  107. limits_handler.config.limits = old_limits
  108. @pytest.mark.asyncio
  109. async def test_monthly_limit(limits_handler):
  110. """
  111. Test monthly limit using the configured limit of 20 requests per month
  112. """
  113. # Clear existing logs
  114. clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
  115. await limits_handler.connection_manager.execute_query(clear_query)
  116. user_id = uuid.uuid4()
  117. route = "/monthly-test"
  118. test_user = User(
  119. id=user_id,
  120. email="monthly@example.com",
  121. is_active=True,
  122. is_verified=True,
  123. is_superuser=False,
  124. limits_overrides=None,
  125. )
  126. old_limits = limits_handler.config.limits
  127. limits_handler.config.limits = LimitSettings(monthly_limit=20)
  128. try:
  129. # Initial check should pass (no requests)
  130. await limits_handler.check_limits(test_user, route)
  131. print("Initial monthly check passed (no requests)")
  132. # Log 20 requests (hits the monthly limit)
  133. for i in range(21):
  134. await limits_handler.log_request(user_id, route)
  135. # Get current month's count
  136. now = datetime.now(timezone.utc)
  137. first_of_month = now.replace(
  138. day=1, hour=0, minute=0, second=0, microsecond=0
  139. )
  140. monthly_count = await limits_handler._count_requests(
  141. user_id, None, first_of_month
  142. )
  143. print(f"Monthly count after 20 requests: {monthly_count}")
  144. # This should fail as we've hit monthly_limit=20
  145. with pytest.raises(ValueError, match="Monthly rate limit exceeded"):
  146. await limits_handler.check_limits(test_user, route)
  147. finally:
  148. limits_handler.config.limits = old_limits
  149. @pytest.mark.asyncio
  150. async def test_user_level_override(limits_handler):
  151. """
  152. Test user-specific override limits with debug logging
  153. """
  154. user_id = UUID("47e53676-b478-5b3f-a409-234ca2164de5")
  155. route = "/test-route"
  156. # Clear existing logs first
  157. clear_query = f"DELETE FROM {limits_handler._get_table_name(PostgresLimitsHandler.TABLE_NAME)}"
  158. await limits_handler.connection_manager.execute_query(clear_query)
  159. test_user = User(
  160. id=user_id,
  161. email="override@example.com",
  162. is_active=True,
  163. is_verified=True,
  164. is_superuser=False,
  165. limits_overrides={
  166. "global_per_min": 2,
  167. "route_per_min": 1,
  168. "route_overrides": {"/test-route": {"route_per_min": 1}},
  169. },
  170. )
  171. # Set default limits that should be overridden
  172. old_limits = limits_handler.config.limits
  173. limits_handler.config.limits = LimitSettings(
  174. global_per_min=10, monthly_limit=20
  175. )
  176. # Debug: Print current limits
  177. print(f"\nDefault limits: {limits_handler.config.limits}")
  178. print(f"User overrides: {test_user.limits_overrides}")
  179. try:
  180. # First check limits (should pass as no requests yet)
  181. await limits_handler.check_limits(test_user, route)
  182. print("Initial check passed (no requests yet)")
  183. # Log first request
  184. await limits_handler.log_request(user_id, route)
  185. # Debug: Get current counts
  186. now = datetime.now(timezone.utc)
  187. one_min_ago = now - timedelta(minutes=1)
  188. global_count = await limits_handler._count_requests(
  189. user_id, None, one_min_ago
  190. )
  191. route_count = await limits_handler._count_requests(
  192. user_id, route, one_min_ago
  193. )
  194. print(f"\nAfter first request:")
  195. print(f"Global count: {global_count}")
  196. print(f"Route count: {route_count}")
  197. # Log second request
  198. await limits_handler.log_request(user_id, route)
  199. # This check should fail as we've hit route_per_min=1
  200. with pytest.raises(
  201. ValueError, match="Per-route per-minute rate limit exceeded"
  202. ):
  203. await limits_handler.check_limits(test_user, route)
  204. finally:
  205. # Cleanup
  206. limits_handler.config.limits = old_limits