1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084 |
- import logging
- import os
- from collections import defaultdict
- from datetime import datetime, timedelta, timezone
- from typing import IO, Any, BinaryIO, Optional, Tuple
- from uuid import UUID
- import toml
- from core.base import (
- CollectionResponse,
- ConversationResponse,
- DocumentResponse,
- GenerationConfig,
- GraphConstructionStatus,
- Message,
- MessageResponse,
- Prompt,
- R2RException,
- StoreType,
- User,
- )
- from ..abstractions import R2RProviders
- from ..config import R2RConfig
- from .base import Service
- logger = logging.getLogger()
- class ManagementService(Service):
- def __init__(
- self,
- config: R2RConfig,
- providers: R2RProviders,
- ):
- super().__init__(
- config,
- providers,
- )
- async def app_settings(self):
- prompts = (
- await self.providers.database.prompts_handler.get_all_prompts()
- )
- config_toml = self.config.to_toml()
- config_dict = toml.loads(config_toml)
- try:
- project_name = os.environ["R2R_PROJECT_NAME"]
- except KeyError:
- project_name = ""
- return {
- "config": config_dict,
- "prompts": prompts,
- "r2r_project_name": project_name,
- }
- async def users_overview(
- self,
- offset: int,
- limit: int,
- user_ids: Optional[list[UUID]] = None,
- ):
- return await self.providers.database.users_handler.get_users_overview(
- offset=offset,
- limit=limit,
- user_ids=user_ids,
- )
- async def delete_documents_and_chunks_by_filter(
- self,
- filters: dict[str, Any],
- ):
- """Delete chunks matching the given filters. If any documents are now
- empty (i.e., have no remaining chunks), delete those documents as well.
- Args:
- filters (dict[str, Any]): Filters specifying which chunks to delete.
- chunks_handler (PostgresChunksHandler): The handler for chunk operations.
- documents_handler (PostgresDocumentsHandler): The handler for document operations.
- graphs_handler: Handler for entity and relationship operations in the Graph.
- Returns:
- dict: A summary of what was deleted.
- """
- def transform_chunk_id_to_id(
- filters: dict[str, Any],
- ) -> dict[str, Any]:
- """Example transformation function if your filters use `chunk_id`
- instead of `id`.
- Recursively transform `chunk_id` to `id`.
- """
- if isinstance(filters, dict):
- transformed = {}
- for key, value in filters.items():
- if key == "chunk_id":
- transformed["id"] = value
- elif key in ["$and", "$or"]:
- transformed[key] = [
- transform_chunk_id_to_id(item) for item in value
- ]
- else:
- transformed[key] = transform_chunk_id_to_id(value)
- return transformed
- return filters
- # Transform filters if needed.
- transformed_filters = transform_chunk_id_to_id(filters)
- # Find chunks that match the filters before deleting
- interim_results = (
- await self.providers.database.chunks_handler.list_chunks(
- filters=transformed_filters,
- offset=0,
- limit=1_000,
- include_vectors=False,
- )
- )
- results = interim_results["results"]
- while interim_results["total_entries"] == 1_000:
- # If we hit the limit, we need to paginate to get all results
- interim_results = (
- await self.providers.database.chunks_handler.list_chunks(
- filters=transformed_filters,
- offset=interim_results["offset"] + 1_000,
- limit=1_000,
- include_vectors=False,
- )
- )
- results.extend(interim_results["results"])
- document_ids = set()
- owner_id = None
- if "$and" in filters:
- for condition in filters["$and"]:
- if "owner_id" in condition and "$eq" in condition["owner_id"]:
- owner_id = condition["owner_id"]["$eq"]
- elif (
- "document_id" in condition
- and "$eq" in condition["document_id"]
- ):
- document_ids.add(UUID(condition["document_id"]["$eq"]))
- elif "document_id" in filters:
- doc_id = filters["document_id"]
- if isinstance(doc_id, str):
- document_ids.add(UUID(doc_id))
- elif isinstance(doc_id, UUID):
- document_ids.add(doc_id)
- elif isinstance(doc_id, dict) and "$eq" in doc_id:
- value = doc_id["$eq"]
- document_ids.add(
- UUID(value) if isinstance(value, str) else value
- )
- # Delete matching chunks from the database
- delete_results = await self.providers.database.chunks_handler.delete(
- transformed_filters
- )
- # Extract the document_ids that were affected.
- affected_doc_ids = {
- UUID(info["document_id"])
- for info in delete_results.values()
- if info.get("document_id")
- }
- document_ids.update(affected_doc_ids)
- # Check if the document still has any chunks left
- docs_to_delete = []
- for doc_id in document_ids:
- documents_overview_response = await self.providers.database.documents_handler.get_documents_overview(
- offset=0, limit=1, filter_document_ids=[doc_id]
- )
- if not documents_overview_response["results"]:
- raise R2RException(
- status_code=404, message="Document not found"
- )
- document = documents_overview_response["results"][0]
- for collection_id in document.collection_ids:
- await self.providers.database.collections_handler.decrement_collection_document_count(
- collection_id=collection_id
- )
- if owner_id and str(document.owner_id) != owner_id:
- raise R2RException(
- status_code=404,
- message="Document not found or insufficient permissions",
- )
- docs_to_delete.append(doc_id)
- # Delete documents that no longer have associated chunks
- for doc_id in docs_to_delete:
- # Delete related entities & relationships if needed:
- await self.providers.database.graphs_handler.entities.delete(
- parent_id=doc_id,
- store_type=StoreType.DOCUMENTS,
- )
- await self.providers.database.graphs_handler.relationships.delete(
- parent_id=doc_id,
- store_type=StoreType.DOCUMENTS,
- )
- # Finally, delete the document from documents_overview:
- await self.providers.database.documents_handler.delete(
- document_id=doc_id
- )
- return {
- "success": True,
- "deleted_chunks_count": len(delete_results),
- "deleted_documents_count": len(docs_to_delete),
- "deleted_document_ids": [str(d) for d in docs_to_delete],
- }
- async def download_file(
- self, document_id: UUID
- ) -> Optional[Tuple[str, BinaryIO, int]]:
- if result := await self.providers.file.retrieve_file(document_id):
- return result
- return None
- async def export_files(
- self,
- document_ids: Optional[list[UUID]] = None,
- start_date: Optional[datetime] = None,
- end_date: Optional[datetime] = None,
- ) -> tuple[str, BinaryIO, int]:
- return await self.providers.file.retrieve_files_as_zip(
- document_ids=document_ids,
- start_date=start_date,
- end_date=end_date,
- )
- async def export_collections(
- self,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- return await self.providers.database.collections_handler.export_to_csv(
- columns=columns,
- filters=filters,
- include_header=include_header,
- )
- async def export_documents(
- self,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- return await self.providers.database.documents_handler.export_to_csv(
- columns=columns,
- filters=filters,
- include_header=include_header,
- )
- async def export_document_entities(
- self,
- id: UUID,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- return await self.providers.database.graphs_handler.entities.export_to_csv(
- parent_id=id,
- store_type=StoreType.DOCUMENTS,
- columns=columns,
- filters=filters,
- include_header=include_header,
- )
- async def export_document_relationships(
- self,
- id: UUID,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- return await self.providers.database.graphs_handler.relationships.export_to_csv(
- parent_id=id,
- store_type=StoreType.DOCUMENTS,
- columns=columns,
- filters=filters,
- include_header=include_header,
- )
- async def export_conversations(
- self,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- return await self.providers.database.conversations_handler.export_conversations_to_csv(
- columns=columns,
- filters=filters,
- include_header=include_header,
- )
- async def export_graph_entities(
- self,
- id: UUID,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- return await self.providers.database.graphs_handler.entities.export_to_csv(
- parent_id=id,
- store_type=StoreType.GRAPHS,
- columns=columns,
- filters=filters,
- include_header=include_header,
- )
- async def export_graph_relationships(
- self,
- id: UUID,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- return await self.providers.database.graphs_handler.relationships.export_to_csv(
- parent_id=id,
- store_type=StoreType.GRAPHS,
- columns=columns,
- filters=filters,
- include_header=include_header,
- )
- async def export_graph_communities(
- self,
- id: UUID,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- return await self.providers.database.graphs_handler.communities.export_to_csv(
- parent_id=id,
- store_type=StoreType.GRAPHS,
- columns=columns,
- filters=filters,
- include_header=include_header,
- )
- async def export_messages(
- self,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- return await self.providers.database.conversations_handler.export_messages_to_csv(
- columns=columns,
- filters=filters,
- include_header=include_header,
- )
- async def export_users(
- self,
- columns: Optional[list[str]] = None,
- filters: Optional[dict] = None,
- include_header: bool = True,
- ) -> tuple[str, IO]:
- return await self.providers.database.users_handler.export_to_csv(
- columns=columns,
- filters=filters,
- include_header=include_header,
- )
- async def documents_overview(
- self,
- offset: int,
- limit: int,
- user_ids: Optional[list[UUID]] = None,
- collection_ids: Optional[list[UUID]] = None,
- document_ids: Optional[list[UUID]] = None,
- owner_only: bool = False,
- ):
- return await self.providers.database.documents_handler.get_documents_overview(
- offset=offset,
- limit=limit,
- filter_document_ids=document_ids,
- filter_user_ids=user_ids,
- filter_collection_ids=collection_ids,
- owner_only=owner_only,
- )
- async def update_document_metadata(
- self,
- document_id: UUID,
- metadata: list[dict],
- overwrite: bool = False,
- ):
- return await self.providers.database.documents_handler.update_document_metadata(
- document_id=document_id,
- metadata=metadata,
- overwrite=overwrite,
- )
- async def list_document_chunks(
- self,
- document_id: UUID,
- offset: int,
- limit: int,
- include_vectors: bool = False,
- ):
- return (
- await self.providers.database.chunks_handler.list_document_chunks(
- document_id=document_id,
- offset=offset,
- limit=limit,
- include_vectors=include_vectors,
- )
- )
- async def assign_document_to_collection(
- self, document_id: UUID, collection_id: UUID
- ):
- await self.providers.database.chunks_handler.assign_document_chunks_to_collection(
- document_id, collection_id
- )
- await self.providers.database.collections_handler.assign_document_to_collection_relational(
- document_id, collection_id
- )
- await self.providers.database.documents_handler.set_workflow_status(
- id=collection_id,
- status_type="graph_sync_status",
- status=GraphConstructionStatus.OUTDATED,
- )
- await self.providers.database.documents_handler.set_workflow_status(
- id=collection_id,
- status_type="graph_cluster_status",
- status=GraphConstructionStatus.OUTDATED,
- )
- return {"message": "Document assigned to collection successfully"}
- async def remove_document_from_collection(
- self, document_id: UUID, collection_id: UUID
- ):
- await self.providers.database.collections_handler.remove_document_from_collection_relational(
- document_id, collection_id
- )
- await self.providers.database.chunks_handler.remove_document_from_collection_vector(
- document_id, collection_id
- )
- # await self.providers.database.graphs_handler.delete_node_via_document_id(
- # document_id, collection_id
- # )
- return None
- def _process_relationships(
- self, relationships: list[Tuple[str, str, str]]
- ) -> Tuple[dict[str, list[str]], dict[str, dict[str, list[str]]]]:
- graph = defaultdict(list)
- grouped: dict[str, dict[str, list[str]]] = defaultdict(
- lambda: defaultdict(list)
- )
- for subject, relation, obj in relationships:
- graph[subject].append(obj)
- grouped[subject][relation].append(obj)
- if obj not in graph:
- graph[obj] = []
- return dict(graph), dict(grouped)
- def generate_output(
- self,
- grouped_relationships: dict[str, dict[str, list[str]]],
- graph: dict[str, list[str]],
- descriptions_dict: dict[str, str],
- print_descriptions: bool = True,
- ) -> list[str]:
- output = []
- # Print grouped relationships
- for subject, relations in grouped_relationships.items():
- output.append(f"\n== {subject} ==")
- if print_descriptions and subject in descriptions_dict:
- output.append(f"\tDescription: {descriptions_dict[subject]}")
- for relation, objects in relations.items():
- output.append(f" {relation}:")
- for obj in objects:
- output.append(f" - {obj}")
- if print_descriptions and obj in descriptions_dict:
- output.append(
- f" Description: {descriptions_dict[obj]}"
- )
- # Print basic graph statistics
- output.extend(
- [
- "\n== Graph Statistics ==",
- f"Number of nodes: {len(graph)}",
- f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}",
- f"Number of connected components: {self._count_connected_components(graph)}",
- ]
- )
- # Find central nodes
- central_nodes = self._get_central_nodes(graph)
- output.extend(
- [
- "\n== Most Central Nodes ==",
- *(
- f" {node}: {centrality:.4f}"
- for node, centrality in central_nodes
- ),
- ]
- )
- return output
- def _count_connected_components(self, graph: dict[str, list[str]]) -> int:
- visited = set()
- components = 0
- def dfs(node):
- visited.add(node)
- for neighbor in graph[node]:
- if neighbor not in visited:
- dfs(neighbor)
- for node in graph:
- if node not in visited:
- dfs(node)
- components += 1
- return components
- def _get_central_nodes(
- self, graph: dict[str, list[str]]
- ) -> list[Tuple[str, float]]:
- degree = {node: len(neighbors) for node, neighbors in graph.items()}
- total_nodes = len(graph)
- centrality = {
- node: deg / (total_nodes - 1) for node, deg in degree.items()
- }
- return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]
- async def create_collection(
- self,
- owner_id: UUID,
- name: Optional[str] = None,
- description: str | None = None,
- ) -> CollectionResponse:
- result = await self.providers.database.collections_handler.create_collection(
- owner_id=owner_id,
- name=name,
- description=description,
- )
- await self.providers.database.graphs_handler.create(
- collection_id=result.id,
- name=name,
- description=description,
- )
- return result
- async def update_collection(
- self,
- collection_id: UUID,
- name: Optional[str] = None,
- description: Optional[str] = None,
- generate_description: bool = False,
- ) -> CollectionResponse:
- if generate_description:
- description = await self.summarize_collection(
- id=collection_id, offset=0, limit=100
- )
- return await self.providers.database.collections_handler.update_collection(
- collection_id=collection_id,
- name=name,
- description=description,
- )
- async def delete_collection(self, collection_id: UUID) -> bool:
- await self.providers.database.collections_handler.delete_collection_relational(
- collection_id
- )
- await self.providers.database.chunks_handler.delete_collection_vector(
- collection_id
- )
- try:
- await self.providers.database.graphs_handler.delete(
- collection_id=collection_id,
- )
- except Exception as e:
- logger.warning(
- f"Error deleting graph for collection {collection_id}: {e}"
- )
- return True
- async def collections_overview(
- self,
- offset: int,
- limit: int,
- user_ids: Optional[list[UUID]] = None,
- document_ids: Optional[list[UUID]] = None,
- collection_ids: Optional[list[UUID]] = None,
- owner_only: bool = False,
- ) -> dict[str, list[CollectionResponse] | int]:
- return await self.providers.database.collections_handler.get_collections_overview(
- offset=offset,
- limit=limit,
- filter_user_ids=user_ids,
- filter_document_ids=document_ids,
- filter_collection_ids=collection_ids,
- owner_only=owner_only,
- )
- async def add_user_to_collection(
- self, user_id: UUID, collection_id: UUID
- ) -> bool:
- return (
- await self.providers.database.users_handler.add_user_to_collection(
- user_id, collection_id
- )
- )
- async def remove_user_from_collection(
- self, user_id: UUID, collection_id: UUID
- ) -> bool:
- return await self.providers.database.users_handler.remove_user_from_collection(
- user_id, collection_id
- )
- async def get_users_in_collection(
- self, collection_id: UUID, offset: int = 0, limit: int = 100
- ) -> dict[str, list[User] | int]:
- return await self.providers.database.users_handler.get_users_in_collection(
- collection_id, offset=offset, limit=limit
- )
- async def documents_in_collection(
- self, collection_id: UUID, offset: int = 0, limit: int = 100
- ) -> dict[str, list[DocumentResponse] | int]:
- return await self.providers.database.collections_handler.documents_in_collection(
- collection_id, offset=offset, limit=limit
- )
- async def summarize_collection(
- self, id: UUID, offset: int, limit: int
- ) -> str:
- documents_in_collection_response = await self.documents_in_collection(
- collection_id=id,
- offset=offset,
- limit=limit,
- )
- document_summaries = [
- document.summary
- for document in documents_in_collection_response["results"] # type: ignore
- ]
- logger.info(
- f"Summarizing collection {id} with {len(document_summaries)} of {documents_in_collection_response['total_entries']} documents."
- )
- formatted_summaries = "\n\n".join(document_summaries) # type: ignore
- messages = await self.providers.database.prompts_handler.get_message_payload(
- system_prompt_name=self.config.database.collection_summary_system_prompt,
- task_prompt_name=self.config.database.collection_summary_prompt,
- task_inputs={"document_summaries": formatted_summaries},
- )
- response = await self.providers.llm.aget_completion(
- messages=messages,
- generation_config=GenerationConfig(
- model=self.config.ingestion.document_summary_model
- or self.config.app.fast_llm
- ),
- )
- if collection_summary := response.choices[0].message.content:
- return collection_summary
- else:
- raise ValueError("Expected a generated response.")
- async def add_prompt(
- self, name: str, template: str, input_types: dict[str, str]
- ) -> dict:
- try:
- await self.providers.database.prompts_handler.add_prompt(
- name, template, input_types
- )
- return f"Prompt '{name}' added successfully." # type: ignore
- except ValueError as e:
- raise R2RException(status_code=400, message=str(e)) from e
- async def get_cached_prompt(
- self,
- prompt_name: str,
- inputs: Optional[dict[str, Any]] = None,
- prompt_override: Optional[str] = None,
- ) -> dict:
- try:
- return {
- "message": (
- await self.providers.database.prompts_handler.get_cached_prompt(
- prompt_name=prompt_name,
- inputs=inputs,
- prompt_override=prompt_override,
- )
- )
- }
- except ValueError as e:
- raise R2RException(status_code=404, message=str(e)) from e
- async def get_prompt(
- self,
- prompt_name: str,
- inputs: Optional[dict[str, Any]] = None,
- prompt_override: Optional[str] = None,
- ) -> dict:
- try:
- return await self.providers.database.prompts_handler.get_prompt( # type: ignore
- name=prompt_name,
- inputs=inputs,
- prompt_override=prompt_override,
- )
- except ValueError as e:
- raise R2RException(status_code=404, message=str(e)) from e
- async def get_all_prompts(self) -> dict[str, Prompt]:
- return await self.providers.database.prompts_handler.get_all_prompts()
- async def update_prompt(
- self,
- name: str,
- template: Optional[str] = None,
- input_types: Optional[dict[str, str]] = None,
- ) -> dict:
- try:
- await self.providers.database.prompts_handler.update_prompt(
- name, template, input_types
- )
- return f"Prompt '{name}' updated successfully." # type: ignore
- except ValueError as e:
- raise R2RException(status_code=404, message=str(e)) from e
- async def delete_prompt(self, name: str) -> dict:
- try:
- await self.providers.database.prompts_handler.delete_prompt(name)
- return {"message": f"Prompt '{name}' deleted successfully."}
- except ValueError as e:
- raise R2RException(status_code=404, message=str(e)) from e
- async def get_conversation(
- self,
- conversation_id: UUID,
- user_ids: Optional[list[UUID]] = None,
- ) -> list[MessageResponse]:
- return await self.providers.database.conversations_handler.get_conversation(
- conversation_id=conversation_id,
- filter_user_ids=user_ids,
- )
- async def create_conversation(
- self,
- user_id: Optional[UUID] = None,
- name: Optional[str] = None,
- ) -> ConversationResponse:
- return await self.providers.database.conversations_handler.create_conversation(
- user_id=user_id,
- name=name,
- )
- async def conversations_overview(
- self,
- offset: int,
- limit: int,
- conversation_ids: Optional[list[UUID]] = None,
- user_ids: Optional[list[UUID]] = None,
- ) -> dict[str, list[dict] | int]:
- return await self.providers.database.conversations_handler.get_conversations_overview(
- offset=offset,
- limit=limit,
- filter_user_ids=user_ids,
- conversation_ids=conversation_ids,
- )
- async def add_message(
- self,
- conversation_id: UUID,
- content: Message,
- parent_id: Optional[UUID] = None,
- metadata: Optional[dict] = None,
- ) -> MessageResponse:
- return await self.providers.database.conversations_handler.add_message(
- conversation_id=conversation_id,
- content=content,
- parent_id=parent_id,
- metadata=metadata,
- )
- async def edit_message(
- self,
- message_id: UUID,
- new_content: Optional[str] = None,
- additional_metadata: Optional[dict] = None,
- ) -> dict[str, Any]:
- return (
- await self.providers.database.conversations_handler.edit_message(
- message_id=message_id,
- new_content=new_content,
- additional_metadata=additional_metadata or {},
- )
- )
- async def update_conversation(
- self, conversation_id: UUID, name: str
- ) -> ConversationResponse:
- return await self.providers.database.conversations_handler.update_conversation(
- conversation_id=conversation_id, name=name
- )
- async def delete_conversation(
- self,
- conversation_id: UUID,
- user_ids: Optional[list[UUID]] = None,
- ) -> None:
- await (
- self.providers.database.conversations_handler.delete_conversation(
- conversation_id=conversation_id,
- filter_user_ids=user_ids,
- )
- )
- async def get_user_max_documents(self, user_id: UUID) -> int | None:
- # Fetch the user to see if they have any overrides stored
- user = await self.providers.database.users_handler.get_user_by_id(
- user_id
- )
- if user.limits_overrides and "max_documents" in user.limits_overrides:
- return user.limits_overrides["max_documents"]
- return self.config.app.default_max_documents_per_user
- async def get_user_max_chunks(self, user_id: UUID) -> int | None:
- user = await self.providers.database.users_handler.get_user_by_id(
- user_id
- )
- if user.limits_overrides and "max_chunks" in user.limits_overrides:
- return user.limits_overrides["max_chunks"]
- return self.config.app.default_max_chunks_per_user
- async def get_user_max_collections(self, user_id: UUID) -> int | None:
- user = await self.providers.database.users_handler.get_user_by_id(
- user_id
- )
- if (
- user.limits_overrides
- and "max_collections" in user.limits_overrides
- ):
- return user.limits_overrides["max_collections"]
- return self.config.app.default_max_collections_per_user
- async def get_max_upload_size_by_type(
- self, user_id: UUID, file_type_or_ext: str
- ) -> int:
- """Return the maximum allowed upload size (in bytes) for the given
- user's file type/extension. Respects user-level overrides if present,
- falling back to the system config.
- ```json
- {
- "limits_overrides": {
- "max_file_size": 20_000_000,
- "max_file_size_by_type":
- {
- "pdf": 50_000_000,
- "docx": 30_000_000
- },
- ...
- }
- }
- ```
- """
- # 1. Normalize extension
- ext = file_type_or_ext.lower().lstrip(".")
- # 2. Fetch user from DB to see if we have any overrides
- user = await self.providers.database.users_handler.get_user_by_id(
- user_id
- )
- user_overrides = user.limits_overrides or {}
- # 3. Check if there's a user-level override for "max_file_size_by_type"
- user_file_type_limits = user_overrides.get("max_file_size_by_type", {})
- if ext in user_file_type_limits:
- return user_file_type_limits[ext]
- # 4. If not, check if there's a user-level fallback "max_file_size"
- if "max_file_size" in user_overrides:
- return user_overrides["max_file_size"]
- # 5. If none exist at user level, use system config
- # Example config paths:
- system_type_limits = self.config.app.max_upload_size_by_type
- if ext in system_type_limits:
- return system_type_limits[ext]
- # 6. Otherwise, return the global default
- return self.config.app.default_max_upload_size
- async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:
- """
- Return a dictionary containing:
- - The system default limits (from self.config.limits)
- - The user's overrides (from user.limits_overrides)
- - The final 'effective' set of limits after merging (overall)
- - The usage for each relevant limit (per-route usage, etc.)
- """
- # 1) Fetch the user
- user = await self.providers.database.users_handler.get_user_by_id(
- user_id
- )
- user_overrides = user.limits_overrides or {}
- # 2) Grab system defaults
- system_defaults = {
- "global_per_min": self.config.database.limits.global_per_min,
- "route_per_min": self.config.database.limits.route_per_min,
- "monthly_limit": self.config.database.limits.monthly_limit,
- # Add additional fields if your LimitSettings has them
- }
- # 3) Build the overall (global) "effective limits" ignoring any specific route
- overall_effective = (
- self.providers.database.limits_handler.determine_effective_limits(
- user, route=""
- )
- )
- # 4) Build usage data. We'll do top-level usage for global_per_min/monthly,
- # then do route-by-route usage in a loop.
- usage: dict[str, Any] = {}
- now = datetime.now(timezone.utc)
- one_min_ago = now - timedelta(minutes=1)
- # (a) Global usage (per-minute)
- global_per_min_used = (
- await self.providers.database.limits_handler._count_requests(
- user_id, route=None, since=one_min_ago
- )
- )
- # (a2) Global usage (monthly) - i.e. usage across ALL routes
- global_monthly_used = await self.providers.database.limits_handler._count_monthly_requests(
- user_id, route=None
- )
- usage["global_per_min"] = {
- "used": global_per_min_used,
- "limit": overall_effective.global_per_min,
- "remaining": (
- overall_effective.global_per_min - global_per_min_used
- if overall_effective.global_per_min is not None
- else None
- ),
- }
- usage["monthly_limit"] = {
- "used": global_monthly_used,
- "limit": overall_effective.monthly_limit,
- "remaining": (
- overall_effective.monthly_limit - global_monthly_used
- if overall_effective.monthly_limit is not None
- else None
- ),
- }
- # (b) Route-level usage. We'll gather all routes from system + user overrides
- system_route_limits = (
- self.config.database.route_limits
- ) # dict[str, LimitSettings]
- user_route_overrides = user_overrides.get("route_overrides", {})
- route_keys = set(system_route_limits.keys()) | set(
- user_route_overrides.keys()
- )
- usage["routes"] = {}
- for route in route_keys:
- # 1) Get the final merged limits for this specific route
- route_effective = self.providers.database.limits_handler.determine_effective_limits(
- user, route
- )
- # 2) Count requests for the last minute on this route
- route_per_min_used = (
- await self.providers.database.limits_handler._count_requests(
- user_id, route, one_min_ago
- )
- )
- # 3) Count route-specific monthly usage
- route_monthly_used = await self.providers.database.limits_handler._count_monthly_requests(
- user_id, route
- )
- usage["routes"][route] = {
- "route_per_min": {
- "used": route_per_min_used,
- "limit": route_effective.route_per_min,
- "remaining": (
- route_effective.route_per_min - route_per_min_used
- if route_effective.route_per_min is not None
- else None
- ),
- },
- "monthly_limit": {
- "used": route_monthly_used,
- "limit": route_effective.monthly_limit,
- "remaining": (
- route_effective.monthly_limit - route_monthly_used
- if route_effective.monthly_limit is not None
- else None
- ),
- },
- }
- max_documents = await self.get_user_max_documents(user_id)
- used_documents = (
- await self.providers.database.documents_handler.get_documents_overview(
- limit=1, offset=0, filter_user_ids=[user_id]
- )
- )["total_entries"]
- max_chunks = await self.get_user_max_chunks(user_id)
- used_chunks = (
- await self.providers.database.chunks_handler.list_chunks(
- limit=1, offset=0, filters={"owner_id": user_id}
- )
- )["total_entries"]
- max_collections = await self.get_user_max_collections(user_id)
- used_collections: int = ( # type: ignore
- await self.providers.database.collections_handler.get_collections_overview(
- limit=1, offset=0, filter_user_ids=[user_id]
- )
- )["total_entries"]
- storage_limits = {
- "chunks": {
- "limit": max_chunks,
- "used": used_chunks,
- "remaining": (
- max_chunks - used_chunks
- if max_chunks is not None
- else None
- ),
- },
- "documents": {
- "limit": max_documents,
- "used": used_documents,
- "remaining": (
- max_documents - used_documents
- if max_documents is not None
- else None
- ),
- },
- "collections": {
- "limit": max_collections,
- "used": used_collections,
- "remaining": (
- max_collections - used_collections
- if max_collections is not None
- else None
- ),
- },
- }
- # 5) Return a structured response
- return {
- "storage_limits": storage_limits,
- "system_defaults": system_defaults,
- "user_overrides": user_overrides,
- "effective_limits": {
- "global_per_min": overall_effective.global_per_min,
- "route_per_min": overall_effective.route_per_min,
- "monthly_limit": overall_effective.monthly_limit,
- },
- "usage": usage,
- }
|