graph_service.py 47 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361
  1. import asyncio
  2. import logging
  3. import math
  4. import random
  5. import re
  6. import time
  7. import uuid
  8. import xml.etree.ElementTree as ET
  9. from typing import Any, AsyncGenerator, Coroutine, Optional
  10. from uuid import UUID
  11. from xml.etree.ElementTree import Element
  12. from core.base import (
  13. DocumentChunk,
  14. GraphExtraction,
  15. GraphExtractionStatus,
  16. R2RDocumentProcessingError,
  17. )
  18. from core.base.abstractions import (
  19. Community,
  20. Entity,
  21. GenerationConfig,
  22. GraphConstructionStatus,
  23. R2RException,
  24. Relationship,
  25. StoreType,
  26. )
  27. from core.base.api.models import GraphResponse
  28. from ..abstractions import R2RProviders
  29. from ..config import R2RConfig
  30. from .base import Service
  31. logger = logging.getLogger()
  32. MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH = 128
  33. async def _collect_async_results(result_gen: AsyncGenerator) -> list[Any]:
  34. """Collects all results from an async generator into a list."""
  35. results = []
  36. async for res in result_gen:
  37. results.append(res)
  38. return results
  39. class GraphService(Service):
  40. def __init__(
  41. self,
  42. config: R2RConfig,
  43. providers: R2RProviders,
  44. ):
  45. super().__init__(
  46. config,
  47. providers,
  48. )
  49. async def create_entity(
  50. self,
  51. name: str,
  52. description: str,
  53. parent_id: UUID,
  54. category: Optional[str] = None,
  55. metadata: Optional[dict] = None,
  56. ) -> Entity:
  57. description_embedding = str(
  58. await self.providers.embedding.async_get_embedding(description)
  59. )
  60. return await self.providers.database.graphs_handler.entities.create(
  61. name=name,
  62. parent_id=parent_id,
  63. store_type=StoreType.GRAPHS,
  64. category=category,
  65. description=description,
  66. description_embedding=description_embedding,
  67. metadata=metadata,
  68. )
  69. async def update_entity(
  70. self,
  71. entity_id: UUID,
  72. name: Optional[str] = None,
  73. description: Optional[str] = None,
  74. category: Optional[str] = None,
  75. metadata: Optional[dict] = None,
  76. ) -> Entity:
  77. description_embedding = None
  78. if description is not None:
  79. description_embedding = str(
  80. await self.providers.embedding.async_get_embedding(description)
  81. )
  82. return await self.providers.database.graphs_handler.entities.update(
  83. entity_id=entity_id,
  84. store_type=StoreType.GRAPHS,
  85. name=name,
  86. description=description,
  87. description_embedding=description_embedding,
  88. category=category,
  89. metadata=metadata,
  90. )
  91. async def delete_entity(
  92. self,
  93. parent_id: UUID,
  94. entity_id: UUID,
  95. ):
  96. return await self.providers.database.graphs_handler.entities.delete(
  97. parent_id=parent_id,
  98. entity_ids=[entity_id],
  99. store_type=StoreType.GRAPHS,
  100. )
  101. async def get_entities(
  102. self,
  103. parent_id: UUID,
  104. offset: int,
  105. limit: int,
  106. entity_ids: Optional[list[UUID]] = None,
  107. entity_names: Optional[list[str]] = None,
  108. include_embeddings: bool = False,
  109. ):
  110. return await self.providers.database.graphs_handler.get_entities(
  111. parent_id=parent_id,
  112. offset=offset,
  113. limit=limit,
  114. entity_ids=entity_ids,
  115. entity_names=entity_names,
  116. include_embeddings=include_embeddings,
  117. )
  118. async def create_relationship(
  119. self,
  120. subject: str,
  121. subject_id: UUID,
  122. predicate: str,
  123. object: str,
  124. object_id: UUID,
  125. parent_id: UUID,
  126. description: str | None = None,
  127. weight: float | None = 1.0,
  128. metadata: Optional[dict[str, Any] | str] = None,
  129. ) -> Relationship:
  130. description_embedding = None
  131. if description:
  132. description_embedding = str(
  133. await self.providers.embedding.async_get_embedding(description)
  134. )
  135. return (
  136. await self.providers.database.graphs_handler.relationships.create(
  137. subject=subject,
  138. subject_id=subject_id,
  139. predicate=predicate,
  140. object=object,
  141. object_id=object_id,
  142. parent_id=parent_id,
  143. description=description,
  144. description_embedding=description_embedding,
  145. weight=weight,
  146. metadata=metadata,
  147. store_type=StoreType.GRAPHS,
  148. )
  149. )
  150. async def delete_relationship(
  151. self,
  152. parent_id: UUID,
  153. relationship_id: UUID,
  154. ):
  155. return (
  156. await self.providers.database.graphs_handler.relationships.delete(
  157. parent_id=parent_id,
  158. relationship_ids=[relationship_id],
  159. store_type=StoreType.GRAPHS,
  160. )
  161. )
  162. async def update_relationship(
  163. self,
  164. relationship_id: UUID,
  165. subject: Optional[str] = None,
  166. subject_id: Optional[UUID] = None,
  167. predicate: Optional[str] = None,
  168. object: Optional[str] = None,
  169. object_id: Optional[UUID] = None,
  170. description: Optional[str] = None,
  171. weight: Optional[float] = None,
  172. metadata: Optional[dict[str, Any] | str] = None,
  173. ) -> Relationship:
  174. description_embedding = None
  175. if description is not None:
  176. description_embedding = str(
  177. await self.providers.embedding.async_get_embedding(description)
  178. )
  179. return (
  180. await self.providers.database.graphs_handler.relationships.update(
  181. relationship_id=relationship_id,
  182. subject=subject,
  183. subject_id=subject_id,
  184. predicate=predicate,
  185. object=object,
  186. object_id=object_id,
  187. description=description,
  188. description_embedding=description_embedding,
  189. weight=weight,
  190. metadata=metadata,
  191. store_type=StoreType.GRAPHS,
  192. )
  193. )
  194. async def get_relationships(
  195. self,
  196. parent_id: UUID,
  197. offset: int,
  198. limit: int,
  199. relationship_ids: Optional[list[UUID]] = None,
  200. entity_names: Optional[list[str]] = None,
  201. ):
  202. return await self.providers.database.graphs_handler.relationships.get(
  203. parent_id=parent_id,
  204. store_type=StoreType.GRAPHS,
  205. offset=offset,
  206. limit=limit,
  207. relationship_ids=relationship_ids,
  208. entity_names=entity_names,
  209. )
  210. async def create_community(
  211. self,
  212. parent_id: UUID,
  213. name: str,
  214. summary: str,
  215. findings: Optional[list[str]],
  216. rating: Optional[float],
  217. rating_explanation: Optional[str],
  218. ) -> Community:
  219. description_embedding = str(
  220. await self.providers.embedding.async_get_embedding(summary)
  221. )
  222. return await self.providers.database.graphs_handler.communities.create(
  223. parent_id=parent_id,
  224. store_type=StoreType.GRAPHS,
  225. name=name,
  226. summary=summary,
  227. description_embedding=description_embedding,
  228. findings=findings,
  229. rating=rating,
  230. rating_explanation=rating_explanation,
  231. )
  232. async def update_community(
  233. self,
  234. community_id: UUID,
  235. name: Optional[str],
  236. summary: Optional[str],
  237. findings: Optional[list[str]],
  238. rating: Optional[float],
  239. rating_explanation: Optional[str],
  240. ) -> Community:
  241. summary_embedding = None
  242. if summary is not None:
  243. summary_embedding = str(
  244. await self.providers.embedding.async_get_embedding(summary)
  245. )
  246. return await self.providers.database.graphs_handler.communities.update(
  247. community_id=community_id,
  248. store_type=StoreType.GRAPHS,
  249. name=name,
  250. summary=summary,
  251. summary_embedding=summary_embedding,
  252. findings=findings,
  253. rating=rating,
  254. rating_explanation=rating_explanation,
  255. )
  256. async def delete_community(
  257. self,
  258. parent_id: UUID,
  259. community_id: UUID,
  260. ) -> None:
  261. await self.providers.database.graphs_handler.communities.delete(
  262. parent_id=parent_id,
  263. community_id=community_id,
  264. )
  265. async def get_communities(
  266. self,
  267. parent_id: UUID,
  268. offset: int,
  269. limit: int,
  270. community_ids: Optional[list[UUID]] = None,
  271. community_names: Optional[list[str]] = None,
  272. include_embeddings: bool = False,
  273. ):
  274. return await self.providers.database.graphs_handler.get_communities(
  275. parent_id=parent_id,
  276. offset=offset,
  277. limit=limit,
  278. community_ids=community_ids,
  279. include_embeddings=include_embeddings,
  280. )
  281. async def list_graphs(
  282. self,
  283. offset: int,
  284. limit: int,
  285. graph_ids: Optional[list[UUID]] = None,
  286. collection_id: Optional[UUID] = None,
  287. ) -> dict[str, list[GraphResponse] | int]:
  288. return await self.providers.database.graphs_handler.list_graphs(
  289. offset=offset,
  290. limit=limit,
  291. filter_graph_ids=graph_ids,
  292. filter_collection_id=collection_id,
  293. )
  294. async def update_graph(
  295. self,
  296. collection_id: UUID,
  297. name: Optional[str] = None,
  298. description: Optional[str] = None,
  299. ) -> GraphResponse:
  300. return await self.providers.database.graphs_handler.update(
  301. collection_id=collection_id,
  302. name=name,
  303. description=description,
  304. )
  305. async def reset_graph(self, id: UUID) -> bool:
  306. await self.providers.database.graphs_handler.reset(
  307. parent_id=id,
  308. )
  309. await self.providers.database.documents_handler.set_workflow_status(
  310. id=id,
  311. status_type="graph_cluster_status",
  312. status=GraphConstructionStatus.PENDING,
  313. )
  314. return True
  315. async def get_document_ids_for_create_graph(
  316. self,
  317. collection_id: UUID,
  318. **kwargs,
  319. ):
  320. document_status_filter = [
  321. GraphExtractionStatus.PENDING,
  322. GraphExtractionStatus.FAILED,
  323. ]
  324. return await self.providers.database.documents_handler.get_document_ids_by_status(
  325. status_type="extraction_status",
  326. status=[str(ele) for ele in document_status_filter],
  327. collection_id=collection_id,
  328. )
  329. async def graph_search_results_entity_description(
  330. self,
  331. document_id: UUID,
  332. max_description_input_length: int,
  333. batch_size: int = 256,
  334. **kwargs,
  335. ):
  336. """A new implementation of the old GraphDescriptionPipe logic inline.
  337. No references to pipe objects.
  338. We:
  339. 1) Count how many entities are in the document
  340. 2) Process them in batches of `batch_size`
  341. 3) For each batch, we retrieve the entity map and possibly call LLM for missing descriptions
  342. """
  343. start_time = time.time()
  344. logger.info(
  345. f"GraphService: Running graph_search_results_entity_description for doc={document_id}"
  346. )
  347. # Count how many doc-entities exist
  348. entity_count = (
  349. await self.providers.database.graphs_handler.get_entity_count(
  350. document_id=document_id,
  351. distinct=True,
  352. entity_table_name="documents_entities", # or whichever table
  353. )
  354. )
  355. logger.info(
  356. f"GraphService: Found {entity_count} doc-entities to describe."
  357. )
  358. all_results = []
  359. num_batches = math.ceil(entity_count / batch_size)
  360. for i in range(num_batches):
  361. offset = i * batch_size
  362. limit = batch_size
  363. logger.info(
  364. f"GraphService: describing batch {i + 1}/{num_batches}, offset={offset}, limit={limit}"
  365. )
  366. # Actually handle describing the entities in the batch
  367. # We'll collect them into a list via an async generator
  368. gen = self._describe_entities_in_document_batch(
  369. document_id=document_id,
  370. offset=offset,
  371. limit=limit,
  372. max_description_input_length=max_description_input_length,
  373. )
  374. batch_results = await _collect_async_results(gen)
  375. all_results.append(batch_results)
  376. # Mark the doc's extraction status as success
  377. await self.providers.database.documents_handler.set_workflow_status(
  378. id=document_id,
  379. status_type="extraction_status",
  380. status=GraphExtractionStatus.SUCCESS,
  381. )
  382. logger.info(
  383. f"GraphService: Completed graph_search_results_entity_description for doc {document_id} in {time.time() - start_time:.2f}s."
  384. )
  385. return all_results
  386. async def _describe_entities_in_document_batch(
  387. self,
  388. document_id: UUID,
  389. offset: int,
  390. limit: int,
  391. max_description_input_length: int,
  392. ) -> AsyncGenerator[str, None]:
  393. """Core logic that replaces GraphDescriptionPipe._run_logic for a
  394. particular document/batch.
  395. Yields entity-names or some textual result as each entity is updated.
  396. """
  397. start_time = time.time()
  398. logger.info(
  399. f"Started describing doc={document_id}, offset={offset}, limit={limit}"
  400. )
  401. # 1) Get the "entity map" from the DB
  402. entity_map = (
  403. await self.providers.database.graphs_handler.get_entity_map(
  404. offset=offset, limit=limit, document_id=document_id
  405. )
  406. )
  407. total_entities = len(entity_map)
  408. logger.info(
  409. f"_describe_entities_in_document_batch: got {total_entities} items in entity_map for doc={document_id}."
  410. )
  411. # 2) For each entity name in the map, we gather sub-entities and relationships
  412. tasks: list[Coroutine[Any, Any, str]] = []
  413. tasks.extend(
  414. self._process_entity_for_description(
  415. entities=[
  416. entity if isinstance(entity, Entity) else Entity(**entity)
  417. for entity in entity_info["entities"]
  418. ],
  419. relationships=[
  420. rel
  421. if isinstance(rel, Relationship)
  422. else Relationship(**rel)
  423. for rel in entity_info["relationships"]
  424. ],
  425. document_id=document_id,
  426. max_description_input_length=max_description_input_length,
  427. )
  428. for entity_name, entity_info in entity_map.items()
  429. )
  430. # 3) Wait for all tasks, yield as they complete
  431. idx = 0
  432. for coro in asyncio.as_completed(tasks):
  433. result = await coro
  434. idx += 1
  435. if idx % 100 == 0:
  436. logger.info(
  437. f"_describe_entities_in_document_batch: {idx}/{total_entities} described for doc={document_id}"
  438. )
  439. yield result
  440. logger.info(
  441. f"Finished describing doc={document_id} batch offset={offset} in {time.time() - start_time:.2f}s."
  442. )
  443. async def _process_entity_for_description(
  444. self,
  445. entities: list[Entity],
  446. relationships: list[Relationship],
  447. document_id: UUID,
  448. max_description_input_length: int,
  449. ) -> str:
  450. """Adapted from the old process_entity function in
  451. GraphDescriptionPipe.
  452. If entity has no description, call an LLM to create one, then store it.
  453. Returns the name of the top entity (or could store more details).
  454. """
  455. def truncate_info(info_list: list[str], max_length: int) -> str:
  456. """Shuffles lines of info to try to keep them distinct, then
  457. accumulates until hitting max_length."""
  458. random.shuffle(info_list)
  459. truncated_info = ""
  460. current_length = 0
  461. for info in info_list:
  462. if current_length + len(info) > max_length:
  463. break
  464. truncated_info += info + "\n"
  465. current_length += len(info)
  466. return truncated_info
  467. # Grab a doc-level summary (optional) to feed into the prompt
  468. response = await self.providers.database.documents_handler.get_documents_overview(
  469. offset=0,
  470. limit=1,
  471. filter_document_ids=[document_id],
  472. )
  473. document_summary = (
  474. response["results"][0].summary if response["results"] else None
  475. )
  476. # Synthesize a minimal “entity info” string + relationship summary
  477. entity_info = [
  478. f"{e.name}, {e.description or 'NONE'}" for e in entities
  479. ]
  480. relationships_txt = [
  481. f"{i + 1}: {r.subject}, {r.object}, {r.predicate} - Summary: {r.description or ''}"
  482. for i, r in enumerate(relationships)
  483. ]
  484. # We'll describe only the first entity for simplicity
  485. # or you could do them all if needed
  486. main_entity = entities[0]
  487. if not main_entity.description:
  488. # We only call LLM if the entity is missing a description
  489. messages = await self.providers.database.prompts_handler.get_message_payload(
  490. task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt,
  491. task_inputs={
  492. "document_summary": document_summary,
  493. "entity_info": truncate_info(
  494. entity_info, max_description_input_length
  495. ),
  496. "relationships_txt": truncate_info(
  497. relationships_txt, max_description_input_length
  498. ),
  499. },
  500. )
  501. # Call the LLM
  502. gen_config = (
  503. self.providers.database.config.graph_creation_settings.generation_config
  504. or GenerationConfig(model=self.config.app.fast_llm)
  505. )
  506. llm_resp = await self.providers.llm.aget_completion(
  507. messages=messages,
  508. generation_config=gen_config,
  509. )
  510. new_description = llm_resp.choices[0].message.content
  511. if not new_description:
  512. logger.error(
  513. f"No LLM description returned for entity={main_entity.name}"
  514. )
  515. return main_entity.name
  516. # create embedding
  517. embed = (
  518. await self.providers.embedding.async_get_embeddings(
  519. [new_description]
  520. )
  521. )[0]
  522. # update DB
  523. main_entity.description = new_description
  524. main_entity.description_embedding = embed
  525. # Use a method to upsert entity in `documents_entities` or your table
  526. await self.providers.database.graphs_handler.add_entities(
  527. [main_entity],
  528. table_name="documents_entities",
  529. )
  530. return main_entity.name
  531. async def graph_search_results_clustering(
  532. self,
  533. collection_id: UUID,
  534. generation_config: GenerationConfig,
  535. leiden_params: dict,
  536. **kwargs,
  537. ):
  538. """
  539. Replacement for the old GraphClusteringPipe logic:
  540. 1) call perform_graph_clustering on the DB
  541. 2) return the result
  542. """
  543. logger.info(
  544. f"Running inline clustering for collection={collection_id} with params={leiden_params}"
  545. )
  546. return await self._perform_graph_clustering(
  547. collection_id=collection_id,
  548. generation_config=generation_config,
  549. leiden_params=leiden_params,
  550. )
  551. async def _perform_graph_clustering(
  552. self,
  553. collection_id: UUID,
  554. generation_config: GenerationConfig,
  555. leiden_params: dict,
  556. ) -> dict:
  557. """The actual clustering logic (previously in
  558. GraphClusteringPipe.cluster_graph_search_results)."""
  559. num_communities = await self.providers.database.graphs_handler.perform_graph_clustering(
  560. collection_id=collection_id,
  561. leiden_params=leiden_params,
  562. )
  563. return {"num_communities": num_communities}
  564. async def graph_search_results_community_summary(
  565. self,
  566. offset: int,
  567. limit: int,
  568. max_summary_input_length: int,
  569. generation_config: GenerationConfig,
  570. collection_id: UUID,
  571. leiden_params: Optional[dict] = None,
  572. **kwargs,
  573. ):
  574. """Replacement for the old GraphCommunitySummaryPipe logic.
  575. Summarizes communities after clustering. Returns an async generator or
  576. you can collect into a list.
  577. """
  578. logger.info(
  579. f"Running inline community summaries for coll={collection_id}, offset={offset}, limit={limit}"
  580. )
  581. # We call an internal function that yields summaries
  582. gen = self._summarize_communities(
  583. offset=offset,
  584. limit=limit,
  585. max_summary_input_length=max_summary_input_length,
  586. generation_config=generation_config,
  587. collection_id=collection_id,
  588. leiden_params=leiden_params or {},
  589. )
  590. return await _collect_async_results(gen)
  591. async def _summarize_communities(
  592. self,
  593. offset: int,
  594. limit: int,
  595. max_summary_input_length: int,
  596. generation_config: GenerationConfig,
  597. collection_id: UUID,
  598. leiden_params: dict,
  599. ) -> AsyncGenerator[dict, None]:
  600. """Does the community summary logic from
  601. GraphCommunitySummaryPipe._run_logic.
  602. Yields each summary dictionary as it completes.
  603. """
  604. start_time = time.time()
  605. logger.info(
  606. f"Starting community summarization for collection={collection_id}"
  607. )
  608. # get all entities & relationships
  609. (
  610. all_entities,
  611. _,
  612. ) = await self.providers.database.graphs_handler.get_entities(
  613. parent_id=collection_id,
  614. offset=0,
  615. limit=-1,
  616. include_embeddings=False,
  617. )
  618. (
  619. all_relationships,
  620. _,
  621. ) = await self.providers.database.graphs_handler.get_relationships(
  622. parent_id=collection_id,
  623. offset=0,
  624. limit=-1,
  625. include_embeddings=False,
  626. )
  627. # We can optionally re-run the clustering to produce fresh community assignments
  628. (
  629. _,
  630. community_clusters,
  631. ) = await self.providers.database.graphs_handler._cluster_and_add_community_info(
  632. relationships=all_relationships,
  633. leiden_params=leiden_params,
  634. collection_id=collection_id,
  635. )
  636. # Group clusters
  637. clusters: dict[Any, list[str]] = {}
  638. for item in community_clusters:
  639. cluster_id = item["cluster"]
  640. node_name = item["node"]
  641. clusters.setdefault(cluster_id, []).append(node_name)
  642. # create an async job for each cluster
  643. tasks: list[Coroutine[Any, Any, dict]] = []
  644. tasks.extend(
  645. self._process_community_summary(
  646. community_id=uuid.uuid4(),
  647. nodes=nodes,
  648. all_entities=all_entities,
  649. all_relationships=all_relationships,
  650. max_summary_input_length=max_summary_input_length,
  651. generation_config=generation_config,
  652. collection_id=collection_id,
  653. )
  654. for nodes in clusters.values()
  655. )
  656. total_jobs = len(tasks)
  657. results_returned = 0
  658. total_errors = 0
  659. for coro in asyncio.as_completed(tasks):
  660. summary = await coro
  661. results_returned += 1
  662. if results_returned % 50 == 0:
  663. logger.info(
  664. f"Community summaries: {results_returned}/{total_jobs} done in {time.time() - start_time:.2f}s"
  665. )
  666. if "error" in summary:
  667. total_errors += 1
  668. yield summary
  669. if total_errors > 0:
  670. logger.warning(
  671. f"{total_errors} communities failed summarization out of {total_jobs}"
  672. )
  673. async def _process_community_summary(
  674. self,
  675. community_id: UUID,
  676. nodes: list[str],
  677. all_entities: list[Entity],
  678. all_relationships: list[Relationship],
  679. max_summary_input_length: int,
  680. generation_config: GenerationConfig,
  681. collection_id: UUID,
  682. ) -> dict:
  683. """
  684. Summarize a single community: gather all relevant entities/relationships, call LLM to generate an XML block,
  685. parse it, store the result as a community in DB.
  686. """
  687. # (Equivalent to process_community in old code)
  688. # fetch the collection description (optional)
  689. response = await self.providers.database.collections_handler.get_collections_overview(
  690. offset=0,
  691. limit=1,
  692. filter_collection_ids=[collection_id],
  693. )
  694. collection_description = (
  695. response["results"][0].description if response["results"] else None # type: ignore
  696. )
  697. # filter out relevant entities / relationships
  698. entities = [e for e in all_entities if e.name in nodes]
  699. relationships = [
  700. r
  701. for r in all_relationships
  702. if r.subject in nodes and r.object in nodes
  703. ]
  704. if not entities and not relationships:
  705. return {
  706. "community_id": community_id,
  707. "error": f"No data in this community (nodes={nodes})",
  708. }
  709. # Create the big input text for the LLM
  710. input_text = await self._community_summary_prompt(
  711. entities,
  712. relationships,
  713. max_summary_input_length,
  714. )
  715. # Attempt up to 3 times to parse
  716. for attempt in range(3):
  717. try:
  718. # Build the prompt
  719. messages = await self.providers.database.prompts_handler.get_message_payload(
  720. task_prompt_name=self.providers.database.config.graph_enrichment_settings.graph_communities_prompt,
  721. task_inputs={
  722. "collection_description": collection_description,
  723. "input_text": input_text,
  724. },
  725. )
  726. llm_resp = await self.providers.llm.aget_completion(
  727. messages=messages,
  728. generation_config=generation_config,
  729. )
  730. llm_text = llm_resp.choices[0].message.content or ""
  731. # find <community>...</community> XML
  732. match = re.search(
  733. r"<community>.*?</community>", llm_text, re.DOTALL
  734. )
  735. if not match:
  736. raise ValueError(
  737. "No <community> XML found in LLM response"
  738. )
  739. xml_content = re.sub(
  740. r"&(?!amp;|quot;|apos;|lt;|gt;)", "&amp;", match.group(0)
  741. ).strip()
  742. root = ET.fromstring(xml_content)
  743. # extract fields
  744. name_elem = root.find("name")
  745. summary_elem = root.find("summary")
  746. rating_elem = root.find("rating")
  747. rating_expl_elem = root.find("rating_explanation")
  748. findings_elem = root.find("findings")
  749. name = name_elem.text if name_elem is not None else ""
  750. summary = summary_elem.text if summary_elem is not None else ""
  751. rating = (
  752. float(rating_elem.text)
  753. if isinstance(rating_elem, Element) and rating_elem.text
  754. else ""
  755. )
  756. rating_explanation = (
  757. rating_expl_elem.text
  758. if rating_expl_elem is not None
  759. else None
  760. )
  761. findings = (
  762. [f.text for f in findings_elem.findall("finding")]
  763. if findings_elem is not None
  764. else []
  765. )
  766. # build embedding
  767. embed_text = (
  768. "Summary:\n"
  769. + (summary or "")
  770. + "\n\nFindings:\n"
  771. + "\n".join(
  772. finding for finding in findings if finding is not None
  773. )
  774. )
  775. embedding = await self.providers.embedding.async_get_embedding(
  776. embed_text
  777. )
  778. # build Community object
  779. community = Community(
  780. community_id=community_id,
  781. collection_id=collection_id,
  782. name=name,
  783. summary=summary,
  784. rating=rating,
  785. rating_explanation=rating_explanation,
  786. findings=findings,
  787. description_embedding=embedding,
  788. )
  789. # store it
  790. await self.providers.database.graphs_handler.add_community(
  791. community
  792. )
  793. return {
  794. "community_id": community_id,
  795. "name": name,
  796. }
  797. except Exception as e:
  798. logger.error(
  799. f"Error summarizing community {community_id}: {e}"
  800. )
  801. if attempt == 2:
  802. return {"community_id": community_id, "error": str(e)}
  803. await asyncio.sleep(1)
  804. # fallback
  805. return {"community_id": community_id, "error": "Failed after retries"}
  806. async def _community_summary_prompt(
  807. self,
  808. entities: list[Entity],
  809. relationships: list[Relationship],
  810. max_summary_input_length: int,
  811. ) -> str:
  812. """Gathers the entity/relationship text, tries not to exceed
  813. `max_summary_input_length`."""
  814. # Group them by entity.name
  815. entity_map: dict[str, dict] = {}
  816. for e in entities:
  817. entity_map.setdefault(
  818. e.name, {"entities": [], "relationships": []}
  819. )
  820. entity_map[e.name]["entities"].append(e)
  821. for r in relationships:
  822. # subject
  823. entity_map.setdefault(
  824. r.subject, {"entities": [], "relationships": []}
  825. )
  826. entity_map[r.subject]["relationships"].append(r)
  827. # sort by # of relationships
  828. sorted_entries = sorted(
  829. entity_map.items(),
  830. key=lambda x: len(x[1]["relationships"]),
  831. reverse=True,
  832. )
  833. # build up the prompt text
  834. prompt_chunks = []
  835. cur_len = 0
  836. for entity_name, data in sorted_entries:
  837. block = f"\nEntity: {entity_name}\nDescriptions:\n"
  838. block += "\n".join(
  839. f"{e.id},{(e.description or '')}" for e in data["entities"]
  840. )
  841. block += "\nRelationships:\n"
  842. block += "\n".join(
  843. f"{r.id},{r.subject},{r.object},{r.predicate},{r.description or ''}"
  844. for r in data["relationships"]
  845. )
  846. # check length
  847. if cur_len + len(block) > max_summary_input_length:
  848. prompt_chunks.append(
  849. block[: max_summary_input_length - cur_len]
  850. )
  851. break
  852. else:
  853. prompt_chunks.append(block)
  854. cur_len += len(block)
  855. return "".join(prompt_chunks)
  856. async def delete(
  857. self,
  858. collection_id: UUID,
  859. **kwargs,
  860. ):
  861. return await self.providers.database.graphs_handler.delete(
  862. collection_id=collection_id,
  863. )
  864. async def graph_search_results_extraction(
  865. self,
  866. document_id: UUID,
  867. generation_config: GenerationConfig,
  868. entity_types: list[str],
  869. relation_types: list[str],
  870. chunk_merge_count: int,
  871. filter_out_existing_chunks: bool = True,
  872. total_tasks: Optional[int] = None,
  873. *args: Any,
  874. **kwargs: Any,
  875. ) -> AsyncGenerator[GraphExtraction | R2RDocumentProcessingError, None]:
  876. """The original “extract Graph from doc” logic, but inlined instead of
  877. referencing a pipe."""
  878. start_time = time.time()
  879. logger.info(
  880. f"Graph Extraction: Processing document {document_id} for graph extraction"
  881. )
  882. # Retrieve chunks from DB
  883. chunks = []
  884. limit = 100
  885. offset = 0
  886. while True:
  887. chunk_req = await self.providers.database.chunks_handler.list_document_chunks(
  888. document_id=document_id,
  889. offset=offset,
  890. limit=limit,
  891. )
  892. new_chunk_objs = [
  893. DocumentChunk(
  894. id=chunk["id"],
  895. document_id=chunk["document_id"],
  896. owner_id=chunk["owner_id"],
  897. collection_ids=chunk["collection_ids"],
  898. data=chunk["text"],
  899. metadata=chunk["metadata"],
  900. )
  901. for chunk in chunk_req["results"]
  902. ]
  903. chunks.extend(new_chunk_objs)
  904. if len(chunk_req["results"]) < limit:
  905. break
  906. offset += limit
  907. if not chunks:
  908. logger.info(f"No chunks found for document {document_id}")
  909. raise R2RException(
  910. message="No chunks found for document",
  911. status_code=404,
  912. )
  913. # Possibly filter out any chunks that have already been processed
  914. if filter_out_existing_chunks:
  915. existing_chunk_ids = await self.providers.database.graphs_handler.get_existing_document_entity_chunk_ids(
  916. document_id=document_id
  917. )
  918. before_count = len(chunks)
  919. chunks = [c for c in chunks if c.id not in existing_chunk_ids]
  920. logger.info(
  921. f"Filtered out {len(existing_chunk_ids)} existing chunk-IDs. {before_count}->{len(chunks)} remain."
  922. )
  923. if not chunks:
  924. return # nothing left to yield
  925. # sort by chunk_order if present
  926. chunks = sorted(
  927. chunks,
  928. key=lambda x: x.metadata.get("chunk_order", float("inf")),
  929. )
  930. # group them
  931. grouped_chunks = [
  932. chunks[i : i + chunk_merge_count]
  933. for i in range(0, len(chunks), chunk_merge_count)
  934. ]
  935. logger.info(
  936. f"Graph Extraction: Created {len(grouped_chunks)} tasks for doc={document_id}"
  937. )
  938. tasks = [
  939. asyncio.create_task(
  940. self._extract_graph_search_results_from_chunk_group(
  941. chunk_group,
  942. generation_config,
  943. entity_types,
  944. relation_types,
  945. )
  946. )
  947. for chunk_group in grouped_chunks
  948. ]
  949. completed_tasks = 0
  950. for t in asyncio.as_completed(tasks):
  951. try:
  952. yield await t
  953. completed_tasks += 1
  954. if completed_tasks % 100 == 0:
  955. logger.info(
  956. f"Graph Extraction: completed {completed_tasks}/{len(tasks)} tasks"
  957. )
  958. except Exception as e:
  959. logger.error(f"Error extracting from chunk group: {e}")
  960. yield R2RDocumentProcessingError(
  961. document_id=document_id,
  962. error_message=str(e),
  963. )
  964. logger.info(
  965. f"Graph Extraction: done with {document_id}, time={time.time() - start_time:.2f}s"
  966. )
  967. async def _extract_graph_search_results_from_chunk_group(
  968. self,
  969. chunks: list[DocumentChunk],
  970. generation_config: GenerationConfig,
  971. entity_types: list[str],
  972. relation_types: list[str],
  973. retries: int = 5,
  974. delay: int = 2,
  975. ) -> GraphExtraction:
  976. """(Equivalent to _extract_graph_search_results in old code.) Merges
  977. chunk data, calls LLM, parses XML, returns GraphExtraction object."""
  978. combined_extraction: str = " ".join(
  979. [
  980. c.data.decode("utf-8") if isinstance(c.data, bytes) else c.data
  981. for c in chunks
  982. if c.data
  983. ]
  984. )
  985. # Possibly get doc-level summary
  986. doc_id = chunks[0].document_id
  987. response = await self.providers.database.documents_handler.get_documents_overview(
  988. offset=0,
  989. limit=1,
  990. filter_document_ids=[doc_id],
  991. )
  992. document_summary = (
  993. response["results"][0].summary if response["results"] else None
  994. )
  995. # Build messages/prompt
  996. prompt_name = self.providers.database.config.graph_creation_settings.graph_extraction_prompt
  997. messages = (
  998. await self.providers.database.prompts_handler.get_message_payload(
  999. task_prompt_name=prompt_name,
  1000. task_inputs={
  1001. "document_summary": document_summary or "",
  1002. "input": combined_extraction,
  1003. "entity_types": "\n".join(entity_types),
  1004. "relation_types": "\n".join(relation_types),
  1005. },
  1006. )
  1007. )
  1008. for attempt in range(retries):
  1009. try:
  1010. resp = await self.providers.llm.aget_completion(
  1011. messages, generation_config=generation_config
  1012. )
  1013. graph_search_results_str = resp.choices[0].message.content
  1014. if not graph_search_results_str:
  1015. raise R2RException(
  1016. "No extraction found in LLM response.",
  1017. 400,
  1018. )
  1019. logger.info(generation_config)
  1020. logger.info(graph_search_results_str)
  1021. # parse the XML
  1022. (
  1023. entities,
  1024. relationships,
  1025. ) = await self._parse_graph_search_results_extraction_xml(
  1026. graph_search_results_str, chunks
  1027. )
  1028. return GraphExtraction(
  1029. entities=entities, relationships=relationships
  1030. )
  1031. except Exception as e:
  1032. if attempt < retries - 1:
  1033. await asyncio.sleep(delay)
  1034. continue
  1035. else:
  1036. logger.error(
  1037. f"All extraction attempts for doc={doc_id} and chunks{[chunk.id for chunk in chunks]} failed with error:\n{e}"
  1038. )
  1039. return GraphExtraction(entities=[], relationships=[])
  1040. return GraphExtraction(entities=[], relationships=[])
  1041. async def _parse_graph_search_results_extraction_xml(
  1042. self, response_str: str, chunks: list[DocumentChunk]
  1043. ) -> tuple[list[Entity], list[Relationship]]:
  1044. """Helper to parse the LLM's XML format, handle edge cases/cleanup,
  1045. produce Entities/Relationships."""
  1046. def sanitize_xml(r: str) -> str:
  1047. # Remove markdown fences
  1048. r = re.sub(r"```xml|```", "", r)
  1049. # Remove xml instructions or userStyle
  1050. r = re.sub(r"<\?.*?\?>", "", r)
  1051. r = re.sub(r"<userStyle>.*?</userStyle>", "", r)
  1052. # Replace bare `&` with `&amp;`
  1053. r = re.sub(r"&(?!amp;|quot;|apos;|lt;|gt;)", "&amp;", r)
  1054. # Also remove <root> if it appears
  1055. r = r.replace("<root>", "").replace("</root>", "")
  1056. return r.strip()
  1057. cleaned_xml = sanitize_xml(response_str)
  1058. wrapped = f"<root>{cleaned_xml}</root>"
  1059. try:
  1060. root = ET.fromstring(wrapped, parser=ET.XMLParser(encoding="utf-8"))
  1061. except ET.ParseError:
  1062. raise R2RException(
  1063. f"Failed to parse XML:\nData: {wrapped}", 400
  1064. ) from None
  1065. entities_elems = root.findall(".//entity")
  1066. if (
  1067. len(response_str) > MIN_VALID_GRAPH_EXTRACTION_RESPONSE_LENGTH
  1068. and len(entities_elems) == 0
  1069. ):
  1070. raise R2RException(
  1071. f"No <entity> found in LLM XML, possibly malformed. Response excerpt: {response_str[:300]}",
  1072. 400,
  1073. )
  1074. # build entity objects
  1075. doc_id = chunks[0].document_id
  1076. chunk_ids = [c.id for c in chunks]
  1077. entities_list: list[Entity] = []
  1078. for element in entities_elems:
  1079. name_attr = element.get("name")
  1080. type_elem = element.find("type")
  1081. desc_elem = element.find("description")
  1082. category = type_elem.text if type_elem is not None else None
  1083. desc = desc_elem.text if desc_elem is not None else None
  1084. desc_embed = await self.providers.embedding.async_get_embedding(
  1085. desc or ""
  1086. )
  1087. ent = Entity(
  1088. category=category,
  1089. description=desc,
  1090. name=name_attr,
  1091. parent_id=doc_id,
  1092. chunk_ids=chunk_ids,
  1093. description_embedding=desc_embed,
  1094. attributes={},
  1095. )
  1096. entities_list.append(ent)
  1097. # build relationship objects
  1098. relationships_list: list[Relationship] = []
  1099. rel_elems = root.findall(".//relationship")
  1100. for r_elem in rel_elems:
  1101. source_elem = r_elem.find("source")
  1102. target_elem = r_elem.find("target")
  1103. type_elem = r_elem.find("type")
  1104. desc_elem = r_elem.find("description")
  1105. weight_elem = r_elem.find("weight")
  1106. try:
  1107. subject = source_elem.text if source_elem is not None else ""
  1108. object_ = target_elem.text if target_elem is not None else ""
  1109. predicate = type_elem.text if type_elem is not None else ""
  1110. desc = desc_elem.text if desc_elem is not None else ""
  1111. weight = (
  1112. float(weight_elem.text)
  1113. if isinstance(weight_elem, Element) and weight_elem.text
  1114. else ""
  1115. )
  1116. embed = await self.providers.embedding.async_get_embedding(
  1117. desc or ""
  1118. )
  1119. rel = Relationship(
  1120. subject=subject,
  1121. predicate=predicate,
  1122. object=object_,
  1123. description=desc,
  1124. weight=weight,
  1125. parent_id=doc_id,
  1126. chunk_ids=chunk_ids,
  1127. attributes={},
  1128. description_embedding=embed,
  1129. )
  1130. relationships_list.append(rel)
  1131. except Exception:
  1132. continue
  1133. return entities_list, relationships_list
  1134. async def store_graph_search_results_extractions(
  1135. self,
  1136. graph_search_results_extractions: list[GraphExtraction],
  1137. ):
  1138. """Stores a batch of knowledge graph extractions in the DB."""
  1139. for extraction in graph_search_results_extractions:
  1140. # Map name->id after creation
  1141. entities_id_map = {}
  1142. for e in extraction.entities:
  1143. if e.parent_id is not None:
  1144. result = await self.providers.database.graphs_handler.entities.create(
  1145. name=e.name,
  1146. parent_id=e.parent_id,
  1147. store_type=StoreType.DOCUMENTS,
  1148. category=e.category,
  1149. description=e.description,
  1150. description_embedding=e.description_embedding,
  1151. chunk_ids=e.chunk_ids,
  1152. metadata=e.metadata,
  1153. )
  1154. entities_id_map[e.name] = result.id
  1155. else:
  1156. logger.warning(f"Skipping entity with None parent_id: {e}")
  1157. # Insert relationships
  1158. for rel in extraction.relationships:
  1159. subject_id = entities_id_map.get(rel.subject)
  1160. object_id = entities_id_map.get(rel.object)
  1161. parent_id = rel.parent_id
  1162. if any(
  1163. id is None for id in (subject_id, object_id, parent_id)
  1164. ):
  1165. logger.warning(f"Missing ID for relationship: {rel}")
  1166. continue
  1167. assert isinstance(subject_id, UUID)
  1168. assert isinstance(object_id, UUID)
  1169. assert isinstance(parent_id, UUID)
  1170. await self.providers.database.graphs_handler.relationships.create(
  1171. subject=rel.subject,
  1172. subject_id=subject_id,
  1173. predicate=rel.predicate,
  1174. object=rel.object,
  1175. object_id=object_id,
  1176. parent_id=parent_id,
  1177. description=rel.description,
  1178. description_embedding=rel.description_embedding,
  1179. weight=rel.weight,
  1180. metadata=rel.metadata,
  1181. store_type=StoreType.DOCUMENTS,
  1182. )
  1183. async def deduplicate_document_entities(
  1184. self,
  1185. document_id: UUID,
  1186. ):
  1187. """
  1188. Inlined from old code: merges duplicates by name, calls LLM for a new consolidated description, updates the record.
  1189. """
  1190. merged_results = await self.providers.database.entities_handler.merge_duplicate_name_blocks(
  1191. parent_id=document_id,
  1192. store_type=StoreType.DOCUMENTS,
  1193. )
  1194. # Grab doc summary
  1195. response = await self.providers.database.documents_handler.get_documents_overview(
  1196. offset=0,
  1197. limit=1,
  1198. filter_document_ids=[document_id],
  1199. )
  1200. document_summary = (
  1201. response["results"][0].summary if response["results"] else None
  1202. )
  1203. # For each merged entity
  1204. for original_entities, merged_entity in merged_results:
  1205. # Summarize them with LLM
  1206. entity_info = "\n".join(
  1207. e.description for e in original_entities if e.description
  1208. )
  1209. messages = await self.providers.database.prompts_handler.get_message_payload(
  1210. task_prompt_name=self.providers.database.config.graph_creation_settings.graph_entity_description_prompt,
  1211. task_inputs={
  1212. "document_summary": document_summary,
  1213. "entity_info": f"{merged_entity.name}\n{entity_info}",
  1214. "relationships_txt": "",
  1215. },
  1216. )
  1217. gen_config = (
  1218. self.config.database.graph_creation_settings.generation_config
  1219. or GenerationConfig(model=self.config.app.fast_llm)
  1220. )
  1221. resp = await self.providers.llm.aget_completion(
  1222. messages, generation_config=gen_config
  1223. )
  1224. new_description = resp.choices[0].message.content
  1225. new_embedding = await self.providers.embedding.async_get_embedding(
  1226. new_description or ""
  1227. )
  1228. if merged_entity.id is not None:
  1229. await self.providers.database.graphs_handler.entities.update(
  1230. entity_id=merged_entity.id,
  1231. store_type=StoreType.DOCUMENTS,
  1232. description=new_description,
  1233. description_embedding=str(new_embedding),
  1234. )
  1235. else:
  1236. logger.warning("Skipping update for entity with None id")