123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241 |
- import asyncio
- import logging
- from typing import Any, Optional, Union
- from uuid import UUID
- from core.base import AsyncState
- from core.base.abstractions import Entity, GenerationConfig
- from core.base.pipes import AsyncPipe
- from core.database import PostgresDatabaseProvider
- from core.providers import ( # PostgresDatabaseProvider,
- LiteLLMCompletionProvider,
- LiteLLMEmbeddingProvider,
- OllamaEmbeddingProvider,
- OpenAICompletionProvider,
- OpenAIEmbeddingProvider,
- )
- logger = logging.getLogger()
- class GraphDeduplicationSummaryPipe(AsyncPipe[Any]):
- class Input(AsyncPipe.Input):
- message: dict
- def __init__(
- self,
- database_provider: PostgresDatabaseProvider,
- llm_provider: Union[
- LiteLLMCompletionProvider, OpenAICompletionProvider
- ],
- embedding_provider: Union[
- LiteLLMEmbeddingProvider,
- OpenAIEmbeddingProvider,
- OllamaEmbeddingProvider,
- ],
- config: AsyncPipe.PipeConfig,
- **kwargs,
- ):
- super().__init__(config=config, **kwargs)
- self.database_provider = database_provider
- self.llm_provider = llm_provider
- self.embedding_provider = embedding_provider
- async def _merge_entity_descriptions_llm_prompt(
- self,
- entity_name: str,
- entity_descriptions: list[str],
- generation_config: GenerationConfig,
- ) -> Entity:
- # find the index until the length is less than 1024
- index = 0
- description_length = 0
- while index < len(entity_descriptions) and not (
- len(entity_descriptions[index]) + description_length
- > self.database_provider.config.graph_entity_deduplication_settings.max_description_input_length
- ):
- description_length += len(entity_descriptions[index])
- index += 1
- completion = await self.llm_provider.aget_completion(
- messages=await self.database_provider.prompts_handler.get_message_payload(
- task_prompt_name=self.database_provider.config.graph_entity_deduplication_settings.graph_entity_deduplication_prompt,
- task_inputs={
- "entity_name": entity_name,
- "entity_descriptions": "\n".join(
- entity_descriptions[:index]
- ),
- },
- ),
- generation_config=GenerationConfig(**generation_config), # type: ignore
- )
- # get the $$description$$
- try:
- description = completion.choices[0].message.content or ""
- description = description.split("$$")[1]
- except:
- logger.error(
- f"Failed to generate a summary for entity {entity_name}."
- )
- return Entity(name=entity_name, description=description)
- async def _merge_entity_descriptions(
- self,
- entity_name: str,
- entity_descriptions: list[str],
- generation_config: GenerationConfig,
- ) -> Entity:
- # TODO: Expose this as a hyperparameter
- if len(entity_descriptions) <= 5:
- return Entity(
- name=entity_name, description="\n".join(entity_descriptions)
- )
- else:
- return await self._merge_entity_descriptions_llm_prompt(
- entity_name, entity_descriptions, generation_config
- )
- async def _prepare_and_upsert_entities(
- self, entities_batch: list[Entity], graph_id: UUID
- ) -> Any:
- embeddings = await self.embedding_provider.async_get_embeddings(
- [entity.description or "" for entity in entities_batch]
- )
- for i, entity in enumerate(entities_batch):
- entity.description_embedding = str(embeddings[i]) # type: ignore
- entity.graph_id = graph_id
- logger.info(
- f"Upserting {len(entities_batch)} entities for graph {graph_id}"
- )
- await self.database_provider.graphs_handler.update_entity_descriptions(
- entities_batch
- )
- logger.info(
- f"Upserted {len(entities_batch)} entities for graph {graph_id}"
- )
- for entity in entities_batch:
- yield entity
- async def _get_entities(
- self,
- graph_id: Optional[UUID],
- collection_id: Optional[UUID],
- offset: int,
- limit: int,
- level,
- ):
- if graph_id is not None:
- return await self.database_provider.graphs_handler.entities.get(
- parent_id=graph_id,
- offset=offset,
- limit=limit,
- level=level,
- )
- elif collection_id is not None:
- return await self.database_provider.graphs_handler.get_entities(
- parent_id=collection_id,
- offset=offset,
- limit=limit,
- )
- else:
- raise ValueError(
- "Either graph_id or collection_id must be provided"
- )
- async def _run_logic(
- self,
- input: AsyncPipe.Input,
- state: AsyncState,
- run_id: UUID,
- *args: Any,
- **kwargs: Any,
- ):
- # TODO: figure out why the return type AsyncGenerator[dict, None] is not working
- graph_id = input.message.get("graph_id", None)
- collection_id = input.message.get("collection_id", None)
- offset = input.message["offset"]
- limit = input.message["limit"]
- graph_entity_deduplication_type = input.message[
- "graph_entity_deduplication_type"
- ]
- graph_entity_deduplication_prompt = input.message[
- "graph_entity_deduplication_prompt"
- ]
- generation_config = input.message["generation_config"]
- logger.info(
- f"Running kg_entity_deduplication_summary for graph {graph_id} with settings graph_entity_deduplication_type: {graph_entity_deduplication_type}, graph_entity_deduplication_prompt: {graph_entity_deduplication_prompt}, generation_config: {generation_config}"
- )
- entities = await self._get_entities(
- graph_id,
- collection_id,
- offset,
- limit, # type: ignore
- )
- entity_names = [entity.name for entity in entities]
- entity_descriptions = (
- await self.database_provider.graphs_handler.get_entities(
- parent_id=collection_id,
- entity_names=entity_names,
- offset=offset,
- limit=limit,
- )
- )["entities"]
- entity_descriptions_dict: dict[str, list[str]] = {}
- for entity_description in entity_descriptions:
- if entity_description.name not in entity_descriptions_dict:
- entity_descriptions_dict[entity_description.name] = []
- entity_descriptions_dict[entity_description.name].append(
- entity_description.description
- )
- logger.info(
- f"Retrieved {len(entity_descriptions)} entity descriptions for graph {graph_id}"
- )
- tasks = []
- entities_batch = []
- for entity in entities:
- tasks.append(
- self._merge_entity_descriptions(
- entity.name,
- entity_descriptions_dict[entity.name],
- generation_config,
- )
- )
- if len(tasks) == 32:
- entities_batch = await asyncio.gather(*tasks)
- # prepare and upsert entities
- async for result in self._prepare_and_upsert_entities(
- entities_batch, graph_id
- ):
- yield result
- tasks = []
- if tasks:
- entities_batch = await asyncio.gather(*tasks)
- for entity in entities_batch:
- yield entity
- # prepare and upsert entities
- async for result in self._prepare_and_upsert_entities(
- entities_batch, graph_id
- ):
- yield result
|