r2r_auth.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. import logging
  2. import os
  3. from datetime import datetime, timedelta, timezone
  4. import jwt
  5. from fastapi import Depends, HTTPException
  6. from fastapi.security import OAuth2PasswordBearer
  7. from core.base import (
  8. AuthConfig,
  9. AuthProvider,
  10. CollectionResponse,
  11. CryptoProvider,
  12. EmailProvider,
  13. R2RException,
  14. Token,
  15. TokenData,
  16. )
  17. from core.base.api.models import User
  18. from ...database.postgres import PostgresDatabaseProvider
  19. DEFAULT_ACCESS_LIFETIME_IN_MINUTES = 3600
  20. DEFAULT_REFRESH_LIFETIME_IN_DAYS = 7
  21. logger = logging.getLogger()
  22. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  23. DEFAULT_R2R_SK = "wNFbczH3QhUVcPALwtWZCPi0lrDlGV3P1DPRVEQCPbM"
  24. class R2RAuthProvider(AuthProvider):
  25. def __init__(
  26. self,
  27. config: AuthConfig,
  28. crypto_provider: CryptoProvider,
  29. database_provider: PostgresDatabaseProvider,
  30. email_provider: EmailProvider,
  31. ):
  32. super().__init__(
  33. config, crypto_provider, database_provider, email_provider
  34. )
  35. self.database_provider: PostgresDatabaseProvider = database_provider
  36. logger.debug(f"Initializing R2RAuthProvider with config: {config}")
  37. self.secret_key = (
  38. config.secret_key or os.getenv("R2R_SECRET_KEY") or DEFAULT_R2R_SK
  39. )
  40. self.access_token_lifetime_in_minutes = (
  41. config.access_token_lifetime_in_minutes
  42. or os.getenv("R2R_ACCESS_LIFE_IN_MINUTES")
  43. )
  44. self.refresh_token_lifetime_in_days = (
  45. config.refresh_token_lifetime_in_days
  46. or os.getenv("R2R_REFRESH_LIFE_IN_MINUTES")
  47. )
  48. self.config: AuthConfig = config
  49. async def initialize(self):
  50. try:
  51. user = await self.register(
  52. email=self.admin_email,
  53. password=self.admin_password,
  54. is_superuser=True,
  55. )
  56. await self.database_provider.users_handler.mark_user_as_superuser(
  57. user.id
  58. )
  59. except R2RException:
  60. logger.info("Default admin user already exists.")
  61. def create_access_token(self, data: dict) -> str:
  62. to_encode = data.copy()
  63. expire = datetime.now(timezone.utc) + timedelta(
  64. minutes=float(
  65. self.access_token_lifetime_in_minutes
  66. or DEFAULT_ACCESS_LIFETIME_IN_MINUTES
  67. )
  68. )
  69. to_encode |= {"exp": expire.timestamp(), "token_type": "access"}
  70. return jwt.encode(to_encode, self.secret_key, algorithm="HS256")
  71. def create_refresh_token(self, data: dict) -> str:
  72. to_encode = data.copy()
  73. expire = datetime.now(timezone.utc) + timedelta(
  74. days=float(
  75. self.refresh_token_lifetime_in_days
  76. or DEFAULT_REFRESH_LIFETIME_IN_DAYS
  77. )
  78. )
  79. to_encode |= {"exp": expire, "token_type": "refresh"}
  80. return jwt.encode(to_encode, self.secret_key, algorithm="HS256")
  81. async def decode_token(self, token: str) -> TokenData:
  82. try:
  83. # First, check if the token is blacklisted
  84. if await self.database_provider.token_handler.is_token_blacklisted(
  85. token
  86. ):
  87. raise R2RException(
  88. status_code=401, message="Token has been invalidated"
  89. )
  90. payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
  91. email: str = payload.get("sub")
  92. token_type: str = payload.get("token_type")
  93. exp: float = payload.get("exp")
  94. exp_datetime = datetime.fromtimestamp(exp, tz=timezone.utc)
  95. if (
  96. email is None
  97. or token_type is None
  98. or exp is None
  99. or exp_datetime < datetime.now(timezone.utc)
  100. ):
  101. raise R2RException(status_code=401, message="Invalid token")
  102. return TokenData(
  103. email=email, token_type=token_type, exp=exp_datetime
  104. )
  105. except jwt.ExpiredSignatureError as e:
  106. raise R2RException(
  107. status_code=401, message="Token has expired"
  108. ) from e
  109. except jwt.InvalidTokenError as e:
  110. raise R2RException(status_code=401, message="Invalid token") from e
  111. async def user(self, token: str = Depends(oauth2_scheme)) -> User:
  112. token_data = await self.decode_token(token)
  113. if not token_data.email:
  114. raise R2RException(
  115. status_code=401, message="Could not validate credentials"
  116. )
  117. user = await self.database_provider.users_handler.get_user_by_email(
  118. token_data.email
  119. )
  120. if user is None:
  121. raise R2RException(
  122. status_code=401, message="Invalid authentication credentials"
  123. )
  124. return user
  125. def get_current_active_user(
  126. self, current_user: User = Depends(user)
  127. ) -> User:
  128. if not current_user.is_active:
  129. raise R2RException(status_code=400, message="Inactive user")
  130. return current_user
  131. async def register(
  132. self, email: str, password: str, is_superuser: bool = False
  133. ) -> User:
  134. # Create new user and give them a default collection
  135. new_user = await self.database_provider.users_handler.create_user(
  136. email, password, is_superuser
  137. )
  138. default_collection: CollectionResponse = (
  139. await self.database_provider.collections_handler.create_collection(
  140. owner_id=new_user.id,
  141. )
  142. )
  143. await self.database_provider.graphs_handler.create(
  144. collection_id=default_collection.id,
  145. name=default_collection.name,
  146. description=default_collection.description,
  147. )
  148. await self.database_provider.users_handler.add_user_to_collection(
  149. new_user.id, default_collection.id
  150. )
  151. if self.config.require_email_verification:
  152. verification_code = (
  153. self.crypto_provider.generate_verification_code()
  154. )
  155. expiry = datetime.now(timezone.utc) + timedelta(hours=24)
  156. await self.database_provider.users_handler.store_verification_code(
  157. new_user.id, verification_code, expiry
  158. )
  159. new_user.verification_code_expiry = expiry
  160. # Safely get first name, defaulting to email if name is None
  161. first_name = (
  162. new_user.name.split(" ")[0]
  163. if new_user.name
  164. else email.split("@")[0]
  165. )
  166. await self.email_provider.send_verification_email(
  167. new_user.email, verification_code, {"first_name": first_name}
  168. )
  169. else:
  170. expiry = datetime.now(timezone.utc) + timedelta(hours=366 * 10)
  171. # Mark user as verified
  172. await self.database_provider.users_handler.store_verification_code(
  173. new_user.id, str(-1), expiry
  174. )
  175. await self.database_provider.users_handler.mark_user_as_verified(
  176. new_user.id
  177. )
  178. return new_user
  179. async def verify_email(
  180. self, email: str, verification_code: str
  181. ) -> dict[str, str]:
  182. user_id = await self.database_provider.users_handler.get_user_id_by_verification_code(
  183. verification_code
  184. )
  185. if not user_id:
  186. raise R2RException(
  187. status_code=400, message="Invalid or expired verification code"
  188. )
  189. await self.database_provider.users_handler.mark_user_as_verified(
  190. user_id
  191. )
  192. await self.database_provider.users_handler.remove_verification_code(
  193. verification_code
  194. )
  195. return {"message": "Email verified successfully"}
  196. async def login(self, email: str, password: str) -> dict[str, Token]:
  197. logger = logging.getLogger()
  198. logger.debug(f"Attempting login for email: {email}")
  199. user = await self.database_provider.users_handler.get_user_by_email(
  200. email
  201. )
  202. if not user:
  203. logger.warning(f"No user found for email: {email}")
  204. raise R2RException(
  205. status_code=401, message="Incorrect email or password"
  206. )
  207. logger.debug(f"User found: {user}")
  208. if not isinstance(user.hashed_password, str):
  209. logger.error(
  210. f"Invalid hashed_password type: {type(user.hashed_password)}"
  211. )
  212. raise HTTPException(
  213. status_code=500,
  214. detail="Invalid password hash in database",
  215. )
  216. try:
  217. password_verified = self.crypto_provider.verify_password(
  218. password, user.hashed_password
  219. )
  220. except Exception as e:
  221. logger.error(f"Error during password verification: {str(e)}")
  222. raise HTTPException(
  223. status_code=500,
  224. detail="Error during password verification",
  225. ) from e
  226. if not password_verified:
  227. logger.warning(f"Invalid password for user: {email}")
  228. raise R2RException(
  229. status_code=401, message="Incorrect email or password"
  230. )
  231. if not user.is_verified and self.config.require_email_verification:
  232. logger.warning(f"Unverified user attempted login: {email}")
  233. raise R2RException(status_code=401, message="Email not verified")
  234. access_token = self.create_access_token(data={"sub": user.email})
  235. refresh_token = self.create_refresh_token(data={"sub": user.email})
  236. return {
  237. "access_token": Token(token=access_token, token_type="access"),
  238. "refresh_token": Token(token=refresh_token, token_type="refresh"),
  239. }
  240. async def refresh_access_token(
  241. self, refresh_token: str
  242. ) -> dict[str, Token]:
  243. token_data = await self.decode_token(refresh_token)
  244. if token_data.token_type != "refresh":
  245. raise R2RException(
  246. status_code=401, message="Invalid refresh token"
  247. )
  248. # Invalidate the old refresh token and create a new one
  249. await self.database_provider.token_handler.blacklist_token(
  250. refresh_token
  251. )
  252. new_access_token = self.create_access_token(
  253. data={"sub": token_data.email}
  254. )
  255. new_refresh_token = self.create_refresh_token(
  256. data={"sub": token_data.email}
  257. )
  258. return {
  259. "access_token": Token(token=new_access_token, token_type="access"),
  260. "refresh_token": Token(
  261. token=new_refresh_token, token_type="refresh"
  262. ),
  263. }
  264. async def change_password(
  265. self, user: User, current_password: str, new_password: str
  266. ) -> dict[str, str]:
  267. if not isinstance(user.hashed_password, str):
  268. logger.error(
  269. f"Invalid hashed_password type: {type(user.hashed_password)}"
  270. )
  271. raise HTTPException(
  272. status_code=500,
  273. detail="Invalid password hash in database",
  274. )
  275. if not self.crypto_provider.verify_password(
  276. current_password, user.hashed_password
  277. ):
  278. raise R2RException(
  279. status_code=400, message="Incorrect current password"
  280. )
  281. hashed_new_password = self.crypto_provider.get_password_hash(
  282. new_password
  283. )
  284. await self.database_provider.users_handler.update_user_password(
  285. user.id, hashed_new_password
  286. )
  287. return {"message": "Password changed successfully"}
  288. async def request_password_reset(self, email: str) -> dict[str, str]:
  289. user = await self.database_provider.users_handler.get_user_by_email(
  290. email
  291. )
  292. if not user:
  293. # To prevent email enumeration, always return a success message
  294. return {
  295. "message": "If the email exists, a reset link has been sent"
  296. }
  297. reset_token = self.crypto_provider.generate_verification_code()
  298. expiry = datetime.now(timezone.utc) + timedelta(hours=1)
  299. await self.database_provider.users_handler.store_reset_token(
  300. user.id, reset_token, expiry
  301. )
  302. # Safely get first name, defaulting to email if name is None
  303. first_name = (
  304. user.name.split(" ")[0] if user.name else email.split("@")[0]
  305. )
  306. await self.email_provider.send_password_reset_email(
  307. email, reset_token, {"first_name": first_name}
  308. )
  309. return {"message": "If the email exists, a reset link has been sent"}
  310. async def confirm_password_reset(
  311. self, reset_token: str, new_password: str
  312. ) -> dict[str, str]:
  313. user_id = await self.database_provider.users_handler.get_user_id_by_reset_token(
  314. reset_token
  315. )
  316. if not user_id:
  317. raise R2RException(
  318. status_code=400, message="Invalid or expired reset token"
  319. )
  320. hashed_new_password = self.crypto_provider.get_password_hash(
  321. new_password
  322. )
  323. await self.database_provider.users_handler.update_user_password(
  324. user_id, hashed_new_password
  325. )
  326. await self.database_provider.users_handler.remove_reset_token(user_id)
  327. return {"message": "Password reset successfully"}
  328. async def logout(self, token: str) -> dict[str, str]:
  329. # Add the token to a blacklist
  330. await self.database_provider.token_handler.blacklist_token(token)
  331. return {"message": "Logged out successfully"}
  332. async def clean_expired_blacklisted_tokens(self):
  333. await self.database_provider.token_handler.clean_expired_blacklisted_tokens()
  334. async def send_reset_email(self, email: str) -> dict:
  335. user = await self.database_provider.users_handler.get_user_by_email(
  336. email
  337. )
  338. if not user:
  339. raise R2RException(status_code=404, message="User not found")
  340. # Generate new verification code
  341. verification_code = self.crypto_provider.generate_verification_code()
  342. expiry = datetime.now(timezone.utc) + timedelta(hours=24)
  343. # Store the verification code
  344. await self.database_provider.users_handler.store_verification_code(
  345. user.id,
  346. verification_code,
  347. expiry,
  348. )
  349. # Safely get first name, defaulting to email if name is None
  350. first_name = (
  351. user.name.split(" ")[0] if user.name else email.split("@")[0]
  352. )
  353. # Send verification email
  354. await self.email_provider.send_verification_email(
  355. email, verification_code, {"first_name": first_name}
  356. )
  357. return {
  358. "verification_code": verification_code,
  359. "expiry": expiry,
  360. "message": f"Verification email sent successfully to {email}",
  361. }