management_service.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684
  1. import logging
  2. import os
  3. from collections import defaultdict
  4. from typing import Any, BinaryIO, Optional, Tuple
  5. from uuid import UUID
  6. import toml
  7. from core.base import (
  8. CollectionResponse,
  9. ConversationResponse,
  10. DocumentResponse,
  11. GenerationConfig,
  12. KGEnrichmentStatus,
  13. Message,
  14. Prompt,
  15. R2RException,
  16. RunManager,
  17. User,
  18. )
  19. from core.telemetry.telemetry_decorator import telemetry_event
  20. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  21. from ..config import R2RConfig
  22. from .base import Service
  23. logger = logging.getLogger()
  24. class ManagementService(Service):
  25. def __init__(
  26. self,
  27. config: R2RConfig,
  28. providers: R2RProviders,
  29. pipes: R2RPipes,
  30. pipelines: R2RPipelines,
  31. agents: R2RAgents,
  32. run_manager: RunManager,
  33. ):
  34. super().__init__(
  35. config,
  36. providers,
  37. pipes,
  38. pipelines,
  39. agents,
  40. run_manager,
  41. )
  42. @telemetry_event("AppSettings")
  43. async def app_settings(self):
  44. prompts = (
  45. await self.providers.database.prompts_handler.get_all_prompts()
  46. )
  47. config_toml = self.config.to_toml()
  48. config_dict = toml.loads(config_toml)
  49. return {
  50. "config": config_dict,
  51. "prompts": prompts,
  52. "r2r_project_name": os.environ["R2R_PROJECT_NAME"],
  53. # "r2r_version": get_version("r2r"),
  54. }
  55. @telemetry_event("UsersOverview")
  56. async def users_overview(
  57. self,
  58. offset: int,
  59. limit: int,
  60. user_ids: Optional[list[UUID]] = None,
  61. ):
  62. return await self.providers.database.users_handler.get_users_overview(
  63. offset=offset,
  64. limit=limit,
  65. user_ids=user_ids,
  66. )
  67. async def delete_documents_and_chunks_by_filter(
  68. self,
  69. filters: dict[str, Any],
  70. ):
  71. """
  72. Delete chunks matching the given filters. If any documents are now empty
  73. (i.e., have no remaining chunks), delete those documents as well.
  74. Args:
  75. filters (dict[str, Any]): Filters specifying which chunks to delete.
  76. chunks_handler (PostgresChunksHandler): The handler for chunk operations.
  77. documents_handler (PostgresDocumentsHandler): The handler for document operations.
  78. graphs_handler: Handler for entity and relationship operations in the KG.
  79. Returns:
  80. dict: A summary of what was deleted.
  81. """
  82. def transform_chunk_id_to_id(
  83. filters: dict[str, Any]
  84. ) -> dict[str, Any]:
  85. """
  86. Example transformation function if your filters use `chunk_id` instead of `id`.
  87. Recursively transform `chunk_id` to `id`.
  88. """
  89. if isinstance(filters, dict):
  90. transformed = {}
  91. for key, value in filters.items():
  92. if key == "chunk_id":
  93. transformed["id"] = value
  94. elif key in ["$and", "$or"]:
  95. transformed[key] = [
  96. transform_chunk_id_to_id(item) for item in value
  97. ]
  98. else:
  99. transformed[key] = transform_chunk_id_to_id(value)
  100. return transformed
  101. return filters
  102. # 1. (Optional) Validate the input filters based on your rules.
  103. # E.g., check if filters is not empty, allowed fields, etc.
  104. # validate_filters(filters)
  105. # 2. Transform filters if needed.
  106. # For example, if `chunk_id` is used, map it to `id`, or similar transformations.
  107. transformed_filters = transform_chunk_id_to_id(filters)
  108. # 3. First, find out which chunks match these filters *before* deleting, so we know which docs are affected.
  109. # You can do a list operation on chunks to see which chunk IDs and doc IDs would be hit.
  110. interim_results = (
  111. await self.providers.database.chunks_handler.list_chunks(
  112. filters=transformed_filters,
  113. offset=0,
  114. limit=1_000, # Arbitrary large limit or pagination logic
  115. include_vectors=False,
  116. )
  117. )
  118. if interim_results["page_info"]["total_entries"] == 0:
  119. raise R2RException(
  120. status_code=404, message="No entries found for deletion."
  121. )
  122. results = interim_results["results"]
  123. while interim_results["page_info"]["total_entries"] == 1_000:
  124. # If we hit the limit, we need to paginate to get all results
  125. interim_results = (
  126. await self.providers.database.chunks_handler.list_chunks(
  127. filters=transformed_filters,
  128. offset=interim_results["offset"] + 1_000,
  129. limit=1_000,
  130. include_vectors=False,
  131. )
  132. )
  133. results.extend(interim_results["results"])
  134. matched_chunk_docs = {UUID(chunk["document_id"]) for chunk in results}
  135. # If no chunks match, raise or return a no-op result
  136. if not matched_chunk_docs:
  137. return {
  138. "success": False,
  139. "message": "No chunks match the given filters.",
  140. }
  141. # 4. Delete the matching chunks from the database.
  142. delete_results = await self.providers.database.chunks_handler.delete(
  143. transformed_filters
  144. )
  145. # 5. From `delete_results`, extract the document_ids that were affected.
  146. # The delete_results should map chunk_id to details including `document_id`.
  147. affected_doc_ids = {
  148. UUID(info["document_id"])
  149. for info in delete_results.values()
  150. if info.get("document_id")
  151. }
  152. # 6. For each affected document, check if the document still has any chunks left.
  153. docs_to_delete = []
  154. for doc_id in affected_doc_ids:
  155. remaining = await self.providers.database.chunks_handler.list_document_chunks(
  156. document_id=doc_id,
  157. offset=0,
  158. limit=1, # Just need to know if there's at least one left
  159. include_vectors=False,
  160. )
  161. # If no remaining chunks, we should delete the document.
  162. if remaining["total_entries"] == 0:
  163. docs_to_delete.append(doc_id)
  164. # 7. Delete documents that no longer have associated chunks.
  165. # Also update graphs if needed (entities/relationships).
  166. for doc_id in docs_to_delete:
  167. # Delete related entities & relationships if needed:
  168. await self.providers.database.graphs_handler.entities.delete(
  169. parent_id=doc_id, store_type="documents"
  170. )
  171. await self.providers.database.graphs_handler.relationships.delete(
  172. parent_id=doc_id, store_type="documents"
  173. )
  174. # Finally, delete the document from documents_overview:
  175. await self.providers.database.documents_handler.delete(
  176. document_id=doc_id
  177. )
  178. # 8. Return a summary of what happened.
  179. return {
  180. "success": True,
  181. "deleted_chunks_count": len(delete_results),
  182. "deleted_documents_count": len(docs_to_delete),
  183. "deleted_document_ids": [str(d) for d in docs_to_delete],
  184. }
  185. @telemetry_event("DownloadFile")
  186. async def download_file(
  187. self, document_id: UUID
  188. ) -> Optional[Tuple[str, BinaryIO, int]]:
  189. if result := await self.providers.database.files_handler.retrieve_file(
  190. document_id
  191. ):
  192. return result
  193. return None
  194. @telemetry_event("DocumentsOverview")
  195. async def documents_overview(
  196. self,
  197. offset: int,
  198. limit: int,
  199. user_ids: Optional[list[UUID]] = None,
  200. collection_ids: Optional[list[UUID]] = None,
  201. document_ids: Optional[list[UUID]] = None,
  202. ):
  203. return await self.providers.database.documents_handler.get_documents_overview(
  204. offset=offset,
  205. limit=limit,
  206. filter_document_ids=document_ids,
  207. filter_user_ids=user_ids,
  208. filter_collection_ids=collection_ids,
  209. )
  210. @telemetry_event("DocumentChunks")
  211. async def list_document_chunks(
  212. self,
  213. document_id: UUID,
  214. offset: int,
  215. limit: int,
  216. include_vectors: bool = False,
  217. ):
  218. return (
  219. await self.providers.database.chunks_handler.list_document_chunks(
  220. document_id=document_id,
  221. offset=offset,
  222. limit=limit,
  223. include_vectors=include_vectors,
  224. )
  225. )
  226. @telemetry_event("AssignDocumentToCollection")
  227. async def assign_document_to_collection(
  228. self, document_id: UUID, collection_id: UUID
  229. ):
  230. await self.providers.database.chunks_handler.assign_document_chunks_to_collection(
  231. document_id, collection_id
  232. )
  233. await self.providers.database.collections_handler.assign_document_to_collection_relational(
  234. document_id, collection_id
  235. )
  236. await self.providers.database.documents_handler.set_workflow_status(
  237. id=collection_id,
  238. status_type="graph_sync_status",
  239. status=KGEnrichmentStatus.OUTDATED,
  240. )
  241. await self.providers.database.documents_handler.set_workflow_status(
  242. id=collection_id,
  243. status_type="graph_cluster_status",
  244. status=KGEnrichmentStatus.OUTDATED,
  245. )
  246. return {"message": "Document assigned to collection successfully"}
  247. @telemetry_event("RemoveDocumentFromCollection")
  248. async def remove_document_from_collection(
  249. self, document_id: UUID, collection_id: UUID
  250. ):
  251. await self.providers.database.collections_handler.remove_document_from_collection_relational(
  252. document_id, collection_id
  253. )
  254. await self.providers.database.chunks_handler.remove_document_from_collection_vector(
  255. document_id, collection_id
  256. )
  257. # await self.providers.database.graphs_handler.delete_node_via_document_id(
  258. # document_id, collection_id
  259. # )
  260. return None
  261. def _process_relationships(
  262. self, relationships: list[Tuple[str, str, str]]
  263. ) -> Tuple[dict[str, list[str]], dict[str, dict[str, list[str]]]]:
  264. graph = defaultdict(list)
  265. grouped: dict[str, dict[str, list[str]]] = defaultdict(
  266. lambda: defaultdict(list)
  267. )
  268. for subject, relation, obj in relationships:
  269. graph[subject].append(obj)
  270. grouped[subject][relation].append(obj)
  271. if obj not in graph:
  272. graph[obj] = []
  273. return dict(graph), dict(grouped)
  274. def generate_output(
  275. self,
  276. grouped_relationships: dict[str, dict[str, list[str]]],
  277. graph: dict[str, list[str]],
  278. descriptions_dict: dict[str, str],
  279. print_descriptions: bool = True,
  280. ) -> list[str]:
  281. output = []
  282. # Print grouped relationships
  283. for subject, relations in grouped_relationships.items():
  284. output.append(f"\n== {subject} ==")
  285. if print_descriptions and subject in descriptions_dict:
  286. output.append(f"\tDescription: {descriptions_dict[subject]}")
  287. for relation, objects in relations.items():
  288. output.append(f" {relation}:")
  289. for obj in objects:
  290. output.append(f" - {obj}")
  291. if print_descriptions and obj in descriptions_dict:
  292. output.append(
  293. f" Description: {descriptions_dict[obj]}"
  294. )
  295. # Print basic graph statistics
  296. output.extend(
  297. [
  298. "\n== Graph Statistics ==",
  299. f"Number of nodes: {len(graph)}",
  300. f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}",
  301. f"Number of connected components: {self._count_connected_components(graph)}",
  302. ]
  303. )
  304. # Find central nodes
  305. central_nodes = self._get_central_nodes(graph)
  306. output.extend(
  307. [
  308. "\n== Most Central Nodes ==",
  309. *(
  310. f" {node}: {centrality:.4f}"
  311. for node, centrality in central_nodes
  312. ),
  313. ]
  314. )
  315. return output
  316. def _count_connected_components(self, graph: dict[str, list[str]]) -> int:
  317. visited = set()
  318. components = 0
  319. def dfs(node):
  320. visited.add(node)
  321. for neighbor in graph[node]:
  322. if neighbor not in visited:
  323. dfs(neighbor)
  324. for node in graph:
  325. if node not in visited:
  326. dfs(node)
  327. components += 1
  328. return components
  329. def _get_central_nodes(
  330. self, graph: dict[str, list[str]]
  331. ) -> list[Tuple[str, float]]:
  332. degree = {node: len(neighbors) for node, neighbors in graph.items()}
  333. total_nodes = len(graph)
  334. centrality = {
  335. node: deg / (total_nodes - 1) for node, deg in degree.items()
  336. }
  337. return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]
  338. @telemetry_event("CreateCollection")
  339. async def create_collection(
  340. self,
  341. owner_id: UUID,
  342. name: Optional[str] = None,
  343. description: str = "",
  344. ) -> CollectionResponse:
  345. result = await self.providers.database.collections_handler.create_collection(
  346. owner_id=owner_id,
  347. name=name,
  348. description=description,
  349. )
  350. graph_result = await self.providers.database.graphs_handler.create(
  351. collection_id=result.id,
  352. name=name,
  353. description=description,
  354. )
  355. return result
  356. @telemetry_event("UpdateCollection")
  357. async def update_collection(
  358. self,
  359. collection_id: UUID,
  360. name: Optional[str] = None,
  361. description: Optional[str] = None,
  362. generate_description: bool = False,
  363. ) -> CollectionResponse:
  364. if generate_description:
  365. description = await self.summarize_collection(
  366. id=collection_id, offset=0, limit=100
  367. )
  368. return await self.providers.database.collections_handler.update_collection(
  369. collection_id=collection_id,
  370. name=name,
  371. description=description,
  372. )
  373. @telemetry_event("DeleteCollection")
  374. async def delete_collection(self, collection_id: UUID) -> bool:
  375. await self.providers.database.collections_handler.delete_collection_relational(
  376. collection_id
  377. )
  378. await self.providers.database.chunks_handler.delete_collection_vector(
  379. collection_id
  380. )
  381. return True
  382. @telemetry_event("ListCollections")
  383. async def collections_overview(
  384. self,
  385. offset: int,
  386. limit: int,
  387. user_ids: Optional[list[UUID]] = None,
  388. document_ids: Optional[list[UUID]] = None,
  389. collection_ids: Optional[list[UUID]] = None,
  390. ) -> dict[str, list[CollectionResponse] | int]:
  391. return await self.providers.database.collections_handler.get_collections_overview(
  392. offset=offset,
  393. limit=limit,
  394. filter_user_ids=user_ids,
  395. filter_document_ids=document_ids,
  396. filter_collection_ids=collection_ids,
  397. )
  398. @telemetry_event("AddUserToCollection")
  399. async def add_user_to_collection(
  400. self, user_id: UUID, collection_id: UUID
  401. ) -> bool:
  402. return (
  403. await self.providers.database.users_handler.add_user_to_collection(
  404. user_id, collection_id
  405. )
  406. )
  407. @telemetry_event("RemoveUserFromCollection")
  408. async def remove_user_from_collection(
  409. self, user_id: UUID, collection_id: UUID
  410. ) -> bool:
  411. x = await self.providers.database.users_handler.remove_user_from_collection(
  412. user_id, collection_id
  413. )
  414. return x
  415. @telemetry_event("GetUsersInCollection")
  416. async def get_users_in_collection(
  417. self, collection_id: UUID, offset: int = 0, limit: int = 100
  418. ) -> dict[str, list[User] | int]:
  419. return await self.providers.database.users_handler.get_users_in_collection(
  420. collection_id, offset=offset, limit=limit
  421. )
  422. @telemetry_event("GetDocumentsInCollection")
  423. async def documents_in_collection(
  424. self, collection_id: UUID, offset: int = 0, limit: int = 100
  425. ) -> dict[str, list[DocumentResponse] | int]:
  426. return await self.providers.database.collections_handler.documents_in_collection(
  427. collection_id, offset=offset, limit=limit
  428. )
  429. @telemetry_event("SummarizeCollection")
  430. async def summarize_collection(
  431. self, id: UUID, offset: int, limit: int
  432. ) -> str:
  433. documents_in_collection_response = await self.documents_in_collection(
  434. collection_id=id,
  435. offset=offset,
  436. limit=limit,
  437. )
  438. document_summaries = [
  439. document.summary
  440. for document in documents_in_collection_response["results"]
  441. ]
  442. logger.info(
  443. f"Summarizing collection {id} with {len(document_summaries)} of {documents_in_collection_response['total_entries']} documents."
  444. )
  445. formatted_summaries = "\n\n".join(document_summaries)
  446. messages = await self.providers.database.prompts_handler.get_message_payload(
  447. system_prompt_name=self.config.database.collection_summary_system_prompt,
  448. task_prompt_name=self.config.database.collection_summary_task_prompt,
  449. task_inputs={"document_summaries": formatted_summaries},
  450. )
  451. response = await self.providers.llm.aget_completion(
  452. messages=messages,
  453. generation_config=GenerationConfig(
  454. model=self.config.ingestion.document_summary_model
  455. ),
  456. )
  457. collection_summary = response.choices[0].message.content
  458. if not collection_summary:
  459. raise ValueError("Expected a generated response.")
  460. return collection_summary
  461. @telemetry_event("AddPrompt")
  462. async def add_prompt(
  463. self, name: str, template: str, input_types: dict[str, str]
  464. ) -> dict:
  465. try:
  466. await self.providers.database.prompts_handler.add_prompt(
  467. name, template, input_types
  468. )
  469. return f"Prompt '{name}' added successfully." # type: ignore
  470. except ValueError as e:
  471. raise R2RException(status_code=400, message=str(e))
  472. @telemetry_event("GetPrompt")
  473. async def get_cached_prompt(
  474. self,
  475. prompt_name: str,
  476. inputs: Optional[dict[str, Any]] = None,
  477. prompt_override: Optional[str] = None,
  478. ) -> dict:
  479. try:
  480. return {
  481. "message": (
  482. await self.providers.database.prompts_handler.get_cached_prompt(
  483. prompt_name, inputs, prompt_override
  484. )
  485. )
  486. }
  487. except ValueError as e:
  488. raise R2RException(status_code=404, message=str(e))
  489. @telemetry_event("GetPrompt")
  490. async def get_prompt(
  491. self,
  492. prompt_name: str,
  493. inputs: Optional[dict[str, Any]] = None,
  494. prompt_override: Optional[str] = None,
  495. ) -> dict:
  496. try:
  497. return await self.providers.database.prompts_handler.get_prompt( # type: ignore
  498. name=prompt_name,
  499. inputs=inputs,
  500. prompt_override=prompt_override,
  501. )
  502. except ValueError as e:
  503. raise R2RException(status_code=404, message=str(e))
  504. @telemetry_event("GetAllPrompts")
  505. async def get_all_prompts(self) -> dict[str, Prompt]:
  506. return await self.providers.database.prompts_handler.get_all_prompts()
  507. @telemetry_event("UpdatePrompt")
  508. async def update_prompt(
  509. self,
  510. name: str,
  511. template: Optional[str] = None,
  512. input_types: Optional[dict[str, str]] = None,
  513. ) -> dict:
  514. try:
  515. await self.providers.database.prompts_handler.update_prompt(
  516. name, template, input_types
  517. )
  518. return f"Prompt '{name}' updated successfully." # type: ignore
  519. except ValueError as e:
  520. raise R2RException(status_code=404, message=str(e))
  521. @telemetry_event("DeletePrompt")
  522. async def delete_prompt(self, name: str) -> dict:
  523. try:
  524. await self.providers.database.prompts_handler.delete_prompt(name)
  525. return {"message": f"Prompt '{name}' deleted successfully."}
  526. except ValueError as e:
  527. raise R2RException(status_code=404, message=str(e))
  528. @telemetry_event("GetConversation")
  529. async def get_conversation(
  530. self,
  531. conversation_id: UUID,
  532. user_ids: Optional[list[UUID]] = None,
  533. ) -> Tuple[str, list[Message], list[dict]]:
  534. return await self.providers.database.conversations_handler.get_conversation(
  535. conversation_id=conversation_id,
  536. filter_user_ids=user_ids,
  537. )
  538. @telemetry_event("CreateConversation")
  539. async def create_conversation(
  540. self,
  541. user_id: Optional[UUID] = None,
  542. name: Optional[str] = None,
  543. ) -> ConversationResponse:
  544. return await self.providers.database.conversations_handler.create_conversation(
  545. user_id=user_id,
  546. name=name,
  547. )
  548. @telemetry_event("ConversationsOverview")
  549. async def conversations_overview(
  550. self,
  551. offset: int,
  552. limit: int,
  553. conversation_ids: Optional[list[UUID]] = None,
  554. user_ids: Optional[list[UUID]] = None,
  555. ) -> dict[str, list[dict] | int]:
  556. return await self.providers.database.conversations_handler.get_conversations_overview(
  557. offset=offset,
  558. limit=limit,
  559. filter_user_ids=user_ids,
  560. conversation_ids=conversation_ids,
  561. )
  562. @telemetry_event("AddMessage")
  563. async def add_message(
  564. self,
  565. conversation_id: UUID,
  566. content: Message,
  567. parent_id: Optional[UUID] = None,
  568. metadata: Optional[dict] = None,
  569. ) -> str:
  570. return await self.providers.database.conversations_handler.add_message(
  571. conversation_id=conversation_id,
  572. content=content,
  573. parent_id=parent_id,
  574. metadata=metadata,
  575. )
  576. @telemetry_event("EditMessage")
  577. async def edit_message(
  578. self,
  579. message_id: UUID,
  580. new_content: Optional[str] = None,
  581. additional_metadata: Optional[dict] = None,
  582. ) -> dict[str, Any]:
  583. return (
  584. await self.providers.database.conversations_handler.edit_message(
  585. message_id=message_id,
  586. new_content=new_content,
  587. additional_metadata=additional_metadata or {},
  588. )
  589. )
  590. @telemetry_event("UpdateConversation")
  591. async def update_conversation(
  592. self, conversation_id: UUID, name: str
  593. ) -> ConversationResponse:
  594. return await self.providers.database.conversations_handler.update_conversation(
  595. conversation_id=conversation_id, name=name
  596. )
  597. @telemetry_event("DeleteConversation")
  598. async def delete_conversation(
  599. self,
  600. conversation_id: UUID,
  601. user_ids: Optional[list[UUID]] = None,
  602. ) -> None:
  603. await self.providers.database.conversations_handler.delete_conversation(
  604. conversation_id=conversation_id,
  605. filter_user_ids=user_ids,
  606. )
  607. async def get_user_max_documents(self, user_id: UUID) -> int:
  608. return self.config.app.default_max_documents_per_user
  609. async def get_user_max_chunks(self, user_id: UUID) -> int:
  610. return self.config.app.default_max_chunks_per_user
  611. async def get_user_max_collections(self, user_id: UUID) -> int:
  612. return self.config.app.default_max_collections_per_user