base_router.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. import functools
  2. import logging
  3. from abc import abstractmethod
  4. from typing import Callable, Optional
  5. from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
  6. from fastapi.responses import FileResponse, StreamingResponse
  7. from core.base import R2RException, manage_run
  8. from ...abstractions import R2RProviders, R2RServices
  9. logger = logging.getLogger()
  10. class BaseRouterV3:
  11. def __init__(self, providers: R2RProviders, services: R2RServices):
  12. """
  13. :param providers: Typically includes auth, database, etc.
  14. :param services: Additional service references (ingestion, run_manager, etc).
  15. """
  16. self.providers = providers
  17. self.services = services
  18. self.router = APIRouter()
  19. self.openapi_extras = self._load_openapi_extras()
  20. # Add the rate-limiting dependency
  21. self.set_rate_limiting()
  22. # Initialize any routes
  23. self._setup_routes()
  24. self._register_workflows()
  25. def get_router(self):
  26. return self.router
  27. def base_endpoint(self, func: Callable):
  28. """
  29. A decorator to wrap endpoints in a standard pattern:
  30. - manage_run context
  31. - error handling
  32. - response shaping
  33. """
  34. @functools.wraps(func)
  35. async def wrapper(*args, **kwargs):
  36. async with manage_run(
  37. self.services.ingestion.run_manager, func.__name__
  38. ) as run_id:
  39. auth_user = kwargs.get("auth_user")
  40. if auth_user:
  41. # Optionally log run info with the user
  42. await self.services.ingestion.run_manager.log_run_info(
  43. user=auth_user,
  44. )
  45. try:
  46. func_result = await func(*args, **kwargs)
  47. if (
  48. isinstance(func_result, tuple)
  49. and len(func_result) == 2
  50. ):
  51. results, outer_kwargs = func_result
  52. else:
  53. results, outer_kwargs = func_result, {}
  54. if isinstance(results, (StreamingResponse, FileResponse)):
  55. return results
  56. return {"results": results, **outer_kwargs}
  57. except R2RException:
  58. raise
  59. except Exception as e:
  60. logger.error(
  61. f"Error in base endpoint {func.__name__}() - {str(e)}",
  62. exc_info=True,
  63. )
  64. raise HTTPException(
  65. status_code=500,
  66. detail={
  67. "message": f"An error '{e}' occurred during {func.__name__}",
  68. "error": str(e),
  69. "error_type": type(e).__name__,
  70. },
  71. ) from e
  72. return wrapper
  73. @classmethod
  74. def build_router(cls, engine):
  75. """
  76. Class method for building a router instance (if you have a standard pattern).
  77. """
  78. return cls(engine).router
  79. def _register_workflows(self):
  80. pass
  81. def _load_openapi_extras(self):
  82. return {}
  83. @abstractmethod
  84. def _setup_routes(self):
  85. """
  86. Subclasses override this to define actual endpoints.
  87. """
  88. pass
  89. def set_rate_limiting(self):
  90. """
  91. Adds a yield-based dependency for rate limiting each request.
  92. Checks the limits, then logs the request if the check passes.
  93. """
  94. async def rate_limit_dependency(
  95. request: Request,
  96. auth_user=Depends(self.providers.auth.auth_wrapper()),
  97. ):
  98. """
  99. 1) Fetch the user from the DB (including .limits_overrides).
  100. 2) Pass it to limits_handler.check_limits.
  101. 3) After the endpoint completes, call limits_handler.log_request.
  102. """
  103. # If the user is superuser, skip checks
  104. if auth_user.is_superuser:
  105. yield
  106. return
  107. user_id = auth_user.id
  108. route = request.scope["path"]
  109. # 1) Fetch the user from DB
  110. user = await self.providers.database.users_handler.get_user_by_id(
  111. user_id
  112. )
  113. if not user:
  114. raise HTTPException(status_code=404, detail="User not found.")
  115. # 2) Rate-limit check
  116. try:
  117. await self.providers.database.limits_handler.check_limits(
  118. user=user, route=route # Pass the User object
  119. )
  120. except ValueError as e:
  121. # If check_limits raises ValueError -> 429 Too Many Requests
  122. raise HTTPException(status_code=429, detail=str(e))
  123. request.state.user_id = user_id
  124. request.state.route = route
  125. # 3) Execute the route
  126. try:
  127. yield
  128. finally:
  129. # 4) Log the request afterwards
  130. await self.providers.database.limits_handler.log_request(
  131. user_id, route
  132. )
  133. async def websocket_rate_limit_dependency(websocket: WebSocket):
  134. # Example: if you want to rate-limit websockets similarly
  135. route = websocket.scope["path"]
  136. # If you had a user or token, you'd do the same check.
  137. try:
  138. # e.g. check_limits(user_id, route)
  139. return True
  140. except ValueError:
  141. await websocket.close(code=4429, reason="Rate limit exceeded")
  142. return False
  143. # Attach the dependencies so you can use them in your endpoints
  144. self.rate_limit_dependency = rate_limit_dependency
  145. self.websocket_rate_limit_dependency = websocket_rate_limit_dependency