auth.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import logging
  2. from abc import ABC, abstractmethod
  3. from typing import TYPE_CHECKING, Optional
  4. from fastapi import Security
  5. from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
  6. from ..abstractions import R2RException, Token, TokenData
  7. from ..api.models import User
  8. from .base import Provider, ProviderConfig
  9. from .crypto import CryptoProvider
  10. # from .database import DatabaseProvider
  11. from .email import EmailProvider
  12. logger = logging.getLogger()
  13. if TYPE_CHECKING:
  14. from core.database import PostgresDatabaseProvider
  15. class AuthConfig(ProviderConfig):
  16. secret_key: Optional[str] = None
  17. require_authentication: bool = False
  18. require_email_verification: bool = False
  19. default_admin_email: str = "admin@example.com"
  20. default_admin_password: str = "change_me_immediately"
  21. access_token_lifetime_in_minutes: Optional[int] = None
  22. refresh_token_lifetime_in_days: Optional[int] = None
  23. @property
  24. def supported_providers(self) -> list[str]:
  25. return ["r2r"]
  26. def validate_config(self) -> None:
  27. pass
  28. class AuthProvider(Provider, ABC):
  29. security = HTTPBearer(auto_error=False)
  30. crypto_provider: CryptoProvider
  31. email_provider: EmailProvider
  32. database_provider: "PostgresDatabaseProvider"
  33. def __init__(
  34. self,
  35. config: AuthConfig,
  36. crypto_provider: CryptoProvider,
  37. database_provider: "PostgresDatabaseProvider",
  38. email_provider: EmailProvider,
  39. ):
  40. if not isinstance(config, AuthConfig):
  41. raise ValueError(
  42. "AuthProvider must be initialized with an AuthConfig"
  43. )
  44. self.config = config
  45. self.admin_email = config.default_admin_email
  46. self.admin_password = config.default_admin_password
  47. self.crypto_provider = crypto_provider
  48. self.database_provider = database_provider
  49. self.email_provider = email_provider
  50. super().__init__(config)
  51. self.config: AuthConfig = config # for type hinting
  52. self.database_provider: "PostgresDatabaseProvider" = (
  53. database_provider # for type hinting
  54. )
  55. async def _get_default_admin_user(self) -> User:
  56. return await self.database_provider.users_handler.get_user_by_email(
  57. self.admin_email
  58. )
  59. @abstractmethod
  60. def create_access_token(self, data: dict) -> str:
  61. pass
  62. @abstractmethod
  63. def create_refresh_token(self, data: dict) -> str:
  64. pass
  65. @abstractmethod
  66. async def decode_token(self, token: str) -> TokenData:
  67. pass
  68. @abstractmethod
  69. async def user(self, token: str) -> User:
  70. pass
  71. @abstractmethod
  72. def get_current_active_user(self, current_user: User) -> User:
  73. pass
  74. @abstractmethod
  75. async def register(self, email: str, password: str) -> User:
  76. pass
  77. @abstractmethod
  78. async def verify_email(
  79. self, email: str, verification_code: str
  80. ) -> dict[str, str]:
  81. pass
  82. @abstractmethod
  83. async def login(self, email: str, password: str) -> dict[str, Token]:
  84. pass
  85. @abstractmethod
  86. async def refresh_access_token(
  87. self, refresh_token: str
  88. ) -> dict[str, Token]:
  89. pass
  90. async def auth_wrapper(
  91. self, auth: Optional[HTTPAuthorizationCredentials] = Security(security)
  92. ) -> User:
  93. if not self.config.require_authentication and auth is None:
  94. return await self._get_default_admin_user()
  95. if auth is None:
  96. raise R2RException(
  97. message="Authentication required.",
  98. status_code=401,
  99. )
  100. try:
  101. return await self.user(auth.credentials)
  102. except Exception as e:
  103. raise R2RException(
  104. message=f"Error '{e}' occurred during authentication.",
  105. status_code=404,
  106. )
  107. @abstractmethod
  108. async def change_password(
  109. self, user: User, current_password: str, new_password: str
  110. ) -> dict[str, str]:
  111. pass
  112. @abstractmethod
  113. async def request_password_reset(self, email: str) -> dict[str, str]:
  114. pass
  115. @abstractmethod
  116. async def confirm_password_reset(
  117. self, reset_token: str, new_password: str
  118. ) -> dict[str, str]:
  119. pass
  120. @abstractmethod
  121. async def logout(self, token: str) -> dict[str, str]:
  122. pass
  123. @abstractmethod
  124. async def send_reset_email(self, email: str) -> dict[str, str]:
  125. pass