1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009 |
- import csv
- import json
- import tempfile
- from datetime import datetime
- from typing import IO, Optional
- from uuid import UUID
- from fastapi import HTTPException
- from core.base import CryptoProvider, Handler
- from core.base.abstractions import R2RException
- from core.utils import generate_user_id
- from shared.abstractions import User
- from .base import PostgresConnectionManager, QueryBuilder
- from .collections import PostgresCollectionsHandler
- class PostgresUserHandler(Handler):
- TABLE_NAME = "users"
- API_KEYS_TABLE_NAME = "users_api_keys"
- def __init__(
- self,
- project_name: str,
- connection_manager: PostgresConnectionManager,
- crypto_provider: CryptoProvider,
- ):
- super().__init__(project_name, connection_manager)
- self.crypto_provider = crypto_provider
- async def create_tables(self):
- user_table_query = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (
- id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
- email TEXT UNIQUE NOT NULL,
- hashed_password TEXT NOT NULL,
- is_superuser BOOLEAN DEFAULT FALSE,
- is_active BOOLEAN DEFAULT TRUE,
- is_verified BOOLEAN DEFAULT FALSE,
- verification_code TEXT,
- verification_code_expiry TIMESTAMPTZ,
- name TEXT,
- bio TEXT,
- profile_picture TEXT,
- reset_token TEXT,
- reset_token_expiry TIMESTAMPTZ,
- collection_ids UUID[] NULL,
- limits_overrides JSONB,
- created_at TIMESTAMPTZ DEFAULT NOW(),
- updated_at TIMESTAMPTZ DEFAULT NOW()
- );
- """
- # API keys table with updated_at instead of last_used_at
- api_keys_table_query = f"""
- CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
- id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
- user_id UUID NOT NULL REFERENCES {self._get_table_name(PostgresUserHandler.TABLE_NAME)}(id) ON DELETE CASCADE,
- public_key TEXT UNIQUE NOT NULL,
- hashed_key TEXT NOT NULL,
- name TEXT,
- created_at TIMESTAMPTZ DEFAULT NOW(),
- updated_at TIMESTAMPTZ DEFAULT NOW()
- );
- CREATE INDEX IF NOT EXISTS idx_api_keys_user_id
- ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(user_id);
- CREATE INDEX IF NOT EXISTS idx_api_keys_public_key
- ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(public_key);
- """
- await self.connection_manager.execute_query(user_table_query)
- await self.connection_manager.execute_query(api_keys_table_query)
- async def get_user_by_id(self, id: UUID) -> User:
- query, _ = (
- QueryBuilder(self._get_table_name("users"))
- .select(
- [
- "id",
- "email",
- "hashed_password",
- "is_superuser",
- "is_active",
- "is_verified",
- "created_at",
- "updated_at",
- "name",
- "profile_picture",
- "bio",
- "collection_ids",
- "limits_overrides", # Fetch JSONB column
- ]
- )
- .where("id = $1")
- .build()
- )
- result = await self.connection_manager.fetchrow_query(query, [id])
- if not result:
- raise R2RException(status_code=404, message="User not found")
- return User(
- id=result["id"],
- email=result["email"],
- hashed_password=result["hashed_password"],
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- name=result["name"],
- profile_picture=result["profile_picture"],
- bio=result["bio"],
- collection_ids=result["collection_ids"],
- # Add the new field
- limits_overrides=json.loads(result["limits_overrides"] or "{}"),
- )
- async def get_user_by_email(self, email: str) -> User:
- query, params = (
- QueryBuilder(self._get_table_name("users"))
- .select(
- [
- "id",
- "email",
- "hashed_password",
- "is_superuser",
- "is_active",
- "is_verified",
- "created_at",
- "updated_at",
- "name",
- "profile_picture",
- "bio",
- "collection_ids",
- "limits_overrides",
- ]
- )
- .where("email = $1")
- .build()
- )
- result = await self.connection_manager.fetchrow_query(query, [email])
- if not result:
- raise R2RException(status_code=404, message="User not found")
- return User(
- id=result["id"],
- email=result["email"],
- hashed_password=result["hashed_password"],
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- name=result["name"],
- profile_picture=result["profile_picture"],
- bio=result["bio"],
- collection_ids=result["collection_ids"],
- limits_overrides=json.loads(result["limits_overrides"] or "{}"),
- )
- async def create_user(
- self, email: str, password: str, is_superuser: bool = False
- ) -> User:
- """Create a new user."""
- try:
- existing = await self.get_user_by_email(email)
- if existing:
- raise R2RException(
- status_code=400,
- message="User with this email already exists",
- )
- except R2RException as e:
- if e.status_code != 404:
- raise e
- hashed_password = self.crypto_provider.get_password_hash(password) # type: ignore
- query, params = (
- QueryBuilder(self._get_table_name(self.TABLE_NAME))
- .insert(
- {
- "email": email,
- "id": generate_user_id(email),
- "is_superuser": is_superuser,
- "hashed_password": hashed_password,
- "collection_ids": [],
- "limits_overrides": None,
- }
- )
- .returning(
- [
- "id",
- "email",
- "is_superuser",
- "is_active",
- "is_verified",
- "created_at",
- "updated_at",
- "collection_ids",
- "limits_overrides",
- ]
- )
- .build()
- )
- result = await self.connection_manager.fetchrow_query(query, params)
- if not result:
- raise R2RException(
- status_code=500,
- message="Failed to create user",
- )
- return User(
- id=result["id"],
- email=result["email"],
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- collection_ids=result["collection_ids"] or [],
- hashed_password=hashed_password,
- limits_overrides=json.loads(result["limits_overrides"] or "{}"),
- name=None,
- bio=None,
- profile_picture=None,
- )
- async def update_user(
- self, user: User, merge_limits: bool = False
- ) -> User:
- """
- Update user information including limits_overrides.
- Args:
- user: User object containing updated information
- merge_limits: If True, will merge existing limits_overrides with new ones.
- If False, will overwrite existing limits_overrides.
- Returns:
- Updated User object
- """
- # Get current user if we need to merge limits or get hashed password
- current_user = None
- try:
- current_user = await self.get_user_by_id(user.id)
- except R2RException:
- raise R2RException(status_code=404, message="User not found")
- # Merge or replace limits_overrides
- final_limits = user.limits_overrides
- if (
- merge_limits
- and current_user.limits_overrides
- and user.limits_overrides
- ):
- final_limits = {
- **current_user.limits_overrides,
- **user.limits_overrides,
- }
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET email = $1,
- is_superuser = $2,
- is_active = $3,
- is_verified = $4,
- updated_at = NOW(),
- name = $5,
- profile_picture = $6,
- bio = $7,
- collection_ids = $8,
- limits_overrides = $9::jsonb
- WHERE id = $10
- RETURNING id, email, is_superuser, is_active, is_verified,
- created_at, updated_at, name, profile_picture, bio,
- collection_ids, limits_overrides, hashed_password
- """
- result = await self.connection_manager.fetchrow_query(
- query,
- [
- user.email,
- user.is_superuser,
- user.is_active,
- user.is_verified,
- user.name,
- user.profile_picture,
- user.bio,
- user.collection_ids or [], # Ensure null becomes empty array
- json.dumps(final_limits), # Already handled null case
- user.id,
- ],
- )
- if not result:
- raise HTTPException(
- status_code=500,
- detail="Failed to update user",
- )
- return User(
- id=result["id"],
- email=result["email"],
- hashed_password=result[
- "hashed_password"
- ], # Include hashed_password
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- name=result["name"],
- profile_picture=result["profile_picture"],
- bio=result["bio"],
- collection_ids=result["collection_ids"]
- or [], # Ensure null becomes empty array
- limits_overrides=json.loads(
- result["limits_overrides"] or "{}"
- ), # Can be null
- )
- async def delete_user_relational(self, id: UUID) -> None:
- """Delete a user and update related records."""
- # Get the collections the user belongs to
- collection_query, params = (
- QueryBuilder(self._get_table_name(self.TABLE_NAME))
- .select(["collection_ids"])
- .where("id = $1")
- .build()
- )
- collection_result = await self.connection_manager.fetchrow_query(
- collection_query, [id]
- )
- if not collection_result:
- raise R2RException(status_code=404, message="User not found")
- # Update documents query
- doc_update_query, doc_params = (
- QueryBuilder(self._get_table_name("documents"))
- .update({"id": None})
- .where("id = $1")
- .build()
- )
- await self.connection_manager.execute_query(doc_update_query, [id])
- # Delete user query
- delete_query, del_params = (
- QueryBuilder(self._get_table_name(self.TABLE_NAME))
- .delete()
- .where("id = $1")
- .returning(["id"])
- .build()
- )
- result = await self.connection_manager.fetchrow_query(
- delete_query, [id]
- )
- if not result:
- raise R2RException(status_code=404, message="User not found")
- async def update_user_password(self, id: UUID, new_hashed_password: str):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET hashed_password = $1, updated_at = NOW()
- WHERE id = $2
- """
- await self.connection_manager.execute_query(
- query, [new_hashed_password, id]
- )
- async def get_all_users(self) -> list[User]:
- """Get all users with minimal information."""
- query, params = (
- QueryBuilder(self._get_table_name(self.TABLE_NAME))
- .select(
- [
- "id",
- "email",
- "is_superuser",
- "is_active",
- "is_verified",
- "created_at",
- "updated_at",
- "collection_ids",
- "hashed_password",
- "limits_overrides",
- "name",
- "bio",
- "profile_picture",
- ]
- )
- .build()
- )
- results = await self.connection_manager.fetch_query(query, params)
- return [
- User(
- id=result["id"],
- email=result["email"],
- hashed_password=result["hashed_password"],
- is_superuser=result["is_superuser"],
- is_active=result["is_active"],
- is_verified=result["is_verified"],
- created_at=result["created_at"],
- updated_at=result["updated_at"],
- collection_ids=result["collection_ids"] or [],
- limits_overrides=json.loads(
- result["limits_overrides"] or "{}"
- ),
- name=result["name"],
- bio=result["bio"],
- profile_picture=result["profile_picture"],
- )
- for result in results
- ]
- async def store_verification_code(
- self, id: UUID, verification_code: str, expiry: datetime
- ):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET verification_code = $1, verification_code_expiry = $2
- WHERE id = $3
- """
- await self.connection_manager.execute_query(
- query, [verification_code, expiry, id]
- )
- async def verify_user(self, verification_code: str) -> None:
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
- WHERE verification_code = $1 AND verification_code_expiry > NOW()
- RETURNING id
- """
- result = await self.connection_manager.fetchrow_query(
- query, [verification_code]
- )
- if not result:
- raise R2RException(
- status_code=400, message="Invalid or expired verification code"
- )
- async def remove_verification_code(self, verification_code: str):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET verification_code = NULL, verification_code_expiry = NULL
- WHERE verification_code = $1
- """
- await self.connection_manager.execute_query(query, [verification_code])
- async def expire_verification_code(self, id: UUID):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET verification_code_expiry = NOW() - INTERVAL '1 day'
- WHERE id = $1
- """
- await self.connection_manager.execute_query(query, [id])
- async def store_reset_token(
- self, id: UUID, reset_token: str, expiry: datetime
- ):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET reset_token = $1, reset_token_expiry = $2
- WHERE id = $3
- """
- await self.connection_manager.execute_query(
- query, [reset_token, expiry, id]
- )
- async def get_user_id_by_reset_token(
- self, reset_token: str
- ) -> Optional[UUID]:
- query = f"""
- SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- WHERE reset_token = $1 AND reset_token_expiry > NOW()
- """
- result = await self.connection_manager.fetchrow_query(
- query, [reset_token]
- )
- return result["id"] if result else None
- async def remove_reset_token(self, id: UUID):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET reset_token = NULL, reset_token_expiry = NULL
- WHERE id = $1
- """
- await self.connection_manager.execute_query(query, [id])
- async def remove_user_from_all_collections(self, id: UUID):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET collection_ids = ARRAY[]::UUID[]
- WHERE id = $1
- """
- await self.connection_manager.execute_query(query, [id])
- async def add_user_to_collection(
- self, id: UUID, collection_id: UUID
- ) -> bool:
- # Check if the user exists
- if not await self.get_user_by_id(id):
- raise R2RException(status_code=404, message="User not found")
- # Check if the collection exists
- if not await self._collection_exists(collection_id):
- raise R2RException(status_code=404, message="Collection not found")
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET collection_ids = array_append(collection_ids, $1)
- WHERE id = $2 AND NOT ($1 = ANY(collection_ids))
- RETURNING id
- """
- result = await self.connection_manager.fetchrow_query(
- query, [collection_id, id]
- )
- if not result:
- raise R2RException(
- status_code=400, message="User already in collection"
- )
- update_collection_query = f"""
- UPDATE {self._get_table_name('collections')}
- SET user_count = user_count + 1
- WHERE id = $1
- """
- await self.connection_manager.execute_query(
- query=update_collection_query,
- params=[collection_id],
- )
- return True
- async def remove_user_from_collection(
- self, id: UUID, collection_id: UUID
- ) -> bool:
- if not await self.get_user_by_id(id):
- raise R2RException(status_code=404, message="User not found")
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET collection_ids = array_remove(collection_ids, $1)
- WHERE id = $2 AND $1 = ANY(collection_ids)
- RETURNING id
- """
- result = await self.connection_manager.fetchrow_query(
- query, [collection_id, id]
- )
- if not result:
- raise R2RException(
- status_code=400,
- message="User is not a member of the specified collection",
- )
- return True
- async def get_users_in_collection(
- self, collection_id: UUID, offset: int, limit: int
- ) -> dict[str, list[User] | int]:
- """Get all users in a specific collection with pagination."""
- if not await self._collection_exists(collection_id):
- raise R2RException(status_code=404, message="Collection not found")
- query, params = (
- QueryBuilder(self._get_table_name(self.TABLE_NAME))
- .select(
- [
- "id",
- "email",
- "is_active",
- "is_superuser",
- "created_at",
- "updated_at",
- "is_verified",
- "collection_ids",
- "name",
- "bio",
- "profile_picture",
- "hashed_password",
- "limits_overrides",
- "COUNT(*) OVER() AS total_entries",
- ]
- )
- .where("$1 = ANY(collection_ids)")
- .order_by("name")
- .offset("$2")
- .limit("$3" if limit != -1 else None)
- .build()
- )
- conditions = [collection_id, offset]
- if limit != -1:
- conditions.append(limit)
- results = await self.connection_manager.fetch_query(query, conditions)
- users_list = [
- User(
- id=row["id"],
- email=row["email"],
- is_active=row["is_active"],
- is_superuser=row["is_superuser"],
- created_at=row["created_at"],
- updated_at=row["updated_at"],
- is_verified=row["is_verified"],
- collection_ids=row["collection_ids"] or [],
- name=row["name"],
- bio=row["bio"],
- profile_picture=row["profile_picture"],
- hashed_password=row["hashed_password"],
- limits_overrides=json.loads(row["limits_overrides"] or "{}"),
- )
- for row in results
- ]
- total_entries = results[0]["total_entries"] if results else 0
- return {"results": users_list, "total_entries": total_entries}
- async def mark_user_as_superuser(self, id: UUID):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET is_superuser = TRUE, is_verified = TRUE,
- verification_code = NULL, verification_code_expiry = NULL
- WHERE id = $1
- """
- await self.connection_manager.execute_query(query, [id])
- async def get_user_id_by_verification_code(
- self, verification_code: str
- ) -> UUID:
- query = f"""
- SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- WHERE verification_code = $1 AND verification_code_expiry > NOW()
- """
- result = await self.connection_manager.fetchrow_query(
- query, [verification_code]
- )
- if not result:
- raise R2RException(
- status_code=400, message="Invalid or expired verification code"
- )
- return result["id"]
- async def mark_user_as_verified(self, id: UUID):
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
- SET is_verified = TRUE,
- verification_code = NULL,
- verification_code_expiry = NULL
- WHERE id = $1
- """
- await self.connection_manager.execute_query(query, [id])
- async def get_users_overview(
- self,
- offset: int,
- limit: int,
- user_ids: Optional[list[UUID]] = None,
- ) -> dict[str, list[User] | int]:
- """
- Return users with document usage and total entries.
- """
- query = f"""
- WITH user_document_ids AS (
- SELECT
- u.id as user_id,
- ARRAY_AGG(d.id) FILTER (WHERE d.id IS NOT NULL) AS doc_ids
- FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
- LEFT JOIN {self._get_table_name('documents')} d ON u.id = d.owner_id
- GROUP BY u.id
- ),
- user_docs AS (
- SELECT
- u.id,
- u.email,
- u.is_superuser,
- u.is_active,
- u.is_verified,
- u.name,
- u.bio,
- u.profile_picture,
- u.collection_ids,
- u.created_at,
- u.updated_at,
- COUNT(d.id) AS num_files,
- COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes,
- ud.doc_ids as document_ids
- FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
- LEFT JOIN {self._get_table_name('documents')} d ON u.id = d.owner_id
- LEFT JOIN user_document_ids ud ON u.id = ud.user_id
- {' WHERE u.id = ANY($3::uuid[])' if user_ids else ''}
- GROUP BY u.id, u.email, u.is_superuser, u.is_active, u.is_verified,
- u.created_at, u.updated_at, u.collection_ids, ud.doc_ids
- )
- SELECT
- user_docs.*,
- COUNT(*) OVER() AS total_entries
- FROM user_docs
- ORDER BY email
- OFFSET $1
- """
- params: list = [offset]
- if limit != -1:
- query += " LIMIT $2"
- params.append(limit)
- if user_ids:
- params.append(user_ids)
- results = await self.connection_manager.fetch_query(query, params)
- if not results:
- raise R2RException(status_code=404, message="No users found")
- users_list = []
- for row in results:
- users_list.append(
- User(
- id=row["id"],
- email=row["email"],
- is_superuser=row["is_superuser"],
- is_active=row["is_active"],
- is_verified=row["is_verified"],
- name=row["name"],
- bio=row["bio"],
- created_at=row["created_at"],
- updated_at=row["updated_at"],
- profile_picture=row["profile_picture"],
- collection_ids=row["collection_ids"] or [],
- num_files=row["num_files"],
- total_size_in_bytes=row["total_size_in_bytes"],
- document_ids=(
- list(row["document_ids"])
- if row["document_ids"]
- else []
- ),
- )
- )
- total_entries = results[0]["total_entries"]
- return {"results": users_list, "total_entries": total_entries}
- async def _collection_exists(self, collection_id: UUID) -> bool:
- """Check if a collection exists."""
- query = f"""
- SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)}
- WHERE id = $1
- """
- result = await self.connection_manager.fetchrow_query(
- query, [collection_id]
- )
- return result is not None
- async def get_user_validation_data(
- self,
- user_id: UUID,
- ) -> dict:
- """
- Get verification data for a specific user.
- This method should be called after superuser authorization has been verified.
- """
- query = f"""
- SELECT
- verification_code,
- verification_code_expiry,
- reset_token,
- reset_token_expiry
- FROM {self._get_table_name("users")}
- WHERE id = $1
- """
- result = await self.connection_manager.fetchrow_query(query, [user_id])
- if not result:
- raise R2RException(status_code=404, message="User not found")
- return {
- "verification_data": {
- "verification_code": result["verification_code"],
- "verification_code_expiry": (
- result["verification_code_expiry"].isoformat()
- if result["verification_code_expiry"]
- else None
- ),
- "reset_token": result["reset_token"],
- "reset_token_expiry": (
- result["reset_token_expiry"].isoformat()
- if result["reset_token_expiry"]
- else None
- ),
- }
- }
- # API Key methods
- async def store_user_api_key(
- self,
- user_id: UUID,
- key_id: str,
- hashed_key: str,
- name: Optional[str] = None,
- ) -> UUID:
- """Store a new API key for a user."""
- query = f"""
- INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
- (user_id, public_key, hashed_key, name)
- VALUES ($1, $2, $3, $4)
- RETURNING id
- """
- result = await self.connection_manager.fetchrow_query(
- query, [user_id, key_id, hashed_key, name]
- )
- if not result:
- raise R2RException(
- status_code=500, message="Failed to store API key"
- )
- return result["id"]
- async def get_api_key_record(self, key_id: str) -> Optional[dict]:
- """
- Get API key record by 'public_key' and update 'updated_at' to now.
- Returns { "user_id", "hashed_key" } or None if not found.
- """
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
- SET updated_at = NOW()
- WHERE public_key = $1
- RETURNING user_id, hashed_key
- """
- result = await self.connection_manager.fetchrow_query(query, [key_id])
- if not result:
- return None
- return {
- "user_id": result["user_id"],
- "hashed_key": result["hashed_key"],
- }
- async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
- """Get all API keys for a user."""
- query = f"""
- SELECT id, public_key, name, created_at, updated_at
- FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
- WHERE user_id = $1
- ORDER BY created_at DESC
- """
- results = await self.connection_manager.fetch_query(query, [user_id])
- return [
- {
- "key_id": str(row["id"]),
- "public_key": row["public_key"],
- "name": row["name"] or "",
- "updated_at": row["updated_at"],
- }
- for row in results
- ]
- async def delete_api_key(self, user_id: UUID, key_id: UUID) -> dict:
- """Delete a specific API key."""
- query = f"""
- DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
- WHERE id = $1 AND user_id = $2
- RETURNING id, public_key, name
- """
- result = await self.connection_manager.fetchrow_query(
- query, [key_id, user_id]
- )
- if result is None:
- raise R2RException(status_code=404, message="API key not found")
- return {
- "key_id": str(result["id"]),
- "public_key": str(result["public_key"]),
- "name": result["name"] or "",
- }
- async def update_api_key_name(
- self, user_id: UUID, key_id: UUID, name: str
- ) -> bool:
- """Update the name of an existing API key."""
- query = f"""
- UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
- SET name = $1, updated_at = NOW()
- WHERE id = $2 AND user_id = $3
- RETURNING id
- """
- result = await self.connection_manager.fetchrow_query(
- query, [name, key_id, user_id]
- )
- if result is None:
- raise R2RException(status_code=404, message="API key not found")
- return True
- async def export_to_csv(
- self,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- """
- Creates a CSV file from the PostgreSQL data and returns the path to the temp file.
- """
- valid_columns = {
- "id",
- "email",
- "is_superuser",
- "is_active",
- "is_verified",
- "name",
- "bio",
- "collection_ids",
- "created_at",
- "updated_at",
- }
- if not columns:
- columns = list(valid_columns)
- elif invalid_cols := set(columns) - valid_columns:
- raise ValueError(f"Invalid columns: {invalid_cols}")
- select_stmt = f"""
- SELECT
- id::text,
- email,
- is_superuser,
- is_active,
- is_verified,
- name,
- bio,
- to_char(created_at, 'YYYY-MM-DD HH24:MI:SS') AS created_at,
- to_char(updated_at, 'YYYY-MM-DD HH24:MI:SS') AS updated_at
- FROM {self._get_table_name(self.TABLE_NAME)}
- """
- params = []
- if filters:
- conditions = []
- param_index = 1
- for field, value in filters.items():
- if field not in valid_columns:
- continue
- if isinstance(value, dict):
- for op, val in value.items():
- if op == "$eq":
- conditions.append(f"{field} = ${param_index}")
- params.append(val)
- param_index += 1
- elif op == "$gt":
- conditions.append(f"{field} > ${param_index}")
- params.append(val)
- param_index += 1
- elif op == "$lt":
- conditions.append(f"{field} < ${param_index}")
- params.append(val)
- param_index += 1
- else:
- # Direct equality
- conditions.append(f"{field} = ${param_index}")
- params.append(value)
- param_index += 1
- if conditions:
- select_stmt = f"{select_stmt} WHERE {' AND '.join(conditions)}"
- select_stmt = f"{select_stmt} ORDER BY created_at DESC"
- temp_file = None
- try:
- temp_file = tempfile.NamedTemporaryFile(
- mode="w", delete=True, suffix=".csv"
- )
- writer = csv.writer(temp_file, quoting=csv.QUOTE_ALL)
- async with self.connection_manager.pool.get_connection() as conn: # type: ignore
- async with conn.transaction():
- cursor = await conn.cursor(select_stmt, *params)
- if include_header:
- writer.writerow(columns)
- chunk_size = 1000
- while True:
- rows = await cursor.fetch(chunk_size)
- if not rows:
- break
- for row in rows:
- writer.writerow(row)
- temp_file.flush()
- return temp_file.name, temp_file
- except Exception as e:
- if temp_file:
- temp_file.close()
- raise HTTPException(
- status_code=500,
- detail=f"Failed to export data: {str(e)}",
- )
|