auth_service.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. import logging
  2. from datetime import datetime
  3. from typing import Optional
  4. from uuid import UUID
  5. from core.base import R2RException, RunManager, Token
  6. from core.base.api.models import User
  7. from core.telemetry.telemetry_decorator import telemetry_event
  8. from core.utils import generate_default_user_collection_id
  9. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  10. from ..config import R2RConfig
  11. from .base import Service
  12. logger = logging.getLogger()
  13. class AuthService(Service):
  14. def __init__(
  15. self,
  16. config: R2RConfig,
  17. providers: R2RProviders,
  18. pipes: R2RPipes,
  19. pipelines: R2RPipelines,
  20. agents: R2RAgents,
  21. run_manager: RunManager,
  22. ):
  23. super().__init__(
  24. config,
  25. providers,
  26. pipes,
  27. pipelines,
  28. agents,
  29. run_manager,
  30. )
  31. @telemetry_event("RegisterUser")
  32. async def register(self, email: str, password: str) -> User:
  33. return await self.providers.auth.register(email, password)
  34. @telemetry_event("VerifyEmail")
  35. async def verify_email(
  36. self, email: str, verification_code: str
  37. ) -> dict[str, str]:
  38. if not self.config.auth.require_email_verification:
  39. raise R2RException(
  40. status_code=400, message="Email verification is not required"
  41. )
  42. user_id = await self.providers.database.users_handler.get_user_id_by_verification_code(
  43. verification_code
  44. )
  45. user = await self.providers.database.users_handler.get_user_by_id(
  46. user_id
  47. )
  48. if not user or user.email != email:
  49. raise R2RException(
  50. status_code=400, message="Invalid or expired verification code"
  51. )
  52. await self.providers.database.users_handler.mark_user_as_verified(
  53. user_id
  54. )
  55. await self.providers.database.users_handler.remove_verification_code(
  56. verification_code
  57. )
  58. return {"message": f"User account {user_id} verified successfully."}
  59. @telemetry_event("Login")
  60. async def login(self, email: str, password: str) -> dict[str, Token]:
  61. return await self.providers.auth.login(email, password)
  62. @telemetry_event("GetCurrentUser")
  63. async def user(self, token: str) -> User:
  64. token_data = await self.providers.auth.decode_token(token)
  65. if not token_data.email:
  66. raise R2RException(
  67. status_code=401, message="Invalid authentication credentials"
  68. )
  69. user = await self.providers.database.users_handler.get_user_by_email(
  70. token_data.email
  71. )
  72. if user is None:
  73. raise R2RException(
  74. status_code=401, message="Invalid authentication credentials"
  75. )
  76. return user
  77. @telemetry_event("RefreshToken")
  78. async def refresh_access_token(
  79. self, refresh_token: str
  80. ) -> dict[str, Token]:
  81. return await self.providers.auth.refresh_access_token(refresh_token)
  82. @telemetry_event("ChangePassword")
  83. async def change_password(
  84. self, user: User, current_password: str, new_password: str
  85. ) -> dict[str, str]:
  86. if not user:
  87. raise R2RException(status_code=404, message="User not found")
  88. return await self.providers.auth.change_password(
  89. user, current_password, new_password
  90. )
  91. @telemetry_event("RequestPasswordReset")
  92. async def request_password_reset(self, email: str) -> dict[str, str]:
  93. return await self.providers.auth.request_password_reset(email)
  94. @telemetry_event("ConfirmPasswordReset")
  95. async def confirm_password_reset(
  96. self, reset_token: str, new_password: str
  97. ) -> dict[str, str]:
  98. return await self.providers.auth.confirm_password_reset(
  99. reset_token, new_password
  100. )
  101. @telemetry_event("Logout")
  102. async def logout(self, token: str) -> dict[str, str]:
  103. return await self.providers.auth.logout(token)
  104. @telemetry_event("UpdateUserProfile")
  105. async def update_user(
  106. self,
  107. user_id: UUID,
  108. email: Optional[str] = None,
  109. is_superuser: Optional[bool] = None,
  110. name: Optional[str] = None,
  111. bio: Optional[str] = None,
  112. profile_picture: Optional[str] = None,
  113. ) -> User:
  114. user: User = (
  115. await self.providers.database.users_handler.get_user_by_id(user_id)
  116. )
  117. if not user:
  118. raise R2RException(status_code=404, message="User not found")
  119. if email is not None:
  120. user.email = email
  121. if is_superuser is not None:
  122. user.is_superuser = is_superuser
  123. if name is not None:
  124. user.name = name
  125. if bio is not None:
  126. user.bio = bio
  127. if profile_picture is not None:
  128. user.profile_picture = profile_picture
  129. return await self.providers.database.users_handler.update_user(user)
  130. @telemetry_event("DeleteUserAccount")
  131. async def delete_user(
  132. self,
  133. user_id: UUID,
  134. password: Optional[str] = None,
  135. delete_vector_data: bool = False,
  136. is_superuser: bool = False,
  137. ) -> dict[str, str]:
  138. user = await self.providers.database.users_handler.get_user_by_id(
  139. user_id
  140. )
  141. if not user:
  142. raise R2RException(status_code=404, message="User not found")
  143. if not is_superuser and not password:
  144. raise R2RException(
  145. status_code=422, message="Password is required for deletion"
  146. )
  147. if not (
  148. is_superuser
  149. or (
  150. user.hashed_password is not None
  151. and self.providers.auth.crypto_provider.verify_password(
  152. password, user.hashed_password # type: ignore
  153. )
  154. )
  155. ):
  156. raise R2RException(status_code=400, message="Incorrect password")
  157. await self.providers.database.users_handler.delete_user_relational(
  158. user_id
  159. )
  160. # Delete user's default collection
  161. # TODO: We need to better define what happens to the user's data when they are deleted
  162. collection_id = generate_default_user_collection_id(user_id)
  163. await self.providers.database.collections_handler.delete_collection_relational(
  164. collection_id
  165. )
  166. try:
  167. await self.providers.database.graphs_handler.delete(
  168. collection_id=collection_id,
  169. )
  170. except Exception as e:
  171. logger.warning(
  172. f"Error deleting graph for collection {collection_id}: {e}"
  173. )
  174. if delete_vector_data:
  175. await self.providers.database.chunks_handler.delete_user_vector(
  176. user_id
  177. )
  178. await self.providers.database.chunks_handler.delete_collection_vector(
  179. collection_id
  180. )
  181. return {"message": f"User account {user_id} deleted successfully."}
  182. @telemetry_event("CleanExpiredBlacklistedTokens")
  183. async def clean_expired_blacklisted_tokens(
  184. self,
  185. max_age_hours: int = 7 * 24,
  186. current_time: Optional[datetime] = None,
  187. ):
  188. await self.providers.database.token_handler.clean_expired_blacklisted_tokens(
  189. max_age_hours, current_time
  190. )
  191. @telemetry_event("GetUserVerificationCode")
  192. async def get_user_verification_code(
  193. self,
  194. user_id: UUID,
  195. ) -> dict:
  196. """
  197. Get only the verification code data for a specific user.
  198. This method should be called after superuser authorization has been verified.
  199. """
  200. verification_data = await self.providers.database.users_handler.get_user_validation_data(
  201. user_id=user_id
  202. )
  203. return {
  204. "verification_code": verification_data["verification_data"][
  205. "verification_code"
  206. ],
  207. "expiry": verification_data["verification_data"][
  208. "verification_code_expiry"
  209. ],
  210. }
  211. @telemetry_event("GetUserVerificationCode")
  212. async def get_user_reset_token(
  213. self,
  214. user_id: UUID,
  215. ) -> dict:
  216. """
  217. Get only the verification code data for a specific user.
  218. This method should be called after superuser authorization has been verified.
  219. """
  220. verification_data = await self.providers.database.users_handler.get_user_validation_data(
  221. user_id=user_id
  222. )
  223. return {
  224. "reset_token": verification_data["verification_data"][
  225. "reset_token"
  226. ],
  227. "expiry": verification_data["verification_data"][
  228. "reset_token_expiry"
  229. ],
  230. }
  231. @telemetry_event("SendResetEmail")
  232. async def send_reset_email(self, email: str) -> dict:
  233. """
  234. Generate a new verification code and send a reset email to the user.
  235. Returns the verification code for testing/sandbox environments.
  236. Args:
  237. email (str): The email address of the user
  238. Returns:
  239. dict: Contains verification_code and message
  240. """
  241. return await self.providers.auth.send_reset_email(email)
  242. async def create_user_api_key(self, user_id: UUID) -> dict:
  243. """
  244. Generate a new API key for the user.
  245. Args:
  246. user_id (UUID): The ID of the user
  247. Returns:
  248. dict: Contains the API key and message
  249. """
  250. return await self.providers.auth.create_user_api_key(user_id)
  251. async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> dict:
  252. """
  253. Delete the API key for the user.
  254. Args:
  255. user_id (UUID): The ID of the user
  256. key_id (str): The ID of the API key
  257. Returns:
  258. dict: Contains the message
  259. """
  260. return await self.providers.auth.delete_user_api_key(
  261. user_id=user_id, key_id=key_id
  262. )
  263. async def list_user_api_keys(self, user_id: UUID) -> dict:
  264. """
  265. List all API keys for the user.
  266. Args:
  267. user_id (UUID): The ID of the user
  268. Returns:
  269. dict: Contains the list of API keys
  270. """
  271. return await self.providers.auth.list_user_api_keys(user_id)