management_service.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850
  1. import logging
  2. import os
  3. from collections import defaultdict
  4. from datetime import datetime
  5. from typing import IO, Any, BinaryIO, Optional, Tuple
  6. from uuid import UUID
  7. import toml
  8. from core.base import (
  9. CollectionResponse,
  10. ConversationResponse,
  11. DocumentResponse,
  12. GenerationConfig,
  13. KGEnrichmentStatus,
  14. Message,
  15. Prompt,
  16. R2RException,
  17. RunManager,
  18. StoreType,
  19. User,
  20. )
  21. from core.telemetry.telemetry_decorator import telemetry_event
  22. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  23. from ..config import R2RConfig
  24. from .base import Service
  25. logger = logging.getLogger()
  26. class ManagementService(Service):
  27. def __init__(
  28. self,
  29. config: R2RConfig,
  30. providers: R2RProviders,
  31. pipes: R2RPipes,
  32. pipelines: R2RPipelines,
  33. agents: R2RAgents,
  34. run_manager: RunManager,
  35. ):
  36. super().__init__(
  37. config,
  38. providers,
  39. pipes,
  40. pipelines,
  41. agents,
  42. run_manager,
  43. )
  44. @telemetry_event("AppSettings")
  45. async def app_settings(self):
  46. prompts = (
  47. await self.providers.database.prompts_handler.get_all_prompts()
  48. )
  49. config_toml = self.config.to_toml()
  50. config_dict = toml.loads(config_toml)
  51. return {
  52. "config": config_dict,
  53. "prompts": prompts,
  54. "r2r_project_name": os.environ["R2R_PROJECT_NAME"],
  55. # "r2r_version": get_version("r2r"),
  56. }
  57. @telemetry_event("UsersOverview")
  58. async def users_overview(
  59. self,
  60. offset: int,
  61. limit: int,
  62. user_ids: Optional[list[UUID]] = None,
  63. ):
  64. return await self.providers.database.users_handler.get_users_overview(
  65. offset=offset,
  66. limit=limit,
  67. user_ids=user_ids,
  68. )
  69. async def delete_documents_and_chunks_by_filter(
  70. self,
  71. filters: dict[str, Any],
  72. ):
  73. """
  74. Delete chunks matching the given filters. If any documents are now empty
  75. (i.e., have no remaining chunks), delete those documents as well.
  76. Args:
  77. filters (dict[str, Any]): Filters specifying which chunks to delete.
  78. chunks_handler (PostgresChunksHandler): The handler for chunk operations.
  79. documents_handler (PostgresDocumentsHandler): The handler for document operations.
  80. graphs_handler: Handler for entity and relationship operations in the KG.
  81. Returns:
  82. dict: A summary of what was deleted.
  83. """
  84. def transform_chunk_id_to_id(
  85. filters: dict[str, Any]
  86. ) -> dict[str, Any]:
  87. """
  88. Example transformation function if your filters use `chunk_id` instead of `id`.
  89. Recursively transform `chunk_id` to `id`.
  90. """
  91. if isinstance(filters, dict):
  92. transformed = {}
  93. for key, value in filters.items():
  94. if key == "chunk_id":
  95. transformed["id"] = value
  96. elif key in ["$and", "$or"]:
  97. transformed[key] = [
  98. transform_chunk_id_to_id(item) for item in value
  99. ]
  100. else:
  101. transformed[key] = transform_chunk_id_to_id(value)
  102. return transformed
  103. return filters
  104. # 1. (Optional) Validate the input filters based on your rules.
  105. # E.g., check if filters is not empty, allowed fields, etc.
  106. # validate_filters(filters)
  107. # 2. Transform filters if needed.
  108. # For example, if `chunk_id` is used, map it to `id`, or similar transformations.
  109. transformed_filters = transform_chunk_id_to_id(filters)
  110. # 3. First, find out which chunks match these filters *before* deleting, so we know which docs are affected.
  111. # You can do a list operation on chunks to see which chunk IDs and doc IDs would be hit.
  112. interim_results = (
  113. await self.providers.database.chunks_handler.list_chunks(
  114. filters=transformed_filters,
  115. offset=0,
  116. limit=1_000, # Arbitrary large limit or pagination logic
  117. include_vectors=False,
  118. )
  119. )
  120. if interim_results["page_info"]["total_entries"] == 0:
  121. raise R2RException(
  122. status_code=404, message="No entries found for deletion."
  123. )
  124. results = interim_results["results"]
  125. while interim_results["page_info"]["total_entries"] == 1_000:
  126. # If we hit the limit, we need to paginate to get all results
  127. interim_results = (
  128. await self.providers.database.chunks_handler.list_chunks(
  129. filters=transformed_filters,
  130. offset=interim_results["offset"] + 1_000,
  131. limit=1_000,
  132. include_vectors=False,
  133. )
  134. )
  135. results.extend(interim_results["results"])
  136. matched_chunk_docs = {UUID(chunk["document_id"]) for chunk in results}
  137. # If no chunks match, raise or return a no-op result
  138. if not matched_chunk_docs:
  139. return {
  140. "success": False,
  141. "message": "No chunks match the given filters.",
  142. }
  143. # 4. Delete the matching chunks from the database.
  144. delete_results = await self.providers.database.chunks_handler.delete(
  145. transformed_filters
  146. )
  147. # 5. From `delete_results`, extract the document_ids that were affected.
  148. # The delete_results should map chunk_id to details including `document_id`.
  149. affected_doc_ids = {
  150. UUID(info["document_id"])
  151. for info in delete_results.values()
  152. if info.get("document_id")
  153. }
  154. # 6. For each affected document, check if the document still has any chunks left.
  155. docs_to_delete = []
  156. for doc_id in affected_doc_ids:
  157. remaining = await self.providers.database.chunks_handler.list_document_chunks(
  158. document_id=doc_id,
  159. offset=0,
  160. limit=1, # Just need to know if there's at least one left
  161. include_vectors=False,
  162. )
  163. # If no remaining chunks, we should delete the document.
  164. if remaining["total_entries"] == 0:
  165. docs_to_delete.append(doc_id)
  166. # 7. Delete documents that no longer have associated chunks.
  167. # Also update graphs if needed (entities/relationships).
  168. for doc_id in docs_to_delete:
  169. # Delete related entities & relationships if needed:
  170. await self.providers.database.graphs_handler.entities.delete(
  171. parent_id=doc_id,
  172. store_type=StoreType.DOCUMENTS,
  173. )
  174. await self.providers.database.graphs_handler.relationships.delete(
  175. parent_id=doc_id,
  176. store_type=StoreType.DOCUMENTS,
  177. )
  178. # Finally, delete the document from documents_overview:
  179. await self.providers.database.documents_handler.delete(
  180. document_id=doc_id
  181. )
  182. # 8. Return a summary of what happened.
  183. return {
  184. "success": True,
  185. "deleted_chunks_count": len(delete_results),
  186. "deleted_documents_count": len(docs_to_delete),
  187. "deleted_document_ids": [str(d) for d in docs_to_delete],
  188. }
  189. @telemetry_event("DownloadFile")
  190. async def download_file(
  191. self, document_id: UUID
  192. ) -> Optional[Tuple[str, BinaryIO, int]]:
  193. if result := await self.providers.database.files_handler.retrieve_file(
  194. document_id
  195. ):
  196. return result
  197. return None
  198. @telemetry_event("ExportFiles")
  199. async def export_files(
  200. self,
  201. document_ids: Optional[list[UUID]] = None,
  202. start_date: Optional[datetime] = None,
  203. end_date: Optional[datetime] = None,
  204. ) -> tuple[str, BinaryIO, int]:
  205. return (
  206. await self.providers.database.files_handler.retrieve_files_as_zip(
  207. document_ids=document_ids,
  208. start_date=start_date,
  209. end_date=end_date,
  210. )
  211. )
  212. @telemetry_event("ExportCollections")
  213. async def export_collections(
  214. self,
  215. columns: Optional[list[str]] = None,
  216. filters: Optional[dict] = None,
  217. include_header: bool = True,
  218. ) -> tuple[str, IO]:
  219. return await self.providers.database.collections_handler.export_to_csv(
  220. columns=columns,
  221. filters=filters,
  222. include_header=include_header,
  223. )
  224. @telemetry_event("ExportDocuments")
  225. async def export_documents(
  226. self,
  227. columns: Optional[list[str]] = None,
  228. filters: Optional[dict] = None,
  229. include_header: bool = True,
  230. ) -> tuple[str, IO]:
  231. return await self.providers.database.documents_handler.export_to_csv(
  232. columns=columns,
  233. filters=filters,
  234. include_header=include_header,
  235. )
  236. @telemetry_event("ExportDocumentEntities")
  237. async def export_document_entities(
  238. self,
  239. id: UUID,
  240. columns: Optional[list[str]] = None,
  241. filters: Optional[dict] = None,
  242. include_header: bool = True,
  243. ) -> tuple[str, IO]:
  244. return await self.providers.database.graphs_handler.entities.export_to_csv(
  245. parent_id=id,
  246. store_type=StoreType.DOCUMENTS,
  247. columns=columns,
  248. filters=filters,
  249. include_header=include_header,
  250. )
  251. @telemetry_event("ExportDocumentRelationships")
  252. async def export_document_relationships(
  253. self,
  254. id: UUID,
  255. columns: Optional[list[str]] = None,
  256. filters: Optional[dict] = None,
  257. include_header: bool = True,
  258. ) -> tuple[str, IO]:
  259. return await self.providers.database.graphs_handler.relationships.export_to_csv(
  260. parent_id=id,
  261. store_type=StoreType.DOCUMENTS,
  262. columns=columns,
  263. filters=filters,
  264. include_header=include_header,
  265. )
  266. @telemetry_event("ExportConversations")
  267. async def export_conversations(
  268. self,
  269. columns: Optional[list[str]] = None,
  270. filters: Optional[dict] = None,
  271. include_header: bool = True,
  272. ) -> tuple[str, IO]:
  273. return await self.providers.database.conversations_handler.export_conversations_to_csv(
  274. columns=columns,
  275. filters=filters,
  276. include_header=include_header,
  277. )
  278. @telemetry_event("ExportGraphEntities")
  279. async def export_graph_entities(
  280. self,
  281. id: UUID,
  282. columns: Optional[list[str]] = None,
  283. filters: Optional[dict] = None,
  284. include_header: bool = True,
  285. ) -> tuple[str, IO]:
  286. return await self.providers.database.graphs_handler.entities.export_to_csv(
  287. parent_id=id,
  288. store_type=StoreType.GRAPHS,
  289. columns=columns,
  290. filters=filters,
  291. include_header=include_header,
  292. )
  293. @telemetry_event("ExportGraphRelationships")
  294. async def export_graph_relationships(
  295. self,
  296. id: UUID,
  297. columns: Optional[list[str]] = None,
  298. filters: Optional[dict] = None,
  299. include_header: bool = True,
  300. ) -> tuple[str, IO]:
  301. return await self.providers.database.graphs_handler.relationships.export_to_csv(
  302. parent_id=id,
  303. store_type=StoreType.GRAPHS,
  304. columns=columns,
  305. filters=filters,
  306. include_header=include_header,
  307. )
  308. @telemetry_event("ExportGraphCommunities")
  309. async def export_graph_communities(
  310. self,
  311. id: UUID,
  312. columns: Optional[list[str]] = None,
  313. filters: Optional[dict] = None,
  314. include_header: bool = True,
  315. ) -> tuple[str, IO]:
  316. return await self.providers.database.graphs_handler.communities.export_to_csv(
  317. parent_id=id,
  318. store_type=StoreType.GRAPHS,
  319. columns=columns,
  320. filters=filters,
  321. include_header=include_header,
  322. )
  323. @telemetry_event("ExportMessages")
  324. async def export_messages(
  325. self,
  326. columns: Optional[list[str]] = None,
  327. filters: Optional[dict] = None,
  328. include_header: bool = True,
  329. ) -> tuple[str, IO]:
  330. return await self.providers.database.conversations_handler.export_messages_to_csv(
  331. columns=columns,
  332. filters=filters,
  333. include_header=include_header,
  334. )
  335. @telemetry_event("ExportUsers")
  336. async def export_users(
  337. self,
  338. columns: Optional[list[str]] = None,
  339. filters: Optional[dict] = None,
  340. include_header: bool = True,
  341. ) -> tuple[str, IO]:
  342. return await self.providers.database.users_handler.export_to_csv(
  343. columns=columns,
  344. filters=filters,
  345. include_header=include_header,
  346. )
  347. @telemetry_event("DocumentsOverview")
  348. async def documents_overview(
  349. self,
  350. offset: int,
  351. limit: int,
  352. user_ids: Optional[list[UUID]] = None,
  353. collection_ids: Optional[list[UUID]] = None,
  354. document_ids: Optional[list[UUID]] = None,
  355. ):
  356. return await self.providers.database.documents_handler.get_documents_overview(
  357. offset=offset,
  358. limit=limit,
  359. filter_document_ids=document_ids,
  360. filter_user_ids=user_ids,
  361. filter_collection_ids=collection_ids,
  362. )
  363. @telemetry_event("DocumentChunks")
  364. async def list_document_chunks(
  365. self,
  366. document_id: UUID,
  367. offset: int,
  368. limit: int,
  369. include_vectors: bool = False,
  370. ):
  371. return (
  372. await self.providers.database.chunks_handler.list_document_chunks(
  373. document_id=document_id,
  374. offset=offset,
  375. limit=limit,
  376. include_vectors=include_vectors,
  377. )
  378. )
  379. @telemetry_event("AssignDocumentToCollection")
  380. async def assign_document_to_collection(
  381. self, document_id: UUID, collection_id: UUID
  382. ):
  383. await self.providers.database.chunks_handler.assign_document_chunks_to_collection(
  384. document_id, collection_id
  385. )
  386. await self.providers.database.collections_handler.assign_document_to_collection_relational(
  387. document_id, collection_id
  388. )
  389. await self.providers.database.documents_handler.set_workflow_status(
  390. id=collection_id,
  391. status_type="graph_sync_status",
  392. status=KGEnrichmentStatus.OUTDATED,
  393. )
  394. await self.providers.database.documents_handler.set_workflow_status(
  395. id=collection_id,
  396. status_type="graph_cluster_status",
  397. status=KGEnrichmentStatus.OUTDATED,
  398. )
  399. return {"message": "Document assigned to collection successfully"}
  400. @telemetry_event("RemoveDocumentFromCollection")
  401. async def remove_document_from_collection(
  402. self, document_id: UUID, collection_id: UUID
  403. ):
  404. await self.providers.database.collections_handler.remove_document_from_collection_relational(
  405. document_id, collection_id
  406. )
  407. await self.providers.database.chunks_handler.remove_document_from_collection_vector(
  408. document_id, collection_id
  409. )
  410. # await self.providers.database.graphs_handler.delete_node_via_document_id(
  411. # document_id, collection_id
  412. # )
  413. return None
  414. def _process_relationships(
  415. self, relationships: list[Tuple[str, str, str]]
  416. ) -> Tuple[dict[str, list[str]], dict[str, dict[str, list[str]]]]:
  417. graph = defaultdict(list)
  418. grouped: dict[str, dict[str, list[str]]] = defaultdict(
  419. lambda: defaultdict(list)
  420. )
  421. for subject, relation, obj in relationships:
  422. graph[subject].append(obj)
  423. grouped[subject][relation].append(obj)
  424. if obj not in graph:
  425. graph[obj] = []
  426. return dict(graph), dict(grouped)
  427. def generate_output(
  428. self,
  429. grouped_relationships: dict[str, dict[str, list[str]]],
  430. graph: dict[str, list[str]],
  431. descriptions_dict: dict[str, str],
  432. print_descriptions: bool = True,
  433. ) -> list[str]:
  434. output = []
  435. # Print grouped relationships
  436. for subject, relations in grouped_relationships.items():
  437. output.append(f"\n== {subject} ==")
  438. if print_descriptions and subject in descriptions_dict:
  439. output.append(f"\tDescription: {descriptions_dict[subject]}")
  440. for relation, objects in relations.items():
  441. output.append(f" {relation}:")
  442. for obj in objects:
  443. output.append(f" - {obj}")
  444. if print_descriptions and obj in descriptions_dict:
  445. output.append(
  446. f" Description: {descriptions_dict[obj]}"
  447. )
  448. # Print basic graph statistics
  449. output.extend(
  450. [
  451. "\n== Graph Statistics ==",
  452. f"Number of nodes: {len(graph)}",
  453. f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}",
  454. f"Number of connected components: {self._count_connected_components(graph)}",
  455. ]
  456. )
  457. # Find central nodes
  458. central_nodes = self._get_central_nodes(graph)
  459. output.extend(
  460. [
  461. "\n== Most Central Nodes ==",
  462. *(
  463. f" {node}: {centrality:.4f}"
  464. for node, centrality in central_nodes
  465. ),
  466. ]
  467. )
  468. return output
  469. def _count_connected_components(self, graph: dict[str, list[str]]) -> int:
  470. visited = set()
  471. components = 0
  472. def dfs(node):
  473. visited.add(node)
  474. for neighbor in graph[node]:
  475. if neighbor not in visited:
  476. dfs(neighbor)
  477. for node in graph:
  478. if node not in visited:
  479. dfs(node)
  480. components += 1
  481. return components
  482. def _get_central_nodes(
  483. self, graph: dict[str, list[str]]
  484. ) -> list[Tuple[str, float]]:
  485. degree = {node: len(neighbors) for node, neighbors in graph.items()}
  486. total_nodes = len(graph)
  487. centrality = {
  488. node: deg / (total_nodes - 1) for node, deg in degree.items()
  489. }
  490. return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]
  491. @telemetry_event("CreateCollection")
  492. async def create_collection(
  493. self,
  494. owner_id: UUID,
  495. name: Optional[str] = None,
  496. description: str = "",
  497. ) -> CollectionResponse:
  498. result = await self.providers.database.collections_handler.create_collection(
  499. owner_id=owner_id,
  500. name=name,
  501. description=description,
  502. )
  503. graph_result = await self.providers.database.graphs_handler.create(
  504. collection_id=result.id,
  505. name=name,
  506. description=description,
  507. )
  508. return result
  509. @telemetry_event("UpdateCollection")
  510. async def update_collection(
  511. self,
  512. collection_id: UUID,
  513. name: Optional[str] = None,
  514. description: Optional[str] = None,
  515. generate_description: bool = False,
  516. ) -> CollectionResponse:
  517. if generate_description:
  518. description = await self.summarize_collection(
  519. id=collection_id, offset=0, limit=100
  520. )
  521. return await self.providers.database.collections_handler.update_collection(
  522. collection_id=collection_id,
  523. name=name,
  524. description=description,
  525. )
  526. @telemetry_event("DeleteCollection")
  527. async def delete_collection(self, collection_id: UUID) -> bool:
  528. await self.providers.database.collections_handler.delete_collection_relational(
  529. collection_id
  530. )
  531. await self.providers.database.chunks_handler.delete_collection_vector(
  532. collection_id
  533. )
  534. return True
  535. @telemetry_event("ListCollections")
  536. async def collections_overview(
  537. self,
  538. offset: int,
  539. limit: int,
  540. user_ids: Optional[list[UUID]] = None,
  541. document_ids: Optional[list[UUID]] = None,
  542. collection_ids: Optional[list[UUID]] = None,
  543. ) -> dict[str, list[CollectionResponse] | int]:
  544. return await self.providers.database.collections_handler.get_collections_overview(
  545. offset=offset,
  546. limit=limit,
  547. filter_user_ids=user_ids,
  548. filter_document_ids=document_ids,
  549. filter_collection_ids=collection_ids,
  550. )
  551. @telemetry_event("AddUserToCollection")
  552. async def add_user_to_collection(
  553. self, user_id: UUID, collection_id: UUID
  554. ) -> bool:
  555. return (
  556. await self.providers.database.users_handler.add_user_to_collection(
  557. user_id, collection_id
  558. )
  559. )
  560. @telemetry_event("RemoveUserFromCollection")
  561. async def remove_user_from_collection(
  562. self, user_id: UUID, collection_id: UUID
  563. ) -> bool:
  564. x = await self.providers.database.users_handler.remove_user_from_collection(
  565. user_id, collection_id
  566. )
  567. return x
  568. @telemetry_event("GetUsersInCollection")
  569. async def get_users_in_collection(
  570. self, collection_id: UUID, offset: int = 0, limit: int = 100
  571. ) -> dict[str, list[User] | int]:
  572. return await self.providers.database.users_handler.get_users_in_collection(
  573. collection_id, offset=offset, limit=limit
  574. )
  575. @telemetry_event("GetDocumentsInCollection")
  576. async def documents_in_collection(
  577. self, collection_id: UUID, offset: int = 0, limit: int = 100
  578. ) -> dict[str, list[DocumentResponse] | int]:
  579. return await self.providers.database.collections_handler.documents_in_collection(
  580. collection_id, offset=offset, limit=limit
  581. )
  582. @telemetry_event("SummarizeCollection")
  583. async def summarize_collection(
  584. self, id: UUID, offset: int, limit: int
  585. ) -> str:
  586. documents_in_collection_response = await self.documents_in_collection(
  587. collection_id=id,
  588. offset=offset,
  589. limit=limit,
  590. )
  591. document_summaries = [
  592. document.summary
  593. for document in documents_in_collection_response["results"]
  594. ]
  595. logger.info(
  596. f"Summarizing collection {id} with {len(document_summaries)} of {documents_in_collection_response['total_entries']} documents."
  597. )
  598. formatted_summaries = "\n\n".join(document_summaries)
  599. messages = await self.providers.database.prompts_handler.get_message_payload(
  600. system_prompt_name=self.config.database.collection_summary_system_prompt,
  601. task_prompt_name=self.config.database.collection_summary_task_prompt,
  602. task_inputs={"document_summaries": formatted_summaries},
  603. )
  604. response = await self.providers.llm.aget_completion(
  605. messages=messages,
  606. generation_config=GenerationConfig(
  607. model=self.config.ingestion.document_summary_model
  608. ),
  609. )
  610. collection_summary = response.choices[0].message.content
  611. if not collection_summary:
  612. raise ValueError("Expected a generated response.")
  613. return collection_summary
  614. @telemetry_event("AddPrompt")
  615. async def add_prompt(
  616. self, name: str, template: str, input_types: dict[str, str]
  617. ) -> dict:
  618. try:
  619. await self.providers.database.prompts_handler.add_prompt(
  620. name, template, input_types
  621. )
  622. return f"Prompt '{name}' added successfully." # type: ignore
  623. except ValueError as e:
  624. raise R2RException(status_code=400, message=str(e))
  625. @telemetry_event("GetPrompt")
  626. async def get_cached_prompt(
  627. self,
  628. prompt_name: str,
  629. inputs: Optional[dict[str, Any]] = None,
  630. prompt_override: Optional[str] = None,
  631. ) -> dict:
  632. try:
  633. return {
  634. "message": (
  635. await self.providers.database.prompts_handler.get_cached_prompt(
  636. prompt_name=prompt_name,
  637. inputs=inputs,
  638. prompt_override=prompt_override,
  639. )
  640. )
  641. }
  642. except ValueError as e:
  643. raise R2RException(status_code=404, message=str(e))
  644. @telemetry_event("GetPrompt")
  645. async def get_prompt(
  646. self,
  647. prompt_name: str,
  648. inputs: Optional[dict[str, Any]] = None,
  649. prompt_override: Optional[str] = None,
  650. ) -> dict:
  651. try:
  652. return await self.providers.database.prompts_handler.get_prompt( # type: ignore
  653. name=prompt_name,
  654. inputs=inputs,
  655. prompt_override=prompt_override,
  656. )
  657. except ValueError as e:
  658. raise R2RException(status_code=404, message=str(e))
  659. @telemetry_event("GetAllPrompts")
  660. async def get_all_prompts(self) -> dict[str, Prompt]:
  661. return await self.providers.database.prompts_handler.get_all_prompts()
  662. @telemetry_event("UpdatePrompt")
  663. async def update_prompt(
  664. self,
  665. name: str,
  666. template: Optional[str] = None,
  667. input_types: Optional[dict[str, str]] = None,
  668. ) -> dict:
  669. try:
  670. await self.providers.database.prompts_handler.update_prompt(
  671. name, template, input_types
  672. )
  673. return f"Prompt '{name}' updated successfully." # type: ignore
  674. except ValueError as e:
  675. raise R2RException(status_code=404, message=str(e))
  676. @telemetry_event("DeletePrompt")
  677. async def delete_prompt(self, name: str) -> dict:
  678. try:
  679. await self.providers.database.prompts_handler.delete_prompt(name)
  680. return {"message": f"Prompt '{name}' deleted successfully."}
  681. except ValueError as e:
  682. raise R2RException(status_code=404, message=str(e))
  683. @telemetry_event("GetConversation")
  684. async def get_conversation(
  685. self,
  686. conversation_id: UUID,
  687. user_ids: Optional[list[UUID]] = None,
  688. ) -> Tuple[str, list[Message], list[dict]]:
  689. return await self.providers.database.conversations_handler.get_conversation(
  690. conversation_id=conversation_id,
  691. filter_user_ids=user_ids,
  692. )
  693. @telemetry_event("CreateConversation")
  694. async def create_conversation(
  695. self,
  696. user_id: Optional[UUID] = None,
  697. name: Optional[str] = None,
  698. ) -> ConversationResponse:
  699. return await self.providers.database.conversations_handler.create_conversation(
  700. user_id=user_id,
  701. name=name,
  702. )
  703. @telemetry_event("ConversationsOverview")
  704. async def conversations_overview(
  705. self,
  706. offset: int,
  707. limit: int,
  708. conversation_ids: Optional[list[UUID]] = None,
  709. user_ids: Optional[list[UUID]] = None,
  710. ) -> dict[str, list[dict] | int]:
  711. return await self.providers.database.conversations_handler.get_conversations_overview(
  712. offset=offset,
  713. limit=limit,
  714. filter_user_ids=user_ids,
  715. conversation_ids=conversation_ids,
  716. )
  717. @telemetry_event("AddMessage")
  718. async def add_message(
  719. self,
  720. conversation_id: UUID,
  721. content: Message,
  722. parent_id: Optional[UUID] = None,
  723. metadata: Optional[dict] = None,
  724. ) -> str:
  725. return await self.providers.database.conversations_handler.add_message(
  726. conversation_id=conversation_id,
  727. content=content,
  728. parent_id=parent_id,
  729. metadata=metadata,
  730. )
  731. @telemetry_event("EditMessage")
  732. async def edit_message(
  733. self,
  734. message_id: UUID,
  735. new_content: Optional[str] = None,
  736. additional_metadata: Optional[dict] = None,
  737. ) -> dict[str, Any]:
  738. return (
  739. await self.providers.database.conversations_handler.edit_message(
  740. message_id=message_id,
  741. new_content=new_content,
  742. additional_metadata=additional_metadata or {},
  743. )
  744. )
  745. @telemetry_event("UpdateConversation")
  746. async def update_conversation(
  747. self, conversation_id: UUID, name: str
  748. ) -> ConversationResponse:
  749. return await self.providers.database.conversations_handler.update_conversation(
  750. conversation_id=conversation_id, name=name
  751. )
  752. @telemetry_event("DeleteConversation")
  753. async def delete_conversation(
  754. self,
  755. conversation_id: UUID,
  756. user_ids: Optional[list[UUID]] = None,
  757. ) -> None:
  758. await self.providers.database.conversations_handler.delete_conversation(
  759. conversation_id=conversation_id,
  760. filter_user_ids=user_ids,
  761. )
  762. async def get_user_max_documents(self, user_id: UUID) -> int | None:
  763. return self.config.app.default_max_documents_per_user
  764. async def get_user_max_chunks(self, user_id: UUID) -> int | None:
  765. return self.config.app.default_max_chunks_per_user
  766. async def get_user_max_collections(self, user_id: UUID) -> int | None:
  767. return self.config.app.default_max_collections_per_user