users.py 48 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326
  1. import csv
  2. import json
  3. import tempfile
  4. from datetime import datetime
  5. from typing import IO, Optional
  6. from uuid import UUID
  7. from fastapi import HTTPException
  8. from core.base import CryptoProvider, Handler
  9. from core.base.abstractions import R2RException
  10. from core.utils import generate_user_id
  11. from shared.abstractions import User
  12. from .base import PostgresConnectionManager, QueryBuilder
  13. from .collections import PostgresCollectionsHandler
  14. def _merge_metadata(
  15. existing_metadata: dict[str, str], new_metadata: dict[str, Optional[str]]
  16. ) -> dict[str, str]:
  17. """
  18. Merges the new metadata with the existing metadata in the Stripe-style approach:
  19. - new_metadata[key] = <string> => update or add that key
  20. - new_metadata[key] = "" => remove that key
  21. - if new_metadata is empty => remove all keys
  22. """
  23. # If new_metadata is an empty dict, it signals removal of all keys.
  24. if new_metadata == {}:
  25. return {}
  26. # Copy so we don't mutate the original
  27. final_metadata = dict(existing_metadata)
  28. for key, value in new_metadata.items():
  29. # If the user sets the key to an empty string, it means "delete" that key
  30. if value == "":
  31. if key in final_metadata:
  32. del final_metadata[key]
  33. # If not None and not empty, set or override
  34. elif value is not None:
  35. final_metadata[key] = value
  36. else:
  37. # If the user sets the value to None in some contexts, decide if you want to remove or ignore
  38. # For now we might treat None same as empty string => remove
  39. if key in final_metadata:
  40. del final_metadata[key]
  41. return final_metadata
  42. class PostgresUserHandler(Handler):
  43. TABLE_NAME = "users"
  44. API_KEYS_TABLE_NAME = "users_api_keys"
  45. def __init__(
  46. self,
  47. project_name: str,
  48. connection_manager: PostgresConnectionManager,
  49. crypto_provider: CryptoProvider,
  50. ):
  51. super().__init__(project_name, connection_manager)
  52. self.crypto_provider = crypto_provider
  53. async def create_tables(self):
  54. user_table_query = f"""
  55. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (
  56. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  57. email TEXT UNIQUE NOT NULL,
  58. hashed_password TEXT NOT NULL,
  59. is_superuser BOOLEAN DEFAULT FALSE,
  60. is_active BOOLEAN DEFAULT TRUE,
  61. is_verified BOOLEAN DEFAULT FALSE,
  62. verification_code TEXT,
  63. verification_code_expiry TIMESTAMPTZ,
  64. name TEXT,
  65. bio TEXT,
  66. profile_picture TEXT,
  67. reset_token TEXT,
  68. reset_token_expiry TIMESTAMPTZ,
  69. collection_ids UUID[] NULL,
  70. limits_overrides JSONB,
  71. metadata JSONB,
  72. created_at TIMESTAMPTZ DEFAULT NOW(),
  73. updated_at TIMESTAMPTZ DEFAULT NOW(),
  74. account_type TEXT NOT NULL DEFAULT 'password',
  75. google_id TEXT,
  76. github_id TEXT
  77. );
  78. """
  79. # API keys table with updated_at instead of last_used_at
  80. api_keys_table_query = f"""
  81. CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
  82. id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
  83. user_id UUID NOT NULL REFERENCES {self._get_table_name(PostgresUserHandler.TABLE_NAME)}(id) ON DELETE CASCADE,
  84. public_key TEXT UNIQUE NOT NULL,
  85. hashed_key TEXT NOT NULL,
  86. name TEXT,
  87. description TEXT,
  88. created_at TIMESTAMPTZ DEFAULT NOW(),
  89. updated_at TIMESTAMPTZ DEFAULT NOW()
  90. );
  91. CREATE INDEX IF NOT EXISTS idx_api_keys_user_id
  92. ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(user_id);
  93. CREATE INDEX IF NOT EXISTS idx_api_keys_public_key
  94. ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(public_key);
  95. """
  96. await self.connection_manager.execute_query(user_table_query)
  97. await self.connection_manager.execute_query(api_keys_table_query)
  98. # (New) Code snippet for adding columns if missing
  99. # Postgres >= 9.6 supports "ADD COLUMN IF NOT EXISTS"
  100. check_columns_query = f"""
  101. ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
  102. ADD COLUMN IF NOT EXISTS metadata JSONB;
  103. ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
  104. ADD COLUMN IF NOT EXISTS limits_overrides JSONB;
  105. ALTER TABLE {self._get_table_name(self.API_KEYS_TABLE_NAME)}
  106. ADD COLUMN IF NOT EXISTS description TEXT;
  107. """
  108. await self.connection_manager.execute_query(check_columns_query)
  109. # Optionally, create indexes for quick lookups:
  110. check_columns_query = f"""
  111. ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
  112. ADD COLUMN IF NOT EXISTS account_type TEXT NOT NULL DEFAULT 'password',
  113. ADD COLUMN IF NOT EXISTS google_id TEXT,
  114. ADD COLUMN IF NOT EXISTS github_id TEXT;
  115. CREATE INDEX IF NOT EXISTS idx_users_google_id
  116. ON {self._get_table_name(self.TABLE_NAME)}(google_id);
  117. CREATE INDEX IF NOT EXISTS idx_users_github_id
  118. ON {self._get_table_name(self.TABLE_NAME)}(github_id);
  119. """
  120. await self.connection_manager.execute_query(check_columns_query)
  121. async def get_user_by_id(self, id: UUID) -> User:
  122. query, _ = (
  123. QueryBuilder(self._get_table_name("users"))
  124. .select(
  125. [
  126. "id",
  127. "email",
  128. "is_superuser",
  129. "is_active",
  130. "is_verified",
  131. "created_at",
  132. "updated_at",
  133. "name",
  134. "profile_picture",
  135. "bio",
  136. "collection_ids",
  137. "limits_overrides",
  138. "metadata",
  139. "account_type",
  140. "hashed_password",
  141. "google_id",
  142. "github_id",
  143. ]
  144. )
  145. .where("id = $1")
  146. .build()
  147. )
  148. result = await self.connection_manager.fetchrow_query(query, [id])
  149. if not result:
  150. raise R2RException(status_code=404, message="User not found")
  151. return User(
  152. id=result["id"],
  153. email=result["email"],
  154. is_superuser=result["is_superuser"],
  155. is_active=result["is_active"],
  156. is_verified=result["is_verified"],
  157. created_at=result["created_at"],
  158. updated_at=result["updated_at"],
  159. name=result["name"],
  160. profile_picture=result["profile_picture"],
  161. bio=result["bio"],
  162. collection_ids=result["collection_ids"],
  163. limits_overrides=json.loads(result["limits_overrides"] or "{}"),
  164. metadata=json.loads(result["metadata"] or "{}"),
  165. hashed_password=result["hashed_password"],
  166. account_type=result["account_type"],
  167. google_id=result["google_id"],
  168. github_id=result["github_id"],
  169. )
  170. async def get_user_by_email(self, email: str) -> User:
  171. query, params = (
  172. QueryBuilder(self._get_table_name("users"))
  173. .select(
  174. [
  175. "id",
  176. "email",
  177. "is_superuser",
  178. "is_active",
  179. "is_verified",
  180. "created_at",
  181. "updated_at",
  182. "name",
  183. "profile_picture",
  184. "bio",
  185. "collection_ids",
  186. "metadata",
  187. "limits_overrides",
  188. "account_type",
  189. "hashed_password",
  190. "google_id",
  191. "github_id",
  192. ]
  193. )
  194. .where("email = $1")
  195. .build()
  196. )
  197. result = await self.connection_manager.fetchrow_query(query, [email])
  198. if not result:
  199. raise R2RException(status_code=404, message="User not found")
  200. return User(
  201. id=result["id"],
  202. email=result["email"],
  203. is_superuser=result["is_superuser"],
  204. is_active=result["is_active"],
  205. is_verified=result["is_verified"],
  206. created_at=result["created_at"],
  207. updated_at=result["updated_at"],
  208. name=result["name"],
  209. profile_picture=result["profile_picture"],
  210. bio=result["bio"],
  211. collection_ids=result["collection_ids"],
  212. limits_overrides=json.loads(result["limits_overrides"] or "{}"),
  213. metadata=json.loads(result["metadata"] or "{}"),
  214. account_type=result["account_type"],
  215. hashed_password=result["hashed_password"],
  216. google_id=result["google_id"],
  217. github_id=result["github_id"],
  218. )
  219. async def create_user(
  220. self,
  221. email: str,
  222. password: Optional[str] = None,
  223. account_type: Optional[str] = "password",
  224. google_id: Optional[str] = None,
  225. github_id: Optional[str] = None,
  226. is_superuser: bool = False,
  227. is_verified: bool = False,
  228. name: Optional[str] = None,
  229. bio: Optional[str] = None,
  230. profile_picture: Optional[str] = None,
  231. ) -> User:
  232. """Create a new user."""
  233. # 1) Check if a user with this email already exists
  234. try:
  235. existing = await self.get_user_by_email(email)
  236. if existing:
  237. raise R2RException(
  238. status_code=400,
  239. message="User with this email already exists",
  240. )
  241. except R2RException as e:
  242. if e.status_code != 404:
  243. raise e
  244. # 2) If google_id is provided, ensure no user already has it
  245. if google_id:
  246. existing_google_user = await self.get_user_by_google_id(google_id)
  247. if existing_google_user:
  248. raise R2RException(
  249. status_code=400,
  250. message="User with this Google account already exists",
  251. )
  252. # 3) If github_id is provided, ensure no user already has it
  253. if github_id:
  254. existing_github_user = await self.get_user_by_github_id(github_id)
  255. if existing_github_user:
  256. raise R2RException(
  257. status_code=400,
  258. message="User with this GitHub account already exists",
  259. )
  260. hashed_password = None
  261. if account_type == "password":
  262. if password is None:
  263. raise R2RException(
  264. status_code=400,
  265. message="Password is required for a 'password' account_type",
  266. )
  267. hashed_password = self.crypto_provider.get_password_hash(password) # type: ignore
  268. query, params = (
  269. QueryBuilder(self._get_table_name(self.TABLE_NAME))
  270. .insert(
  271. {
  272. "email": email,
  273. "id": generate_user_id(email),
  274. "is_superuser": is_superuser,
  275. "collection_ids": [],
  276. "limits_overrides": None,
  277. "metadata": None,
  278. "account_type": account_type,
  279. "hashed_password": hashed_password
  280. or "", # Ensure hashed_password is not None
  281. # !!WARNING - Upstream checks are required to treat oauth differently from password!!
  282. "google_id": google_id,
  283. "github_id": github_id,
  284. "is_verified": is_verified or (account_type != "password"),
  285. "name": name,
  286. "bio": bio,
  287. "profile_picture": profile_picture,
  288. }
  289. )
  290. .returning(
  291. [
  292. "id",
  293. "email",
  294. "is_superuser",
  295. "is_active",
  296. "is_verified",
  297. "created_at",
  298. "updated_at",
  299. "collection_ids",
  300. "limits_overrides",
  301. "metadata",
  302. "name",
  303. "bio",
  304. "profile_picture",
  305. ]
  306. )
  307. .build()
  308. )
  309. result = await self.connection_manager.fetchrow_query(query, params)
  310. if not result:
  311. raise R2RException(
  312. status_code=500,
  313. message="Failed to create user",
  314. )
  315. return User(
  316. id=result["id"],
  317. email=result["email"],
  318. is_superuser=result["is_superuser"],
  319. is_active=result["is_active"],
  320. is_verified=result["is_verified"],
  321. created_at=result["created_at"],
  322. updated_at=result["updated_at"],
  323. collection_ids=result["collection_ids"] or [],
  324. limits_overrides=json.loads(result["limits_overrides"] or "{}"),
  325. metadata=json.loads(result["metadata"] or "{}"),
  326. name=result["name"],
  327. bio=result["bio"],
  328. profile_picture=result["profile_picture"],
  329. account_type=account_type or "password",
  330. hashed_password=hashed_password,
  331. google_id=google_id,
  332. github_id=github_id,
  333. )
  334. async def update_user(
  335. self,
  336. user: User,
  337. merge_limits: bool = False,
  338. new_metadata: dict[str, Optional[str]] | None = None,
  339. ) -> User:
  340. """Update user information including limits_overrides.
  341. Args:
  342. user: User object containing updated information
  343. merge_limits: If True, will merge existing limits_overrides with new ones.
  344. If False, will overwrite existing limits_overrides.
  345. Returns:
  346. Updated User object
  347. """
  348. # Get current user if we need to merge limits or get hashed password
  349. current_user = None
  350. try:
  351. current_user = await self.get_user_by_id(user.id)
  352. except R2RException:
  353. raise R2RException(
  354. status_code=404, message="User not found"
  355. ) from None
  356. # If the new user.google_id != current_user.google_id, check for duplicates
  357. if user.email and (user.email != current_user.email):
  358. existing_email_user = await self.get_user_by_email(user.email)
  359. if existing_email_user and existing_email_user.id != user.id:
  360. raise R2RException(
  361. status_code=400,
  362. message="That email account is already associated with another user.",
  363. )
  364. # If the new user.google_id != current_user.google_id, check for duplicates
  365. if user.google_id and (user.google_id != current_user.google_id):
  366. existing_google_user = await self.get_user_by_google_id(
  367. user.google_id
  368. )
  369. if existing_google_user and existing_google_user.id != user.id:
  370. raise R2RException(
  371. status_code=400,
  372. message="That Google account is already associated with another user.",
  373. )
  374. # Similarly for GitHub:
  375. if user.github_id and (user.github_id != current_user.github_id):
  376. existing_github_user = await self.get_user_by_github_id(
  377. user.github_id
  378. )
  379. if existing_github_user and existing_github_user.id != user.id:
  380. raise R2RException(
  381. status_code=400,
  382. message="That GitHub account is already associated with another user.",
  383. )
  384. # Merge or replace metadata if provided
  385. final_metadata = current_user.metadata or {}
  386. if new_metadata is not None:
  387. final_metadata = _merge_metadata(final_metadata, new_metadata)
  388. # Merge or replace limits_overrides
  389. final_limits = user.limits_overrides
  390. if (
  391. merge_limits
  392. and current_user.limits_overrides
  393. and user.limits_overrides
  394. ):
  395. final_limits = {
  396. **current_user.limits_overrides,
  397. **user.limits_overrides,
  398. }
  399. query = f"""
  400. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  401. SET email = $1,
  402. is_superuser = $2,
  403. is_active = $3,
  404. is_verified = $4,
  405. updated_at = NOW(),
  406. name = $5,
  407. profile_picture = $6,
  408. bio = $7,
  409. collection_ids = $8,
  410. limits_overrides = $9::jsonb,
  411. metadata = $10::jsonb
  412. WHERE id = $11
  413. RETURNING id, email, is_superuser, is_active, is_verified,
  414. created_at, updated_at, name, profile_picture, bio,
  415. collection_ids, limits_overrides, metadata, hashed_password,
  416. account_type, google_id, github_id
  417. """
  418. result = await self.connection_manager.fetchrow_query(
  419. query,
  420. [
  421. user.email,
  422. user.is_superuser,
  423. user.is_active,
  424. user.is_verified,
  425. user.name,
  426. user.profile_picture,
  427. user.bio,
  428. user.collection_ids or [],
  429. json.dumps(final_limits),
  430. json.dumps(final_metadata),
  431. user.id,
  432. ],
  433. )
  434. if not result:
  435. raise HTTPException(
  436. status_code=500,
  437. detail="Failed to update user",
  438. )
  439. return User(
  440. id=result["id"],
  441. email=result["email"],
  442. is_superuser=result["is_superuser"],
  443. is_active=result["is_active"],
  444. is_verified=result["is_verified"],
  445. created_at=result["created_at"],
  446. updated_at=result["updated_at"],
  447. name=result["name"],
  448. profile_picture=result["profile_picture"],
  449. bio=result["bio"],
  450. collection_ids=result["collection_ids"]
  451. or [], # Ensure null becomes empty array
  452. limits_overrides=json.loads(
  453. result["limits_overrides"] or "{}"
  454. ), # Can be null
  455. metadata=json.loads(result["metadata"] or "{}"),
  456. account_type=result["account_type"],
  457. hashed_password=result[
  458. "hashed_password"
  459. ], # Include hashed_password
  460. google_id=result["google_id"],
  461. github_id=result["github_id"],
  462. )
  463. async def delete_user_relational(self, id: UUID) -> None:
  464. """Delete a user and update related records."""
  465. # Get the collections the user belongs to
  466. collection_query, params = (
  467. QueryBuilder(self._get_table_name(self.TABLE_NAME))
  468. .select(["collection_ids"])
  469. .where("id = $1")
  470. .build()
  471. )
  472. collection_result = await self.connection_manager.fetchrow_query(
  473. collection_query, [id]
  474. )
  475. if not collection_result:
  476. raise R2RException(status_code=404, message="User not found")
  477. # Update documents query
  478. doc_update_query, doc_params = (
  479. QueryBuilder(self._get_table_name("documents"))
  480. .update({"id": None})
  481. .where("id = $1")
  482. .build()
  483. )
  484. await self.connection_manager.execute_query(doc_update_query, [id])
  485. # Delete user query
  486. delete_query, del_params = (
  487. QueryBuilder(self._get_table_name(self.TABLE_NAME))
  488. .delete()
  489. .where("id = $1")
  490. .returning(["id"])
  491. .build()
  492. )
  493. result = await self.connection_manager.fetchrow_query(
  494. delete_query, [id]
  495. )
  496. if not result:
  497. raise R2RException(status_code=404, message="User not found")
  498. async def update_user_password(self, id: UUID, new_hashed_password: str):
  499. query = f"""
  500. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  501. SET hashed_password = $1, updated_at = NOW()
  502. WHERE id = $2
  503. """
  504. await self.connection_manager.execute_query(
  505. query, [new_hashed_password, id]
  506. )
  507. async def get_all_users(self) -> list[User]:
  508. """Get all users with minimal information."""
  509. query, params = (
  510. QueryBuilder(self._get_table_name(self.TABLE_NAME))
  511. .select(
  512. [
  513. "id",
  514. "email",
  515. "is_superuser",
  516. "is_active",
  517. "is_verified",
  518. "created_at",
  519. "updated_at",
  520. "collection_ids",
  521. "hashed_password",
  522. "limits_overrides",
  523. "metadata",
  524. "name",
  525. "bio",
  526. "profile_picture",
  527. "account_type",
  528. "google_id",
  529. "github_id",
  530. ]
  531. )
  532. .build()
  533. )
  534. results = await self.connection_manager.fetch_query(query, params)
  535. return [
  536. User(
  537. id=result["id"],
  538. email=result["email"],
  539. is_superuser=result["is_superuser"],
  540. is_active=result["is_active"],
  541. is_verified=result["is_verified"],
  542. created_at=result["created_at"],
  543. updated_at=result["updated_at"],
  544. collection_ids=result["collection_ids"] or [],
  545. limits_overrides=json.loads(
  546. result["limits_overrides"] or "{}"
  547. ),
  548. metadata=json.loads(result["metadata"] or "{}"),
  549. name=result["name"],
  550. bio=result["bio"],
  551. profile_picture=result["profile_picture"],
  552. account_type=result["account_type"],
  553. hashed_password=result["hashed_password"],
  554. google_id=result["google_id"],
  555. github_id=result["github_id"],
  556. )
  557. for result in results
  558. ]
  559. async def store_verification_code(
  560. self, id: UUID, verification_code: str, expiry: datetime
  561. ):
  562. query = f"""
  563. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  564. SET verification_code = $1, verification_code_expiry = $2
  565. WHERE id = $3
  566. """
  567. await self.connection_manager.execute_query(
  568. query, [verification_code, expiry, id]
  569. )
  570. async def verify_user(self, verification_code: str) -> None:
  571. query = f"""
  572. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  573. SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
  574. WHERE verification_code = $1 AND verification_code_expiry > NOW()
  575. RETURNING id
  576. """
  577. result = await self.connection_manager.fetchrow_query(
  578. query, [verification_code]
  579. )
  580. if not result:
  581. raise R2RException(
  582. status_code=400, message="Invalid or expired verification code"
  583. )
  584. async def remove_verification_code(self, verification_code: str):
  585. query = f"""
  586. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  587. SET verification_code = NULL, verification_code_expiry = NULL
  588. WHERE verification_code = $1
  589. """
  590. await self.connection_manager.execute_query(query, [verification_code])
  591. async def expire_verification_code(self, id: UUID):
  592. query = f"""
  593. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  594. SET verification_code_expiry = NOW() - INTERVAL '1 day'
  595. WHERE id = $1
  596. """
  597. await self.connection_manager.execute_query(query, [id])
  598. async def store_reset_token(
  599. self, id: UUID, reset_token: str, expiry: datetime
  600. ):
  601. query = f"""
  602. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  603. SET reset_token = $1, reset_token_expiry = $2
  604. WHERE id = $3
  605. """
  606. await self.connection_manager.execute_query(
  607. query, [reset_token, expiry, id]
  608. )
  609. async def get_user_id_by_reset_token(
  610. self, reset_token: str
  611. ) -> Optional[UUID]:
  612. query = f"""
  613. SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  614. WHERE reset_token = $1 AND reset_token_expiry > NOW()
  615. """
  616. result = await self.connection_manager.fetchrow_query(
  617. query, [reset_token]
  618. )
  619. return result["id"] if result else None
  620. async def remove_reset_token(self, id: UUID):
  621. query = f"""
  622. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  623. SET reset_token = NULL, reset_token_expiry = NULL
  624. WHERE id = $1
  625. """
  626. await self.connection_manager.execute_query(query, [id])
  627. async def remove_user_from_all_collections(self, id: UUID):
  628. query = f"""
  629. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  630. SET collection_ids = ARRAY[]::UUID[]
  631. WHERE id = $1
  632. """
  633. await self.connection_manager.execute_query(query, [id])
  634. async def add_user_to_collection(
  635. self, id: UUID, collection_id: UUID
  636. ) -> bool:
  637. # Check if the user exists
  638. if not await self.get_user_by_id(id):
  639. raise R2RException(status_code=404, message="User not found")
  640. # Check if the collection exists
  641. if not await self._collection_exists(collection_id):
  642. raise R2RException(status_code=404, message="Collection not found")
  643. query = f"""
  644. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  645. SET collection_ids = array_append(collection_ids, $1)
  646. WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
  647. RETURNING id
  648. """
  649. result = await self.connection_manager.fetchrow_query(
  650. query, [collection_id, id]
  651. )
  652. if not result:
  653. raise R2RException(
  654. status_code=400, message="User already in collection"
  655. )
  656. update_collection_query = f"""
  657. UPDATE {self._get_table_name("collections")}
  658. SET user_count = user_count + 1
  659. WHERE id = $1
  660. """
  661. await self.connection_manager.execute_query(
  662. query=update_collection_query,
  663. params=[collection_id],
  664. )
  665. return True
  666. async def remove_user_from_collection(
  667. self, id: UUID, collection_id: UUID
  668. ) -> bool:
  669. if not await self.get_user_by_id(id):
  670. raise R2RException(status_code=404, message="User not found")
  671. query = f"""
  672. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  673. SET collection_ids = array_remove(collection_ids, $1)
  674. WHERE id = $2 AND $1 = ANY(collection_ids)
  675. RETURNING id
  676. """
  677. result = await self.connection_manager.fetchrow_query(
  678. query, [collection_id, id]
  679. )
  680. if not result:
  681. raise R2RException(
  682. status_code=400,
  683. message="User is not a member of the specified collection",
  684. )
  685. return True
  686. async def get_users_in_collection(
  687. self, collection_id: UUID, offset: int, limit: int
  688. ) -> dict[str, list[User] | int]:
  689. """Get all users in a specific collection with pagination."""
  690. if not await self._collection_exists(collection_id):
  691. raise R2RException(status_code=404, message="Collection not found")
  692. query, params = (
  693. QueryBuilder(self._get_table_name(self.TABLE_NAME))
  694. .select(
  695. [
  696. "id",
  697. "email",
  698. "is_active",
  699. "is_superuser",
  700. "created_at",
  701. "updated_at",
  702. "is_verified",
  703. "collection_ids",
  704. "name",
  705. "bio",
  706. "profile_picture",
  707. "limits_overrides",
  708. "metadata",
  709. "account_type",
  710. "hashed_password",
  711. "google_id",
  712. "github_id",
  713. "COUNT(*) OVER() AS total_entries",
  714. ]
  715. )
  716. .where("$1 = ANY(collection_ids)")
  717. .order_by("name")
  718. .offset("$2")
  719. .limit("$3" if limit != -1 else None)
  720. .build()
  721. )
  722. conditions = [collection_id, offset]
  723. if limit != -1:
  724. conditions.append(limit)
  725. results = await self.connection_manager.fetch_query(query, conditions)
  726. users_list = [
  727. User(
  728. id=row["id"],
  729. email=row["email"],
  730. is_active=row["is_active"],
  731. is_superuser=row["is_superuser"],
  732. created_at=row["created_at"],
  733. updated_at=row["updated_at"],
  734. is_verified=row["is_verified"],
  735. collection_ids=row["collection_ids"] or [],
  736. name=row["name"],
  737. bio=row["bio"],
  738. profile_picture=row["profile_picture"],
  739. limits_overrides=json.loads(row["limits_overrides"] or "{}"),
  740. metadata=json.loads(row["metadata"] or "{}"),
  741. account_type=row["account_type"],
  742. hashed_password=row["hashed_password"],
  743. google_id=row["google_id"],
  744. github_id=row["github_id"],
  745. )
  746. for row in results
  747. ]
  748. total_entries = results[0]["total_entries"] if results else 0
  749. return {"results": users_list, "total_entries": total_entries}
  750. async def mark_user_as_superuser(self, id: UUID):
  751. query = f"""
  752. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  753. SET is_superuser = TRUE, is_verified = TRUE,
  754. verification_code = NULL, verification_code_expiry = NULL
  755. WHERE id = $1
  756. """
  757. await self.connection_manager.execute_query(query, [id])
  758. async def get_user_id_by_verification_code(
  759. self, verification_code: str
  760. ) -> UUID:
  761. query = f"""
  762. SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  763. WHERE verification_code = $1 AND verification_code_expiry > NOW()
  764. """
  765. result = await self.connection_manager.fetchrow_query(
  766. query, [verification_code]
  767. )
  768. if not result:
  769. raise R2RException(
  770. status_code=400, message="Invalid or expired verification code"
  771. )
  772. return result["id"]
  773. async def mark_user_as_verified(self, id: UUID):
  774. query = f"""
  775. UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
  776. SET is_verified = TRUE,
  777. verification_code = NULL,
  778. verification_code_expiry = NULL
  779. WHERE id = $1
  780. """
  781. await self.connection_manager.execute_query(query, [id])
  782. async def get_users_overview(
  783. self,
  784. offset: int,
  785. limit: int,
  786. user_ids: Optional[list[UUID]] = None,
  787. ) -> dict[str, list[User] | int]:
  788. """Return users with document usage and total entries."""
  789. query = f"""
  790. WITH user_document_ids AS (
  791. SELECT
  792. u.id as user_id,
  793. ARRAY_AGG(d.id) FILTER (WHERE d.id IS NOT NULL) AS doc_ids
  794. FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
  795. LEFT JOIN {self._get_table_name("documents")} d ON u.id = d.owner_id
  796. GROUP BY u.id
  797. ),
  798. user_docs AS (
  799. SELECT
  800. u.id,
  801. u.email,
  802. u.is_superuser,
  803. u.is_active,
  804. u.is_verified,
  805. u.name,
  806. u.bio,
  807. u.profile_picture,
  808. u.collection_ids,
  809. u.created_at,
  810. u.updated_at,
  811. COUNT(d.id) AS num_files,
  812. COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes,
  813. ud.doc_ids as document_ids
  814. FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
  815. LEFT JOIN {self._get_table_name("documents")} d ON u.id = d.owner_id
  816. LEFT JOIN user_document_ids ud ON u.id = ud.user_id
  817. {" WHERE u.id = ANY($3::uuid[])" if user_ids else ""}
  818. GROUP BY u.id, u.email, u.is_superuser, u.is_active, u.is_verified,
  819. u.created_at, u.updated_at, u.collection_ids, ud.doc_ids
  820. )
  821. SELECT
  822. user_docs.*,
  823. COUNT(*) OVER() AS total_entries
  824. FROM user_docs
  825. ORDER BY email
  826. OFFSET $1
  827. """
  828. params: list = [offset]
  829. if limit != -1:
  830. query += " LIMIT $2"
  831. params.append(limit)
  832. if user_ids:
  833. params.append(user_ids)
  834. results = await self.connection_manager.fetch_query(query, params)
  835. if not results:
  836. raise R2RException(status_code=404, message="No users found")
  837. users_list = []
  838. for row in results:
  839. users_list.append(
  840. User(
  841. id=row["id"],
  842. email=row["email"],
  843. is_superuser=row["is_superuser"],
  844. is_active=row["is_active"],
  845. is_verified=row["is_verified"],
  846. name=row["name"],
  847. bio=row["bio"],
  848. created_at=row["created_at"],
  849. updated_at=row["updated_at"],
  850. profile_picture=row["profile_picture"],
  851. collection_ids=row["collection_ids"] or [],
  852. num_files=row["num_files"],
  853. total_size_in_bytes=row["total_size_in_bytes"],
  854. document_ids=(
  855. list(row["document_ids"])
  856. if row["document_ids"]
  857. else []
  858. ),
  859. )
  860. )
  861. total_entries = results[0]["total_entries"]
  862. return {"results": users_list, "total_entries": total_entries}
  863. async def _collection_exists(self, collection_id: UUID) -> bool:
  864. """Check if a collection exists."""
  865. query = f"""
  866. SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
  867. WHERE id = $1
  868. """
  869. result = await self.connection_manager.fetchrow_query(
  870. query, [collection_id]
  871. )
  872. return result is not None
  873. async def get_user_validation_data(
  874. self,
  875. user_id: UUID,
  876. ) -> dict:
  877. """Get verification data for a specific user.
  878. This method should be called after superuser authorization has been
  879. verified.
  880. """
  881. query = f"""
  882. SELECT
  883. verification_code,
  884. verification_code_expiry,
  885. reset_token,
  886. reset_token_expiry
  887. FROM {self._get_table_name("users")}
  888. WHERE id = $1
  889. """
  890. result = await self.connection_manager.fetchrow_query(query, [user_id])
  891. if not result:
  892. raise R2RException(status_code=404, message="User not found")
  893. return {
  894. "verification_data": {
  895. "verification_code": result["verification_code"],
  896. "verification_code_expiry": (
  897. result["verification_code_expiry"].isoformat()
  898. if result["verification_code_expiry"]
  899. else None
  900. ),
  901. "reset_token": result["reset_token"],
  902. "reset_token_expiry": (
  903. result["reset_token_expiry"].isoformat()
  904. if result["reset_token_expiry"]
  905. else None
  906. ),
  907. }
  908. }
  909. # API Key methods
  910. async def store_user_api_key(
  911. self,
  912. user_id: UUID,
  913. key_id: str,
  914. hashed_key: str,
  915. name: Optional[str] = None,
  916. description: Optional[str] = None,
  917. ) -> UUID:
  918. """Store a new API key for a user with optional name and
  919. description."""
  920. query = f"""
  921. INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
  922. (user_id, public_key, hashed_key, name, description)
  923. VALUES ($1, $2, $3, $4, $5)
  924. RETURNING id
  925. """
  926. result = await self.connection_manager.fetchrow_query(
  927. query, [user_id, key_id, hashed_key, name or "", description or ""]
  928. )
  929. if not result:
  930. raise R2RException(
  931. status_code=500, message="Failed to store API key"
  932. )
  933. return result["id"]
  934. async def get_api_key_record(self, key_id: str) -> Optional[dict]:
  935. """Get API key record by 'public_key' and update 'updated_at' to now.
  936. Returns { "user_id", "hashed_key" } or None if not found.
  937. """
  938. query = f"""
  939. UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
  940. SET updated_at = NOW()
  941. WHERE public_key = $1
  942. RETURNING user_id, hashed_key
  943. """
  944. result = await self.connection_manager.fetchrow_query(query, [key_id])
  945. if not result:
  946. return None
  947. return {
  948. "user_id": result["user_id"],
  949. "hashed_key": result["hashed_key"],
  950. }
  951. async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
  952. """Get all API keys for a user."""
  953. query = f"""
  954. SELECT id, public_key, name, description, created_at, updated_at
  955. FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
  956. WHERE user_id = $1
  957. ORDER BY created_at DESC
  958. """
  959. results = await self.connection_manager.fetch_query(query, [user_id])
  960. return [
  961. {
  962. "key_id": str(row["id"]),
  963. "public_key": row["public_key"],
  964. "name": row["name"] or "",
  965. "description": row["description"] or "",
  966. "updated_at": row["updated_at"],
  967. }
  968. for row in results
  969. ]
  970. async def delete_api_key(self, user_id: UUID, key_id: UUID) -> bool:
  971. """Delete a specific API key."""
  972. query = f"""
  973. DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
  974. WHERE id = $1 AND user_id = $2
  975. RETURNING id, public_key, name, description
  976. """
  977. result = await self.connection_manager.fetchrow_query(
  978. query, [key_id, user_id]
  979. )
  980. if result is None:
  981. raise R2RException(status_code=404, message="API key not found")
  982. return True
  983. async def update_api_key_name(
  984. self, user_id: UUID, key_id: UUID, name: str
  985. ) -> bool:
  986. """Update the name of an existing API key."""
  987. query = f"""
  988. UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
  989. SET name = $1, updated_at = NOW()
  990. WHERE id = $2 AND user_id = $3
  991. RETURNING id
  992. """
  993. result = await self.connection_manager.fetchrow_query(
  994. query, [name, key_id, user_id]
  995. )
  996. if result is None:
  997. raise R2RException(status_code=404, message="API key not found")
  998. return True
  999. async def export_to_csv(
  1000. self,
  1001. columns: Optional[list[str]] = None,
  1002. filters: Optional[dict] = None,
  1003. include_header: bool = True,
  1004. ) -> tuple[str, IO]:
  1005. """Creates a CSV file from the PostgreSQL data and returns the path to
  1006. the temp file."""
  1007. valid_columns = {
  1008. "id",
  1009. "email",
  1010. "is_superuser",
  1011. "is_active",
  1012. "is_verified",
  1013. "name",
  1014. "bio",
  1015. "collection_ids",
  1016. "created_at",
  1017. "updated_at",
  1018. }
  1019. if not columns:
  1020. columns = list(valid_columns)
  1021. elif invalid_cols := set(columns) - valid_columns:
  1022. raise ValueError(f"Invalid columns: {invalid_cols}")
  1023. select_stmt = f"""
  1024. SELECT
  1025. id::text,
  1026. email,
  1027. is_superuser,
  1028. is_active,
  1029. is_verified,
  1030. name,
  1031. bio,
  1032. collection_ids::text,
  1033. to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
  1034. to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
  1035. FROM {self._get_table_name(self.TABLE_NAME)}
  1036. """
  1037. params = []
  1038. if filters:
  1039. conditions = []
  1040. param_index = 1
  1041. for field, value in filters.items():
  1042. if field not in valid_columns:
  1043. continue
  1044. if isinstance(value, dict):
  1045. for op, val in value.items():
  1046. if op == "$eq":
  1047. conditions.append(f"{field} = ${param_index}")
  1048. params.append(val)
  1049. param_index += 1
  1050. elif op == "$gt":
  1051. conditions.append(f"{field} > ${param_index}")
  1052. params.append(val)
  1053. param_index += 1
  1054. elif op == "$lt":
  1055. conditions.append(f"{field} < ${param_index}")
  1056. params.append(val)
  1057. param_index += 1
  1058. else:
  1059. # Direct equality
  1060. conditions.append(f"{field} = ${param_index}")
  1061. params.append(value)
  1062. param_index += 1
  1063. if conditions:
  1064. select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
  1065. select_stmt = f"{select_stmt} ORDER BY created_at DESC"
  1066. temp_file = None
  1067. try:
  1068. temp_file = tempfile.NamedTemporaryFile(
  1069. mode="w", delete=True, suffix=".csv"
  1070. )
  1071. writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
  1072. async with self.connection_manager.pool.get_connection() as conn: # type: ignore
  1073. async with conn.transaction():
  1074. cursor = await conn.cursor(select_stmt, *params)
  1075. if include_header:
  1076. writer.writerow(columns)
  1077. chunk_size = 1000
  1078. while True:
  1079. rows = await cursor.fetch(chunk_size)
  1080. if not rows:
  1081. break
  1082. for row in rows:
  1083. row_dict = {
  1084. "id": row[0],
  1085. "email": row[1],
  1086. "is_superuser": row[2],
  1087. "is_active": row[3],
  1088. "is_verified": row[4],
  1089. "name": row[5],
  1090. "bio": row[6],
  1091. "collection_ids": row[7],
  1092. "created_at": row[8],
  1093. "updated_at": row[9],
  1094. }
  1095. writer.writerow([row_dict[col] for col in columns])
  1096. temp_file.flush()
  1097. return temp_file.name, temp_file
  1098. except Exception as e:
  1099. if temp_file:
  1100. temp_file.close()
  1101. raise HTTPException(
  1102. status_code=500,
  1103. detail=f"Failed to export data: {str(e)}",
  1104. ) from e
  1105. async def get_user_by_google_id(self, google_id: str) -> Optional[User]:
  1106. """Return a User if the google_id is found; otherwise None."""
  1107. query, params = (
  1108. QueryBuilder(self._get_table_name("users"))
  1109. .select(
  1110. [
  1111. "id",
  1112. "email",
  1113. "is_superuser",
  1114. "is_active",
  1115. "is_verified",
  1116. "created_at",
  1117. "updated_at",
  1118. "name",
  1119. "profile_picture",
  1120. "bio",
  1121. "collection_ids",
  1122. "limits_overrides",
  1123. "metadata",
  1124. "account_type",
  1125. "hashed_password",
  1126. "google_id",
  1127. "github_id",
  1128. ]
  1129. )
  1130. .where("google_id = $1")
  1131. .build()
  1132. )
  1133. result = await self.connection_manager.fetchrow_query(
  1134. query, [google_id]
  1135. )
  1136. if not result:
  1137. return None
  1138. return User(
  1139. id=result["id"],
  1140. email=result["email"],
  1141. is_superuser=result["is_superuser"],
  1142. is_active=result["is_active"],
  1143. is_verified=result["is_verified"],
  1144. created_at=result["created_at"],
  1145. updated_at=result["updated_at"],
  1146. name=result["name"],
  1147. profile_picture=result["profile_picture"],
  1148. bio=result["bio"],
  1149. collection_ids=result["collection_ids"] or [],
  1150. limits_overrides=json.loads(result["limits_overrides"] or "{}"),
  1151. metadata=json.loads(result["metadata"] or "{}"),
  1152. account_type=result["account_type"],
  1153. hashed_password=result["hashed_password"],
  1154. google_id=result["google_id"],
  1155. github_id=result["github_id"],
  1156. )
  1157. async def get_user_by_github_id(self, github_id: str) -> Optional[User]:
  1158. """Return a User if the github_id is found; otherwise None."""
  1159. query, params = (
  1160. QueryBuilder(self._get_table_name("users"))
  1161. .select(
  1162. [
  1163. "id",
  1164. "email",
  1165. "is_superuser",
  1166. "is_active",
  1167. "is_verified",
  1168. "created_at",
  1169. "updated_at",
  1170. "name",
  1171. "profile_picture",
  1172. "bio",
  1173. "collection_ids",
  1174. "limits_overrides",
  1175. "metadata",
  1176. "account_type",
  1177. "hashed_password",
  1178. "google_id",
  1179. "github_id",
  1180. ]
  1181. )
  1182. .where("github_id = $1")
  1183. .build()
  1184. )
  1185. result = await self.connection_manager.fetchrow_query(
  1186. query, [github_id]
  1187. )
  1188. if not result:
  1189. return None
  1190. return User(
  1191. id=result["id"],
  1192. email=result["email"],
  1193. is_superuser=result["is_superuser"],
  1194. is_active=result["is_active"],
  1195. is_verified=result["is_verified"],
  1196. created_at=result["created_at"],
  1197. updated_at=result["updated_at"],
  1198. name=result["name"],
  1199. profile_picture=result["profile_picture"],
  1200. bio=result["bio"],
  1201. collection_ids=result["collection_ids"] or [],
  1202. limits_overrides=json.loads(result["limits_overrides"] or "{}"),
  1203. metadata=json.loads(result["metadata"] or "{}"),
  1204. account_type=result["account_type"],
  1205. hashed_password=result["hashed_password"],
  1206. google_id=result["google_id"],
  1207. github_id=result["github_id"],
  1208. )