users.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781
  1. from datetime import datetime
  2. from typing import Optional
  3. from uuid import UUID
  4. from fastapi import HTTPException
  5. from core.base import CryptoProvider, Handler
  6. from core.base.abstractions import R2RException
  7. from core.utils import generate_user_id
  8. from shared.abstractions import User
  9. from .base import PostgresConnectionManager, QueryBuilder
  10. from .collections import PostgresCollectionsHandler
  11. class PostgresUserHandler(Handler):
  12. TABLE_NAME = "users"
  13. API_KEYS_TABLE_NAME = "users_api_keys"
  14. def __init__(
  15. self,
  16. project_name: str,
  17. connection_manager: PostgresConnectionManager,
  18. crypto_provider: CryptoProvider,
  19. ):
  20. super().__init__(project_name, connection_manager)
  21. self.crypto_provider = crypto_provider
  22. async def create_tables(self):
  23. user_table_query = f"""
  24. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (
  25. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  26. email TEXT UNIQUE NOT NULL,
  27. hashed_password TEXT NOT NULL,
  28. is_superuser BOOLEAN DEFAULT FALSE,
  29. is_active BOOLEAN DEFAULT TRUE,
  30. is_verified BOOLEAN DEFAULT FALSE,
  31. verification_code TEXT,
  32. verification_code_expiry TIMESTAMPTZ,
  33. name TEXT,
  34. bio TEXT,
  35. profile_picture TEXT,
  36. reset_token TEXT,
  37. reset_token_expiry TIMESTAMPTZ,
  38. collection_ids UUID[] NULL,
  39. created_at TIMESTAMPTZ DEFAULT NOW(),
  40. updated_at TIMESTAMPTZ DEFAULT NOW()
  41. );
  42. """
  43. # API keys table with updated_at instead of last_used_at
  44. api_keys_table_query = f"""
  45. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
  46. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  47. user_id UUID NOT NULL REFERENCES {self._get_table_name(PostgresUserHandler.TABLE_NAME)}(id) ON DELETE CASCADE,
  48. public_key TEXT UNIQUE NOT NULL,
  49. hashed_key TEXT NOT NULL,
  50. name TEXT,
  51. created_at TIMESTAMPTZ DEFAULT NOW(),
  52. updated_at TIMESTAMPTZ DEFAULT NOW()
  53. );
  54. CREATE INDEX IF NOT EXISTS idx_api_keys_user_id
  55. ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(user_id);
  56. CREATE INDEX IF NOT EXISTS idx_api_keys_public_key
  57. ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(public_key);
  58. """
  59. await self.connection_manager.execute_query(user_table_query)
  60. await self.connection_manager.execute_query(api_keys_table_query)
  61. async def get_user_by_id(self, id: UUID) -> User:
  62. query, _ = (
  63. QueryBuilder(self._get_table_name("users"))
  64. .select(
  65. [
  66. "id",
  67. "email",
  68. "hashed_password",
  69. "is_superuser",
  70. "is_active",
  71. "is_verified",
  72. "created_at",
  73. "updated_at",
  74. "name",
  75. "profile_picture",
  76. "bio",
  77. "collection_ids",
  78. ]
  79. )
  80. .where("id = $1")
  81. .build()
  82. )
  83. result = await self.connection_manager.fetchrow_query(query, [id])
  84. if not result:
  85. raise R2RException(status_code=404, message="User not found")
  86. return User(
  87. id=result["id"],
  88. email=result["email"],
  89. hashed_password=result["hashed_password"],
  90. is_superuser=result["is_superuser"],
  91. is_active=result["is_active"],
  92. is_verified=result["is_verified"],
  93. created_at=result["created_at"],
  94. updated_at=result["updated_at"],
  95. name=result["name"],
  96. profile_picture=result["profile_picture"],
  97. bio=result["bio"],
  98. collection_ids=result["collection_ids"],
  99. )
  100. async def get_user_by_email(self, email: str) -> User:
  101. query, params = (
  102. QueryBuilder(self._get_table_name("users"))
  103. .select(
  104. [
  105. "id",
  106. "email",
  107. "hashed_password",
  108. "is_superuser",
  109. "is_active",
  110. "is_verified",
  111. "created_at",
  112. "updated_at",
  113. "name",
  114. "profile_picture",
  115. "bio",
  116. "collection_ids",
  117. ]
  118. )
  119. .where("email = $1")
  120. .build()
  121. )
  122. result = await self.connection_manager.fetchrow_query(query, [email])
  123. if not result:
  124. raise R2RException(status_code=404, message="User not found")
  125. return User(
  126. id=result["id"],
  127. email=result["email"],
  128. hashed_password=result["hashed_password"],
  129. is_superuser=result["is_superuser"],
  130. is_active=result["is_active"],
  131. is_verified=result["is_verified"],
  132. created_at=result["created_at"],
  133. updated_at=result["updated_at"],
  134. name=result["name"],
  135. profile_picture=result["profile_picture"],
  136. bio=result["bio"],
  137. collection_ids=result["collection_ids"],
  138. )
  139. async def create_user(
  140. self, email: str, password: str, is_superuser: bool = False
  141. ) -> User:
  142. try:
  143. if await self.get_user_by_email(email):
  144. raise R2RException(
  145. status_code=400,
  146. message="User with this email already exists",
  147. )
  148. except R2RException as e:
  149. if e.status_code != 404:
  150. raise e
  151. hashed_password = self.crypto_provider.get_password_hash(password) # type: ignore
  152. query = f"""
  153. INSERT INTO {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  154. (email, id, is_superuser, hashed_password, collection_ids)
  155. VALUES ($1, $2, $3, $4, $5)
  156. RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
  157. """
  158. result = await self.connection_manager.fetchrow_query(
  159. query,
  160. [
  161. email,
  162. generate_user_id(email),
  163. is_superuser,
  164. hashed_password,
  165. [],
  166. ],
  167. )
  168. if not result:
  169. raise HTTPException(
  170. status_code=500,
  171. detail="Failed to create user",
  172. )
  173. return User(
  174. id=result["id"],
  175. email=result["email"],
  176. is_superuser=result["is_superuser"],
  177. is_active=result["is_active"],
  178. is_verified=result["is_verified"],
  179. created_at=result["created_at"],
  180. updated_at=result["updated_at"],
  181. collection_ids=result["collection_ids"],
  182. hashed_password=hashed_password,
  183. )
  184. async def update_user(self, user: User) -> User:
  185. query = f"""
  186. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  187. SET email = $1, is_superuser = $2, is_active = $3, is_verified = $4, updated_at = NOW(),
  188. name = $5, profile_picture = $6, bio = $7, collection_ids = $8
  189. WHERE id = $9
  190. RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, name, profile_picture, bio, collection_ids
  191. """
  192. result = await self.connection_manager.fetchrow_query(
  193. query,
  194. [
  195. user.email,
  196. user.is_superuser,
  197. user.is_active,
  198. user.is_verified,
  199. user.name,
  200. user.profile_picture,
  201. user.bio,
  202. user.collection_ids,
  203. user.id,
  204. ],
  205. )
  206. if not result:
  207. raise HTTPException(
  208. status_code=500,
  209. detail="Failed to update user",
  210. )
  211. return User(
  212. id=result["id"],
  213. email=result["email"],
  214. is_superuser=result["is_superuser"],
  215. is_active=result["is_active"],
  216. is_verified=result["is_verified"],
  217. created_at=result["created_at"],
  218. updated_at=result["updated_at"],
  219. name=result["name"],
  220. profile_picture=result["profile_picture"],
  221. bio=result["bio"],
  222. collection_ids=result["collection_ids"],
  223. )
  224. async def delete_user_relational(self, id: UUID) -> None:
  225. # Get the collections the user belongs to
  226. collection_query = f"""
  227. SELECT collection_ids FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  228. WHERE id = $1
  229. """
  230. collection_result = await self.connection_manager.fetchrow_query(
  231. collection_query, [id]
  232. )
  233. if not collection_result:
  234. raise R2RException(status_code=404, message="User not found")
  235. # Remove user from documents
  236. doc_update_query = f"""
  237. UPDATE {self._get_table_name('documents')}
  238. SET id = NULL
  239. WHERE id = $1
  240. """
  241. await self.connection_manager.execute_query(doc_update_query, [id])
  242. # Delete the user
  243. delete_query = f"""
  244. DELETE FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  245. WHERE id = $1
  246. RETURNING id
  247. """
  248. result = await self.connection_manager.fetchrow_query(
  249. delete_query, [id]
  250. )
  251. if not result:
  252. raise R2RException(status_code=404, message="User not found")
  253. async def update_user_password(self, id: UUID, new_hashed_password: str):
  254. query = f"""
  255. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  256. SET hashed_password = $1, updated_at = NOW()
  257. WHERE id = $2
  258. """
  259. await self.connection_manager.execute_query(
  260. query, [new_hashed_password, id]
  261. )
  262. async def get_all_users(self) -> list[User]:
  263. query = f"""
  264. SELECT id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
  265. FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  266. """
  267. results = await self.connection_manager.fetch_query(query)
  268. return [
  269. User(
  270. id=result["id"],
  271. email=result["email"],
  272. hashed_password="null",
  273. is_superuser=result["is_superuser"],
  274. is_active=result["is_active"],
  275. is_verified=result["is_verified"],
  276. created_at=result["created_at"],
  277. updated_at=result["updated_at"],
  278. collection_ids=result["collection_ids"],
  279. )
  280. for result in results
  281. ]
  282. async def store_verification_code(
  283. self, id: UUID, verification_code: str, expiry: datetime
  284. ):
  285. query = f"""
  286. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  287. SET verification_code = $1, verification_code_expiry = $2
  288. WHERE id = $3
  289. """
  290. await self.connection_manager.execute_query(
  291. query, [verification_code, expiry, id]
  292. )
  293. async def verify_user(self, verification_code: str) -> None:
  294. query = f"""
  295. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  296. SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
  297. WHERE verification_code = $1 AND verification_code_expiry > NOW()
  298. RETURNING id
  299. """
  300. result = await self.connection_manager.fetchrow_query(
  301. query, [verification_code]
  302. )
  303. if not result:
  304. raise R2RException(
  305. status_code=400, message="Invalid or expired verification code"
  306. )
  307. async def remove_verification_code(self, verification_code: str):
  308. query = f"""
  309. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  310. SET verification_code = NULL, verification_code_expiry = NULL
  311. WHERE verification_code = $1
  312. """
  313. await self.connection_manager.execute_query(query, [verification_code])
  314. async def expire_verification_code(self, id: UUID):
  315. query = f"""
  316. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  317. SET verification_code_expiry = NOW() - INTERVAL '1 day'
  318. WHERE id = $1
  319. """
  320. await self.connection_manager.execute_query(query, [id])
  321. async def store_reset_token(
  322. self, id: UUID, reset_token: str, expiry: datetime
  323. ):
  324. query = f"""
  325. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  326. SET reset_token = $1, reset_token_expiry = $2
  327. WHERE id = $3
  328. """
  329. await self.connection_manager.execute_query(
  330. query, [reset_token, expiry, id]
  331. )
  332. async def get_user_id_by_reset_token(
  333. self, reset_token: str
  334. ) -> Optional[UUID]:
  335. query = f"""
  336. SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  337. WHERE reset_token = $1 AND reset_token_expiry > NOW()
  338. """
  339. result = await self.connection_manager.fetchrow_query(
  340. query, [reset_token]
  341. )
  342. return result["id"] if result else None
  343. async def remove_reset_token(self, id: UUID):
  344. query = f"""
  345. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  346. SET reset_token = NULL, reset_token_expiry = NULL
  347. WHERE id = $1
  348. """
  349. await self.connection_manager.execute_query(query, [id])
  350. async def remove_user_from_all_collections(self, id: UUID):
  351. query = f"""
  352. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  353. SET collection_ids = ARRAY[]::UUID[]
  354. WHERE id = $1
  355. """
  356. await self.connection_manager.execute_query(query, [id])
  357. async def add_user_to_collection(
  358. self, id: UUID, collection_id: UUID
  359. ) -> bool:
  360. # Check if the user exists
  361. if not await self.get_user_by_id(id):
  362. raise R2RException(status_code=404, message="User not found")
  363. # Check if the collection exists
  364. if not await self._collection_exists(collection_id):
  365. raise R2RException(status_code=404, message="Collection not found")
  366. query = f"""
  367. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  368. SET collection_ids = array_append(collection_ids, $1)
  369. WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
  370. RETURNING id
  371. """
  372. result = await self.connection_manager.fetchrow_query(
  373. query, [collection_id, id]
  374. )
  375. if not result:
  376. raise R2RException(
  377. status_code=400, message="User already in collection"
  378. )
  379. update_collection_query = f"""
  380. UPDATE {self._get_table_name('collections')}
  381. SET user_count = user_count + 1
  382. WHERE id = $1
  383. """
  384. await self.connection_manager.execute_query(
  385. query=update_collection_query,
  386. params=[collection_id],
  387. )
  388. return True
  389. async def remove_user_from_collection(
  390. self, id: UUID, collection_id: UUID
  391. ) -> bool:
  392. if not await self.get_user_by_id(id):
  393. raise R2RException(status_code=404, message="User not found")
  394. query = f"""
  395. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  396. SET collection_ids = array_remove(collection_ids, $1)
  397. WHERE id = $2 AND $1 = ANY(collection_ids)
  398. RETURNING id
  399. """
  400. result = await self.connection_manager.fetchrow_query(
  401. query, [collection_id, id]
  402. )
  403. if not result:
  404. raise R2RException(
  405. status_code=400,
  406. message="User is not a member of the specified collection",
  407. )
  408. return True
  409. async def get_users_in_collection(
  410. self, collection_id: UUID, offset: int, limit: int
  411. ) -> dict[str, list[User] | int]:
  412. """
  413. Get all users in a specific collection with pagination.
  414. Args:
  415. collection_id (UUID): The ID of the collection to get users from.
  416. offset (int): The number of users to skip.
  417. limit (int): The maximum number of users to return.
  418. Returns:
  419. List[User]: A list of User objects representing the users in the collection.
  420. Raises:
  421. R2RException: If the collection doesn't exist.
  422. """
  423. if not await self._collection_exists(collection_id): # type: ignore
  424. raise R2RException(status_code=404, message="Collection not found")
  425. query = f"""
  426. SELECT u.id, u.email, u.is_active, u.is_superuser, u.created_at, u.updated_at,
  427. u.is_verified, u.collection_ids, u.name, u.bio, u.profile_picture,
  428. COUNT(*) OVER() AS total_entries
  429. FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
  430. WHERE $1 = ANY(u.collection_ids)
  431. ORDER BY u.name
  432. OFFSET $2
  433. """
  434. conditions = [collection_id, offset]
  435. if limit != -1:
  436. query += " LIMIT $3"
  437. conditions.append(limit)
  438. results = await self.connection_manager.fetch_query(query, conditions)
  439. users = [
  440. User(
  441. id=row["id"],
  442. email=row["email"],
  443. is_active=row["is_active"],
  444. is_superuser=row["is_superuser"],
  445. created_at=row["created_at"],
  446. updated_at=row["updated_at"],
  447. is_verified=row["is_verified"],
  448. collection_ids=row["collection_ids"],
  449. name=row["name"],
  450. bio=row["bio"],
  451. profile_picture=row["profile_picture"],
  452. hashed_password=None,
  453. verification_code_expiry=None,
  454. )
  455. for row in results
  456. ]
  457. total_entries = results[0]["total_entries"] if results else 0
  458. return {"results": users, "total_entries": total_entries}
  459. async def mark_user_as_superuser(self, id: UUID):
  460. query = f"""
  461. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  462. SET is_superuser = TRUE, is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
  463. WHERE id = $1
  464. """
  465. await self.connection_manager.execute_query(query, [id])
  466. async def get_user_id_by_verification_code(
  467. self, verification_code: str
  468. ) -> UUID:
  469. query = f"""
  470. SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  471. WHERE verification_code = $1 AND verification_code_expiry > NOW()
  472. """
  473. result = await self.connection_manager.fetchrow_query(
  474. query, [verification_code]
  475. )
  476. if not result:
  477. raise R2RException(
  478. status_code=400, message="Invalid or expired verification code"
  479. )
  480. return result["id"]
  481. async def mark_user_as_verified(self, id: UUID):
  482. query = f"""
  483. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  484. SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
  485. WHERE id = $1
  486. """
  487. await self.connection_manager.execute_query(query, [id])
  488. async def get_users_overview(
  489. self,
  490. offset: int,
  491. limit: int,
  492. user_ids: Optional[list[UUID]] = None,
  493. ) -> dict[str, list[User] | int]:
  494. query = f"""
  495. WITH user_document_ids AS (
  496. SELECT
  497. u.id as user_id,
  498. ARRAY_AGG(d.id) FILTER (WHERE d.id IS NOT NULL) AS doc_ids
  499. FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
  500. LEFT JOIN {self._get_table_name('documents')} d ON u.id = d.owner_id
  501. GROUP BY u.id
  502. ),
  503. user_docs AS (
  504. SELECT
  505. u.id,
  506. u.email,
  507. u.is_superuser,
  508. u.is_active,
  509. u.is_verified,
  510. u.name,
  511. u.bio,
  512. u.profile_picture,
  513. u.collection_ids,
  514. u.created_at,
  515. u.updated_at,
  516. COUNT(d.id) AS num_files,
  517. COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes,
  518. ud.doc_ids as document_ids
  519. FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
  520. LEFT JOIN {self._get_table_name('documents')} d ON u.id = d.owner_id
  521. LEFT JOIN user_document_ids ud ON u.id = ud.user_id
  522. {' WHERE u.id = ANY($3::uuid[])' if user_ids else ''}
  523. GROUP BY u.id, u.email, u.is_superuser, u.is_active, u.is_verified,
  524. u.created_at, u.updated_at, u.collection_ids, ud.doc_ids
  525. )
  526. SELECT
  527. user_docs.*,
  528. COUNT(*) OVER() AS total_entries
  529. FROM user_docs
  530. ORDER BY email
  531. OFFSET $1
  532. """
  533. params: list = [offset]
  534. if limit != -1:
  535. query += " LIMIT $2"
  536. params.append(limit)
  537. if user_ids:
  538. params.append(user_ids)
  539. results = await self.connection_manager.fetch_query(query, params)
  540. users = [
  541. User(
  542. id=row["id"],
  543. email=row["email"],
  544. is_superuser=row["is_superuser"],
  545. is_active=row["is_active"],
  546. is_verified=row["is_verified"],
  547. name=row["name"],
  548. bio=row["bio"],
  549. created_at=row["created_at"],
  550. updated_at=row["updated_at"],
  551. collection_ids=row["collection_ids"] or [],
  552. num_files=row["num_files"],
  553. total_size_in_bytes=row["total_size_in_bytes"],
  554. document_ids=(
  555. []
  556. if row["document_ids"] is None
  557. else list(row["document_ids"])
  558. ),
  559. )
  560. for row in results
  561. ]
  562. if not users:
  563. raise R2RException(status_code=404, message="No users found")
  564. total_entries = results[0]["total_entries"]
  565. return {"results": users, "total_entries": total_entries}
  566. async def _collection_exists(self, collection_id: UUID) -> bool:
  567. """Check if a collection exists."""
  568. query = f"""
  569. SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
  570. WHERE id = $1
  571. """
  572. result = await self.connection_manager.fetchrow_query(
  573. query, [collection_id]
  574. )
  575. return result is not None
  576. async def get_user_validation_data(
  577. self,
  578. user_id: UUID,
  579. ) -> dict:
  580. """
  581. Get verification data for a specific user.
  582. This method should be called after superuser authorization has been verified.
  583. """
  584. query = f"""
  585. SELECT
  586. verification_code,
  587. verification_code_expiry,
  588. reset_token,
  589. reset_token_expiry
  590. FROM {self._get_table_name("users")}
  591. WHERE id = $1
  592. """
  593. result = await self.connection_manager.fetchrow_query(query, [user_id])
  594. if not result:
  595. raise R2RException(status_code=404, message="User not found")
  596. return {
  597. "verification_data": {
  598. "verification_code": result["verification_code"],
  599. "verification_code_expiry": (
  600. result["verification_code_expiry"].isoformat()
  601. if result["verification_code_expiry"]
  602. else None
  603. ),
  604. "reset_token": result["reset_token"],
  605. "reset_token_expiry": (
  606. result["reset_token_expiry"].isoformat()
  607. if result["reset_token_expiry"]
  608. else None
  609. ),
  610. }
  611. }
  612. # API Key methods
  613. async def store_user_api_key(
  614. self,
  615. user_id: UUID,
  616. key_id: str,
  617. hashed_key: str,
  618. name: Optional[str] = None,
  619. ) -> UUID:
  620. """Store a new API key for a user"""
  621. query = f"""
  622. INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
  623. (user_id, public_key, hashed_key, name)
  624. VALUES ($1, $2, $3, $4)
  625. RETURNING id
  626. """
  627. result = await self.connection_manager.fetchrow_query(
  628. query, [user_id, key_id, hashed_key, name]
  629. )
  630. if not result:
  631. raise R2RException(
  632. status_code=500, message="Failed to store API key"
  633. )
  634. return result["id"]
  635. async def get_api_key_record(self, key_id: str) -> Optional[dict]:
  636. """Get API key record and update updated_at"""
  637. query = f"""
  638. UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
  639. SET updated_at = NOW()
  640. WHERE public_key = $1
  641. RETURNING user_id, hashed_key
  642. """
  643. result = await self.connection_manager.fetchrow_query(query, [key_id])
  644. if not result:
  645. return None
  646. return {
  647. "user_id": result["user_id"],
  648. "hashed_key": result["hashed_key"],
  649. }
  650. async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
  651. """Get all API keys for a user"""
  652. query = f"""
  653. SELECT id, public_key, name, created_at, updated_at
  654. FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
  655. WHERE user_id = $1
  656. ORDER BY created_at DESC
  657. """
  658. results = await self.connection_manager.fetch_query(query, [user_id])
  659. return [
  660. {
  661. "key_id": str(row["id"]),
  662. "public_key": row["public_key"],
  663. "name": row["name"] or "",
  664. "updated_at": row["updated_at"],
  665. }
  666. for row in results
  667. ]
  668. async def delete_api_key(self, user_id: UUID, key_id: UUID) -> dict:
  669. """Delete a specific API key"""
  670. query = f"""
  671. DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
  672. WHERE id = $1 AND user_id = $2
  673. RETURNING id, public_key, name
  674. """
  675. result = await self.connection_manager.fetchrow_query(
  676. query, [key_id, user_id]
  677. )
  678. if result is None:
  679. raise R2RException(status_code=404, message="API key not found")
  680. return {
  681. "key_id": str(result["id"]),
  682. "public_key": str(result["public_key"]),
  683. "name": result["name"] or "",
  684. }
  685. async def update_api_key_name(
  686. self, user_id: UUID, key_id: UUID, name: str
  687. ) -> bool:
  688. """Update the name of an API key"""
  689. query = f"""
  690. UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
  691. SET name = $1, updated_at = NOW()
  692. WHERE id = $2 AND user_id = $3
  693. RETURNING id
  694. """
  695. result = await self.connection_manager.fetchrow_query(
  696. query, [name, key_id, user_id]
  697. )
  698. if result is None:
  699. raise R2RException(status_code=404, message="API key not found")
  700. return True