123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- # pipe to extract nodes/relationships etc
- import asyncio
- import logging
- import random
- import time
- from typing import Any, AsyncGenerator
- from uuid import UUID
- from core.base import AsyncState, CompletionProvider, EmbeddingProvider
- from core.base.pipes.base_pipe import AsyncPipe
- from ...database.postgres import PostgresDatabaseProvider
- logger = logging.getLogger()
- class GraphDescriptionPipe(AsyncPipe):
- """
- The pipe takes input a list of nodes and extracts description from them.
- """
- class Input(AsyncPipe.Input):
- message: dict[str, Any]
- def __init__(
- self,
- database_provider: PostgresDatabaseProvider,
- llm_provider: CompletionProvider,
- embedding_provider: EmbeddingProvider,
- config: AsyncPipe.PipeConfig,
- *args,
- **kwargs,
- ):
- super().__init__(
- config=config,
- )
- self.database_provider = database_provider
- self.llm_provider = llm_provider
- self.embedding_provider = embedding_provider
- async def _run_logic( # type: ignore
- self,
- input: AsyncPipe.Input,
- state: AsyncState,
- run_id: UUID,
- *args: Any,
- **kwargs: Any,
- ) -> AsyncGenerator[Any, None]:
- """
- Extracts description from the input.
- """
- start_time = time.time()
- def truncate_info(info_list, max_length):
- random.shuffle(info_list)
- truncated_info = ""
- current_length = 0
- for info in info_list:
- if current_length + len(info) > max_length:
- break
- truncated_info += info + "\n"
- current_length += len(info)
- return truncated_info
- async def process_entity(
- entities,
- relationships,
- max_description_input_length,
- document_id: UUID,
- ):
- response = await self.database_provider.documents_handler.get_documents_overview( # type: ignore
- offset=0,
- limit=1,
- filter_document_ids=[document_id],
- )
- document_summary = (
- response["results"][0].summary if response["results"] else None
- )
- entity_info = [
- f"{entity.name}, {entity.description}" for entity in entities
- ]
- relationships_txt = [
- f"{i+1}: {relationship.subject}, {relationship.object}, {relationship.predicate} - Summary: {relationship.description}"
- for i, relationship in enumerate(relationships)
- ]
- # potentially slow at scale, but set to avoid duplicates
- unique_chunk_ids = set()
- for entity in entities:
- for chunk_id in entity.chunk_ids:
- unique_chunk_ids.add(chunk_id)
- out_entity = entities[0]
- if not out_entity.description:
- out_entity.description = (
- (
- await self.llm_provider.aget_completion(
- messages=await self.database_provider.prompts_handler.get_message_payload(
- task_prompt_name=self.database_provider.config.graph_creation_settings.graph_entity_description_prompt,
- task_inputs={
- "document_summary": document_summary,
- "entity_info": truncate_info(
- entity_info,
- max_description_input_length,
- ),
- "relationships_txt": truncate_info(
- relationships_txt,
- max_description_input_length,
- ),
- },
- ),
- generation_config=self.database_provider.config.graph_creation_settings.generation_config,
- )
- )
- .choices[0]
- .message.content
- )
- if not out_entity.description:
- logger.error(
- f"No description for entity {out_entity.name}"
- )
- return out_entity.name
- out_entity.description_embedding = (
- await self.embedding_provider.async_get_embeddings(
- [out_entity.description]
- )
- )[0]
- # upsert the entity and its embedding
- await self.database_provider.graphs_handler.add_entities(
- [out_entity],
- table_name="documents_entities",
- )
- return out_entity.name
- offset = input.message["offset"]
- limit = input.message["limit"]
- document_id = input.message["document_id"]
- logger = input.message["logger"]
- logger.info(
- f"GraphDescriptionPipe: Getting entity map for document {document_id}",
- )
- entity_map = (
- await self.database_provider.graphs_handler.get_entity_map(
- offset, limit, document_id
- )
- )
- total_entities = len(entity_map)
- logger.info(
- f"GraphDescriptionPipe: Got entity map for document {document_id}, total entities: {total_entities}, time from start: {time.time() - start_time:.2f} seconds",
- )
- workflows = []
- for _, (entity_name, entity_info) in enumerate(entity_map.items()):
- try:
- workflows.append(
- process_entity(
- entities=entity_info["entities"],
- relationships=entity_info["relationships"],
- max_description_input_length=input.message[
- "max_description_input_length"
- ],
- document_id=document_id,
- )
- )
- except Exception as e:
- logger.error(f"Error processing entity {entity_name}: {e}")
- completed_entities = 0
- for result in asyncio.as_completed(workflows):
- if completed_entities % 100 == 0:
- logger.info(
- f"GraphDescriptionPipe: Completed {completed_entities+1} of {total_entities} entities for document {document_id}",
- )
- yield await result
- completed_entities += 1
- logger.info(
- f"GraphDescriptionPipe: Processed {total_entities} entities for document {document_id}, time from start: {time.time() - start_time:.2f} seconds",
- )
|