123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326 |
- import csv
- import json
- import tempfile
- from datetime import datetime
- from typing import IO, Optional
- from uuid import UUID
- from fastapi import HTTPException
- from core.base import CryptoProvider, Handler
- from core.base.abstractions import R2RException
- from core.utils import generate_user_id
- from shared.abstractions import User
- from .base import PostgresConnectionManager, QueryBuilder
- from .collections import PostgresCollectionsHandler
- def _merge_metadata(
- existing_metadata: dict[str, str], new_metadata: dict[str, Optional[str]]
- ) -> dict[str, str]:
- """
- Merges the new metadata with the existing metadata in the Stripe-style approach:
- - new_metadata[key] = <string> => update or add that key
- - new_metadata[key] = "" => remove that key
- - if new_metadata is empty => remove all keys
- """
- # If new_metadata is an empty dict, it signals removal of all keys.
- if new_metadata == {}:
- return {}
- # Copy so we don't mutate the original
- final_metadata = dict(existing_metadata)
- for key, value in new_metadata.items():
- # If the user sets the key to an empty string, it means "delete" that key
- if value == "":
- if key in final_metadata:
- del final_metadata[key]
- # If not None and not empty, set or override
- elif value is not None:
- final_metadata[key] = value
- else:
- # If the user sets the value to None in some contexts, decide if you want to remove or ignore
- # For now we might treat None same as empty string => remove
- if key in final_metadata:
- del final_metadata[key]
- return final_metadata
- class PostgresUserHandler(Handler):
- TABLE_NAME = "users"
- API_KEYS_TABLE_NAME = "users_api_keys"
- def __init__(
- self,
- project_name: str,
- connection_manager: PostgresConnectionManager,
- crypto_provider: CryptoProvider,
- ):
- super().__init__(project_name, connection_manager)
- self.crypto_provider = crypto_provider
- async def create_tables(self):
- user_table_query = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (
- id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
- email TEXT UNIQUE NOT NULL,
- hashed_password TEXT NOT NULL,
- is_superuser BOOLEAN DEFAULT FALSE,
- is_active BOOLEAN DEFAULT TRUE,
- is_verified BOOLEAN DEFAULT FALSE,
- verification_code TEXT,
- verification_code_expiry TIMESTAMPTZ,
- name TEXT,
- bio TEXT,
- profile_picture TEXT,
- reset_token TEXT,
- reset_token_expiry TIMESTAMPTZ,
- collection_ids UUID[] NULL,
- limits_overrides JSONB,
- metadata JSONB,
- created_at TIMESTAMPTZ DEFAULT NOW(),
- updated_at TIMESTAMPTZ DEFAULT NOW(),
- account_type TEXT NOT NULL DEFAULT 'password',
- google_id TEXT,
- github_id TEXT
- );
- """
- # API keys table with updated_at instead of last_used_at
- api_keys_table_query = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
- id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
- user_id UUID NOT NULL REFERENCES {self._get_table_name(PostgresUserHandler.TABLE_NAME)}(id) ON DELETE CASCADE,
- public_key TEXT UNIQUE NOT NULL,
- hashed_key TEXT NOT NULL,
- name TEXT,
- description TEXT,
- created_at TIMESTAMPTZ DEFAULT NOW(),
- updated_at TIMESTAMPTZ DEFAULT NOW()
- );
- CREATE INDEX IF NOT EXISTS idx_api_keys_user_id
- ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(user_id);
- CREATE INDEX IF NOT EXISTS idx_api_keys_public_key
- ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(public_key);
- """
- await self.connection_manager.execute_query(user_table_query)
- await self.connection_manager.execute_query(api_keys_table_query)
- # (New) Code snippet for adding columns if missing
- # Postgres >= 9.6 supports "ADD COLUMN IF NOT EXISTS"
- check_columns_query = f"""
- ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
- ADD COLUMN IF NOT EXISTS metadata JSONB;
- ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
- ADD COLUMN IF NOT EXISTS limits_overrides JSONB;
- ALTER TABLE {self._get_table_name(self.API_KEYS_TABLE_NAME)}
- ADD COLUMN IF NOT EXISTS description TEXT;
- """
- await self.connection_manager.execute_query(check_columns_query)
- # Optionally, create indexes for quick lookups:
- check_columns_query = f"""
- ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
- ADD COLUMN IF NOT EXISTS account_type TEXT NOT NULL DEFAULT 'password',
- ADD COLUMN IF NOT EXISTS google_id TEXT,
- ADD COLUMN IF NOT EXISTS github_id TEXT;
- CREATE INDEX IF NOT EXISTS idx_users_google_id
- ON {self._get_table_name(self.TABLE_NAME)}(google_id);
- CREATE INDEX IF NOT EXISTS idx_users_github_id
- ON {self._get_table_name(self.TABLE_NAME)}(github_id);
- """
- await self.connection_manager.execute_query(check_columns_query)
- async def get_user_by_id(self, id: UUID) -> User:
- query, _ = (
- QueryBuilder(self._get_table_name("users"))
- .select(
- [
- "id",
- "email",
- "is_superuser",
- "is_active",
- "is_verified",
- "created_at",
- "updated_at",
- "name",
- "profile_picture",
- "bio",
- "collection_ids",
- "limits_overrides",
- "metadata",
- "account_type",
- "hashed_password",
- "google_id",
- "github_id",
- ]
- )
- .where("id = $1")
- .build()
- )
- result = await self.connection_manager.fetchrow_query(query, [id])
- if not result:
- raise R2RException(status_code=404, message="User not found")
- return User(
- id=result["id"],
- email=result["email"],
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- name=result["name"],
- profile_picture=result["profile_picture"],
- bio=result["bio"],
- collection_ids=result["collection_ids"],
- limits_overrides=json.loads(result["limits_overrides"] or "{}"),
- metadata=json.loads(result["metadata"] or "{}"),
- hashed_password=result["hashed_password"],
- account_type=result["account_type"],
- google_id=result["google_id"],
- github_id=result["github_id"],
- )
- async def get_user_by_email(self, email: str) -> User:
- query, params = (
- QueryBuilder(self._get_table_name("users"))
- .select(
- [
- "id",
- "email",
- "is_superuser",
- "is_active",
- "is_verified",
- "created_at",
- "updated_at",
- "name",
- "profile_picture",
- "bio",
- "collection_ids",
- "metadata",
- "limits_overrides",
- "account_type",
- "hashed_password",
- "google_id",
- "github_id",
- ]
- )
- .where("email = $1")
- .build()
- )
- result = await self.connection_manager.fetchrow_query(query, [email])
- if not result:
- raise R2RException(status_code=404, message="User not found")
- return User(
- id=result["id"],
- email=result["email"],
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- name=result["name"],
- profile_picture=result["profile_picture"],
- bio=result["bio"],
- collection_ids=result["collection_ids"],
- limits_overrides=json.loads(result["limits_overrides"] or "{}"),
- metadata=json.loads(result["metadata"] or "{}"),
- account_type=result["account_type"],
- hashed_password=result["hashed_password"],
- google_id=result["google_id"],
- github_id=result["github_id"],
- )
- async def create_user(
- self,
- email: str,
- password: Optional[str] = None,
- account_type: Optional[str] = "password",
- google_id: Optional[str] = None,
- github_id: Optional[str] = None,
- is_superuser: bool = False,
- is_verified: bool = False,
- name: Optional[str] = None,
- bio: Optional[str] = None,
- profile_picture: Optional[str] = None,
- ) -> User:
- """Create a new user."""
- # 1) Check if a user with this email already exists
- try:
- existing = await self.get_user_by_email(email)
- if existing:
- raise R2RException(
- status_code=400,
- message="User with this email already exists",
- )
- except R2RException as e:
- if e.status_code != 404:
- raise e
- # 2) If google_id is provided, ensure no user already has it
- if google_id:
- existing_google_user = await self.get_user_by_google_id(google_id)
- if existing_google_user:
- raise R2RException(
- status_code=400,
- message="User with this Google account already exists",
- )
- # 3) If github_id is provided, ensure no user already has it
- if github_id:
- existing_github_user = await self.get_user_by_github_id(github_id)
- if existing_github_user:
- raise R2RException(
- status_code=400,
- message="User with this GitHub account already exists",
- )
- hashed_password = None
- if account_type == "password":
- if password is None:
- raise R2RException(
- status_code=400,
- message="Password is required for a 'password' account_type",
- )
- hashed_password = self.crypto_provider.get_password_hash(password) # type: ignore
- query, params = (
- QueryBuilder(self._get_table_name(self.TABLE_NAME))
- .insert(
- {
- "email": email,
- "id": generate_user_id(email),
- "is_superuser": is_superuser,
- "collection_ids": [],
- "limits_overrides": None,
- "metadata": None,
- "account_type": account_type,
- "hashed_password": hashed_password
- or "", # Ensure hashed_password is not None
- # !!WARNING - Upstream checks are required to treat oauth differently from password!!
- "google_id": google_id,
- "github_id": github_id,
- "is_verified": is_verified or (account_type != "password"),
- "name": name,
- "bio": bio,
- "profile_picture": profile_picture,
- }
- )
- .returning(
- [
- "id",
- "email",
- "is_superuser",
- "is_active",
- "is_verified",
- "created_at",
- "updated_at",
- "collection_ids",
- "limits_overrides",
- "metadata",
- "name",
- "bio",
- "profile_picture",
- ]
- )
- .build()
- )
- result = await self.connection_manager.fetchrow_query(query, params)
- if not result:
- raise R2RException(
- status_code=500,
- message="Failed to create user",
- )
- return User(
- id=result["id"],
- email=result["email"],
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- collection_ids=result["collection_ids"] or [],
- limits_overrides=json.loads(result["limits_overrides"] or "{}"),
- metadata=json.loads(result["metadata"] or "{}"),
- name=result["name"],
- bio=result["bio"],
- profile_picture=result["profile_picture"],
- account_type=account_type or "password",
- hashed_password=hashed_password,
- google_id=google_id,
- github_id=github_id,
- )
- async def update_user(
- self,
- user: User,
- merge_limits: bool = False,
- new_metadata: dict[str, Optional[str]] | None = None,
- ) -> User:
- """Update user information including limits_overrides.
- Args:
- user: User object containing updated information
- merge_limits: If True, will merge existing limits_overrides with new ones.
- If False, will overwrite existing limits_overrides.
- Returns:
- Updated User object
- """
- # Get current user if we need to merge limits or get hashed password
- current_user = None
- try:
- current_user = await self.get_user_by_id(user.id)
- except R2RException:
- raise R2RException(
- status_code=404, message="User not found"
- ) from None
- # If the new user.google_id != current_user.google_id, check for duplicates
- if user.email and (user.email != current_user.email):
- existing_email_user = await self.get_user_by_email(user.email)
- if existing_email_user and existing_email_user.id != user.id:
- raise R2RException(
- status_code=400,
- message="That email account is already associated with another user.",
- )
- # If the new user.google_id != current_user.google_id, check for duplicates
- if user.google_id and (user.google_id != current_user.google_id):
- existing_google_user = await self.get_user_by_google_id(
- user.google_id
- )
- if existing_google_user and existing_google_user.id != user.id:
- raise R2RException(
- status_code=400,
- message="That Google account is already associated with another user.",
- )
- # Similarly for GitHub:
- if user.github_id and (user.github_id != current_user.github_id):
- existing_github_user = await self.get_user_by_github_id(
- user.github_id
- )
- if existing_github_user and existing_github_user.id != user.id:
- raise R2RException(
- status_code=400,
- message="That GitHub account is already associated with another user.",
- )
- # Merge or replace metadata if provided
- final_metadata = current_user.metadata or {}
- if new_metadata is not None:
- final_metadata = _merge_metadata(final_metadata, new_metadata)
- # Merge or replace limits_overrides
- final_limits = user.limits_overrides
- if (
- merge_limits
- and current_user.limits_overrides
- and user.limits_overrides
- ):
- final_limits = {
- **current_user.limits_overrides,
- **user.limits_overrides,
- }
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET email = $1,
- is_superuser = $2,
- is_active = $3,
- is_verified = $4,
- updated_at = NOW(),
- name = $5,
- profile_picture = $6,
- bio = $7,
- collection_ids = $8,
- limits_overrides = $9::jsonb,
- metadata = $10::jsonb
- WHERE id = $11
- RETURNING id, email, is_superuser, is_active, is_verified,
- created_at, updated_at, name, profile_picture, bio,
- collection_ids, limits_overrides, metadata, hashed_password,
- account_type, google_id, github_id
- """
- result = await self.connection_manager.fetchrow_query(
- query,
- [
- user.email,
- user.is_superuser,
- user.is_active,
- user.is_verified,
- user.name,
- user.profile_picture,
- user.bio,
- user.collection_ids or [],
- json.dumps(final_limits),
- json.dumps(final_metadata),
- user.id,
- ],
- )
- if not result:
- raise HTTPException(
- status_code=500,
- detail="Failed to update user",
- )
- return User(
- id=result["id"],
- email=result["email"],
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- name=result["name"],
- profile_picture=result["profile_picture"],
- bio=result["bio"],
- collection_ids=result["collection_ids"]
- or [], # Ensure null becomes empty array
- limits_overrides=json.loads(
- result["limits_overrides"] or "{}"
- ), # Can be null
- metadata=json.loads(result["metadata"] or "{}"),
- account_type=result["account_type"],
- hashed_password=result[
- "hashed_password"
- ], # Include hashed_password
- google_id=result["google_id"],
- github_id=result["github_id"],
- )
- async def delete_user_relational(self, id: UUID) -> None:
- """Delete a user and update related records."""
- # Get the collections the user belongs to
- collection_query, params = (
- QueryBuilder(self._get_table_name(self.TABLE_NAME))
- .select(["collection_ids"])
- .where("id = $1")
- .build()
- )
- collection_result = await self.connection_manager.fetchrow_query(
- collection_query, [id]
- )
- if not collection_result:
- raise R2RException(status_code=404, message="User not found")
- # Update documents query
- doc_update_query, doc_params = (
- QueryBuilder(self._get_table_name("documents"))
- .update({"id": None})
- .where("id = $1")
- .build()
- )
- await self.connection_manager.execute_query(doc_update_query, [id])
- # Delete user query
- delete_query, del_params = (
- QueryBuilder(self._get_table_name(self.TABLE_NAME))
- .delete()
- .where("id = $1")
- .returning(["id"])
- .build()
- )
- result = await self.connection_manager.fetchrow_query(
- delete_query, [id]
- )
- if not result:
- raise R2RException(status_code=404, message="User not found")
- async def update_user_password(self, id: UUID, new_hashed_password: str):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET hashed_password = $1, updated_at = NOW()
- WHERE id = $2
- """
- await self.connection_manager.execute_query(
- query, [new_hashed_password, id]
- )
- async def get_all_users(self) -> list[User]:
- """Get all users with minimal information."""
- query, params = (
- QueryBuilder(self._get_table_name(self.TABLE_NAME))
- .select(
- [
- "id",
- "email",
- "is_superuser",
- "is_active",
- "is_verified",
- "created_at",
- "updated_at",
- "collection_ids",
- "hashed_password",
- "limits_overrides",
- "metadata",
- "name",
- "bio",
- "profile_picture",
- "account_type",
- "google_id",
- "github_id",
- ]
- )
- .build()
- )
- results = await self.connection_manager.fetch_query(query, params)
- return [
- User(
- id=result["id"],
- email=result["email"],
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- collection_ids=result["collection_ids"] or [],
- limits_overrides=json.loads(
- result["limits_overrides"] or "{}"
- ),
- metadata=json.loads(result["metadata"] or "{}"),
- name=result["name"],
- bio=result["bio"],
- profile_picture=result["profile_picture"],
- account_type=result["account_type"],
- hashed_password=result["hashed_password"],
- google_id=result["google_id"],
- github_id=result["github_id"],
- )
- for result in results
- ]
- async def store_verification_code(
- self, id: UUID, verification_code: str, expiry: datetime
- ):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET verification_code = $1, verification_code_expiry = $2
- WHERE id = $3
- """
- await self.connection_manager.execute_query(
- query, [verification_code, expiry, id]
- )
- async def verify_user(self, verification_code: str) -> None:
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
- WHERE verification_code = $1 AND verification_code_expiry > NOW()
- RETURNING id
- """
- result = await self.connection_manager.fetchrow_query(
- query, [verification_code]
- )
- if not result:
- raise R2RException(
- status_code=400, message="Invalid or expired verification code"
- )
- async def remove_verification_code(self, verification_code: str):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET verification_code = NULL, verification_code_expiry = NULL
- WHERE verification_code = $1
- """
- await self.connection_manager.execute_query(query, [verification_code])
- async def expire_verification_code(self, id: UUID):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET verification_code_expiry = NOW() - INTERVAL '1 day'
- WHERE id = $1
- """
- await self.connection_manager.execute_query(query, [id])
- async def store_reset_token(
- self, id: UUID, reset_token: str, expiry: datetime
- ):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET reset_token = $1, reset_token_expiry = $2
- WHERE id = $3
- """
- await self.connection_manager.execute_query(
- query, [reset_token, expiry, id]
- )
- async def get_user_id_by_reset_token(
- self, reset_token: str
- ) -> Optional[UUID]:
- query = f"""
- SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- WHERE reset_token = $1 AND reset_token_expiry > NOW()
- """
- result = await self.connection_manager.fetchrow_query(
- query, [reset_token]
- )
- return result["id"] if result else None
- async def remove_reset_token(self, id: UUID):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET reset_token = NULL, reset_token_expiry = NULL
- WHERE id = $1
- """
- await self.connection_manager.execute_query(query, [id])
- async def remove_user_from_all_collections(self, id: UUID):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET collection_ids = ARRAY[]::UUID[]
- WHERE id = $1
- """
- await self.connection_manager.execute_query(query, [id])
- async def add_user_to_collection(
- self, id: UUID, collection_id: UUID
- ) -> bool:
- # Check if the user exists
- if not await self.get_user_by_id(id):
- raise R2RException(status_code=404, message="User not found")
- # Check if the collection exists
- if not await self._collection_exists(collection_id):
- raise R2RException(status_code=404, message="Collection not found")
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET collection_ids = array_append(collection_ids, $1)
- WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
- RETURNING id
- """
- result = await self.connection_manager.fetchrow_query(
- query, [collection_id, id]
- )
- if not result:
- raise R2RException(
- status_code=400, message="User already in collection"
- )
- update_collection_query = f"""
- UPDATE {self._get_table_name("collections")}
- SET user_count = user_count + 1
- WHERE id = $1
- """
- await self.connection_manager.execute_query(
- query=update_collection_query,
- params=[collection_id],
- )
- return True
- async def remove_user_from_collection(
- self, id: UUID, collection_id: UUID
- ) -> bool:
- if not await self.get_user_by_id(id):
- raise R2RException(status_code=404, message="User not found")
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET collection_ids = array_remove(collection_ids, $1)
- WHERE id = $2 AND $1 = ANY(collection_ids)
- RETURNING id
- """
- result = await self.connection_manager.fetchrow_query(
- query, [collection_id, id]
- )
- if not result:
- raise R2RException(
- status_code=400,
- message="User is not a member of the specified collection",
- )
- return True
- async def get_users_in_collection(
- self, collection_id: UUID, offset: int, limit: int
- ) -> dict[str, list[User] | int]:
- """Get all users in a specific collection with pagination."""
- if not await self._collection_exists(collection_id):
- raise R2RException(status_code=404, message="Collection not found")
- query, params = (
- QueryBuilder(self._get_table_name(self.TABLE_NAME))
- .select(
- [
- "id",
- "email",
- "is_active",
- "is_superuser",
- "created_at",
- "updated_at",
- "is_verified",
- "collection_ids",
- "name",
- "bio",
- "profile_picture",
- "limits_overrides",
- "metadata",
- "account_type",
- "hashed_password",
- "google_id",
- "github_id",
- "COUNT(*) OVER() AS total_entries",
- ]
- )
- .where("$1 = ANY(collection_ids)")
- .order_by("name")
- .offset("$2")
- .limit("$3" if limit != -1 else None)
- .build()
- )
- conditions = [collection_id, offset]
- if limit != -1:
- conditions.append(limit)
- results = await self.connection_manager.fetch_query(query, conditions)
- users_list = [
- User(
- id=row["id"],
- email=row["email"],
- is_active=row["is_active"],
- is_superuser=row["is_superuser"],
- created_at=row["created_at"],
- updated_at=row["updated_at"],
- is_verified=row["is_verified"],
- collection_ids=row["collection_ids"] or [],
- name=row["name"],
- bio=row["bio"],
- profile_picture=row["profile_picture"],
- limits_overrides=json.loads(row["limits_overrides"] or "{}"),
- metadata=json.loads(row["metadata"] or "{}"),
- account_type=row["account_type"],
- hashed_password=row["hashed_password"],
- google_id=row["google_id"],
- github_id=row["github_id"],
- )
- for row in results
- ]
- total_entries = results[0]["total_entries"] if results else 0
- return {"results": users_list, "total_entries": total_entries}
- async def mark_user_as_superuser(self, id: UUID):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET is_superuser = TRUE, is_verified = TRUE,
- verification_code = NULL, verification_code_expiry = NULL
- WHERE id = $1
- """
- await self.connection_manager.execute_query(query, [id])
- async def get_user_id_by_verification_code(
- self, verification_code: str
- ) -> UUID:
- query = f"""
- SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- WHERE verification_code = $1 AND verification_code_expiry > NOW()
- """
- result = await self.connection_manager.fetchrow_query(
- query, [verification_code]
- )
- if not result:
- raise R2RException(
- status_code=400, message="Invalid or expired verification code"
- )
- return result["id"]
- async def mark_user_as_verified(self, id: UUID):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET is_verified = TRUE,
- verification_code = NULL,
- verification_code_expiry = NULL
- WHERE id = $1
- """
- await self.connection_manager.execute_query(query, [id])
- async def get_users_overview(
- self,
- offset: int,
- limit: int,
- user_ids: Optional[list[UUID]] = None,
- ) -> dict[str, list[User] | int]:
- """Return users with document usage and total entries."""
- query = f"""
- WITH user_document_ids AS (
- SELECT
- u.id as user_id,
- ARRAY_AGG(d.id) FILTER (WHERE d.id IS NOT NULL) AS doc_ids
- FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
- LEFT JOIN {self._get_table_name("documents")} d ON u.id = d.owner_id
- GROUP BY u.id
- ),
- user_docs AS (
- SELECT
- u.id,
- u.email,
- u.is_superuser,
- u.is_active,
- u.is_verified,
- u.name,
- u.bio,
- u.profile_picture,
- u.collection_ids,
- u.created_at,
- u.updated_at,
- COUNT(d.id) AS num_files,
- COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes,
- ud.doc_ids as document_ids
- FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
- LEFT JOIN {self._get_table_name("documents")} d ON u.id = d.owner_id
- LEFT JOIN user_document_ids ud ON u.id = ud.user_id
- {" WHERE u.id = ANY($3::uuid[])" if user_ids else ""}
- GROUP BY u.id, u.email, u.is_superuser, u.is_active, u.is_verified,
- u.created_at, u.updated_at, u.collection_ids, ud.doc_ids
- )
- SELECT
- user_docs.*,
- COUNT(*) OVER() AS total_entries
- FROM user_docs
- ORDER BY email
- OFFSET $1
- """
- params: list = [offset]
- if limit != -1:
- query += " LIMIT $2"
- params.append(limit)
- if user_ids:
- params.append(user_ids)
- results = await self.connection_manager.fetch_query(query, params)
- if not results:
- raise R2RException(status_code=404, message="No users found")
- users_list = []
- for row in results:
- users_list.append(
- User(
- id=row["id"],
- email=row["email"],
- is_superuser=row["is_superuser"],
- is_active=row["is_active"],
- is_verified=row["is_verified"],
- name=row["name"],
- bio=row["bio"],
- created_at=row["created_at"],
- updated_at=row["updated_at"],
- profile_picture=row["profile_picture"],
- collection_ids=row["collection_ids"] or [],
- num_files=row["num_files"],
- total_size_in_bytes=row["total_size_in_bytes"],
- document_ids=(
- list(row["document_ids"])
- if row["document_ids"]
- else []
- ),
- )
- )
- total_entries = results[0]["total_entries"]
- return {"results": users_list, "total_entries": total_entries}
- async def _collection_exists(self, collection_id: UUID) -> bool:
- """Check if a collection exists."""
- query = f"""
- SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
- WHERE id = $1
- """
- result = await self.connection_manager.fetchrow_query(
- query, [collection_id]
- )
- return result is not None
- async def get_user_validation_data(
- self,
- user_id: UUID,
- ) -> dict:
- """Get verification data for a specific user.
- This method should be called after superuser authorization has been
- verified.
- """
- query = f"""
- SELECT
- verification_code,
- verification_code_expiry,
- reset_token,
- reset_token_expiry
- FROM {self._get_table_name("users")}
- WHERE id = $1
- """
- result = await self.connection_manager.fetchrow_query(query, [user_id])
- if not result:
- raise R2RException(status_code=404, message="User not found")
- return {
- "verification_data": {
- "verification_code": result["verification_code"],
- "verification_code_expiry": (
- result["verification_code_expiry"].isoformat()
- if result["verification_code_expiry"]
- else None
- ),
- "reset_token": result["reset_token"],
- "reset_token_expiry": (
- result["reset_token_expiry"].isoformat()
- if result["reset_token_expiry"]
- else None
- ),
- }
- }
- # API Key methods
- async def store_user_api_key(
- self,
- user_id: UUID,
- key_id: str,
- hashed_key: str,
- name: Optional[str] = None,
- description: Optional[str] = None,
- ) -> UUID:
- """Store a new API key for a user with optional name and
- description."""
- query = f"""
- INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
- (user_id, public_key, hashed_key, name, description)
- VALUES ($1, $2, $3, $4, $5)
- RETURNING id
- """
- result = await self.connection_manager.fetchrow_query(
- query, [user_id, key_id, hashed_key, name or "", description or ""]
- )
- if not result:
- raise R2RException(
- status_code=500, message="Failed to store API key"
- )
- return result["id"]
- async def get_api_key_record(self, key_id: str) -> Optional[dict]:
- """Get API key record by 'public_key' and update 'updated_at' to now.
- Returns { "user_id", "hashed_key" } or None if not found.
- """
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
- SET updated_at = NOW()
- WHERE public_key = $1
- RETURNING user_id, hashed_key
- """
- result = await self.connection_manager.fetchrow_query(query, [key_id])
- if not result:
- return None
- return {
- "user_id": result["user_id"],
- "hashed_key": result["hashed_key"],
- }
- async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
- """Get all API keys for a user."""
- query = f"""
- SELECT id, public_key, name, description, created_at, updated_at
- FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
- WHERE user_id = $1
- ORDER BY created_at DESC
- """
- results = await self.connection_manager.fetch_query(query, [user_id])
- return [
- {
- "key_id": str(row["id"]),
- "public_key": row["public_key"],
- "name": row["name"] or "",
- "description": row["description"] or "",
- "updated_at": row["updated_at"],
- }
- for row in results
- ]
- async def delete_api_key(self, user_id: UUID, key_id: UUID) -> bool:
- """Delete a specific API key."""
- query = f"""
- DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
- WHERE id = $1 AND user_id = $2
- RETURNING id, public_key, name, description
- """
- result = await self.connection_manager.fetchrow_query(
- query, [key_id, user_id]
- )
- if result is None:
- raise R2RException(status_code=404, message="API key not found")
- return True
- async def update_api_key_name(
- self, user_id: UUID, key_id: UUID, name: str
- ) -> bool:
- """Update the name of an existing API key."""
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
- SET name = $1, updated_at = NOW()
- WHERE id = $2 AND user_id = $3
- RETURNING id
- """
- result = await self.connection_manager.fetchrow_query(
- query, [name, key_id, user_id]
- )
- if result is None:
- raise R2RException(status_code=404, message="API key not found")
- return True
- async def export_to_csv(
- self,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- """Creates a CSV file from the PostgreSQL data and returns the path to
- the temp file."""
- valid_columns = {
- "id",
- "email",
- "is_superuser",
- "is_active",
- "is_verified",
- "name",
- "bio",
- "collection_ids",
- "created_at",
- "updated_at",
- }
- if not columns:
- columns = list(valid_columns)
- elif invalid_cols := set(columns) - valid_columns:
- raise ValueError(f"Invalid columns: {invalid_cols}")
- select_stmt = f"""
- SELECT
- id::text,
- email,
- is_superuser,
- is_active,
- is_verified,
- name,
- bio,
- collection_ids::text,
- to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
- to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
- FROM {self._get_table_name(self.TABLE_NAME)}
- """
- params = []
- if filters:
- conditions = []
- param_index = 1
- for field, value in filters.items():
- if field not in valid_columns:
- continue
- if isinstance(value, dict):
- for op, val in value.items():
- if op == "$eq":
- conditions.append(f"{field} = ${param_index}")
- params.append(val)
- param_index += 1
- elif op == "$gt":
- conditions.append(f"{field} > ${param_index}")
- params.append(val)
- param_index += 1
- elif op == "$lt":
- conditions.append(f"{field} < ${param_index}")
- params.append(val)
- param_index += 1
- else:
- # Direct equality
- conditions.append(f"{field} = ${param_index}")
- params.append(value)
- param_index += 1
- if conditions:
- select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
- select_stmt = f"{select_stmt} ORDER BY created_at DESC"
- temp_file = None
- try:
- temp_file = tempfile.NamedTemporaryFile(
- mode="w", delete=True, suffix=".csv"
- )
- writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
- async with self.connection_manager.pool.get_connection() as conn: # type: ignore
- async with conn.transaction():
- cursor = await conn.cursor(select_stmt, *params)
- if include_header:
- writer.writerow(columns)
- chunk_size = 1000
- while True:
- rows = await cursor.fetch(chunk_size)
- if not rows:
- break
- for row in rows:
- row_dict = {
- "id": row[0],
- "email": row[1],
- "is_superuser": row[2],
- "is_active": row[3],
- "is_verified": row[4],
- "name": row[5],
- "bio": row[6],
- "collection_ids": row[7],
- "created_at": row[8],
- "updated_at": row[9],
- }
- writer.writerow([row_dict[col] for col in columns])
- temp_file.flush()
- return temp_file.name, temp_file
- except Exception as e:
- if temp_file:
- temp_file.close()
- raise HTTPException(
- status_code=500,
- detail=f"Failed to export data: {str(e)}",
- ) from e
- async def get_user_by_google_id(self, google_id: str) -> Optional[User]:
- """Return a User if the google_id is found; otherwise None."""
- query, params = (
- QueryBuilder(self._get_table_name("users"))
- .select(
- [
- "id",
- "email",
- "is_superuser",
- "is_active",
- "is_verified",
- "created_at",
- "updated_at",
- "name",
- "profile_picture",
- "bio",
- "collection_ids",
- "limits_overrides",
- "metadata",
- "account_type",
- "hashed_password",
- "google_id",
- "github_id",
- ]
- )
- .where("google_id = $1")
- .build()
- )
- result = await self.connection_manager.fetchrow_query(
- query, [google_id]
- )
- if not result:
- return None
- return User(
- id=result["id"],
- email=result["email"],
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- name=result["name"],
- profile_picture=result["profile_picture"],
- bio=result["bio"],
- collection_ids=result["collection_ids"] or [],
- limits_overrides=json.loads(result["limits_overrides"] or "{}"),
- metadata=json.loads(result["metadata"] or "{}"),
- account_type=result["account_type"],
- hashed_password=result["hashed_password"],
- google_id=result["google_id"],
- github_id=result["github_id"],
- )
- async def get_user_by_github_id(self, github_id: str) -> Optional[User]:
- """Return a User if the github_id is found; otherwise None."""
- query, params = (
- QueryBuilder(self._get_table_name("users"))
- .select(
- [
- "id",
- "email",
- "is_superuser",
- "is_active",
- "is_verified",
- "created_at",
- "updated_at",
- "name",
- "profile_picture",
- "bio",
- "collection_ids",
- "limits_overrides",
- "metadata",
- "account_type",
- "hashed_password",
- "google_id",
- "github_id",
- ]
- )
- .where("github_id = $1")
- .build()
- )
- result = await self.connection_manager.fetchrow_query(
- query, [github_id]
- )
- if not result:
- return None
- return User(
- id=result["id"],
- email=result["email"],
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- name=result["name"],
- profile_picture=result["profile_picture"],
- bio=result["bio"],
- collection_ids=result["collection_ids"] or [],
- limits_overrides=json.loads(result["limits_overrides"] or "{}"),
- metadata=json.loads(result["metadata"] or "{}"),
- account_type=result["account_type"],
- hashed_password=result["hashed_password"],
- google_id=result["google_id"],
- github_id=result["github_id"],
- )
|