users.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660
  1. from datetime import datetime
  2. from typing import Optional
  3. from uuid import UUID
  4. from fastapi import HTTPException
  5. from core.base import CryptoProvider, Handler
  6. from core.base.abstractions import R2RException
  7. from core.utils import generate_user_id
  8. from shared.abstractions import User
  9. from .base import PostgresConnectionManager, QueryBuilder
  10. from .collections import PostgresCollectionsHandler
  11. class PostgresUserHandler(Handler):
  12. TABLE_NAME = "users"
  13. def __init__(
  14. self,
  15. project_name: str,
  16. connection_manager: PostgresConnectionManager,
  17. crypto_provider: CryptoProvider,
  18. ):
  19. super().__init__(project_name, connection_manager)
  20. self.crypto_provider = crypto_provider
  21. async def create_tables(self):
  22. query = f"""
  23. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (
  24. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  25. email TEXT UNIQUE NOT NULL,
  26. hashed_password TEXT NOT NULL,
  27. is_superuser BOOLEAN DEFAULT FALSE,
  28. is_active BOOLEAN DEFAULT TRUE,
  29. is_verified BOOLEAN DEFAULT FALSE,
  30. verification_code TEXT,
  31. verification_code_expiry TIMESTAMPTZ,
  32. name TEXT,
  33. bio TEXT,
  34. profile_picture TEXT,
  35. reset_token TEXT,
  36. reset_token_expiry TIMESTAMPTZ,
  37. collection_ids UUID[] NULL,
  38. created_at TIMESTAMPTZ DEFAULT NOW(),
  39. updated_at TIMESTAMPTZ DEFAULT NOW()
  40. );
  41. """
  42. await self.connection_manager.execute_query(query)
  43. async def get_user_by_id(self, id: UUID) -> User:
  44. query, _ = (
  45. QueryBuilder(self._get_table_name("users"))
  46. .select(
  47. [
  48. "id",
  49. "email",
  50. "hashed_password",
  51. "is_superuser",
  52. "is_active",
  53. "is_verified",
  54. "created_at",
  55. "updated_at",
  56. "name",
  57. "profile_picture",
  58. "bio",
  59. "collection_ids",
  60. ]
  61. )
  62. .where("id = $1")
  63. .build()
  64. )
  65. result = await self.connection_manager.fetchrow_query(query, [id])
  66. if not result:
  67. raise R2RException(status_code=404, message="User not found")
  68. return User(
  69. id=result["id"],
  70. email=result["email"],
  71. hashed_password=result["hashed_password"],
  72. is_superuser=result["is_superuser"],
  73. is_active=result["is_active"],
  74. is_verified=result["is_verified"],
  75. created_at=result["created_at"],
  76. updated_at=result["updated_at"],
  77. name=result["name"],
  78. profile_picture=result["profile_picture"],
  79. bio=result["bio"],
  80. collection_ids=result["collection_ids"],
  81. )
  82. async def get_user_by_email(self, email: str) -> User:
  83. query, params = (
  84. QueryBuilder(self._get_table_name("users"))
  85. .select(
  86. [
  87. "id",
  88. "email",
  89. "hashed_password",
  90. "is_superuser",
  91. "is_active",
  92. "is_verified",
  93. "created_at",
  94. "updated_at",
  95. "name",
  96. "profile_picture",
  97. "bio",
  98. "collection_ids",
  99. ]
  100. )
  101. .where("email = $1")
  102. .build()
  103. )
  104. result = await self.connection_manager.fetchrow_query(query, [email])
  105. if not result:
  106. raise R2RException(status_code=404, message="User not found")
  107. return User(
  108. id=result["id"],
  109. email=result["email"],
  110. hashed_password=result["hashed_password"],
  111. is_superuser=result["is_superuser"],
  112. is_active=result["is_active"],
  113. is_verified=result["is_verified"],
  114. created_at=result["created_at"],
  115. updated_at=result["updated_at"],
  116. name=result["name"],
  117. profile_picture=result["profile_picture"],
  118. bio=result["bio"],
  119. collection_ids=result["collection_ids"],
  120. )
  121. async def create_user(
  122. self, email: str, password: str, is_superuser: bool = False
  123. ) -> User:
  124. try:
  125. if await self.get_user_by_email(email):
  126. raise R2RException(
  127. status_code=400,
  128. message="User with this email already exists",
  129. )
  130. except R2RException as e:
  131. if e.status_code != 404:
  132. raise e
  133. hashed_password = self.crypto_provider.get_password_hash(password) # type: ignore
  134. query = f"""
  135. INSERT INTO {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  136. (email, id, is_superuser, hashed_password, collection_ids)
  137. VALUES ($1, $2, $3, $4, $5)
  138. RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
  139. """
  140. result = await self.connection_manager.fetchrow_query(
  141. query,
  142. [
  143. email,
  144. generate_user_id(email),
  145. is_superuser,
  146. hashed_password,
  147. [],
  148. ],
  149. )
  150. if not result:
  151. raise HTTPException(
  152. status_code=500,
  153. detail="Failed to create user",
  154. )
  155. return User(
  156. id=result["id"],
  157. email=result["email"],
  158. is_superuser=result["is_superuser"],
  159. is_active=result["is_active"],
  160. is_verified=result["is_verified"],
  161. created_at=result["created_at"],
  162. updated_at=result["updated_at"],
  163. collection_ids=result["collection_ids"],
  164. hashed_password=hashed_password,
  165. )
  166. async def update_user(self, user: User) -> User:
  167. query = f"""
  168. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  169. SET email = $1, is_superuser = $2, is_active = $3, is_verified = $4, updated_at = NOW(),
  170. name = $5, profile_picture = $6, bio = $7, collection_ids = $8
  171. WHERE id = $9
  172. RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, name, profile_picture, bio, collection_ids
  173. """
  174. result = await self.connection_manager.fetchrow_query(
  175. query,
  176. [
  177. user.email,
  178. user.is_superuser,
  179. user.is_active,
  180. user.is_verified,
  181. user.name,
  182. user.profile_picture,
  183. user.bio,
  184. user.collection_ids,
  185. user.id,
  186. ],
  187. )
  188. if not result:
  189. raise HTTPException(
  190. status_code=500,
  191. detail="Failed to update user",
  192. )
  193. return User(
  194. id=result["id"],
  195. email=result["email"],
  196. is_superuser=result["is_superuser"],
  197. is_active=result["is_active"],
  198. is_verified=result["is_verified"],
  199. created_at=result["created_at"],
  200. updated_at=result["updated_at"],
  201. name=result["name"],
  202. profile_picture=result["profile_picture"],
  203. bio=result["bio"],
  204. collection_ids=result["collection_ids"],
  205. )
  206. async def delete_user_relational(self, id: UUID) -> None:
  207. # Get the collections the user belongs to
  208. collection_query = f"""
  209. SELECT collection_ids FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  210. WHERE id = $1
  211. """
  212. collection_result = await self.connection_manager.fetchrow_query(
  213. collection_query, [id]
  214. )
  215. if not collection_result:
  216. raise R2RException(status_code=404, message="User not found")
  217. # Remove user from documents
  218. doc_update_query = f"""
  219. UPDATE {self._get_table_name('documents')}
  220. SET id = NULL
  221. WHERE id = $1
  222. """
  223. await self.connection_manager.execute_query(doc_update_query, [id])
  224. # Delete the user
  225. delete_query = f"""
  226. DELETE FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  227. WHERE id = $1
  228. RETURNING id
  229. """
  230. result = await self.connection_manager.fetchrow_query(
  231. delete_query, [id]
  232. )
  233. if not result:
  234. raise R2RException(status_code=404, message="User not found")
  235. async def update_user_password(self, id: UUID, new_hashed_password: str):
  236. query = f"""
  237. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  238. SET hashed_password = $1, updated_at = NOW()
  239. WHERE id = $2
  240. """
  241. await self.connection_manager.execute_query(
  242. query, [new_hashed_password, id]
  243. )
  244. async def get_all_users(self) -> list[User]:
  245. query = f"""
  246. SELECT id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
  247. FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  248. """
  249. results = await self.connection_manager.fetch_query(query)
  250. return [
  251. User(
  252. id=result["id"],
  253. email=result["email"],
  254. hashed_password="null",
  255. is_superuser=result["is_superuser"],
  256. is_active=result["is_active"],
  257. is_verified=result["is_verified"],
  258. created_at=result["created_at"],
  259. updated_at=result["updated_at"],
  260. collection_ids=result["collection_ids"],
  261. )
  262. for result in results
  263. ]
  264. async def store_verification_code(
  265. self, id: UUID, verification_code: str, expiry: datetime
  266. ):
  267. query = f"""
  268. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  269. SET verification_code = $1, verification_code_expiry = $2
  270. WHERE id = $3
  271. """
  272. await self.connection_manager.execute_query(
  273. query, [verification_code, expiry, id]
  274. )
  275. async def verify_user(self, verification_code: str) -> None:
  276. query = f"""
  277. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  278. SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
  279. WHERE verification_code = $1 AND verification_code_expiry > NOW()
  280. RETURNING id
  281. """
  282. result = await self.connection_manager.fetchrow_query(
  283. query, [verification_code]
  284. )
  285. if not result:
  286. raise R2RException(
  287. status_code=400, message="Invalid or expired verification code"
  288. )
  289. async def remove_verification_code(self, verification_code: str):
  290. query = f"""
  291. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  292. SET verification_code = NULL, verification_code_expiry = NULL
  293. WHERE verification_code = $1
  294. """
  295. await self.connection_manager.execute_query(query, [verification_code])
  296. async def expire_verification_code(self, id: UUID):
  297. query = f"""
  298. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  299. SET verification_code_expiry = NOW() - INTERVAL '1 day'
  300. WHERE id = $1
  301. """
  302. await self.connection_manager.execute_query(query, [id])
  303. async def store_reset_token(
  304. self, id: UUID, reset_token: str, expiry: datetime
  305. ):
  306. query = f"""
  307. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  308. SET reset_token = $1, reset_token_expiry = $2
  309. WHERE id = $3
  310. """
  311. await self.connection_manager.execute_query(
  312. query, [reset_token, expiry, id]
  313. )
  314. async def get_user_id_by_reset_token(
  315. self, reset_token: str
  316. ) -> Optional[UUID]:
  317. query = f"""
  318. SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  319. WHERE reset_token = $1 AND reset_token_expiry > NOW()
  320. """
  321. result = await self.connection_manager.fetchrow_query(
  322. query, [reset_token]
  323. )
  324. return result["id"] if result else None
  325. async def remove_reset_token(self, id: UUID):
  326. query = f"""
  327. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  328. SET reset_token = NULL, reset_token_expiry = NULL
  329. WHERE id = $1
  330. """
  331. await self.connection_manager.execute_query(query, [id])
  332. async def remove_user_from_all_collections(self, id: UUID):
  333. query = f"""
  334. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  335. SET collection_ids = ARRAY[]::UUID[]
  336. WHERE id = $1
  337. """
  338. await self.connection_manager.execute_query(query, [id])
  339. async def add_user_to_collection(
  340. self, id: UUID, collection_id: UUID
  341. ) -> bool:
  342. # Check if the user exists
  343. if not await self.get_user_by_id(id):
  344. raise R2RException(status_code=404, message="User not found")
  345. # Check if the collection exists
  346. if not await self._collection_exists(collection_id):
  347. raise R2RException(status_code=404, message="Collection not found")
  348. query = f"""
  349. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  350. SET collection_ids = array_append(collection_ids, $1)
  351. WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
  352. RETURNING id
  353. """
  354. result = await self.connection_manager.fetchrow_query(
  355. query, [collection_id, id]
  356. )
  357. if not result:
  358. raise R2RException(
  359. status_code=400, message="User already in collection"
  360. )
  361. update_collection_query = f"""
  362. UPDATE {self._get_table_name('collections')}
  363. SET user_count = user_count + 1
  364. WHERE id = $1
  365. """
  366. await self.connection_manager.execute_query(
  367. query=update_collection_query,
  368. params=[collection_id],
  369. )
  370. return True
  371. async def remove_user_from_collection(
  372. self, id: UUID, collection_id: UUID
  373. ) -> bool:
  374. if not await self.get_user_by_id(id):
  375. raise R2RException(status_code=404, message="User not found")
  376. query = f"""
  377. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  378. SET collection_ids = array_remove(collection_ids, $1)
  379. WHERE id = $2 AND $1 = ANY(collection_ids)
  380. RETURNING id
  381. """
  382. result = await self.connection_manager.fetchrow_query(
  383. query, [collection_id, id]
  384. )
  385. if not result:
  386. raise R2RException(
  387. status_code=400,
  388. message="User is not a member of the specified collection",
  389. )
  390. return True
  391. async def get_users_in_collection(
  392. self, collection_id: UUID, offset: int, limit: int
  393. ) -> dict[str, list[User] | int]:
  394. """
  395. Get all users in a specific collection with pagination.
  396. Args:
  397. collection_id (UUID): The ID of the collection to get users from.
  398. offset (int): The number of users to skip.
  399. limit (int): The maximum number of users to return.
  400. Returns:
  401. List[User]: A list of User objects representing the users in the collection.
  402. Raises:
  403. R2RException: If the collection doesn't exist.
  404. """
  405. if not await self._collection_exists(collection_id): # type: ignore
  406. raise R2RException(status_code=404, message="Collection not found")
  407. query = f"""
  408. SELECT u.id, u.email, u.is_active, u.is_superuser, u.created_at, u.updated_at,
  409. u.is_verified, u.collection_ids, u.name, u.bio, u.profile_picture,
  410. COUNT(*) OVER() AS total_entries
  411. FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
  412. WHERE $1 = ANY(u.collection_ids)
  413. ORDER BY u.name
  414. OFFSET $2
  415. """
  416. conditions = [collection_id, offset]
  417. if limit != -1:
  418. query += " LIMIT $3"
  419. conditions.append(limit)
  420. results = await self.connection_manager.fetch_query(query, conditions)
  421. users = [
  422. User(
  423. id=row["id"],
  424. email=row["email"],
  425. is_active=row["is_active"],
  426. is_superuser=row["is_superuser"],
  427. created_at=row["created_at"],
  428. updated_at=row["updated_at"],
  429. is_verified=row["is_verified"],
  430. collection_ids=row["collection_ids"],
  431. name=row["name"],
  432. bio=row["bio"],
  433. profile_picture=row["profile_picture"],
  434. hashed_password=None,
  435. verification_code_expiry=None,
  436. )
  437. for row in results
  438. ]
  439. total_entries = results[0]["total_entries"] if results else 0
  440. return {"results": users, "total_entries": total_entries}
  441. async def mark_user_as_superuser(self, id: UUID):
  442. query = f"""
  443. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  444. SET is_superuser = TRUE, is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
  445. WHERE id = $1
  446. """
  447. await self.connection_manager.execute_query(query, [id])
  448. async def get_user_id_by_verification_code(
  449. self, verification_code: str
  450. ) -> Optional[UUID]:
  451. query = f"""
  452. SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  453. WHERE verification_code = $1 AND verification_code_expiry > NOW()
  454. """
  455. result = await self.connection_manager.fetchrow_query(
  456. query, [verification_code]
  457. )
  458. if not result:
  459. raise R2RException(
  460. status_code=400, message="Invalid or expired verification code"
  461. )
  462. return result["id"]
  463. async def mark_user_as_verified(self, id: UUID):
  464. query = f"""
  465. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  466. SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
  467. WHERE id = $1
  468. """
  469. await self.connection_manager.execute_query(query, [id])
  470. async def get_users_overview(
  471. self,
  472. offset: int,
  473. limit: int,
  474. user_ids: Optional[list[UUID]] = None,
  475. ) -> dict[str, list[User] | int]:
  476. query = f"""
  477. WITH user_document_ids AS (
  478. SELECT
  479. u.id as user_id,
  480. ARRAY_AGG(d.id) FILTER (WHERE d.id IS NOT NULL) AS doc_ids
  481. FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
  482. LEFT JOIN {self._get_table_name('documents')} d ON u.id = d.owner_id
  483. GROUP BY u.id
  484. ),
  485. user_docs AS (
  486. SELECT
  487. u.id,
  488. u.email,
  489. u.is_superuser,
  490. u.is_active,
  491. u.is_verified,
  492. u.created_at,
  493. u.updated_at,
  494. u.collection_ids,
  495. COUNT(d.id) AS num_files,
  496. COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes,
  497. ud.doc_ids as document_ids
  498. FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
  499. LEFT JOIN {self._get_table_name('documents')} d ON u.id = d.owner_id
  500. LEFT JOIN user_document_ids ud ON u.id = ud.user_id
  501. {' WHERE u.id = ANY($3::uuid[])' if user_ids else ''}
  502. GROUP BY u.id, u.email, u.is_superuser, u.is_active, u.is_verified,
  503. u.created_at, u.updated_at, u.collection_ids, ud.doc_ids
  504. )
  505. SELECT
  506. user_docs.*,
  507. COUNT(*) OVER() AS total_entries
  508. FROM user_docs
  509. ORDER BY email
  510. OFFSET $1
  511. """
  512. params: list = [offset]
  513. if limit != -1:
  514. query += " LIMIT $2"
  515. params.append(limit)
  516. if user_ids:
  517. params.append(user_ids)
  518. results = await self.connection_manager.fetch_query(query, params)
  519. users = [
  520. User(
  521. id=row["id"],
  522. email=row["email"],
  523. is_superuser=row["is_superuser"],
  524. is_active=row["is_active"],
  525. is_verified=row["is_verified"],
  526. created_at=row["created_at"],
  527. updated_at=row["updated_at"],
  528. collection_ids=row["collection_ids"] or [],
  529. num_files=row["num_files"],
  530. total_size_in_bytes=row["total_size_in_bytes"],
  531. document_ids=(
  532. []
  533. if row["document_ids"] is None
  534. else [doc_id for doc_id in row["document_ids"]]
  535. ),
  536. )
  537. for row in results
  538. ]
  539. if not users:
  540. raise R2RException(status_code=404, message="No users found")
  541. total_entries = results[0]["total_entries"]
  542. return {"results": users, "total_entries": total_entries}
  543. async def _collection_exists(self, collection_id: UUID) -> bool:
  544. """Check if a collection exists."""
  545. query = f"""
  546. SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
  547. WHERE id = $1
  548. """
  549. result = await self.connection_manager.fetchrow_query(
  550. query, [collection_id]
  551. )
  552. return result is not None
  553. async def get_user_validation_data(
  554. self,
  555. user_id: UUID,
  556. ) -> dict:
  557. """
  558. Get verification data for a specific user.
  559. This method should be called after superuser authorization has been verified.
  560. """
  561. query = f"""
  562. SELECT
  563. verification_code,
  564. verification_code_expiry,
  565. reset_token,
  566. reset_token_expiry
  567. FROM {self._get_table_name("users")}
  568. WHERE id = $1
  569. """
  570. result = await self.connection_manager.fetchrow_query(query, [user_id])
  571. if not result:
  572. raise R2RException(status_code=404, message="User not found")
  573. return {
  574. "verification_data": {
  575. "verification_code": result["verification_code"],
  576. "verification_code_expiry": (
  577. result["verification_code_expiry"].isoformat()
  578. if result["verification_code_expiry"]
  579. else None
  580. ),
  581. "reset_token": result["reset_token"],
  582. "reset_token_expiry": (
  583. result["reset_token_expiry"].isoformat()
  584. if result["reset_token_expiry"]
  585. else None
  586. ),
  587. }
  588. }