auth.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import logging
  2. from abc import ABC, abstractmethod
  3. from typing import TYPE_CHECKING, Optional
  4. from fastapi import Security
  5. from fastapi.security import (
  6. APIKeyHeader,
  7. HTTPAuthorizationCredentials,
  8. HTTPBearer,
  9. )
  10. from ..abstractions import R2RException, Token, TokenData
  11. from ..api.models import User
  12. from .base import Provider, ProviderConfig
  13. from .crypto import CryptoProvider
  14. # from .database import DatabaseProvider
  15. from .email import EmailProvider
  16. logger = logging.getLogger()
  17. if TYPE_CHECKING:
  18. from core.database import PostgresDatabaseProvider
  19. api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
  20. class AuthConfig(ProviderConfig):
  21. secret_key: Optional[str] = None
  22. require_authentication: bool = False
  23. require_email_verification: bool = False
  24. default_admin_email: str = "admin@example.com"
  25. default_admin_password: str = "change_me_immediately"
  26. access_token_lifetime_in_minutes: Optional[int] = None
  27. refresh_token_lifetime_in_days: Optional[int] = None
  28. @property
  29. def supported_providers(self) -> list[str]:
  30. return ["r2r"]
  31. def validate_config(self) -> None:
  32. pass
  33. class AuthProvider(Provider, ABC):
  34. security = HTTPBearer(auto_error=False)
  35. crypto_provider: CryptoProvider
  36. email_provider: EmailProvider
  37. database_provider: "PostgresDatabaseProvider"
  38. def __init__(
  39. self,
  40. config: AuthConfig,
  41. crypto_provider: CryptoProvider,
  42. database_provider: "PostgresDatabaseProvider",
  43. email_provider: EmailProvider,
  44. ):
  45. if not isinstance(config, AuthConfig):
  46. raise ValueError(
  47. "AuthProvider must be initialized with an AuthConfig"
  48. )
  49. self.config = config
  50. self.admin_email = config.default_admin_email
  51. self.admin_password = config.default_admin_password
  52. self.crypto_provider = crypto_provider
  53. self.database_provider = database_provider
  54. self.email_provider = email_provider
  55. super().__init__(config)
  56. self.config: AuthConfig = config # for type hinting
  57. self.database_provider: "PostgresDatabaseProvider" = (
  58. database_provider # for type hinting
  59. )
  60. async def _get_default_admin_user(self) -> User:
  61. return await self.database_provider.users_handler.get_user_by_email(
  62. self.admin_email
  63. )
  64. @abstractmethod
  65. def create_access_token(self, data: dict) -> str:
  66. pass
  67. @abstractmethod
  68. def create_refresh_token(self, data: dict) -> str:
  69. pass
  70. @abstractmethod
  71. async def decode_token(self, token: str) -> TokenData:
  72. pass
  73. @abstractmethod
  74. async def user(self, token: str) -> User:
  75. pass
  76. @abstractmethod
  77. def get_current_active_user(self, current_user: User) -> User:
  78. pass
  79. @abstractmethod
  80. async def register(self, email: str, password: str) -> User:
  81. pass
  82. @abstractmethod
  83. async def verify_email(
  84. self, email: str, verification_code: str
  85. ) -> dict[str, str]:
  86. pass
  87. @abstractmethod
  88. async def login(self, email: str, password: str) -> dict[str, Token]:
  89. pass
  90. @abstractmethod
  91. async def refresh_access_token(
  92. self, refresh_token: str
  93. ) -> dict[str, Token]:
  94. pass
  95. def auth_wrapper(
  96. self,
  97. public: bool = False,
  98. ):
  99. async def _auth_wrapper(
  100. auth: Optional[HTTPAuthorizationCredentials] = Security(
  101. self.security
  102. ),
  103. api_key: Optional[str] = Security(api_key_header),
  104. ) -> User:
  105. # If authentication is not required and no credentials are provided, return the default admin user
  106. if (
  107. ((not self.config.require_authentication) or public)
  108. and auth is None
  109. and api_key is None
  110. ):
  111. return await self._get_default_admin_user()
  112. if not auth and not api_key:
  113. raise R2RException(
  114. message="No credentials provided",
  115. status_code=401,
  116. )
  117. if auth and api_key:
  118. raise R2RException(
  119. message="Cannot have both Bearer token and API key",
  120. status_code=400,
  121. )
  122. # 1. Try JWT if `auth` is present (Bearer token)
  123. if auth is not None:
  124. credentials = auth.credentials
  125. try:
  126. token_data = await self.decode_token(credentials)
  127. user = await self.database_provider.users_handler.get_user_by_email(
  128. token_data.email
  129. )
  130. if user is not None:
  131. return user
  132. except R2RException:
  133. # JWT decoding failed for logical reasons (invalid token)
  134. pass
  135. except Exception as e:
  136. # JWT decoding failed unexpectedly, log and continue
  137. logger.debug(f"JWT verification failed: {e}")
  138. # 2. If JWT failed, try API key from Bearer token
  139. # Expected format: key_id.raw_api_key
  140. if "." in credentials:
  141. key_id, raw_api_key = credentials.split(".", 1)
  142. api_key_record = await self.database_provider.users_handler.get_api_key_record(
  143. key_id
  144. )
  145. if api_key_record is not None:
  146. hashed_key = api_key_record["hashed_key"]
  147. if self.crypto_provider.verify_api_key(
  148. raw_api_key, hashed_key
  149. ):
  150. user = await self.database_provider.users_handler.get_user_by_id(
  151. api_key_record["user_id"]
  152. )
  153. if user is not None and user.is_active:
  154. return user
  155. # 3. If no Bearer token worked, try the X-API-Key header
  156. if api_key is not None and "." in api_key:
  157. key_id, raw_api_key = api_key.split(".", 1)
  158. api_key_record = await self.database_provider.users_handler.get_api_key_record(
  159. key_id
  160. )
  161. if api_key_record is not None:
  162. hashed_key = api_key_record["hashed_key"]
  163. if self.crypto_provider.verify_api_key(
  164. raw_api_key, hashed_key
  165. ):
  166. user = await self.database_provider.users_handler.get_user_by_id(
  167. api_key_record["user_id"]
  168. )
  169. if user is not None and user.is_active:
  170. return user
  171. # If we reach here, both JWT and API key auth failed
  172. raise R2RException(
  173. message="Invalid token or API key",
  174. status_code=401,
  175. )
  176. return _auth_wrapper
  177. @abstractmethod
  178. async def change_password(
  179. self, user: User, current_password: str, new_password: str
  180. ) -> dict[str, str]:
  181. pass
  182. @abstractmethod
  183. async def request_password_reset(self, email: str) -> dict[str, str]:
  184. pass
  185. @abstractmethod
  186. async def confirm_password_reset(
  187. self, reset_token: str, new_password: str
  188. ) -> dict[str, str]:
  189. pass
  190. @abstractmethod
  191. async def logout(self, token: str) -> dict[str, str]:
  192. pass
  193. @abstractmethod
  194. async def send_reset_email(self, email: str) -> dict[str, str]:
  195. pass