deduplication.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import json
  2. import logging
  3. from typing import Any
  4. from uuid import UUID
  5. from core.base import AsyncState
  6. from core.base.abstractions import Entity, KGEntityDeduplicationType
  7. from core.base.pipes import AsyncPipe
  8. from core.database import PostgresDatabaseProvider
  9. from core.providers import (
  10. LiteLLMCompletionProvider,
  11. LiteLLMEmbeddingProvider,
  12. OllamaEmbeddingProvider,
  13. OpenAICompletionProvider,
  14. OpenAIEmbeddingProvider,
  15. )
  16. logger = logging.getLogger()
  17. class GraphDeduplicationPipe(AsyncPipe):
  18. def __init__(
  19. self,
  20. config: AsyncPipe.PipeConfig,
  21. database_provider: PostgresDatabaseProvider,
  22. llm_provider: OpenAICompletionProvider | LiteLLMCompletionProvider,
  23. embedding_provider: (
  24. LiteLLMEmbeddingProvider
  25. | OpenAIEmbeddingProvider
  26. | OllamaEmbeddingProvider
  27. ),
  28. **kwargs,
  29. ):
  30. super().__init__(
  31. config=config
  32. or AsyncPipe.PipeConfig(name="graph_deduplication_pipe"),
  33. )
  34. self.database_provider = database_provider
  35. self.llm_provider = llm_provider
  36. self.embedding_provider = embedding_provider
  37. async def _get_entities(
  38. self, graph_id: UUID | None, collection_id: UUID | None
  39. ):
  40. if collection_id is not None:
  41. return await self.database_provider.graphs_handler.get_entities(
  42. collection_id=collection_id, offset=0, limit=-1
  43. )
  44. elif graph_id is not None:
  45. # TODO: remove the tuple return type
  46. return (
  47. await self.database_provider.graphs_handler.entities.get(
  48. id=graph_id,
  49. offset=0,
  50. limit=-1,
  51. )
  52. )[0]
  53. else:
  54. raise ValueError(
  55. "Either graph_id or collection_id must be provided"
  56. )
  57. async def kg_named_entity_deduplication(
  58. self, graph_id: UUID | None, collection_id: UUID | None, **kwargs
  59. ):
  60. import numpy as np
  61. entities = await self._get_entities(graph_id, collection_id)
  62. logger.info(
  63. f"GraphDeduplicationPipe: Got {len(entities)} entities for {graph_id or collection_id}"
  64. )
  65. # deduplicate entities by name
  66. deduplicated_entities: dict[str, dict[str, list[str]]] = {}
  67. deduplication_source_keys = [
  68. "description",
  69. "chunk_ids",
  70. "document_id",
  71. # "description_embedding",
  72. ]
  73. deduplication_target_keys = [
  74. "description",
  75. "chunk_ids",
  76. "document_ids",
  77. # "description_embedding",
  78. ]
  79. deduplication_keys = list(
  80. zip(deduplication_source_keys, deduplication_target_keys)
  81. )
  82. for entity in entities:
  83. if entity.name not in deduplicated_entities:
  84. deduplicated_entities[entity.name] = {
  85. target_key: [] for _, target_key in deduplication_keys
  86. }
  87. # deduplicated_entities[entity.name]['total_entries'] = 0
  88. # deduplicated_entities[entity.name]['description_embedding'] = np.zeros(len(json.loads(entity.description_embedding)))
  89. for source_key, target_key in deduplication_keys:
  90. value = getattr(entity, source_key)
  91. # if source_key == "description_embedding":
  92. # deduplicated_entities[entity.name]['total_entries'] += 1
  93. # deduplicated_entities[entity.name][target_key] += np.array(json.loads(value))
  94. if isinstance(value, list):
  95. deduplicated_entities[entity.name][target_key].extend(
  96. value
  97. )
  98. else:
  99. deduplicated_entities[entity.name][target_key].append(
  100. value
  101. )
  102. # upsert deduplcated entities in the collection_entity table
  103. deduplicated_entities_list = [
  104. Entity(
  105. name=name,
  106. # description="\n".join(entity["description"]),
  107. # description_embedding=json.dumps((entity["description_embedding"] / entity['total_entries']).tolist()),
  108. collection_id=collection_id,
  109. graph_id=graph_id,
  110. chunk_ids=list(set(entity["chunk_ids"])),
  111. document_ids=list(set(entity["document_ids"])),
  112. attributes={},
  113. )
  114. for name, entity in deduplicated_entities.items()
  115. ]
  116. logger.info(
  117. f"GraphDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}"
  118. )
  119. await self.database_provider.graphs_handler.add_entities(
  120. deduplicated_entities_list,
  121. table_name="collection_entity",
  122. )
  123. yield {
  124. "result": f"successfully deduplicated {len(entities)} entities to {len(deduplicated_entities)} entities for collection {graph_id}",
  125. "num_entities": len(deduplicated_entities),
  126. }
  127. async def kg_description_entity_deduplication(
  128. self, graph_id: UUID | None, collection_id: UUID | None, **kwargs
  129. ):
  130. from sklearn.cluster import DBSCAN
  131. entities = await self._get_entities(graph_id, collection_id)
  132. for entity in entities:
  133. entity.description_embedding = json.loads(
  134. entity.description_embedding
  135. )
  136. deduplication_source_keys = [
  137. "chunk_ids",
  138. "document_id",
  139. "attributes",
  140. ]
  141. deduplication_target_keys = [
  142. "chunk_ids",
  143. "document_ids",
  144. "attributes",
  145. ]
  146. deduplication_keys = list(
  147. zip(deduplication_source_keys, deduplication_target_keys)
  148. )
  149. embeddings = [entity.description_embedding for entity in entities]
  150. logger.info(
  151. f"GraphDeduplicationPipe: Running DBSCAN clustering on {len(embeddings)} embeddings"
  152. )
  153. # TODO: make eps a config, make it very strict for now
  154. clustering = DBSCAN(eps=0.1, min_samples=2, metric="cosine").fit(
  155. embeddings
  156. )
  157. labels = clustering.labels_
  158. # Log clustering results
  159. n_clusters = len(set(labels)) - (1 if -1 in labels else 0)
  160. n_noise = list(labels).count(-1)
  161. logger.info(
  162. f"GraphDeduplicationPipe: Found {n_clusters} clusters and {n_noise} noise points"
  163. )
  164. # for all labels in the same cluster, we can deduplicate them by name
  165. deduplicated_entities: dict[int, list] = {}
  166. for id, label in enumerate(labels):
  167. if label != -1:
  168. if label not in deduplicated_entities:
  169. deduplicated_entities[label] = []
  170. deduplicated_entities[label].append(entities[id])
  171. # upsert deduplcated entities in the collection_entity table
  172. deduplicated_entities_list = []
  173. for label, entities in deduplicated_entities.items():
  174. longest_name = ""
  175. descriptions = []
  176. aliases = set()
  177. for entity in entities:
  178. aliases.add(entity.name)
  179. descriptions.append(entity.description)
  180. if len(entity.name) > len(longest_name):
  181. longest_name = entity.name
  182. descriptions.sort(key=len, reverse=True)
  183. description = "\n".join(descriptions[:5])
  184. # Collect all extraction IDs from entities in the cluster
  185. chunk_ids = set()
  186. document_ids = set()
  187. for entity in entities:
  188. if entity.chunk_ids:
  189. chunk_ids.update(entity.chunk_ids)
  190. if entity.document_id:
  191. document_ids.add(entity.document_id)
  192. chunk_ids_list = list(chunk_ids)
  193. document_ids_list = list(document_ids)
  194. deduplicated_entities_list.append(
  195. Entity(
  196. name=longest_name,
  197. description=description,
  198. graph_id=graph_id,
  199. collection_id=collection_id,
  200. chunk_ids=chunk_ids_list,
  201. document_ids=document_ids_list,
  202. attributes={
  203. "aliases": list(aliases),
  204. },
  205. )
  206. )
  207. logger.info(
  208. f"GraphDeduplicationPipe: Upserting {len(deduplicated_entities_list)} deduplicated entities for collection {graph_id}"
  209. )
  210. await self.database_provider.graphs_handler.add_entities(
  211. deduplicated_entities_list,
  212. table_name="collection_entity",
  213. conflict_columns=["name", "graph_id", "attributes"],
  214. )
  215. yield {
  216. "result": f"successfully deduplicated {len(entities)} entities to {len(deduplicated_entities)} entities for collection {graph_id}",
  217. "num_entities": len(deduplicated_entities),
  218. }
  219. # async def kg_llm_entity_deduplication(
  220. # self, graph_id: UUID, collection_id: UUID, **kwargs
  221. # ):
  222. # # TODO: implement LLM based entity deduplication
  223. # raise NotImplementedError(
  224. # "LLM entity deduplication is not implemented yet"
  225. # )
  226. async def _run_logic(
  227. self,
  228. input: AsyncPipe.Input,
  229. state: AsyncState,
  230. run_id: UUID,
  231. *args: Any,
  232. **kwargs: Any,
  233. ):
  234. # TODO: figure out why the return type AsyncGenerator[dict, None] is not working
  235. graph_id = input.message.get("graph_id", None)
  236. collection_id = input.message.get("collection_id", None)
  237. if graph_id and collection_id:
  238. raise ValueError(
  239. "graph_id and collection_id cannot both be provided"
  240. )
  241. graph_entity_deduplication_type = input.message[
  242. "graph_entity_deduplication_type"
  243. ]
  244. if (
  245. graph_entity_deduplication_type
  246. == KGEntityDeduplicationType.BY_NAME
  247. ):
  248. async for result in self.kg_named_entity_deduplication(
  249. graph_id=graph_id, collection_id=collection_id, **kwargs
  250. ):
  251. yield result
  252. elif (
  253. graph_entity_deduplication_type
  254. == KGEntityDeduplicationType.BY_DESCRIPTION
  255. ):
  256. async for result in self.kg_description_entity_deduplication(
  257. graph_id=graph_id, collection_id=collection_id, **kwargs
  258. ):
  259. yield result
  260. elif (
  261. graph_entity_deduplication_type == KGEntityDeduplicationType.BY_LLM
  262. ):
  263. raise NotImplementedError(
  264. "LLM entity deduplication is not implemented yet"
  265. )
  266. else:
  267. raise ValueError(
  268. f"Invalid graph_entity_deduplication_type: {graph_entity_deduplication_type}"
  269. )