description.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # pipe to extract nodes/relationships etc
  2. import asyncio
  3. import logging
  4. import random
  5. import time
  6. from typing import Any, AsyncGenerator
  7. from uuid import UUID
  8. from core.base import AsyncState, CompletionProvider, EmbeddingProvider
  9. from core.base.pipes.base_pipe import AsyncPipe
  10. from ...database.postgres import PostgresDatabaseProvider
  11. logger = logging.getLogger()
  12. class GraphDescriptionPipe(AsyncPipe):
  13. """
  14. The pipe takes input a list of nodes and extracts description from them.
  15. """
  16. class Input(AsyncPipe.Input):
  17. message: dict[str, Any]
  18. def __init__(
  19. self,
  20. database_provider: PostgresDatabaseProvider,
  21. llm_provider: CompletionProvider,
  22. embedding_provider: EmbeddingProvider,
  23. config: AsyncPipe.PipeConfig,
  24. *args,
  25. **kwargs,
  26. ):
  27. super().__init__(
  28. config=config,
  29. )
  30. self.database_provider = database_provider
  31. self.llm_provider = llm_provider
  32. self.embedding_provider = embedding_provider
  33. async def _run_logic( # type: ignore
  34. self,
  35. input: AsyncPipe.Input,
  36. state: AsyncState,
  37. run_id: UUID,
  38. *args: Any,
  39. **kwargs: Any,
  40. ) -> AsyncGenerator[Any, None]:
  41. """
  42. Extracts description from the input.
  43. """
  44. start_time = time.time()
  45. def truncate_info(info_list, max_length):
  46. random.shuffle(info_list)
  47. truncated_info = ""
  48. current_length = 0
  49. for info in info_list:
  50. if current_length + len(info) > max_length:
  51. break
  52. truncated_info += info + "\n"
  53. current_length += len(info)
  54. return truncated_info
  55. async def process_entity(
  56. entities,
  57. relationships,
  58. max_description_input_length,
  59. document_id: UUID,
  60. ):
  61. response = await self.database_provider.documents_handler.get_documents_overview( # type: ignore
  62. offset=0,
  63. limit=1,
  64. filter_document_ids=[document_id],
  65. )
  66. document_summary = (
  67. response["results"][0].summary if response["results"] else None
  68. )
  69. entity_info = [
  70. f"{entity.name}, {entity.description}" for entity in entities
  71. ]
  72. relationships_txt = [
  73. f"{i+1}: {relationship.subject}, {relationship.object}, {relationship.predicate} - Summary: {relationship.description}"
  74. for i, relationship in enumerate(relationships)
  75. ]
  76. # potentially slow at scale, but set to avoid duplicates
  77. unique_chunk_ids = set()
  78. for entity in entities:
  79. for chunk_id in entity.chunk_ids:
  80. unique_chunk_ids.add(chunk_id)
  81. out_entity = entities[0]
  82. if not out_entity.description:
  83. out_entity.description = (
  84. (
  85. await self.llm_provider.aget_completion(
  86. messages=await self.database_provider.prompts_handler.get_message_payload(
  87. task_prompt_name=self.database_provider.config.graph_creation_settings.graph_entity_description_prompt,
  88. task_inputs={
  89. "document_summary": document_summary,
  90. "entity_info": truncate_info(
  91. entity_info,
  92. max_description_input_length,
  93. ),
  94. "relationships_txt": truncate_info(
  95. relationships_txt,
  96. max_description_input_length,
  97. ),
  98. },
  99. ),
  100. generation_config=self.database_provider.config.graph_creation_settings.generation_config,
  101. )
  102. )
  103. .choices[0]
  104. .message.content
  105. )
  106. if not out_entity.description:
  107. logger.error(
  108. f"No description for entity {out_entity.name}"
  109. )
  110. return out_entity.name
  111. out_entity.description_embedding = (
  112. await self.embedding_provider.async_get_embeddings(
  113. [out_entity.description]
  114. )
  115. )[0]
  116. # upsert the entity and its embedding
  117. await self.database_provider.graphs_handler.add_entities(
  118. [out_entity],
  119. table_name="documents_entities",
  120. )
  121. return out_entity.name
  122. offset = input.message["offset"]
  123. limit = input.message["limit"]
  124. document_id = input.message["document_id"]
  125. logger = input.message["logger"]
  126. logger.info(
  127. f"GraphDescriptionPipe: Getting entity map for document {document_id}",
  128. )
  129. entity_map = (
  130. await self.database_provider.graphs_handler.get_entity_map(
  131. offset, limit, document_id
  132. )
  133. )
  134. total_entities = len(entity_map)
  135. logger.info(
  136. f"GraphDescriptionPipe: Got entity map for document {document_id}, total entities: {total_entities}, time from start: {time.time() - start_time:.2f} seconds",
  137. )
  138. workflows = []
  139. for _, (entity_name, entity_info) in enumerate(entity_map.items()):
  140. try:
  141. workflows.append(
  142. process_entity(
  143. entities=entity_info["entities"],
  144. relationships=entity_info["relationships"],
  145. max_description_input_length=input.message[
  146. "max_description_input_length"
  147. ],
  148. document_id=document_id,
  149. )
  150. )
  151. except Exception as e:
  152. logger.error(f"Error processing entity {entity_name}: {e}")
  153. completed_entities = 0
  154. for result in asyncio.as_completed(workflows):
  155. if completed_entities % 100 == 0:
  156. logger.info(
  157. f"GraphDescriptionPipe: Completed {completed_entities+1} of {total_entities} entities for document {document_id}",
  158. )
  159. yield await result
  160. completed_entities += 1
  161. logger.info(
  162. f"GraphDescriptionPipe: Processed {total_entities} entities for document {document_id}, time from start: {time.time() - start_time:.2f} seconds",
  163. )