management_service.py 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084
  1. import logging
  2. import os
  3. from collections import defaultdict
  4. from datetime import datetime, timedelta, timezone
  5. from typing import IO, Any, BinaryIO, Optional, Tuple
  6. from uuid import UUID
  7. import toml
  8. from core.base import (
  9. CollectionResponse,
  10. ConversationResponse,
  11. DocumentResponse,
  12. GenerationConfig,
  13. GraphConstructionStatus,
  14. Message,
  15. MessageResponse,
  16. Prompt,
  17. R2RException,
  18. StoreType,
  19. User,
  20. )
  21. from ..abstractions import R2RProviders
  22. from ..config import R2RConfig
  23. from .base import Service
  24. logger = logging.getLogger()
  25. class ManagementService(Service):
  26. def __init__(
  27. self,
  28. config: R2RConfig,
  29. providers: R2RProviders,
  30. ):
  31. super().__init__(
  32. config,
  33. providers,
  34. )
  35. async def app_settings(self):
  36. prompts = (
  37. await self.providers.database.prompts_handler.get_all_prompts()
  38. )
  39. config_toml = self.config.to_toml()
  40. config_dict = toml.loads(config_toml)
  41. try:
  42. project_name = os.environ["R2R_PROJECT_NAME"]
  43. except KeyError:
  44. project_name = ""
  45. return {
  46. "config": config_dict,
  47. "prompts": prompts,
  48. "r2r_project_name": project_name,
  49. }
  50. async def users_overview(
  51. self,
  52. offset: int,
  53. limit: int,
  54. user_ids: Optional[list[UUID]] = None,
  55. ):
  56. return await self.providers.database.users_handler.get_users_overview(
  57. offset=offset,
  58. limit=limit,
  59. user_ids=user_ids,
  60. )
  61. async def delete_documents_and_chunks_by_filter(
  62. self,
  63. filters: dict[str, Any],
  64. ):
  65. """Delete chunks matching the given filters. If any documents are now
  66. empty (i.e., have no remaining chunks), delete those documents as well.
  67. Args:
  68. filters (dict[str, Any]): Filters specifying which chunks to delete.
  69. chunks_handler (PostgresChunksHandler): The handler for chunk operations.
  70. documents_handler (PostgresDocumentsHandler): The handler for document operations.
  71. graphs_handler: Handler for entity and relationship operations in the Graph.
  72. Returns:
  73. dict: A summary of what was deleted.
  74. """
  75. def transform_chunk_id_to_id(
  76. filters: dict[str, Any],
  77. ) -> dict[str, Any]:
  78. """Example transformation function if your filters use `chunk_id`
  79. instead of `id`.
  80. Recursively transform `chunk_id` to `id`.
  81. """
  82. if isinstance(filters, dict):
  83. transformed = {}
  84. for key, value in filters.items():
  85. if key == "chunk_id":
  86. transformed["id"] = value
  87. elif key in ["$and", "$or"]:
  88. transformed[key] = [
  89. transform_chunk_id_to_id(item) for item in value
  90. ]
  91. else:
  92. transformed[key] = transform_chunk_id_to_id(value)
  93. return transformed
  94. return filters
  95. # Transform filters if needed.
  96. transformed_filters = transform_chunk_id_to_id(filters)
  97. # Find chunks that match the filters before deleting
  98. interim_results = (
  99. await self.providers.database.chunks_handler.list_chunks(
  100. filters=transformed_filters,
  101. offset=0,
  102. limit=1_000,
  103. include_vectors=False,
  104. )
  105. )
  106. results = interim_results["results"]
  107. while interim_results["total_entries"] == 1_000:
  108. # If we hit the limit, we need to paginate to get all results
  109. interim_results = (
  110. await self.providers.database.chunks_handler.list_chunks(
  111. filters=transformed_filters,
  112. offset=interim_results["offset"] + 1_000,
  113. limit=1_000,
  114. include_vectors=False,
  115. )
  116. )
  117. results.extend(interim_results["results"])
  118. document_ids = set()
  119. owner_id = None
  120. if "$and" in filters:
  121. for condition in filters["$and"]:
  122. if "owner_id" in condition and "$eq" in condition["owner_id"]:
  123. owner_id = condition["owner_id"]["$eq"]
  124. elif (
  125. "document_id" in condition
  126. and "$eq" in condition["document_id"]
  127. ):
  128. document_ids.add(UUID(condition["document_id"]["$eq"]))
  129. elif "document_id" in filters:
  130. doc_id = filters["document_id"]
  131. if isinstance(doc_id, str):
  132. document_ids.add(UUID(doc_id))
  133. elif isinstance(doc_id, UUID):
  134. document_ids.add(doc_id)
  135. elif isinstance(doc_id, dict) and "$eq" in doc_id:
  136. value = doc_id["$eq"]
  137. document_ids.add(
  138. UUID(value) if isinstance(value, str) else value
  139. )
  140. # Delete matching chunks from the database
  141. delete_results = await self.providers.database.chunks_handler.delete(
  142. transformed_filters
  143. )
  144. # Extract the document_ids that were affected.
  145. affected_doc_ids = {
  146. UUID(info["document_id"])
  147. for info in delete_results.values()
  148. if info.get("document_id")
  149. }
  150. document_ids.update(affected_doc_ids)
  151. # Check if the document still has any chunks left
  152. docs_to_delete = []
  153. for doc_id in document_ids:
  154. documents_overview_response = await self.providers.database.documents_handler.get_documents_overview(
  155. offset=0, limit=1, filter_document_ids=[doc_id]
  156. )
  157. if not documents_overview_response["results"]:
  158. raise R2RException(
  159. status_code=404, message="Document not found"
  160. )
  161. document = documents_overview_response["results"][0]
  162. for collection_id in document.collection_ids:
  163. await self.providers.database.collections_handler.decrement_collection_document_count(
  164. collection_id=collection_id
  165. )
  166. if owner_id and str(document.owner_id) != owner_id:
  167. raise R2RException(
  168. status_code=404,
  169. message="Document not found or insufficient permissions",
  170. )
  171. docs_to_delete.append(doc_id)
  172. # Delete documents that no longer have associated chunks
  173. for doc_id in docs_to_delete:
  174. # Delete related entities & relationships if needed:
  175. await self.providers.database.graphs_handler.entities.delete(
  176. parent_id=doc_id,
  177. store_type=StoreType.DOCUMENTS,
  178. )
  179. await self.providers.database.graphs_handler.relationships.delete(
  180. parent_id=doc_id,
  181. store_type=StoreType.DOCUMENTS,
  182. )
  183. # Finally, delete the document from documents_overview:
  184. await self.providers.database.documents_handler.delete(
  185. document_id=doc_id
  186. )
  187. return {
  188. "success": True,
  189. "deleted_chunks_count": len(delete_results),
  190. "deleted_documents_count": len(docs_to_delete),
  191. "deleted_document_ids": [str(d) for d in docs_to_delete],
  192. }
  193. async def download_file(
  194. self, document_id: UUID
  195. ) -> Optional[Tuple[str, BinaryIO, int]]:
  196. if result := await self.providers.file.retrieve_file(document_id):
  197. return result
  198. return None
  199. async def export_files(
  200. self,
  201. document_ids: Optional[list[UUID]] = None,
  202. start_date: Optional[datetime] = None,
  203. end_date: Optional[datetime] = None,
  204. ) -> tuple[str, BinaryIO, int]:
  205. return await self.providers.file.retrieve_files_as_zip(
  206. document_ids=document_ids,
  207. start_date=start_date,
  208. end_date=end_date,
  209. )
  210. async def export_collections(
  211. self,
  212. columns: Optional[list[str]] = None,
  213. filters: Optional[dict] = None,
  214. include_header: bool = True,
  215. ) -> tuple[str, IO]:
  216. return await self.providers.database.collections_handler.export_to_csv(
  217. columns=columns,
  218. filters=filters,
  219. include_header=include_header,
  220. )
  221. async def export_documents(
  222. self,
  223. columns: Optional[list[str]] = None,
  224. filters: Optional[dict] = None,
  225. include_header: bool = True,
  226. ) -> tuple[str, IO]:
  227. return await self.providers.database.documents_handler.export_to_csv(
  228. columns=columns,
  229. filters=filters,
  230. include_header=include_header,
  231. )
  232. async def export_document_entities(
  233. self,
  234. id: UUID,
  235. columns: Optional[list[str]] = None,
  236. filters: Optional[dict] = None,
  237. include_header: bool = True,
  238. ) -> tuple[str, IO]:
  239. return await self.providers.database.graphs_handler.entities.export_to_csv(
  240. parent_id=id,
  241. store_type=StoreType.DOCUMENTS,
  242. columns=columns,
  243. filters=filters,
  244. include_header=include_header,
  245. )
  246. async def export_document_relationships(
  247. self,
  248. id: UUID,
  249. columns: Optional[list[str]] = None,
  250. filters: Optional[dict] = None,
  251. include_header: bool = True,
  252. ) -> tuple[str, IO]:
  253. return await self.providers.database.graphs_handler.relationships.export_to_csv(
  254. parent_id=id,
  255. store_type=StoreType.DOCUMENTS,
  256. columns=columns,
  257. filters=filters,
  258. include_header=include_header,
  259. )
  260. async def export_conversations(
  261. self,
  262. columns: Optional[list[str]] = None,
  263. filters: Optional[dict] = None,
  264. include_header: bool = True,
  265. ) -> tuple[str, IO]:
  266. return await self.providers.database.conversations_handler.export_conversations_to_csv(
  267. columns=columns,
  268. filters=filters,
  269. include_header=include_header,
  270. )
  271. async def export_graph_entities(
  272. self,
  273. id: UUID,
  274. columns: Optional[list[str]] = None,
  275. filters: Optional[dict] = None,
  276. include_header: bool = True,
  277. ) -> tuple[str, IO]:
  278. return await self.providers.database.graphs_handler.entities.export_to_csv(
  279. parent_id=id,
  280. store_type=StoreType.GRAPHS,
  281. columns=columns,
  282. filters=filters,
  283. include_header=include_header,
  284. )
  285. async def export_graph_relationships(
  286. self,
  287. id: UUID,
  288. columns: Optional[list[str]] = None,
  289. filters: Optional[dict] = None,
  290. include_header: bool = True,
  291. ) -> tuple[str, IO]:
  292. return await self.providers.database.graphs_handler.relationships.export_to_csv(
  293. parent_id=id,
  294. store_type=StoreType.GRAPHS,
  295. columns=columns,
  296. filters=filters,
  297. include_header=include_header,
  298. )
  299. async def export_graph_communities(
  300. self,
  301. id: UUID,
  302. columns: Optional[list[str]] = None,
  303. filters: Optional[dict] = None,
  304. include_header: bool = True,
  305. ) -> tuple[str, IO]:
  306. return await self.providers.database.graphs_handler.communities.export_to_csv(
  307. parent_id=id,
  308. store_type=StoreType.GRAPHS,
  309. columns=columns,
  310. filters=filters,
  311. include_header=include_header,
  312. )
  313. async def export_messages(
  314. self,
  315. columns: Optional[list[str]] = None,
  316. filters: Optional[dict] = None,
  317. include_header: bool = True,
  318. ) -> tuple[str, IO]:
  319. return await self.providers.database.conversations_handler.export_messages_to_csv(
  320. columns=columns,
  321. filters=filters,
  322. include_header=include_header,
  323. )
  324. async def export_users(
  325. self,
  326. columns: Optional[list[str]] = None,
  327. filters: Optional[dict] = None,
  328. include_header: bool = True,
  329. ) -> tuple[str, IO]:
  330. return await self.providers.database.users_handler.export_to_csv(
  331. columns=columns,
  332. filters=filters,
  333. include_header=include_header,
  334. )
  335. async def documents_overview(
  336. self,
  337. offset: int,
  338. limit: int,
  339. user_ids: Optional[list[UUID]] = None,
  340. collection_ids: Optional[list[UUID]] = None,
  341. document_ids: Optional[list[UUID]] = None,
  342. owner_only: bool = False,
  343. ):
  344. return await self.providers.database.documents_handler.get_documents_overview(
  345. offset=offset,
  346. limit=limit,
  347. filter_document_ids=document_ids,
  348. filter_user_ids=user_ids,
  349. filter_collection_ids=collection_ids,
  350. owner_only=owner_only,
  351. )
  352. async def update_document_metadata(
  353. self,
  354. document_id: UUID,
  355. metadata: list[dict],
  356. overwrite: bool = False,
  357. ):
  358. return await self.providers.database.documents_handler.update_document_metadata(
  359. document_id=document_id,
  360. metadata=metadata,
  361. overwrite=overwrite,
  362. )
  363. async def list_document_chunks(
  364. self,
  365. document_id: UUID,
  366. offset: int,
  367. limit: int,
  368. include_vectors: bool = False,
  369. ):
  370. return (
  371. await self.providers.database.chunks_handler.list_document_chunks(
  372. document_id=document_id,
  373. offset=offset,
  374. limit=limit,
  375. include_vectors=include_vectors,
  376. )
  377. )
  378. async def assign_document_to_collection(
  379. self, document_id: UUID, collection_id: UUID
  380. ):
  381. await self.providers.database.chunks_handler.assign_document_chunks_to_collection(
  382. document_id, collection_id
  383. )
  384. await self.providers.database.collections_handler.assign_document_to_collection_relational(
  385. document_id, collection_id
  386. )
  387. await self.providers.database.documents_handler.set_workflow_status(
  388. id=collection_id,
  389. status_type="graph_sync_status",
  390. status=GraphConstructionStatus.OUTDATED,
  391. )
  392. await self.providers.database.documents_handler.set_workflow_status(
  393. id=collection_id,
  394. status_type="graph_cluster_status",
  395. status=GraphConstructionStatus.OUTDATED,
  396. )
  397. return {"message": "Document assigned to collection successfully"}
  398. async def remove_document_from_collection(
  399. self, document_id: UUID, collection_id: UUID
  400. ):
  401. await self.providers.database.collections_handler.remove_document_from_collection_relational(
  402. document_id, collection_id
  403. )
  404. await self.providers.database.chunks_handler.remove_document_from_collection_vector(
  405. document_id, collection_id
  406. )
  407. # await self.providers.database.graphs_handler.delete_node_via_document_id(
  408. # document_id, collection_id
  409. # )
  410. return None
  411. def _process_relationships(
  412. self, relationships: list[Tuple[str, str, str]]
  413. ) -> Tuple[dict[str, list[str]], dict[str, dict[str, list[str]]]]:
  414. graph = defaultdict(list)
  415. grouped: dict[str, dict[str, list[str]]] = defaultdict(
  416. lambda: defaultdict(list)
  417. )
  418. for subject, relation, obj in relationships:
  419. graph[subject].append(obj)
  420. grouped[subject][relation].append(obj)
  421. if obj not in graph:
  422. graph[obj] = []
  423. return dict(graph), dict(grouped)
  424. def generate_output(
  425. self,
  426. grouped_relationships: dict[str, dict[str, list[str]]],
  427. graph: dict[str, list[str]],
  428. descriptions_dict: dict[str, str],
  429. print_descriptions: bool = True,
  430. ) -> list[str]:
  431. output = []
  432. # Print grouped relationships
  433. for subject, relations in grouped_relationships.items():
  434. output.append(f"\n== {subject} ==")
  435. if print_descriptions and subject in descriptions_dict:
  436. output.append(f"\tDescription: {descriptions_dict[subject]}")
  437. for relation, objects in relations.items():
  438. output.append(f" {relation}:")
  439. for obj in objects:
  440. output.append(f" - {obj}")
  441. if print_descriptions and obj in descriptions_dict:
  442. output.append(
  443. f" Description: {descriptions_dict[obj]}"
  444. )
  445. # Print basic graph statistics
  446. output.extend(
  447. [
  448. "\n== Graph Statistics ==",
  449. f"Number of nodes: {len(graph)}",
  450. f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}",
  451. f"Number of connected components: {self._count_connected_components(graph)}",
  452. ]
  453. )
  454. # Find central nodes
  455. central_nodes = self._get_central_nodes(graph)
  456. output.extend(
  457. [
  458. "\n== Most Central Nodes ==",
  459. *(
  460. f" {node}: {centrality:.4f}"
  461. for node, centrality in central_nodes
  462. ),
  463. ]
  464. )
  465. return output
  466. def _count_connected_components(self, graph: dict[str, list[str]]) -> int:
  467. visited = set()
  468. components = 0
  469. def dfs(node):
  470. visited.add(node)
  471. for neighbor in graph[node]:
  472. if neighbor not in visited:
  473. dfs(neighbor)
  474. for node in graph:
  475. if node not in visited:
  476. dfs(node)
  477. components += 1
  478. return components
  479. def _get_central_nodes(
  480. self, graph: dict[str, list[str]]
  481. ) -> list[Tuple[str, float]]:
  482. degree = {node: len(neighbors) for node, neighbors in graph.items()}
  483. total_nodes = len(graph)
  484. centrality = {
  485. node: deg / (total_nodes - 1) for node, deg in degree.items()
  486. }
  487. return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]
  488. async def create_collection(
  489. self,
  490. owner_id: UUID,
  491. name: Optional[str] = None,
  492. description: str | None = None,
  493. ) -> CollectionResponse:
  494. result = await self.providers.database.collections_handler.create_collection(
  495. owner_id=owner_id,
  496. name=name,
  497. description=description,
  498. )
  499. await self.providers.database.graphs_handler.create(
  500. collection_id=result.id,
  501. name=name,
  502. description=description,
  503. )
  504. return result
  505. async def update_collection(
  506. self,
  507. collection_id: UUID,
  508. name: Optional[str] = None,
  509. description: Optional[str] = None,
  510. generate_description: bool = False,
  511. ) -> CollectionResponse:
  512. if generate_description:
  513. description = await self.summarize_collection(
  514. id=collection_id, offset=0, limit=100
  515. )
  516. return await self.providers.database.collections_handler.update_collection(
  517. collection_id=collection_id,
  518. name=name,
  519. description=description,
  520. )
  521. async def delete_collection(self, collection_id: UUID) -> bool:
  522. await self.providers.database.collections_handler.delete_collection_relational(
  523. collection_id
  524. )
  525. await self.providers.database.chunks_handler.delete_collection_vector(
  526. collection_id
  527. )
  528. try:
  529. await self.providers.database.graphs_handler.delete(
  530. collection_id=collection_id,
  531. )
  532. except Exception as e:
  533. logger.warning(
  534. f"Error deleting graph for collection {collection_id}: {e}"
  535. )
  536. return True
  537. async def collections_overview(
  538. self,
  539. offset: int,
  540. limit: int,
  541. user_ids: Optional[list[UUID]] = None,
  542. document_ids: Optional[list[UUID]] = None,
  543. collection_ids: Optional[list[UUID]] = None,
  544. owner_only: bool = False,
  545. ) -> dict[str, list[CollectionResponse] | int]:
  546. return await self.providers.database.collections_handler.get_collections_overview(
  547. offset=offset,
  548. limit=limit,
  549. filter_user_ids=user_ids,
  550. filter_document_ids=document_ids,
  551. filter_collection_ids=collection_ids,
  552. owner_only=owner_only,
  553. )
  554. async def add_user_to_collection(
  555. self, user_id: UUID, collection_id: UUID
  556. ) -> bool:
  557. return (
  558. await self.providers.database.users_handler.add_user_to_collection(
  559. user_id, collection_id
  560. )
  561. )
  562. async def remove_user_from_collection(
  563. self, user_id: UUID, collection_id: UUID
  564. ) -> bool:
  565. return await self.providers.database.users_handler.remove_user_from_collection(
  566. user_id, collection_id
  567. )
  568. async def get_users_in_collection(
  569. self, collection_id: UUID, offset: int = 0, limit: int = 100
  570. ) -> dict[str, list[User] | int]:
  571. return await self.providers.database.users_handler.get_users_in_collection(
  572. collection_id, offset=offset, limit=limit
  573. )
  574. async def documents_in_collection(
  575. self, collection_id: UUID, offset: int = 0, limit: int = 100
  576. ) -> dict[str, list[DocumentResponse] | int]:
  577. return await self.providers.database.collections_handler.documents_in_collection(
  578. collection_id, offset=offset, limit=limit
  579. )
  580. async def summarize_collection(
  581. self, id: UUID, offset: int, limit: int
  582. ) -> str:
  583. documents_in_collection_response = await self.documents_in_collection(
  584. collection_id=id,
  585. offset=offset,
  586. limit=limit,
  587. )
  588. document_summaries = [
  589. document.summary
  590. for document in documents_in_collection_response["results"] # type: ignore
  591. ]
  592. logger.info(
  593. f"Summarizing collection {id} with {len(document_summaries)} of {documents_in_collection_response['total_entries']} documents."
  594. )
  595. formatted_summaries = "\n\n".join(document_summaries) # type: ignore
  596. messages = await self.providers.database.prompts_handler.get_message_payload(
  597. system_prompt_name=self.config.database.collection_summary_system_prompt,
  598. task_prompt_name=self.config.database.collection_summary_prompt,
  599. task_inputs={"document_summaries": formatted_summaries},
  600. )
  601. response = await self.providers.llm.aget_completion(
  602. messages=messages,
  603. generation_config=GenerationConfig(
  604. model=self.config.ingestion.document_summary_model
  605. or self.config.app.fast_llm
  606. ),
  607. )
  608. if collection_summary := response.choices[0].message.content:
  609. return collection_summary
  610. else:
  611. raise ValueError("Expected a generated response.")
  612. async def add_prompt(
  613. self, name: str, template: str, input_types: dict[str, str]
  614. ) -> dict:
  615. try:
  616. await self.providers.database.prompts_handler.add_prompt(
  617. name, template, input_types
  618. )
  619. return f"Prompt '{name}' added successfully." # type: ignore
  620. except ValueError as e:
  621. raise R2RException(status_code=400, message=str(e)) from e
  622. async def get_cached_prompt(
  623. self,
  624. prompt_name: str,
  625. inputs: Optional[dict[str, Any]] = None,
  626. prompt_override: Optional[str] = None,
  627. ) -> dict:
  628. try:
  629. return {
  630. "message": (
  631. await self.providers.database.prompts_handler.get_cached_prompt(
  632. prompt_name=prompt_name,
  633. inputs=inputs,
  634. prompt_override=prompt_override,
  635. )
  636. )
  637. }
  638. except ValueError as e:
  639. raise R2RException(status_code=404, message=str(e)) from e
  640. async def get_prompt(
  641. self,
  642. prompt_name: str,
  643. inputs: Optional[dict[str, Any]] = None,
  644. prompt_override: Optional[str] = None,
  645. ) -> dict:
  646. try:
  647. return await self.providers.database.prompts_handler.get_prompt( # type: ignore
  648. name=prompt_name,
  649. inputs=inputs,
  650. prompt_override=prompt_override,
  651. )
  652. except ValueError as e:
  653. raise R2RException(status_code=404, message=str(e)) from e
  654. async def get_all_prompts(self) -> dict[str, Prompt]:
  655. return await self.providers.database.prompts_handler.get_all_prompts()
  656. async def update_prompt(
  657. self,
  658. name: str,
  659. template: Optional[str] = None,
  660. input_types: Optional[dict[str, str]] = None,
  661. ) -> dict:
  662. try:
  663. await self.providers.database.prompts_handler.update_prompt(
  664. name, template, input_types
  665. )
  666. return f"Prompt '{name}' updated successfully." # type: ignore
  667. except ValueError as e:
  668. raise R2RException(status_code=404, message=str(e)) from e
  669. async def delete_prompt(self, name: str) -> dict:
  670. try:
  671. await self.providers.database.prompts_handler.delete_prompt(name)
  672. return {"message": f"Prompt '{name}' deleted successfully."}
  673. except ValueError as e:
  674. raise R2RException(status_code=404, message=str(e)) from e
  675. async def get_conversation(
  676. self,
  677. conversation_id: UUID,
  678. user_ids: Optional[list[UUID]] = None,
  679. ) -> list[MessageResponse]:
  680. return await self.providers.database.conversations_handler.get_conversation(
  681. conversation_id=conversation_id,
  682. filter_user_ids=user_ids,
  683. )
  684. async def create_conversation(
  685. self,
  686. user_id: Optional[UUID] = None,
  687. name: Optional[str] = None,
  688. ) -> ConversationResponse:
  689. return await self.providers.database.conversations_handler.create_conversation(
  690. user_id=user_id,
  691. name=name,
  692. )
  693. async def conversations_overview(
  694. self,
  695. offset: int,
  696. limit: int,
  697. conversation_ids: Optional[list[UUID]] = None,
  698. user_ids: Optional[list[UUID]] = None,
  699. ) -> dict[str, list[dict] | int]:
  700. return await self.providers.database.conversations_handler.get_conversations_overview(
  701. offset=offset,
  702. limit=limit,
  703. filter_user_ids=user_ids,
  704. conversation_ids=conversation_ids,
  705. )
  706. async def add_message(
  707. self,
  708. conversation_id: UUID,
  709. content: Message,
  710. parent_id: Optional[UUID] = None,
  711. metadata: Optional[dict] = None,
  712. ) -> MessageResponse:
  713. return await self.providers.database.conversations_handler.add_message(
  714. conversation_id=conversation_id,
  715. content=content,
  716. parent_id=parent_id,
  717. metadata=metadata,
  718. )
  719. async def edit_message(
  720. self,
  721. message_id: UUID,
  722. new_content: Optional[str] = None,
  723. additional_metadata: Optional[dict] = None,
  724. ) -> dict[str, Any]:
  725. return (
  726. await self.providers.database.conversations_handler.edit_message(
  727. message_id=message_id,
  728. new_content=new_content,
  729. additional_metadata=additional_metadata or {},
  730. )
  731. )
  732. async def update_conversation(
  733. self, conversation_id: UUID, name: str
  734. ) -> ConversationResponse:
  735. return await self.providers.database.conversations_handler.update_conversation(
  736. conversation_id=conversation_id, name=name
  737. )
  738. async def delete_conversation(
  739. self,
  740. conversation_id: UUID,
  741. user_ids: Optional[list[UUID]] = None,
  742. ) -> None:
  743. await (
  744. self.providers.database.conversations_handler.delete_conversation(
  745. conversation_id=conversation_id,
  746. filter_user_ids=user_ids,
  747. )
  748. )
  749. async def get_user_max_documents(self, user_id: UUID) -> int | None:
  750. # Fetch the user to see if they have any overrides stored
  751. user = await self.providers.database.users_handler.get_user_by_id(
  752. user_id
  753. )
  754. if user.limits_overrides and "max_documents" in user.limits_overrides:
  755. return user.limits_overrides["max_documents"]
  756. return self.config.app.default_max_documents_per_user
  757. async def get_user_max_chunks(self, user_id: UUID) -> int | None:
  758. user = await self.providers.database.users_handler.get_user_by_id(
  759. user_id
  760. )
  761. if user.limits_overrides and "max_chunks" in user.limits_overrides:
  762. return user.limits_overrides["max_chunks"]
  763. return self.config.app.default_max_chunks_per_user
  764. async def get_user_max_collections(self, user_id: UUID) -> int | None:
  765. user = await self.providers.database.users_handler.get_user_by_id(
  766. user_id
  767. )
  768. if (
  769. user.limits_overrides
  770. and "max_collections" in user.limits_overrides
  771. ):
  772. return user.limits_overrides["max_collections"]
  773. return self.config.app.default_max_collections_per_user
  774. async def get_max_upload_size_by_type(
  775. self, user_id: UUID, file_type_or_ext: str
  776. ) -> int:
  777. """Return the maximum allowed upload size (in bytes) for the given
  778. user's file type/extension. Respects user-level overrides if present,
  779. falling back to the system config.
  780. ```json
  781. {
  782. "limits_overrides": {
  783. "max_file_size": 20_000_000,
  784. "max_file_size_by_type":
  785. {
  786. "pdf": 50_000_000,
  787. "docx": 30_000_000
  788. },
  789. ...
  790. }
  791. }
  792. ```
  793. """
  794. # 1. Normalize extension
  795. ext = file_type_or_ext.lower().lstrip(".")
  796. # 2. Fetch user from DB to see if we have any overrides
  797. user = await self.providers.database.users_handler.get_user_by_id(
  798. user_id
  799. )
  800. user_overrides = user.limits_overrides or {}
  801. # 3. Check if there's a user-level override for "max_file_size_by_type"
  802. user_file_type_limits = user_overrides.get("max_file_size_by_type", {})
  803. if ext in user_file_type_limits:
  804. return user_file_type_limits[ext]
  805. # 4. If not, check if there's a user-level fallback "max_file_size"
  806. if "max_file_size" in user_overrides:
  807. return user_overrides["max_file_size"]
  808. # 5. If none exist at user level, use system config
  809. # Example config paths:
  810. system_type_limits = self.config.app.max_upload_size_by_type
  811. if ext in system_type_limits:
  812. return system_type_limits[ext]
  813. # 6. Otherwise, return the global default
  814. return self.config.app.default_max_upload_size
  815. async def get_all_user_limits(self, user_id: UUID) -> dict[str, Any]:
  816. """
  817. Return a dictionary containing:
  818. - The system default limits (from self.config.limits)
  819. - The user's overrides (from user.limits_overrides)
  820. - The final 'effective' set of limits after merging (overall)
  821. - The usage for each relevant limit (per-route usage, etc.)
  822. """
  823. # 1) Fetch the user
  824. user = await self.providers.database.users_handler.get_user_by_id(
  825. user_id
  826. )
  827. user_overrides = user.limits_overrides or {}
  828. # 2) Grab system defaults
  829. system_defaults = {
  830. "global_per_min": self.config.database.limits.global_per_min,
  831. "route_per_min": self.config.database.limits.route_per_min,
  832. "monthly_limit": self.config.database.limits.monthly_limit,
  833. # Add additional fields if your LimitSettings has them
  834. }
  835. # 3) Build the overall (global) "effective limits" ignoring any specific route
  836. overall_effective = (
  837. self.providers.database.limits_handler.determine_effective_limits(
  838. user, route=""
  839. )
  840. )
  841. # 4) Build usage data. We'll do top-level usage for global_per_min/monthly,
  842. # then do route-by-route usage in a loop.
  843. usage: dict[str, Any] = {}
  844. now = datetime.now(timezone.utc)
  845. one_min_ago = now - timedelta(minutes=1)
  846. # (a) Global usage (per-minute)
  847. global_per_min_used = (
  848. await self.providers.database.limits_handler._count_requests(
  849. user_id, route=None, since=one_min_ago
  850. )
  851. )
  852. # (a2) Global usage (monthly) - i.e. usage across ALL routes
  853. global_monthly_used = await self.providers.database.limits_handler._count_monthly_requests(
  854. user_id, route=None
  855. )
  856. usage["global_per_min"] = {
  857. "used": global_per_min_used,
  858. "limit": overall_effective.global_per_min,
  859. "remaining": (
  860. overall_effective.global_per_min - global_per_min_used
  861. if overall_effective.global_per_min is not None
  862. else None
  863. ),
  864. }
  865. usage["monthly_limit"] = {
  866. "used": global_monthly_used,
  867. "limit": overall_effective.monthly_limit,
  868. "remaining": (
  869. overall_effective.monthly_limit - global_monthly_used
  870. if overall_effective.monthly_limit is not None
  871. else None
  872. ),
  873. }
  874. # (b) Route-level usage. We'll gather all routes from system + user overrides
  875. system_route_limits = (
  876. self.config.database.route_limits
  877. ) # dict[str, LimitSettings]
  878. user_route_overrides = user_overrides.get("route_overrides", {})
  879. route_keys = set(system_route_limits.keys()) | set(
  880. user_route_overrides.keys()
  881. )
  882. usage["routes"] = {}
  883. for route in route_keys:
  884. # 1) Get the final merged limits for this specific route
  885. route_effective = self.providers.database.limits_handler.determine_effective_limits(
  886. user, route
  887. )
  888. # 2) Count requests for the last minute on this route
  889. route_per_min_used = (
  890. await self.providers.database.limits_handler._count_requests(
  891. user_id, route, one_min_ago
  892. )
  893. )
  894. # 3) Count route-specific monthly usage
  895. route_monthly_used = await self.providers.database.limits_handler._count_monthly_requests(
  896. user_id, route
  897. )
  898. usage["routes"][route] = {
  899. "route_per_min": {
  900. "used": route_per_min_used,
  901. "limit": route_effective.route_per_min,
  902. "remaining": (
  903. route_effective.route_per_min - route_per_min_used
  904. if route_effective.route_per_min is not None
  905. else None
  906. ),
  907. },
  908. "monthly_limit": {
  909. "used": route_monthly_used,
  910. "limit": route_effective.monthly_limit,
  911. "remaining": (
  912. route_effective.monthly_limit - route_monthly_used
  913. if route_effective.monthly_limit is not None
  914. else None
  915. ),
  916. },
  917. }
  918. max_documents = await self.get_user_max_documents(user_id)
  919. used_documents = (
  920. await self.providers.database.documents_handler.get_documents_overview(
  921. limit=1, offset=0, filter_user_ids=[user_id]
  922. )
  923. )["total_entries"]
  924. max_chunks = await self.get_user_max_chunks(user_id)
  925. used_chunks = (
  926. await self.providers.database.chunks_handler.list_chunks(
  927. limit=1, offset=0, filters={"owner_id": user_id}
  928. )
  929. )["total_entries"]
  930. max_collections = await self.get_user_max_collections(user_id)
  931. used_collections: int = ( # type: ignore
  932. await self.providers.database.collections_handler.get_collections_overview(
  933. limit=1, offset=0, filter_user_ids=[user_id]
  934. )
  935. )["total_entries"]
  936. storage_limits = {
  937. "chunks": {
  938. "limit": max_chunks,
  939. "used": used_chunks,
  940. "remaining": (
  941. max_chunks - used_chunks
  942. if max_chunks is not None
  943. else None
  944. ),
  945. },
  946. "documents": {
  947. "limit": max_documents,
  948. "used": used_documents,
  949. "remaining": (
  950. max_documents - used_documents
  951. if max_documents is not None
  952. else None
  953. ),
  954. },
  955. "collections": {
  956. "limit": max_collections,
  957. "used": used_collections,
  958. "remaining": (
  959. max_collections - used_collections
  960. if max_collections is not None
  961. else None
  962. ),
  963. },
  964. }
  965. # 5) Return a structured response
  966. return {
  967. "storage_limits": storage_limits,
  968. "system_defaults": system_defaults,
  969. "user_overrides": user_overrides,
  970. "effective_limits": {
  971. "global_per_min": overall_effective.global_per_min,
  972. "route_per_min": overall_effective.route_per_min,
  973. "monthly_limit": overall_effective.monthly_limit,
  974. },
  975. "usage": usage,
  976. }