community_summary.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353
  1. import asyncio
  2. import json
  3. import logging
  4. import random
  5. import time
  6. from typing import Any, AsyncGenerator
  7. from uuid import UUID, uuid4
  8. from core.base import (
  9. AsyncPipe,
  10. AsyncState,
  11. Community,
  12. CompletionProvider,
  13. EmbeddingProvider,
  14. GenerationConfig,
  15. )
  16. from core.base.abstractions import Entity, Relationship
  17. from ...database.postgres import PostgresDatabaseProvider
  18. logger = logging.getLogger()
  19. class GraphCommunitySummaryPipe(AsyncPipe):
  20. """
  21. Clusters entities and relationships into communities within the knowledge graph using hierarchical Leiden algorithm.
  22. """
  23. def __init__(
  24. self,
  25. database_provider: PostgresDatabaseProvider,
  26. llm_provider: CompletionProvider,
  27. embedding_provider: EmbeddingProvider,
  28. config: AsyncPipe.PipeConfig,
  29. *args,
  30. **kwargs,
  31. ):
  32. """
  33. Initializes the KG clustering pipe with necessary components and configurations.
  34. """
  35. super().__init__(
  36. config=config
  37. or AsyncPipe.PipeConfig(name="graph_community_summary_pipe"),
  38. )
  39. self.database_provider = database_provider
  40. self.llm_provider = llm_provider
  41. self.embedding_provider = embedding_provider
  42. async def community_summary_prompt(
  43. self,
  44. entities: list[Entity],
  45. relationships: list[Relationship],
  46. max_summary_input_length: int,
  47. ):
  48. entity_map: dict[str, dict[str, list[Any]]] = {}
  49. for entity in entities:
  50. if not entity.name in entity_map:
  51. entity_map[entity.name] = {"entities": [], "relationships": []} # type: ignore
  52. entity_map[entity.name]["entities"].append(entity) # type: ignore
  53. for relationship in relationships:
  54. if not relationship.subject in entity_map:
  55. entity_map[relationship.subject] = { # type: ignore
  56. "entities": [],
  57. "relationships": [],
  58. }
  59. entity_map[relationship.subject]["relationships"].append( # type: ignore
  60. relationship
  61. )
  62. # sort in descending order of relationship count
  63. sorted_entity_map = sorted(
  64. entity_map.items(),
  65. key=lambda x: len(x[1]["relationships"]),
  66. reverse=True,
  67. )
  68. async def _get_entity_descriptions_string(
  69. entities: list, max_count: int = 100
  70. ):
  71. # randomly sample max_count entities if there are duplicates. This will become a map reduce job later.
  72. sampled_entities = (
  73. random.sample(entities, max_count)
  74. if len(entities) > max_count
  75. else entities
  76. )
  77. return "\n".join(
  78. f"{entity.id},{entity.description}"
  79. for entity in sampled_entities
  80. )
  81. async def _get_relationships_string(
  82. relationships: list, max_count: int = 100
  83. ):
  84. sampled_relationships = (
  85. random.sample(relationships, max_count)
  86. if len(relationships) > max_count
  87. else relationships
  88. )
  89. return "\n".join(
  90. f"{relationship.id},{relationship.subject},{relationship.object},{relationship.predicate},{relationship.description}"
  91. for relationship in sampled_relationships
  92. )
  93. prompt = ""
  94. for entity_name, entity_data in sorted_entity_map:
  95. entity_descriptions = await _get_entity_descriptions_string(
  96. entity_data["entities"]
  97. )
  98. relationships = await _get_relationships_string(
  99. entity_data["relationships"]
  100. )
  101. prompt += f"""
  102. Entity: {entity_name}
  103. Descriptions:
  104. {entity_descriptions}
  105. Relationships:
  106. {relationships}
  107. """
  108. if len(prompt) > max_summary_input_length:
  109. logger.info(
  110. f"Community summary prompt was created of length {len(prompt)}, trimming to {max_summary_input_length} characters."
  111. )
  112. # open a file and write the prompt to it
  113. prompt = prompt[:max_summary_input_length]
  114. break
  115. return prompt
  116. async def process_community(
  117. self,
  118. community_id: UUID,
  119. max_summary_input_length: int,
  120. generation_config: GenerationConfig,
  121. collection_id: UUID,
  122. nodes: list[str],
  123. all_entities: list[Entity],
  124. all_relationships: list[Relationship],
  125. ) -> dict:
  126. """
  127. Process a community by summarizing it and creating a summary embedding and storing it to a database.
  128. """
  129. response = await self.database_provider.collections_handler.get_collections_overview( # type: ignore
  130. offset=0,
  131. limit=1,
  132. filter_collection_ids=[collection_id],
  133. )
  134. collection_description = (
  135. response["results"][0].description if response["results"] else None # type: ignore
  136. )
  137. entities = [entity for entity in all_entities if entity.name in nodes]
  138. relationships = [
  139. relationship
  140. for relationship in all_relationships
  141. if relationship.subject in nodes and relationship.object in nodes
  142. ]
  143. if not entities and not relationships:
  144. raise ValueError(
  145. f"Community {community_id} has no entities or relationships."
  146. )
  147. input_text = await self.community_summary_prompt(
  148. entities,
  149. relationships,
  150. max_summary_input_length,
  151. )
  152. for attempt in range(3):
  153. description = (
  154. (
  155. await self.llm_provider.aget_completion(
  156. messages=await self.database_provider.prompts_handler.get_message_payload(
  157. task_prompt_name=self.database_provider.config.graph_enrichment_settings.graphrag_communities,
  158. task_inputs={
  159. "collection_description": collection_description,
  160. "input_text": input_text,
  161. },
  162. ),
  163. generation_config=generation_config,
  164. )
  165. )
  166. .choices[0]
  167. .message.content
  168. )
  169. try:
  170. if description and description.startswith("```json"):
  171. description = (
  172. description.strip("```json").strip("```").strip()
  173. )
  174. else:
  175. raise ValueError(
  176. f"Failed to generate a summary for community {community_id}"
  177. )
  178. description_dict = json.loads(description)
  179. name = description_dict["name"]
  180. summary = description_dict["summary"]
  181. findings = description_dict["findings"]
  182. rating = description_dict["rating"]
  183. rating_explanation = description_dict["rating_explanation"]
  184. break
  185. except Exception as e:
  186. if attempt == 2:
  187. logger.error(
  188. f"GraphCommunitySummaryPipe: Error generating community summary for community {community_id}: {e}"
  189. )
  190. return {
  191. "community_id": community_id,
  192. "error": str(e),
  193. }
  194. community = Community(
  195. community_id=community_id,
  196. collection_id=collection_id,
  197. name=name,
  198. summary=summary,
  199. rating=rating,
  200. rating_explanation=rating_explanation,
  201. findings=findings,
  202. description_embedding=await self.embedding_provider.async_get_embedding(
  203. "Summary:\n"
  204. + summary
  205. + "\n\nFindings:\n"
  206. + "\n".join(findings)
  207. ),
  208. )
  209. await self.database_provider.graphs_handler.add_community(community)
  210. return {
  211. "community_id": community.community_id,
  212. "name": community.name,
  213. }
  214. async def _run_logic( # type: ignore
  215. self,
  216. input: AsyncPipe.Input,
  217. state: AsyncState,
  218. run_id: UUID,
  219. *args: Any,
  220. **kwargs: Any,
  221. ) -> AsyncGenerator[dict, None]:
  222. """
  223. Executes the KG community summary pipe: summarizing communities.
  224. """
  225. start_time = time.time()
  226. offset = input.message["offset"]
  227. limit = input.message["limit"]
  228. generation_config = input.message["generation_config"]
  229. max_summary_input_length = input.message["max_summary_input_length"]
  230. collection_id = input.message.get("collection_id", None)
  231. clustering_mode = input.message.get("clustering_mode", None)
  232. community_summary_jobs = []
  233. logger = input.message.get("logger", logging.getLogger())
  234. # check which community summaries exist and don't run them again
  235. logger.info(
  236. f"GraphCommunitySummaryPipe: Checking if community summaries exist for communities {offset} to {offset + limit}"
  237. )
  238. (
  239. all_entities,
  240. _,
  241. ) = await self.database_provider.graphs_handler.get_entities(
  242. parent_id=collection_id,
  243. offset=0,
  244. limit=-1,
  245. include_embeddings=False,
  246. )
  247. (
  248. all_relationships,
  249. _,
  250. ) = await self.database_provider.graphs_handler.get_relationships(
  251. parent_id=collection_id,
  252. offset=0,
  253. limit=-1,
  254. include_embeddings=False,
  255. )
  256. # Perform clustering
  257. leiden_params = input.message.get("leiden_params", {})
  258. (
  259. _,
  260. community_clusters,
  261. ) = await self.database_provider.graphs_handler._cluster_and_add_community_info(
  262. relationships=all_relationships,
  263. relationship_ids_cache={},
  264. leiden_params=leiden_params,
  265. collection_id=collection_id,
  266. clustering_mode=clustering_mode,
  267. )
  268. # Organize clusters
  269. clusters: dict[Any, Any] = {}
  270. for item in community_clusters:
  271. cluster_id = (
  272. item["cluster"]
  273. if clustering_mode == "remote"
  274. else item.cluster
  275. )
  276. if cluster_id not in clusters:
  277. clusters[cluster_id] = []
  278. clusters[cluster_id].append(
  279. item["node"] if clustering_mode == "remote" else item.node
  280. )
  281. # Now, process the clusters
  282. for _, nodes in clusters.items():
  283. community_summary_jobs.append(
  284. self.process_community(
  285. community_id=uuid4(),
  286. nodes=nodes,
  287. all_entities=all_entities,
  288. all_relationships=all_relationships,
  289. max_summary_input_length=max_summary_input_length,
  290. generation_config=generation_config,
  291. collection_id=collection_id,
  292. )
  293. )
  294. total_jobs = len(community_summary_jobs)
  295. total_errors = 0
  296. completed_community_summary_jobs = 0
  297. for community_summary in asyncio.as_completed(community_summary_jobs):
  298. summary = await community_summary
  299. completed_community_summary_jobs += 1
  300. if completed_community_summary_jobs % 50 == 0:
  301. logger.info(
  302. f"GraphCommunitySummaryPipe: {completed_community_summary_jobs}/{total_jobs} community summaries completed, elapsed time: {time.time() - start_time:.2f} seconds"
  303. )
  304. if "error" in summary:
  305. logger.error(
  306. f"GraphCommunitySummaryPipe: Error generating community summary for community {summary['community_id']}: {summary['error']}"
  307. )
  308. total_errors += 1
  309. continue
  310. yield summary
  311. if total_errors > 0:
  312. raise ValueError(
  313. f"GraphCommunitySummaryPipe: Failed to generate community summaries for {total_errors} out of {total_jobs} communities. Please rerun the job if there are too many failures."
  314. )