|
@@ -1,5 +1,6 @@
|
|
|
|
+import json
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
-from typing import Optional
|
|
|
|
|
|
+from typing import Any, Dict, List, Optional
|
|
from uuid import UUID
|
|
from uuid import UUID
|
|
|
|
|
|
from fastapi import HTTPException
|
|
from fastapi import HTTPException
|
|
@@ -43,10 +44,12 @@ class PostgresUserHandler(Handler):
|
|
reset_token TEXT,
|
|
reset_token TEXT,
|
|
reset_token_expiry TIMESTAMPTZ,
|
|
reset_token_expiry TIMESTAMPTZ,
|
|
collection_ids UUID[] NULL,
|
|
collection_ids UUID[] NULL,
|
|
|
|
+ limits_overrides JSONB,
|
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
|
created_at TIMESTAMPTZ DEFAULT NOW(),
|
|
updated_at TIMESTAMPTZ DEFAULT NOW()
|
|
updated_at TIMESTAMPTZ DEFAULT NOW()
|
|
);
|
|
);
|
|
"""
|
|
"""
|
|
|
|
+
|
|
# API keys table with updated_at instead of last_used_at
|
|
# API keys table with updated_at instead of last_used_at
|
|
api_keys_table_query = f"""
|
|
api_keys_table_query = f"""
|
|
CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
|
|
CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)} (
|
|
@@ -86,6 +89,7 @@ class PostgresUserHandler(Handler):
|
|
"profile_picture",
|
|
"profile_picture",
|
|
"bio",
|
|
"bio",
|
|
"collection_ids",
|
|
"collection_ids",
|
|
|
|
+ "limits_overrides", # Fetch JSONB column
|
|
]
|
|
]
|
|
)
|
|
)
|
|
.where("id = $1")
|
|
.where("id = $1")
|
|
@@ -109,6 +113,8 @@ class PostgresUserHandler(Handler):
|
|
profile_picture=result["profile_picture"],
|
|
profile_picture=result["profile_picture"],
|
|
bio=result["bio"],
|
|
bio=result["bio"],
|
|
collection_ids=result["collection_ids"],
|
|
collection_ids=result["collection_ids"],
|
|
|
|
+ # Add the new field
|
|
|
|
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
|
|
)
|
|
)
|
|
|
|
|
|
async def get_user_by_email(self, email: str) -> User:
|
|
async def get_user_by_email(self, email: str) -> User:
|
|
@@ -128,6 +134,7 @@ class PostgresUserHandler(Handler):
|
|
"profile_picture",
|
|
"profile_picture",
|
|
"bio",
|
|
"bio",
|
|
"collection_ids",
|
|
"collection_ids",
|
|
|
|
+ "limits_overrides",
|
|
]
|
|
]
|
|
)
|
|
)
|
|
.where("email = $1")
|
|
.where("email = $1")
|
|
@@ -150,13 +157,16 @@ class PostgresUserHandler(Handler):
|
|
profile_picture=result["profile_picture"],
|
|
profile_picture=result["profile_picture"],
|
|
bio=result["bio"],
|
|
bio=result["bio"],
|
|
collection_ids=result["collection_ids"],
|
|
collection_ids=result["collection_ids"],
|
|
|
|
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
|
|
)
|
|
)
|
|
|
|
|
|
async def create_user(
|
|
async def create_user(
|
|
self, email: str, password: str, is_superuser: bool = False
|
|
self, email: str, password: str, is_superuser: bool = False
|
|
) -> User:
|
|
) -> User:
|
|
|
|
+ """Create a new user."""
|
|
try:
|
|
try:
|
|
- if await self.get_user_by_email(email):
|
|
|
|
|
|
+ existing = await self.get_user_by_email(email)
|
|
|
|
+ if existing:
|
|
raise R2RException(
|
|
raise R2RException(
|
|
status_code=400,
|
|
status_code=400,
|
|
message="User with this email already exists",
|
|
message="User with this email already exists",
|
|
@@ -166,27 +176,39 @@ class PostgresUserHandler(Handler):
|
|
raise e
|
|
raise e
|
|
|
|
|
|
hashed_password = self.crypto_provider.get_password_hash(password) # type: ignore
|
|
hashed_password = self.crypto_provider.get_password_hash(password) # type: ignore
|
|
- query = f"""
|
|
|
|
- INSERT INTO {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
|
|
|
|
- (email, id, is_superuser, hashed_password, collection_ids)
|
|
|
|
- VALUES ($1, $2, $3, $4, $5)
|
|
|
|
- RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
|
|
|
|
- """
|
|
|
|
- result = await self.connection_manager.fetchrow_query(
|
|
|
|
- query,
|
|
|
|
- [
|
|
|
|
- email,
|
|
|
|
- generate_user_id(email),
|
|
|
|
- is_superuser,
|
|
|
|
- hashed_password,
|
|
|
|
- [],
|
|
|
|
- ],
|
|
|
|
|
|
+ query, params = (
|
|
|
|
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
|
|
|
|
+ .insert(
|
|
|
|
+ {
|
|
|
|
+ "email": email,
|
|
|
|
+ "id": generate_user_id(email),
|
|
|
|
+ "is_superuser": is_superuser,
|
|
|
|
+ "hashed_password": hashed_password,
|
|
|
|
+ "collection_ids": [],
|
|
|
|
+ "limits_overrides": None,
|
|
|
|
+ }
|
|
|
|
+ )
|
|
|
|
+ .returning(
|
|
|
|
+ [
|
|
|
|
+ "id",
|
|
|
|
+ "email",
|
|
|
|
+ "is_superuser",
|
|
|
|
+ "is_active",
|
|
|
|
+ "is_verified",
|
|
|
|
+ "created_at",
|
|
|
|
+ "updated_at",
|
|
|
|
+ "collection_ids",
|
|
|
|
+ "limits_overrides",
|
|
|
|
+ ]
|
|
|
|
+ )
|
|
|
|
+ .build()
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ result = await self.connection_manager.fetchrow_query(query, params)
|
|
if not result:
|
|
if not result:
|
|
- raise HTTPException(
|
|
|
|
|
|
+ raise R2RException(
|
|
status_code=500,
|
|
status_code=500,
|
|
- detail="Failed to create user",
|
|
|
|
|
|
+ message="Failed to create user",
|
|
)
|
|
)
|
|
|
|
|
|
return User(
|
|
return User(
|
|
@@ -197,17 +219,62 @@ class PostgresUserHandler(Handler):
|
|
is_verified=result["is_verified"],
|
|
is_verified=result["is_verified"],
|
|
created_at=result["created_at"],
|
|
created_at=result["created_at"],
|
|
updated_at=result["updated_at"],
|
|
updated_at=result["updated_at"],
|
|
- collection_ids=result["collection_ids"],
|
|
|
|
|
|
+ collection_ids=result["collection_ids"] or [],
|
|
hashed_password=hashed_password,
|
|
hashed_password=hashed_password,
|
|
|
|
+ limits_overrides=json.loads(result["limits_overrides"] or "{}"),
|
|
|
|
+ name=None,
|
|
|
|
+ bio=None,
|
|
|
|
+ profile_picture=None,
|
|
)
|
|
)
|
|
|
|
|
|
- async def update_user(self, user: User) -> User:
|
|
|
|
|
|
+ async def update_user(
|
|
|
|
+ self, user: User, merge_limits: bool = False
|
|
|
|
+ ) -> User:
|
|
|
|
+ """
|
|
|
|
+ Update user information including limits_overrides.
|
|
|
|
+
|
|
|
|
+ Args:
|
|
|
|
+ user: User object containing updated information
|
|
|
|
+ merge_limits: If True, will merge existing limits_overrides with new ones.
|
|
|
|
+ If False, will overwrite existing limits_overrides.
|
|
|
|
+
|
|
|
|
+ Returns:
|
|
|
|
+ Updated User object
|
|
|
|
+ """
|
|
|
|
+ # Get current user if we need to merge limits or get hashed password
|
|
|
|
+ current_user = None
|
|
|
|
+ try:
|
|
|
|
+ current_user = await self.get_user_by_id(user.id)
|
|
|
|
+ except R2RException:
|
|
|
|
+ raise R2RException(status_code=404, message="User not found")
|
|
|
|
+
|
|
|
|
+ # Merge or replace limits_overrides
|
|
|
|
+ final_limits = user.limits_overrides
|
|
|
|
+ if (
|
|
|
|
+ merge_limits
|
|
|
|
+ and current_user.limits_overrides
|
|
|
|
+ and user.limits_overrides
|
|
|
|
+ ):
|
|
|
|
+ final_limits = {
|
|
|
|
+ **current_user.limits_overrides,
|
|
|
|
+ **user.limits_overrides,
|
|
|
|
+ }
|
|
query = f"""
|
|
query = f"""
|
|
UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
|
|
UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
|
|
- SET email = $1, is_superuser = $2, is_active = $3, is_verified = $4, updated_at = NOW(),
|
|
|
|
- name = $5, profile_picture = $6, bio = $7, collection_ids = $8
|
|
|
|
- WHERE id = $9
|
|
|
|
- RETURNING id, email, is_superuser, is_active, is_verified, created_at, updated_at, name, profile_picture, bio, collection_ids
|
|
|
|
|
|
+ SET email = $1,
|
|
|
|
+ is_superuser = $2,
|
|
|
|
+ is_active = $3,
|
|
|
|
+ is_verified = $4,
|
|
|
|
+ updated_at = NOW(),
|
|
|
|
+ name = $5,
|
|
|
|
+ profile_picture = $6,
|
|
|
|
+ bio = $7,
|
|
|
|
+ collection_ids = $8,
|
|
|
|
+ limits_overrides = $9::jsonb
|
|
|
|
+ WHERE id = $10
|
|
|
|
+ RETURNING id, email, is_superuser, is_active, is_verified,
|
|
|
|
+ created_at, updated_at, name, profile_picture, bio,
|
|
|
|
+ collection_ids, limits_overrides, hashed_password
|
|
"""
|
|
"""
|
|
result = await self.connection_manager.fetchrow_query(
|
|
result = await self.connection_manager.fetchrow_query(
|
|
query,
|
|
query,
|
|
@@ -219,7 +286,8 @@ class PostgresUserHandler(Handler):
|
|
user.name,
|
|
user.name,
|
|
user.profile_picture,
|
|
user.profile_picture,
|
|
user.bio,
|
|
user.bio,
|
|
- user.collection_ids,
|
|
|
|
|
|
+ user.collection_ids or [], # Ensure null becomes empty array
|
|
|
|
+ json.dumps(final_limits), # Already handled null case
|
|
user.id,
|
|
user.id,
|
|
],
|
|
],
|
|
)
|
|
)
|
|
@@ -233,6 +301,9 @@ class PostgresUserHandler(Handler):
|
|
return User(
|
|
return User(
|
|
id=result["id"],
|
|
id=result["id"],
|
|
email=result["email"],
|
|
email=result["email"],
|
|
|
|
+ hashed_password=result[
|
|
|
|
+ "hashed_password"
|
|
|
|
+ ], # Include hashed_password
|
|
is_superuser=result["is_superuser"],
|
|
is_superuser=result["is_superuser"],
|
|
is_active=result["is_active"],
|
|
is_active=result["is_active"],
|
|
is_verified=result["is_verified"],
|
|
is_verified=result["is_verified"],
|
|
@@ -241,15 +312,23 @@ class PostgresUserHandler(Handler):
|
|
name=result["name"],
|
|
name=result["name"],
|
|
profile_picture=result["profile_picture"],
|
|
profile_picture=result["profile_picture"],
|
|
bio=result["bio"],
|
|
bio=result["bio"],
|
|
- collection_ids=result["collection_ids"],
|
|
|
|
|
|
+ collection_ids=result["collection_ids"]
|
|
|
|
+ or [], # Ensure null becomes empty array
|
|
|
|
+ limits_overrides=json.loads(
|
|
|
|
+ result["limits_overrides"] or "{}"
|
|
|
|
+ ), # Can be null
|
|
)
|
|
)
|
|
|
|
|
|
async def delete_user_relational(self, id: UUID) -> None:
|
|
async def delete_user_relational(self, id: UUID) -> None:
|
|
|
|
+ """Delete a user and update related records."""
|
|
# Get the collections the user belongs to
|
|
# Get the collections the user belongs to
|
|
- collection_query = f"""
|
|
|
|
- SELECT collection_ids FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
|
|
|
|
- WHERE id = $1
|
|
|
|
- """
|
|
|
|
|
|
+ collection_query, params = (
|
|
|
|
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
|
|
|
|
+ .select(["collection_ids"])
|
|
|
|
+ .where("id = $1")
|
|
|
|
+ .build()
|
|
|
|
+ )
|
|
|
|
+
|
|
collection_result = await self.connection_manager.fetchrow_query(
|
|
collection_result = await self.connection_manager.fetchrow_query(
|
|
collection_query, [id]
|
|
collection_query, [id]
|
|
)
|
|
)
|
|
@@ -257,20 +336,25 @@ class PostgresUserHandler(Handler):
|
|
if not collection_result:
|
|
if not collection_result:
|
|
raise R2RException(status_code=404, message="User not found")
|
|
raise R2RException(status_code=404, message="User not found")
|
|
|
|
|
|
- # Remove user from documents
|
|
|
|
- doc_update_query = f"""
|
|
|
|
- UPDATE {self._get_table_name('documents')}
|
|
|
|
- SET id = NULL
|
|
|
|
- WHERE id = $1
|
|
|
|
- """
|
|
|
|
|
|
+ # Update documents query
|
|
|
|
+ doc_update_query, doc_params = (
|
|
|
|
+ QueryBuilder(self._get_table_name("documents"))
|
|
|
|
+ .update({"id": None})
|
|
|
|
+ .where("id = $1")
|
|
|
|
+ .build()
|
|
|
|
+ )
|
|
|
|
+
|
|
await self.connection_manager.execute_query(doc_update_query, [id])
|
|
await self.connection_manager.execute_query(doc_update_query, [id])
|
|
|
|
|
|
- # Delete the user
|
|
|
|
- delete_query = f"""
|
|
|
|
- DELETE FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
|
|
|
|
- WHERE id = $1
|
|
|
|
- RETURNING id
|
|
|
|
- """
|
|
|
|
|
|
+ # Delete user query
|
|
|
|
+ delete_query, del_params = (
|
|
|
|
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
|
|
|
|
+ .delete()
|
|
|
|
+ .where("id = $1")
|
|
|
|
+ .returning(["id"])
|
|
|
|
+ .build()
|
|
|
|
+ )
|
|
|
|
+
|
|
result = await self.connection_manager.fetchrow_query(
|
|
result = await self.connection_manager.fetchrow_query(
|
|
delete_query, [id]
|
|
delete_query, [id]
|
|
)
|
|
)
|
|
@@ -288,24 +372,48 @@ class PostgresUserHandler(Handler):
|
|
query, [new_hashed_password, id]
|
|
query, [new_hashed_password, id]
|
|
)
|
|
)
|
|
|
|
|
|
- async def get_all_users(self) -> list[User]:
|
|
|
|
- query = f"""
|
|
|
|
- SELECT id, email, is_superuser, is_active, is_verified, created_at, updated_at, collection_ids
|
|
|
|
- FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
|
|
|
|
- """
|
|
|
|
- results = await self.connection_manager.fetch_query(query)
|
|
|
|
|
|
+ async def get_all_users(self) -> List[User]:
|
|
|
|
+ """Get all users with minimal information."""
|
|
|
|
+ query, params = (
|
|
|
|
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
|
|
|
|
+ .select(
|
|
|
|
+ [
|
|
|
|
+ "id",
|
|
|
|
+ "email",
|
|
|
|
+ "is_superuser",
|
|
|
|
+ "is_active",
|
|
|
|
+ "is_verified",
|
|
|
|
+ "created_at",
|
|
|
|
+ "updated_at",
|
|
|
|
+ "collection_ids",
|
|
|
|
+ "hashed_password",
|
|
|
|
+ "limits_overrides",
|
|
|
|
+ "name",
|
|
|
|
+ "bio",
|
|
|
|
+ "profile_picture",
|
|
|
|
+ ]
|
|
|
|
+ )
|
|
|
|
+ .build()
|
|
|
|
+ )
|
|
|
|
|
|
|
|
+ results = await self.connection_manager.fetch_query(query, params)
|
|
return [
|
|
return [
|
|
User(
|
|
User(
|
|
id=result["id"],
|
|
id=result["id"],
|
|
email=result["email"],
|
|
email=result["email"],
|
|
- hashed_password="null",
|
|
|
|
|
|
+ hashed_password=result["hashed_password"],
|
|
is_superuser=result["is_superuser"],
|
|
is_superuser=result["is_superuser"],
|
|
is_active=result["is_active"],
|
|
is_active=result["is_active"],
|
|
is_verified=result["is_verified"],
|
|
is_verified=result["is_verified"],
|
|
created_at=result["created_at"],
|
|
created_at=result["created_at"],
|
|
updated_at=result["updated_at"],
|
|
updated_at=result["updated_at"],
|
|
- collection_ids=result["collection_ids"],
|
|
|
|
|
|
+ collection_ids=result["collection_ids"] or [],
|
|
|
|
+ limits_overrides=json.loads(
|
|
|
|
+ result["limits_overrides"] or "{}"
|
|
|
|
+ ),
|
|
|
|
+ name=result["name"],
|
|
|
|
+ bio=result["bio"],
|
|
|
|
+ profile_picture=result["profile_picture"],
|
|
)
|
|
)
|
|
for result in results
|
|
for result in results
|
|
]
|
|
]
|
|
@@ -456,41 +564,44 @@ class PostgresUserHandler(Handler):
|
|
async def get_users_in_collection(
|
|
async def get_users_in_collection(
|
|
self, collection_id: UUID, offset: int, limit: int
|
|
self, collection_id: UUID, offset: int, limit: int
|
|
) -> dict[str, list[User] | int]:
|
|
) -> dict[str, list[User] | int]:
|
|
- """
|
|
|
|
- Get all users in a specific collection with pagination.
|
|
|
|
-
|
|
|
|
- Args:
|
|
|
|
- collection_id (UUID): The ID of the collection to get users from.
|
|
|
|
- offset (int): The number of users to skip.
|
|
|
|
- limit (int): The maximum number of users to return.
|
|
|
|
-
|
|
|
|
- Returns:
|
|
|
|
- List[User]: A list of User objects representing the users in the collection.
|
|
|
|
-
|
|
|
|
- Raises:
|
|
|
|
- R2RException: If the collection doesn't exist.
|
|
|
|
- """
|
|
|
|
- if not await self._collection_exists(collection_id): # type: ignore
|
|
|
|
|
|
+ """Get all users in a specific collection with pagination."""
|
|
|
|
+ if not await self._collection_exists(collection_id):
|
|
raise R2RException(status_code=404, message="Collection not found")
|
|
raise R2RException(status_code=404, message="Collection not found")
|
|
|
|
|
|
- query = f"""
|
|
|
|
- SELECT u.id, u.email, u.is_active, u.is_superuser, u.created_at, u.updated_at,
|
|
|
|
- u.is_verified, u.collection_ids, u.name, u.bio, u.profile_picture,
|
|
|
|
- COUNT(*) OVER() AS total_entries
|
|
|
|
- FROM {self._get_table_name(PostgresUserHandler.TABLE_NAME)} u
|
|
|
|
- WHERE $1 = ANY(u.collection_ids)
|
|
|
|
- ORDER BY u.name
|
|
|
|
- OFFSET $2
|
|
|
|
- """
|
|
|
|
|
|
+ query, params = (
|
|
|
|
+ QueryBuilder(self._get_table_name(self.TABLE_NAME))
|
|
|
|
+ .select(
|
|
|
|
+ [
|
|
|
|
+ "id",
|
|
|
|
+ "email",
|
|
|
|
+ "is_active",
|
|
|
|
+ "is_superuser",
|
|
|
|
+ "created_at",
|
|
|
|
+ "updated_at",
|
|
|
|
+ "is_verified",
|
|
|
|
+ "collection_ids",
|
|
|
|
+ "name",
|
|
|
|
+ "bio",
|
|
|
|
+ "profile_picture",
|
|
|
|
+ "hashed_password",
|
|
|
|
+ "limits_overrides",
|
|
|
|
+ "COUNT(*) OVER() AS total_entries",
|
|
|
|
+ ]
|
|
|
|
+ )
|
|
|
|
+ .where("$1 = ANY(collection_ids)")
|
|
|
|
+ .order_by("name")
|
|
|
|
+ .offset("$2")
|
|
|
|
+ .limit("$3" if limit != -1 else None)
|
|
|
|
+ .build()
|
|
|
|
+ )
|
|
|
|
|
|
conditions = [collection_id, offset]
|
|
conditions = [collection_id, offset]
|
|
if limit != -1:
|
|
if limit != -1:
|
|
- query += " LIMIT $3"
|
|
|
|
conditions.append(limit)
|
|
conditions.append(limit)
|
|
|
|
|
|
results = await self.connection_manager.fetch_query(query, conditions)
|
|
results = await self.connection_manager.fetch_query(query, conditions)
|
|
|
|
|
|
- users = [
|
|
|
|
|
|
+ users_list = [
|
|
User(
|
|
User(
|
|
id=row["id"],
|
|
id=row["id"],
|
|
email=row["email"],
|
|
email=row["email"],
|
|
@@ -499,24 +610,24 @@ class PostgresUserHandler(Handler):
|
|
created_at=row["created_at"],
|
|
created_at=row["created_at"],
|
|
updated_at=row["updated_at"],
|
|
updated_at=row["updated_at"],
|
|
is_verified=row["is_verified"],
|
|
is_verified=row["is_verified"],
|
|
- collection_ids=row["collection_ids"],
|
|
|
|
|
|
+ collection_ids=row["collection_ids"] or [],
|
|
name=row["name"],
|
|
name=row["name"],
|
|
bio=row["bio"],
|
|
bio=row["bio"],
|
|
profile_picture=row["profile_picture"],
|
|
profile_picture=row["profile_picture"],
|
|
- hashed_password=None,
|
|
|
|
- verification_code_expiry=None,
|
|
|
|
|
|
+ hashed_password=row["hashed_password"],
|
|
|
|
+ limits_overrides=json.loads(row["limits_overrides"] or "{}"),
|
|
)
|
|
)
|
|
for row in results
|
|
for row in results
|
|
]
|
|
]
|
|
|
|
|
|
total_entries = results[0]["total_entries"] if results else 0
|
|
total_entries = results[0]["total_entries"] if results else 0
|
|
-
|
|
|
|
- return {"results": users, "total_entries": total_entries}
|
|
|
|
|
|
+ return {"results": users_list, "total_entries": total_entries}
|
|
|
|
|
|
async def mark_user_as_superuser(self, id: UUID):
|
|
async def mark_user_as_superuser(self, id: UUID):
|
|
query = f"""
|
|
query = f"""
|
|
UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
|
|
UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
|
|
- SET is_superuser = TRUE, is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
|
|
|
|
|
|
+ SET is_superuser = TRUE, is_verified = TRUE,
|
|
|
|
+ verification_code = NULL, verification_code_expiry = NULL
|
|
WHERE id = $1
|
|
WHERE id = $1
|
|
"""
|
|
"""
|
|
await self.connection_manager.execute_query(query, [id])
|
|
await self.connection_manager.execute_query(query, [id])
|
|
@@ -542,7 +653,9 @@ class PostgresUserHandler(Handler):
|
|
async def mark_user_as_verified(self, id: UUID):
|
|
async def mark_user_as_verified(self, id: UUID):
|
|
query = f"""
|
|
query = f"""
|
|
UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
|
|
UPDATE {self._get_table_name(PostgresUserHandler.TABLE_NAME)}
|
|
- SET is_verified = TRUE, verification_code = NULL, verification_code_expiry = NULL
|
|
|
|
|
|
+ SET is_verified = TRUE,
|
|
|
|
+ verification_code = NULL,
|
|
|
|
+ verification_code_expiry = NULL
|
|
WHERE id = $1
|
|
WHERE id = $1
|
|
"""
|
|
"""
|
|
await self.connection_manager.execute_query(query, [id])
|
|
await self.connection_manager.execute_query(query, [id])
|
|
@@ -553,7 +666,9 @@ class PostgresUserHandler(Handler):
|
|
limit: int,
|
|
limit: int,
|
|
user_ids: Optional[list[UUID]] = None,
|
|
user_ids: Optional[list[UUID]] = None,
|
|
) -> dict[str, list[User] | int]:
|
|
) -> dict[str, list[User] | int]:
|
|
-
|
|
|
|
|
|
+ """
|
|
|
|
+ Return users with document usage and total entries.
|
|
|
|
+ """
|
|
query = f"""
|
|
query = f"""
|
|
WITH user_document_ids AS (
|
|
WITH user_document_ids AS (
|
|
SELECT
|
|
SELECT
|
|
@@ -604,36 +719,36 @@ class PostgresUserHandler(Handler):
|
|
params.append(user_ids)
|
|
params.append(user_ids)
|
|
|
|
|
|
results = await self.connection_manager.fetch_query(query, params)
|
|
results = await self.connection_manager.fetch_query(query, params)
|
|
|
|
+ if not results:
|
|
|
|
+ raise R2RException(status_code=404, message="No users found")
|
|
|
|
|
|
- users = [
|
|
|
|
- User(
|
|
|
|
- id=row["id"],
|
|
|
|
- email=row["email"],
|
|
|
|
- is_superuser=row["is_superuser"],
|
|
|
|
- is_active=row["is_active"],
|
|
|
|
- is_verified=row["is_verified"],
|
|
|
|
- name=row["name"],
|
|
|
|
- bio=row["bio"],
|
|
|
|
- created_at=row["created_at"],
|
|
|
|
- updated_at=row["updated_at"],
|
|
|
|
- collection_ids=row["collection_ids"] or [],
|
|
|
|
- num_files=row["num_files"],
|
|
|
|
- total_size_in_bytes=row["total_size_in_bytes"],
|
|
|
|
- document_ids=(
|
|
|
|
- []
|
|
|
|
- if row["document_ids"] is None
|
|
|
|
- else list(row["document_ids"])
|
|
|
|
- ),
|
|
|
|
|
|
+ users_list = []
|
|
|
|
+ for row in results:
|
|
|
|
+ users_list.append(
|
|
|
|
+ User(
|
|
|
|
+ id=row["id"],
|
|
|
|
+ email=row["email"],
|
|
|
|
+ is_superuser=row["is_superuser"],
|
|
|
|
+ is_active=row["is_active"],
|
|
|
|
+ is_verified=row["is_verified"],
|
|
|
|
+ name=row["name"],
|
|
|
|
+ bio=row["bio"],
|
|
|
|
+ created_at=row["created_at"],
|
|
|
|
+ updated_at=row["updated_at"],
|
|
|
|
+ profile_picture=row["profile_picture"],
|
|
|
|
+ collection_ids=row["collection_ids"] or [],
|
|
|
|
+ num_files=row["num_files"],
|
|
|
|
+ total_size_in_bytes=row["total_size_in_bytes"],
|
|
|
|
+ document_ids=(
|
|
|
|
+ list(row["document_ids"])
|
|
|
|
+ if row["document_ids"]
|
|
|
|
+ else []
|
|
|
|
+ ),
|
|
|
|
+ )
|
|
)
|
|
)
|
|
- for row in results
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- if not users:
|
|
|
|
- raise R2RException(status_code=404, message="No users found")
|
|
|
|
|
|
|
|
total_entries = results[0]["total_entries"]
|
|
total_entries = results[0]["total_entries"]
|
|
-
|
|
|
|
- return {"results": users, "total_entries": total_entries}
|
|
|
|
|
|
+ return {"results": users_list, "total_entries": total_entries}
|
|
|
|
|
|
async def _collection_exists(self, collection_id: UUID) -> bool:
|
|
async def _collection_exists(self, collection_id: UUID) -> bool:
|
|
"""Check if a collection exists."""
|
|
"""Check if a collection exists."""
|
|
@@ -693,7 +808,7 @@ class PostgresUserHandler(Handler):
|
|
hashed_key: str,
|
|
hashed_key: str,
|
|
name: Optional[str] = None,
|
|
name: Optional[str] = None,
|
|
) -> UUID:
|
|
) -> UUID:
|
|
- """Store a new API key for a user"""
|
|
|
|
|
|
+ """Store a new API key for a user."""
|
|
query = f"""
|
|
query = f"""
|
|
INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
|
|
INSERT INTO {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
|
|
(user_id, public_key, hashed_key, name)
|
|
(user_id, public_key, hashed_key, name)
|
|
@@ -710,7 +825,10 @@ class PostgresUserHandler(Handler):
|
|
return result["id"]
|
|
return result["id"]
|
|
|
|
|
|
async def get_api_key_record(self, key_id: str) -> Optional[dict]:
|
|
async def get_api_key_record(self, key_id: str) -> Optional[dict]:
|
|
- """Get API key record and update updated_at"""
|
|
|
|
|
|
+ """
|
|
|
|
+ Get API key record by 'public_key' and update 'updated_at' to now.
|
|
|
|
+ Returns { "user_id", "hashed_key" } or None if not found.
|
|
|
|
+ """
|
|
query = f"""
|
|
query = f"""
|
|
UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
|
|
UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
|
|
SET updated_at = NOW()
|
|
SET updated_at = NOW()
|
|
@@ -726,7 +844,7 @@ class PostgresUserHandler(Handler):
|
|
}
|
|
}
|
|
|
|
|
|
async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
|
|
async def get_user_api_keys(self, user_id: UUID) -> list[dict]:
|
|
- """Get all API keys for a user"""
|
|
|
|
|
|
+ """Get all API keys for a user."""
|
|
query = f"""
|
|
query = f"""
|
|
SELECT id, public_key, name, created_at, updated_at
|
|
SELECT id, public_key, name, created_at, updated_at
|
|
FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
|
|
FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
|
|
@@ -745,7 +863,7 @@ class PostgresUserHandler(Handler):
|
|
]
|
|
]
|
|
|
|
|
|
async def delete_api_key(self, user_id: UUID, key_id: UUID) -> dict:
|
|
async def delete_api_key(self, user_id: UUID, key_id: UUID) -> dict:
|
|
- """Delete a specific API key"""
|
|
|
|
|
|
+ """Delete a specific API key."""
|
|
query = f"""
|
|
query = f"""
|
|
DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
|
|
DELETE FROM {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
|
|
WHERE id = $1 AND user_id = $2
|
|
WHERE id = $1 AND user_id = $2
|
|
@@ -766,7 +884,7 @@ class PostgresUserHandler(Handler):
|
|
async def update_api_key_name(
|
|
async def update_api_key_name(
|
|
self, user_id: UUID, key_id: UUID, name: str
|
|
self, user_id: UUID, key_id: UUID, name: str
|
|
) -> bool:
|
|
) -> bool:
|
|
- """Update the name of an API key"""
|
|
|
|
|
|
+ """Update the name of an existing API key."""
|
|
query = f"""
|
|
query = f"""
|
|
UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
|
|
UPDATE {self._get_table_name(PostgresUserHandler.API_KEYS_TABLE_NAME)}
|
|
SET name = $1, updated_at = NOW()
|
|
SET name = $1, updated_at = NOW()
|