auth_service.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. from datetime import datetime
  2. from typing import Optional
  3. from uuid import UUID
  4. from core.base import R2RException, RunManager, Token
  5. from core.base.api.models import User
  6. from core.telemetry.telemetry_decorator import telemetry_event
  7. from core.utils import generate_default_user_collection_id
  8. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  9. from ..config import R2RConfig
  10. from .base import Service
  11. class AuthService(Service):
  12. def __init__(
  13. self,
  14. config: R2RConfig,
  15. providers: R2RProviders,
  16. pipes: R2RPipes,
  17. pipelines: R2RPipelines,
  18. agents: R2RAgents,
  19. run_manager: RunManager,
  20. ):
  21. super().__init__(
  22. config,
  23. providers,
  24. pipes,
  25. pipelines,
  26. agents,
  27. run_manager,
  28. )
  29. @telemetry_event("RegisterUser")
  30. async def register(self, email: str, password: str) -> User:
  31. return await self.providers.auth.register(email, password)
  32. @telemetry_event("VerifyEmail")
  33. async def verify_email(
  34. self, email: str, verification_code: str
  35. ) -> dict[str, str]:
  36. if not self.config.auth.require_email_verification:
  37. raise R2RException(
  38. status_code=400, message="Email verification is not required"
  39. )
  40. user_id = await self.providers.database.users_handler.get_user_id_by_verification_code(
  41. verification_code
  42. )
  43. if not user_id:
  44. raise R2RException(
  45. status_code=400, message="Invalid or expired verification code"
  46. )
  47. user = await self.providers.database.users_handler.get_user_by_id(
  48. user_id
  49. )
  50. if not user or user.email != email:
  51. raise R2RException(
  52. status_code=400, message="Invalid or expired verification code"
  53. )
  54. await self.providers.database.users_handler.mark_user_as_verified(
  55. user_id
  56. )
  57. await self.providers.database.users_handler.remove_verification_code(
  58. verification_code
  59. )
  60. return {"message": f"User account {user_id} verified successfully."}
  61. @telemetry_event("Login")
  62. async def login(self, email: str, password: str) -> dict[str, Token]:
  63. return await self.providers.auth.login(email, password)
  64. @telemetry_event("GetCurrentUser")
  65. async def user(self, token: str) -> User:
  66. token_data = await self.providers.auth.decode_token(token)
  67. if not token_data.email:
  68. raise R2RException(
  69. status_code=401, message="Invalid authentication credentials"
  70. )
  71. user = await self.providers.database.users_handler.get_user_by_email(
  72. token_data.email
  73. )
  74. if user is None:
  75. raise R2RException(
  76. status_code=401, message="Invalid authentication credentials"
  77. )
  78. return user
  79. @telemetry_event("RefreshToken")
  80. async def refresh_access_token(
  81. self, refresh_token: str
  82. ) -> dict[str, Token]:
  83. return await self.providers.auth.refresh_access_token(refresh_token)
  84. @telemetry_event("ChangePassword")
  85. async def change_password(
  86. self, user: User, current_password: str, new_password: str
  87. ) -> dict[str, str]:
  88. if not user:
  89. raise R2RException(status_code=404, message="User not found")
  90. return await self.providers.auth.change_password(
  91. user, current_password, new_password
  92. )
  93. @telemetry_event("RequestPasswordReset")
  94. async def request_password_reset(self, email: str) -> dict[str, str]:
  95. return await self.providers.auth.request_password_reset(email)
  96. @telemetry_event("ConfirmPasswordReset")
  97. async def confirm_password_reset(
  98. self, reset_token: str, new_password: str
  99. ) -> dict[str, str]:
  100. return await self.providers.auth.confirm_password_reset(
  101. reset_token, new_password
  102. )
  103. @telemetry_event("Logout")
  104. async def logout(self, token: str) -> dict[str, str]:
  105. return await self.providers.auth.logout(token)
  106. @telemetry_event("UpdateUserProfile")
  107. async def update_user(
  108. self,
  109. user_id: UUID,
  110. email: Optional[str] = None,
  111. is_superuser: Optional[bool] = None,
  112. name: Optional[str] = None,
  113. bio: Optional[str] = None,
  114. profile_picture: Optional[str] = None,
  115. ) -> User:
  116. user: User = (
  117. await self.providers.database.users_handler.get_user_by_id(user_id)
  118. )
  119. if not user:
  120. raise R2RException(status_code=404, message="User not found")
  121. if email is not None:
  122. user.email = email
  123. if is_superuser is not None:
  124. user.is_superuser = is_superuser
  125. if name is not None:
  126. user.name = name
  127. if bio is not None:
  128. user.bio = bio
  129. if profile_picture is not None:
  130. user.profile_picture = profile_picture
  131. return await self.providers.database.users_handler.update_user(user)
  132. @telemetry_event("DeleteUserAccount")
  133. async def delete_user(
  134. self,
  135. user_id: UUID,
  136. password: Optional[str] = None,
  137. delete_vector_data: bool = False,
  138. is_superuser: bool = False,
  139. ) -> dict[str, str]:
  140. user = await self.providers.database.users_handler.get_user_by_id(
  141. user_id
  142. )
  143. if not user:
  144. raise R2RException(status_code=404, message="User not found")
  145. if not is_superuser and not password:
  146. raise R2RException(
  147. status_code=422, message="Password is required for deletion"
  148. )
  149. if not (
  150. is_superuser
  151. or (
  152. user.hashed_password is not None
  153. and self.providers.auth.crypto_provider.verify_password(
  154. password, user.hashed_password # type: ignore
  155. )
  156. )
  157. ):
  158. raise R2RException(status_code=400, message="Incorrect password")
  159. await self.providers.database.users_handler.delete_user_relational(
  160. user_id
  161. )
  162. # Delete user's default collection
  163. # TODO: We need to better define what happens to the user's data when they are deleted
  164. collection_id = generate_default_user_collection_id(user_id)
  165. await self.providers.database.collections_handler.delete_collection_relational(
  166. collection_id
  167. )
  168. try:
  169. await self.providers.database.graphs_handler.delete_graph_for_collection(
  170. collection_id=collection_id,
  171. )
  172. except Exception as e:
  173. # print(f"Error deleting graph for collection {collection_id}: {e}")
  174. pass
  175. if delete_vector_data:
  176. await self.providers.database.chunks_handler.delete_user_vector(
  177. user_id
  178. )
  179. await self.providers.database.chunks_handler.delete_collection_vector(
  180. collection_id
  181. )
  182. return {"message": f"User account {user_id} deleted successfully."}
  183. @telemetry_event("CleanExpiredBlacklistedTokens")
  184. async def clean_expired_blacklisted_tokens(
  185. self,
  186. max_age_hours: int = 7 * 24,
  187. current_time: Optional[datetime] = None,
  188. ):
  189. await self.providers.database.token_handler.clean_expired_blacklisted_tokens(
  190. max_age_hours, current_time
  191. )
  192. @telemetry_event("GetUserVerificationCode")
  193. async def get_user_verification_code(
  194. self,
  195. user_id: UUID,
  196. ) -> dict:
  197. """
  198. Get only the verification code data for a specific user.
  199. This method should be called after superuser authorization has been verified.
  200. """
  201. verification_data = await self.providers.database.users_handler.get_user_validation_data(
  202. user_id=user_id
  203. )
  204. return {
  205. "verification_code": verification_data["verification_data"][
  206. "verification_code"
  207. ],
  208. "expiry": verification_data["verification_data"][
  209. "verification_code_expiry"
  210. ],
  211. }
  212. @telemetry_event("GetUserVerificationCode")
  213. async def get_user_reset_token(
  214. self,
  215. user_id: UUID,
  216. ) -> dict:
  217. """
  218. Get only the verification code data for a specific user.
  219. This method should be called after superuser authorization has been verified.
  220. """
  221. verification_data = await self.providers.database.users_handler.get_user_validation_data(
  222. user_id=user_id
  223. )
  224. return {
  225. "reset_token": verification_data["verification_data"][
  226. "reset_token"
  227. ],
  228. "expiry": verification_data["verification_data"][
  229. "reset_token_expiry"
  230. ],
  231. }
  232. @telemetry_event("SendResetEmail")
  233. async def send_reset_email(self, email: str) -> dict:
  234. """
  235. Generate a new verification code and send a reset email to the user.
  236. Returns the verification code for testing/sandbox environments.
  237. Args:
  238. email (str): The email address of the user
  239. Returns:
  240. dict: Contains verification_code and message
  241. """
  242. return await self.providers.auth.send_reset_email(email)