base_router.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. import functools
  2. import logging
  3. from abc import abstractmethod
  4. from typing import Callable
  5. from fastapi import APIRouter, Depends, HTTPException, Request
  6. from fastapi.responses import FileResponse, StreamingResponse
  7. from core.base import R2RException
  8. from ...abstractions import R2RProviders, R2RServices
  9. from ...config import R2RConfig
  10. logger = logging.getLogger()
  11. class BaseRouterV3:
  12. def __init__(
  13. self, providers: R2RProviders, services: R2RServices, config: R2RConfig
  14. ):
  15. """
  16. :param providers: Typically includes auth, database, etc.
  17. :param services: Additional service references (ingestion, etc).
  18. """
  19. self.providers = providers
  20. self.services = services
  21. self.config = config
  22. self.router = APIRouter()
  23. self.openapi_extras = self._load_openapi_extras()
  24. # Add the rate-limiting dependency
  25. self.set_rate_limiting()
  26. # Initialize any routes
  27. self._setup_routes()
  28. self._register_workflows()
  29. def get_router(self):
  30. return self.router
  31. def base_endpoint(self, func: Callable):
  32. """
  33. A decorator to wrap endpoints in a standard pattern:
  34. - error handling
  35. - response shaping
  36. """
  37. @functools.wraps(func)
  38. async def wrapper(*args, **kwargs):
  39. try:
  40. func_result = await func(*args, **kwargs)
  41. if isinstance(func_result, tuple) and len(func_result) == 2:
  42. results, outer_kwargs = func_result
  43. else:
  44. results, outer_kwargs = func_result, {}
  45. if isinstance(results, (StreamingResponse, FileResponse)):
  46. return results
  47. return {"results": results, **outer_kwargs}
  48. except R2RException:
  49. raise
  50. except Exception as e:
  51. logger.error(
  52. f"Error in base endpoint {func.__name__}() - {str(e)}",
  53. exc_info=True,
  54. )
  55. raise HTTPException(
  56. status_code=500,
  57. detail={
  58. "message": f"An error '{e}' occurred during {func.__name__}",
  59. "error": str(e),
  60. "error_type": type(e).__name__,
  61. },
  62. ) from e
  63. wrapper._is_base_endpoint = True # type: ignore
  64. return wrapper
  65. @classmethod
  66. def build_router(cls, engine):
  67. """Class method for building a router instance (if you have a standard
  68. pattern)."""
  69. return cls(engine).router
  70. def _register_workflows(self):
  71. pass
  72. def _load_openapi_extras(self):
  73. return {}
  74. @abstractmethod
  75. def _setup_routes(self):
  76. """Subclasses override this to define actual endpoints."""
  77. pass
  78. def set_rate_limiting(self):
  79. """Adds a yield-based dependency for rate limiting each request.
  80. Checks the limits, then logs the request if the check passes.
  81. """
  82. async def rate_limit_dependency(
  83. request: Request,
  84. auth_user=Depends(self.providers.auth.auth_wrapper()),
  85. ):
  86. """1) Fetch the user from the DB (including .limits_overrides).
  87. 2) Pass it to limits_handler.check_limits. 3) After the endpoint
  88. completes, call limits_handler.log_request.
  89. """
  90. # If the user is superuser, skip checks
  91. if auth_user.is_superuser:
  92. yield
  93. return
  94. user_id = auth_user.id
  95. route = request.scope["path"]
  96. # 1) Fetch the user from DB
  97. user = await self.providers.database.users_handler.get_user_by_id(
  98. user_id
  99. )
  100. if not user:
  101. raise HTTPException(status_code=404, detail="User not found.")
  102. # 2) Rate-limit check
  103. try:
  104. await self.providers.database.limits_handler.check_limits(
  105. user=user,
  106. route=route, # Pass the User object
  107. )
  108. except ValueError as e:
  109. # If check_limits raises ValueError -> 429 Too Many Requests
  110. raise HTTPException(status_code=429, detail=str(e)) from e
  111. request.state.user_id = user_id
  112. request.state.route = route
  113. # 3) Execute the route
  114. try:
  115. yield
  116. finally:
  117. # 4) Log only POST and DELETE requests
  118. if request.method in ["POST", "DELETE"]:
  119. await self.providers.database.limits_handler.log_request(
  120. user_id, route
  121. )
  122. # Attach the dependencies so you can use them in your endpoints
  123. self.rate_limit_dependency = rate_limit_dependency