r2r_auth.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502
  1. import logging
  2. import os
  3. from datetime import datetime, timedelta, timezone
  4. from typing import Optional
  5. from uuid import UUID
  6. from fastapi import Depends, HTTPException
  7. from fastapi.security import OAuth2PasswordBearer
  8. from core.base import (
  9. AuthConfig,
  10. AuthProvider,
  11. CollectionResponse,
  12. CryptoProvider,
  13. EmailProvider,
  14. R2RException,
  15. Token,
  16. TokenData,
  17. )
  18. from core.base.api.models import User
  19. from ...database.postgres import PostgresDatabaseProvider
  20. DEFAULT_ACCESS_LIFETIME_IN_MINUTES = 3600
  21. DEFAULT_REFRESH_LIFETIME_IN_DAYS = 7
  22. logger = logging.getLogger()
  23. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  24. class R2RAuthProvider(AuthProvider):
  25. def __init__(
  26. self,
  27. config: AuthConfig,
  28. crypto_provider: CryptoProvider,
  29. database_provider: PostgresDatabaseProvider,
  30. email_provider: EmailProvider,
  31. ):
  32. super().__init__(
  33. config, crypto_provider, database_provider, email_provider
  34. )
  35. self.database_provider: PostgresDatabaseProvider = database_provider
  36. logger.debug(f"Initializing R2RAuthProvider with config: {config}")
  37. # We no longer use a local secret_key or defaults here.
  38. # All key handling is done in the crypto_provider.
  39. self.access_token_lifetime_in_minutes = (
  40. config.access_token_lifetime_in_minutes
  41. or os.getenv("R2R_ACCESS_LIFE_IN_MINUTES")
  42. or DEFAULT_ACCESS_LIFETIME_IN_MINUTES
  43. )
  44. self.refresh_token_lifetime_in_days = (
  45. config.refresh_token_lifetime_in_days
  46. or os.getenv("R2R_REFRESH_LIFE_IN_DAYS")
  47. or DEFAULT_REFRESH_LIFETIME_IN_DAYS
  48. )
  49. self.config: AuthConfig = config
  50. async def initialize(self):
  51. try:
  52. user = await self.register(
  53. email=self.admin_email,
  54. password=self.admin_password,
  55. is_superuser=True,
  56. )
  57. await self.database_provider.users_handler.mark_user_as_superuser(
  58. id=user.id
  59. )
  60. except R2RException:
  61. logger.info("Default admin user already exists.")
  62. def create_access_token(self, data: dict) -> str:
  63. expire = datetime.now(timezone.utc) + timedelta(
  64. minutes=float(self.access_token_lifetime_in_minutes)
  65. )
  66. # Add token_type and pass data/expiry to crypto_provider
  67. data_with_type = {**data, "token_type": "access"}
  68. return self.crypto_provider.generate_secure_token(
  69. data=data_with_type,
  70. expiry=expire,
  71. )
  72. def create_refresh_token(self, data: dict) -> str:
  73. expire = datetime.now(timezone.utc) + timedelta(
  74. days=float(self.refresh_token_lifetime_in_days)
  75. )
  76. data_with_type = {**data, "token_type": "refresh"}
  77. return self.crypto_provider.generate_secure_token(
  78. data=data_with_type,
  79. expiry=expire,
  80. )
  81. async def decode_token(self, token: str) -> TokenData:
  82. # First, check if the token is blacklisted
  83. if await self.database_provider.token_handler.is_token_blacklisted(
  84. token=token
  85. ):
  86. raise R2RException(
  87. status_code=401, message="Token has been invalidated"
  88. )
  89. # Verify token using crypto_provider
  90. payload = self.crypto_provider.verify_secure_token(token=token)
  91. if payload is None:
  92. raise R2RException(
  93. status_code=401, message="Invalid or expired token"
  94. )
  95. email = payload.get("sub")
  96. token_type = payload.get("token_type")
  97. exp = payload.get("exp")
  98. if email is None or token_type is None or exp is None:
  99. raise R2RException(status_code=401, message="Invalid token claims")
  100. email_str: str = email
  101. token_type_str: str = token_type
  102. exp_float: float = exp
  103. exp_datetime = datetime.fromtimestamp(exp_float, tz=timezone.utc)
  104. if exp_datetime < datetime.now(timezone.utc):
  105. raise R2RException(status_code=401, message="Token has expired")
  106. return TokenData(
  107. email=email_str, token_type=token_type_str, exp=exp_datetime
  108. )
  109. async def authenticate_api_key(self, api_key: str) -> Optional[User]:
  110. """
  111. Authenticate using an API key of the form "public_key.raw_key".
  112. Returns a User if successful, or raises R2RException if not.
  113. """
  114. try:
  115. key_id, raw_key = api_key.split(".", 1)
  116. except ValueError:
  117. raise R2RException(
  118. status_code=401, message="Invalid API key format"
  119. )
  120. key_record = (
  121. await self.database_provider.users_handler.get_api_key_record(
  122. key_id=key_id
  123. )
  124. )
  125. if not key_record:
  126. raise R2RException(status_code=401, message="Invalid API key")
  127. if not self.crypto_provider.verify_api_key(
  128. raw_api_key=raw_key, hashed_key=key_record["hashed_key"]
  129. ):
  130. raise R2RException(status_code=401, message="Invalid API key")
  131. user = await self.database_provider.users_handler.get_user_by_id(
  132. id=key_record["user_id"]
  133. )
  134. if not user.is_active:
  135. raise R2RException(
  136. status_code=401, message="User account is inactive"
  137. )
  138. return user
  139. async def user(self, token: str = Depends(oauth2_scheme)) -> User:
  140. """
  141. Attempt to authenticate via JWT first, then fallback to API key.
  142. """
  143. # Try JWT auth
  144. try:
  145. token_data = await self.decode_token(token=token)
  146. if not token_data.email:
  147. raise R2RException(
  148. status_code=401, message="Could not validate credentials"
  149. )
  150. user = (
  151. await self.database_provider.users_handler.get_user_by_email(
  152. email=token_data.email
  153. )
  154. )
  155. if user is None:
  156. raise R2RException(
  157. status_code=401,
  158. message="Invalid authentication credentials",
  159. )
  160. return user
  161. except R2RException:
  162. # If JWT fails, try API key auth
  163. # OAuth2PasswordBearer provides token as "Bearer xxx", strip it if needed
  164. token = token.removeprefix("Bearer ")
  165. return await self.authenticate_api_key(api_key=token)
  166. def get_current_active_user(
  167. self, current_user: User = Depends(user)
  168. ) -> User:
  169. if not current_user.is_active:
  170. raise R2RException(status_code=400, message="Inactive user")
  171. return current_user
  172. async def register(
  173. self, email: str, password: str, is_superuser: bool = False
  174. ) -> User:
  175. new_user = await self.database_provider.users_handler.create_user(
  176. email=email, password=password, is_superuser=is_superuser
  177. )
  178. default_collection: CollectionResponse = (
  179. await self.database_provider.collections_handler.create_collection(
  180. owner_id=new_user.id,
  181. )
  182. )
  183. graph_result = await self.database_provider.graphs_handler.create(
  184. collection_id=default_collection.id,
  185. name=default_collection.name,
  186. description=default_collection.description,
  187. )
  188. await self.database_provider.users_handler.add_user_to_collection(
  189. new_user.id, default_collection.id
  190. )
  191. if self.config.require_email_verification:
  192. verification_code = (
  193. self.crypto_provider.generate_verification_code()
  194. )
  195. expiry = datetime.now(timezone.utc) + timedelta(hours=24)
  196. await self.database_provider.users_handler.store_verification_code(
  197. id=new_user.id,
  198. verification_code=verification_code,
  199. expiry=expiry,
  200. )
  201. new_user.verification_code_expiry = expiry
  202. first_name = (
  203. new_user.name.split(" ")[0]
  204. if new_user.name
  205. else email.split("@")[0]
  206. )
  207. await self.email_provider.send_verification_email(
  208. new_user.email, verification_code, {"first_name": first_name}
  209. )
  210. else:
  211. expiry = datetime.now(timezone.utc) + timedelta(hours=366 * 10)
  212. await self.database_provider.users_handler.store_verification_code(
  213. id=new_user.id,
  214. verification_code=str(-1),
  215. expiry=expiry,
  216. )
  217. await self.database_provider.users_handler.mark_user_as_verified(
  218. id=new_user.id
  219. )
  220. return new_user
  221. async def verify_email(
  222. self, email: str, verification_code: str
  223. ) -> dict[str, str]:
  224. user_id = await self.database_provider.users_handler.get_user_id_by_verification_code(
  225. verification_code=verification_code
  226. )
  227. await self.database_provider.users_handler.mark_user_as_verified(
  228. id=user_id
  229. )
  230. await self.database_provider.users_handler.remove_verification_code(
  231. verification_code=verification_code
  232. )
  233. return {"message": "Email verified successfully"}
  234. async def login(self, email: str, password: str) -> dict[str, Token]:
  235. logger.debug(f"Attempting login for email: {email}")
  236. user = await self.database_provider.users_handler.get_user_by_email(
  237. email=email
  238. )
  239. logger.debug(f"User found: {user}")
  240. if not isinstance(user.hashed_password, str):
  241. logger.error(
  242. f"Invalid hashed_password type: {type(user.hashed_password)}"
  243. )
  244. raise HTTPException(
  245. status_code=500,
  246. detail="Invalid password hash in database",
  247. )
  248. try:
  249. password_verified = self.crypto_provider.verify_password(
  250. plain_password=password,
  251. hashed_password=user.hashed_password,
  252. )
  253. except Exception as e:
  254. logger.error(f"Error during password verification: {str(e)}")
  255. raise HTTPException(
  256. status_code=500,
  257. detail="Error during password verification",
  258. ) from e
  259. if not password_verified:
  260. logger.warning(f"Invalid password for user: {email}")
  261. raise R2RException(
  262. status_code=401, message="Incorrect email or password"
  263. )
  264. if not user.is_verified and self.config.require_email_verification:
  265. logger.warning(f"Unverified user attempted login: {email}")
  266. raise R2RException(status_code=401, message="Email not verified")
  267. access_token = self.create_access_token(data={"sub": user.email})
  268. refresh_token = self.create_refresh_token(data={"sub": user.email})
  269. return {
  270. "access_token": Token(token=access_token, token_type="access"),
  271. "refresh_token": Token(token=refresh_token, token_type="refresh"),
  272. }
  273. async def refresh_access_token(
  274. self, refresh_token: str
  275. ) -> dict[str, Token]:
  276. token_data = await self.decode_token(refresh_token)
  277. if token_data.token_type != "refresh":
  278. raise R2RException(
  279. status_code=401, message="Invalid refresh token"
  280. )
  281. # Invalidate the old refresh token and create a new one
  282. await self.database_provider.token_handler.blacklist_token(
  283. token=refresh_token
  284. )
  285. new_access_token = self.create_access_token(
  286. data={"sub": token_data.email}
  287. )
  288. new_refresh_token = self.create_refresh_token(
  289. data={"sub": token_data.email}
  290. )
  291. return {
  292. "access_token": Token(token=new_access_token, token_type="access"),
  293. "refresh_token": Token(
  294. token=new_refresh_token, token_type="refresh"
  295. ),
  296. }
  297. async def change_password(
  298. self, user: User, current_password: str, new_password: str
  299. ) -> dict[str, str]:
  300. if not isinstance(user.hashed_password, str):
  301. logger.error(
  302. f"Invalid hashed_password type: {type(user.hashed_password)}"
  303. )
  304. raise HTTPException(
  305. status_code=500,
  306. detail="Invalid password hash in database",
  307. )
  308. if not self.crypto_provider.verify_password(
  309. plain_password=current_password,
  310. hashed_password=user.hashed_password,
  311. ):
  312. raise R2RException(
  313. status_code=400, message="Incorrect current password"
  314. )
  315. hashed_new_password = self.crypto_provider.get_password_hash(
  316. password=new_password
  317. )
  318. await self.database_provider.users_handler.update_user_password(
  319. id=user.id,
  320. new_hashed_password=hashed_new_password,
  321. )
  322. return {"message": "Password changed successfully"}
  323. async def request_password_reset(self, email: str) -> dict[str, str]:
  324. try:
  325. user = (
  326. await self.database_provider.users_handler.get_user_by_email(
  327. email=email
  328. )
  329. )
  330. reset_token = self.crypto_provider.generate_verification_code()
  331. expiry = datetime.now(timezone.utc) + timedelta(hours=1)
  332. await self.database_provider.users_handler.store_reset_token(
  333. id=user.id,
  334. reset_token=reset_token,
  335. expiry=expiry,
  336. )
  337. first_name = (
  338. user.name.split(" ")[0] if user.name else email.split("@")[0]
  339. )
  340. await self.email_provider.send_password_reset_email(
  341. email, reset_token, {"first_name": first_name}
  342. )
  343. return {
  344. "message": "If the email exists, a reset link has been sent"
  345. }
  346. except R2RException as e:
  347. if e.status_code == 404:
  348. # User doesn't exist; return a success message anyway
  349. return {
  350. "message": "If the email exists, a reset link has been sent"
  351. }
  352. else:
  353. raise
  354. async def confirm_password_reset(
  355. self, reset_token: str, new_password: str
  356. ) -> dict[str, str]:
  357. user_id = await self.database_provider.users_handler.get_user_id_by_reset_token(
  358. reset_token=reset_token
  359. )
  360. if not user_id:
  361. raise R2RException(
  362. status_code=400, message="Invalid or expired reset token"
  363. )
  364. hashed_new_password = self.crypto_provider.get_password_hash(
  365. password=new_password
  366. )
  367. await self.database_provider.users_handler.update_user_password(
  368. id=user_id,
  369. new_hashed_password=hashed_new_password,
  370. )
  371. await self.database_provider.users_handler.remove_reset_token(
  372. id=user_id
  373. )
  374. return {"message": "Password reset successfully"}
  375. async def logout(self, token: str) -> dict[str, str]:
  376. await self.database_provider.token_handler.blacklist_token(token=token)
  377. return {"message": "Logged out successfully"}
  378. async def clean_expired_blacklisted_tokens(self):
  379. await self.database_provider.token_handler.clean_expired_blacklisted_tokens()
  380. async def send_reset_email(self, email: str) -> dict:
  381. user = await self.database_provider.users_handler.get_user_by_email(
  382. email=email
  383. )
  384. if not user:
  385. raise R2RException(status_code=404, message="User not found")
  386. verification_code = self.crypto_provider.generate_verification_code()
  387. expiry = datetime.now(timezone.utc) + timedelta(hours=24)
  388. await self.database_provider.users_handler.store_verification_code(
  389. id=user.id,
  390. verification_code=verification_code,
  391. expiry=expiry,
  392. )
  393. first_name = (
  394. user.name.split(" ")[0] if user.name else email.split("@")[0]
  395. )
  396. await self.email_provider.send_verification_email(
  397. email, verification_code, {"first_name": first_name}
  398. )
  399. return {
  400. "verification_code": verification_code,
  401. "expiry": expiry,
  402. "message": f"Verification email sent successfully to {email}",
  403. }
  404. async def create_user_api_key(
  405. self, user_id: UUID, name: Optional[str] = None
  406. ) -> dict[str, str]:
  407. key_id, raw_api_key = self.crypto_provider.generate_api_key()
  408. hashed_key = self.crypto_provider.hash_api_key(raw_api_key)
  409. api_key_uuid = (
  410. await self.database_provider.users_handler.store_user_api_key(
  411. user_id=user_id,
  412. key_id=key_id,
  413. hashed_key=hashed_key,
  414. name=name,
  415. )
  416. )
  417. return {
  418. "api_key": f"{key_id}.{raw_api_key}",
  419. "key_id": str(api_key_uuid),
  420. "public_key": key_id,
  421. "name": name or "",
  422. }
  423. async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
  424. return await self.database_provider.users_handler.get_user_api_keys(
  425. user_id=user_id
  426. )
  427. async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> dict:
  428. return await self.database_provider.users_handler.delete_api_key(
  429. user_id=user_id,
  430. key_id=key_id,
  431. )
  432. async def rename_api_key(
  433. self, user_id: UUID, key_id: UUID, new_name: str
  434. ) -> bool:
  435. return await self.database_provider.users_handler.update_api_key_name(
  436. user_id=user_id,
  437. key_id=key_id,
  438. name=new_name,
  439. )