graph_service.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997
  1. import asyncio
  2. import json
  3. import logging
  4. import math
  5. import re
  6. import time
  7. from typing import Any, AsyncGenerator, Optional
  8. from uuid import UUID
  9. from core.base import (
  10. DocumentChunk,
  11. KGExtraction,
  12. KGExtractionStatus,
  13. R2RDocumentProcessingError,
  14. RunManager,
  15. )
  16. from core.base.abstractions import (
  17. Community,
  18. Entity,
  19. GenerationConfig,
  20. KGCreationSettings,
  21. KGEnrichmentSettings,
  22. KGEnrichmentStatus,
  23. R2RException,
  24. Relationship,
  25. StoreType,
  26. )
  27. from core.base.api.models import GraphResponse
  28. from core.telemetry.telemetry_decorator import telemetry_event
  29. from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
  30. from ..config import R2RConfig
  31. from .base import Service
  32. logger = logging.getLogger()
  33. MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH = 128
  34. async def _collect_results(result_gen: AsyncGenerator) -> list[dict]:
  35. results = []
  36. async for res in result_gen:
  37. results.append(res.json() if hasattr(res, "json") else res)
  38. return results
  39. # TODO - Fix naming convention to read `KGService` instead of `GraphService`
  40. # this will require a minor change in how services are registered.
  41. class GraphService(Service):
  42. def __init__(
  43. self,
  44. config: R2RConfig,
  45. providers: R2RProviders,
  46. pipes: R2RPipes,
  47. pipelines: R2RPipelines,
  48. agents: R2RAgents,
  49. run_manager: RunManager,
  50. ):
  51. super().__init__(
  52. config,
  53. providers,
  54. pipes,
  55. pipelines,
  56. agents,
  57. run_manager,
  58. )
  59. @telemetry_event("kg_relationships_extraction")
  60. async def kg_relationships_extraction(
  61. self,
  62. document_id: UUID,
  63. generation_config: GenerationConfig,
  64. chunk_merge_count: int,
  65. max_knowledge_relationships: int,
  66. entity_types: list[str],
  67. relation_types: list[str],
  68. **kwargs,
  69. ):
  70. try:
  71. logger.info(
  72. f"KGService: Processing document {document_id} for KG extraction"
  73. )
  74. await self.providers.database.documents_handler.set_workflow_status(
  75. id=document_id,
  76. status_type="extraction_status",
  77. status=KGExtractionStatus.PROCESSING,
  78. )
  79. relationships = await self.pipes.graph_extraction_pipe.run(
  80. input=self.pipes.graph_extraction_pipe.Input(
  81. message={
  82. "document_id": document_id,
  83. "generation_config": generation_config,
  84. "chunk_merge_count": chunk_merge_count,
  85. "max_knowledge_relationships": max_knowledge_relationships,
  86. "entity_types": entity_types,
  87. "relation_types": relation_types,
  88. "logger": logger,
  89. }
  90. ),
  91. state=None,
  92. run_manager=self.run_manager,
  93. )
  94. logger.info(
  95. f"KGService: Finished processing document {document_id} for KG extraction"
  96. )
  97. result_gen = await self.pipes.graph_storage_pipe.run(
  98. input=self.pipes.graph_storage_pipe.Input(
  99. message=relationships
  100. ),
  101. state=None,
  102. run_manager=self.run_manager,
  103. )
  104. except Exception as e:
  105. logger.error(f"KGService: Error in kg_extraction: {e}")
  106. await self.providers.database.documents_handler.set_workflow_status(
  107. id=document_id,
  108. status_type="extraction_status",
  109. status=KGExtractionStatus.FAILED,
  110. )
  111. raise e
  112. return await _collect_results(result_gen)
  113. @telemetry_event("create_entity")
  114. async def create_entity(
  115. self,
  116. name: str,
  117. description: str,
  118. parent_id: UUID,
  119. category: Optional[str] = None,
  120. metadata: Optional[dict] = None,
  121. ) -> Entity:
  122. description_embedding = str(
  123. await self.providers.embedding.async_get_embedding(description)
  124. )
  125. return await self.providers.database.graphs_handler.entities.create(
  126. name=name,
  127. parent_id=parent_id,
  128. store_type=StoreType.GRAPHS,
  129. category=category,
  130. description=description,
  131. description_embedding=description_embedding,
  132. metadata=metadata,
  133. )
  134. @telemetry_event("update_entity")
  135. async def update_entity(
  136. self,
  137. entity_id: UUID,
  138. name: Optional[str] = None,
  139. description: Optional[str] = None,
  140. category: Optional[str] = None,
  141. metadata: Optional[dict] = None,
  142. ) -> Entity:
  143. description_embedding = None
  144. if description is not None:
  145. description_embedding = str(
  146. await self.providers.embedding.async_get_embedding(description)
  147. )
  148. return await self.providers.database.graphs_handler.entities.update(
  149. entity_id=entity_id,
  150. store_type=StoreType.GRAPHS,
  151. name=name,
  152. description=description,
  153. description_embedding=description_embedding,
  154. category=category,
  155. metadata=metadata,
  156. )
  157. @telemetry_event("delete_entity")
  158. async def delete_entity(
  159. self,
  160. parent_id: UUID,
  161. entity_id: UUID,
  162. ):
  163. return await self.providers.database.graphs_handler.entities.delete(
  164. parent_id=parent_id,
  165. entity_ids=[entity_id],
  166. store_type=StoreType.GRAPHS,
  167. )
  168. @telemetry_event("get_entities")
  169. async def get_entities(
  170. self,
  171. parent_id: UUID,
  172. offset: int,
  173. limit: int,
  174. entity_ids: Optional[list[UUID]] = None,
  175. entity_names: Optional[list[str]] = None,
  176. include_embeddings: bool = False,
  177. ):
  178. return await self.providers.database.graphs_handler.get_entities(
  179. parent_id=parent_id,
  180. offset=offset,
  181. limit=limit,
  182. entity_ids=entity_ids,
  183. entity_names=entity_names,
  184. include_embeddings=include_embeddings,
  185. )
  186. @telemetry_event("create_relationship")
  187. async def create_relationship(
  188. self,
  189. subject: str,
  190. subject_id: UUID,
  191. predicate: str,
  192. object: str,
  193. object_id: UUID,
  194. parent_id: UUID,
  195. description: str | None = None,
  196. weight: float | None = 1.0,
  197. metadata: Optional[dict[str, Any] | str] = None,
  198. ) -> Relationship:
  199. description_embedding = None
  200. if description:
  201. description_embedding = str(
  202. await self.providers.embedding.async_get_embedding(description)
  203. )
  204. return (
  205. await self.providers.database.graphs_handler.relationships.create(
  206. subject=subject,
  207. subject_id=subject_id,
  208. predicate=predicate,
  209. object=object,
  210. object_id=object_id,
  211. parent_id=parent_id,
  212. description=description,
  213. description_embedding=description_embedding,
  214. weight=weight,
  215. metadata=metadata,
  216. store_type=StoreType.GRAPHS,
  217. )
  218. )
  219. @telemetry_event("delete_relationship")
  220. async def delete_relationship(
  221. self,
  222. parent_id: UUID,
  223. relationship_id: UUID,
  224. ):
  225. return (
  226. await self.providers.database.graphs_handler.relationships.delete(
  227. parent_id=parent_id,
  228. relationship_ids=[relationship_id],
  229. store_type=StoreType.GRAPHS,
  230. )
  231. )
  232. @telemetry_event("update_relationship")
  233. async def update_relationship(
  234. self,
  235. relationship_id: UUID,
  236. subject: Optional[str] = None,
  237. subject_id: Optional[UUID] = None,
  238. predicate: Optional[str] = None,
  239. object: Optional[str] = None,
  240. object_id: Optional[UUID] = None,
  241. description: Optional[str] = None,
  242. weight: Optional[float] = None,
  243. metadata: Optional[dict[str, Any] | str] = None,
  244. ) -> Relationship:
  245. description_embedding = None
  246. if description is not None:
  247. description_embedding = str(
  248. await self.providers.embedding.async_get_embedding(description)
  249. )
  250. return (
  251. await self.providers.database.graphs_handler.relationships.update(
  252. relationship_id=relationship_id,
  253. subject=subject,
  254. subject_id=subject_id,
  255. predicate=predicate,
  256. object=object,
  257. object_id=object_id,
  258. description=description,
  259. description_embedding=description_embedding,
  260. weight=weight,
  261. metadata=metadata,
  262. store_type=StoreType.GRAPHS,
  263. )
  264. )
  265. @telemetry_event("get_relationships")
  266. async def get_relationships(
  267. self,
  268. parent_id: UUID,
  269. offset: int,
  270. limit: int,
  271. relationship_ids: Optional[list[UUID]] = None,
  272. entity_names: Optional[list[str]] = None,
  273. ):
  274. return await self.providers.database.graphs_handler.relationships.get(
  275. parent_id=parent_id,
  276. store_type=StoreType.GRAPHS,
  277. offset=offset,
  278. limit=limit,
  279. relationship_ids=relationship_ids,
  280. entity_names=entity_names,
  281. )
  282. @telemetry_event("create_community")
  283. async def create_community(
  284. self,
  285. parent_id: UUID,
  286. name: str,
  287. summary: str,
  288. findings: Optional[list[str]],
  289. rating: Optional[float],
  290. rating_explanation: Optional[str],
  291. ) -> Community:
  292. description_embedding = str(
  293. await self.providers.embedding.async_get_embedding(summary)
  294. )
  295. return await self.providers.database.graphs_handler.communities.create(
  296. parent_id=parent_id,
  297. store_type=StoreType.GRAPHS,
  298. name=name,
  299. summary=summary,
  300. description_embedding=description_embedding,
  301. findings=findings,
  302. rating=rating,
  303. rating_explanation=rating_explanation,
  304. )
  305. @telemetry_event("update_community")
  306. async def update_community(
  307. self,
  308. community_id: UUID,
  309. name: Optional[str],
  310. summary: Optional[str],
  311. findings: Optional[list[str]],
  312. rating: Optional[float],
  313. rating_explanation: Optional[str],
  314. ) -> Community:
  315. summary_embedding = None
  316. if summary is not None:
  317. summary_embedding = str(
  318. await self.providers.embedding.async_get_embedding(summary)
  319. )
  320. return await self.providers.database.graphs_handler.communities.update(
  321. community_id=community_id,
  322. store_type=StoreType.GRAPHS,
  323. name=name,
  324. summary=summary,
  325. summary_embedding=summary_embedding,
  326. findings=findings,
  327. rating=rating,
  328. rating_explanation=rating_explanation,
  329. )
  330. @telemetry_event("delete_community")
  331. async def delete_community(
  332. self,
  333. parent_id: UUID,
  334. community_id: UUID,
  335. ) -> None:
  336. await self.providers.database.graphs_handler.communities.delete(
  337. parent_id=parent_id,
  338. community_id=community_id,
  339. )
  340. @telemetry_event("list_communities")
  341. async def list_communities(
  342. self,
  343. collection_id: UUID,
  344. offset: int,
  345. limit: int,
  346. ):
  347. return await self.providers.database.graphs_handler.communities.get(
  348. parent_id=collection_id,
  349. store_type=StoreType.GRAPHS,
  350. offset=offset,
  351. limit=limit,
  352. )
  353. @telemetry_event("get_communities")
  354. async def get_communities(
  355. self,
  356. parent_id: UUID,
  357. offset: int,
  358. limit: int,
  359. community_ids: Optional[list[UUID]] = None,
  360. community_names: Optional[list[str]] = None,
  361. include_embeddings: bool = False,
  362. ):
  363. return await self.providers.database.graphs_handler.get_communities(
  364. parent_id=parent_id,
  365. offset=offset,
  366. limit=limit,
  367. community_ids=community_ids,
  368. include_embeddings=include_embeddings,
  369. )
  370. # @telemetry_event("create_new_graph")
  371. # async def create_new_graph(
  372. # self,
  373. # collection_id: UUID,
  374. # user_id: UUID,
  375. # name: Optional[str],
  376. # description: str = "",
  377. # ) -> GraphResponse:
  378. # return await self.providers.database.graphs_handler.create(
  379. # collection_id=collection_id,
  380. # user_id=user_id,
  381. # name=name,
  382. # description=description,
  383. # graph_id=collection_id,
  384. # )
  385. async def list_graphs(
  386. self,
  387. offset: int,
  388. limit: int,
  389. # user_ids: Optional[list[UUID]] = None,
  390. graph_ids: Optional[list[UUID]] = None,
  391. collection_id: Optional[UUID] = None,
  392. ) -> dict[str, list[GraphResponse] | int]:
  393. return await self.providers.database.graphs_handler.list_graphs(
  394. offset=offset,
  395. limit=limit,
  396. # filter_user_ids=user_ids,
  397. filter_graph_ids=graph_ids,
  398. filter_collection_id=collection_id,
  399. )
  400. @telemetry_event("update_graph")
  401. async def update_graph(
  402. self,
  403. collection_id: UUID,
  404. name: Optional[str] = None,
  405. description: Optional[str] = None,
  406. ) -> GraphResponse:
  407. return await self.providers.database.graphs_handler.update(
  408. collection_id=collection_id,
  409. name=name,
  410. description=description,
  411. )
  412. @telemetry_event("reset_graph_v3")
  413. async def reset_graph_v3(self, id: UUID) -> bool:
  414. await self.providers.database.graphs_handler.reset(
  415. parent_id=id,
  416. )
  417. await self.providers.database.documents_handler.set_workflow_status(
  418. id=id,
  419. status_type="graph_cluster_status",
  420. status=KGEnrichmentStatus.PENDING,
  421. )
  422. return True
  423. @telemetry_event("get_document_ids_for_create_graph")
  424. async def get_document_ids_for_create_graph(
  425. self,
  426. collection_id: UUID,
  427. force_kg_creation: bool = False,
  428. **kwargs,
  429. ):
  430. document_status_filter = [
  431. KGExtractionStatus.PENDING,
  432. KGExtractionStatus.FAILED,
  433. ]
  434. if force_kg_creation:
  435. document_status_filter += [
  436. KGExtractionStatus.PROCESSING,
  437. ]
  438. return await self.providers.database.documents_handler.get_document_ids_by_status(
  439. status_type="extraction_status",
  440. status=[str(ele) for ele in document_status_filter],
  441. collection_id=collection_id,
  442. )
  443. @telemetry_event("kg_entity_description")
  444. async def kg_entity_description(
  445. self,
  446. document_id: UUID,
  447. max_description_input_length: int,
  448. **kwargs,
  449. ):
  450. start_time = time.time()
  451. logger.info(
  452. f"KGService: Running kg_entity_description for document {document_id}"
  453. )
  454. entity_count = (
  455. await self.providers.database.graphs_handler.get_entity_count(
  456. document_id=document_id,
  457. distinct=True,
  458. entity_table_name="documents_entities",
  459. )
  460. )
  461. logger.info(
  462. f"KGService: Found {entity_count} entities in document {document_id}"
  463. )
  464. # TODO - Do not hardcode the batch size,
  465. # make it a configurable parameter at runtime & server-side defaults
  466. # process 256 entities at a time
  467. num_batches = math.ceil(entity_count / 256)
  468. logger.info(
  469. f"Calling `kg_entity_description` on document {document_id} with an entity count of {entity_count} and total batches of {num_batches}"
  470. )
  471. all_results = []
  472. for i in range(num_batches):
  473. logger.info(
  474. f"KGService: Running kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
  475. )
  476. node_descriptions = await self.pipes.graph_description_pipe.run(
  477. input=self.pipes.graph_description_pipe.Input(
  478. message={
  479. "offset": i * 256,
  480. "limit": 256,
  481. "max_description_input_length": max_description_input_length,
  482. "document_id": document_id,
  483. "logger": logger,
  484. }
  485. ),
  486. state=None,
  487. run_manager=self.run_manager,
  488. )
  489. all_results.append(await _collect_results(node_descriptions))
  490. logger.info(
  491. f"KGService: Completed kg_entity_description for batch {i+1}/{num_batches} for document {document_id}"
  492. )
  493. await self.providers.database.documents_handler.set_workflow_status(
  494. id=document_id,
  495. status_type="extraction_status",
  496. status=KGExtractionStatus.SUCCESS,
  497. )
  498. logger.info(
  499. f"KGService: Completed kg_entity_description for document {document_id} in {time.time() - start_time:.2f} seconds",
  500. )
  501. return all_results
  502. @telemetry_event("kg_clustering")
  503. async def kg_clustering(
  504. self,
  505. collection_id: UUID,
  506. # graph_id: UUID,
  507. generation_config: GenerationConfig,
  508. leiden_params: dict,
  509. **kwargs,
  510. ):
  511. logger.info(
  512. f"Running ClusteringPipe for collection {collection_id} with settings {leiden_params}"
  513. )
  514. clustering_result = await self.pipes.graph_clustering_pipe.run(
  515. input=self.pipes.graph_clustering_pipe.Input(
  516. message={
  517. "collection_id": collection_id,
  518. "generation_config": generation_config,
  519. "leiden_params": leiden_params,
  520. "logger": logger,
  521. "clustering_mode": self.config.database.graph_creation_settings.clustering_mode,
  522. }
  523. ),
  524. state=None,
  525. run_manager=self.run_manager,
  526. )
  527. return await _collect_results(clustering_result)
  528. @telemetry_event("kg_community_summary")
  529. async def kg_community_summary(
  530. self,
  531. offset: int,
  532. limit: int,
  533. max_summary_input_length: int,
  534. generation_config: GenerationConfig,
  535. collection_id: UUID | None,
  536. # graph_id: UUID | None,
  537. **kwargs,
  538. ):
  539. summary_results = await self.pipes.graph_community_summary_pipe.run(
  540. input=self.pipes.graph_community_summary_pipe.Input(
  541. message={
  542. "offset": offset,
  543. "limit": limit,
  544. "generation_config": generation_config,
  545. "max_summary_input_length": max_summary_input_length,
  546. "collection_id": collection_id,
  547. # "graph_id": graph_id,
  548. "logger": logger,
  549. }
  550. ),
  551. state=None,
  552. run_manager=self.run_manager,
  553. )
  554. return await _collect_results(summary_results)
  555. @telemetry_event("delete_graph_for_documents")
  556. async def delete_graph_for_documents(
  557. self,
  558. document_ids: list[UUID],
  559. **kwargs,
  560. ):
  561. # TODO: Implement this, as it needs some checks.
  562. raise NotImplementedError
  563. @telemetry_event("delete_graph")
  564. async def delete_graph(
  565. self,
  566. collection_id: UUID,
  567. ):
  568. return await self.delete(collection_id=collection_id)
  569. @telemetry_event("delete")
  570. async def delete(
  571. self,
  572. collection_id: UUID,
  573. **kwargs,
  574. ):
  575. return await self.providers.database.graphs_handler.delete(
  576. collection_id=collection_id,
  577. )
  578. @telemetry_event("get_creation_estimate")
  579. async def get_creation_estimate(
  580. self,
  581. graph_creation_settings: KGCreationSettings,
  582. document_id: Optional[UUID] = None,
  583. collection_id: Optional[UUID] = None,
  584. **kwargs,
  585. ):
  586. return (
  587. await self.providers.database.graphs_handler.get_creation_estimate(
  588. document_id=document_id,
  589. collection_id=collection_id,
  590. graph_creation_settings=graph_creation_settings,
  591. )
  592. )
  593. @telemetry_event("get_enrichment_estimate")
  594. async def get_enrichment_estimate(
  595. self,
  596. collection_id: Optional[UUID] = None,
  597. graph_id: Optional[UUID] = None,
  598. graph_enrichment_settings: KGEnrichmentSettings = KGEnrichmentSettings(),
  599. **kwargs,
  600. ):
  601. if graph_id is None and collection_id is None:
  602. raise ValueError(
  603. "Either graph_id or collection_id must be provided"
  604. )
  605. return await self.providers.database.graphs_handler.get_enrichment_estimate(
  606. collection_id=collection_id,
  607. graph_id=graph_id,
  608. graph_enrichment_settings=graph_enrichment_settings,
  609. )
  610. async def kg_extraction( # type: ignore
  611. self,
  612. document_id: UUID,
  613. generation_config: GenerationConfig,
  614. max_knowledge_relationships: int,
  615. entity_types: list[str],
  616. relation_types: list[str],
  617. chunk_merge_count: int,
  618. filter_out_existing_chunks: bool = True,
  619. total_tasks: Optional[int] = None,
  620. *args: Any,
  621. **kwargs: Any,
  622. ) -> AsyncGenerator[KGExtraction | R2RDocumentProcessingError, None]:
  623. start_time = time.time()
  624. logger.info(
  625. f"GraphExtractionPipe: Processing document {document_id} for KG extraction",
  626. )
  627. # Then create the extractions from the results
  628. limit = 100
  629. offset = 0
  630. chunks = []
  631. while True:
  632. chunk_req = await self.providers.database.chunks_handler.list_document_chunks( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
  633. document_id=document_id,
  634. offset=offset,
  635. limit=limit,
  636. )
  637. chunks.extend(
  638. [
  639. DocumentChunk(
  640. id=chunk["id"],
  641. document_id=chunk["document_id"],
  642. owner_id=chunk["owner_id"],
  643. collection_ids=chunk["collection_ids"],
  644. data=chunk["text"],
  645. metadata=chunk["metadata"],
  646. )
  647. for chunk in chunk_req["results"]
  648. ]
  649. )
  650. if len(chunk_req["results"]) < limit:
  651. break
  652. offset += limit
  653. logger.info(f"Found {len(chunks)} chunks for document {document_id}")
  654. if len(chunks) == 0:
  655. logger.info(f"No chunks found for document {document_id}")
  656. raise R2RException(
  657. message="No chunks found for document",
  658. status_code=404,
  659. )
  660. if filter_out_existing_chunks:
  661. existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids(
  662. document_id=document_id
  663. )
  664. chunks = [
  665. chunk for chunk in chunks if chunk.id not in existing_chunk_ids
  666. ]
  667. logger.info(
  668. f"Filtered out {len(existing_chunk_ids)} existing chunks, remaining {len(chunks)} chunks for document {document_id}"
  669. )
  670. if len(chunks) == 0:
  671. logger.info(f"No extractions left for document {document_id}")
  672. return
  673. logger.info(
  674. f"GraphExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds",
  675. )
  676. # sort the extractions accroding to chunk_order field in metadata in ascending order
  677. chunks = sorted(
  678. chunks,
  679. key=lambda x: x.metadata.get("chunk_order", float("inf")),
  680. )
  681. # group these extractions into groups of chunk_merge_count
  682. grouped_chunks = [
  683. chunks[i : i + chunk_merge_count]
  684. for i in range(0, len(chunks), chunk_merge_count)
  685. ]
  686. logger.info(
  687. f"GraphExtractionPipe: Extracting KG Relationships for document and created {len(grouped_chunks)} tasks, time from start: {time.time() - start_time:.2f} seconds",
  688. )
  689. tasks = [
  690. asyncio.create_task(
  691. self._extract_kg(
  692. chunks=chunk_group,
  693. generation_config=generation_config,
  694. max_knowledge_relationships=max_knowledge_relationships,
  695. entity_types=entity_types,
  696. relation_types=relation_types,
  697. task_id=task_id,
  698. total_tasks=len(grouped_chunks),
  699. )
  700. )
  701. for task_id, chunk_group in enumerate(grouped_chunks)
  702. ]
  703. completed_tasks = 0
  704. total_tasks = len(tasks)
  705. logger.info(
  706. f"GraphExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
  707. )
  708. for completed_task in asyncio.as_completed(tasks):
  709. try:
  710. yield await completed_task
  711. completed_tasks += 1
  712. if completed_tasks % 100 == 0:
  713. logger.info(
  714. f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
  715. )
  716. except Exception as e:
  717. logger.error(f"Error in Extracting KG Relationships: {e}")
  718. yield R2RDocumentProcessingError(
  719. document_id=document_id,
  720. error_message=str(e),
  721. )
  722. logger.info(
  723. f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
  724. )
  725. async def _extract_kg(
  726. self,
  727. chunks: list[DocumentChunk],
  728. generation_config: GenerationConfig,
  729. max_knowledge_relationships: int,
  730. entity_types: list[str],
  731. relation_types: list[str],
  732. retries: int = 5,
  733. delay: int = 2,
  734. task_id: Optional[int] = None,
  735. total_tasks: Optional[int] = None,
  736. ) -> KGExtraction:
  737. """
  738. Extracts NER relationships from a extraction with retries.
  739. """
  740. # combine all extractions into a single string
  741. combined_extraction: str = " ".join([chunk.data for chunk in chunks]) # type: ignore
  742. response = await self.providers.database.documents_handler.get_documents_overview( # type: ignore
  743. offset=0,
  744. limit=1,
  745. filter_document_ids=[chunks[0].document_id],
  746. )
  747. document_summary = (
  748. response["results"][0].summary if response["results"] else None
  749. )
  750. messages = await self.providers.database.prompts_handler.get_message_payload(
  751. task_prompt_name=self.providers.database.config.graph_creation_settings.graphrag_relationships_extraction_few_shot,
  752. task_inputs={
  753. "document_summary": document_summary,
  754. "input": combined_extraction,
  755. "max_knowledge_relationships": max_knowledge_relationships,
  756. "entity_types": "\n".join(entity_types),
  757. "relation_types": "\n".join(relation_types),
  758. },
  759. )
  760. for attempt in range(retries):
  761. try:
  762. response = await self.providers.llm.aget_completion(
  763. messages,
  764. generation_config=generation_config,
  765. )
  766. kg_extraction = response.choices[0].message.content
  767. if not kg_extraction:
  768. raise R2RException(
  769. "No knowledge graph extraction found in the response string, the selected LLM likely failed to format it's response correctly.",
  770. 400,
  771. )
  772. entity_pattern = (
  773. r'\("entity"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\)'
  774. )
  775. relationship_pattern = r'\("relationship"\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}([^$]+)\${4}(\d+(?:\.\d+)?)\)'
  776. async def parse_fn(response_str: str) -> Any:
  777. entities = re.findall(entity_pattern, response_str)
  778. if (
  779. len(kg_extraction)
  780. > MIN_VALID_KG_EXTRACTION_RESPONSE_LENGTH
  781. and len(entities) == 0
  782. ):
  783. raise R2RException(
  784. f"No entities found in the response string, the selected LLM likely failed to format it's response correctly. {response_str}",
  785. 400,
  786. )
  787. relationships = re.findall(
  788. relationship_pattern, response_str
  789. )
  790. entities_arr = []
  791. for entity in entities:
  792. entity_value = entity[0]
  793. entity_category = entity[1]
  794. entity_description = entity[2]
  795. description_embedding = (
  796. await self.providers.embedding.async_get_embedding(
  797. entity_description
  798. )
  799. )
  800. entities_arr.append(
  801. Entity(
  802. category=entity_category,
  803. description=entity_description,
  804. name=entity_value,
  805. parent_id=chunks[0].document_id,
  806. chunk_ids=[chunk.id for chunk in chunks],
  807. description_embedding=description_embedding,
  808. attributes={},
  809. )
  810. )
  811. relations_arr = []
  812. for relationship in relationships:
  813. subject = relationship[0]
  814. object = relationship[1]
  815. predicate = relationship[2]
  816. description = relationship[3]
  817. weight = float(relationship[4])
  818. relationship_embedding = (
  819. await self.providers.embedding.async_get_embedding(
  820. description
  821. )
  822. )
  823. # check if subject and object are in entities_dict
  824. relations_arr.append(
  825. Relationship(
  826. subject=subject,
  827. predicate=predicate,
  828. object=object,
  829. description=description,
  830. weight=weight,
  831. parent_id=chunks[0].document_id,
  832. chunk_ids=[chunk.id for chunk in chunks],
  833. attributes={},
  834. description_embedding=relationship_embedding,
  835. )
  836. )
  837. return entities_arr, relations_arr
  838. entities, relationships = await parse_fn(kg_extraction)
  839. return KGExtraction(
  840. entities=entities,
  841. relationships=relationships,
  842. )
  843. except (
  844. Exception,
  845. json.JSONDecodeError,
  846. KeyError,
  847. IndexError,
  848. R2RException,
  849. ) as e:
  850. if attempt < retries - 1:
  851. await asyncio.sleep(delay)
  852. else:
  853. logger.warning(
  854. f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}"
  855. )
  856. logger.info(
  857. f"GraphExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
  858. )
  859. return KGExtraction(
  860. entities=[],
  861. relationships=[],
  862. )
  863. async def store_kg_extractions(
  864. self,
  865. kg_extractions: list[KGExtraction],
  866. ):
  867. """
  868. Stores a batch of knowledge graph extractions in the graph database.
  869. """
  870. for extraction in kg_extractions:
  871. entities_id_map = {}
  872. for entity in extraction.entities:
  873. result = await self.providers.database.graphs_handler.entities.create(
  874. name=entity.name,
  875. parent_id=entity.parent_id,
  876. store_type=StoreType.DOCUMENTS,
  877. category=entity.category,
  878. description=entity.description,
  879. description_embedding=entity.description_embedding,
  880. chunk_ids=entity.chunk_ids,
  881. metadata=entity.metadata,
  882. )
  883. entities_id_map[entity.name] = result.id
  884. if extraction.relationships:
  885. for relationship in extraction.relationships:
  886. await self.providers.database.graphs_handler.relationships.create(
  887. subject=relationship.subject,
  888. subject_id=entities_id_map.get(relationship.subject),
  889. predicate=relationship.predicate,
  890. object=relationship.object,
  891. object_id=entities_id_map.get(relationship.object),
  892. parent_id=relationship.parent_id,
  893. description=relationship.description,
  894. description_embedding=relationship.description_embedding,
  895. weight=relationship.weight,
  896. metadata=relationship.metadata,
  897. store_type=StoreType.DOCUMENTS,
  898. )