123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353 |
- import asyncio
- import json
- import logging
- import random
- import time
- from typing import Any, AsyncGenerator
- from uuid import UUID, uuid4
- from core.base import (
- AsyncPipe,
- AsyncState,
- Community,
- CompletionProvider,
- EmbeddingProvider,
- GenerationConfig,
- )
- from core.base.abstractions import Entity, Relationship
- from ...database.postgres import PostgresDatabaseProvider
- logger = logging.getLogger()
- class GraphCommunitySummaryPipe(AsyncPipe):
- """
- Clusters entities and relationships into communities within the knowledge graph using hierarchical Leiden algorithm.
- """
- def __init__(
- self,
- database_provider: PostgresDatabaseProvider,
- llm_provider: CompletionProvider,
- embedding_provider: EmbeddingProvider,
- config: AsyncPipe.PipeConfig,
- *args,
- **kwargs,
- ):
- """
- Initializes the KG clustering pipe with necessary components and configurations.
- """
- super().__init__(
- config=config
- or AsyncPipe.PipeConfig(name="graph_community_summary_pipe"),
- )
- self.database_provider = database_provider
- self.llm_provider = llm_provider
- self.embedding_provider = embedding_provider
- async def community_summary_prompt(
- self,
- entities: list[Entity],
- relationships: list[Relationship],
- max_summary_input_length: int,
- ):
- entity_map: dict[str, dict[str, list[Any]]] = {}
- for entity in entities:
- if not entity.name in entity_map:
- entity_map[entity.name] = {"entities": [], "relationships": []} # type: ignore
- entity_map[entity.name]["entities"].append(entity) # type: ignore
- for relationship in relationships:
- if not relationship.subject in entity_map:
- entity_map[relationship.subject] = { # type: ignore
- "entities": [],
- "relationships": [],
- }
- entity_map[relationship.subject]["relationships"].append( # type: ignore
- relationship
- )
- # sort in descending order of relationship count
- sorted_entity_map = sorted(
- entity_map.items(),
- key=lambda x: len(x[1]["relationships"]),
- reverse=True,
- )
- async def _get_entity_descriptions_string(
- entities: list, max_count: int = 100
- ):
- # randomly sample max_count entities if there are duplicates. This will become a map reduce job later.
- sampled_entities = (
- random.sample(entities, max_count)
- if len(entities) > max_count
- else entities
- )
- return "\n".join(
- f"{entity.id},{entity.description}"
- for entity in sampled_entities
- )
- async def _get_relationships_string(
- relationships: list, max_count: int = 100
- ):
- sampled_relationships = (
- random.sample(relationships, max_count)
- if len(relationships) > max_count
- else relationships
- )
- return "\n".join(
- f"{relationship.id},{relationship.subject},{relationship.object},{relationship.predicate},{relationship.description}"
- for relationship in sampled_relationships
- )
- prompt = ""
- for entity_name, entity_data in sorted_entity_map:
- entity_descriptions = await _get_entity_descriptions_string(
- entity_data["entities"]
- )
- relationships = await _get_relationships_string(
- entity_data["relationships"]
- )
- prompt += f"""
- Entity: {entity_name}
- Descriptions:
- {entity_descriptions}
- Relationships:
- {relationships}
- """
- if len(prompt) > max_summary_input_length:
- logger.info(
- f"Community summary prompt was created of length {len(prompt)}, trimming to {max_summary_input_length} characters."
- )
- # open a file and write the prompt to it
- prompt = prompt[:max_summary_input_length]
- break
- return prompt
- async def process_community(
- self,
- community_id: UUID,
- max_summary_input_length: int,
- generation_config: GenerationConfig,
- collection_id: UUID,
- nodes: list[str],
- all_entities: list[Entity],
- all_relationships: list[Relationship],
- ) -> dict:
- """
- Process a community by summarizing it and creating a summary embedding and storing it to a database.
- """
- response = await self.database_provider.collections_handler.get_collections_overview( # type: ignore
- offset=0,
- limit=1,
- filter_collection_ids=[collection_id],
- )
- collection_description = (
- response["results"][0].description if response["results"] else None # type: ignore
- )
- entities = [entity for entity in all_entities if entity.name in nodes]
- relationships = [
- relationship
- for relationship in all_relationships
- if relationship.subject in nodes and relationship.object in nodes
- ]
- if not entities and not relationships:
- raise ValueError(
- f"Community {community_id} has no entities or relationships."
- )
- input_text = await self.community_summary_prompt(
- entities,
- relationships,
- max_summary_input_length,
- )
- for attempt in range(3):
- description = (
- (
- await self.llm_provider.aget_completion(
- messages=await self.database_provider.prompts_handler.get_message_payload(
- task_prompt_name=self.database_provider.config.graph_enrichment_settings.graphrag_communities,
- task_inputs={
- "collection_description": collection_description,
- "input_text": input_text,
- },
- ),
- generation_config=generation_config,
- )
- )
- .choices[0]
- .message.content
- )
- try:
- if description and description.startswith("```json"):
- description = (
- description.strip("```json").strip("```").strip()
- )
- else:
- raise ValueError(
- f"Failed to generate a summary for community {community_id}"
- )
- description_dict = json.loads(description)
- name = description_dict["name"]
- summary = description_dict["summary"]
- findings = description_dict["findings"]
- rating = description_dict["rating"]
- rating_explanation = description_dict["rating_explanation"]
- break
- except Exception as e:
- if attempt == 2:
- logger.error(
- f"GraphCommunitySummaryPipe: Error generating community summary for community {community_id}: {e}"
- )
- return {
- "community_id": community_id,
- "error": str(e),
- }
- community = Community(
- community_id=community_id,
- collection_id=collection_id,
- name=name,
- summary=summary,
- rating=rating,
- rating_explanation=rating_explanation,
- findings=findings,
- description_embedding=await self.embedding_provider.async_get_embedding(
- "Summary:\n"
- + summary
- + "\n\nFindings:\n"
- + "\n".join(findings)
- ),
- )
- await self.database_provider.graphs_handler.add_community(community)
- return {
- "community_id": community.community_id,
- "name": community.name,
- }
- async def _run_logic( # type: ignore
- self,
- input: AsyncPipe.Input,
- state: AsyncState,
- run_id: UUID,
- *args: Any,
- **kwargs: Any,
- ) -> AsyncGenerator[dict, None]:
- """
- Executes the KG community summary pipe: summarizing communities.
- """
- start_time = time.time()
- offset = input.message["offset"]
- limit = input.message["limit"]
- generation_config = input.message["generation_config"]
- max_summary_input_length = input.message["max_summary_input_length"]
- collection_id = input.message.get("collection_id", None)
- clustering_mode = input.message.get("clustering_mode", None)
- community_summary_jobs = []
- logger = input.message.get("logger", logging.getLogger())
- # check which community summaries exist and don't run them again
- logger.info(
- f"GraphCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}"
- )
- (
- all_entities,
- _,
- ) = await self.database_provider.graphs_handler.get_entities(
- parent_id=collection_id,
- offset=0,
- limit=-1,
- include_embeddings=False,
- )
- (
- all_relationships,
- _,
- ) = await self.database_provider.graphs_handler.get_relationships(
- parent_id=collection_id,
- offset=0,
- limit=-1,
- include_embeddings=False,
- )
- # Perform clustering
- leiden_params = input.message.get("leiden_params", {})
- (
- _,
- community_clusters,
- ) = await self.database_provider.graphs_handler._cluster_and_add_community_info(
- relationships=all_relationships,
- relationship_ids_cache={},
- leiden_params=leiden_params,
- collection_id=collection_id,
- clustering_mode=clustering_mode,
- )
- # Organize clusters
- clusters: dict[Any, Any] = {}
- for item in community_clusters:
- cluster_id = (
- item["cluster"]
- if clustering_mode == "remote"
- else item.cluster
- )
- if cluster_id not in clusters:
- clusters[cluster_id] = []
- clusters[cluster_id].append(
- item["node"] if clustering_mode == "remote" else item.node
- )
- # Now, process the clusters
- for _, nodes in clusters.items():
- community_summary_jobs.append(
- self.process_community(
- community_id=uuid4(),
- nodes=nodes,
- all_entities=all_entities,
- all_relationships=all_relationships,
- max_summary_input_length=max_summary_input_length,
- generation_config=generation_config,
- collection_id=collection_id,
- )
- )
- total_jobs = len(community_summary_jobs)
- total_errors = 0
- completed_community_summary_jobs = 0
- for community_summary in asyncio.as_completed(community_summary_jobs):
- summary = await community_summary
- completed_community_summary_jobs += 1
- if completed_community_summary_jobs % 50 == 0:
- logger.info(
- f"GraphCommunitySummaryPipe: {completed_community_summary_jobs}/{total_jobs} community summaries completed, elapsed time: {time.time() - start_time:.2f} seconds"
- )
- if "error" in summary:
- logger.error(
- f"GraphCommunitySummaryPipe: Error generating community summary for community {summary['community_id']}: {summary['error']}"
- )
- total_errors += 1
- continue
- yield summary
- if total_errors > 0:
- raise ValueError(
- f"GraphCommunitySummaryPipe: Failed to generate community summaries for {total_errors} out of {total_jobs} communities. Please rerun the job if there are too many failures."
- )
|