auth_service.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import logging
  2. from datetime import datetime
  3. from typing import Optional
  4. from uuid import UUID
  5. from core.base import R2RException, Token
  6. from core.base.api.models import User
  7. from core.utils import generate_default_user_collection_id
  8. from ..abstractions import R2RProviders
  9. from ..config import R2RConfig
  10. from .base import Service
  11. logger = logging.getLogger()
  12. class AuthService(Service):
  13. def __init__(
  14. self,
  15. config: R2RConfig,
  16. providers: R2RProviders,
  17. ):
  18. super().__init__(
  19. config,
  20. providers,
  21. )
  22. async def register(
  23. self,
  24. email: str,
  25. password: str,
  26. is_verified: bool = False,
  27. name: Optional[str] = None,
  28. bio: Optional[str] = None,
  29. profile_picture: Optional[str] = None,
  30. ) -> User:
  31. return await self.providers.auth.register(
  32. email=email,
  33. password=password,
  34. is_verified=is_verified,
  35. name=name,
  36. bio=bio,
  37. profile_picture=profile_picture,
  38. )
  39. async def send_verification_email(
  40. self, email: str
  41. ) -> tuple[str, datetime]:
  42. return await self.providers.auth.send_verification_email(email=email)
  43. async def verify_email(
  44. self, email: str, verification_code: str
  45. ) -> dict[str, str]:
  46. if not self.config.auth.require_email_verification:
  47. raise R2RException(
  48. status_code=400, message="Email verification is not required"
  49. )
  50. user_id = await self.providers.database.users_handler.get_user_id_by_verification_code(
  51. verification_code
  52. )
  53. user = await self.providers.database.users_handler.get_user_by_id(
  54. user_id
  55. )
  56. if not user or user.email != email:
  57. raise R2RException(
  58. status_code=400, message="Invalid or expired verification code"
  59. )
  60. await self.providers.database.users_handler.mark_user_as_verified(
  61. user_id
  62. )
  63. await self.providers.database.users_handler.remove_verification_code(
  64. verification_code
  65. )
  66. return {"message": f"User account {user_id} verified successfully."}
  67. async def login(self, email: str, password: str) -> dict[str, Token]:
  68. return await self.providers.auth.login(email, password)
  69. async def user(self, token: str) -> User:
  70. token_data = await self.providers.auth.decode_token(token)
  71. if not token_data.email:
  72. raise R2RException(
  73. status_code=401, message="Invalid authentication credentials"
  74. )
  75. user = await self.providers.database.users_handler.get_user_by_email(
  76. token_data.email
  77. )
  78. if user is None:
  79. raise R2RException(
  80. status_code=401, message="Invalid authentication credentials"
  81. )
  82. return user
  83. async def refresh_access_token(
  84. self, refresh_token: str
  85. ) -> dict[str, Token]:
  86. return await self.providers.auth.refresh_access_token(refresh_token)
  87. async def change_password(
  88. self, user: User, current_password: str, new_password: str
  89. ) -> dict[str, str]:
  90. if not user:
  91. raise R2RException(status_code=404, message="User not found")
  92. return await self.providers.auth.change_password(
  93. user, current_password, new_password
  94. )
  95. async def request_password_reset(self, email: str) -> dict[str, str]:
  96. return await self.providers.auth.request_password_reset(email)
  97. async def confirm_password_reset(
  98. self, reset_token: str, new_password: str
  99. ) -> dict[str, str]:
  100. return await self.providers.auth.confirm_password_reset(
  101. reset_token, new_password
  102. )
  103. async def logout(self, token: str) -> dict[str, str]:
  104. return await self.providers.auth.logout(token)
  105. async def update_user(
  106. self,
  107. user_id: UUID,
  108. email: Optional[str] = None,
  109. is_superuser: Optional[bool] = None,
  110. name: Optional[str] = None,
  111. bio: Optional[str] = None,
  112. profile_picture: Optional[str] = None,
  113. limits_overrides: Optional[dict] = None,
  114. merge_limits: bool = False,
  115. new_metadata: Optional[dict] = None,
  116. ) -> User:
  117. user: User = (
  118. await self.providers.database.users_handler.get_user_by_id(user_id)
  119. )
  120. if not user:
  121. raise R2RException(status_code=404, message="User not found")
  122. if email is not None:
  123. user.email = email
  124. if is_superuser is not None:
  125. user.is_superuser = is_superuser
  126. if name is not None:
  127. user.name = name
  128. if bio is not None:
  129. user.bio = bio
  130. if profile_picture is not None:
  131. user.profile_picture = profile_picture
  132. if limits_overrides is not None:
  133. user.limits_overrides = limits_overrides
  134. return await self.providers.database.users_handler.update_user(
  135. user, merge_limits=merge_limits, new_metadata=new_metadata
  136. )
  137. async def delete_user(
  138. self,
  139. user_id: UUID,
  140. password: Optional[str] = None,
  141. delete_vector_data: bool = False,
  142. is_superuser: bool = False,
  143. ) -> dict[str, str]:
  144. user = await self.providers.database.users_handler.get_user_by_id(
  145. user_id
  146. )
  147. if not user:
  148. raise R2RException(status_code=404, message="User not found")
  149. if not is_superuser and not password:
  150. raise R2RException(
  151. status_code=422, message="Password is required for deletion"
  152. )
  153. if not (
  154. is_superuser
  155. or (
  156. user.hashed_password is not None
  157. and password is not None
  158. and self.providers.auth.crypto_provider.verify_password(
  159. plain_password=password,
  160. hashed_password=user.hashed_password,
  161. )
  162. )
  163. ):
  164. raise R2RException(status_code=400, message="Incorrect password")
  165. await self.providers.database.users_handler.delete_user_relational(
  166. user_id
  167. )
  168. # Delete user's default collection
  169. # TODO: We need to better define what happens to the user's data when they are deleted
  170. collection_id = generate_default_user_collection_id(user_id)
  171. await self.providers.database.collections_handler.delete_collection_relational(
  172. collection_id
  173. )
  174. try:
  175. await self.providers.database.graphs_handler.delete(
  176. collection_id=collection_id,
  177. )
  178. except Exception as e:
  179. logger.warning(
  180. f"Error deleting graph for collection {collection_id}: {e}"
  181. )
  182. if delete_vector_data:
  183. await self.providers.database.chunks_handler.delete_user_vector(
  184. user_id
  185. )
  186. await self.providers.database.chunks_handler.delete_collection_vector(
  187. collection_id
  188. )
  189. return {"message": f"User account {user_id} deleted successfully."}
  190. async def clean_expired_blacklisted_tokens(
  191. self,
  192. max_age_hours: int = 7 * 24,
  193. current_time: Optional[datetime] = None,
  194. ):
  195. await self.providers.database.token_handler.clean_expired_blacklisted_tokens(
  196. max_age_hours, current_time
  197. )
  198. async def get_user_verification_code(
  199. self,
  200. user_id: UUID,
  201. ) -> dict:
  202. """Get only the verification code data for a specific user.
  203. This method should be called after superuser authorization has been
  204. verified.
  205. """
  206. verification_data = await self.providers.database.users_handler.get_user_validation_data(
  207. user_id=user_id
  208. )
  209. return {
  210. "verification_code": verification_data["verification_data"][
  211. "verification_code"
  212. ],
  213. "expiry": verification_data["verification_data"][
  214. "verification_code_expiry"
  215. ],
  216. }
  217. async def get_user_reset_token(
  218. self,
  219. user_id: UUID,
  220. ) -> dict:
  221. """Get only the verification code data for a specific user.
  222. This method should be called after superuser authorization has been
  223. verified.
  224. """
  225. verification_data = await self.providers.database.users_handler.get_user_validation_data(
  226. user_id=user_id
  227. )
  228. return {
  229. "reset_token": verification_data["verification_data"][
  230. "reset_token"
  231. ],
  232. "expiry": verification_data["verification_data"][
  233. "reset_token_expiry"
  234. ],
  235. }
  236. async def send_reset_email(self, email: str) -> dict:
  237. """Generate a new verification code and send a reset email to the user.
  238. Returns the verification code for testing/sandbox environments.
  239. Args:
  240. email (str): The email address of the user
  241. Returns:
  242. dict: Contains verification_code and message
  243. """
  244. return await self.providers.auth.send_reset_email(email)
  245. async def create_user_api_key(
  246. self, user_id: UUID, name: Optional[str], description: Optional[str]
  247. ) -> dict:
  248. """Generate a new API key for the user with optional name and
  249. description.
  250. Args:
  251. user_id (UUID): The ID of the user
  252. name (Optional[str]): Name of the API key
  253. description (Optional[str]): Description of the API key
  254. Returns:
  255. dict: Contains the API key and message
  256. """
  257. return await self.providers.auth.create_user_api_key(
  258. user_id=user_id, name=name, description=description
  259. )
  260. async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
  261. """Delete the API key for the user.
  262. Args:
  263. user_id (UUID): The ID of the user
  264. key_id (str): The ID of the API key
  265. Returns:
  266. bool: True if the API key was deleted successfully
  267. """
  268. return await self.providers.auth.delete_user_api_key(
  269. user_id=user_id, key_id=key_id
  270. )
  271. async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
  272. """List all API keys for the user.
  273. Args:
  274. user_id (UUID): The ID of the user
  275. Returns:
  276. dict: Contains the list of API keys
  277. """
  278. return await self.providers.auth.list_user_api_keys(user_id)