kg.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. from typing import Optional, Union
  2. from uuid import UUID
  3. from ..models import (
  4. KGCreationSettings,
  5. KGEnrichmentSettings,
  6. KGEntityDeduplicationSettings,
  7. KGRunType,
  8. )
  9. class KGMixins:
  10. async def create_graph(
  11. self,
  12. collection_id: Optional[Union[UUID, str]] = None,
  13. run_type: Optional[Union[str, KGRunType]] = None,
  14. kg_creation_settings: Optional[Union[dict, KGCreationSettings]] = None,
  15. run_with_orchestration: Optional[bool] = None,
  16. ) -> dict:
  17. """
  18. Create a graph from the given settings.
  19. Args:
  20. collection_id (Optional[Union[UUID, str]]): The ID of the collection to create the graph for.
  21. run_type (Optional[Union[str, KGRunType]]): The type of run to perform.
  22. kg_creation_settings (Optional[Union[dict, KGCreationSettings]]): Settings for the graph creation process.
  23. """
  24. if isinstance(kg_creation_settings, KGCreationSettings):
  25. kg_creation_settings = kg_creation_settings.model_dump()
  26. data = {
  27. "collection_id": str(collection_id) if collection_id else None,
  28. "run_type": str(run_type) if run_type else None,
  29. "kg_creation_settings": kg_creation_settings or {},
  30. "run_with_orchestration": run_with_orchestration or True,
  31. }
  32. return await self._make_request("POST", "create_graph", json=data) # type: ignore
  33. async def enrich_graph(
  34. self,
  35. collection_id: Optional[Union[UUID, str]] = None,
  36. run_type: Optional[Union[str, KGRunType]] = None,
  37. kg_enrichment_settings: Optional[
  38. Union[dict, KGEnrichmentSettings]
  39. ] = None,
  40. run_with_orchestration: Optional[bool] = None,
  41. ) -> dict:
  42. """
  43. Perform graph enrichment over the entire graph.
  44. Args:
  45. collection_id (Optional[Union[UUID, str]]): The ID of the collection to enrich the graph for.
  46. run_type (Optional[Union[str, KGRunType]]): The type of run to perform.
  47. kg_enrichment_settings (Optional[Union[dict, KGEnrichmentSettings]]): Settings for the graph enrichment process.
  48. Returns:
  49. Results of the graph enrichment process.
  50. """
  51. if isinstance(kg_enrichment_settings, KGEnrichmentSettings):
  52. kg_enrichment_settings = kg_enrichment_settings.model_dump()
  53. data = {
  54. "collection_id": str(collection_id) if collection_id else None,
  55. "run_type": str(run_type) if run_type else None,
  56. "kg_enrichment_settings": kg_enrichment_settings or {},
  57. "run_with_orchestration": run_with_orchestration or True,
  58. }
  59. return await self._make_request("POST", "enrich_graph", json=data) # type: ignore
  60. async def get_entities(
  61. self,
  62. collection_id: Optional[Union[UUID, str]] = None,
  63. entity_level: Optional[str] = None,
  64. entity_ids: Optional[list[str]] = None,
  65. offset: Optional[int] = None,
  66. limit: Optional[int] = None,
  67. ) -> dict:
  68. """
  69. Retrieve entities from the knowledge graph.
  70. Args:
  71. collection_id (str): The ID of the collection to retrieve entities from.
  72. offset (int): The offset for pagination.
  73. limit (int): The limit for pagination.
  74. entity_level (Optional[str]): The level of entity to filter by.
  75. entity_ids (Optional[List[str]]): Optional list of entity IDs to filter by.
  76. Returns:
  77. dict: A dictionary containing the retrieved entities and total count.
  78. """
  79. params = {
  80. "collection_id": collection_id,
  81. "entity_level": entity_level,
  82. "entity_ids": entity_ids,
  83. "offset": offset,
  84. "limit": limit,
  85. }
  86. params = {k: v for k, v in params.items() if v is not None}
  87. return await self._make_request("GET", "entities", params=params) # type: ignore
  88. async def get_triples(
  89. self,
  90. collection_id: Optional[Union[UUID, str]] = None,
  91. entity_names: Optional[list[str]] = None,
  92. relationship_ids: Optional[list[str]] = None,
  93. offset: Optional[int] = None,
  94. limit: Optional[int] = None,
  95. ) -> dict:
  96. """
  97. Retrieve relationships from the knowledge graph.
  98. Args:
  99. collection_id (str): The ID of the collection to retrieve relationships from.
  100. offset (int): The offset for pagination.
  101. limit (int): The limit for pagination.
  102. entity_names (Optional[List[str]]): Optional list of entity names to filter by.
  103. relationship_ids (Optional[List[str]]): Optional list of relationship IDs to filter by.
  104. Returns:
  105. dict: A dictionary containing the retrieved relationships and total count.
  106. """
  107. params = {
  108. "collection_id": collection_id,
  109. "entity_names": entity_names,
  110. "relationship_ids": relationship_ids,
  111. "offset": offset,
  112. "limit": limit,
  113. }
  114. params = {k: v for k, v in params.items() if v is not None}
  115. return await self._make_request("GET", "relationships", params=params) # type: ignore
  116. async def get_communities(
  117. self,
  118. collection_id: Optional[Union[UUID, str]] = None,
  119. levels: Optional[list[int]] = None,
  120. community_ids: Optional[list[UUID]] = None,
  121. offset: Optional[int] = None,
  122. limit: Optional[int] = None,
  123. ) -> dict:
  124. """
  125. Retrieve communities from the knowledge graph.
  126. Args:
  127. collection_id (str): The ID of the collection to retrieve communities from.
  128. offset (int): The offset for pagination.
  129. limit (int): The limit for pagination.
  130. levels (Optional[List[int]]): Optional list of levels to filter by.
  131. community_ids (Optional[List[int]]): Optional list of community numbers to filter by.
  132. Returns:
  133. dict: A dictionary containing the retrieved communities.
  134. """
  135. params = {
  136. "collection_id": collection_id,
  137. "levels": levels,
  138. "community_ids": community_ids,
  139. "offset": offset,
  140. "limit": limit,
  141. }
  142. params = {k: v for k, v in params.items() if v is not None}
  143. return await self._make_request("GET", "communities", params=params) # type: ignore
  144. async def get_tuned_prompt(
  145. self,
  146. prompt_name: str,
  147. collection_id: Optional[str] = None,
  148. documents_offset: Optional[int] = 0,
  149. documents_limit: Optional[int] = 100,
  150. chunk_offset: Optional[int] = 0,
  151. chunk_limit: Optional[int] = 100,
  152. ) -> dict:
  153. """
  154. Tune the GraphRAG prompt for a given collection.
  155. The tuning process provides an LLM with chunks from each document in the collection. The relative sample size can therefore be controlled by adjusting the document and chunk limits.
  156. Args:
  157. prompt_name (str): The name of the prompt to tune.
  158. collection_id (str): The ID of the collection to tune the prompt for.
  159. documents_offset (Optional[int]): The offset for pagination of documents.
  160. documents_limit (Optional[int]): The limit for pagination of documents.
  161. chunk_offset (Optional[int]): The offset for pagination of chunks.
  162. chunk_limit (Optional[int]): The limit for pagination of chunks.
  163. Returns:
  164. dict: A dictionary containing the tuned prompt.
  165. """
  166. params = {
  167. "prompt_name": prompt_name,
  168. "collection_id": collection_id,
  169. "documents_offset": documents_offset,
  170. "documents_limit": documents_limit,
  171. "chunk_offset": chunk_offset,
  172. "chunk_limit": chunk_limit,
  173. }
  174. params = {k: v for k, v in params.items() if v is not None}
  175. return await self._make_request("GET", "tuned_prompt", params=params) # type: ignore
  176. async def deduplicate_entities(
  177. self,
  178. collection_id: Optional[Union[UUID, str]] = None,
  179. run_type: Optional[Union[str, KGRunType]] = None,
  180. deduplication_settings: Optional[
  181. Union[dict, KGEntityDeduplicationSettings]
  182. ] = None,
  183. ):
  184. """
  185. Deduplicate entities in the knowledge graph.
  186. Args:
  187. collection_id (Optional[Union[UUID, str]]): The ID of the collection to deduplicate entities for.
  188. run_type (Optional[Union[str, KGRunType]]): The type of run to perform.
  189. deduplication_settings (Optional[Union[dict, KGEntityDeduplicationSettings]]): Settings for the deduplication process.
  190. """
  191. if isinstance(deduplication_settings, KGEntityDeduplicationSettings):
  192. deduplication_settings = deduplication_settings.model_dump()
  193. data = {
  194. "collection_id": str(collection_id) if collection_id else None,
  195. "run_type": str(run_type) if run_type else None,
  196. "deduplication_settings": deduplication_settings or {},
  197. }
  198. return await self._make_request( # type: ignore
  199. "POST", "deduplicate_entities", json=data
  200. )
  201. async def delete_graph_for_collection(
  202. self, collection_id: Union[UUID, str], cascade: bool = False
  203. ) -> dict:
  204. """
  205. Delete the graph for a given collection.
  206. Args:
  207. collection_id (Union[UUID, str]): The ID of the collection to delete the graph for.
  208. cascade (bool): Whether to cascade the deletion, and delete entities and relationships belonging to the collection.
  209. NOTE: Setting this flag to true will delete entities and relationships for documents that are shared across multiple collections. Do not set this flag unless you are absolutely sure that you want to delete the entities and relationships for all documents in the collection.
  210. """
  211. data = {
  212. "collection_id": str(collection_id),
  213. "cascade": cascade,
  214. }
  215. return await self._make_request("DELETE", "delete_graph_for_collection", json=data) # type: ignore