kg_service.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093
  1. import asyncio
  2. import json
  3. import logging
  4. import math
  5. import re
  6. import time
  7. from typing import Any, AsyncGenerator, Optional
  8. from uuid import UUID
  9. from core.base import (
  10. DocumentChunk,
  11. KGExtraction,
  12. KGExtractionStatus,
  13. R2RDocumentProcessingError,
  14. RunManager,
  15. )
  16. from core.base.abstractions import (
  17. Community,
  18. Entity,
  19. GenerationConfig,
  20. KGCreationSettings,
  21. KGEnrichmentSettings,
  22. KGEnrichmentStatus,
  23. KGEntityDeduplicationSettings,
  24. KGEntityDeduplicationType,
  25. R2RException,
  26. Relationship,
  27. )
  28. from core.base.api.models import GraphResponse
  29. from core.telemetry.telemetry_decorator import telemetry_event
  30. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  31. from ..config import R2RConfig
  32. from .base import Service
  33. logger = logging.getLogger()
  34. MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH = 128
  35. async def _collect_results(result_gen: AsyncGenerator) -> list[dict]:
  36. results = []
  37. async for res in result_gen:
  38. results.append(res.json() if hasattr(res, "json") else res)
  39. return results
  40. # TODO - Fix naming convention to read `KGService` instead of `KgService`
  41. # this will require a minor change in how services are registered.
  42. class KgService(Service):
  43. def __init__(
  44. self,
  45. config: R2RConfig,
  46. providers: R2RProviders,
  47. pipes: R2RPipes,
  48. pipelines: R2RPipelines,
  49. agents: R2RAgents,
  50. run_manager: RunManager,
  51. ):
  52. super().__init__(
  53. config,
  54. providers,
  55. pipes,
  56. pipelines,
  57. agents,
  58. run_manager,
  59. )
  60. @telemetry_event("kg_relationships_extraction")
  61. async def kg_relationships_extraction(
  62. self,
  63. document_id: UUID,
  64. generation_config: GenerationConfig,
  65. chunk_merge_count: int,
  66. max_knowledge_relationships: int,
  67. entity_types: list[str],
  68. relation_types: list[str],
  69. **kwargs,
  70. ):
  71. try:
  72. logger.info(
  73. f"KGService: Processing document {document_id} for KG extraction"
  74. )
  75. await self.providers.database.documents_handler.set_workflow_status(
  76. id=document_id,
  77. status_type="extraction_status",
  78. status=KGExtractionStatus.PROCESSING,
  79. )
  80. relationships = await self.pipes.kg_relationships_extraction_pipe.run(
  81. input=self.pipes.kg_relationships_extraction_pipe.Input(
  82. message={
  83. "document_id": document_id,
  84. "generation_config": generation_config,
  85. "chunk_merge_count": chunk_merge_count,
  86. "max_knowledge_relationships": max_knowledge_relationships,
  87. "entity_types": entity_types,
  88. "relation_types": relation_types,
  89. "logger": logger,
  90. }
  91. ),
  92. state=None,
  93. run_manager=self.run_manager,
  94. )
  95. logger.info(
  96. f"KGService: Finished processing document {document_id} for KG extraction"
  97. )
  98. result_gen = await self.pipes.kg_storage_pipe.run(
  99. input=self.pipes.kg_storage_pipe.Input(message=relationships),
  100. state=None,
  101. run_manager=self.run_manager,
  102. )
  103. except Exception as e:
  104. logger.error(f"KGService: Error in kg_extraction: {e}")
  105. await self.providers.database.documents_handler.set_workflow_status(
  106. id=document_id,
  107. status_type="extraction_status",
  108. status=KGExtractionStatus.FAILED,
  109. )
  110. raise e
  111. return await _collect_results(result_gen)
  112. @telemetry_event("create_entity")
  113. async def create_entity(
  114. self,
  115. name: str,
  116. description: str,
  117. parent_id: UUID,
  118. category: Optional[str] = None,
  119. metadata: Optional[dict] = None,
  120. ) -> Entity:
  121. description_embedding = str(
  122. await self.providers.embedding.async_get_embedding(description)
  123. )
  124. return await self.providers.database.graphs_handler.entities.create(
  125. name=name,
  126. parent_id=parent_id,
  127. store_type="graphs", # type: ignore
  128. category=category,
  129. description=description,
  130. description_embedding=description_embedding,
  131. metadata=metadata,
  132. )
  133. @telemetry_event("update_entity")
  134. async def update_entity(
  135. self,
  136. entity_id: UUID,
  137. name: Optional[str] = None,
  138. description: Optional[str] = None,
  139. category: Optional[str] = None,
  140. metadata: Optional[dict] = None,
  141. ) -> Entity:
  142. description_embedding = None
  143. if description is not None:
  144. description_embedding = str(
  145. await self.providers.embedding.async_get_embedding(description)
  146. )
  147. return await self.providers.database.graphs_handler.entities.update(
  148. entity_id=entity_id,
  149. store_type="graphs", # type: ignore
  150. name=name,
  151. description=description,
  152. description_embedding=description_embedding,
  153. category=category,
  154. metadata=metadata,
  155. )
  156. @telemetry_event("delete_entity")
  157. async def delete_entity(
  158. self,
  159. parent_id: UUID,
  160. entity_id: UUID,
  161. ):
  162. return await self.providers.database.graphs_handler.entities.delete(
  163. parent_id=parent_id,
  164. entity_ids=[entity_id],
  165. store_type="graphs", # type: ignore
  166. )
  167. @telemetry_event("get_entities")
  168. async def get_entities(
  169. self,
  170. parent_id: UUID,
  171. offset: int,
  172. limit: int,
  173. entity_ids: Optional[list[UUID]] = None,
  174. entity_names: Optional[list[str]] = None,
  175. include_embeddings: bool = False,
  176. ):
  177. return await self.providers.database.graphs_handler.get_entities(
  178. parent_id=parent_id,
  179. offset=offset,
  180. limit=limit,
  181. entity_ids=entity_ids,
  182. entity_names=entity_names,
  183. include_embeddings=include_embeddings,
  184. )
  185. @telemetry_event("create_relationship")
  186. async def create_relationship(
  187. self,
  188. subject: str,
  189. subject_id: UUID,
  190. predicate: str,
  191. object: str,
  192. object_id: UUID,
  193. parent_id: UUID,
  194. description: str | None = None,
  195. weight: float | None = 1.0,
  196. metadata: Optional[dict[str, Any] | str] = None,
  197. ) -> Relationship:
  198. description_embedding = None
  199. if description:
  200. description_embedding = str(
  201. await self.providers.embedding.async_get_embedding(description)
  202. )
  203. return (
  204. await self.providers.database.graphs_handler.relationships.create(
  205. subject=subject,
  206. subject_id=subject_id,
  207. predicate=predicate,
  208. object=object,
  209. object_id=object_id,
  210. parent_id=parent_id,
  211. description=description,
  212. description_embedding=description_embedding,
  213. weight=weight,
  214. metadata=metadata,
  215. store_type="graphs", # type: ignore
  216. )
  217. )
  218. @telemetry_event("delete_relationship")
  219. async def delete_relationship(
  220. self,
  221. parent_id: UUID,
  222. relationship_id: UUID,
  223. ):
  224. return (
  225. await self.providers.database.graphs_handler.relationships.delete(
  226. parent_id=parent_id,
  227. relationship_ids=[relationship_id],
  228. store_type="graphs", # type: ignore
  229. )
  230. )
  231. @telemetry_event("update_relationship")
  232. async def update_relationship(
  233. self,
  234. relationship_id: UUID,
  235. subject: Optional[str] = None,
  236. subject_id: Optional[UUID] = None,
  237. predicate: Optional[str] = None,
  238. object: Optional[str] = None,
  239. object_id: Optional[UUID] = None,
  240. description: Optional[str] = None,
  241. weight: Optional[float] = None,
  242. metadata: Optional[dict[str, Any] | str] = None,
  243. ) -> Relationship:
  244. description_embedding = None
  245. if description is not None:
  246. description_embedding = str(
  247. await self.providers.embedding.async_get_embedding(description)
  248. )
  249. return (
  250. await self.providers.database.graphs_handler.relationships.update(
  251. relationship_id=relationship_id,
  252. subject=subject,
  253. subject_id=subject_id,
  254. predicate=predicate,
  255. object=object,
  256. object_id=object_id,
  257. description=description,
  258. description_embedding=description_embedding,
  259. weight=weight,
  260. metadata=metadata,
  261. store_type="graphs", # type: ignore
  262. )
  263. )
  264. @telemetry_event("get_relationships")
  265. async def get_relationships(
  266. self,
  267. parent_id: UUID,
  268. offset: int,
  269. limit: int,
  270. relationship_ids: Optional[list[UUID]] = None,
  271. entity_names: Optional[list[str]] = None,
  272. ):
  273. return await self.providers.database.graphs_handler.relationships.get(
  274. parent_id=parent_id,
  275. store_type="graphs", # type: ignore
  276. offset=offset,
  277. limit=limit,
  278. relationship_ids=relationship_ids,
  279. entity_names=entity_names,
  280. )
  281. @telemetry_event("create_community")
  282. async def create_community(
  283. self,
  284. parent_id: UUID,
  285. name: str,
  286. summary: str,
  287. findings: Optional[list[str]],
  288. rating: Optional[float],
  289. rating_explanation: Optional[str],
  290. ) -> Community:
  291. description_embedding = str(
  292. await self.providers.embedding.async_get_embedding(summary)
  293. )
  294. return await self.providers.database.graphs_handler.communities.create(
  295. parent_id=parent_id,
  296. store_type="graphs", # type: ignore
  297. name=name,
  298. summary=summary,
  299. description_embedding=description_embedding,
  300. findings=findings,
  301. rating=rating,
  302. rating_explanation=rating_explanation,
  303. )
  304. @telemetry_event("update_community")
  305. async def update_community(
  306. self,
  307. community_id: UUID,
  308. name: Optional[str],
  309. summary: Optional[str],
  310. findings: Optional[list[str]],
  311. rating: Optional[float],
  312. rating_explanation: Optional[str],
  313. ) -> Community:
  314. summary_embedding = None
  315. if summary is not None:
  316. summary_embedding = str(
  317. await self.providers.embedding.async_get_embedding(summary)
  318. )
  319. return await self.providers.database.graphs_handler.communities.update(
  320. community_id=community_id,
  321. store_type="graphs", # type: ignore
  322. name=name,
  323. summary=summary,
  324. summary_embedding=summary_embedding,
  325. findings=findings,
  326. rating=rating,
  327. rating_explanation=rating_explanation,
  328. )
  329. @telemetry_event("delete_community")
  330. async def delete_community(
  331. self,
  332. parent_id: UUID,
  333. community_id: UUID,
  334. ) -> None:
  335. await self.providers.database.graphs_handler.communities.delete(
  336. parent_id=parent_id,
  337. community_id=community_id,
  338. )
  339. @telemetry_event("list_communities")
  340. async def list_communities(
  341. self,
  342. collection_id: UUID,
  343. offset: int,
  344. limit: int,
  345. ):
  346. return await self.providers.database.graphs_handler.communities.get(
  347. parent_id=collection_id,
  348. store_type="graphs", # type: ignore
  349. offset=offset,
  350. limit=limit,
  351. )
  352. @telemetry_event("get_communities")
  353. async def get_communities(
  354. self,
  355. parent_id: UUID,
  356. offset: int,
  357. limit: int,
  358. community_ids: Optional[list[UUID]] = None,
  359. community_names: Optional[list[str]] = None,
  360. include_embeddings: bool = False,
  361. ):
  362. return await self.providers.database.graphs_handler.get_communities(
  363. parent_id=parent_id,
  364. offset=offset,
  365. limit=limit,
  366. community_ids=community_ids,
  367. include_embeddings=include_embeddings,
  368. )
  369. # @telemetry_event("create_new_graph")
  370. # async def create_new_graph(
  371. # self,
  372. # collection_id: UUID,
  373. # user_id: UUID,
  374. # name: Optional[str],
  375. # description: str = "",
  376. # ) -> GraphResponse:
  377. # return await self.providers.database.graphs_handler.create(
  378. # collection_id=collection_id,
  379. # user_id=user_id,
  380. # name=name,
  381. # description=description,
  382. # graph_id=collection_id,
  383. # )
  384. async def list_graphs(
  385. self,
  386. offset: int,
  387. limit: int,
  388. # user_ids: Optional[list[UUID]] = None,
  389. graph_ids: Optional[list[UUID]] = None,
  390. collection_id: Optional[UUID] = None,
  391. ) -> dict[str, list[GraphResponse] | int]:
  392. return await self.providers.database.graphs_handler.list_graphs(
  393. offset=offset,
  394. limit=limit,
  395. # filter_user_ids=user_ids,
  396. filter_graph_ids=graph_ids,
  397. filter_collection_id=collection_id,
  398. )
  399. @telemetry_event("update_graph")
  400. async def update_graph(
  401. self,
  402. collection_id: UUID,
  403. name: Optional[str] = None,
  404. description: Optional[str] = None,
  405. ) -> GraphResponse:
  406. return await self.providers.database.graphs_handler.update(
  407. collection_id=collection_id,
  408. name=name,
  409. description=description,
  410. )
  411. @telemetry_event("reset_graph_v3")
  412. async def reset_graph_v3(self, id: UUID) -> bool:
  413. await self.providers.database.graphs_handler.reset(
  414. parent_id=id,
  415. )
  416. await self.providers.database.documents_handler.set_workflow_status(
  417. id=id,
  418. status_type="graph_cluster_status",
  419. status=KGEnrichmentStatus.PENDING,
  420. )
  421. return True
  422. @telemetry_event("get_document_ids_for_create_graph")
  423. async def get_document_ids_for_create_graph(
  424. self,
  425. collection_id: UUID,
  426. force_kg_creation: bool = False,
  427. **kwargs,
  428. ):
  429. document_status_filter = [
  430. KGExtractionStatus.PENDING,
  431. KGExtractionStatus.FAILED,
  432. ]
  433. if force_kg_creation:
  434. document_status_filter += [
  435. KGExtractionStatus.PROCESSING,
  436. ]
  437. return await self.providers.database.documents_handler.get_document_ids_by_status(
  438. status_type="extraction_status",
  439. status=[str(ele) for ele in document_status_filter],
  440. collection_id=collection_id,
  441. )
  442. @telemetry_event("kg_entity_description")
  443. async def kg_entity_description(
  444. self,
  445. document_id: UUID,
  446. max_description_input_length: int,
  447. **kwargs,
  448. ):
  449. start_time = time.time()
  450. logger.info(
  451. f"KGService: Running kg_entity_description for document {document_id}"
  452. )
  453. entity_count = (
  454. await self.providers.database.graphs_handler.get_entity_count(
  455. document_id=document_id,
  456. distinct=True,
  457. entity_table_name="documents_entities",
  458. )
  459. )
  460. logger.info(
  461. f"KGService: Found {entity_count} entities in document {document_id}"
  462. )
  463. # TODO - Do not hardcode the batch size,
  464. # make it a configurable parameter at runtime & server-side defaults
  465. # process 256 entities at a time
  466. num_batches = math.ceil(entity_count / 256)
  467. logger.info(
  468. f"Calling `kg_entity_description` on document {document_id} with an entity count of {entity_count} and total batches of {num_batches}"
  469. )
  470. all_results = []
  471. for i in range(num_batches):
  472. logger.info(
  473. f"KGService: Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
  474. )
  475. node_descriptions = await self.pipes.kg_entity_description_pipe.run(
  476. input=self.pipes.kg_entity_description_pipe.Input(
  477. message={
  478. "offset": i * 256,
  479. "limit": 256,
  480. "max_description_input_length": max_description_input_length,
  481. "document_id": document_id,
  482. "logger": logger,
  483. }
  484. ),
  485. state=None,
  486. run_manager=self.run_manager,
  487. )
  488. all_results.append(await _collect_results(node_descriptions))
  489. logger.info(
  490. f"KGService: Completed kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
  491. )
  492. await self.providers.database.documents_handler.set_workflow_status(
  493. id=document_id,
  494. status_type="extraction_status",
  495. status=KGExtractionStatus.SUCCESS,
  496. )
  497. logger.info(
  498. f"KGService: Completed kg_entity_description for document {document_id} in {time.time() - start_time:.2f} seconds",
  499. )
  500. return all_results
  501. @telemetry_event("kg_clustering")
  502. async def kg_clustering(
  503. self,
  504. collection_id: UUID,
  505. # graph_id: UUID,
  506. generation_config: GenerationConfig,
  507. leiden_params: dict,
  508. **kwargs,
  509. ):
  510. logger.info(
  511. f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
  512. )
  513. clustering_result = await self.pipes.kg_clustering_pipe.run(
  514. input=self.pipes.kg_clustering_pipe.Input(
  515. message={
  516. "collection_id": collection_id,
  517. "generation_config": generation_config,
  518. "leiden_params": leiden_params,
  519. "logger": logger,
  520. "clustering_mode": self.config.database.graph_creation_settings.clustering_mode,
  521. }
  522. ),
  523. state=None,
  524. run_manager=self.run_manager,
  525. )
  526. return await _collect_results(clustering_result)
  527. @telemetry_event("kg_community_summary")
  528. async def kg_community_summary(
  529. self,
  530. offset: int,
  531. limit: int,
  532. max_summary_input_length: int,
  533. generation_config: GenerationConfig,
  534. collection_id: UUID | None,
  535. # graph_id: UUID | None,
  536. **kwargs,
  537. ):
  538. summary_results = await self.pipes.kg_community_summary_pipe.run(
  539. input=self.pipes.kg_community_summary_pipe.Input(
  540. message={
  541. "offset": offset,
  542. "limit": limit,
  543. "generation_config": generation_config,
  544. "max_summary_input_length": max_summary_input_length,
  545. "collection_id": collection_id,
  546. # "graph_id": graph_id,
  547. "logger": logger,
  548. }
  549. ),
  550. state=None,
  551. run_manager=self.run_manager,
  552. )
  553. return await _collect_results(summary_results)
  554. @telemetry_event("delete_graph_for_documents")
  555. async def delete_graph_for_documents(
  556. self,
  557. document_ids: list[UUID],
  558. **kwargs,
  559. ):
  560. # TODO: Implement this, as it needs some checks.
  561. raise NotImplementedError
  562. @telemetry_event("delete_graph")
  563. async def delete_graph(
  564. self,
  565. collection_id: UUID,
  566. cascade: bool,
  567. **kwargs,
  568. ):
  569. return await self.delete_graph_for_collection(
  570. collection_id=collection_id, cascade=cascade
  571. )
  572. @telemetry_event("delete_graph_for_collection")
  573. async def delete_graph_for_collection(
  574. self,
  575. collection_id: UUID,
  576. cascade: bool,
  577. **kwargs,
  578. ):
  579. return await self.providers.database.graphs_handler.delete_graph_for_collection(
  580. collection_id=collection_id,
  581. cascade=cascade,
  582. )
  583. @telemetry_event("delete_node_via_document_id")
  584. async def delete_node_via_document_id(
  585. self,
  586. document_id: UUID,
  587. collection_id: UUID,
  588. **kwargs,
  589. ):
  590. return await self.providers.database.graphs_handler.delete_node_via_document_id(
  591. document_id=document_id,
  592. collection_id=collection_id,
  593. )
  594. @telemetry_event("get_creation_estimate")
  595. async def get_creation_estimate(
  596. self,
  597. graph_creation_settings: KGCreationSettings,
  598. document_id: Optional[UUID] = None,
  599. collection_id: Optional[UUID] = None,
  600. **kwargs,
  601. ):
  602. return (
  603. await self.providers.database.graphs_handler.get_creation_estimate(
  604. document_id=document_id,
  605. collection_id=collection_id,
  606. graph_creation_settings=graph_creation_settings,
  607. )
  608. )
  609. @telemetry_event("get_enrichment_estimate")
  610. async def get_enrichment_estimate(
  611. self,
  612. collection_id: Optional[UUID] = None,
  613. graph_id: Optional[UUID] = None,
  614. graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
  615. **kwargs,
  616. ):
  617. if graph_id is None and collection_id is None:
  618. raise ValueError(
  619. "Either graph_id or collection_id must be provided"
  620. )
  621. return await self.providers.database.graphs_handler.get_enrichment_estimate(
  622. collection_id=collection_id,
  623. graph_id=graph_id,
  624. graph_enrichment_settings=graph_enrichment_settings,
  625. )
  626. @telemetry_event("get_deduplication_estimate")
  627. async def get_deduplication_estimate(
  628. self,
  629. collection_id: UUID,
  630. kg_deduplication_settings: KGEntityDeduplicationSettings,
  631. **kwargs,
  632. ):
  633. return await self.providers.database.graphs_handler.get_deduplication_estimate(
  634. collection_id=collection_id,
  635. kg_deduplication_settings=kg_deduplication_settings,
  636. )
  637. @telemetry_event("kg_entity_deduplication")
  638. async def kg_entity_deduplication(
  639. self,
  640. collection_id: UUID,
  641. graph_id: UUID,
  642. graph_entity_deduplication_type: KGEntityDeduplicationType,
  643. graph_entity_deduplication_prompt: str,
  644. generation_config: GenerationConfig,
  645. **kwargs,
  646. ):
  647. deduplication_results = await self.pipes.kg_entity_deduplication_pipe.run(
  648. input=self.pipes.kg_entity_deduplication_pipe.Input(
  649. message={
  650. "collection_id": collection_id,
  651. "graph_id": graph_id,
  652. "graph_entity_deduplication_type": graph_entity_deduplication_type,
  653. "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt,
  654. "generation_config": generation_config,
  655. **kwargs,
  656. }
  657. ),
  658. state=None,
  659. run_manager=self.run_manager,
  660. )
  661. return await _collect_results(deduplication_results)
  662. @telemetry_event("kg_entity_deduplication_summary")
  663. async def kg_entity_deduplication_summary(
  664. self,
  665. collection_id: UUID,
  666. offset: int,
  667. limit: int,
  668. graph_entity_deduplication_type: KGEntityDeduplicationType,
  669. graph_entity_deduplication_prompt: str,
  670. generation_config: GenerationConfig,
  671. **kwargs,
  672. ):
  673. logger.info(
  674. f"Running kg_entity_deduplication_summary for collection {collection_id} with settings {kwargs}"
  675. )
  676. deduplication_summary_results = await self.pipes.kg_entity_deduplication_summary_pipe.run(
  677. input=self.pipes.kg_entity_deduplication_summary_pipe.Input(
  678. message={
  679. "collection_id": collection_id,
  680. "offset": offset,
  681. "limit": limit,
  682. "graph_entity_deduplication_type": graph_entity_deduplication_type,
  683. "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt,
  684. "generation_config": generation_config,
  685. }
  686. ),
  687. state=None,
  688. run_manager=self.run_manager,
  689. )
  690. return await _collect_results(deduplication_summary_results)
  691. async def kg_extraction( # type: ignore
  692. self,
  693. document_id: UUID,
  694. generation_config: GenerationConfig,
  695. max_knowledge_relationships: int,
  696. entity_types: list[str],
  697. relation_types: list[str],
  698. chunk_merge_count: int,
  699. filter_out_existing_chunks: bool = True,
  700. total_tasks: Optional[int] = None,
  701. *args: Any,
  702. **kwargs: Any,
  703. ) -> AsyncGenerator[KGExtraction | R2RDocumentProcessingError, None]:
  704. start_time = time.time()
  705. logger.info(
  706. f"KGExtractionPipe: Processing document {document_id} for KG extraction",
  707. )
  708. # Then create the extractions from the results
  709. limit = 100
  710. offset = 0
  711. chunks = []
  712. while True:
  713. chunk_req = await self.providers.database.chunks_handler.list_document_chunks( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
  714. document_id=document_id,
  715. offset=offset,
  716. limit=limit,
  717. )
  718. chunks.extend(
  719. [
  720. DocumentChunk(
  721. id=chunk["id"],
  722. document_id=chunk["document_id"],
  723. owner_id=chunk["owner_id"],
  724. collection_ids=chunk["collection_ids"],
  725. data=chunk["text"],
  726. metadata=chunk["metadata"],
  727. )
  728. for chunk in chunk_req["results"]
  729. ]
  730. )
  731. if len(chunk_req["results"]) < limit:
  732. break
  733. offset += limit
  734. logger.info(f"Found {len(chunks)} chunks for document {document_id}")
  735. if len(chunks) == 0:
  736. logger.info(f"No chunks found for document {document_id}")
  737. raise R2RException(
  738. message="No chunks found for document",
  739. status_code=404,
  740. )
  741. if filter_out_existing_chunks:
  742. existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids(
  743. document_id=document_id
  744. )
  745. chunks = [
  746. chunk for chunk in chunks if chunk.id not in existing_chunk_ids
  747. ]
  748. logger.info(
  749. f"Filtered out {len(existing_chunk_ids)} existing chunks, remaining {len(chunks)} chunks for document {document_id}"
  750. )
  751. if len(chunks) == 0:
  752. logger.info(f"No extractions left for document {document_id}")
  753. return
  754. logger.info(
  755. f"KGExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds",
  756. )
  757. # sort the extractions accroding to chunk_order field in metadata in ascending order
  758. chunks = sorted(
  759. chunks,
  760. key=lambda x: x.metadata.get("chunk_order", float("inf")),
  761. )
  762. # group these extractions into groups of chunk_merge_count
  763. grouped_chunks = [
  764. chunks[i : i + chunk_merge_count]
  765. for i in range(0, len(chunks), chunk_merge_count)
  766. ]
  767. logger.info(
  768. f"KGExtractionPipe: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds",
  769. )
  770. tasks = [
  771. asyncio.create_task(
  772. self._extract_kg(
  773. chunks=chunk_group,
  774. generation_config=generation_config,
  775. max_knowledge_relationships=max_knowledge_relationships,
  776. entity_types=entity_types,
  777. relation_types=relation_types,
  778. task_id=task_id,
  779. total_tasks=len(grouped_chunks),
  780. )
  781. )
  782. for task_id, chunk_group in enumerate(grouped_chunks)
  783. ]
  784. completed_tasks = 0
  785. total_tasks = len(tasks)
  786. logger.info(
  787. f"KGExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
  788. )
  789. for completed_task in asyncio.as_completed(tasks):
  790. try:
  791. yield await completed_task
  792. completed_tasks += 1
  793. if completed_tasks % 100 == 0:
  794. logger.info(
  795. f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
  796. )
  797. except Exception as e:
  798. logger.error(f"Error in Extracting KG Relationships: {e}")
  799. yield R2RDocumentProcessingError(
  800. document_id=document_id,
  801. error_message=str(e),
  802. )
  803. logger.info(
  804. f"KGExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
  805. )
  806. async def _extract_kg(
  807. self,
  808. chunks: list[DocumentChunk],
  809. generation_config: GenerationConfig,
  810. max_knowledge_relationships: int,
  811. entity_types: list[str],
  812. relation_types: list[str],
  813. retries: int = 5,
  814. delay: int = 2,
  815. task_id: Optional[int] = None,
  816. total_tasks: Optional[int] = None,
  817. ) -> KGExtraction:
  818. """
  819. Extracts NER relationships from a extraction with retries.
  820. """
  821. # combine all extractions into a single string
  822. combined_extraction: str = " ".join([chunk.data for chunk in chunks]) # type: ignore
  823. response = await self.providers.database.documents_handler.get_documents_overview( # type: ignore
  824. offset=0,
  825. limit=1,
  826. filter_document_ids=[chunks[0].document_id],
  827. )
  828. document_summary = (
  829. response["results"][0].summary if response["results"] else None
  830. )
  831. messages = await self.providers.database.prompts_handler.get_message_payload(
  832. task_prompt_name=self.providers.database.config.graph_creation_settings.graphrag_relationships_extraction_few_shot,
  833. task_inputs={
  834. "document_summary": document_summary,
  835. "input": combined_extraction,
  836. "max_knowledge_relationships": max_knowledge_relationships,
  837. "entity_types": "\n".join(entity_types),
  838. "relation_types": "\n".join(relation_types),
  839. },
  840. )
  841. for attempt in range(retries):
  842. try:
  843. response = await self.providers.llm.aget_completion(
  844. messages,
  845. generation_config=generation_config,
  846. )
  847. kg_extraction = response.choices[0].message.content
  848. if not kg_extraction:
  849. raise R2RException(
  850. "No knowledge graph extraction found in the response string, the selected LLM likely failed to format it's response correctly.",
  851. 400,
  852. )
  853. entity_pattern = (
  854. r'\("entity"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\)'
  855. )
  856. relationship_pattern = r'\("relationship"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}(\d+(?:\.\d+)?)\)'
  857. async def parse_fn(response_str: str) -> Any:
  858. entities = re.findall(entity_pattern, response_str)
  859. if (
  860. len(kg_extraction)
  861. > MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH
  862. and len(entities) == 0
  863. ):
  864. raise R2RException(
  865. f"No entities found in the response string, the selected LLM likely failed to format it's response correctly. {response_str}",
  866. 400,
  867. )
  868. relationships = re.findall(
  869. relationship_pattern, response_str
  870. )
  871. entities_arr = []
  872. for entity in entities:
  873. entity_value = entity[0]
  874. entity_category = entity[1]
  875. entity_description = entity[2]
  876. description_embedding = (
  877. await self.providers.embedding.async_get_embedding(
  878. entity_description
  879. )
  880. )
  881. entities_arr.append(
  882. Entity(
  883. category=entity_category,
  884. description=entity_description,
  885. name=entity_value,
  886. parent_id=chunks[0].document_id,
  887. chunk_ids=[chunk.id for chunk in chunks],
  888. description_embedding=description_embedding,
  889. attributes={},
  890. )
  891. )
  892. relations_arr = []
  893. for relationship in relationships:
  894. subject = relationship[0]
  895. object = relationship[1]
  896. predicate = relationship[2]
  897. description = relationship[3]
  898. weight = float(relationship[4])
  899. relationship_embedding = (
  900. await self.providers.embedding.async_get_embedding(
  901. description
  902. )
  903. )
  904. # check if subject and object are in entities_dict
  905. relations_arr.append(
  906. Relationship(
  907. subject=subject,
  908. predicate=predicate,
  909. object=object,
  910. description=description,
  911. weight=weight,
  912. parent_id=chunks[0].document_id,
  913. chunk_ids=[chunk.id for chunk in chunks],
  914. attributes={},
  915. description_embedding=relationship_embedding,
  916. )
  917. )
  918. return entities_arr, relations_arr
  919. entities, relationships = await parse_fn(kg_extraction)
  920. return KGExtraction(
  921. entities=entities,
  922. relationships=relationships,
  923. )
  924. except (
  925. Exception,
  926. json.JSONDecodeError,
  927. KeyError,
  928. IndexError,
  929. R2RException,
  930. ) as e:
  931. if attempt < retries - 1:
  932. await asyncio.sleep(delay)
  933. else:
  934. logger.warning(
  935. f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}"
  936. )
  937. logger.info(
  938. f"KGExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
  939. )
  940. return KGExtraction(
  941. entities=[],
  942. relationships=[],
  943. )
  944. async def store_kg_extractions(
  945. self,
  946. kg_extractions: list[KGExtraction],
  947. ):
  948. """
  949. Stores a batch of knowledge graph extractions in the graph database.
  950. """
  951. for extraction in kg_extractions:
  952. entities_id_map = {}
  953. for entity in extraction.entities:
  954. result = await self.providers.database.graphs_handler.entities.create(
  955. name=entity.name,
  956. parent_id=entity.parent_id,
  957. store_type="documents", # type: ignore
  958. category=entity.category,
  959. description=entity.description,
  960. description_embedding=entity.description_embedding,
  961. chunk_ids=entity.chunk_ids,
  962. metadata=entity.metadata,
  963. )
  964. entities_id_map[entity.name] = result.id
  965. if extraction.relationships:
  966. for relationship in extraction.relationships:
  967. await self.providers.database.graphs_handler.relationships.create(
  968. subject=relationship.subject,
  969. subject_id=entities_id_map.get(relationship.subject),
  970. predicate=relationship.predicate,
  971. object=relationship.object,
  972. object_id=entities_id_map.get(relationship.object),
  973. parent_id=relationship.parent_id,
  974. description=relationship.description,
  975. description_embedding=relationship.description_embedding,
  976. weight=relationship.weight,
  977. metadata=relationship.metadata,
  978. store_type="documents", # type: ignore
  979. )