tokens.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from datetime import datetime, timedelta
  2. from typing import Optional
  3. from core.base import Handler
  4. from .base import PostgresConnectionManager
  5. class PostgresTokensHandler(Handler):
  6. TABLE_NAME = "blacklisted_tokens"
  7. def __init__(
  8. self, project_name: str, connection_manager: PostgresConnectionManager
  9. ):
  10. super().__init__(project_name, connection_manager)
  11. async def create_tables(self):
  12. query = f"""
  13. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (
  14. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  15. token TEXT NOT NULL,
  16. blacklisted_at TIMESTAMPTZ DEFAULT NOW()
  17. );
  18. CREATE INDEX IF NOT EXISTS idx_{self.project_name}_{PostgresTokensHandler.TABLE_NAME}_token
  19. ON {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (token);
  20. CREATE INDEX IF NOT EXISTS idx_{self.project_name}_{PostgresTokensHandler.TABLE_NAME}_blacklisted_at
  21. ON {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (blacklisted_at);
  22. """
  23. await self.connection_manager.execute_query(query)
  24. async def blacklist_token(
  25. self, token: str, current_time: Optional[datetime] = None
  26. ):
  27. if current_time is None:
  28. current_time = datetime.utcnow()
  29. query = f"""
  30. INSERT INTO {self._get_table_name(PostgresTokensHandler.TABLE_NAME)} (token, blacklisted_at)
  31. VALUES ($1, $2)
  32. """
  33. await self.connection_manager.execute_query(
  34. query, [token, current_time]
  35. )
  36. async def is_token_blacklisted(self, token: str) -> bool:
  37. query = f"""
  38. SELECT 1 FROM {self._get_table_name(PostgresTokensHandler.TABLE_NAME)}
  39. WHERE token = $1
  40. LIMIT 1
  41. """
  42. result = await self.connection_manager.fetchrow_query(query, [token])
  43. return bool(result)
  44. async def clean_expired_blacklisted_tokens(
  45. self,
  46. max_age_hours: int = 7 * 24,
  47. current_time: Optional[datetime] = None,
  48. ):
  49. if current_time is None:
  50. current_time = datetime.utcnow()
  51. expiry_time = current_time - timedelta(hours=max_age_hours)
  52. query = f"""
  53. DELETE FROM {self._get_table_name(PostgresTokensHandler.TABLE_NAME)}
  54. WHERE blacklisted_at < $1
  55. """
  56. await self.connection_manager.execute_query(query, [expiry_time])