r2r_auth.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. import logging
  2. import os
  3. from datetime import datetime, timedelta, timezone
  4. from typing import Optional
  5. from uuid import UUID
  6. from fastapi import Depends, HTTPException
  7. from fastapi.security import OAuth2PasswordBearer
  8. from core.base import (
  9. AuthConfig,
  10. AuthProvider,
  11. CollectionResponse,
  12. CryptoProvider,
  13. EmailProvider,
  14. R2RException,
  15. Token,
  16. TokenData,
  17. )
  18. from core.base.api.models import User
  19. from ..database import PostgresDatabaseProvider
  20. DEFAULT_ACCESS_LIFETIME_IN_MINUTES = 3600
  21. DEFAULT_REFRESH_LIFETIME_IN_DAYS = 7
  22. logger = logging.getLogger()
  23. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  24. def normalize_email(email: str) -> str:
  25. """Normalizes an email address by converting it to lowercase. This ensures
  26. consistent email handling throughout the application.
  27. Args:
  28. email: The email address to normalize
  29. Returns:
  30. The normalized (lowercase) email address
  31. """
  32. return email.lower() if email else ""
  33. class R2RAuthProvider(AuthProvider):
  34. def __init__(
  35. self,
  36. config: AuthConfig,
  37. crypto_provider: CryptoProvider,
  38. database_provider: PostgresDatabaseProvider,
  39. email_provider: EmailProvider,
  40. ):
  41. super().__init__(
  42. config, crypto_provider, database_provider, email_provider
  43. )
  44. self.database_provider: PostgresDatabaseProvider = database_provider
  45. logger.debug(f"Initializing R2RAuthProvider with config: {config}")
  46. # We no longer use a local secret_key or defaults here.
  47. # All key handling is done in the crypto_provider.
  48. self.access_token_lifetime_in_minutes = (
  49. config.access_token_lifetime_in_minutes
  50. or os.getenv("R2R_ACCESS_LIFE_IN_MINUTES")
  51. or DEFAULT_ACCESS_LIFETIME_IN_MINUTES
  52. )
  53. self.refresh_token_lifetime_in_days = (
  54. config.refresh_token_lifetime_in_days
  55. or os.getenv("R2R_REFRESH_LIFE_IN_DAYS")
  56. or DEFAULT_REFRESH_LIFETIME_IN_DAYS
  57. )
  58. self.config: AuthConfig = config
  59. async def initialize(self):
  60. try:
  61. user = await self.register(
  62. email=normalize_email(self.admin_email),
  63. password=self.admin_password,
  64. is_superuser=True,
  65. )
  66. await self.database_provider.users_handler.mark_user_as_superuser(
  67. id=user.id
  68. )
  69. except R2RException:
  70. logger.info("Default admin user already exists.")
  71. def create_access_token(self, data: dict) -> str:
  72. expire = datetime.now(timezone.utc) + timedelta(
  73. minutes=float(self.access_token_lifetime_in_minutes)
  74. )
  75. # Add token_type and pass data/expiry to crypto_provider
  76. data_with_type = {**data, "token_type": "access"}
  77. return self.crypto_provider.generate_secure_token(
  78. data=data_with_type,
  79. expiry=expire,
  80. )
  81. def create_refresh_token(self, data: dict) -> str:
  82. expire = datetime.now(timezone.utc) + timedelta(
  83. days=float(self.refresh_token_lifetime_in_days)
  84. )
  85. data_with_type = {**data, "token_type": "refresh"}
  86. return self.crypto_provider.generate_secure_token(
  87. data=data_with_type,
  88. expiry=expire,
  89. )
  90. async def decode_token(self, token: str) -> TokenData:
  91. if "token=" in token:
  92. token = token.split("token=")[1]
  93. if "&tokenType=refresh" in token:
  94. token = token.split("&tokenType=refresh")[0]
  95. # First, check if the token is blacklisted
  96. if await self.database_provider.token_handler.is_token_blacklisted(
  97. token=token
  98. ):
  99. raise R2RException(
  100. status_code=401, message="Token has been invalidated"
  101. )
  102. # Verify token using crypto_provider
  103. payload = self.crypto_provider.verify_secure_token(token=token)
  104. if payload is None:
  105. raise R2RException(
  106. status_code=401, message="Invalid or expired token"
  107. )
  108. email = payload.get("sub")
  109. token_type = payload.get("token_type")
  110. exp = payload.get("exp")
  111. if email is None or token_type is None or exp is None:
  112. raise R2RException(status_code=401, message="Invalid token claims")
  113. email_str: str = email
  114. token_type_str: str = token_type
  115. exp_float: float = exp
  116. exp_datetime = datetime.fromtimestamp(exp_float, tz=timezone.utc)
  117. if exp_datetime < datetime.now(timezone.utc):
  118. raise R2RException(status_code=401, message="Token has expired")
  119. return TokenData(
  120. email=normalize_email(email_str),
  121. token_type=token_type_str,
  122. exp=exp_datetime,
  123. )
  124. async def authenticate_api_key(self, api_key: str) -> User:
  125. """Authenticate using an API key of the form "public_key.raw_key".
  126. Returns a User if successful, or raises R2RException if not.
  127. """
  128. try:
  129. key_id, raw_key = api_key.split(".", 1)
  130. except ValueError as e:
  131. raise R2RException(
  132. status_code=401, message="Invalid API key format"
  133. ) from e
  134. key_record = (
  135. await self.database_provider.users_handler.get_api_key_record(
  136. key_id=key_id
  137. )
  138. )
  139. if not key_record:
  140. raise R2RException(status_code=401, message="Invalid API key")
  141. if not self.crypto_provider.verify_api_key(
  142. raw_api_key=raw_key, hashed_key=key_record["hashed_key"]
  143. ):
  144. raise R2RException(status_code=401, message="Invalid API key")
  145. user = await self.database_provider.users_handler.get_user_by_id(
  146. id=key_record["user_id"]
  147. )
  148. if not user.is_active:
  149. raise R2RException(
  150. status_code=401, message="User account is inactive"
  151. )
  152. return user
  153. async def user(self, token: str = Depends(oauth2_scheme)) -> User:
  154. """Attempt to authenticate via JWT first, then fallback to API key."""
  155. # Try JWT auth
  156. try:
  157. token_data = await self.decode_token(token=token)
  158. if not token_data.email:
  159. raise R2RException(
  160. status_code=401, message="Could not validate credentials"
  161. )
  162. user = (
  163. await self.database_provider.users_handler.get_user_by_email(
  164. email=normalize_email(token_data.email)
  165. )
  166. )
  167. if user is None:
  168. raise R2RException(
  169. status_code=401,
  170. message="Invalid authentication credentials",
  171. )
  172. return user
  173. except R2RException:
  174. # If JWT fails, try API key auth
  175. # OAuth2PasswordBearer provides token as "Bearer xxx", strip it if needed
  176. token = token.removeprefix("Bearer ")
  177. return await self.authenticate_api_key(api_key=token)
  178. def get_current_active_user(
  179. self, current_user: User = Depends(user)
  180. ) -> User:
  181. if not current_user.is_active:
  182. raise R2RException(status_code=400, message="Inactive user")
  183. return current_user
  184. async def register(
  185. self,
  186. email: str,
  187. password: Optional[str] = None,
  188. is_superuser: bool = False,
  189. is_verified: bool = False,
  190. account_type: str = "password",
  191. github_id: Optional[str] = None,
  192. google_id: Optional[str] = None,
  193. name: Optional[str] = None,
  194. bio: Optional[str] = None,
  195. profile_picture: Optional[str] = None,
  196. ) -> User:
  197. if account_type == "password":
  198. if not password:
  199. raise R2RException(
  200. status_code=400,
  201. message="Password is required for password accounts",
  202. )
  203. else:
  204. if github_id and google_id:
  205. raise R2RException(
  206. status_code=400,
  207. message="Cannot register OAuth with both GitHub and Google IDs",
  208. )
  209. if not github_id and not google_id:
  210. raise R2RException(
  211. status_code=400,
  212. message="Invalid OAuth specification without GitHub or Google ID",
  213. )
  214. new_user = await self.database_provider.users_handler.create_user(
  215. email=normalize_email(email),
  216. password=password,
  217. is_superuser=is_superuser,
  218. is_verified=is_verified,
  219. account_type=account_type,
  220. github_id=github_id,
  221. google_id=google_id,
  222. name=name,
  223. bio=bio,
  224. profile_picture=profile_picture,
  225. )
  226. default_collection: CollectionResponse = (
  227. await self.database_provider.collections_handler.create_collection(
  228. owner_id=new_user.id,
  229. )
  230. )
  231. await self.database_provider.graphs_handler.create(
  232. collection_id=default_collection.id,
  233. name=default_collection.name,
  234. description=default_collection.description,
  235. )
  236. await self.database_provider.users_handler.add_user_to_collection(
  237. new_user.id, default_collection.id
  238. )
  239. new_user = await self.database_provider.users_handler.get_user_by_id(
  240. new_user.id
  241. )
  242. if self.config.require_email_verification and not is_verified:
  243. verification_code, _ = await self.send_verification_email(
  244. email=normalize_email(email), user=new_user
  245. )
  246. return new_user
  247. async def send_verification_email(
  248. self, email: str, user: Optional[User] = None
  249. ) -> tuple[str, datetime]:
  250. if user is None:
  251. user = (
  252. await self.database_provider.users_handler.get_user_by_email(
  253. email=normalize_email(email)
  254. )
  255. )
  256. if not user:
  257. raise R2RException(status_code=404, message="User not found")
  258. verification_code = self.crypto_provider.generate_verification_code()
  259. expiry = datetime.now(timezone.utc) + timedelta(hours=24)
  260. await self.database_provider.users_handler.store_verification_code(
  261. id=user.id,
  262. verification_code=verification_code,
  263. expiry=expiry,
  264. )
  265. if hasattr(user, "verification_code_expiry"):
  266. user.verification_code_expiry = expiry
  267. first_name = (
  268. user.name.split(" ")[0] if user.name else email.split("@")[0]
  269. )
  270. await self.email_provider.send_verification_email(
  271. to_email=user.email,
  272. verification_code=verification_code,
  273. dynamic_template_data={"first_name": first_name},
  274. )
  275. return verification_code, expiry
  276. async def verify_email(
  277. self, email: str, verification_code: str
  278. ) -> dict[str, str]:
  279. user_id = await self.database_provider.users_handler.get_user_id_by_verification_code(
  280. verification_code=verification_code
  281. )
  282. await self.database_provider.users_handler.mark_user_as_verified(
  283. id=user_id
  284. )
  285. await self.database_provider.users_handler.remove_verification_code(
  286. verification_code=verification_code
  287. )
  288. return {"message": "Email verified successfully"}
  289. async def login(self, email: str, password: str) -> dict[str, Token]:
  290. logger.debug(f"Attempting login for email: {email}")
  291. user = await self.database_provider.users_handler.get_user_by_email(
  292. email=normalize_email(email)
  293. )
  294. if user.account_type != "password":
  295. logger.warning(
  296. f"Password login not allowed for {user.account_type} accounts: {email}"
  297. )
  298. raise R2RException(
  299. status_code=401,
  300. message=f"This account is configured for {user.account_type} login, not password.",
  301. )
  302. logger.debug(f"User found: {user}")
  303. if not isinstance(user.hashed_password, str):
  304. logger.error(
  305. f"Invalid hashed_password type: {type(user.hashed_password)}"
  306. )
  307. raise HTTPException(
  308. status_code=500,
  309. detail="Invalid password hash in database",
  310. )
  311. try:
  312. password_verified = self.crypto_provider.verify_password(
  313. plain_password=password,
  314. hashed_password=user.hashed_password,
  315. )
  316. except Exception as e:
  317. logger.error(f"Error during password verification: {str(e)}")
  318. raise HTTPException(
  319. status_code=500,
  320. detail="Error during password verification",
  321. ) from e
  322. if not password_verified:
  323. logger.warning(f"Invalid password for user: {email}")
  324. raise R2RException(
  325. status_code=401, message="Incorrect email or password"
  326. )
  327. if not user.is_verified and self.config.require_email_verification:
  328. logger.warning(f"Unverified user attempted login: {email}")
  329. raise R2RException(status_code=401, message="Email not verified")
  330. access_token = self.create_access_token(
  331. data={"sub": normalize_email(user.email)}
  332. )
  333. refresh_token = self.create_refresh_token(
  334. data={"sub": normalize_email(user.email)}
  335. )
  336. return {
  337. "access_token": Token(token=access_token, token_type="access"),
  338. "refresh_token": Token(token=refresh_token, token_type="refresh"),
  339. }
  340. async def refresh_access_token(
  341. self, refresh_token: str
  342. ) -> dict[str, Token]:
  343. token_data = await self.decode_token(refresh_token)
  344. if token_data.token_type != "refresh":
  345. raise R2RException(
  346. status_code=401, message="Invalid refresh token"
  347. )
  348. # Invalidate the old refresh token and create a new one
  349. await self.database_provider.token_handler.blacklist_token(
  350. token=refresh_token
  351. )
  352. new_access_token = self.create_access_token(
  353. data={"sub": normalize_email(token_data.email)}
  354. )
  355. new_refresh_token = self.create_refresh_token(
  356. data={"sub": normalize_email(token_data.email)}
  357. )
  358. return {
  359. "access_token": Token(token=new_access_token, token_type="access"),
  360. "refresh_token": Token(
  361. token=new_refresh_token, token_type="refresh"
  362. ),
  363. }
  364. async def change_password(
  365. self, user: User, current_password: str, new_password: str
  366. ) -> dict[str, str]:
  367. if not isinstance(user.hashed_password, str):
  368. logger.error(
  369. f"Invalid hashed_password type: {type(user.hashed_password)}"
  370. )
  371. raise HTTPException(
  372. status_code=500,
  373. detail="Invalid password hash in database",
  374. )
  375. if not self.crypto_provider.verify_password(
  376. plain_password=current_password,
  377. hashed_password=user.hashed_password,
  378. ):
  379. raise R2RException(
  380. status_code=400, message="Incorrect current password"
  381. )
  382. hashed_new_password = self.crypto_provider.get_password_hash(
  383. password=new_password
  384. )
  385. await self.database_provider.users_handler.update_user_password(
  386. id=user.id,
  387. new_hashed_password=hashed_new_password,
  388. )
  389. try:
  390. await self.email_provider.send_password_changed_email(
  391. to_email=normalize_email(user.email),
  392. dynamic_template_data={
  393. "first_name": (
  394. user.name.split(" ")[0] or "User"
  395. if user.name
  396. else "User"
  397. )
  398. },
  399. )
  400. except Exception as e:
  401. logger.error(
  402. f"Failed to send password change notification: {str(e)}"
  403. )
  404. return {"message": "Password changed successfully"}
  405. async def request_password_reset(self, email: str) -> dict[str, str]:
  406. try:
  407. user = (
  408. await self.database_provider.users_handler.get_user_by_email(
  409. email=normalize_email(email)
  410. )
  411. )
  412. reset_token = self.crypto_provider.generate_verification_code()
  413. expiry = datetime.now(timezone.utc) + timedelta(hours=1)
  414. await self.database_provider.users_handler.store_reset_token(
  415. id=user.id,
  416. reset_token=reset_token,
  417. expiry=expiry,
  418. )
  419. first_name = (
  420. user.name.split(" ")[0] if user.name else email.split("@")[0]
  421. )
  422. await self.email_provider.send_password_reset_email(
  423. to_email=normalize_email(email),
  424. reset_token=reset_token,
  425. dynamic_template_data={"first_name": first_name},
  426. )
  427. return {
  428. "message": "If the email exists, a reset link has been sent"
  429. }
  430. except R2RException as e:
  431. if e.status_code == 404:
  432. # User doesn't exist; return a success message anyway
  433. return {
  434. "message": "If the email exists, a reset link has been sent"
  435. }
  436. else:
  437. raise
  438. async def confirm_password_reset(
  439. self, reset_token: str, new_password: str
  440. ) -> dict[str, str]:
  441. user_id = await self.database_provider.users_handler.get_user_id_by_reset_token(
  442. reset_token=reset_token
  443. )
  444. if not user_id:
  445. raise R2RException(
  446. status_code=400, message="Invalid or expired reset token"
  447. )
  448. hashed_new_password = self.crypto_provider.get_password_hash(
  449. password=new_password
  450. )
  451. await self.database_provider.users_handler.update_user_password(
  452. id=user_id,
  453. new_hashed_password=hashed_new_password,
  454. )
  455. await self.database_provider.users_handler.remove_reset_token(
  456. id=user_id
  457. )
  458. # Get the user information
  459. user = await self.database_provider.users_handler.get_user_by_id(
  460. id=user_id
  461. )
  462. try:
  463. await self.email_provider.send_password_changed_email(
  464. to_email=normalize_email(user.email),
  465. dynamic_template_data={
  466. "first_name": (
  467. user.name.split(" ")[0] or "User"
  468. if user.name
  469. else "User"
  470. )
  471. },
  472. )
  473. except Exception as e:
  474. logger.error(
  475. f"Failed to send password change notification: {str(e)}"
  476. )
  477. return {"message": "Password reset successfully"}
  478. async def logout(self, token: str) -> dict[str, str]:
  479. await self.database_provider.token_handler.blacklist_token(token=token)
  480. return {"message": "Logged out successfully"}
  481. async def clean_expired_blacklisted_tokens(self):
  482. await self.database_provider.token_handler.clean_expired_blacklisted_tokens()
  483. async def send_reset_email(self, email: str) -> dict:
  484. verification_code, expiry = await self.send_verification_email(
  485. email=normalize_email(email)
  486. )
  487. return {
  488. "verification_code": verification_code,
  489. "expiry": expiry,
  490. "message": f"Verification email sent successfully to {email}",
  491. }
  492. async def create_user_api_key(
  493. self,
  494. user_id: UUID,
  495. name: Optional[str] = None,
  496. description: Optional[str] = None,
  497. ) -> dict[str, str]:
  498. key_id, raw_api_key = self.crypto_provider.generate_api_key()
  499. hashed_key = self.crypto_provider.hash_api_key(raw_api_key)
  500. api_key_uuid = (
  501. await self.database_provider.users_handler.store_user_api_key(
  502. user_id=user_id,
  503. key_id=key_id,
  504. hashed_key=hashed_key,
  505. name=name,
  506. description=description,
  507. )
  508. )
  509. return {
  510. "api_key": f"{key_id}.{raw_api_key}",
  511. "key_id": str(api_key_uuid),
  512. "public_key": key_id,
  513. "name": name or "",
  514. }
  515. async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
  516. return await self.database_provider.users_handler.get_user_api_keys(
  517. user_id=user_id
  518. )
  519. async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
  520. return await self.database_provider.users_handler.delete_api_key(
  521. user_id=user_id,
  522. key_id=key_id,
  523. )
  524. async def rename_api_key(
  525. self, user_id: UUID, key_id: UUID, new_name: str
  526. ) -> bool:
  527. return await self.database_provider.users_handler.update_api_key_name(
  528. user_id=user_id,
  529. key_id=key_id,
  530. name=new_name,
  531. )
  532. async def oauth_callback_handler(
  533. self, provider: str, oauth_id: str, email: str
  534. ) -> dict[str, Token]:
  535. """Handles a login/registration flow for OAuth providers (e.g., Google
  536. or GitHub).
  537. :param provider: "google" or "github"
  538. :param oauth_id: The unique ID from the OAuth provider (e.g. Google's
  539. 'sub')
  540. :param email: The user's email from the provider, if available.
  541. :return: dict with access_token and refresh_token
  542. """
  543. # 1) Attempt to find user by google_id or github_id, or by email
  544. # The logic depends on your preference. We'll assume "google" => google_id, etc.
  545. try:
  546. if provider == "google":
  547. try:
  548. user = await self.database_provider.users_handler.get_user_by_email(
  549. normalize_email(email)
  550. )
  551. # If user found, check if user.google_id matches or is null. If null, update it
  552. if user and not user.google_id:
  553. raise R2RException(
  554. status_code=401,
  555. message="User already exists and is not linked to Google account",
  556. )
  557. except Exception:
  558. # Create new user
  559. user = await self.register(
  560. email=normalize_email(email)
  561. or f"{oauth_id}@google_oauth.fake", # fallback
  562. password=None, # no password
  563. account_type="oauth",
  564. google_id=oauth_id,
  565. )
  566. elif provider == "github":
  567. try:
  568. user = await self.database_provider.users_handler.get_user_by_email(
  569. normalize_email(email)
  570. )
  571. # If user found, check if user.google_id matches or is null. If null, update it
  572. if user and not user.github_id:
  573. raise R2RException(
  574. status_code=401,
  575. message="User already exists and is not linked to Github account",
  576. )
  577. except Exception:
  578. # Create new user
  579. user = await self.register(
  580. email=normalize_email(email)
  581. or f"{oauth_id}@github_oauth.fake", # fallback
  582. password=None, # no password
  583. account_type="oauth",
  584. github_id=oauth_id,
  585. )
  586. # else handle other providers
  587. except R2RException:
  588. # If no user found or creation fails
  589. raise R2RException(
  590. status_code=401, message="Could not create or fetch user"
  591. ) from None
  592. # If user is inactive, etc.
  593. if not user.is_active:
  594. raise R2RException(
  595. status_code=401, message="User account is inactive"
  596. )
  597. # Possibly mark user as verified if you trust the OAuth provider's email
  598. user.is_verified = True
  599. await self.database_provider.users_handler.update_user(user)
  600. # 2) Generate tokens
  601. access_token = self.create_access_token(
  602. data={"sub": normalize_email(user.email)}
  603. )
  604. refresh_token = self.create_refresh_token(
  605. data={"sub": normalize_email(user.email)}
  606. )
  607. return {
  608. "access_token": Token(token=access_token, token_type="access"),
  609. "refresh_token": Token(token=refresh_token, token_type="refresh"),
  610. }