deduplication_summary.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. import asyncio
  2. import logging
  3. from typing import Any, Optional, Union
  4. from uuid import UUID
  5. from core.base import AsyncState
  6. from core.base.abstractions import Entity, GenerationConfig
  7. from core.base.pipes import AsyncPipe
  8. from core.database import PostgresDatabaseProvider
  9. from core.providers import ( # PostgresDatabaseProvider,
  10. LiteLLMCompletionProvider,
  11. LiteLLMEmbeddingProvider,
  12. OllamaEmbeddingProvider,
  13. OpenAICompletionProvider,
  14. OpenAIEmbeddingProvider,
  15. )
  16. logger = logging.getLogger()
  17. class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
  18. class Input(AsyncPipe.Input):
  19. message: dict
  20. def __init__(
  21. self,
  22. database_provider: PostgresDatabaseProvider,
  23. llm_provider: Union[
  24. LiteLLMCompletionProvider, OpenAICompletionProvider
  25. ],
  26. embedding_provider: Union[
  27. LiteLLMEmbeddingProvider,
  28. OpenAIEmbeddingProvider,
  29. OllamaEmbeddingProvider,
  30. ],
  31. config: AsyncPipe.PipeConfig,
  32. **kwargs,
  33. ):
  34. super().__init__(config=config, **kwargs)
  35. self.database_provider = database_provider
  36. self.llm_provider = llm_provider
  37. self.embedding_provider = embedding_provider
  38. async def _merge_entity_descriptions_llm_prompt(
  39. self,
  40. entity_name: str,
  41. entity_descriptions: list[str],
  42. generation_config: GenerationConfig,
  43. ) -> Entity:
  44. # find the index until the length is less than 1024
  45. index = 0
  46. description_length = 0
  47. while index < len(entity_descriptions) and not (
  48. len(entity_descriptions[index]) + description_length
  49. > self.database_provider.config.graph_entity_deduplication_settings.max_description_input_length
  50. ):
  51. description_length += len(entity_descriptions[index])
  52. index += 1
  53. completion = await self.llm_provider.aget_completion(
  54. messages=await self.database_provider.prompts_handler.get_message_payload(
  55. task_prompt_name=self.database_provider.config.graph_entity_deduplication_settings.graph_entity_deduplication_prompt,
  56. task_inputs={
  57. "entity_name": entity_name,
  58. "entity_descriptions": "\n".join(
  59. entity_descriptions[:index]
  60. ),
  61. },
  62. ),
  63. generation_config=GenerationConfig(**generation_config), # type: ignore
  64. )
  65. # get the $$description$$
  66. try:
  67. description = completion.choices[0].message.content or ""
  68. description = description.split("$$")[1]
  69. except:
  70. logger.error(
  71. f"Failed to generate a summary for entity {entity_name}."
  72. )
  73. return Entity(name=entity_name, description=description)
  74. async def _merge_entity_descriptions(
  75. self,
  76. entity_name: str,
  77. entity_descriptions: list[str],
  78. generation_config: GenerationConfig,
  79. ) -> Entity:
  80. # TODO: Expose this as a hyperparameter
  81. if len(entity_descriptions) <= 5:
  82. return Entity(
  83. name=entity_name, description="\n".join(entity_descriptions)
  84. )
  85. else:
  86. return await self._merge_entity_descriptions_llm_prompt(
  87. entity_name, entity_descriptions, generation_config
  88. )
  89. async def _prepare_and_upsert_entities(
  90. self, entities_batch: list[Entity], graph_id: UUID
  91. ) -> Any:
  92. embeddings = await self.embedding_provider.async_get_embeddings(
  93. [entity.description or "" for entity in entities_batch]
  94. )
  95. for i, entity in enumerate(entities_batch):
  96. entity.description_embedding = str(embeddings[i]) # type: ignore
  97. entity.graph_id = graph_id
  98. logger.info(
  99. f"Upserting {len(entities_batch)} entities for graph {graph_id}"
  100. )
  101. await self.database_provider.graphs_handler.update_entity_descriptions(
  102. entities_batch
  103. )
  104. logger.info(
  105. f"Upserted {len(entities_batch)} entities for graph {graph_id}"
  106. )
  107. for entity in entities_batch:
  108. yield entity
  109. async def _get_entities(
  110. self,
  111. graph_id: Optional[UUID],
  112. collection_id: Optional[UUID],
  113. offset: int,
  114. limit: int,
  115. level,
  116. ):
  117. if graph_id is not None:
  118. return await self.database_provider.graphs_handler.entities.get(
  119. parent_id=graph_id,
  120. offset=offset,
  121. limit=limit,
  122. level=level,
  123. )
  124. elif collection_id is not None:
  125. return await self.database_provider.graphs_handler.get_entities(
  126. parent_id=collection_id,
  127. offset=offset,
  128. limit=limit,
  129. )
  130. else:
  131. raise ValueError(
  132. "Either graph_id or collection_id must be provided"
  133. )
  134. async def _run_logic(
  135. self,
  136. input: AsyncPipe.Input,
  137. state: AsyncState,
  138. run_id: UUID,
  139. *args: Any,
  140. **kwargs: Any,
  141. ):
  142. # TODO: figure out why the return type AsyncGenerator[dict, None] is not working
  143. graph_id = input.message.get("graph_id", None)
  144. collection_id = input.message.get("collection_id", None)
  145. offset = input.message["offset"]
  146. limit = input.message["limit"]
  147. graph_entity_deduplication_type = input.message[
  148. "graph_entity_deduplication_type"
  149. ]
  150. graph_entity_deduplication_prompt = input.message[
  151. "graph_entity_deduplication_prompt"
  152. ]
  153. generation_config = input.message["generation_config"]
  154. logger.info(
  155. f"Running kg_entity_deduplication_summary for graph {graph_id} with settings graph_entity_deduplication_type: {graph_entity_deduplication_type}, graph_entity_deduplication_prompt: {graph_entity_deduplication_prompt}, generation_config: {generation_config}"
  156. )
  157. entities = await self._get_entities(
  158. graph_id,
  159. collection_id,
  160. offset,
  161. limit, # type: ignore
  162. )
  163. entity_names = [entity.name for entity in entities]
  164. entity_descriptions = (
  165. await self.database_provider.graphs_handler.get_entities(
  166. parent_id=collection_id,
  167. entity_names=entity_names,
  168. offset=offset,
  169. limit=limit,
  170. )
  171. )["entities"]
  172. entity_descriptions_dict: dict[str, list[str]] = {}
  173. for entity_description in entity_descriptions:
  174. if entity_description.name not in entity_descriptions_dict:
  175. entity_descriptions_dict[entity_description.name] = []
  176. entity_descriptions_dict[entity_description.name].append(
  177. entity_description.description
  178. )
  179. logger.info(
  180. f"Retrieved {len(entity_descriptions)} entity descriptions for graph {graph_id}"
  181. )
  182. tasks = []
  183. entities_batch = []
  184. for entity in entities:
  185. tasks.append(
  186. self._merge_entity_descriptions(
  187. entity.name,
  188. entity_descriptions_dict[entity.name],
  189. generation_config,
  190. )
  191. )
  192. if len(tasks) == 32:
  193. entities_batch = await asyncio.gather(*tasks)
  194. # prepare and upsert entities
  195. async for result in self._prepare_and_upsert_entities(
  196. entities_batch, graph_id
  197. ):
  198. yield result
  199. tasks = []
  200. if tasks:
  201. entities_batch = await asyncio.gather(*tasks)
  202. for entity in entities_batch:
  203. yield entity
  204. # prepare and upsert entities
  205. async for result in self._prepare_and_upsert_entities(
  206. entities_batch, graph_id
  207. ):
  208. yield result