auth_service.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import logging
  2. from datetime import datetime
  3. from typing import Optional
  4. from uuid import UUID
  5. from core.base import R2RException, RunManager, Token
  6. from core.base.api.models import User
  7. from core.telemetry.telemetry_decorator import telemetry_event
  8. from core.utils import generate_default_user_collection_id
  9. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  10. from ..config import R2RConfig
  11. from .base import Service
  12. logger = logging.getLogger()
  13. class AuthService(Service):
  14. def __init__(
  15. self,
  16. config: R2RConfig,
  17. providers: R2RProviders,
  18. pipes: R2RPipes,
  19. pipelines: R2RPipelines,
  20. agents: R2RAgents,
  21. run_manager: RunManager,
  22. ):
  23. super().__init__(
  24. config,
  25. providers,
  26. pipes,
  27. pipelines,
  28. agents,
  29. run_manager,
  30. )
  31. @telemetry_event("RegisterUser")
  32. async def register(self, email: str, password: str) -> User:
  33. return await self.providers.auth.register(email, password)
  34. @telemetry_event("VerifyEmail")
  35. async def verify_email(
  36. self, email: str, verification_code: str
  37. ) -> dict[str, str]:
  38. if not self.config.auth.require_email_verification:
  39. raise R2RException(
  40. status_code=400, message="Email verification is not required"
  41. )
  42. user_id = await self.providers.database.users_handler.get_user_id_by_verification_code(
  43. verification_code
  44. )
  45. user = await self.providers.database.users_handler.get_user_by_id(
  46. user_id
  47. )
  48. if not user or user.email != email:
  49. raise R2RException(
  50. status_code=400, message="Invalid or expired verification code"
  51. )
  52. await self.providers.database.users_handler.mark_user_as_verified(
  53. user_id
  54. )
  55. await self.providers.database.users_handler.remove_verification_code(
  56. verification_code
  57. )
  58. return {"message": f"User account {user_id} verified successfully."}
  59. @telemetry_event("Login")
  60. async def login(self, email: str, password: str) -> dict[str, Token]:
  61. return await self.providers.auth.login(email, password)
  62. @telemetry_event("GetCurrentUser")
  63. async def user(self, token: str) -> User:
  64. token_data = await self.providers.auth.decode_token(token)
  65. if not token_data.email:
  66. raise R2RException(
  67. status_code=401, message="Invalid authentication credentials"
  68. )
  69. user = await self.providers.database.users_handler.get_user_by_email(
  70. token_data.email
  71. )
  72. if user is None:
  73. raise R2RException(
  74. status_code=401, message="Invalid authentication credentials"
  75. )
  76. return user
  77. @telemetry_event("RefreshToken")
  78. async def refresh_access_token(
  79. self, refresh_token: str
  80. ) -> dict[str, Token]:
  81. return await self.providers.auth.refresh_access_token(refresh_token)
  82. @telemetry_event("ChangePassword")
  83. async def change_password(
  84. self, user: User, current_password: str, new_password: str
  85. ) -> dict[str, str]:
  86. if not user:
  87. raise R2RException(status_code=404, message="User not found")
  88. return await self.providers.auth.change_password(
  89. user, current_password, new_password
  90. )
  91. @telemetry_event("RequestPasswordReset")
  92. async def request_password_reset(self, email: str) -> dict[str, str]:
  93. return await self.providers.auth.request_password_reset(email)
  94. @telemetry_event("ConfirmPasswordReset")
  95. async def confirm_password_reset(
  96. self, reset_token: str, new_password: str
  97. ) -> dict[str, str]:
  98. return await self.providers.auth.confirm_password_reset(
  99. reset_token, new_password
  100. )
  101. @telemetry_event("Logout")
  102. async def logout(self, token: str) -> dict[str, str]:
  103. return await self.providers.auth.logout(token)
  104. @telemetry_event("UpdateUserProfile")
  105. async def update_user(
  106. self,
  107. user_id: UUID,
  108. email: Optional[str] = None,
  109. is_superuser: Optional[bool] = None,
  110. name: Optional[str] = None,
  111. bio: Optional[str] = None,
  112. profile_picture: Optional[str] = None,
  113. limits_overrides: Optional[dict] = None,
  114. ) -> User:
  115. user: User = (
  116. await self.providers.database.users_handler.get_user_by_id(user_id)
  117. )
  118. if not user:
  119. raise R2RException(status_code=404, message="User not found")
  120. if email is not None:
  121. user.email = email
  122. if is_superuser is not None:
  123. user.is_superuser = is_superuser
  124. if name is not None:
  125. user.name = name
  126. if bio is not None:
  127. user.bio = bio
  128. if profile_picture is not None:
  129. user.profile_picture = profile_picture
  130. if limits_overrides is not None:
  131. user.limits_overrides = limits_overrides
  132. return await self.providers.database.users_handler.update_user(user)
  133. @telemetry_event("DeleteUserAccount")
  134. async def delete_user(
  135. self,
  136. user_id: UUID,
  137. password: Optional[str] = None,
  138. delete_vector_data: bool = False,
  139. is_superuser: bool = False,
  140. ) -> dict[str, str]:
  141. user = await self.providers.database.users_handler.get_user_by_id(
  142. user_id
  143. )
  144. if not user:
  145. raise R2RException(status_code=404, message="User not found")
  146. if not is_superuser and not password:
  147. raise R2RException(
  148. status_code=422, message="Password is required for deletion"
  149. )
  150. if not (
  151. is_superuser
  152. or (
  153. user.hashed_password is not None
  154. and self.providers.auth.crypto_provider.verify_password(
  155. password, user.hashed_password # type: ignore
  156. )
  157. )
  158. ):
  159. raise R2RException(status_code=400, message="Incorrect password")
  160. await self.providers.database.users_handler.delete_user_relational(
  161. user_id
  162. )
  163. # Delete user's default collection
  164. # TODO: We need to better define what happens to the user's data when they are deleted
  165. collection_id = generate_default_user_collection_id(user_id)
  166. await self.providers.database.collections_handler.delete_collection_relational(
  167. collection_id
  168. )
  169. try:
  170. await self.providers.database.graphs_handler.delete(
  171. collection_id=collection_id,
  172. )
  173. except Exception as e:
  174. logger.warning(
  175. f"Error deleting graph for collection {collection_id}: {e}"
  176. )
  177. if delete_vector_data:
  178. await self.providers.database.chunks_handler.delete_user_vector(
  179. user_id
  180. )
  181. await self.providers.database.chunks_handler.delete_collection_vector(
  182. collection_id
  183. )
  184. return {"message": f"User account {user_id} deleted successfully."}
  185. @telemetry_event("CleanExpiredBlacklistedTokens")
  186. async def clean_expired_blacklisted_tokens(
  187. self,
  188. max_age_hours: int = 7 * 24,
  189. current_time: Optional[datetime] = None,
  190. ):
  191. await self.providers.database.token_handler.clean_expired_blacklisted_tokens(
  192. max_age_hours, current_time
  193. )
  194. @telemetry_event("GetUserVerificationCode")
  195. async def get_user_verification_code(
  196. self,
  197. user_id: UUID,
  198. ) -> dict:
  199. """
  200. Get only the verification code data for a specific user.
  201. This method should be called after superuser authorization has been verified.
  202. """
  203. verification_data = await self.providers.database.users_handler.get_user_validation_data(
  204. user_id=user_id
  205. )
  206. return {
  207. "verification_code": verification_data["verification_data"][
  208. "verification_code"
  209. ],
  210. "expiry": verification_data["verification_data"][
  211. "verification_code_expiry"
  212. ],
  213. }
  214. @telemetry_event("GetUserVerificationCode")
  215. async def get_user_reset_token(
  216. self,
  217. user_id: UUID,
  218. ) -> dict:
  219. """
  220. Get only the verification code data for a specific user.
  221. This method should be called after superuser authorization has been verified.
  222. """
  223. verification_data = await self.providers.database.users_handler.get_user_validation_data(
  224. user_id=user_id
  225. )
  226. return {
  227. "reset_token": verification_data["verification_data"][
  228. "reset_token"
  229. ],
  230. "expiry": verification_data["verification_data"][
  231. "reset_token_expiry"
  232. ],
  233. }
  234. @telemetry_event("SendResetEmail")
  235. async def send_reset_email(self, email: str) -> dict:
  236. """
  237. Generate a new verification code and send a reset email to the user.
  238. Returns the verification code for testing/sandbox environments.
  239. Args:
  240. email (str): The email address of the user
  241. Returns:
  242. dict: Contains verification_code and message
  243. """
  244. return await self.providers.auth.send_reset_email(email)
  245. async def create_user_api_key(self, user_id: UUID) -> dict:
  246. """
  247. Generate a new API key for the user.
  248. Args:
  249. user_id (UUID): The ID of the user
  250. Returns:
  251. dict: Contains the API key and message
  252. """
  253. return await self.providers.auth.create_user_api_key(user_id)
  254. async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> dict:
  255. """
  256. Delete the API key for the user.
  257. Args:
  258. user_id (UUID): The ID of the user
  259. key_id (str): The ID of the API key
  260. Returns:
  261. dict: Contains the message
  262. """
  263. return await self.providers.auth.delete_user_api_key(
  264. user_id=user_id, key_id=key_id
  265. )
  266. async def list_user_api_keys(self, user_id: UUID) -> dict:
  267. """
  268. List all API keys for the user.
  269. Args:
  270. user_id (UUID): The ID of the user
  271. Returns:
  272. dict: Contains the list of API keys
  273. """
  274. return await self.providers.auth.list_user_api_keys(user_id)