r2r_auth.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. import logging
  2. import os
  3. from datetime import datetime, timedelta, timezone
  4. from typing import Optional
  5. from uuid import UUID
  6. from fastapi import Depends, HTTPException
  7. from fastapi.security import OAuth2PasswordBearer
  8. from core.base import (
  9. AuthConfig,
  10. AuthProvider,
  11. CollectionResponse,
  12. CryptoProvider,
  13. EmailProvider,
  14. R2RException,
  15. Token,
  16. TokenData,
  17. )
  18. from core.base.api.models import User
  19. from ...database.postgres import PostgresDatabaseProvider
  20. DEFAULT_ACCESS_LIFETIME_IN_MINUTES = 3600
  21. DEFAULT_REFRESH_LIFETIME_IN_DAYS = 7
  22. logger = logging.getLogger()
  23. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  24. class R2RAuthProvider(AuthProvider):
  25. def __init__(
  26. self,
  27. config: AuthConfig,
  28. crypto_provider: CryptoProvider,
  29. database_provider: PostgresDatabaseProvider,
  30. email_provider: EmailProvider,
  31. ):
  32. super().__init__(
  33. config, crypto_provider, database_provider, email_provider
  34. )
  35. self.database_provider: PostgresDatabaseProvider = database_provider
  36. logger.debug(f"Initializing R2RAuthProvider with config: {config}")
  37. # We no longer use a local secret_key or defaults here.
  38. # All key handling is done in the crypto_provider.
  39. self.access_token_lifetime_in_minutes = (
  40. config.access_token_lifetime_in_minutes
  41. or os.getenv("R2R_ACCESS_LIFE_IN_MINUTES")
  42. or DEFAULT_ACCESS_LIFETIME_IN_MINUTES
  43. )
  44. self.refresh_token_lifetime_in_days = (
  45. config.refresh_token_lifetime_in_days
  46. or os.getenv("R2R_REFRESH_LIFE_IN_DAYS")
  47. or DEFAULT_REFRESH_LIFETIME_IN_DAYS
  48. )
  49. self.config: AuthConfig = config
  50. async def initialize(self):
  51. try:
  52. user = await self.register(
  53. email=self.admin_email,
  54. password=self.admin_password,
  55. is_superuser=True,
  56. )
  57. await self.database_provider.users_handler.mark_user_as_superuser(
  58. id=user.id
  59. )
  60. except R2RException:
  61. logger.info("Default admin user already exists.")
  62. def create_access_token(self, data: dict) -> str:
  63. expire = datetime.now(timezone.utc) + timedelta(
  64. minutes=float(self.access_token_lifetime_in_minutes)
  65. )
  66. # Add token_type and pass data/expiry to crypto_provider
  67. data_with_type = {**data, "token_type": "access"}
  68. return self.crypto_provider.generate_secure_token(
  69. data=data_with_type,
  70. expiry=expire,
  71. )
  72. def create_refresh_token(self, data: dict) -> str:
  73. expire = datetime.now(timezone.utc) + timedelta(
  74. days=float(self.refresh_token_lifetime_in_days)
  75. )
  76. data_with_type = {**data, "token_type": "refresh"}
  77. return self.crypto_provider.generate_secure_token(
  78. data=data_with_type,
  79. expiry=expire,
  80. )
  81. async def decode_token(self, token: str) -> TokenData:
  82. # First, check if the token is blacklisted
  83. if await self.database_provider.token_handler.is_token_blacklisted(
  84. token=token
  85. ):
  86. raise R2RException(
  87. status_code=401, message="Token has been invalidated"
  88. )
  89. # Verify token using crypto_provider
  90. payload = self.crypto_provider.verify_secure_token(token=token)
  91. if payload is None:
  92. raise R2RException(
  93. status_code=401, message="Invalid or expired token"
  94. )
  95. email: str = payload.get("sub")
  96. token_type: str = payload.get("token_type")
  97. exp: float = payload.get("exp")
  98. if email is None or token_type is None or exp is None:
  99. raise R2RException(status_code=401, message="Invalid token claims")
  100. exp_datetime = datetime.fromtimestamp(exp, tz=timezone.utc)
  101. if exp_datetime < datetime.now(timezone.utc):
  102. raise R2RException(status_code=401, message="Token has expired")
  103. return TokenData(email=email, token_type=token_type, exp=exp_datetime)
  104. async def authenticate_api_key(self, api_key: str) -> Optional[User]:
  105. """
  106. Authenticate using an API key of the form "public_key.raw_key".
  107. Returns a User if successful, or raises R2RException if not.
  108. """
  109. try:
  110. key_id, raw_key = api_key.split(".", 1)
  111. except ValueError:
  112. raise R2RException(
  113. status_code=401, message="Invalid API key format"
  114. )
  115. key_record = (
  116. await self.database_provider.users_handler.get_api_key_record(
  117. key_id=key_id
  118. )
  119. )
  120. if not key_record:
  121. raise R2RException(status_code=401, message="Invalid API key")
  122. if not self.crypto_provider.verify_api_key(
  123. raw_api_key=raw_key, hashed_key=key_record["hashed_key"]
  124. ):
  125. raise R2RException(status_code=401, message="Invalid API key")
  126. user = await self.database_provider.users_handler.get_user_by_id(
  127. id=key_record["user_id"]
  128. )
  129. if not user.is_active:
  130. raise R2RException(
  131. status_code=401, message="User account is inactive"
  132. )
  133. return user
  134. async def user(self, token: str = Depends(oauth2_scheme)) -> User:
  135. """
  136. Attempt to authenticate via JWT first, then fallback to API key.
  137. """
  138. # Try JWT auth
  139. try:
  140. token_data = await self.decode_token(token=token)
  141. if not token_data.email:
  142. raise R2RException(
  143. status_code=401, message="Could not validate credentials"
  144. )
  145. user = (
  146. await self.database_provider.users_handler.get_user_by_email(
  147. email=token_data.email
  148. )
  149. )
  150. if user is None:
  151. raise R2RException(
  152. status_code=401,
  153. message="Invalid authentication credentials",
  154. )
  155. return user
  156. except R2RException:
  157. # If JWT fails, try API key auth
  158. # OAuth2PasswordBearer provides token as "Bearer xxx", strip it if needed
  159. token = token.removeprefix("Bearer ")
  160. return await self.authenticate_api_key(api_key=token)
  161. def get_current_active_user(
  162. self, current_user: User = Depends(user)
  163. ) -> User:
  164. if not current_user.is_active:
  165. raise R2RException(status_code=400, message="Inactive user")
  166. return current_user
  167. async def register(
  168. self, email: str, password: str, is_superuser: bool = False
  169. ) -> User:
  170. new_user = await self.database_provider.users_handler.create_user(
  171. email=email, password=password, is_superuser=is_superuser
  172. )
  173. default_collection: CollectionResponse = (
  174. await self.database_provider.collections_handler.create_collection(
  175. owner_id=new_user.id,
  176. )
  177. )
  178. graph_result = await self.database_provider.graphs_handler.create(
  179. collection_id=default_collection.id,
  180. name=default_collection.name,
  181. description=default_collection.description,
  182. )
  183. await self.database_provider.users_handler.add_user_to_collection(
  184. new_user.id, default_collection.id
  185. )
  186. if self.config.require_email_verification:
  187. verification_code = (
  188. self.crypto_provider.generate_verification_code()
  189. )
  190. expiry = datetime.now(timezone.utc) + timedelta(hours=24)
  191. await self.database_provider.users_handler.store_verification_code(
  192. id=new_user.id,
  193. verification_code=verification_code,
  194. expiry=expiry,
  195. )
  196. new_user.verification_code_expiry = expiry
  197. first_name = (
  198. new_user.name.split(" ")[0]
  199. if new_user.name
  200. else email.split("@")[0]
  201. )
  202. await self.email_provider.send_verification_email(
  203. new_user.email, verification_code, {"first_name": first_name}
  204. )
  205. else:
  206. expiry = datetime.now(timezone.utc) + timedelta(hours=366 * 10)
  207. await self.database_provider.users_handler.store_verification_code(
  208. id=new_user.id,
  209. verification_code=str(-1),
  210. expiry=expiry,
  211. )
  212. await self.database_provider.users_handler.mark_user_as_verified(
  213. id=new_user.id
  214. )
  215. return new_user
  216. async def verify_email(
  217. self, email: str, verification_code: str
  218. ) -> dict[str, str]:
  219. user_id = await self.database_provider.users_handler.get_user_id_by_verification_code(
  220. verification_code=verification_code
  221. )
  222. await self.database_provider.users_handler.mark_user_as_verified(
  223. id=user_id
  224. )
  225. await self.database_provider.users_handler.remove_verification_code(
  226. verification_code=verification_code
  227. )
  228. return {"message": "Email verified successfully"}
  229. async def login(self, email: str, password: str) -> dict[str, Token]:
  230. logger.debug(f"Attempting login for email: {email}")
  231. user = await self.database_provider.users_handler.get_user_by_email(
  232. email=email
  233. )
  234. logger.debug(f"User found: {user}")
  235. if not isinstance(user.hashed_password, str):
  236. logger.error(
  237. f"Invalid hashed_password type: {type(user.hashed_password)}"
  238. )
  239. raise HTTPException(
  240. status_code=500,
  241. detail="Invalid password hash in database",
  242. )
  243. try:
  244. password_verified = self.crypto_provider.verify_password(
  245. plain_password=password,
  246. hashed_password=user.hashed_password,
  247. )
  248. except Exception as e:
  249. logger.error(f"Error during password verification: {str(e)}")
  250. raise HTTPException(
  251. status_code=500,
  252. detail="Error during password verification",
  253. ) from e
  254. if not password_verified:
  255. logger.warning(f"Invalid password for user: {email}")
  256. raise R2RException(
  257. status_code=401, message="Incorrect email or password"
  258. )
  259. if not user.is_verified and self.config.require_email_verification:
  260. logger.warning(f"Unverified user attempted login: {email}")
  261. raise R2RException(status_code=401, message="Email not verified")
  262. access_token = self.create_access_token(data={"sub": user.email})
  263. refresh_token = self.create_refresh_token(data={"sub": user.email})
  264. return {
  265. "access_token": Token(token=access_token, token_type="access"),
  266. "refresh_token": Token(token=refresh_token, token_type="refresh"),
  267. }
  268. async def refresh_access_token(
  269. self, refresh_token: str
  270. ) -> dict[str, Token]:
  271. token_data = await self.decode_token(refresh_token)
  272. if token_data.token_type != "refresh":
  273. raise R2RException(
  274. status_code=401, message="Invalid refresh token"
  275. )
  276. # Invalidate the old refresh token and create a new one
  277. await self.database_provider.token_handler.blacklist_token(
  278. token=refresh_token
  279. )
  280. new_access_token = self.create_access_token(
  281. data={"sub": token_data.email}
  282. )
  283. new_refresh_token = self.create_refresh_token(
  284. data={"sub": token_data.email}
  285. )
  286. return {
  287. "access_token": Token(token=new_access_token, token_type="access"),
  288. "refresh_token": Token(
  289. token=new_refresh_token, token_type="refresh"
  290. ),
  291. }
  292. async def change_password(
  293. self, user: User, current_password: str, new_password: str
  294. ) -> dict[str, str]:
  295. if not isinstance(user.hashed_password, str):
  296. logger.error(
  297. f"Invalid hashed_password type: {type(user.hashed_password)}"
  298. )
  299. raise HTTPException(
  300. status_code=500,
  301. detail="Invalid password hash in database",
  302. )
  303. if not self.crypto_provider.verify_password(
  304. plain_password=current_password,
  305. hashed_password=user.hashed_password,
  306. ):
  307. raise R2RException(
  308. status_code=400, message="Incorrect current password"
  309. )
  310. hashed_new_password = self.crypto_provider.get_password_hash(
  311. password=new_password
  312. )
  313. await self.database_provider.users_handler.update_user_password(
  314. id=user.id,
  315. new_hashed_password=hashed_new_password,
  316. )
  317. return {"message": "Password changed successfully"}
  318. async def request_password_reset(self, email: str) -> dict[str, str]:
  319. try:
  320. user = (
  321. await self.database_provider.users_handler.get_user_by_email(
  322. email=email
  323. )
  324. )
  325. reset_token = self.crypto_provider.generate_verification_code()
  326. expiry = datetime.now(timezone.utc) + timedelta(hours=1)
  327. await self.database_provider.users_handler.store_reset_token(
  328. id=user.id,
  329. reset_token=reset_token,
  330. expiry=expiry,
  331. )
  332. first_name = (
  333. user.name.split(" ")[0] if user.name else email.split("@")[0]
  334. )
  335. await self.email_provider.send_password_reset_email(
  336. email, reset_token, {"first_name": first_name}
  337. )
  338. return {
  339. "message": "If the email exists, a reset link has been sent"
  340. }
  341. except R2RException as e:
  342. if e.status_code == 404:
  343. # User doesn't exist; return a success message anyway
  344. return {
  345. "message": "If the email exists, a reset link has been sent"
  346. }
  347. else:
  348. raise
  349. async def confirm_password_reset(
  350. self, reset_token: str, new_password: str
  351. ) -> dict[str, str]:
  352. user_id = await self.database_provider.users_handler.get_user_id_by_reset_token(
  353. reset_token=reset_token
  354. )
  355. if not user_id:
  356. raise R2RException(
  357. status_code=400, message="Invalid or expired reset token"
  358. )
  359. hashed_new_password = self.crypto_provider.get_password_hash(
  360. password=new_password
  361. )
  362. await self.database_provider.users_handler.update_user_password(
  363. id=user_id,
  364. new_hashed_password=hashed_new_password,
  365. )
  366. await self.database_provider.users_handler.remove_reset_token(
  367. id=user_id
  368. )
  369. return {"message": "Password reset successfully"}
  370. async def logout(self, token: str) -> dict[str, str]:
  371. await self.database_provider.token_handler.blacklist_token(token=token)
  372. return {"message": "Logged out successfully"}
  373. async def clean_expired_blacklisted_tokens(self):
  374. await self.database_provider.token_handler.clean_expired_blacklisted_tokens()
  375. async def send_reset_email(self, email: str) -> dict:
  376. user = await self.database_provider.users_handler.get_user_by_email(
  377. email=email
  378. )
  379. if not user:
  380. raise R2RException(status_code=404, message="User not found")
  381. verification_code = self.crypto_provider.generate_verification_code()
  382. expiry = datetime.now(timezone.utc) + timedelta(hours=24)
  383. await self.database_provider.users_handler.store_verification_code(
  384. id=user.id,
  385. verification_code=verification_code,
  386. expiry=expiry,
  387. )
  388. first_name = (
  389. user.name.split(" ")[0] if user.name else email.split("@")[0]
  390. )
  391. await self.email_provider.send_verification_email(
  392. email, verification_code, {"first_name": first_name}
  393. )
  394. return {
  395. "verification_code": verification_code,
  396. "expiry": expiry,
  397. "message": f"Verification email sent successfully to {email}",
  398. }
  399. async def create_user_api_key(
  400. self, user_id: UUID, name: Optional[str] = None
  401. ) -> dict[str, str]:
  402. key_id, raw_api_key = self.crypto_provider.generate_api_key()
  403. hashed_key = self.crypto_provider.hash_api_key(raw_api_key)
  404. api_key_uuid = (
  405. await self.database_provider.users_handler.store_user_api_key(
  406. user_id=user_id,
  407. key_id=key_id,
  408. hashed_key=hashed_key,
  409. name=name,
  410. )
  411. )
  412. return {
  413. "api_key": f"{key_id}.{raw_api_key}",
  414. "key_id": str(api_key_uuid),
  415. "public_key": key_id,
  416. "name": name or "",
  417. }
  418. async def list_user_api_keys(self, user_id: UUID) -> list[dict]:
  419. return await self.database_provider.users_handler.get_user_api_keys(
  420. user_id=user_id
  421. )
  422. async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
  423. return await self.database_provider.users_handler.delete_api_key(
  424. user_id=user_id,
  425. key_id=key_id,
  426. )
  427. async def rename_api_key(
  428. self, user_id: UUID, key_id: UUID, new_name: str
  429. ) -> bool:
  430. return await self.database_provider.users_handler.update_api_key_name(
  431. user_id=user_id,
  432. key_id=key_id,
  433. name=new_name,
  434. )