graph_service.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071
  1. import asyncio
  2. import json
  3. import logging
  4. import math
  5. import re
  6. import time
  7. from typing import Any, AsyncGenerator, Optional
  8. from uuid import UUID
  9. from core.base import (
  10. DocumentChunk,
  11. KGExtraction,
  12. KGExtractionStatus,
  13. R2RDocumentProcessingError,
  14. RunManager,
  15. )
  16. from core.base.abstractions import (
  17. Community,
  18. Entity,
  19. GenerationConfig,
  20. KGCreationSettings,
  21. KGEnrichmentSettings,
  22. KGEnrichmentStatus,
  23. KGEntityDeduplicationSettings,
  24. KGEntityDeduplicationType,
  25. R2RException,
  26. Relationship,
  27. )
  28. from core.base.api.models import GraphResponse
  29. from core.telemetry.telemetry_decorator import telemetry_event
  30. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  31. from ..config import R2RConfig
  32. from .base import Service
  33. logger = logging.getLogger()
  34. MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH = 128
  35. async def _collect_results(result_gen: AsyncGenerator) -> list[dict]:
  36. results = []
  37. async for res in result_gen:
  38. results.append(res.json() if hasattr(res, "json") else res)
  39. return results
  40. # TODO - Fix naming convention to read `KGService` instead of `GraphService`
  41. # this will require a minor change in how services are registered.
  42. class GraphService(Service):
  43. def __init__(
  44. self,
  45. config: R2RConfig,
  46. providers: R2RProviders,
  47. pipes: R2RPipes,
  48. pipelines: R2RPipelines,
  49. agents: R2RAgents,
  50. run_manager: RunManager,
  51. ):
  52. super().__init__(
  53. config,
  54. providers,
  55. pipes,
  56. pipelines,
  57. agents,
  58. run_manager,
  59. )
  60. @telemetry_event("kg_relationships_extraction")
  61. async def kg_relationships_extraction(
  62. self,
  63. document_id: UUID,
  64. generation_config: GenerationConfig,
  65. chunk_merge_count: int,
  66. max_knowledge_relationships: int,
  67. entity_types: list[str],
  68. relation_types: list[str],
  69. **kwargs,
  70. ):
  71. try:
  72. logger.info(
  73. f"KGService: Processing document {document_id} for KG extraction"
  74. )
  75. await self.providers.database.documents_handler.set_workflow_status(
  76. id=document_id,
  77. status_type="extraction_status",
  78. status=KGExtractionStatus.PROCESSING,
  79. )
  80. relationships = await self.pipes.graph_extraction_pipe.run(
  81. input=self.pipes.graph_extraction_pipe.Input(
  82. message={
  83. "document_id": document_id,
  84. "generation_config": generation_config,
  85. "chunk_merge_count": chunk_merge_count,
  86. "max_knowledge_relationships": max_knowledge_relationships,
  87. "entity_types": entity_types,
  88. "relation_types": relation_types,
  89. "logger": logger,
  90. }
  91. ),
  92. state=None,
  93. run_manager=self.run_manager,
  94. )
  95. logger.info(
  96. f"KGService: Finished processing document {document_id} for KG extraction"
  97. )
  98. result_gen = await self.pipes.graph_storage_pipe.run(
  99. input=self.pipes.graph_storage_pipe.Input(
  100. message=relationships
  101. ),
  102. state=None,
  103. run_manager=self.run_manager,
  104. )
  105. except Exception as e:
  106. logger.error(f"KGService: Error in kg_extraction: {e}")
  107. await self.providers.database.documents_handler.set_workflow_status(
  108. id=document_id,
  109. status_type="extraction_status",
  110. status=KGExtractionStatus.FAILED,
  111. )
  112. raise e
  113. return await _collect_results(result_gen)
  114. @telemetry_event("create_entity")
  115. async def create_entity(
  116. self,
  117. name: str,
  118. description: str,
  119. parent_id: UUID,
  120. category: Optional[str] = None,
  121. metadata: Optional[dict] = None,
  122. ) -> Entity:
  123. description_embedding = str(
  124. await self.providers.embedding.async_get_embedding(description)
  125. )
  126. return await self.providers.database.graphs_handler.entities.create(
  127. name=name,
  128. parent_id=parent_id,
  129. store_type="graphs", # type: ignore
  130. category=category,
  131. description=description,
  132. description_embedding=description_embedding,
  133. metadata=metadata,
  134. )
  135. @telemetry_event("update_entity")
  136. async def update_entity(
  137. self,
  138. entity_id: UUID,
  139. name: Optional[str] = None,
  140. description: Optional[str] = None,
  141. category: Optional[str] = None,
  142. metadata: Optional[dict] = None,
  143. ) -> Entity:
  144. description_embedding = None
  145. if description is not None:
  146. description_embedding = str(
  147. await self.providers.embedding.async_get_embedding(description)
  148. )
  149. return await self.providers.database.graphs_handler.entities.update(
  150. entity_id=entity_id,
  151. store_type="graphs", # type: ignore
  152. name=name,
  153. description=description,
  154. description_embedding=description_embedding,
  155. category=category,
  156. metadata=metadata,
  157. )
  158. @telemetry_event("delete_entity")
  159. async def delete_entity(
  160. self,
  161. parent_id: UUID,
  162. entity_id: UUID,
  163. ):
  164. return await self.providers.database.graphs_handler.entities.delete(
  165. parent_id=parent_id,
  166. entity_ids=[entity_id],
  167. store_type="graphs", # type: ignore
  168. )
  169. @telemetry_event("get_entities")
  170. async def get_entities(
  171. self,
  172. parent_id: UUID,
  173. offset: int,
  174. limit: int,
  175. entity_ids: Optional[list[UUID]] = None,
  176. entity_names: Optional[list[str]] = None,
  177. include_embeddings: bool = False,
  178. ):
  179. return await self.providers.database.graphs_handler.get_entities(
  180. parent_id=parent_id,
  181. offset=offset,
  182. limit=limit,
  183. entity_ids=entity_ids,
  184. entity_names=entity_names,
  185. include_embeddings=include_embeddings,
  186. )
  187. @telemetry_event("create_relationship")
  188. async def create_relationship(
  189. self,
  190. subject: str,
  191. subject_id: UUID,
  192. predicate: str,
  193. object: str,
  194. object_id: UUID,
  195. parent_id: UUID,
  196. description: str | None = None,
  197. weight: float | None = 1.0,
  198. metadata: Optional[dict[str, Any] | str] = None,
  199. ) -> Relationship:
  200. description_embedding = None
  201. if description:
  202. description_embedding = str(
  203. await self.providers.embedding.async_get_embedding(description)
  204. )
  205. return (
  206. await self.providers.database.graphs_handler.relationships.create(
  207. subject=subject,
  208. subject_id=subject_id,
  209. predicate=predicate,
  210. object=object,
  211. object_id=object_id,
  212. parent_id=parent_id,
  213. description=description,
  214. description_embedding=description_embedding,
  215. weight=weight,
  216. metadata=metadata,
  217. store_type="graphs", # type: ignore
  218. )
  219. )
  220. @telemetry_event("delete_relationship")
  221. async def delete_relationship(
  222. self,
  223. parent_id: UUID,
  224. relationship_id: UUID,
  225. ):
  226. return (
  227. await self.providers.database.graphs_handler.relationships.delete(
  228. parent_id=parent_id,
  229. relationship_ids=[relationship_id],
  230. store_type="graphs", # type: ignore
  231. )
  232. )
  233. @telemetry_event("update_relationship")
  234. async def update_relationship(
  235. self,
  236. relationship_id: UUID,
  237. subject: Optional[str] = None,
  238. subject_id: Optional[UUID] = None,
  239. predicate: Optional[str] = None,
  240. object: Optional[str] = None,
  241. object_id: Optional[UUID] = None,
  242. description: Optional[str] = None,
  243. weight: Optional[float] = None,
  244. metadata: Optional[dict[str, Any] | str] = None,
  245. ) -> Relationship:
  246. description_embedding = None
  247. if description is not None:
  248. description_embedding = str(
  249. await self.providers.embedding.async_get_embedding(description)
  250. )
  251. return (
  252. await self.providers.database.graphs_handler.relationships.update(
  253. relationship_id=relationship_id,
  254. subject=subject,
  255. subject_id=subject_id,
  256. predicate=predicate,
  257. object=object,
  258. object_id=object_id,
  259. description=description,
  260. description_embedding=description_embedding,
  261. weight=weight,
  262. metadata=metadata,
  263. store_type="graphs", # type: ignore
  264. )
  265. )
  266. @telemetry_event("get_relationships")
  267. async def get_relationships(
  268. self,
  269. parent_id: UUID,
  270. offset: int,
  271. limit: int,
  272. relationship_ids: Optional[list[UUID]] = None,
  273. entity_names: Optional[list[str]] = None,
  274. ):
  275. return await self.providers.database.graphs_handler.relationships.get(
  276. parent_id=parent_id,
  277. store_type="graphs", # type: ignore
  278. offset=offset,
  279. limit=limit,
  280. relationship_ids=relationship_ids,
  281. entity_names=entity_names,
  282. )
  283. @telemetry_event("create_community")
  284. async def create_community(
  285. self,
  286. parent_id: UUID,
  287. name: str,
  288. summary: str,
  289. findings: Optional[list[str]],
  290. rating: Optional[float],
  291. rating_explanation: Optional[str],
  292. ) -> Community:
  293. description_embedding = str(
  294. await self.providers.embedding.async_get_embedding(summary)
  295. )
  296. return await self.providers.database.graphs_handler.communities.create(
  297. parent_id=parent_id,
  298. store_type="graphs", # type: ignore
  299. name=name,
  300. summary=summary,
  301. description_embedding=description_embedding,
  302. findings=findings,
  303. rating=rating,
  304. rating_explanation=rating_explanation,
  305. )
  306. @telemetry_event("update_community")
  307. async def update_community(
  308. self,
  309. community_id: UUID,
  310. name: Optional[str],
  311. summary: Optional[str],
  312. findings: Optional[list[str]],
  313. rating: Optional[float],
  314. rating_explanation: Optional[str],
  315. ) -> Community:
  316. summary_embedding = None
  317. if summary is not None:
  318. summary_embedding = str(
  319. await self.providers.embedding.async_get_embedding(summary)
  320. )
  321. return await self.providers.database.graphs_handler.communities.update(
  322. community_id=community_id,
  323. store_type="graphs", # type: ignore
  324. name=name,
  325. summary=summary,
  326. summary_embedding=summary_embedding,
  327. findings=findings,
  328. rating=rating,
  329. rating_explanation=rating_explanation,
  330. )
  331. @telemetry_event("delete_community")
  332. async def delete_community(
  333. self,
  334. parent_id: UUID,
  335. community_id: UUID,
  336. ) -> None:
  337. await self.providers.database.graphs_handler.communities.delete(
  338. parent_id=parent_id,
  339. community_id=community_id,
  340. )
  341. @telemetry_event("list_communities")
  342. async def list_communities(
  343. self,
  344. collection_id: UUID,
  345. offset: int,
  346. limit: int,
  347. ):
  348. return await self.providers.database.graphs_handler.communities.get(
  349. parent_id=collection_id,
  350. store_type="graphs", # type: ignore
  351. offset=offset,
  352. limit=limit,
  353. )
  354. @telemetry_event("get_communities")
  355. async def get_communities(
  356. self,
  357. parent_id: UUID,
  358. offset: int,
  359. limit: int,
  360. community_ids: Optional[list[UUID]] = None,
  361. community_names: Optional[list[str]] = None,
  362. include_embeddings: bool = False,
  363. ):
  364. return await self.providers.database.graphs_handler.get_communities(
  365. parent_id=parent_id,
  366. offset=offset,
  367. limit=limit,
  368. community_ids=community_ids,
  369. include_embeddings=include_embeddings,
  370. )
  371. # @telemetry_event("create_new_graph")
  372. # async def create_new_graph(
  373. # self,
  374. # collection_id: UUID,
  375. # user_id: UUID,
  376. # name: Optional[str],
  377. # description: str = "",
  378. # ) -> GraphResponse:
  379. # return await self.providers.database.graphs_handler.create(
  380. # collection_id=collection_id,
  381. # user_id=user_id,
  382. # name=name,
  383. # description=description,
  384. # graph_id=collection_id,
  385. # )
  386. async def list_graphs(
  387. self,
  388. offset: int,
  389. limit: int,
  390. # user_ids: Optional[list[UUID]] = None,
  391. graph_ids: Optional[list[UUID]] = None,
  392. collection_id: Optional[UUID] = None,
  393. ) -> dict[str, list[GraphResponse] | int]:
  394. return await self.providers.database.graphs_handler.list_graphs(
  395. offset=offset,
  396. limit=limit,
  397. # filter_user_ids=user_ids,
  398. filter_graph_ids=graph_ids,
  399. filter_collection_id=collection_id,
  400. )
  401. @telemetry_event("update_graph")
  402. async def update_graph(
  403. self,
  404. collection_id: UUID,
  405. name: Optional[str] = None,
  406. description: Optional[str] = None,
  407. ) -> GraphResponse:
  408. return await self.providers.database.graphs_handler.update(
  409. collection_id=collection_id,
  410. name=name,
  411. description=description,
  412. )
  413. @telemetry_event("reset_graph_v3")
  414. async def reset_graph_v3(self, id: UUID) -> bool:
  415. await self.providers.database.graphs_handler.reset(
  416. parent_id=id,
  417. )
  418. await self.providers.database.documents_handler.set_workflow_status(
  419. id=id,
  420. status_type="graph_cluster_status",
  421. status=KGEnrichmentStatus.PENDING,
  422. )
  423. return True
  424. @telemetry_event("get_document_ids_for_create_graph")
  425. async def get_document_ids_for_create_graph(
  426. self,
  427. collection_id: UUID,
  428. force_kg_creation: bool = False,
  429. **kwargs,
  430. ):
  431. document_status_filter = [
  432. KGExtractionStatus.PENDING,
  433. KGExtractionStatus.FAILED,
  434. ]
  435. if force_kg_creation:
  436. document_status_filter += [
  437. KGExtractionStatus.PROCESSING,
  438. ]
  439. return await self.providers.database.documents_handler.get_document_ids_by_status(
  440. status_type="extraction_status",
  441. status=[str(ele) for ele in document_status_filter],
  442. collection_id=collection_id,
  443. )
  444. @telemetry_event("kg_entity_description")
  445. async def kg_entity_description(
  446. self,
  447. document_id: UUID,
  448. max_description_input_length: int,
  449. **kwargs,
  450. ):
  451. start_time = time.time()
  452. logger.info(
  453. f"KGService: Running kg_entity_description for document {document_id}"
  454. )
  455. entity_count = (
  456. await self.providers.database.graphs_handler.get_entity_count(
  457. document_id=document_id,
  458. distinct=True,
  459. entity_table_name="documents_entities",
  460. )
  461. )
  462. logger.info(
  463. f"KGService: Found {entity_count} entities in document {document_id}"
  464. )
  465. # TODO - Do not hardcode the batch size,
  466. # make it a configurable parameter at runtime & server-side defaults
  467. # process 256 entities at a time
  468. num_batches = math.ceil(entity_count / 256)
  469. logger.info(
  470. f"Calling `kg_entity_description` on document {document_id} with an entity count of {entity_count} and total batches of {num_batches}"
  471. )
  472. all_results = []
  473. for i in range(num_batches):
  474. logger.info(
  475. f"KGService: Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
  476. )
  477. node_descriptions = await self.pipes.graph_description_pipe.run(
  478. input=self.pipes.graph_description_pipe.Input(
  479. message={
  480. "offset": i * 256,
  481. "limit": 256,
  482. "max_description_input_length": max_description_input_length,
  483. "document_id": document_id,
  484. "logger": logger,
  485. }
  486. ),
  487. state=None,
  488. run_manager=self.run_manager,
  489. )
  490. all_results.append(await _collect_results(node_descriptions))
  491. logger.info(
  492. f"KGService: Completed kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
  493. )
  494. await self.providers.database.documents_handler.set_workflow_status(
  495. id=document_id,
  496. status_type="extraction_status",
  497. status=KGExtractionStatus.SUCCESS,
  498. )
  499. logger.info(
  500. f"KGService: Completed kg_entity_description for document {document_id} in {time.time() - start_time:.2f} seconds",
  501. )
  502. return all_results
  503. @telemetry_event("kg_clustering")
  504. async def kg_clustering(
  505. self,
  506. collection_id: UUID,
  507. # graph_id: UUID,
  508. generation_config: GenerationConfig,
  509. leiden_params: dict,
  510. **kwargs,
  511. ):
  512. logger.info(
  513. f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
  514. )
  515. clustering_result = await self.pipes.graph_clustering_pipe.run(
  516. input=self.pipes.graph_clustering_pipe.Input(
  517. message={
  518. "collection_id": collection_id,
  519. "generation_config": generation_config,
  520. "leiden_params": leiden_params,
  521. "logger": logger,
  522. "clustering_mode": self.config.database.graph_creation_settings.clustering_mode,
  523. }
  524. ),
  525. state=None,
  526. run_manager=self.run_manager,
  527. )
  528. return await _collect_results(clustering_result)
  529. @telemetry_event("kg_community_summary")
  530. async def kg_community_summary(
  531. self,
  532. offset: int,
  533. limit: int,
  534. max_summary_input_length: int,
  535. generation_config: GenerationConfig,
  536. collection_id: UUID | None,
  537. # graph_id: UUID | None,
  538. **kwargs,
  539. ):
  540. summary_results = await self.pipes.graph_community_summary_pipe.run(
  541. input=self.pipes.graph_community_summary_pipe.Input(
  542. message={
  543. "offset": offset,
  544. "limit": limit,
  545. "generation_config": generation_config,
  546. "max_summary_input_length": max_summary_input_length,
  547. "collection_id": collection_id,
  548. # "graph_id": graph_id,
  549. "logger": logger,
  550. }
  551. ),
  552. state=None,
  553. run_manager=self.run_manager,
  554. )
  555. return await _collect_results(summary_results)
  556. @telemetry_event("delete_graph_for_documents")
  557. async def delete_graph_for_documents(
  558. self,
  559. document_ids: list[UUID],
  560. **kwargs,
  561. ):
  562. # TODO: Implement this, as it needs some checks.
  563. raise NotImplementedError
  564. @telemetry_event("delete_graph")
  565. async def delete_graph(
  566. self,
  567. collection_id: UUID,
  568. cascade: bool,
  569. **kwargs,
  570. ):
  571. return await self.delete(collection_id=collection_id, cascade=cascade)
  572. @telemetry_event("delete")
  573. async def delete(
  574. self,
  575. collection_id: UUID,
  576. cascade: bool,
  577. **kwargs,
  578. ):
  579. return await self.providers.database.graphs_handler.delete(
  580. collection_id=collection_id,
  581. cascade=cascade,
  582. )
  583. @telemetry_event("get_creation_estimate")
  584. async def get_creation_estimate(
  585. self,
  586. graph_creation_settings: KGCreationSettings,
  587. document_id: Optional[UUID] = None,
  588. collection_id: Optional[UUID] = None,
  589. **kwargs,
  590. ):
  591. return (
  592. await self.providers.database.graphs_handler.get_creation_estimate(
  593. document_id=document_id,
  594. collection_id=collection_id,
  595. graph_creation_settings=graph_creation_settings,
  596. )
  597. )
  598. @telemetry_event("get_enrichment_estimate")
  599. async def get_enrichment_estimate(
  600. self,
  601. collection_id: Optional[UUID] = None,
  602. graph_id: Optional[UUID] = None,
  603. graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
  604. **kwargs,
  605. ):
  606. if graph_id is None and collection_id is None:
  607. raise ValueError(
  608. "Either graph_id or collection_id must be provided"
  609. )
  610. return await self.providers.database.graphs_handler.get_enrichment_estimate(
  611. collection_id=collection_id,
  612. graph_id=graph_id,
  613. graph_enrichment_settings=graph_enrichment_settings,
  614. )
  615. @telemetry_event("get_deduplication_estimate")
  616. async def get_deduplication_estimate(
  617. self,
  618. collection_id: UUID,
  619. kg_deduplication_settings: KGEntityDeduplicationSettings,
  620. **kwargs,
  621. ):
  622. return await self.providers.database.graphs_handler.get_deduplication_estimate(
  623. collection_id=collection_id,
  624. kg_deduplication_settings=kg_deduplication_settings,
  625. )
  626. @telemetry_event("kg_entity_deduplication")
  627. async def kg_entity_deduplication(
  628. self,
  629. collection_id: UUID,
  630. graph_id: UUID,
  631. graph_entity_deduplication_type: KGEntityDeduplicationType,
  632. graph_entity_deduplication_prompt: str,
  633. generation_config: GenerationConfig,
  634. **kwargs,
  635. ):
  636. deduplication_results = await self.pipes.graph_deduplication_pipe.run(
  637. input=self.pipes.graph_deduplication_pipe.Input(
  638. message={
  639. "collection_id": collection_id,
  640. "graph_id": graph_id,
  641. "graph_entity_deduplication_type": graph_entity_deduplication_type,
  642. "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt,
  643. "generation_config": generation_config,
  644. **kwargs,
  645. }
  646. ),
  647. state=None,
  648. run_manager=self.run_manager,
  649. )
  650. return await _collect_results(deduplication_results)
  651. @telemetry_event("kg_entity_deduplication_summary")
  652. async def kg_entity_deduplication_summary(
  653. self,
  654. collection_id: UUID,
  655. offset: int,
  656. limit: int,
  657. graph_entity_deduplication_type: KGEntityDeduplicationType,
  658. graph_entity_deduplication_prompt: str,
  659. generation_config: GenerationConfig,
  660. **kwargs,
  661. ):
  662. logger.info(
  663. f"Running kg_entity_deduplication_summary for collection {collection_id} with settings {kwargs}"
  664. )
  665. deduplication_summary_results = await self.pipes.graph_deduplication_summary_pipe.run(
  666. input=self.pipes.graph_deduplication_summary_pipe.Input(
  667. message={
  668. "collection_id": collection_id,
  669. "offset": offset,
  670. "limit": limit,
  671. "graph_entity_deduplication_type": graph_entity_deduplication_type,
  672. "graph_entity_deduplication_prompt": graph_entity_deduplication_prompt,
  673. "generation_config": generation_config,
  674. }
  675. ),
  676. state=None,
  677. run_manager=self.run_manager,
  678. )
  679. return await _collect_results(deduplication_summary_results)
  680. async def kg_extraction( # type: ignore
  681. self,
  682. document_id: UUID,
  683. generation_config: GenerationConfig,
  684. max_knowledge_relationships: int,
  685. entity_types: list[str],
  686. relation_types: list[str],
  687. chunk_merge_count: int,
  688. filter_out_existing_chunks: bool = True,
  689. total_tasks: Optional[int] = None,
  690. *args: Any,
  691. **kwargs: Any,
  692. ) -> AsyncGenerator[KGExtraction | R2RDocumentProcessingError, None]:
  693. start_time = time.time()
  694. logger.info(
  695. f"GraphExtractionPipe: Processing document {document_id} for KG extraction",
  696. )
  697. # Then create the extractions from the results
  698. limit = 100
  699. offset = 0
  700. chunks = []
  701. while True:
  702. chunk_req = await self.providers.database.chunks_handler.list_document_chunks( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
  703. document_id=document_id,
  704. offset=offset,
  705. limit=limit,
  706. )
  707. chunks.extend(
  708. [
  709. DocumentChunk(
  710. id=chunk["id"],
  711. document_id=chunk["document_id"],
  712. owner_id=chunk["owner_id"],
  713. collection_ids=chunk["collection_ids"],
  714. data=chunk["text"],
  715. metadata=chunk["metadata"],
  716. )
  717. for chunk in chunk_req["results"]
  718. ]
  719. )
  720. if len(chunk_req["results"]) < limit:
  721. break
  722. offset += limit
  723. logger.info(f"Found {len(chunks)} chunks for document {document_id}")
  724. if len(chunks) == 0:
  725. logger.info(f"No chunks found for document {document_id}")
  726. raise R2RException(
  727. message="No chunks found for document",
  728. status_code=404,
  729. )
  730. if filter_out_existing_chunks:
  731. existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids(
  732. document_id=document_id
  733. )
  734. chunks = [
  735. chunk for chunk in chunks if chunk.id not in existing_chunk_ids
  736. ]
  737. logger.info(
  738. f"Filtered out {len(existing_chunk_ids)} existing chunks, remaining {len(chunks)} chunks for document {document_id}"
  739. )
  740. if len(chunks) == 0:
  741. logger.info(f"No extractions left for document {document_id}")
  742. return
  743. logger.info(
  744. f"GraphExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds",
  745. )
  746. # sort the extractions accroding to chunk_order field in metadata in ascending order
  747. chunks = sorted(
  748. chunks,
  749. key=lambda x: x.metadata.get("chunk_order", float("inf")),
  750. )
  751. # group these extractions into groups of chunk_merge_count
  752. grouped_chunks = [
  753. chunks[i : i + chunk_merge_count]
  754. for i in range(0, len(chunks), chunk_merge_count)
  755. ]
  756. logger.info(
  757. f"GraphExtractionPipe: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds",
  758. )
  759. tasks = [
  760. asyncio.create_task(
  761. self._extract_kg(
  762. chunks=chunk_group,
  763. generation_config=generation_config,
  764. max_knowledge_relationships=max_knowledge_relationships,
  765. entity_types=entity_types,
  766. relation_types=relation_types,
  767. task_id=task_id,
  768. total_tasks=len(grouped_chunks),
  769. )
  770. )
  771. for task_id, chunk_group in enumerate(grouped_chunks)
  772. ]
  773. completed_tasks = 0
  774. total_tasks = len(tasks)
  775. logger.info(
  776. f"GraphExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
  777. )
  778. for completed_task in asyncio.as_completed(tasks):
  779. try:
  780. yield await completed_task
  781. completed_tasks += 1
  782. if completed_tasks % 100 == 0:
  783. logger.info(
  784. f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
  785. )
  786. except Exception as e:
  787. logger.error(f"Error in Extracting KG Relationships: {e}")
  788. yield R2RDocumentProcessingError(
  789. document_id=document_id,
  790. error_message=str(e),
  791. )
  792. logger.info(
  793. f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
  794. )
  795. async def _extract_kg(
  796. self,
  797. chunks: list[DocumentChunk],
  798. generation_config: GenerationConfig,
  799. max_knowledge_relationships: int,
  800. entity_types: list[str],
  801. relation_types: list[str],
  802. retries: int = 5,
  803. delay: int = 2,
  804. task_id: Optional[int] = None,
  805. total_tasks: Optional[int] = None,
  806. ) -> KGExtraction:
  807. """
  808. Extracts NER relationships from a extraction with retries.
  809. """
  810. # combine all extractions into a single string
  811. combined_extraction: str = " ".join([chunk.data for chunk in chunks]) # type: ignore
  812. response = await self.providers.database.documents_handler.get_documents_overview( # type: ignore
  813. offset=0,
  814. limit=1,
  815. filter_document_ids=[chunks[0].document_id],
  816. )
  817. document_summary = (
  818. response["results"][0].summary if response["results"] else None
  819. )
  820. messages = await self.providers.database.prompts_handler.get_message_payload(
  821. task_prompt_name=self.providers.database.config.graph_creation_settings.graphrag_relationships_extraction_few_shot,
  822. task_inputs={
  823. "document_summary": document_summary,
  824. "input": combined_extraction,
  825. "max_knowledge_relationships": max_knowledge_relationships,
  826. "entity_types": "\n".join(entity_types),
  827. "relation_types": "\n".join(relation_types),
  828. },
  829. )
  830. for attempt in range(retries):
  831. try:
  832. response = await self.providers.llm.aget_completion(
  833. messages,
  834. generation_config=generation_config,
  835. )
  836. kg_extraction = response.choices[0].message.content
  837. if not kg_extraction:
  838. raise R2RException(
  839. "No knowledge graph extraction found in the response string, the selected LLM likely failed to format it's response correctly.",
  840. 400,
  841. )
  842. entity_pattern = (
  843. r'\("entity"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\)'
  844. )
  845. relationship_pattern = r'\("relationship"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}(\d+(?:\.\d+)?)\)'
  846. async def parse_fn(response_str: str) -> Any:
  847. entities = re.findall(entity_pattern, response_str)
  848. if (
  849. len(kg_extraction)
  850. > MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH
  851. and len(entities) == 0
  852. ):
  853. raise R2RException(
  854. f"No entities found in the response string, the selected LLM likely failed to format it's response correctly. {response_str}",
  855. 400,
  856. )
  857. relationships = re.findall(
  858. relationship_pattern, response_str
  859. )
  860. entities_arr = []
  861. for entity in entities:
  862. entity_value = entity[0]
  863. entity_category = entity[1]
  864. entity_description = entity[2]
  865. description_embedding = (
  866. await self.providers.embedding.async_get_embedding(
  867. entity_description
  868. )
  869. )
  870. entities_arr.append(
  871. Entity(
  872. category=entity_category,
  873. description=entity_description,
  874. name=entity_value,
  875. parent_id=chunks[0].document_id,
  876. chunk_ids=[chunk.id for chunk in chunks],
  877. description_embedding=description_embedding,
  878. attributes={},
  879. )
  880. )
  881. relations_arr = []
  882. for relationship in relationships:
  883. subject = relationship[0]
  884. object = relationship[1]
  885. predicate = relationship[2]
  886. description = relationship[3]
  887. weight = float(relationship[4])
  888. relationship_embedding = (
  889. await self.providers.embedding.async_get_embedding(
  890. description
  891. )
  892. )
  893. # check if subject and object are in entities_dict
  894. relations_arr.append(
  895. Relationship(
  896. subject=subject,
  897. predicate=predicate,
  898. object=object,
  899. description=description,
  900. weight=weight,
  901. parent_id=chunks[0].document_id,
  902. chunk_ids=[chunk.id for chunk in chunks],
  903. attributes={},
  904. description_embedding=relationship_embedding,
  905. )
  906. )
  907. return entities_arr, relations_arr
  908. entities, relationships = await parse_fn(kg_extraction)
  909. return KGExtraction(
  910. entities=entities,
  911. relationships=relationships,
  912. )
  913. except (
  914. Exception,
  915. json.JSONDecodeError,
  916. KeyError,
  917. IndexError,
  918. R2RException,
  919. ) as e:
  920. if attempt < retries - 1:
  921. await asyncio.sleep(delay)
  922. else:
  923. logger.warning(
  924. f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}"
  925. )
  926. logger.info(
  927. f"GraphExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
  928. )
  929. return KGExtraction(
  930. entities=[],
  931. relationships=[],
  932. )
  933. async def store_kg_extractions(
  934. self,
  935. kg_extractions: list[KGExtraction],
  936. ):
  937. """
  938. Stores a batch of knowledge graph extractions in the graph database.
  939. """
  940. for extraction in kg_extractions:
  941. entities_id_map = {}
  942. for entity in extraction.entities:
  943. result = await self.providers.database.graphs_handler.entities.create(
  944. name=entity.name,
  945. parent_id=entity.parent_id,
  946. store_type="documents", # type: ignore
  947. category=entity.category,
  948. description=entity.description,
  949. description_embedding=entity.description_embedding,
  950. chunk_ids=entity.chunk_ids,
  951. metadata=entity.metadata,
  952. )
  953. entities_id_map[entity.name] = result.id
  954. if extraction.relationships:
  955. for relationship in extraction.relationships:
  956. await self.providers.database.graphs_handler.relationships.create(
  957. subject=relationship.subject,
  958. subject_id=entities_id_map.get(relationship.subject),
  959. predicate=relationship.predicate,
  960. object=relationship.object,
  961. object_id=entities_id_map.get(relationship.object),
  962. parent_id=relationship.parent_id,
  963. description=relationship.description,
  964. description_embedding=relationship.description_embedding,
  965. weight=relationship.weight,
  966. metadata=relationship.metadata,
  967. store_type="documents", # type: ignore
  968. )