from datetime import datetime from typing import Optional from uuid import UUID from fastapi import HTTPException from core.base import CryptoProvider, Handler from core.base.abstractions import R2RException from core.utils import generate_user_id from shared.abstractions import User from .base import PostgresConnectionManager, QueryBuilder from .collections import PostgresCollectionsHandler class PostgresUserHandler(Handler): TABLE_NAME = "users" API_KEYS_TABLE_NAME = "users_api_keys" def __init__( self, project_name: str, connection_manager: PostgresConnectionManager, crypto_provider: CryptoProvider, ): super().__init__(project_name, connection_manager) self.crypto_provider = crypto_provider async def create_tables(self): user_table_query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.TABLE_NAME)} ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), email TEXT UNIQUE NOT NULL, hashed_password TEXT NOT NULL, is_superuser BOOLEAN DEFAULT FALSE, is_active BOOLEAN DEFAULT TRUE, is_verified BOOLEAN DEFAULT FALSE, verification_code TEXT, verification_code_expiry TIMESTAMPTZ, name TEXT, bio TEXT, profile_picture TEXT, reset_token TEXT, reset_token_expiry TIMESTAMPTZ, collection_ids UUID[] NULL, created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW() ); """ # API keys table with updated_at instead of last_used_at api_keys_table_query = f""" CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} ( id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), user_id UUID NOT NULL REFERENCES {self._get_table_name(PostgresUserHandler.TABLE_NAME)}(id) ON DELETE CASCADE, public_key TEXT UNIQUE NOT NULL, hashed_key TEXT NOT NULL, name TEXT, created_at TIMESTAMPTZ DEFAULT NOW(), updated_at TIMESTAMPTZ DEFAULT NOW() ); CREATE INDEX IF NOT EXISTS idx_api_keys_user_id ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(user_id); CREATE INDEX IF NOT EXISTS idx_api_keys_public_key ON {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}(public_key); """ await self.connection_manager.execute_query(user_table_query) await self.connection_manager.execute_query(api_keys_table_query) async def get_user_by_id(self, id: UUID) -> User: query, _ = ( QueryBuilder(self._get_table_name("users")) .select( [ "id", "email", "hashed_password", "is_superuser", "is_active", "is_verified", "created_at", "updated_at", "name", "profile_picture", "bio", "collection_ids", ] ) .where("id = $1") .build() ) result = await self.connection_manager.fetchrow_query(query, [id]) if not result: raise R2RException(status_code=404, message="User not found") return User( id=result["id"], email=result["email"], hashed_password=result["hashed_password"], is_superuser=result["is_superuser"], is_active=result["is_active"], is_verified=result["is_verified"], created_at=result["created_at"], updated_at=result["updated_at"], name=result["name"], profile_picture=result["profile_picture"], bio=result["bio"], collection_ids=result["collection_ids"], ) async def get_user_by_email(self, email: str) -> User: query, params = ( QueryBuilder(self._get_table_name("users")) .select( [ "id", "email", "hashed_password", "is_superuser", "is_active", "is_verified", "created_at", "updated_at", "name", "profile_picture", "bio", "collection_ids", ] ) .where("email = $1") .build() ) result = await self.connection_manager.fetchrow_query(query, [email]) if not result: raise R2RException(status_code=404, message="User not found") return User( id=result["id"], email=result["email"], hashed_password=result["hashed_password"], is_superuser=result["is_superuser"], is_active=result["is_active"], is_verified=result["is_verified"], created_at=result["created_at"], updated_at=result["updated_at"], name=result["name"], profile_picture=result["profile_picture"], bio=result["bio"], collection_ids=result["collection_ids"], ) async def create_user( self, email: str, password: str, is_superuser: bool = False ) -> User: try: if await self.get_user_by_email(email): raise R2RException( status_code=400, message="User with this email already exists", ) except R2RException as e: if e.status_code != 404: raise e hashed_password = self.crypto_provider.get_password_hash(password) # type: ignore query = f""" INSERT INTO {self._get_table_name(PostgresUserHandler.TABLE_NAME)} (email, id, is_superuser, hashed_password, collection_ids) VALUES ($1, $2, $3, $4, $5) RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids """ result = await self.connection_manager.fetchrow_query( query, [ email, generate_user_id(email), is_superuser, hashed_password, [], ], ) if not result: raise HTTPException( status_code=500, detail="Failed to create user", ) return User( id=result["id"], email=result["email"], is_superuser=result["is_superuser"], is_active=result["is_active"], is_verified=result["is_verified"], created_at=result["created_at"], updated_at=result["updated_at"], collection_ids=result["collection_ids"], hashed_password=hashed_password, ) async def update_user(self, user: User) -> User: query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET email = $1, is_superuser = $2, is_active = $3, is_verified = $4, updated_at = NOW(), name = $5, profile_picture = $6, bio = $7, collection_ids = $8 WHERE id = $9 RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, name, profile_picture, bio, collection_ids """ result = await self.connection_manager.fetchrow_query( query, [ user.email, user.is_superuser, user.is_active, user.is_verified, user.name, user.profile_picture, user.bio, user.collection_ids, user.id, ], ) if not result: raise HTTPException( status_code=500, detail="Failed to update user", ) return User( id=result["id"], email=result["email"], is_superuser=result["is_superuser"], is_active=result["is_active"], is_verified=result["is_verified"], created_at=result["created_at"], updated_at=result["updated_at"], name=result["name"], profile_picture=result["profile_picture"], bio=result["bio"], collection_ids=result["collection_ids"], ) async def delete_user_relational(self, id: UUID) -> None: # Get the collections the user belongs to collection_query = f""" SELECT collection_ids FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} WHERE id = $1 """ collection_result = await self.connection_manager.fetchrow_query( collection_query, [id] ) if not collection_result: raise R2RException(status_code=404, message="User not found") # Remove user from documents doc_update_query = f""" UPDATE {self._get_table_name('documents')} SET id = NULL WHERE id = $1 """ await self.connection_manager.execute_query(doc_update_query, [id]) # Delete the user delete_query = f""" DELETE FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} WHERE id = $1 RETURNING id """ result = await self.connection_manager.fetchrow_query( delete_query, [id] ) if not result: raise R2RException(status_code=404, message="User not found") async def update_user_password(self, id: UUID, new_hashed_password: str): query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET hashed_password = $1, updated_at = NOW() WHERE id = $2 """ await self.connection_manager.execute_query( query, [new_hashed_password, id] ) async def get_all_users(self) -> list[User]: query = f""" SELECT id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} """ results = await self.connection_manager.fetch_query(query) return [ User( id=result["id"], email=result["email"], hashed_password="null", is_superuser=result["is_superuser"], is_active=result["is_active"], is_verified=result["is_verified"], created_at=result["created_at"], updated_at=result["updated_at"], collection_ids=result["collection_ids"], ) for result in results ] async def store_verification_code( self, id: UUID, verification_code: str, expiry: datetime ): query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET verification_code = $1, verification_code_expiry = $2 WHERE id = $3 """ await self.connection_manager.execute_query( query, [verification_code, expiry, id] ) async def verify_user(self, verification_code: str) -> None: query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL WHERE verification_code = $1 AND verification_code_expiry > NOW() RETURNING id """ result = await self.connection_manager.fetchrow_query( query, [verification_code] ) if not result: raise R2RException( status_code=400, message="Invalid or expired verification code" ) async def remove_verification_code(self, verification_code: str): query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET verification_code = NULL, verification_code_expiry = NULL WHERE verification_code = $1 """ await self.connection_manager.execute_query(query, [verification_code]) async def expire_verification_code(self, id: UUID): query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET verification_code_expiry = NOW() - INTERVAL '1 day' WHERE id = $1 """ await self.connection_manager.execute_query(query, [id]) async def store_reset_token( self, id: UUID, reset_token: str, expiry: datetime ): query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET reset_token = $1, reset_token_expiry = $2 WHERE id = $3 """ await self.connection_manager.execute_query( query, [reset_token, expiry, id] ) async def get_user_id_by_reset_token( self, reset_token: str ) -> Optional[UUID]: query = f""" SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} WHERE reset_token = $1 AND reset_token_expiry > NOW() """ result = await self.connection_manager.fetchrow_query( query, [reset_token] ) return result["id"] if result else None async def remove_reset_token(self, id: UUID): query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET reset_token = NULL, reset_token_expiry = NULL WHERE id = $1 """ await self.connection_manager.execute_query(query, [id]) async def remove_user_from_all_collections(self, id: UUID): query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET collection_ids = ARRAY[]::UUID[] WHERE id = $1 """ await self.connection_manager.execute_query(query, [id]) async def add_user_to_collection( self, id: UUID, collection_id: UUID ) -> bool: # Check if the user exists if not await self.get_user_by_id(id): raise R2RException(status_code=404, message="User not found") # Check if the collection exists if not await self._collection_exists(collection_id): raise R2RException(status_code=404, message="Collection not found") query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET collection_ids = array_append(collection_ids, $1) WHERE id = $2 AND NOT ($1 = ANY(collection_ids)) RETURNING id """ result = await self.connection_manager.fetchrow_query( query, [collection_id, id] ) if not result: raise R2RException( status_code=400, message="User already in collection" ) update_collection_query = f""" UPDATE {self._get_table_name('collections')} SET user_count = user_count + 1 WHERE id = $1 """ await self.connection_manager.execute_query( query=update_collection_query, params=[collection_id], ) return True async def remove_user_from_collection( self, id: UUID, collection_id: UUID ) -> bool: if not await self.get_user_by_id(id): raise R2RException(status_code=404, message="User not found") query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET collection_ids = array_remove(collection_ids, $1) WHERE id = $2 AND $1 = ANY(collection_ids) RETURNING id """ result = await self.connection_manager.fetchrow_query( query, [collection_id, id] ) if not result: raise R2RException( status_code=400, message="User is not a member of the specified collection", ) return True async def get_users_in_collection( self, collection_id: UUID, offset: int, limit: int ) -> dict[str, list[User] | int]: """ Get all users in a specific collection with pagination. Args: collection_id (UUID): The ID of the collection to get users from. offset (int): The number of users to skip. limit (int): The maximum number of users to return. Returns: List[User]: A list of User objects representing the users in the collection. Raises: R2RException: If the collection doesn't exist. """ if not await self._collection_exists(collection_id): # type: ignore raise R2RException(status_code=404, message="Collection not found") query = f""" SELECT u.id, u.email, u.is_active, u.is_superuser, u.created_at, u.updated_at, u.is_verified, u.collection_ids, u.name, u.bio, u.profile_picture, COUNT(*) OVER() AS total_entries FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u WHERE $1 = ANY(u.collection_ids) ORDER BY u.name OFFSET $2 """ conditions = [collection_id, offset] if limit != -1: query += " LIMIT $3" conditions.append(limit) results = await self.connection_manager.fetch_query(query, conditions) users = [ User( id=row["id"], email=row["email"], is_active=row["is_active"], is_superuser=row["is_superuser"], created_at=row["created_at"], updated_at=row["updated_at"], is_verified=row["is_verified"], collection_ids=row["collection_ids"], name=row["name"], bio=row["bio"], profile_picture=row["profile_picture"], hashed_password=None, verification_code_expiry=None, ) for row in results ] total_entries = results[0]["total_entries"] if results else 0 return {"results": users, "total_entries": total_entries} async def mark_user_as_superuser(self, id: UUID): query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET is_superuser = TRUE, is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL WHERE id = $1 """ await self.connection_manager.execute_query(query, [id]) async def get_user_id_by_verification_code( self, verification_code: str ) -> UUID: query = f""" SELECT id FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} WHERE verification_code = $1 AND verification_code_expiry > NOW() """ result = await self.connection_manager.fetchrow_query( query, [verification_code] ) if not result: raise R2RException( status_code=400, message="Invalid or expired verification code" ) return result["id"] async def mark_user_as_verified(self, id: UUID): query = f""" UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)} SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL WHERE id = $1 """ await self.connection_manager.execute_query(query, [id]) async def get_users_overview( self, offset: int, limit: int, user_ids: Optional[list[UUID]] = None, ) -> dict[str, list[User] | int]: query = f""" WITH user_document_ids AS ( SELECT u.id as user_id, ARRAY_AGG(d.id) FILTER (WHERE d.id IS NOT NULL) AS doc_ids FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u LEFT JOIN {self._get_table_name('documents')} d ON u.id = d.owner_id GROUP BY u.id ), user_docs AS ( SELECT u.id, u.email, u.is_superuser, u.is_active, u.is_verified, u.name, u.bio, u.profile_picture, u.collection_ids, u.created_at, u.updated_at, COUNT(d.id) AS num_files, COALESCE(SUM(d.size_in_bytes), 0) AS total_size_in_bytes, ud.doc_ids as document_ids FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u LEFT JOIN {self._get_table_name('documents')} d ON u.id = d.owner_id LEFT JOIN user_document_ids ud ON u.id = ud.user_id {' WHERE u.id = ANY($3::uuid[])' if user_ids else ''} GROUP BY u.id, u.email, u.is_superuser, u.is_active, u.is_verified, u.created_at, u.updated_at, u.collection_ids, ud.doc_ids ) SELECT user_docs.*, COUNT(*) OVER() AS total_entries FROM user_docs ORDER BY email OFFSET $1 """ params: list = [offset] if limit != -1: query += " LIMIT $2" params.append(limit) if user_ids: params.append(user_ids) results = await self.connection_manager.fetch_query(query, params) users = [ User( id=row["id"], email=row["email"], is_superuser=row["is_superuser"], is_active=row["is_active"], is_verified=row["is_verified"], name=row["name"], bio=row["bio"], created_at=row["created_at"], updated_at=row["updated_at"], collection_ids=row["collection_ids"] or [], num_files=row["num_files"], total_size_in_bytes=row["total_size_in_bytes"], document_ids=( [] if row["document_ids"] is None else list(row["document_ids"]) ), ) for row in results ] if not users: raise R2RException(status_code=404, message="No users found") total_entries = results[0]["total_entries"] return {"results": users, "total_entries": total_entries} async def _collection_exists(self, collection_id: UUID) -> bool: """Check if a collection exists.""" query = f""" SELECT 1 FROM {self._get_table_name(PostgresCollectionsHandler.TABLE_NAME)} WHERE id = $1 """ result = await self.connection_manager.fetchrow_query( query, [collection_id] ) return result is not None async def get_user_validation_data( self, user_id: UUID, ) -> dict: """ Get verification data for a specific user. This method should be called after superuser authorization has been verified. """ query = f""" SELECT verification_code, verification_code_expiry, reset_token, reset_token_expiry FROM {self._get_table_name("users")} WHERE id = $1 """ result = await self.connection_manager.fetchrow_query(query, [user_id]) if not result: raise R2RException(status_code=404, message="User not found") return { "verification_data": { "verification_code": result["verification_code"], "verification_code_expiry": ( result["verification_code_expiry"].isoformat() if result["verification_code_expiry"] else None ), "reset_token": result["reset_token"], "reset_token_expiry": ( result["reset_token_expiry"].isoformat() if result["reset_token_expiry"] else None ), } } # API Key methods async def store_user_api_key( self, user_id: UUID, key_id: str, hashed_key: str, name: Optional[str] = None, ) -> UUID: """Store a new API key for a user""" query = f""" INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (user_id, public_key, hashed_key, name) VALUES ($1, $2, $3, $4) RETURNING id """ result = await self.connection_manager.fetchrow_query( query, [user_id, key_id, hashed_key, name] ) if not result: raise R2RException( status_code=500, message="Failed to store API key" ) return result["id"] async def get_api_key_record(self, key_id: str) -> Optional[dict]: """Get API key record and update updated_at""" query = f""" UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} SET updated_at = NOW() WHERE public_key = $1 RETURNING user_id, hashed_key """ result = await self.connection_manager.fetchrow_query(query, [key_id]) if not result: return None return { "user_id": result["user_id"], "hashed_key": result["hashed_key"], } async def get_user_api_keys(self, user_id: UUID) -> list[dict]: """Get all API keys for a user""" query = f""" SELECT id, public_key, name, created_at, updated_at FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} WHERE user_id = $1 ORDER BY created_at DESC """ results = await self.connection_manager.fetch_query(query, [user_id]) return [ { "key_id": str(row["id"]), "public_key": row["public_key"], "name": row["name"] or "", "updated_at": row["updated_at"], } for row in results ] async def delete_api_key(self, user_id: UUID, key_id: UUID) -> dict: """Delete a specific API key""" query = f""" DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} WHERE id = $1 AND user_id = $2 RETURNING id, public_key, name """ result = await self.connection_manager.fetchrow_query( query, [key_id, user_id] ) if result is None: raise R2RException(status_code=404, message="API key not found") return { "key_id": str(result["id"]), "public_key": str(result["public_key"]), "name": result["name"] or "", } async def update_api_key_name( self, user_id: UUID, key_id: UUID, name: str ) -> bool: """Update the name of an API key""" query = f""" UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} SET name = $1, updated_at = NOW() WHERE id = $2 AND user_id = $3 RETURNING id """ result = await self.connection_manager.fetchrow_query( query, [name, key_id, user_id] ) if result is None: raise R2RException(status_code=404, message="API key not found") return True