123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173 |
- import functools
- import logging
- from abc import abstractmethod
- from typing import Callable, Optional
- from fastapi import APIRouter, Depends, HTTPException, Request, WebSocket
- from fastapi.responses import FileResponse, StreamingResponse
- from core.base import R2RException, manage_run
- from ...abstractions import R2RProviders, R2RServices
- logger = logging.getLogger()
- class BaseRouterV3:
- def __init__(self, providers: R2RProviders, services: R2RServices):
- """
- :param providers: Typically includes auth, database, etc.
- :param services: Additional service references (ingestion, run_manager, etc).
- """
- self.providers = providers
- self.services = services
- self.router = APIRouter()
- self.openapi_extras = self._load_openapi_extras()
- # Add the rate-limiting dependency
- self.set_rate_limiting()
- # Initialize any routes
- self._setup_routes()
- self._register_workflows()
- def get_router(self):
- return self.router
- def base_endpoint(self, func: Callable):
- """
- A decorator to wrap endpoints in a standard pattern:
- - manage_run context
- - error handling
- - response shaping
- """
- @functools.wraps(func)
- async def wrapper(*args, **kwargs):
- async with manage_run(
- self.services.ingestion.run_manager, func.__name__
- ) as run_id:
- auth_user = kwargs.get("auth_user")
- if auth_user:
- # Optionally log run info with the user
- await self.services.ingestion.run_manager.log_run_info(
- user=auth_user,
- )
- try:
- func_result = await func(*args, **kwargs)
- if (
- isinstance(func_result, tuple)
- and len(func_result) == 2
- ):
- results, outer_kwargs = func_result
- else:
- results, outer_kwargs = func_result, {}
- if isinstance(results, (StreamingResponse, FileResponse)):
- return results
- return {"results": results, **outer_kwargs}
- except R2RException:
- raise
- except Exception as e:
- logger.error(
- f"Error in base endpoint {func.__name__}() - {str(e)}",
- exc_info=True,
- )
- raise HTTPException(
- status_code=500,
- detail={
- "message": f"An error '{e}' occurred during {func.__name__}",
- "error": str(e),
- "error_type": type(e).__name__,
- },
- ) from e
- return wrapper
- @classmethod
- def build_router(cls, engine):
- """
- Class method for building a router instance (if you have a standard pattern).
- """
- return cls(engine).router
- def _register_workflows(self):
- pass
- def _load_openapi_extras(self):
- return {}
- @abstractmethod
- def _setup_routes(self):
- """
- Subclasses override this to define actual endpoints.
- """
- pass
- def set_rate_limiting(self):
- """
- Adds a yield-based dependency for rate limiting each request.
- Checks the limits, then logs the request if the check passes.
- """
- async def rate_limit_dependency(
- request: Request,
- auth_user=Depends(self.providers.auth.auth_wrapper()),
- ):
- """
- 1) Fetch the user from the DB (including .limits_overrides).
- 2) Pass it to limits_handler.check_limits.
- 3) After the endpoint completes, call limits_handler.log_request.
- """
- # If the user is superuser, skip checks
- if auth_user.is_superuser:
- yield
- return
- user_id = auth_user.id
- route = request.scope["path"]
- # 1) Fetch the user from DB
- user = await self.providers.database.users_handler.get_user_by_id(
- user_id
- )
- if not user:
- raise HTTPException(status_code=404, detail="User not found.")
- # 2) Rate-limit check
- try:
- await self.providers.database.limits_handler.check_limits(
- user=user, route=route # Pass the User object
- )
- except ValueError as e:
- # If check_limits raises ValueError -> 429 Too Many Requests
- raise HTTPException(status_code=429, detail=str(e))
- request.state.user_id = user_id
- request.state.route = route
- # 3) Execute the route
- try:
- yield
- finally:
- # 4) Log the request afterwards
- await self.providers.database.limits_handler.log_request(
- user_id, route
- )
- async def websocket_rate_limit_dependency(websocket: WebSocket):
- # Example: if you want to rate-limit websockets similarly
- route = websocket.scope["path"]
- # If you had a user or token, you'd do the same check.
- try:
- # e.g. check_limits(user_id, route)
- return True
- except ValueError:
- await websocket.close(code=4429, reason="Rate limit exceeded")
- return False
- # Attach the dependencies so you can use them in your endpoints
- self.rate_limit_dependency = rate_limit_dependency
- self.websocket_rate_limit_dependency = websocket_rate_limit_dependency
|