123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- import logging
- from datetime import datetime, timedelta, timezone
- from typing import Optional
- from uuid import UUID
- from core.base import Handler, R2RException
- from .base import PostgresConnectionManager
- logger = logging.getLogger()
- class PostgresLimitsHandler(Handler):
- TABLE_NAME = "request_log"
- def __init__(
- self,
- project_name: str,
- connection_manager: PostgresConnectionManager,
- route_limits: dict,
- ):
- super().__init__(project_name, connection_manager)
- self.route_limits = route_limits
- async def create_tables(self):
- query = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
- time TIMESTAMPTZ NOT NULL,
- user_id UUID NOT NULL,
- route TEXT NOT NULL
- );
- """
- await self.connection_manager.execute_query(query)
- async def _count_requests(
- self, user_id: UUID, route: Optional[str], since: datetime
- ) -> int:
- if route:
- query = f"""
- SELECT COUNT(*)::int
- FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
- WHERE user_id = $1
- AND route = $2
- AND time >= $3
- """
- params = [user_id, route, since]
- else:
- query = f"""
- SELECT COUNT(*)::int
- FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
- WHERE user_id = $1
- AND time >= $2
- """
- params = [user_id, since]
- result = await self.connection_manager.fetchrow_query(query, params)
- return result["count"] if result else 0
- async def _count_monthly_requests(self, user_id: UUID) -> int:
- now = datetime.now(timezone.utc)
- start_of_month = now.replace(
- day=1, hour=0, minute=0, second=0, microsecond=0
- )
- return await self._count_requests(
- user_id, route=None, since=start_of_month
- )
- async def check_limits(self, user_id: UUID, route: str):
- limits = self.route_limits.get(
- route,
- {
- "global_per_min": 60,
- "route_per_min": 30,
- "monthly_limit": 10000,
- },
- )
- global_per_min = limits["global_per_min"]
- route_per_min = limits["route_per_min"]
- monthly_limit = limits["monthly_limit"]
- now = datetime.now(timezone.utc)
- one_min_ago = now - timedelta(minutes=1)
- # Global per-minute check
- user_req_count = await self._count_requests(user_id, None, one_min_ago)
- print("min req count = ", user_req_count)
- if user_req_count >= global_per_min:
- raise ValueError("Global per-minute rate limit exceeded")
- # Per-route per-minute check
- route_req_count = await self._count_requests(
- user_id, route, one_min_ago
- )
- if route_req_count >= route_per_min:
- raise ValueError("Per-route per-minute rate limit exceeded")
- # Monthly limit check
- monthly_count = await self._count_monthly_requests(user_id)
- print("monthly_count = ", monthly_count)
- if monthly_count >= monthly_limit:
- raise ValueError("Monthly rate limit exceeded")
- async def log_request(self, user_id: UUID, route: str):
- query = f"""
- INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
- VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
- """
- await self.connection_manager.execute_query(query, [user_id, route])
- # import logging
- # from datetime import datetime, timedelta
- # from typing import Optional
- # from uuid import UUID
- # from core.base import Handler, R2RException
- # from .base import PostgresConnectionManager
- # logger = logging.getLogger()
- # class PostgresLimitsHandler(Handler):
- # TABLE_NAME = "request_log"
- # def __init__(
- # self,
- # project_name: str,
- # connection_manager: PostgresConnectionManager,
- # route_limits: dict,
- # ):
- # super().__init__(project_name, connection_manager)
- # self.route_limits = route_limits
- # async def create_tables(self):
- # """
- # Create the request_log table if it doesn't exist.
- # """
- # query = f"""
- # CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
- # time TIMESTAMPTZ NOT NULL,
- # user_id UUID NOT NULL,
- # route TEXT NOT NULL
- # );
- # """
- # await self.connection_manager.execute_query(query)
- # async def _count_requests(
- # self, user_id: UUID, route: Optional[str], since: datetime
- # ) -> int:
- # if route:
- # query = f"""
- # SELECT COUNT(*)::int
- # FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
- # WHERE user_id = $1
- # AND route = $2
- # AND time >= $3
- # """
- # params = [user_id, route, since]
- # else:
- # query = f"""
- # SELECT COUNT(*)::int
- # FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
- # WHERE user_id = $1
- # AND time >= $2
- # """
- # params = [user_id, since]
- # result = await self.connection_manager.fetchrow_query(query, params)
- # return result["count"] if result else 0
- # async def _count_monthly_requests(self, user_id: UUID) -> int:
- # now = datetime.utcnow()
- # start_of_month = now.replace(
- # day=1, hour=0, minute=0, second=0, microsecond=0
- # )
- # return await self._count_requests(
- # user_id, route=None, since=start_of_month
- # )
- # async def check_limits(self, user_id: UUID, route: str):
- # """
- # Check if the user can proceed with the request, using route-specific limits.
- # Raises ValueError if the user exceeded any limit.
- # """
- # limits = self.route_limits.get(
- # route,
- # {
- # "global_per_min": 60, # default global per min
- # "route_per_min": 20, # default route per min
- # "monthly_limit": 10000, # default monthly limit
- # },
- # )
- # global_per_min = limits["global_per_min"]
- # route_per_min = limits["route_per_min"]
- # monthly_limit = limits["monthly_limit"]
- # now = datetime.utcnow()
- # one_min_ago = now - timedelta(minutes=1)
- # # Global per-minute check
- # user_req_count = await self._count_requests(user_id, None, one_min_ago)
- # print('min req count = ', user_req_count)
- # if user_req_count >= global_per_min:
- # raise ValueError("Global per-minute rate limit exceeded")
- # # Per-route per-minute check
- # route_req_count = await self._count_requests(
- # user_id, route, one_min_ago
- # )
- # if route_req_count >= route_per_min:
- # raise ValueError("Per-route per-minute rate limit exceeded")
- # # Monthly limit check
- # monthly_count = await self._count_monthly_requests(user_id)
- # print('monthly_count = ', monthly_count)
- # if monthly_count >= monthly_limit:
- # raise ValueError("Monthly rate limit exceeded")
- # async def log_request(self, user_id: UUID, route: str):
- # query = f"""
- # INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (time, user_id, route)
- # VALUES (NOW(), $1, $2)
- # """
- # await self.connection_manager.execute_query(query, [user_id, route])
|